Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf: Poseidon2 GKR circuit #1410

Merged
merged 57 commits into from
Feb 24, 2025
Merged

Perf: Poseidon2 GKR circuit #1410

merged 57 commits into from
Feb 24, 2025

Conversation

Tabaie
Copy link
Contributor

@Tabaie Tabaie commented Feb 4, 2025

@Tabaie Tabaie marked this pull request as ready for review February 12, 2025 15:14
Copy link
Collaborator

@ivokub ivokub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the current Poseidon2 PR and fixed the conflicts. When that PR is done with the issues addressed then we can merge once again.

I checked the GKR gates and they seem correct. However, I'm not able to fully follow how the full permutation is implemented, the gate registrations etc. are inlined and Poseidon2 permutation steps are unrolled.

See the comments and I would come back to review the PR once again, maybe we can have a look at the unrolled methods in a call.

Currently the PR is only hardcoded for BLS12-377, which I think for start is fine as it is really difficult to support different native elements. I actually encountered the same issue when implementing non-native sumcheck where I implemented generic arithEngine interface which can be used to perform operations on different types (see https://github.com/Consensys/gnark/blob/master/std/recursion/sumcheck/arithengine.go), but it was exploratory approahc.

See also the suggested edit below to resolve one TODO in the test file.

@ivokub
Copy link
Collaborator

ivokub commented Feb 13, 2025

Suggested edit:

diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go
index 3519f8e1..55dad8fb 100644
--- a/std/gkr/api_test.go
+++ b/std/gkr/api_test.go
@@ -16,8 +16,6 @@ import (
 	bw6761 "github.com/consensys/gnark/constraint/bw6-761"
 	"github.com/consensys/gnark/test"
 
-	"github.com/consensys/gnark-crypto/kzg"
-	"github.com/consensys/gnark/backend/plonk"
 	bn254 "github.com/consensys/gnark/constraint/bn254"
 	"github.com/stretchr/testify/require"
 
@@ -26,15 +24,12 @@ import (
 	"github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr"
 	bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc"
 	"github.com/consensys/gnark/backend/groth16"
-	"github.com/consensys/gnark/backend/witness"
 	"github.com/consensys/gnark/constraint"
 	"github.com/consensys/gnark/frontend"
 	"github.com/consensys/gnark/frontend/cs/r1cs"
-	"github.com/consensys/gnark/frontend/cs/scs"
 	stdHash "github.com/consensys/gnark/std/hash"
 	"github.com/consensys/gnark/std/hash/mimc"
 	test_vector_utils "github.com/consensys/gnark/std/internal/test_vectors_utils"
-	"github.com/consensys/gnark/test/unsafekzg"
 )
 
 // compressThreshold --> if linear expressions are larger than this, the frontend will introduce
@@ -69,6 +64,7 @@ func (c *doubleNoDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestDoubleNoDependencyCircuit(t *testing.T) {
+	assert := test.NewAssert(t)
 
 	xValuess := [][]frontend.Variable{
 		{1, 1},
@@ -77,12 +73,13 @@ func TestDoubleNoDependencyCircuit(t *testing.T) {
 
 	hashes := []string{"-1", "-20"}
 
-	for _, xValues := range xValuess {
+	for i, xValues := range xValuess {
 		for _, hashName := range hashes {
 			assignment := doubleNoDependencyCircuit{X: xValues}
 			circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName}
-
-			test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment))
+			assert.Run(func(assert *test.Assert) {
+				assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+			}, fmt.Sprintf("xValue=%d/hash=%s", i, hashName))
 
 		}
 	}
@@ -115,6 +112,7 @@ func (c *sqNoDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestSqNoDependencyCircuit(t *testing.T) {
+	assert := test.NewAssert(t)
 
 	xValuess := [][]frontend.Variable{
 		{1, 1},
@@ -123,12 +121,13 @@ func TestSqNoDependencyCircuit(t *testing.T) {
 
 	hashes := []string{"-1", "-20"}
 
-	for _, xValues := range xValuess {
+	for i, xValues := range xValuess {
 		for _, hashName := range hashes {
 			assignment := sqNoDependencyCircuit{X: xValues}
 			circuit := sqNoDependencyCircuit{X: make([]frontend.Variable, len(xValues)), hashName: hashName}
-			testGroth16(t, &circuit, &assignment)
-			testPlonk(t, &circuit, &assignment)
+			assert.Run(func(assert *test.Assert) {
+				assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+			}, fmt.Sprintf("xValues=%d/hash=%s", i, hashName))
 		}
 	}
 }
@@ -168,6 +167,7 @@ func (c *mulNoDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestMulNoDependency(t *testing.T) {
+	assert := test.NewAssert(t)
 	xValuess := [][]frontend.Variable{
 		{1, 2},
 	}
@@ -189,9 +189,9 @@ func TestMulNoDependency(t *testing.T) {
 				Y:        make([]frontend.Variable, len(yValuess[i])),
 				hashName: hashName,
 			}
-
-			testGroth16(t, &circuit, &assignment)
-			testPlonk(t, &circuit, &assignment)
+			assert.Run(func(assert *test.Assert) {
+				assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
+			}, fmt.Sprintf("xValues=%d/hash=%s", i, hashName))
 		}
 	}
 }
@@ -240,14 +240,13 @@ func (c *mulWithDependencyCircuit) Define(api frontend.API) error {
 }
 
 func TestSolveMulWithDependency(t *testing.T) {
+	assert := test.NewAssert(t)
 	assignment := mulWithDependencyCircuit{
 		XLast: 1,
 		Y:     []frontend.Variable{3, 2},
 	}
 	circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"}
-
-	testGroth16(t, &circuit, &assignment)
-	testPlonk(t, &circuit, &assignment)
+	assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254))
 }
 
 func TestApiMul(t *testing.T) {
@@ -387,54 +386,6 @@ func (c *benchMiMCMerkleTreeCircuit) Define(api frontend.API) error {
 	return solution.Verify("-20", challenge)
 }
 
-// TODO @Tabaie just try using IsSolved instead?
-func testGroth16(t *testing.T, circuit, assignment frontend.Circuit) {
-	cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold))
-	require.NoError(t, err)
-	var (
-		fullWitness   witness.Witness
-		publicWitness witness.Witness
-		pk            groth16.ProvingKey
-		vk            groth16.VerifyingKey
-		proof         groth16.Proof
-	)
-	fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField())
-	require.NoError(t, err)
-	publicWitness, err = fullWitness.Public()
-	require.NoError(t, err)
-	pk, vk, err = groth16.Setup(cs)
-	require.NoError(t, err)
-	proof, err = groth16.Prove(cs, pk, fullWitness)
-	require.NoError(t, err)
-	err = groth16.Verify(proof, vk, publicWitness)
-	require.NoError(t, err)
-}
-
-func testPlonk(t *testing.T, circuit, assignment frontend.Circuit) {
-	cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, circuit, frontend.WithCompressThreshold(compressThreshold))
-	require.NoError(t, err)
-	var (
-		fullWitness   witness.Witness
-		publicWitness witness.Witness
-		pk            plonk.ProvingKey
-		vk            plonk.VerifyingKey
-		proof         plonk.Proof
-		kzgSrs        kzg.SRS
-	)
-	fullWitness, err = frontend.NewWitness(assignment, ecc.BN254.ScalarField())
-	require.NoError(t, err)
-	publicWitness, err = fullWitness.Public()
-	require.NoError(t, err)
-	kzgSrs, srsLagrange, err := unsafekzg.NewSRS(cs)
-	require.NoError(t, err)
-	pk, vk, err = plonk.Setup(cs, kzgSrs, srsLagrange)
-	require.NoError(t, err)
-	proof, err = plonk.Prove(cs, pk, fullWitness)
-	require.NoError(t, err)
-	err = plonk.Verify(proof, vk, publicWitness)
-	require.NoError(t, err)
-}
-
 func registerMiMC() {
 	bn254.RegisterHashBuilder("mimc", func() hash.Hash {
 		return bn254MiMC.NewMiMC()
@@ -646,19 +597,21 @@ func BenchmarkMiMCNoGkrFullDepthSolve(b *testing.B) {
 }
 
 func TestMiMCFullDepthNoDepSolve(t *testing.T) {
+	assert := test.NewAssert(t)
 	registerMiMC()
 	for i := 0; i < 100; i++ {
 		circuit, assignment := mimcNoDepCircuits(5, 1<<2, "-20")
-		testGroth16(t, circuit, assignment)
-		testPlonk(t, circuit, assignment)
+		assert.Run(func(assert *test.Assert) {
+			assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254))
+		}, fmt.Sprintf("i=%d", i))
 	}
 }
 
 func TestMiMCFullDepthNoDepSolveWithMiMCHash(t *testing.T) {
+	assert := test.NewAssert(t)
 	registerMiMC()
 	circuit, assignment := mimcNoDepCircuits(5, 1<<2, "mimc")
-	testGroth16(t, circuit, assignment)
-	testPlonk(t, circuit, assignment)
+	assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254))
 }
 
 func mimcNoGkrCircuits(mimcDepth, nbInstances int) (circuit, assignment frontend.Circuit) {
diff --git a/std/gkr/testing.go b/std/gkr/testing.go
index 5f824cac..5a3f7337 100644
--- a/std/gkr/testing.go
+++ b/std/gkr/testing.go
@@ -3,12 +3,13 @@ package gkr
 import (
 	"errors"
 	"fmt"
+	"math/big"
+
 	"github.com/consensys/gnark-crypto/ecc"
 	frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
 	gkrBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr"
 	hint "github.com/consensys/gnark/constraint/solver"
 	"github.com/consensys/gnark/frontend"
-	"math/big"
 )
 
 // SolveAll IS A TEST FUNCTION USED ONLY TO DEBUG a GKR circuit

Copy link
Collaborator

@ivokub ivokub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GKR circuit implementation is functionally correct, but imo still difficult to follow and not future-proof. For example:

  • it is hardcoded for the default parameters of BLS12-377 and wouldn't work with other curves.
  • the different grouping of rounds compared to naive in-circuit implementation requires manually thinking about first and last rounds.
  • currently is hardcoded for width=2, but when we start supporting compiling circuits to small fields, then this wouldn't allow using GKR poseidon2
  • the gate registration is inlined with gate evaluation
  • some utility methods are defined using function-local scope (fullRounds etc.)

But it seems to be correct, lets go ahead with it for now and I'll try to make a bigger refactor so that it would be more generic in a separate PR (also requires other changes in gnark-crypto to make Parameters in gnark-crypto interface not specific types etc.). Only things before merging - could you change how you test the circuit against real solver in addition to test solver and refactoring the BenchmarkGkrPoseidon method. See the inlined comments.

It would also nice to have some basic documentation for the NewGkrPermutation function to describe what it is and what is required to use (register hints etc.)

Copy link
Collaborator

@ivokub ivokub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks good now! I have the notes for refactoring locally for now, will try to make a coherent issue out of it.

@Tabaie Tabaie merged commit aab575e into master Feb 24, 2025
5 checks passed
@Tabaie Tabaie deleted the perf/gkr-poseidon2-breakup-sbox branch February 24, 2025 14:38
lucasmenendez pushed a commit to vocdoni/gnark-no-assert that referenced this pull request Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants