Skip to content

Commit c07572a

Browse files
rootulprach-id
authored andcommitted
feat: throw error if incorrect sequenceLength (celestiaorg#865)
Closes celestiaorg#839
1 parent 8c40c19 commit c07572a

10 files changed

+192
-50
lines changed

pkg/shares/compact_shares_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ func TestCompactShareContainsInfoByte(t *testing.T) {
123123

124124
infoByte := shares[0][appconsts.NamespaceSize : appconsts.NamespaceSize+appconsts.ShareInfoBytes][0]
125125

126-
isMessageStart := true
127-
want, err := NewInfoByte(appconsts.ShareVersion, isMessageStart)
126+
isSequenceStart := true
127+
want, err := NewInfoByte(appconsts.ShareVersion, isSequenceStart)
128128

129129
require.NoError(t, err)
130130
assert.Equal(t, byte(want), infoByte)
@@ -143,8 +143,8 @@ func TestContiguousCompactShareContainsInfoByte(t *testing.T) {
143143

144144
infoByte := shares[1][appconsts.NamespaceSize : appconsts.NamespaceSize+appconsts.ShareInfoBytes][0]
145145

146-
isMessageStart := false
147-
want, err := NewInfoByte(appconsts.ShareVersion, isMessageStart)
146+
isSequenceStart := false
147+
want, err := NewInfoByte(appconsts.ShareVersion, isSequenceStart)
148148

149149
require.NoError(t, err)
150150
assert.Equal(t, byte(want), infoByte)

pkg/shares/info_byte.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ import (
88

99
// InfoByte is a byte with the following structure: the first 7 bits are
1010
// reserved for version information in big endian form (initially `0000000`).
11-
// The last bit is a "message start indicator", that is `1` if the share is at
12-
// the start of a message and `0` otherwise.
11+
// The last bit is a "sequence start indicator", that is `1` if this is the
12+
// first share of a sequence and `0` if this is a continuation share.
1313
type InfoByte byte
1414

15-
func NewInfoByte(version uint8, isMessageStart bool) (InfoByte, error) {
15+
func NewInfoByte(version uint8, isSequenceStart bool) (InfoByte, error) {
1616
if version > appconsts.MaxShareVersion {
1717
return 0, fmt.Errorf("version %d must be less than or equal to %d", version, appconsts.MaxShareVersion)
1818
}
1919

2020
prefix := version << 1
21-
if isMessageStart {
21+
if isSequenceStart {
2222
return InfoByte(prefix + 1), nil
2323
}
2424
return InfoByte(prefix), nil
@@ -31,13 +31,13 @@ func (i InfoByte) Version() uint8 {
3131
return version
3232
}
3333

34-
// IsMessageStart returns whether this share is the start of a message.
35-
func (i InfoByte) IsMessageStart() bool {
34+
// IsSequenceStart returns whether this share is the start of a message.
35+
func (i InfoByte) IsSequenceStart() bool {
3636
return uint(i)%2 == 1
3737
}
3838

3939
func ParseInfoByte(i byte) (InfoByte, error) {
40-
isMessageStart := i%2 == 1
40+
isSequenceStart := i%2 == 1
4141
version := uint8(i) >> 1
42-
return NewInfoByte(version, isMessageStart)
42+
return NewInfoByte(version, isSequenceStart)
4343
}

pkg/shares/info_byte_test.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ func TestInfoByte(t *testing.T) {
77
notMessageStart := false
88

99
type testCase struct {
10-
version uint8
11-
isMessageStart bool
10+
version uint8
11+
isSequenceStart bool
1212
}
1313
tests := []testCase{
1414
{0, messageStart},
@@ -23,15 +23,15 @@ func TestInfoByte(t *testing.T) {
2323
}
2424

2525
for _, test := range tests {
26-
irb, err := NewInfoByte(test.version, test.isMessageStart)
26+
irb, err := NewInfoByte(test.version, test.isSequenceStart)
2727
if err != nil {
2828
t.Errorf("got %v want no error", err)
2929
}
3030
if got := irb.Version(); got != test.version {
3131
t.Errorf("got version %v want %v", got, test.version)
3232
}
33-
if got := irb.IsMessageStart(); got != test.isMessageStart {
34-
t.Errorf("got isMessageStart %v want %v", got, test.isMessageStart)
33+
if got := irb.IsSequenceStart(); got != test.isSequenceStart {
34+
t.Errorf("got IsSequenceStart %v want %v", got, test.isSequenceStart)
3535
}
3636
}
3737
}
@@ -41,8 +41,8 @@ func TestInfoByteErrors(t *testing.T) {
4141
notMessageStart := false
4242

4343
type testCase struct {
44-
version uint8
45-
isMessageStart bool
44+
version uint8
45+
isSequenceStart bool
4646
}
4747

4848
tests := []testCase{
@@ -61,11 +61,11 @@ func TestInfoByteErrors(t *testing.T) {
6161
}
6262

6363
func FuzzNewInfoByte(f *testing.F) {
64-
f.Fuzz(func(t *testing.T, version uint8, isMessageStart bool) {
64+
f.Fuzz(func(t *testing.T, version uint8, isSequenceStart bool) {
6565
if version > 127 {
6666
t.Skip()
6767
}
68-
_, err := NewInfoByte(version, isMessageStart)
68+
_, err := NewInfoByte(version, isSequenceStart)
6969
if err != nil {
7070
t.Errorf("got nil but want error when version > 127")
7171
}
@@ -74,9 +74,9 @@ func FuzzNewInfoByte(f *testing.F) {
7474

7575
func TestParseInfoByte(t *testing.T) {
7676
type testCase struct {
77-
b byte
78-
wantVersion uint8
79-
wantIsMessageStart bool
77+
b byte
78+
wantVersion uint8
79+
wantisSequenceStart bool
8080
}
8181

8282
tests := []testCase{
@@ -96,8 +96,8 @@ func TestParseInfoByte(t *testing.T) {
9696
if got.Version() != test.wantVersion {
9797
t.Errorf("got version %v want %v", got.Version(), test.wantVersion)
9898
}
99-
if got.IsMessageStart() != test.wantIsMessageStart {
100-
t.Errorf("got isMessageStart %v want %v", got.IsMessageStart(), test.wantIsMessageStart)
99+
if got.IsSequenceStart() != test.wantisSequenceStart {
100+
t.Errorf("got IsSequenceStart %v want %v", got.IsSequenceStart(), test.wantisSequenceStart)
101101
}
102102
}
103103
}

pkg/shares/parse_compact_shares.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func (ss *shareStack) resolve() ([][]byte, error) {
5656
if err != nil {
5757
panic(err)
5858
}
59-
if !infoByte.IsMessageStart() {
59+
if !infoByte.IsSequenceStart() {
6060
return nil, errors.New("first share is not a message start")
6161
}
6262
err = ss.peel(ss.shares[0][appconsts.NamespaceSize+appconsts.ShareInfoBytes+appconsts.FirstCompactShareSequenceLengthBytes+appconsts.CompactShareReservedBytes:], true)

pkg/shares/parse_sparse_shares.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ func parseSparseShares(rawShares [][]byte, supportedShareVersions []uint8) ([]co
6161
if err != nil {
6262
panic(err)
6363
}
64-
if infoByte.IsMessageStart() != isNewMessage {
65-
return nil, fmt.Errorf("expected message start indicator to be %t but got %t", isNewMessage, infoByte.IsMessageStart())
64+
if infoByte.IsSequenceStart() != isNewMessage {
65+
return nil, fmt.Errorf("expected sequence start indicator to be %t but got %t", isNewMessage, infoByte.IsSequenceStart())
6666
}
6767
currentMsg = coretypes.Message{
6868
NamespaceID: nid,

pkg/shares/share_merging.go

+74-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func ParseShares(rawShares [][]byte) ([]ShareSequence, error) {
156156
if err != nil {
157157
return sequences, err
158158
}
159-
if infoByte.IsMessageStart() {
159+
if infoByte.IsSequenceStart() {
160160
if len(currentSequence.Shares) > 0 {
161161
sequences = append(sequences, currentSequence)
162162
}
@@ -176,5 +176,78 @@ func ParseShares(rawShares [][]byte) ([]ShareSequence, error) {
176176
sequences = append(sequences, currentSequence)
177177
}
178178

179+
for _, sequence := range sequences {
180+
if err := sequence.validSequenceLength(); err != nil {
181+
return sequences, err
182+
}
183+
}
184+
179185
return sequences, nil
180186
}
187+
188+
// validSequenceLength extracts the sequenceLength written to the first share
189+
// and returns an error if the number of shares needed to store a sequence of
190+
// length sequenceLength doesn't match the number of shares in this share
191+
// sequence. Returns nil if there is no error.
192+
func (s ShareSequence) validSequenceLength() error {
193+
if len(s.Shares) == 0 {
194+
return fmt.Errorf("invalid sequence length because share sequence %v has no shares", s)
195+
}
196+
firstShare := s.Shares[0]
197+
sharesNeeded, err := numberOfSharesNeeded(firstShare)
198+
if err != nil {
199+
return err
200+
}
201+
202+
if len(s.Shares) != sharesNeeded {
203+
return fmt.Errorf("share sequence has %d shares but needed %d shares", len(s.Shares), sharesNeeded)
204+
}
205+
return nil
206+
}
207+
208+
// numberOfSharesNeeded extracts the sequenceLength written to the share
209+
// firstShare and returns the number of shares needed to store a sequence of
210+
// that length.
211+
func numberOfSharesNeeded(firstShare Share) (sharesUsed int, err error) {
212+
sequenceLength, err := firstShare.SequenceLength()
213+
if err != nil {
214+
return 0, err
215+
}
216+
217+
if firstShare.isCompactShare() {
218+
return compactSharesNeeded(int(sequenceLength)), nil
219+
}
220+
return sparseSharesNeeded(int(sequenceLength)), nil
221+
}
222+
223+
// compactSharesNeeded returns the number of compact shares needed to store a
224+
// sequence of length sequenceLength. The parameter sequenceLength is the number
225+
// of bytes of transaction, intermediate state root, or evidence data in a
226+
// sequence.
227+
func compactSharesNeeded(sequenceLength int) (sharesNeeded int) {
228+
if sequenceLength == 0 {
229+
return 0
230+
}
231+
232+
if sequenceLength < appconsts.FirstCompactShareContentSize {
233+
return 1
234+
}
235+
sequenceLength -= appconsts.FirstCompactShareContentSize
236+
sharesNeeded++
237+
238+
for sequenceLength > 0 {
239+
sequenceLength -= appconsts.ContinuationCompactShareContentSize
240+
sharesNeeded++
241+
}
242+
return sharesNeeded
243+
}
244+
245+
// sparseSharesNeeded returns the number of shares needed to store a sequence of
246+
// length sequenceLength.
247+
func sparseSharesNeeded(sequenceLength int) (sharesNeeded int) {
248+
sharesNeeded = sequenceLength / appconsts.SparseShareContentSize
249+
if sequenceLength%appconsts.SparseShareContentSize != 0 {
250+
sharesNeeded++
251+
}
252+
return sharesNeeded
253+
}

pkg/shares/share_merging_test.go

+79-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package shares
22

33
import (
4+
"encoding/binary"
45
"math/rand"
56
"reflect"
67
"testing"
78

89
"github.com/celestiaorg/celestia-app/pkg/appconsts"
910
"github.com/celestiaorg/nmt/namespace"
11+
"github.com/stretchr/testify/assert"
1012
tmrand "github.com/tendermint/tendermint/libs/rand"
1113
"github.com/tendermint/tendermint/types"
1214
)
@@ -41,8 +43,14 @@ func TestParseShares(t *testing.T) {
4143
messageTwoStart := messageTwoShares[0]
4244
messageTwoContinuation := messageTwoShares[1]
4345

44-
invalidShare := generateRawShare(messageOneNamespace, start)
45-
invalidShare = append(invalidShare, []byte{0}...)
46+
invalidShare := generateRawShare(messageOneNamespace, start, 1)
47+
invalidShare = append(invalidShare, []byte{0}...) // invalidShare is now longer than the length of a valid share
48+
49+
largeSequenceLength := 1000 // it takes more than one share to store a sequence of 1000 bytes
50+
oneShareWithTooLargeSequenceLength := generateRawShare(messageOneNamespace, start, uint64(largeSequenceLength))
51+
52+
shortSequenceLength := 0
53+
oneShareWithTooShortSequenceLength := generateRawShare(messageOneNamespace, start, uint64(shortSequenceLength))
4654

4755
tests := []testCase{
4856
{
@@ -115,12 +123,25 @@ func TestParseShares(t *testing.T) {
115123
[]ShareSequence{},
116124
true,
117125
},
126+
{
127+
"one share with too large sequence length",
128+
[][]byte{oneShareWithTooLargeSequenceLength},
129+
[]ShareSequence{},
130+
true,
131+
},
132+
{
133+
"one share with too short sequence length",
134+
[][]byte{oneShareWithTooShortSequenceLength},
135+
[]ShareSequence{},
136+
true,
137+
},
118138
}
119139
for _, tt := range tests {
120140
t.Run(tt.name, func(t *testing.T) {
121141
got, err := ParseShares(tt.shares)
122-
if tt.expectErr && err == nil {
123-
t.Errorf("ParseShares() error %v, expectErr %v", err, tt.expectErr)
142+
if tt.expectErr {
143+
assert.Error(t, err)
144+
return
124145
}
125146
if !reflect.DeepEqual(got, tt.want) {
126147
t.Errorf("ParseShares() got %v, want %v", got, tt.want)
@@ -129,16 +150,64 @@ func TestParseShares(t *testing.T) {
129150
}
130151
}
131152

132-
func generateRawShare(namespace namespace.ID, isMessageStart bool) (rawShare []byte) {
133-
infoByte, _ := NewInfoByte(appconsts.ShareVersion, isMessageStart)
134-
rawData := make([]byte, appconsts.ShareSize-len(rawShare))
135-
rand.Read(rawData)
153+
func Test_compactSharesNeeded(t *testing.T) {
154+
type testCase struct {
155+
sequenceLength int
156+
want int
157+
}
158+
testCases := []testCase{
159+
{0, 0},
160+
{1, 1},
161+
{2, 1},
162+
{appconsts.FirstCompactShareContentSize, 1},
163+
{appconsts.FirstCompactShareContentSize + 1, 2},
164+
{appconsts.FirstCompactShareContentSize + appconsts.ContinuationCompactShareContentSize, 2},
165+
{appconsts.FirstCompactShareContentSize + appconsts.ContinuationCompactShareContentSize*100, 101},
166+
}
167+
for _, tc := range testCases {
168+
got := compactSharesNeeded(tc.sequenceLength)
169+
assert.Equal(t, tc.want, got)
170+
}
171+
}
172+
173+
func Test_sparseSharesNeeded(t *testing.T) {
174+
type testCase struct {
175+
sequenceLength int
176+
want int
177+
}
178+
testCases := []testCase{
179+
{0, 0},
180+
{1, 1},
181+
{2, 1},
182+
{appconsts.SparseShareContentSize, 1},
183+
{appconsts.SparseShareContentSize + 1, 2},
184+
{appconsts.SparseShareContentSize * 2, 2},
185+
{appconsts.SparseShareContentSize*100 + 1, 101},
186+
}
187+
for _, tc := range testCases {
188+
got := sparseSharesNeeded(tc.sequenceLength)
189+
assert.Equal(t, tc.want, got)
190+
}
191+
}
192+
193+
func generateRawShare(namespace namespace.ID, isSequenceStart bool, sequenceLength uint64) (rawShare []byte) {
194+
infoByte, _ := NewInfoByte(appconsts.ShareVersion, isSequenceStart)
195+
196+
sequenceLengthVarint := make([]byte, binary.MaxVarintLen64)
197+
numBytesWritten := binary.PutUvarint(sequenceLengthVarint, sequenceLength)
136198

137199
rawShare = append(rawShare, namespace...)
138200
rawShare = append(rawShare, byte(infoByte))
139-
rawShare = append(rawShare, rawData...)
201+
rawShare = append(rawShare, sequenceLengthVarint[:numBytesWritten]...)
202+
203+
return padWithRandomBytes(rawShare)
204+
}
140205

141-
return rawShare
206+
func padWithRandomBytes(partialShare Share) (paddedShare Share) {
207+
paddedShare = make([]byte, appconsts.ShareSize)
208+
copy(paddedShare, partialShare)
209+
rand.Read(paddedShare[len(partialShare):])
210+
return paddedShare
142211
}
143212

144213
func generateRandomTxs(count, size int) types.Txs {

0 commit comments

Comments
 (0)