-
Notifications
You must be signed in to change notification settings - Fork 418
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #279 from ConsenSys/feat/statistics
Adds internal/stats package
- Loading branch information
Showing
8 changed files
with
373 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,120 +1 @@ | ||
package gnark | ||
|
||
import ( | ||
"encoding/gob" | ||
"os" | ||
"sync" | ||
"testing" | ||
|
||
"github.com/consensys/gnark-crypto/ecc" | ||
"github.com/consensys/gnark/backend" | ||
"github.com/consensys/gnark/frontend" | ||
"github.com/consensys/gnark/frontend/cs/r1cs" | ||
"github.com/consensys/gnark/frontend/cs/scs" | ||
"github.com/consensys/gnark/internal/backend/circuits" | ||
"github.com/consensys/gnark/test" | ||
) | ||
|
||
const ( | ||
fileStats = "init.stats" | ||
generateNewStats = false | ||
) | ||
|
||
var statsM sync.Mutex | ||
|
||
func TestCircuitStatistics(t *testing.T) { | ||
assert := test.NewAssert(t) | ||
for k := range circuits.Circuits { | ||
for _, curve := range ecc.Implemented() { | ||
for _, b := range backend.Implemented() { | ||
curve := curve | ||
backendID := b | ||
name := k | ||
// copy the circuit now in case assert calls t.Parallel() | ||
tData := circuits.Circuits[k] | ||
assert.Run(func(assert *test.Assert) { | ||
var newCompiler frontend.NewBuilder | ||
|
||
switch backendID { | ||
case backend.GROTH16: | ||
newCompiler = r1cs.NewBuilder | ||
case backend.PLONK: | ||
newCompiler = scs.NewBuilder | ||
default: | ||
panic("not implemented") | ||
} | ||
|
||
ccs, err := frontend.Compile(curve, newCompiler, tData.Circuit) | ||
assert.NoError(err) | ||
|
||
// ensure we didn't introduce regressions that make circuits less efficient | ||
nbConstraints := ccs.GetNbConstraints() | ||
internal, secret, public := ccs.GetNbVariables() | ||
checkStats(assert, name, nbConstraints, internal, secret, public, curve, backendID) | ||
}, name, curve.String(), backendID.String()) | ||
} | ||
} | ||
|
||
} | ||
|
||
// serialize newStats | ||
if generateNewStats { | ||
fStats, err := os.Create(fileStats) | ||
assert.NoError(err) | ||
|
||
encoder := gob.NewEncoder(fStats) | ||
err = encoder.Encode(mStats) | ||
assert.NoError(err) | ||
} | ||
} | ||
|
||
type circuitStats struct { | ||
NbConstraints, Internal, Secret, Public int | ||
} | ||
|
||
var mStats map[string][backend.PLONK + 1][ecc.BW6_633 + 1]circuitStats | ||
|
||
func checkStats(assert *test.Assert, circuitName string, nbConstraints, internal, secret, public int, curve ecc.ID, backendID backend.ID) { | ||
statsM.Lock() | ||
defer statsM.Unlock() | ||
if generateNewStats { | ||
rs := mStats[circuitName] | ||
rs[backendID][curve] = circuitStats{nbConstraints, internal, secret, public} | ||
mStats[circuitName] = rs | ||
return | ||
} | ||
if referenceStats, ok := mStats[circuitName]; !ok { | ||
assert.Log("warning: no stats for circuit", circuitName) | ||
} else { | ||
ref := referenceStats[backendID][curve] | ||
if ref.NbConstraints != nbConstraints { | ||
assert.Failf("unexpected constraint count", "expected %d nbConstraints (reference), got %d. %s, %s, %s", ref.NbConstraints, nbConstraints, circuitName, backendID.String(), curve.String()) | ||
} | ||
if ref.Internal != internal { | ||
assert.Failf("unexpected internal variable count", "expected %d internal (reference), got %d. %s, %s, %s", ref.Internal, internal, circuitName, backendID.String(), curve.String()) | ||
} | ||
if ref.Secret != secret { | ||
assert.Failf("unexpected secret variable count", "expected %d secret (reference), got %d. %s, %s, %s", ref.Secret, secret, circuitName, backendID.String(), curve.String()) | ||
} | ||
if ref.Public != public { | ||
assert.Failf("unexpected public variable count", "expected %d public (reference), got %d. %s, %s, %s", ref.Public, public, circuitName, backendID.String(), curve.String()) | ||
} | ||
} | ||
} | ||
|
||
func init() { | ||
mStats = make(map[string][backend.PLONK + 1][ecc.BW6_633 + 1]circuitStats) | ||
|
||
if !generateNewStats { | ||
fStats, err := os.Open(fileStats) | ||
if err != nil { | ||
panic(err) | ||
} | ||
decoder := gob.NewDecoder(fStats) | ||
err = decoder.Decode(&mStats) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package main | ||
|
||
import ( | ||
"flag" | ||
"fmt" | ||
"log" | ||
"regexp" | ||
"sync" | ||
|
||
"github.com/consensys/gnark/backend" | ||
"github.com/consensys/gnark/internal/stats" | ||
) | ||
|
||
var ( | ||
fSave = flag.Bool("s", false, "save new stats in file ") | ||
fFilter = flag.String("run", "", "filter runs with regexp; example 'pairing*'") | ||
) | ||
|
||
func main() { | ||
flag.Parse() | ||
|
||
var r *regexp.Regexp | ||
if *fFilter != "" { | ||
r = regexp.MustCompile(*fFilter) | ||
} | ||
|
||
s := stats.NewGlobalStats() | ||
|
||
// load reference objects | ||
// for each circuit, on each curve, on each backend | ||
// compare with reference stats | ||
snippets := stats.GetSnippets() | ||
var wg sync.WaitGroup | ||
for name, c := range snippets { | ||
if r != nil && !r.MatchString(name) { | ||
continue | ||
} | ||
wg.Add(1) | ||
go func(name string, circuit stats.Circuit) { | ||
defer wg.Done() | ||
for _, curve := range circuit.Curves { | ||
for _, backendID := range backend.Implemented() { | ||
cs, err := stats.NewSnippetStats(curve, backendID, circuit.Circuit) | ||
if err != nil { | ||
log.Fatalf("building stats for circuit %s %v", name, err) | ||
} | ||
s.Add(curve, backendID, cs, name) | ||
} | ||
} | ||
}(name, c) | ||
} | ||
wg.Wait() | ||
|
||
fmt.Println("id,curve,backend,nbConstraints,nbWires") | ||
for name, c := range snippets { | ||
if r != nil && !r.MatchString(name) { | ||
continue | ||
} | ||
ss := s.Stats[name] | ||
for _, curve := range c.Curves { | ||
for _, backendID := range backend.Implemented() { | ||
cs := ss[backendID][curve] | ||
fmt.Printf("%s,%s,%s,%d,%d\n", name, curve, backendID, cs.NbConstraints, cs.NbInternalWires) | ||
} | ||
} | ||
} | ||
|
||
if *fSave { | ||
const refPath = "../latest.stats" | ||
if err := s.Save(refPath); err != nil { | ||
log.Fatal(err) | ||
} | ||
log.Println("successfully saved new reference stats file", refPath) | ||
} | ||
|
||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
package stats | ||
|
||
import ( | ||
"math" | ||
"sync" | ||
|
||
"github.com/consensys/gnark-crypto/ecc" | ||
"github.com/consensys/gnark/frontend" | ||
"github.com/consensys/gnark/std/algebra/fields_bls12377" | ||
"github.com/consensys/gnark/std/algebra/fields_bls24315" | ||
"github.com/consensys/gnark/std/algebra/sw_bls12377" | ||
"github.com/consensys/gnark/std/algebra/sw_bls24315" | ||
"github.com/consensys/gnark/std/hash/mimc" | ||
"github.com/consensys/gnark/std/math/bits" | ||
) | ||
|
||
var ( | ||
initOnce sync.Once | ||
snippets = make(map[string]Circuit) | ||
) | ||
|
||
func GetSnippets() map[string]Circuit { | ||
initOnce.Do(initSnippets) | ||
return snippets | ||
} | ||
|
||
type snippet func(api frontend.API, newVariable func() frontend.Variable) | ||
|
||
func registerSnippet(name string, snippet snippet, curves ...ecc.ID) { | ||
if _, ok := snippets[name]; ok { | ||
panic("circuit " + name + " already registered") | ||
} | ||
if len(curves) == 0 { | ||
curves = ecc.Implemented() | ||
} | ||
snippets[name] = Circuit{makeSnippetCircuit(snippet), curves} | ||
} | ||
|
||
func initSnippets() { | ||
// add api snippets | ||
registerSnippet("api/IsZero", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = api.IsZero(newVariable()) | ||
}) | ||
|
||
registerSnippet("api/Lookup2", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = api.Lookup2(newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable()) | ||
}) | ||
|
||
registerSnippet("api/AssertIsLessOrEqual", func(api frontend.API, newVariable func() frontend.Variable) { | ||
api.AssertIsLessOrEqual(newVariable(), newVariable()) | ||
}) | ||
registerSnippet("api/AssertIsLessOrEqual/constant_bound_64_bits", func(api frontend.API, newVariable func() frontend.Variable) { | ||
bound := uint64(math.MaxUint64) | ||
api.AssertIsLessOrEqual(newVariable(), bound) | ||
}) | ||
|
||
// add std snippets | ||
registerSnippet("math/bits.ToBinary", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = bits.ToBinary(api, newVariable()) | ||
}) | ||
registerSnippet("math/bits.ToBinary/unconstrained", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = bits.ToBinary(api, newVariable(), bits.WithUnconstrainedOutputs()) | ||
}) | ||
registerSnippet("math/bits.ToTernary", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = bits.ToTernary(api, newVariable()) | ||
}) | ||
registerSnippet("math/bits.ToTernary/unconstrained", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = bits.ToTernary(api, newVariable(), bits.WithUnconstrainedOutputs()) | ||
}) | ||
registerSnippet("math/bits.ToNAF", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = bits.ToNAF(api, newVariable()) | ||
}) | ||
registerSnippet("math/bits.ToNAF/unconstrained", func(api frontend.API, newVariable func() frontend.Variable) { | ||
_ = bits.ToNAF(api, newVariable(), bits.WithUnconstrainedOutputs()) | ||
}) | ||
|
||
registerSnippet("hash/mimc", func(api frontend.API, newVariable func() frontend.Variable) { | ||
mimc, _ := mimc.NewMiMC(api) | ||
mimc.Write(newVariable()) | ||
_ = mimc.Sum() | ||
}) | ||
|
||
registerSnippet("pairing_bls12377", func(api frontend.API, newVariable func() frontend.Variable) { | ||
ateLoop := uint64(9586122913090633729) | ||
ext := fields_bls12377.GetBLS12377ExtensionFp12(api) | ||
pairingInfo := sw_bls12377.PairingContext{AteLoop: ateLoop, Extension: ext} | ||
|
||
var dummyG1 sw_bls12377.G1Affine | ||
var dummyG2 sw_bls12377.G2Affine | ||
dummyG1.X = newVariable() | ||
dummyG1.Y = newVariable() | ||
dummyG2.X.A0 = newVariable() | ||
dummyG2.X.A1 = newVariable() | ||
dummyG2.Y.A0 = newVariable() | ||
dummyG2.Y.A1 = newVariable() | ||
|
||
var resMillerLoop fields_bls12377.E12 | ||
// e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB) | ||
sw_bls12377.MillerLoop(api, dummyG1, dummyG2, &resMillerLoop, pairingInfo) | ||
|
||
// performs the final expo | ||
var resPairing fields_bls12377.E12 | ||
resPairing.FinalExponentiation(api, resMillerLoop, pairingInfo.AteLoop, pairingInfo.Extension) | ||
}, ecc.BW6_761) | ||
|
||
registerSnippet("pairing_bls24315", func(api frontend.API, newVariable func() frontend.Variable) { | ||
ateLoop := uint64(3218079743) | ||
ext := fields_bls24315.GetBLS24315ExtensionFp24(api) | ||
pairingInfo := sw_bls24315.PairingContext{AteLoop: ateLoop, Extension: ext} | ||
|
||
var dummyG1 sw_bls24315.G1Affine | ||
var dummyG2 sw_bls24315.G2Affine | ||
dummyG1.X = newVariable() | ||
dummyG1.Y = newVariable() | ||
dummyG2.X.B0.A0 = newVariable() | ||
dummyG2.X.B0.A1 = newVariable() | ||
dummyG2.X.B1.A0 = newVariable() | ||
dummyG2.X.B1.A1 = newVariable() | ||
dummyG2.Y.B0.A0 = newVariable() | ||
dummyG2.Y.B0.A1 = newVariable() | ||
dummyG2.Y.B1.A0 = newVariable() | ||
dummyG2.Y.B1.A1 = newVariable() | ||
|
||
var resMillerLoop fields_bls24315.E24 | ||
// e(psi0, -gamma)*e(-πC, -δ)*e(πA, πB) | ||
sw_bls24315.MillerLoop(api, dummyG1, dummyG2, &resMillerLoop, pairingInfo) | ||
|
||
// performs the final expo | ||
var resPairing fields_bls24315.E24 | ||
resPairing.FinalExponentiation(api, resMillerLoop, pairingInfo.AteLoop, pairingInfo.Extension) | ||
}, ecc.BW6_633) | ||
|
||
} | ||
|
||
type snippetCircuit struct { | ||
V [1024]frontend.Variable | ||
s snippet | ||
vIndex int | ||
} | ||
|
||
func (d *snippetCircuit) Define(api frontend.API) error { | ||
d.s(api, d.newVariable) | ||
return nil | ||
} | ||
|
||
func (d *snippetCircuit) newVariable() frontend.Variable { | ||
d.vIndex++ | ||
return d.V[(d.vIndex-1)%len(d.V)] | ||
} | ||
|
||
func makeSnippetCircuit(s snippet) frontend.Circuit { | ||
return &snippetCircuit{s: s} | ||
} |
Oops, something went wrong.