Skip to content

Commit

Permalink
perf(kzg): remove folding and shrinked scalars options in MSM
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Mar 9, 2024
1 parent 9fc5c14 commit 1b7c6d0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 69 deletions.
57 changes: 2 additions & 55 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -1096,57 +1096,6 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat

}

// scalarBitsMulGeneric computes [s]p and returns it where sBits is the bit decomposition of s. It doesn't modify p nor sBits.
// ⚠️ p must not be (0,0) and sBits not [0,...,0], unless [algopts.WithCompleteArithmetic] option is set.
func (c *Curve[B, S]) scalarBitsMulGeneric(p *AffinePoint[B], sBits []frontend.Variable, opts ...algopts.AlgebraOption) *AffinePoint[B] {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
var selector frontend.Variable
if cfg.CompleteArithmetic {
// if p=(0,0) we assign a dummy (0,1) to p and continue
selector = c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y))
one := c.baseApi.One()
p = c.Select(selector, &AffinePoint[B]{X: *one, Y: *one}, p)
}

var st S
n := st.Modulus().BitLen()
if cfg.NbScalarBits > 2 && cfg.NbScalarBits < n {
n = cfg.NbScalarBits
}

// i = 1
Rb := c.triple(p)
R0 := c.Select(sBits[1], Rb, p)
R1 := c.Select(sBits[1], p, Rb)

for i := 2; i < n-1; i++ {
Rb = c.doubleAndAddSelect(sBits[i], R0, R1)
R0 = c.Select(sBits[i], Rb, R0)
R1 = c.Select(sBits[i], R1, Rb)
}

// i = n-1
Rb = c.doubleAndAddSelect(sBits[n-1], R0, R1)
R0 = c.Select(sBits[n-1], Rb, R0)

// i = 0
// we use AddUnified instead of Add. This is because:
// - when s=0 then R0=P and AddUnified(P, -P) = (0,0). We return (0,0).
// - when s=1 then R0=P AddUnified(Q, -Q) is well defined. We return R0=P.
R0 = c.Select(sBits[0], R0, c.AddUnified(R0, c.Neg(p)))

if cfg.CompleteArithmetic {
// if p=(0,0), return (0,0)
zero := c.baseApi.Zero()
R0 = c.Select(selector, &AffinePoint[B]{X: *zero, Y: *zero}, R0)
}

return R0
}

// ScalarMulBase computes [s]g and returns it where g is the fixed curve generator. It doesn't modify p nor s.
//
// ScalarMul calls scalarMulBaseGeneric or scalarMulGLV depending on whether an efficient endomorphism is available.
Expand Down Expand Up @@ -1278,12 +1227,10 @@ func (c *Curve[B, S]) MultiScalarMul(p []*AffinePoint[B], s []*emulated.Element[
return nil, fmt.Errorf("need scalar for folding")
}
gamma := s[0]
gamma = c.scalarApi.Reduce(gamma)
gammaBits := c.scalarApi.ToBits(gamma)
res := c.scalarBitsMulGeneric(p[len(p)-1], gammaBits, opts...)
res := c.ScalarMul(p[len(p)-1], gamma, opts...)
for i := len(p) - 2; i > 0; i-- {
res = addFn(p[i], res)
res = c.scalarBitsMulGeneric(res, gammaBits, opts...)
res = c.ScalarMul(res, gamma, opts...)
}
res = addFn(p[0], res)
return res, nil
Expand Down
27 changes: 13 additions & 14 deletions std/commitments/kzg/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ func (v *Verifier[FR, G1El, G2El, GTEl]) FoldProofsMultiPoint(digests []Commitme
seed := whSnark.Sum()
binSeed := bits.ToBinary(v.api, seed, bits.WithNbDigits(fr.Modulus().BitLen()))
randomNumbers[1] = v.scalarApi.FromBits(binSeed...)
nbScalarBits := ((v.api.Compiler().FieldBitLen()+7)/8 - 1) * 8

for i := 2; i < len(randomNumbers); i++ {
// TODO use real random numbers, follow the solidity smart contract to know which variables are used as seed
Expand All @@ -521,10 +520,11 @@ func (v *Verifier[FR, G1El, G2El, GTEl]) FoldProofsMultiPoint(digests []Commitme
for i := 0; i < len(randomNumbers); i++ {
quotients[i] = &proofs[i].Quotient
}
foldedQuotients, err := v.curve.MultiScalarMul(quotients, []*emulated.Element[FR]{randomNumbers[1]}, algopts.WithFoldingScalarMul(), algopts.WithNbScalarBits(nbScalarBits))
foldedQuotients, err := v.curve.MultiScalarMul(quotients[1:], randomNumbers[1:])
if err != nil {
return nil, nil, fmt.Errorf("fold quotients: %w", err)
}
foldedQuotients = v.curve.Add(foldedQuotients, quotients[0])
foldedPointsQuotients, err := v.curve.MultiScalarMul(quotients, randomPointNumbers)
if err != nil {
return nil, nil, fmt.Errorf("fold point quotients: %w", err)
Expand Down Expand Up @@ -594,7 +594,6 @@ func (v *Verifier[FR, G1El, G2El, GTEl]) FoldProof(digests []Commitment[G1El], b
var retP OpeningProof[FR, G1El]
var retC Commitment[G1El]
// we assume the short hash output size is full byte fitting into the modulus length.
nbScalarBits := ((v.api.Compiler().FieldBitLen()+7)/8 - 1) * 8
nbDigests := len(digests)

// check consistency between numbers of claims vs number of digests
Expand All @@ -607,27 +606,27 @@ func (v *Verifier[FR, G1El, G2El, GTEl]) FoldProof(digests []Commitment[G1El], b
if err != nil {
return retP, retC, fmt.Errorf("derive gamma: %w", err)
}

// gammai = [1,γ,γ²,..,γⁿ⁻¹]
gammai := make([]*emulated.Element[FR], nbDigests)
gammai[0] = v.scalarApi.One()
if nbDigests > 1 {
gammai[1] = gamma
}
for i := 2; i < nbDigests; i++ {
gammai[i] = v.scalarApi.Mul(gammai[i-1], gamma)
}
// fold the claimed values and digests
// compute ∑ᵢ γ^i C_i = C_0 + γ(C_1 + γ(C2 ...)), allowing to bound the scalar multiplication iterations
digestsP := make([]*G1El, len(digests))
for i := range digestsP {
digestsP[i] = &digests[i].G1El
}
foldedDigests, err := v.curve.MultiScalarMul(digestsP, []*emulated.Element[FR]{gamma}, algopts.WithNbScalarBits(nbScalarBits), algopts.WithFoldingScalarMul())
foldedDigests, err := v.curve.MultiScalarMul(digestsP[1:], gammai[1:])
if err != nil {
return retP, retC, fmt.Errorf("multi scalar mul: %w", err)
}
foldedDigests = v.curve.Add(foldedDigests, digestsP[0])

// gammai = [1,γ,γ²,..,γⁿ⁻¹]
gammai := make([]*emulated.Element[FR], nbDigests)
gammai[0] = v.scalarApi.One()
if nbDigests > 1 {
gammai[1] = gamma
}
for i := 2; i < nbDigests; i++ {
gammai[i] = v.scalarApi.Mul(gammai[i-1], gamma)
}
foldedEvaluations := &batchOpeningProof.ClaimedValues[0]
for i := 1; i < nbDigests; i++ {
tmp := v.scalarApi.Mul(&batchOpeningProof.ClaimedValues[i], gammai[i])
Expand Down

0 comments on commit 1b7c6d0

Please sign in to comment.