Skip to content

Commit e406468

Browse files
kevmo314Sean-Der
authored andcommittedJul 21, 2024·
Perform handshake on first read/write
Updates the connection to perform a handshake on first read/write instead of on accept. Closes #279.
1 parent 6178064 commit e406468

File tree

18 files changed

+227
-249
lines changed

18 files changed

+227
-249
lines changed
 

‎bench_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestSimpleReadWrite(t *testing.T) {
4040
return
4141
}
4242
buf := make([]byte, 1024)
43-
if _, sErr = server.Read(buf); sErr != nil {
43+
if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck
4444
t.Error(sErr)
4545
}
4646
gotHello <- struct{}{}

‎config.go

-21
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
package dtls
55

66
import (
7-
"context"
87
"crypto/ecdsa"
98
"crypto/ed25519"
109
"crypto/rsa"
@@ -118,15 +117,6 @@ type Config struct {
118117

119118
LoggerFactory logging.LoggerFactory
120119

121-
// ConnectContextMaker is a function to make a context used in Dial(),
122-
// Client(), Server(), and Accept(). If nil, the default ConnectContextMaker
123-
// is used. It can be implemented as following.
124-
//
125-
// func ConnectContextMaker() (context.Context, func()) {
126-
// return context.WithTimeout(context.Background(), 30*time.Second)
127-
// }
128-
ConnectContextMaker func() (context.Context, func())
129-
130120
// MTU is the length at which handshake messages will be fragmented to
131121
// fit within the maximum transmission unit (default is 1200 bytes)
132122
MTU int
@@ -230,17 +220,6 @@ type Config struct {
230220
OnConnectionAttempt func(net.Addr) error
231221
}
232222

233-
func defaultConnectContextMaker() (context.Context, func()) {
234-
return context.WithTimeout(context.Background(), 30*time.Second)
235-
}
236-
237-
func (c *Config) connectContextMaker() (context.Context, func()) {
238-
if c.ConnectContextMaker == nil {
239-
return defaultConnectContextMaker()
240-
}
241-
return c.ConnectContextMaker()
242-
}
243-
244223
func (c *Config) includeCertificateSuites() bool {
245224
return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil
246225
}

‎conn.go

+93-102
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ type Conn struct {
7373
paddingLengthGenerator func(uint) uint
7474

7575
handshakeCompletedSuccessfully atomic.Value
76+
handshakeMutex sync.Mutex
7677

7778
encryptedPackets []addrPkt
7879

@@ -94,9 +95,11 @@ type Conn struct {
9495
fsm *handshakeFSM
9596

9697
replayProtectionWindow uint
98+
99+
handshakeConfig *handshakeConfig
97100
}
98101

99-
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) {
102+
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State) (*Conn, error) {
100103
if err := validateConfig(config); err != nil {
101104
return nil, err
102105
}
@@ -127,42 +130,6 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
127130
paddingLengthGenerator = func(uint) uint { return 0 }
128131
}
129132

130-
c := &Conn{
131-
rAddr: rAddr,
132-
nextConn: netctx.NewPacketConn(nextConn),
133-
fragmentBuffer: newFragmentBuffer(),
134-
handshakeCache: newHandshakeCache(),
135-
maximumTransmissionUnit: mtu,
136-
paddingLengthGenerator: paddingLengthGenerator,
137-
138-
decrypted: make(chan interface{}, 1),
139-
log: logger,
140-
141-
readDeadline: deadline.New(),
142-
writeDeadline: deadline.New(),
143-
144-
reading: make(chan struct{}, 1),
145-
handshakeRecv: make(chan recvHandshakeState),
146-
closed: closer.NewCloser(),
147-
cancelHandshaker: func() {},
148-
149-
replayProtectionWindow: uint(replayProtectionWindow),
150-
151-
state: State{
152-
isClient: isClient,
153-
},
154-
}
155-
156-
c.setRemoteEpoch(0)
157-
c.setLocalEpoch(0)
158-
return c, nil
159-
}
160-
161-
func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
162-
if conn == nil {
163-
return nil, errNilNextConn
164-
}
165-
166133
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
167134
if err != nil {
168135
return nil, err
@@ -190,7 +157,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
190157
curves = defaultCurves
191158
}
192159

193-
hsCfg := &handshakeConfig{
160+
handshakeConfig := &handshakeConfig{
194161
localPSKCallback: config.PSK,
195162
localPSKIdentityHint: config.PSKIdentityHint,
196163
localCipherSuites: cipherSuites,
@@ -209,7 +176,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
209176
customCipherSuites: config.CustomCipherSuites,
210177
initialRetransmitInterval: workerInterval,
211178
disableRetransmitBackoff: config.DisableRetransmitBackoff,
212-
log: conn.log,
179+
log: logger,
213180
initialEpoch: 0,
214181
keyLogWriter: config.KeyLogWriter,
215182
sessionStore: config.SessionStore,
@@ -222,82 +189,115 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
222189
clientHelloMessageHook: config.ClientHelloMessageHook,
223190
serverHelloMessageHook: config.ServerHelloMessageHook,
224191
certificateRequestMessageHook: config.CertificateRequestMessageHook,
192+
resumeState: resumeState,
193+
}
194+
195+
c := &Conn{
196+
rAddr: rAddr,
197+
nextConn: netctx.NewPacketConn(nextConn),
198+
handshakeConfig: handshakeConfig,
199+
fragmentBuffer: newFragmentBuffer(),
200+
handshakeCache: newHandshakeCache(),
201+
maximumTransmissionUnit: mtu,
202+
paddingLengthGenerator: paddingLengthGenerator,
203+
204+
decrypted: make(chan interface{}, 1),
205+
log: logger,
206+
207+
readDeadline: deadline.New(),
208+
writeDeadline: deadline.New(),
209+
210+
reading: make(chan struct{}, 1),
211+
handshakeRecv: make(chan recvHandshakeState),
212+
closed: closer.NewCloser(),
213+
cancelHandshaker: func() {},
214+
cancelHandshakeReader: func() {},
215+
216+
replayProtectionWindow: uint(replayProtectionWindow),
217+
218+
state: State{
219+
isClient: isClient,
220+
},
221+
}
222+
223+
c.setRemoteEpoch(0)
224+
c.setLocalEpoch(0)
225+
return c, nil
226+
}
227+
228+
// Handshake runs the client or server DTLS handshake
229+
// protocol if it has not yet been run.
230+
//
231+
// Most uses of this package need not call Handshake explicitly: the
232+
// first [Conn.Read] or [Conn.Write] will call it automatically.
233+
//
234+
// For control over canceling or setting a timeout on a handshake, use
235+
// [Conn.HandshakeContext].
236+
func (c *Conn) Handshake() error {
237+
return c.HandshakeContext(context.Background())
238+
}
239+
240+
// HandshakeContext runs the client or server DTLS handshake
241+
// protocol if it has not yet been run.
242+
//
243+
// The provided Context must be non-nil. If the context is canceled before
244+
// the handshake is complete, the handshake is interrupted and an error is returned.
245+
// Once the handshake has completed, cancellation of the context will not affect the
246+
// connection.
247+
//
248+
// Most uses of this package need not call HandshakeContext explicitly: the
249+
// first [Conn.Read] or [Conn.Write] will call it automatically.
250+
func (c *Conn) HandshakeContext(ctx context.Context) error {
251+
c.handshakeMutex.Lock()
252+
defer c.handshakeMutex.Unlock()
253+
254+
if c.isHandshakeCompletedSuccessfully() {
255+
return nil
225256
}
226257

227258
// rfc5246#section-7.4.3
228259
// In addition, the hash and signature algorithms MUST be compatible
229260
// with the key in the server's end-entity certificate.
230-
if !isClient {
231-
cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
261+
if !c.state.isClient {
262+
cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{})
232263
if err != nil && !errors.Is(err, errNoCertificates) {
233-
return nil, err
264+
return err
234265
}
235-
hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
266+
c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites)
236267
}
237268

238269
var initialFlight flightVal
239270
var initialFSMState handshakeState
240271

241-
if initialState != nil {
242-
if conn.state.isClient {
272+
if c.handshakeConfig.resumeState != nil {
273+
if c.state.isClient {
243274
initialFlight = flight5
244275
} else {
245276
initialFlight = flight6
246277
}
247278
initialFSMState = handshakeFinished
248279

249-
conn.state = *initialState
280+
c.state = *c.handshakeConfig.resumeState
250281
} else {
251-
if conn.state.isClient {
282+
if c.state.isClient {
252283
initialFlight = flight1
253284
} else {
254285
initialFlight = flight0
255286
}
256287
initialFSMState = handshakePreparing
257288
}
258289
// Do handshake
259-
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
260-
return nil, err
290+
if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil {
291+
return err
261292
}
262293

263-
conn.log.Trace("Handshake Completed")
294+
c.log.Trace("Handshake Completed")
264295

265-
return conn, nil
296+
return nil
266297
}
267298

268299
// Dial connects to the given network address and establishes a DTLS connection on top.
269-
// Connection handshake will timeout using ConnectContextMaker in the Config.
270-
// If you want to specify the timeout duration, use DialWithContext() instead.
271300
func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
272-
ctx, cancel := config.connectContextMaker()
273-
defer cancel()
274-
275-
return DialWithContext(ctx, network, rAddr, config)
276-
}
277-
278-
// Client establishes a DTLS connection over an existing connection.
279-
// Connection handshake will timeout using ConnectContextMaker in the Config.
280-
// If you want to specify the timeout duration, use ClientWithContext() instead.
281-
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
282-
ctx, cancel := config.connectContextMaker()
283-
defer cancel()
284-
285-
return ClientWithContext(ctx, conn, rAddr, config)
286-
}
287-
288-
// Server listens for incoming DTLS connections.
289-
// Connection handshake will timeout using ConnectContextMaker in the Config.
290-
// If you want to specify the timeout duration, use ServerWithContext() instead.
291-
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
292-
ctx, cancel := config.connectContextMaker()
293-
defer cancel()
294-
295-
return ServerWithContext(ctx, conn, rAddr, config)
296-
}
297-
298-
// DialWithContext connects to the given network address and establishes a DTLS
299-
// connection on top.
300-
func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) {
301301
// net.ListenUDP is used rather than net.DialUDP as the latter prevents the
302302
// use of net.PacketConn.WriteTo.
303303
// https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115
@@ -306,28 +306,23 @@ func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, co
306306
return nil, err
307307
}
308308

309-
return ClientWithContext(ctx, pConn, rAddr, config)
309+
return Client(pConn, rAddr, config)
310310
}
311311

312-
// ClientWithContext establishes a DTLS connection over an existing connection.
313-
func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
312+
// Client establishes a DTLS connection over an existing connection.
313+
func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
314314
switch {
315315
case config == nil:
316316
return nil, errNoConfigProvided
317317
case config.PSK != nil && config.PSKIdentityHint == nil:
318318
return nil, errPSKAndIdentityMustBeSetForClient
319319
}
320320

321-
dconn, err := createConn(conn, rAddr, config, true)
322-
if err != nil {
323-
return nil, err
324-
}
325-
326-
return handshakeConn(ctx, dconn, config, true, nil)
321+
return createConn(conn, rAddr, config, true, nil)
327322
}
328323

329-
// ServerWithContext listens for incoming DTLS connections.
330-
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
324+
// Server listens for incoming DTLS connections.
325+
func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
331326
if config == nil {
332327
return nil, errNoConfigProvided
333328
}
@@ -336,17 +331,13 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
336331
return nil, err
337332
}
338333
}
339-
dconn, err := createConn(conn, rAddr, config, false)
340-
if err != nil {
341-
return nil, err
342-
}
343-
return handshakeConn(ctx, dconn, config, false, nil)
334+
return createConn(conn, rAddr, config, false, nil)
344335
}
345336

346337
// Read reads data from the connection.
347338
func (c *Conn) Read(p []byte) (n int, err error) {
348-
if !c.isHandshakeCompletedSuccessfully() {
349-
return 0, errHandshakeInProgress
339+
if err := c.Handshake(); err != nil {
340+
return 0, err
350341
}
351342

352343
select {
@@ -389,8 +380,8 @@ func (c *Conn) Write(p []byte) (int, error) {
389380
default:
390381
}
391382

392-
if !c.isHandshakeCompletedSuccessfully() {
393-
return 0, errHandshakeInProgress
383+
if err := c.Handshake(); err != nil {
384+
return 0, err
394385
}
395386

396387
return len(p), c.writePackets(c.writeDeadline, []*packet{

0 commit comments

Comments
 (0)
Please sign in to comment.