Skip to content

Commit f702adf

Browse files
mmatczukHenrik Johansson
and
Henrik Johansson
committed
scylla: dedicated token aware ConnPicker
This is an extension to tokenAwareHostPolicy supported by the Scylla 2.3 and onwards. It allows driver to select a connection to shard on a host based on the token. The protocol extension spec is available at [1]. [1] https://github.com/scylladb/scylla/blob/master/docs/protocol-extensions.md Co-authored-by: Henrik Johansson <[email protected]> Co-authored-by: Michał Matczuk <[email protected]>
1 parent 608dba8 commit f702adf

File tree

4 files changed

+10303
-0
lines changed

4 files changed

+10303
-0
lines changed

connectionpool.go

+5
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,11 @@ func (pool *hostConnPool) initConnPicker(conn *Conn) {
517517
return
518518
}
519519

520+
if isScyllaConn(conn) {
521+
pool.connPicker = newScyllaConnPicker(conn)
522+
return
523+
}
524+
520525
pool.connPicker = newDefaultConnPicker(pool.size)
521526
}
522527

scylla.go

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package gocql
2+
3+
import (
4+
"fmt"
5+
"math"
6+
"strconv"
7+
"sync/atomic"
8+
)
9+
10+
// scyllaSupported represents Scylla connection options as sent in SUPPORTED
11+
// frame.
12+
type scyllaSupported struct {
13+
shard int
14+
nrShards int
15+
msbIgnore uint64
16+
}
17+
18+
func parseSupported(supported map[string][]string) scyllaSupported {
19+
const (
20+
scyllaShard = "SCYLLA_SHARD"
21+
scyllaNrShards = "SCYLLA_NR_SHARDS"
22+
scyllaPartitioner = "SCYLLA_PARTITIONER"
23+
scyllaShardingAlgorithm = "SCYLLA_SHARDING_ALGORITHM"
24+
scyllaShardingIgnoreMSB = "SCYLLA_SHARDING_IGNORE_MSB"
25+
)
26+
27+
var (
28+
si scyllaSupported
29+
err error
30+
)
31+
32+
if s, ok := supported[scyllaShard]; ok {
33+
if si.shard, err = strconv.Atoi(s[0]); err != nil {
34+
if gocqlDebug {
35+
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShard, s, err)
36+
}
37+
}
38+
}
39+
if s, ok := supported[scyllaNrShards]; ok {
40+
if si.nrShards, err = strconv.Atoi(s[0]); err != nil {
41+
if gocqlDebug {
42+
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaNrShards, s, err)
43+
}
44+
}
45+
}
46+
if s, ok := supported[scyllaShardingIgnoreMSB]; ok {
47+
if si.msbIgnore, err = strconv.ParseUint(s[0], 10, 64); err != nil {
48+
if gocqlDebug {
49+
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardingIgnoreMSB, s, err)
50+
}
51+
}
52+
}
53+
54+
var (
55+
partitioner string
56+
algorithm string
57+
)
58+
if s, ok := supported[scyllaPartitioner]; ok {
59+
partitioner = s[0]
60+
}
61+
if s, ok := supported[scyllaShardingAlgorithm]; ok {
62+
algorithm = s[0]
63+
}
64+
65+
if partitioner != "org.apache.cassandra.dht.Murmur3Partitioner" || algorithm != "biased-token-round-robin" || si.nrShards == 0 || si.msbIgnore == 0 {
66+
if gocqlDebug {
67+
Logger.Printf("scylla: unsupported sharding configuration")
68+
}
69+
return scyllaSupported{}
70+
}
71+
72+
return si
73+
}
74+
75+
// isScyllaConn checks if conn is suitable for scyllaConnPicker.
76+
func isScyllaConn(conn *Conn) bool {
77+
s := parseSupported(conn.supported)
78+
return s.nrShards != 0
79+
}
80+
81+
// scyllaConnPicker is a specialised ConnPicker that selects connections based
82+
// on token trying to get connection to a shard containing the given token.
83+
type scyllaConnPicker struct {
84+
conns []*Conn
85+
nrConns int
86+
nrShards int
87+
msbIgnore uint64
88+
pos int32
89+
}
90+
91+
func newScyllaConnPicker(conn *Conn) *scyllaConnPicker {
92+
s := parseSupported(conn.supported)
93+
if s.nrShards == 0 {
94+
panic(fmt.Sprintf("scylla: %s not a sharded connection", conn.Address()))
95+
}
96+
97+
if gocqlDebug {
98+
Logger.Printf("scylla: %s sharding options %+v", conn.Address(), s)
99+
}
100+
101+
return &scyllaConnPicker{
102+
nrShards: s.nrShards,
103+
msbIgnore: s.msbIgnore,
104+
}
105+
}
106+
107+
func (p *scyllaConnPicker) Remove(conn *Conn) {
108+
s := parseSupported(conn.supported)
109+
if s.nrShards == 0 {
110+
panic(fmt.Sprintf("scylla: %s not a sharded connection", conn.Address()))
111+
}
112+
if gocqlDebug {
113+
Logger.Printf("scylla: %s remove shard %d connection", conn.Address(), s.shard)
114+
}
115+
p.conns[s.shard] = nil
116+
}
117+
118+
func (p *scyllaConnPicker) Close() {
119+
conns := p.conns
120+
p.conns = nil
121+
for _, conn := range conns {
122+
if conn != nil {
123+
conn.Close()
124+
}
125+
}
126+
}
127+
128+
func (p *scyllaConnPicker) Size() (int, int) {
129+
return p.nrConns, p.nrShards - p.nrConns
130+
}
131+
132+
func (p *scyllaConnPicker) Pick(t token) *Conn {
133+
if len(p.conns) == 0 {
134+
return nil
135+
}
136+
137+
if t == nil {
138+
idx := int(atomic.AddInt32(&p.pos, 1))
139+
for i := 0; i < len(p.conns); i++ {
140+
if conn := p.conns[(idx+i)%len(p.conns)]; conn != nil {
141+
return conn
142+
}
143+
}
144+
return nil
145+
}
146+
147+
mmt, ok := t.(murmur3Token)
148+
// double check if that's murmur3 token
149+
if !ok {
150+
return nil
151+
}
152+
153+
idx := p.shardOf(mmt)
154+
return p.conns[idx]
155+
}
156+
157+
func (p *scyllaConnPicker) shardOf(token murmur3Token) int {
158+
shards := uint64(p.nrShards)
159+
z := uint64(token+math.MinInt64) << p.msbIgnore
160+
lo := z & 0xffffffff
161+
hi := (z >> 32) & 0xffffffff
162+
mul1 := lo * shards
163+
mul2 := hi * shards
164+
sum := (mul1 >> 32) + mul2
165+
return int(sum >> 32)
166+
}
167+
168+
func (p *scyllaConnPicker) Put(conn *Conn) {
169+
s := parseSupported(conn.supported)
170+
if s.nrShards == 0 {
171+
panic(fmt.Sprintf("scylla: %s not a sharded connection", conn.Address()))
172+
}
173+
174+
if s.nrShards != len(p.conns) {
175+
if s.nrShards != p.nrShards {
176+
panic(fmt.Sprintf("scylla: %s invalid number of shards", conn.Address()))
177+
}
178+
conns := p.conns
179+
p.conns = make([]*Conn, s.nrShards, s.nrShards)
180+
copy(p.conns, conns)
181+
}
182+
if c := p.conns[s.shard]; c != nil {
183+
conn.Close()
184+
return
185+
}
186+
p.conns[s.shard] = conn
187+
p.nrConns++
188+
if gocqlDebug {
189+
Logger.Printf("scylla: %s put shard %d connection total: %d missing: %d", conn.Address(), s.shard, p.nrConns, p.nrShards-p.nrConns)
190+
}
191+
}

scylla_test.go

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package gocql
2+
3+
import (
4+
"runtime"
5+
"sync"
6+
"testing"
7+
)
8+
9+
func TestScyllaConnPickerPickNilToken(t *testing.T) {
10+
t.Parallel()
11+
12+
s := scyllaConnPicker{
13+
nrShards: 4,
14+
msbIgnore: 12,
15+
}
16+
17+
t.Run("no conns", func(t *testing.T) {
18+
s.conns = []*Conn{{}}
19+
if s.Pick(token(nil)) != s.conns[0] {
20+
t.Fatal("expected connection")
21+
}
22+
})
23+
24+
t.Run("one shard", func(t *testing.T) {
25+
s.conns = []*Conn{{}}
26+
if s.Pick(token(nil)) != s.conns[0] {
27+
t.Fatal("expected connection")
28+
}
29+
})
30+
31+
t.Run("multiple shards", func(t *testing.T) {
32+
s.conns = []*Conn{nil, {}}
33+
if s.Pick(token(nil)) != s.conns[1] {
34+
t.Fatal("expected connection")
35+
}
36+
if s.Pick(token(nil)) != s.conns[1] {
37+
t.Fatal("expected connection")
38+
}
39+
})
40+
41+
t.Run("multiple shards no conns", func(t *testing.T) {
42+
s.conns = []*Conn{nil, nil}
43+
if s.Pick(token(nil)) != nil {
44+
t.Fatal("expected nil")
45+
}
46+
if s.Pick(token(nil)) != nil {
47+
t.Fatal("expected nil")
48+
}
49+
})
50+
}
51+
52+
func hammerConnPicker(t *testing.T, wg *sync.WaitGroup, s *scyllaConnPicker, loops int) {
53+
t.Helper()
54+
for i := 0; i < loops; i++ {
55+
if c := s.Pick(nil); c == nil {
56+
t.Error("unexpected nil")
57+
}
58+
}
59+
wg.Done()
60+
}
61+
62+
func TestScyllaConnPickerHammerPickNilToken(t *testing.T) {
63+
t.Parallel()
64+
65+
s := scyllaConnPicker{
66+
nrShards: 4,
67+
msbIgnore: 12,
68+
}
69+
s.conns = make([]*Conn, 100)
70+
for i := range s.conns {
71+
if i%7 == 0 {
72+
continue
73+
}
74+
s.conns[i] = &Conn{}
75+
}
76+
77+
n := runtime.GOMAXPROCS(0)
78+
loops := 10000 / n
79+
80+
var wg sync.WaitGroup
81+
wg.Add(n)
82+
for i := 0; i < n; i++ {
83+
go hammerConnPicker(t, &wg, &s, loops)
84+
}
85+
wg.Wait()
86+
}
87+
88+
func TestScyllaConnPickerShardOf(t *testing.T) {
89+
t.Parallel()
90+
91+
s := scyllaConnPicker{
92+
nrShards: 4,
93+
msbIgnore: 12,
94+
}
95+
for _, test := range scyllaShardOfTests {
96+
if shard := s.shardOf(murmur3Token(test.token)); shard != test.shard {
97+
t.Errorf("wrong scylla shard calculated for token %d, expected %d, got %d", test.token, test.shard, shard)
98+
}
99+
}
100+
}

0 commit comments

Comments
 (0)