Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds internal/stats package #279

Merged
merged 8 commits into from
Mar 9, 2022
119 changes: 0 additions & 119 deletions circuitstats_test.go
Original file line number Diff line number Diff line change
@@ -1,120 +1 @@
package gnark

import (
"encoding/gob"
"os"
"sync"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/internal/backend/circuits"
"github.com/consensys/gnark/test"
)

const (
fileStats = "init.stats"
generateNewStats = false
)

var statsM sync.Mutex

func TestCircuitStatistics(t *testing.T) {
assert := test.NewAssert(t)
for k := range circuits.Circuits {
for _, curve := range ecc.Implemented() {
for _, b := range backend.Implemented() {
curve := curve
backendID := b
name := k
// copy the circuit now in case assert calls t.Parallel()
tData := circuits.Circuits[k]
assert.Run(func(assert *test.Assert) {
var newCompiler frontend.NewBuilder

switch backendID {
case backend.GROTH16:
newCompiler = r1cs.NewBuilder
case backend.PLONK:
newCompiler = scs.NewBuilder
default:
panic("not implemented")
}

ccs, err := frontend.Compile(curve, newCompiler, tData.Circuit)
assert.NoError(err)

// ensure we didn't introduce regressions that make circuits less efficient
nbConstraints := ccs.GetNbConstraints()
internal, secret, public := ccs.GetNbVariables()
checkStats(assert, name, nbConstraints, internal, secret, public, curve, backendID)
}, name, curve.String(), backendID.String())
}
}

}

// serialize newStats
if generateNewStats {
fStats, err := os.Create(fileStats)
assert.NoError(err)

encoder := gob.NewEncoder(fStats)
err = encoder.Encode(mStats)
assert.NoError(err)
}
}

type circuitStats struct {
NbConstraints, Internal, Secret, Public int
}

var mStats map[string][backend.PLONK + 1][ecc.BW6_633 + 1]circuitStats

func checkStats(assert *test.Assert, circuitName string, nbConstraints, internal, secret, public int, curve ecc.ID, backendID backend.ID) {
statsM.Lock()
defer statsM.Unlock()
if generateNewStats {
rs := mStats[circuitName]
rs[backendID][curve] = circuitStats{nbConstraints, internal, secret, public}
mStats[circuitName] = rs
return
}
if referenceStats, ok := mStats[circuitName]; !ok {
assert.Log("warning: no stats for circuit", circuitName)
} else {
ref := referenceStats[backendID][curve]
if ref.NbConstraints != nbConstraints {
assert.Failf("unexpected constraint count", "expected %d nbConstraints (reference), got %d. %s, %s, %s", ref.NbConstraints, nbConstraints, circuitName, backendID.String(), curve.String())
}
if ref.Internal != internal {
assert.Failf("unexpected internal variable count", "expected %d internal (reference), got %d. %s, %s, %s", ref.Internal, internal, circuitName, backendID.String(), curve.String())
}
if ref.Secret != secret {
assert.Failf("unexpected secret variable count", "expected %d secret (reference), got %d. %s, %s, %s", ref.Secret, secret, circuitName, backendID.String(), curve.String())
}
if ref.Public != public {
assert.Failf("unexpected public variable count", "expected %d public (reference), got %d. %s, %s, %s", ref.Public, public, circuitName, backendID.String(), curve.String())
}
}
}

func init() {
mStats = make(map[string][backend.PLONK + 1][ecc.BW6_633 + 1]circuitStats)

if !generateNewStats {
fStats, err := os.Open(fileStats)
if err != nil {
panic(err)
}
decoder := gob.NewDecoder(fStats)
err = decoder.Decode(&mStats)
if err != nil {
panic(err)
}
}

}
Binary file removed init.stats
Binary file not shown.
76 changes: 76 additions & 0 deletions internal/stats/generate/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package main

import (
"flag"
"fmt"
"log"
"regexp"
"sync"

"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/internal/stats"
)

var (
fSave = flag.Bool("s", false, "save new stats in file ")
fFilter = flag.String("run", "", "filter runs with regexp; example 'pairing*'")
)

func main() {
flag.Parse()

var r *regexp.Regexp
if *fFilter != "" {
r = regexp.MustCompile(*fFilter)
}

s := stats.NewGlobalStats()

// load reference objects
// for each circuit, on each curve, on each backend
// compare with reference stats
snippets := stats.GetSnippets()
var wg sync.WaitGroup
for name, c := range snippets {
if r != nil && !r.MatchString(name) {
continue
}
wg.Add(1)
go func(name string, circuit stats.Circuit) {
defer wg.Done()
for _, curve := range circuit.Curves {
for _, backendID := range backend.Implemented() {
cs, err := stats.NewSnippetStats(curve, backendID, circuit.Circuit)
if err != nil {
log.Fatalf("building stats for circuit %s %v", name, err)
}
s.Add(curve, backendID, cs, name)
}
}
}(name, c)
}
wg.Wait()

fmt.Println("id,curve,backend,nbConstraints,nbWires")
for name, c := range snippets {
if r != nil && !r.MatchString(name) {
continue
}
ss := s.Stats[name]
for _, curve := range c.Curves {
for _, backendID := range backend.Implemented() {
cs := ss[backendID][curve]
fmt.Printf("%s,%s,%s,%d,%d\n", name, curve, backendID, cs.NbConstraints, cs.NbInternalWires)
}
}
}

if *fSave {
const refPath = "../latest.stats"
if err := s.Save(refPath); err != nil {
log.Fatal(err)
}
log.Println("successfully saved new reference stats file", refPath)
}

}
Binary file added internal/stats/latest.stats
Binary file not shown.
147 changes: 147 additions & 0 deletions internal/stats/snippet.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package stats

import (
"math"
"sync"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/fields_bls12377"
"github.com/consensys/gnark/std/algebra/fields_bls24315"
"github.com/consensys/gnark/std/algebra/sw_bls12377"
"github.com/consensys/gnark/std/algebra/sw_bls24315"
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/std/math/bits"
)

var (
initOnce sync.Once
snippets = make(map[string]Circuit)
)

func GetSnippets() map[string]Circuit {
initOnce.Do(initSnippets)
return snippets
}

type snippet func(frontend.API, frontend.Variable)

func registerSnippet(name string, snippet snippet, curves ...ecc.ID) {
if _, ok := snippets[name]; ok {
panic("circuit " + name + " already registered")
}
if len(curves) == 0 {
curves = ecc.Implemented()
}
snippets[name] = Circuit{makeSnippetCircuit(snippet), curves}
}

func initSnippets() {
// add api snippets
registerSnippet("api/IsZero", func(api frontend.API, v frontend.Variable) {
_ = api.IsZero(v)
})

registerSnippet("api/Lookup2", func(api frontend.API, v frontend.Variable) {
_ = api.Lookup2(v, v, v, v, v, v)
})

registerSnippet("api/AssertIsLessOrEqual", func(api frontend.API, v frontend.Variable) {
api.AssertIsLessOrEqual(v, v)
})
registerSnippet("api/AssertIsLessOrEqual/constant_bound_64_bits", func(api frontend.API, v frontend.Variable) {
bound := uint64(math.MaxUint64)
api.AssertIsLessOrEqual(v, bound)
})

// add std snippets
registerSnippet("math/bits.ToBinary", func(api frontend.API, v frontend.Variable) {
_ = bits.ToBinary(api, v)
})
registerSnippet("math/bits.ToBinary/unconstrained", func(api frontend.API, v frontend.Variable) {
_ = bits.ToBinary(api, v, bits.WithUnconstrainedOutputs())
})
registerSnippet("math/bits.ToTernary", func(api frontend.API, v frontend.Variable) {
_ = bits.ToTernary(api, v)
})
registerSnippet("math/bits.ToTernary/unconstrained", func(api frontend.API, v frontend.Variable) {
_ = bits.ToTernary(api, v, bits.WithUnconstrainedOutputs())
})
registerSnippet("math/bits.ToNAF", func(api frontend.API, v frontend.Variable) {
_ = bits.ToNAF(api, v)
})
registerSnippet("math/bits.ToNAF/unconstrained", func(api frontend.API, v frontend.Variable) {
_ = bits.ToNAF(api, v, bits.WithUnconstrainedOutputs())
})

registerSnippet("hash/mimc", func(api frontend.API, v frontend.Variable) {
mimc, _ := mimc.NewMiMC(api)
mimc.Write(v)
_ = mimc.Sum()
})

registerSnippet("pairing_bls12377", func(api frontend.API, v frontend.Variable) {
ateLoop := uint64(9586122913090633729)
ext := fields_bls12377.GetBLS12377ExtensionFp12(api)
pairingInfo := sw_bls12377.PairingContext{AteLoop: ateLoop, Extension: ext}

var dummyG1 sw_bls12377.G1Affine
var dummyG2 sw_bls12377.G2Affine
dummyG1.X = v
dummyG1.Y = v
dummyG2.X.A0 = v
dummyG2.X.A1 = v
dummyG2.Y.A0 = v
dummyG2.Y.A1 = v

var resMillerLoop fields_bls12377.E12
// e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB)
sw_bls12377.MillerLoop(api, dummyG1, dummyG2, &resMillerLoop, pairingInfo)

// performs the final expo
var resPairing fields_bls12377.E12
resPairing.FinalExponentiation(api, resMillerLoop, pairingInfo.AteLoop, pairingInfo.Extension)
}, ecc.BW6_761)

registerSnippet("pairing_bls24315", func(api frontend.API, v frontend.Variable) {
ateLoop := uint64(3218079743)
ext := fields_bls24315.GetBLS24315ExtensionFp24(api)
pairingInfo := sw_bls24315.PairingContext{AteLoop: ateLoop, Extension: ext}

var dummyG1 sw_bls24315.G1Affine
var dummyG2 sw_bls24315.G2Affine
dummyG1.X = v
dummyG1.Y = v
dummyG2.X.B0.A0 = v
dummyG2.X.B0.A1 = v
dummyG2.X.B1.A0 = v
dummyG2.X.B1.A1 = v
dummyG2.Y.B0.A0 = v
dummyG2.Y.B0.A1 = v
dummyG2.Y.B1.A0 = v
dummyG2.Y.B1.A1 = v

var resMillerLoop fields_bls24315.E24
// e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB)
sw_bls24315.MillerLoop(api, dummyG1, dummyG2, &resMillerLoop, pairingInfo)

// performs the final expo
var resPairing fields_bls24315.E24
resPairing.FinalExponentiation(api, resMillerLoop, pairingInfo.AteLoop, pairingInfo.Extension)
}, ecc.BW6_633)

}

type snippetCircuit struct {
V frontend.Variable
s snippet
}

func (d *snippetCircuit) Define(api frontend.API) error {
d.s(api, d.V)
return nil
}

func makeSnippetCircuit(s snippet) frontend.Circuit {
return &snippetCircuit{s: s}
}
Loading