-
Notifications
You must be signed in to change notification settings - Fork 418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perf: Poseidon2 GKR circuit #1410
Conversation
…kr-poseidon2-breakup-sbox
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged the current Poseidon2 PR and fixed the conflicts. When that PR is done with the issues addressed then we can merge once again.
I checked the GKR gates and they seem correct. However, I'm not able to fully follow how the full permutation is implemented, the gate registrations etc. are inlined and Poseidon2 permutation steps are unrolled.
See the comments and I would come back to review the PR once again, maybe we can have a look at the unrolled methods in a call.
Currently the PR is only hardcoded for BLS12-377, which I think for start is fine as it is really difficult to support different native elements. I actually encountered the same issue when implementing non-native sumcheck where I implemented generic arithEngine
interface which can be used to perform operations on different types (see https://github.com/Consensys/gnark/blob/master/std/recursion/sumcheck/arithengine.go), but it was exploratory approahc.
See also the suggested edit below to resolve one TODO in the test file.
Suggested edit: diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go
index 3519f8e1..55dad8fb 100644
--- a/std/gkr/api_test.go
+++ b/std/gkr/api_test.go
@@ -16,8 +16,6 @@ import (
bw6761 "github.com/consensys/gnark/constraint/bw6-761"
"github.com/consensys/gnark/test"
- "github.com/consensys/gnark-crypto/kzg"
- "github.com/consensys/gnark/backend/plonk"
bn254 "github.com/consensys/gnark/constraint/bn254"
"github.com/stretchr/testify/require"
@@ -26,15 +24,12 @@ import (
"github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr"
bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc"
"github.com/consensys/gnark/backend/groth16"
- "github.com/consensys/gnark/backend/witness"
"github.com/consensys/gnark/constraint"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
- "github.com/consensys/gnark/frontend/cs/scs"
stdHash "github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/hash/mimc"
test_vector_utils "github.com/consensys/gnark/std/internal/test_vectors_utils"
- "github.com/consensys/gnark/test/unsafekzg"
)
// compressThreshold --> if linear expressions are larger than this, the frontend will introduce
@@ -69,6 +64,7 @@ func (c *doubleNoDependencyCircuit) Define(api frontend.API) error {
}
func TestDoubleNoDependencyCircuit(t *testing.T) {
+ assert := test.NewAssert(t)
xValuess := [][]frontend.Variable{
{1, 1},
@@ -77,12 +73,13 @@ func TestDoubleNoDependencyCircuit(t *testing.T) {
hashes := []string{"-1", "-20"}
- for _, xValues := range xValuess {
+ for i, xValues := range xValuess {
for _, hashName := range hashes {
assignment := doubleNoDependencyCircuit{X: xValues}
circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName}
-
- test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment))
+ assert.Run(func(assert *test.Assert) {
+ assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+ }, fmt.Sprintf("xValue=%d/hash=%s", i, hashName))
}
}
@@ -115,6 +112,7 @@ func (c *sqNoDependencyCircuit) Define(api frontend.API) error {
}
func TestSqNoDependencyCircuit(t *testing.T) {
+ assert := test.NewAssert(t)
xValuess := [][]frontend.Variable{
{1, 1},
@@ -123,12 +121,13 @@ func TestSqNoDependencyCircuit(t *testing.T) {
hashes := []string{"-1", "-20"}
- for _, xValues := range xValuess {
+ for i, xValues := range xValuess {
for _, hashName := range hashes {
assignment := sqNoDependencyCircuit{X: xValues}
circuit := sqNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName}
- testGroth16(t, &circuit, &assignment)
- testPlonk(t, &circuit, &assignment)
+ assert.Run(func(assert *test.Assert) {
+ assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+ }, fmt.Sprintf("xValues=%d/hash=%s", i, hashName))
}
}
}
@@ -168,6 +167,7 @@ func (c *mulNoDependencyCircuit) Define(api frontend.API) error {
}
func TestMulNoDependency(t *testing.T) {
+ assert := test.NewAssert(t)
xValuess := [][]frontend.Variable{
{1, 2},
}
@@ -189,9 +189,9 @@ func TestMulNoDependency(t *testing.T) {
Y: make([]frontend.Variable, len(yValuess[i])),
hashName: hashName,
}
-
- testGroth16(t, &circuit, &assignment)
- testPlonk(t, &circuit, &assignment)
+ assert.Run(func(assert *test.Assert) {
+ assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+ }, fmt.Sprintf("xValues=%d/hash=%s", i, hashName))
}
}
}
@@ -240,14 +240,13 @@ func (c *mulWithDependencyCircuit) Define(api frontend.API) error {
}
func TestSolveMulWithDependency(t *testing.T) {
+ assert := test.NewAssert(t)
assignment := mulWithDependencyCircuit{
XLast: 1,
Y: []frontend.Variable{3, 2},
}
circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"}
-
- testGroth16(t, &circuit, &assignment)
- testPlonk(t, &circuit, &assignment)
+ assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
}
func TestApiMul(t *testing.T) {
@@ -387,54 +386,6 @@ func (c *benchMiMCMerkleTreeCircuit) Define(api frontend.API) error {
return solution.Verify("-20", challenge)
}
-// TODO @Tabaie just try using IsSolved instead?
-func testGroth16(t *testing.T, circuit, assignment frontend.Circuit) {
- cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold))
- require.NoError(t, err)
- var (
- fullWitness witness.Witness
- publicWitness witness.Witness
- pk groth16.ProvingKey
- vk groth16.VerifyingKey
- proof groth16.Proof
- )
- fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField())
- require.NoError(t, err)
- publicWitness, err = fullWitness.Public()
- require.NoError(t, err)
- pk, vk, err = groth16.Setup(cs)
- require.NoError(t, err)
- proof, err = groth16.Prove(cs, pk, fullWitness)
- require.NoError(t, err)
- err = groth16.Verify(proof, vk, publicWitness)
- require.NoError(t, err)
-}
-
-func testPlonk(t *testing.T, circuit, assignment frontend.Circuit) {
- cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold))
- require.NoError(t, err)
- var (
- fullWitness witness.Witness
- publicWitness witness.Witness
- pk plonk.ProvingKey
- vk plonk.VerifyingKey
- proof plonk.Proof
- kzgSrs kzg.SRS
- )
- fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField())
- require.NoError(t, err)
- publicWitness, err = fullWitness.Public()
- require.NoError(t, err)
- kzgSrs, srsLagrange, err := unsafekzg.NewSRS(cs)
- require.NoError(t, err)
- pk, vk, err = plonk.Setup(cs, kzgSrs, srsLagrange)
- require.NoError(t, err)
- proof, err = plonk.Prove(cs, pk, fullWitness)
- require.NoError(t, err)
- err = plonk.Verify(proof, vk, publicWitness)
- require.NoError(t, err)
-}
-
func registerMiMC() {
bn254.RegisterHashBuilder("mimc", func() hash.Hash {
return bn254MiMC.NewMiMC()
@@ -646,19 +597,21 @@ func BenchmarkMiMCNoGkrFullDepthSolve(b *testing.B) {
}
func TestMiMCFullDepthNoDepSolve(t *testing.T) {
+ assert := test.NewAssert(t)
registerMiMC()
for i := 0; i < 100; i++ {
circuit, assignment := mimcNoDepCircuits(5, 1<<2, "-20")
- testGroth16(t, circuit, assignment)
- testPlonk(t, circuit, assignment)
+ assert.Run(func(assert *test.Assert) {
+ assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254))
+ }, fmt.Sprintf("i=%d", i))
}
}
func TestMiMCFullDepthNoDepSolveWithMiMCHash(t *testing.T) {
+ assert := test.NewAssert(t)
registerMiMC()
circuit, assignment := mimcNoDepCircuits(5, 1<<2, "mimc")
- testGroth16(t, circuit, assignment)
- testPlonk(t, circuit, assignment)
+ assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254))
}
func mimcNoGkrCircuits(mimcDepth, nbInstances int) (circuit, assignment frontend.Circuit) {
diff --git a/std/gkr/testing.go b/std/gkr/testing.go
index 5f824cac..5a3f7337 100644
--- a/std/gkr/testing.go
+++ b/std/gkr/testing.go
@@ -3,12 +3,13 @@ package gkr
import (
"errors"
"fmt"
+ "math/big"
+
"github.com/consensys/gnark-crypto/ecc"
frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
gkrBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr"
hint "github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
- "math/big"
)
// SolveAll IS A TEST FUNCTION USED ONLY TO DEBUG a GKR circuit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GKR circuit implementation is functionally correct, but imo still difficult to follow and not future-proof. For example:
- it is hardcoded for the default parameters of BLS12-377 and wouldn't work with other curves.
- the different grouping of rounds compared to naive in-circuit implementation requires manually thinking about first and last rounds.
- currently is hardcoded for width=2, but when we start supporting compiling circuits to small fields, then this wouldn't allow using GKR poseidon2
- the gate registration is inlined with gate evaluation
- some utility methods are defined using function-local scope (
fullRounds
etc.)
But it seems to be correct, lets go ahead with it for now and I'll try to make a bigger refactor so that it would be more generic in a separate PR (also requires other changes in gnark-crypto to make Parameters
in gnark-crypto interface not specific types etc.). Only things before merging - could you change how you test the circuit against real solver in addition to test solver and refactoring the BenchmarkGkrPoseidon
method. See the inlined comments.
It would also nice to have some basic documentation for the NewGkrPermutation
function to describe what it is and what is required to use (register hints etc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looks good now! I have the notes for refactoring locally for now, will try to make a coherent issue out of it.
Co-authored-by: Ivo Kubjas <[email protected]>
Companion to Consensys/gnark-crypto#628