Skip to content

Commit cdf5bde

Browse files
authored
fix(counters): smt depth fix and test (#1695)
* fix(counters): smt depth fix and test * fix(smt): set depth and get depth correctly * fix(smt): depth update
1 parent e393448 commit cdf5bde

File tree

8 files changed

+317
-163
lines changed

8 files changed

+317
-163
lines changed

.github/workflows/test-unwinds.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88
workflow_dispatch:
99

1010
jobs:
11-
fixing-unwinds-tests:
11+
unwind-tests:
1212
runs-on: ubuntu-22.04
1313

1414
steps:

smt/pkg/db/mdbx.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ const TableAccountValues = "HermezSmtAccountValues"
3232
const TableMetadata = "HermezSmtMetadata"
3333
const TableHashKey = "HermezSmtHashKey"
3434

35+
const MetaLastRoot = "lastRoot"
36+
const MetaDepth = "depth"
37+
3538
var HermezSmtTables = []string{TableSmt, TableStats, TableAccountValues, TableMetadata, TableHashKey}
3639

3740
type EriDb struct {
@@ -219,7 +222,7 @@ func (m *EriDb) RollbackBatch() {
219222
}
220223

221224
func (m *EriRoDb) GetLastRoot() (*big.Int, error) {
222-
data, err := m.kvTxRo.GetOne(TableStats, []byte("lastRoot"))
225+
data, err := m.kvTxRo.GetOne(TableStats, []byte(MetaLastRoot))
223226
if err != nil {
224227
return big.NewInt(0), err
225228
}
@@ -233,11 +236,11 @@ func (m *EriRoDb) GetLastRoot() (*big.Int, error) {
233236

234237
func (m *EriDb) SetLastRoot(r *big.Int) error {
235238
v := utils.ConvertBigIntToHex(r)
236-
return m.tx.Put(TableStats, []byte("lastRoot"), []byte(v))
239+
return m.tx.Put(TableStats, []byte(MetaLastRoot), []byte(v))
237240
}
238241

239242
func (m *EriRoDb) GetDepth() (uint8, error) {
240-
data, err := m.kvTxRo.GetOne(TableStats, []byte("depth"))
243+
data, err := m.kvTxRo.GetOne(TableStats, []byte(MetaDepth))
241244
if err != nil {
242245
return 0, err
243246
}
@@ -250,7 +253,7 @@ func (m *EriRoDb) GetDepth() (uint8, error) {
250253
}
251254

252255
func (m *EriDb) SetDepth(depth uint8) error {
253-
return m.tx.Put(TableStats, []byte("lastRoot"), []byte{depth})
256+
return m.tx.Put(TableStats, []byte(MetaDepth), []byte{depth})
254257
}
255258

256259
func (m *EriRoDb) Get(key utils.NodeKey) (utils.NodeValue12, error) {

smt/pkg/smt/smt_create.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,14 @@ func (s *SMT) GenerateFromKVBulk(ctx context.Context, logPrefix string, nodeKeys
6161

6262
var buildSmtLoopErr error
6363
var rootNode *SmtNode
64+
var maxDepth int
6465
tempTreeBuildStart := time.Now()
6566
leafValueMap := sync.Map{}
6667
accountValuesReadChan := make(chan *utils.NodeValue8, 1024)
6768
go func() {
6869
defer wg.Done()
6970
defer deletesWorker.Stop()
70-
rootNode, buildSmtLoopErr = runBuildSmtLoop(s, logPrefix, nodeKeys, &leafValueMap, deletesWorker, accountValuesReadChan)
71+
rootNode, maxDepth, buildSmtLoopErr = runBuildSmtLoop(s, logPrefix, nodeKeys, &leafValueMap, deletesWorker, accountValuesReadChan)
7172
}()
7273

7374
// startBuildSmtLoopDbCompanionLoop is blocking operation. It continue only when the last result is saved
@@ -124,10 +125,14 @@ func (s *SMT) GenerateFromKVBulk(ctx context.Context, logPrefix string, nodeKeys
124125
return [4]uint64{}, err
125126
}
126127

128+
if err := s.updateDepth(maxDepth); err != nil {
129+
return [4]uint64{}, err
130+
}
131+
127132
return finalRoot, nil
128133
}
129134

130-
func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafValueMap *sync.Map, deletesWorker *utils.Worker, accountValuesReadChan <-chan *utils.NodeValue8) (*SmtNode, error) {
135+
func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafValueMap *sync.Map, deletesWorker *utils.Worker, accountValuesReadChan <-chan *utils.NodeValue8) (*SmtNode, int, error) {
131136
totalKeysCount := len(nodeKeys)
132137
insertedKeysCount := uint64(0)
133138
maxReachedLevel := 0
@@ -148,7 +153,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
148153
keys := k.GetPath()
149154
vPointer := <-accountValuesReadChan
150155
if vPointer == nil {
151-
return nil, fmt.Errorf("the actual error is returned by main DB thread")
156+
return nil, 0, fmt.Errorf("the actual error is returned by main DB thread")
152157
}
153158
v := *vPointer
154159
leafValueMap.Store(k, &v)
@@ -202,7 +207,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
202207
//sanity check - new leaf should be on the right side
203208
//otherwise something went wrong
204209
if leaf0.rKey[level2] != 0 || keys[level2+level] != 1 {
205-
return nil, fmt.Errorf(
210+
return nil, 0, fmt.Errorf(
206211
"leaf insert error. new leaf should be on the right of the old, oldLeaf: %v, newLeaf: %v",
207212
append(keys[:level+1], leaf0.rKey[level2:]...),
208213
keys,
@@ -264,7 +269,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
264269
// this is case for 1 leaf inserted to the left of the root node
265270
if len(siblings) == 0 && keys[0] == 0 {
266271
if upperNode.node0 != nil {
267-
return nil, fmt.Errorf("tried to override left node")
272+
return nil, 0, fmt.Errorf("tried to override left node")
268273
}
269274
upperNode.node0 = newNode
270275
} else {
@@ -273,7 +278,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
273278
//the new leaf should be on the right side
274279
//otherwise something went wrong
275280
if upperNode.node1 != nil || keys[level] != 1 {
276-
return nil, fmt.Errorf(
281+
return nil, 0, fmt.Errorf(
277282
"leaf insert error. new should be on the right of the found node, foundNode: %v, newLeafKey: %v",
278283
upperNode.node1,
279284
keys,
@@ -318,9 +323,7 @@ func runBuildSmtLoop(s *SMT, logPrefix string, nodeKeys []utils.NodeKey, leafVal
318323
progressChan <- uint64(totalKeysCount) + insertedKeysCount
319324
}
320325

321-
s.updateDepth(maxReachedLevel)
322-
323-
return &rootNode, nil
326+
return &rootNode, maxReachedLevel, nil
324327
}
325328

326329
func startBuildSmtLoopDbCompanionLoop(s *SMT, nodeKeys []utils.NodeKey, jobResultsChannel chan utils.JobResult, accountValuesReadChan chan *utils.NodeValue8) error {

turbo/jsonrpc/zkevm_counters.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -299,16 +299,28 @@ func populateCounters(collected *vm.Counters, execResult *core.ExecutionResult,
299299
return resJson, nil
300300
}
301301

302-
func getSmtDepth(hermezDb *hermez_db.HermezDbReader, blockNum uint64, config *tracers.TraceConfig_ZkEvm) (int, error) {
303-
var smtDepth int
302+
type IDepthGetter interface {
303+
GetClosestSmtDepth(blockNum uint64) (depthBlockNum uint64, smtDepth uint64, err error)
304+
}
305+
306+
func getSmtDepth(
307+
hermezDb IDepthGetter,
308+
blockNum uint64,
309+
config *tracers.TraceConfig_ZkEvm,
310+
) (smtDepth int, err error) {
304311
if config != nil && config.SmtDepth != nil {
305312
smtDepth = *config.SmtDepth
306313
} else {
307-
depthBlockNum, smtDepth, err := hermezDb.GetClosestSmtDepth(blockNum)
314+
var depthBlockNum uint64
315+
var smtDepthUint64 uint64
316+
317+
depthBlockNum, smtDepthUint64, err = hermezDb.GetClosestSmtDepth(blockNum)
308318
if err != nil {
309319
return 0, err
310320
}
311321

322+
smtDepth = int(smtDepthUint64)
323+
312324
if depthBlockNum < blockNum {
313325
smtDepth += smtDepth / 10
314326
}

turbo/jsonrpc/zkevm_counters_test.go

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package jsonrpc
2+
3+
import (
4+
"testing"
5+
"errors"
6+
"github.com/ledgerwatch/erigon/eth/tracers"
7+
)
8+
9+
type MockDepthGetter struct {
10+
DepthBlockNum uint64
11+
SMTDepth uint64
12+
Err error
13+
}
14+
15+
func (m *MockDepthGetter) GetClosestSmtDepth(blockNum uint64) (uint64, uint64, error) {
16+
return m.DepthBlockNum, m.SMTDepth, m.Err
17+
}
18+
19+
func intPtr(i int) *int {
20+
return &i
21+
}
22+
23+
func TestGetSmtDepth(t *testing.T) {
24+
testCases := map[string]struct {
25+
blockNum uint64
26+
config *tracers.TraceConfig_ZkEvm
27+
mockSetup func(m *MockDepthGetter)
28+
expectedDepth int
29+
expectedErr error
30+
}{
31+
"Config provided with SmtDepth": {
32+
blockNum: 100,
33+
config: &tracers.TraceConfig_ZkEvm{
34+
SmtDepth: intPtr(128),
35+
},
36+
mockSetup: func(m *MockDepthGetter) {
37+
// No DB call expected.
38+
},
39+
expectedDepth: 128,
40+
expectedErr: nil,
41+
},
42+
"Config is nil, GetClosestSmtDepth returns depthBlockNum < blockNum": {
43+
blockNum: 100,
44+
config: nil,
45+
mockSetup: func(m *MockDepthGetter) {
46+
m.DepthBlockNum = 90
47+
m.SMTDepth = 100
48+
m.Err = nil
49+
},
50+
expectedDepth: 110, // 100 + 100/10
51+
expectedErr: nil,
52+
},
53+
"Config is nil, GetClosestSmtDepth returns depthBlockNum >= blockNum": {
54+
blockNum: 100,
55+
config: nil,
56+
mockSetup: func(m *MockDepthGetter) {
57+
m.DepthBlockNum = 100
58+
m.SMTDepth = 100
59+
m.Err = nil
60+
},
61+
expectedDepth: 100,
62+
expectedErr: nil,
63+
},
64+
"Config is nil, smtDepth after adjustment exceeds 256": {
65+
blockNum: 100,
66+
config: nil,
67+
mockSetup: func(m *MockDepthGetter) {
68+
m.DepthBlockNum = 90
69+
m.SMTDepth = 250
70+
m.Err = nil
71+
},
72+
expectedDepth: 256, // 250 + 25 = 275 -> capped to 256
73+
expectedErr: nil,
74+
},
75+
"Config is nil, smtDepth is 0": {
76+
blockNum: 100,
77+
config: nil,
78+
mockSetup: func(m *MockDepthGetter) {
79+
m.DepthBlockNum = 90
80+
m.SMTDepth = 0
81+
m.Err = nil
82+
},
83+
expectedDepth: 256, // 0 is invalid, set to 256
84+
expectedErr: nil,
85+
},
86+
"Config is nil, GetClosestSmtDepth returns error": {
87+
blockNum: 100,
88+
config: nil,
89+
mockSetup: func(m *MockDepthGetter) {
90+
m.DepthBlockNum = 0
91+
m.SMTDepth = 0
92+
m.Err = errors.New("database error")
93+
},
94+
expectedDepth: 0,
95+
expectedErr: errors.New("database error"),
96+
},
97+
"Config provided with SmtDepth exceeding 256": {
98+
blockNum: 100,
99+
config: &tracers.TraceConfig_ZkEvm{
100+
SmtDepth: intPtr(300),
101+
},
102+
mockSetup: func(m *MockDepthGetter) {
103+
// No DB call expected.
104+
},
105+
expectedDepth: 300, // As per the function logic, returned as-is.
106+
expectedErr: nil,
107+
},
108+
"Config provided with SmtDepth set to 0": {
109+
blockNum: 100,
110+
config: &tracers.TraceConfig_ZkEvm{
111+
SmtDepth: intPtr(0),
112+
},
113+
mockSetup: func(m *MockDepthGetter) {
114+
// No DB call expected.
115+
},
116+
expectedDepth: 0, // As per the function logic, returned as-is.
117+
expectedErr: nil,
118+
},
119+
}
120+
121+
for name, tc := range testCases {
122+
t.Run(name, func(t *testing.T) {
123+
mock := &MockDepthGetter{}
124+
tc.mockSetup(mock)
125+
126+
actualDepth, actualErr := getSmtDepth(mock, tc.blockNum, tc.config)
127+
128+
if tc.expectedErr != nil {
129+
if actualErr == nil {
130+
t.Fatalf("expected error '%v', but got nil", tc.expectedErr)
131+
}
132+
if actualErr.Error() != tc.expectedErr.Error() {
133+
t.Fatalf("expected error '%v', but got '%v'", tc.expectedErr, actualErr)
134+
}
135+
} else {
136+
if actualErr != nil {
137+
t.Fatalf("expected no error, but got '%v'", actualErr)
138+
}
139+
}
140+
141+
if actualDepth != tc.expectedDepth {
142+
t.Errorf("expected smtDepth %d, but got %d", tc.expectedDepth, actualDepth)
143+
}
144+
})
145+
}
146+
}

zk/smt/unwind_smt.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ func UnwindZkSMT(ctx context.Context, logPrefix string, from, to uint64, tx kv.R
3838
eridb.OpenBatch(quit)
3939
}
4040

41-
changesGetter := NewChangesGetter(tx)
42-
if err := changesGetter.openChangesGetter(from); err != nil {
41+
cg := NewChangesGetter(tx)
42+
if err := cg.openChangesGetter(from); err != nil {
4343
return trie.EmptyRoot, fmt.Errorf("OpenChangesGetter: %w", err)
4444
}
45-
defer changesGetter.closeChangesGetter()
45+
defer cg.closeChangesGetter()
4646

4747
total := uint64(math.Abs(float64(from) - float64(to) + 1))
4848
progressChan, stopPrinter := zk.ProgressPrinter(fmt.Sprintf("[%s] Progress unwinding", logPrefix), total, quiet)
@@ -58,7 +58,7 @@ func UnwindZkSMT(ctx context.Context, logPrefix string, from, to uint64, tx kv.R
5858
default:
5959
}
6060

61-
if err := changesGetter.getChangesForBlock(i); err != nil {
61+
if err := cg.getChangesForBlock(i); err != nil {
6262
return trie.EmptyRoot, fmt.Errorf("getChangesForBlock: %w", err)
6363
}
6464

@@ -67,7 +67,7 @@ func UnwindZkSMT(ctx context.Context, logPrefix string, from, to uint64, tx kv.R
6767

6868
stopPrinter()
6969

70-
if _, _, err := dbSmt.SetStorage(ctx, logPrefix, changesGetter.accChanges, changesGetter.codeChanges, changesGetter.storageChanges); err != nil {
70+
if _, _, err := dbSmt.SetStorage(ctx, logPrefix, cg.accChanges, cg.codeChanges, cg.storageChanges); err != nil {
7171
return trie.EmptyRoot, err
7272
}
7373

0 commit comments

Comments
 (0)