Skip to content

Commit

Permalink
Merge pull request #279 from ConsenSys/feat/statistics
Browse files Browse the repository at this point in the history
Adds internal/stats package
  • Loading branch information
gbotrel authored Mar 9, 2022
2 parents d9b9239 + 2062350 commit bf6e081
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 120 deletions.
119 changes: 0 additions & 119 deletions circuitstats_test.go
Original file line number Diff line number Diff line change
@@ -1,120 +1 @@
package gnark

import (
"encoding/gob"
"os"
"sync"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/internal/backend/circuits"
"github.com/consensys/gnark/test"
)

const (
fileStats = "init.stats"
generateNewStats = false
)

var statsM sync.Mutex

func TestCircuitStatistics(t *testing.T) {
assert := test.NewAssert(t)
for k := range circuits.Circuits {
for _, curve := range ecc.Implemented() {
for _, b := range backend.Implemented() {
curve := curve
backendID := b
name := k
// copy the circuit now in case assert calls t.Parallel()
tData := circuits.Circuits[k]
assert.Run(func(assert *test.Assert) {
var newCompiler frontend.NewBuilder

switch backendID {
case backend.GROTH16:
newCompiler = r1cs.NewBuilder
case backend.PLONK:
newCompiler = scs.NewBuilder
default:
panic("not implemented")
}

ccs, err := frontend.Compile(curve, newCompiler, tData.Circuit)
assert.NoError(err)

// ensure we didn't introduce regressions that make circuits less efficient
nbConstraints := ccs.GetNbConstraints()
internal, secret, public := ccs.GetNbVariables()
checkStats(assert, name, nbConstraints, internal, secret, public, curve, backendID)
}, name, curve.String(), backendID.String())
}
}

}

// serialize newStats
if generateNewStats {
fStats, err := os.Create(fileStats)
assert.NoError(err)

encoder := gob.NewEncoder(fStats)
err = encoder.Encode(mStats)
assert.NoError(err)
}
}

type circuitStats struct {
NbConstraints, Internal, Secret, Public int
}

var mStats map[string][backend.PLONK + 1][ecc.BW6_633 + 1]circuitStats

func checkStats(assert *test.Assert, circuitName string, nbConstraints, internal, secret, public int, curve ecc.ID, backendID backend.ID) {
statsM.Lock()
defer statsM.Unlock()
if generateNewStats {
rs := mStats[circuitName]
rs[backendID][curve] = circuitStats{nbConstraints, internal, secret, public}
mStats[circuitName] = rs
return
}
if referenceStats, ok := mStats[circuitName]; !ok {
assert.Log("warning: no stats for circuit", circuitName)
} else {
ref := referenceStats[backendID][curve]
if ref.NbConstraints != nbConstraints {
assert.Failf("unexpected constraint count", "expected %d nbConstraints (reference), got %d. %s, %s, %s", ref.NbConstraints, nbConstraints, circuitName, backendID.String(), curve.String())
}
if ref.Internal != internal {
assert.Failf("unexpected internal variable count", "expected %d internal (reference), got %d. %s, %s, %s", ref.Internal, internal, circuitName, backendID.String(), curve.String())
}
if ref.Secret != secret {
assert.Failf("unexpected secret variable count", "expected %d secret (reference), got %d. %s, %s, %s", ref.Secret, secret, circuitName, backendID.String(), curve.String())
}
if ref.Public != public {
assert.Failf("unexpected public variable count", "expected %d public (reference), got %d. %s, %s, %s", ref.Public, public, circuitName, backendID.String(), curve.String())
}
}
}

func init() {
mStats = make(map[string][backend.PLONK + 1][ecc.BW6_633 + 1]circuitStats)

if !generateNewStats {
fStats, err := os.Open(fileStats)
if err != nil {
panic(err)
}
decoder := gob.NewDecoder(fStats)
err = decoder.Decode(&mStats)
if err != nil {
panic(err)
}
}

}
Binary file removed init.stats
Binary file not shown.
76 changes: 76 additions & 0 deletions internal/stats/generate/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package main

import (
"flag"
"fmt"
"log"
"regexp"
"sync"

"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/internal/stats"
)

var (
fSave = flag.Bool("s", false, "save new stats in file ")
fFilter = flag.String("run", "", "filter runs with regexp; example 'pairing*'")
)

func main() {
flag.Parse()

var r *regexp.Regexp
if *fFilter != "" {
r = regexp.MustCompile(*fFilter)
}

s := stats.NewGlobalStats()

// load reference objects
// for each circuit, on each curve, on each backend
// compare with reference stats
snippets := stats.GetSnippets()
var wg sync.WaitGroup
for name, c := range snippets {
if r != nil && !r.MatchString(name) {
continue
}
wg.Add(1)
go func(name string, circuit stats.Circuit) {
defer wg.Done()
for _, curve := range circuit.Curves {
for _, backendID := range backend.Implemented() {
cs, err := stats.NewSnippetStats(curve, backendID, circuit.Circuit)
if err != nil {
log.Fatalf("building stats for circuit %s %v", name, err)
}
s.Add(curve, backendID, cs, name)
}
}
}(name, c)
}
wg.Wait()

fmt.Println("id,curve,backend,nbConstraints,nbWires")
for name, c := range snippets {
if r != nil && !r.MatchString(name) {
continue
}
ss := s.Stats[name]
for _, curve := range c.Curves {
for _, backendID := range backend.Implemented() {
cs := ss[backendID][curve]
fmt.Printf("%s,%s,%s,%d,%d\n", name, curve, backendID, cs.NbConstraints, cs.NbInternalWires)
}
}
}

if *fSave {
const refPath = "../latest.stats"
if err := s.Save(refPath); err != nil {
log.Fatal(err)
}
log.Println("successfully saved new reference stats file", refPath)
}

}
Binary file added internal/stats/latest.stats
Binary file not shown.
153 changes: 153 additions & 0 deletions internal/stats/snippet.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package stats

import (
"math"
"sync"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/fields_bls12377"
"github.com/consensys/gnark/std/algebra/fields_bls24315"
"github.com/consensys/gnark/std/algebra/sw_bls12377"
"github.com/consensys/gnark/std/algebra/sw_bls24315"
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/std/math/bits"
)

var (
initOnce sync.Once
snippets = make(map[string]Circuit)
)

func GetSnippets() map[string]Circuit {
initOnce.Do(initSnippets)
return snippets
}

type snippet func(api frontend.API, newVariable func() frontend.Variable)

func registerSnippet(name string, snippet snippet, curves ...ecc.ID) {
if _, ok := snippets[name]; ok {
panic("circuit " + name + " already registered")
}
if len(curves) == 0 {
curves = ecc.Implemented()
}
snippets[name] = Circuit{makeSnippetCircuit(snippet), curves}
}

func initSnippets() {
// add api snippets
registerSnippet("api/IsZero", func(api frontend.API, newVariable func() frontend.Variable) {
_ = api.IsZero(newVariable())
})

registerSnippet("api/Lookup2", func(api frontend.API, newVariable func() frontend.Variable) {
_ = api.Lookup2(newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable())
})

registerSnippet("api/AssertIsLessOrEqual", func(api frontend.API, newVariable func() frontend.Variable) {
api.AssertIsLessOrEqual(newVariable(), newVariable())
})
registerSnippet("api/AssertIsLessOrEqual/constant_bound_64_bits", func(api frontend.API, newVariable func() frontend.Variable) {
bound := uint64(math.MaxUint64)
api.AssertIsLessOrEqual(newVariable(), bound)
})

// add std snippets
registerSnippet("math/bits.ToBinary", func(api frontend.API, newVariable func() frontend.Variable) {
_ = bits.ToBinary(api, newVariable())
})
registerSnippet("math/bits.ToBinary/unconstrained", func(api frontend.API, newVariable func() frontend.Variable) {
_ = bits.ToBinary(api, newVariable(), bits.WithUnconstrainedOutputs())
})
registerSnippet("math/bits.ToTernary", func(api frontend.API, newVariable func() frontend.Variable) {
_ = bits.ToTernary(api, newVariable())
})
registerSnippet("math/bits.ToTernary/unconstrained", func(api frontend.API, newVariable func() frontend.Variable) {
_ = bits.ToTernary(api, newVariable(), bits.WithUnconstrainedOutputs())
})
registerSnippet("math/bits.ToNAF", func(api frontend.API, newVariable func() frontend.Variable) {
_ = bits.ToNAF(api, newVariable())
})
registerSnippet("math/bits.ToNAF/unconstrained", func(api frontend.API, newVariable func() frontend.Variable) {
_ = bits.ToNAF(api, newVariable(), bits.WithUnconstrainedOutputs())
})

registerSnippet("hash/mimc", func(api frontend.API, newVariable func() frontend.Variable) {
mimc, _ := mimc.NewMiMC(api)
mimc.Write(newVariable())
_ = mimc.Sum()
})

registerSnippet("pairing_bls12377", func(api frontend.API, newVariable func() frontend.Variable) {
ateLoop := uint64(9586122913090633729)
ext := fields_bls12377.GetBLS12377ExtensionFp12(api)
pairingInfo := sw_bls12377.PairingContext{AteLoop: ateLoop, Extension: ext}

var dummyG1 sw_bls12377.G1Affine
var dummyG2 sw_bls12377.G2Affine
dummyG1.X = newVariable()
dummyG1.Y = newVariable()
dummyG2.X.A0 = newVariable()
dummyG2.X.A1 = newVariable()
dummyG2.Y.A0 = newVariable()
dummyG2.Y.A1 = newVariable()

var resMillerLoop fields_bls12377.E12
// e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB)
sw_bls12377.MillerLoop(api, dummyG1, dummyG2, &resMillerLoop, pairingInfo)

// performs the final expo
var resPairing fields_bls12377.E12
resPairing.FinalExponentiation(api, resMillerLoop, pairingInfo.AteLoop, pairingInfo.Extension)
}, ecc.BW6_761)

registerSnippet("pairing_bls24315", func(api frontend.API, newVariable func() frontend.Variable) {
ateLoop := uint64(3218079743)
ext := fields_bls24315.GetBLS24315ExtensionFp24(api)
pairingInfo := sw_bls24315.PairingContext{AteLoop: ateLoop, Extension: ext}

var dummyG1 sw_bls24315.G1Affine
var dummyG2 sw_bls24315.G2Affine
dummyG1.X = newVariable()
dummyG1.Y = newVariable()
dummyG2.X.B0.A0 = newVariable()
dummyG2.X.B0.A1 = newVariable()
dummyG2.X.B1.A0 = newVariable()
dummyG2.X.B1.A1 = newVariable()
dummyG2.Y.B0.A0 = newVariable()
dummyG2.Y.B0.A1 = newVariable()
dummyG2.Y.B1.A0 = newVariable()
dummyG2.Y.B1.A1 = newVariable()

var resMillerLoop fields_bls24315.E24
// e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB)
sw_bls24315.MillerLoop(api, dummyG1, dummyG2, &resMillerLoop, pairingInfo)

// performs the final expo
var resPairing fields_bls24315.E24
resPairing.FinalExponentiation(api, resMillerLoop, pairingInfo.AteLoop, pairingInfo.Extension)
}, ecc.BW6_633)

}

type snippetCircuit struct {
V [1024]frontend.Variable
s snippet
vIndex int
}

func (d *snippetCircuit) Define(api frontend.API) error {
d.s(api, d.newVariable)
return nil
}

func (d *snippetCircuit) newVariable() frontend.Variable {
d.vIndex++
return d.V[(d.vIndex-1)%len(d.V)]
}

func makeSnippetCircuit(s snippet) frontend.Circuit {
return &snippetCircuit{s: s}
}
Loading

0 comments on commit bf6e081

Please sign in to comment.