@@ -73,6 +73,7 @@ type Conn struct {
73
73
paddingLengthGenerator func (uint ) uint
74
74
75
75
handshakeCompletedSuccessfully atomic.Value
76
+ handshakeMutex sync.Mutex
76
77
77
78
encryptedPackets []addrPkt
78
79
@@ -94,9 +95,11 @@ type Conn struct {
94
95
fsm * handshakeFSM
95
96
96
97
replayProtectionWindow uint
98
+
99
+ handshakeConfig * handshakeConfig
97
100
}
98
101
99
- func createConn (nextConn net.PacketConn , rAddr net.Addr , config * Config , isClient bool ) (* Conn , error ) {
102
+ func createConn (nextConn net.PacketConn , rAddr net.Addr , config * Config , isClient bool , resumeState * State ) (* Conn , error ) {
100
103
if err := validateConfig (config ); err != nil {
101
104
return nil , err
102
105
}
@@ -127,42 +130,6 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
127
130
paddingLengthGenerator = func (uint ) uint { return 0 }
128
131
}
129
132
130
- c := & Conn {
131
- rAddr : rAddr ,
132
- nextConn : netctx .NewPacketConn (nextConn ),
133
- fragmentBuffer : newFragmentBuffer (),
134
- handshakeCache : newHandshakeCache (),
135
- maximumTransmissionUnit : mtu ,
136
- paddingLengthGenerator : paddingLengthGenerator ,
137
-
138
- decrypted : make (chan interface {}, 1 ),
139
- log : logger ,
140
-
141
- readDeadline : deadline .New (),
142
- writeDeadline : deadline .New (),
143
-
144
- reading : make (chan struct {}, 1 ),
145
- handshakeRecv : make (chan recvHandshakeState ),
146
- closed : closer .NewCloser (),
147
- cancelHandshaker : func () {},
148
-
149
- replayProtectionWindow : uint (replayProtectionWindow ),
150
-
151
- state : State {
152
- isClient : isClient ,
153
- },
154
- }
155
-
156
- c .setRemoteEpoch (0 )
157
- c .setLocalEpoch (0 )
158
- return c , nil
159
- }
160
-
161
- func handshakeConn (ctx context.Context , conn * Conn , config * Config , isClient bool , initialState * State ) (* Conn , error ) {
162
- if conn == nil {
163
- return nil , errNilNextConn
164
- }
165
-
166
133
cipherSuites , err := parseCipherSuites (config .CipherSuites , config .CustomCipherSuites , config .includeCertificateSuites (), config .PSK != nil )
167
134
if err != nil {
168
135
return nil , err
@@ -190,7 +157,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
190
157
curves = defaultCurves
191
158
}
192
159
193
- hsCfg := & handshakeConfig {
160
+ handshakeConfig := & handshakeConfig {
194
161
localPSKCallback : config .PSK ,
195
162
localPSKIdentityHint : config .PSKIdentityHint ,
196
163
localCipherSuites : cipherSuites ,
@@ -209,7 +176,7 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
209
176
customCipherSuites : config .CustomCipherSuites ,
210
177
initialRetransmitInterval : workerInterval ,
211
178
disableRetransmitBackoff : config .DisableRetransmitBackoff ,
212
- log : conn . log ,
179
+ log : logger ,
213
180
initialEpoch : 0 ,
214
181
keyLogWriter : config .KeyLogWriter ,
215
182
sessionStore : config .SessionStore ,
@@ -222,82 +189,115 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
222
189
clientHelloMessageHook : config .ClientHelloMessageHook ,
223
190
serverHelloMessageHook : config .ServerHelloMessageHook ,
224
191
certificateRequestMessageHook : config .CertificateRequestMessageHook ,
192
+ resumeState : resumeState ,
193
+ }
194
+
195
+ c := & Conn {
196
+ rAddr : rAddr ,
197
+ nextConn : netctx .NewPacketConn (nextConn ),
198
+ handshakeConfig : handshakeConfig ,
199
+ fragmentBuffer : newFragmentBuffer (),
200
+ handshakeCache : newHandshakeCache (),
201
+ maximumTransmissionUnit : mtu ,
202
+ paddingLengthGenerator : paddingLengthGenerator ,
203
+
204
+ decrypted : make (chan interface {}, 1 ),
205
+ log : logger ,
206
+
207
+ readDeadline : deadline .New (),
208
+ writeDeadline : deadline .New (),
209
+
210
+ reading : make (chan struct {}, 1 ),
211
+ handshakeRecv : make (chan recvHandshakeState ),
212
+ closed : closer .NewCloser (),
213
+ cancelHandshaker : func () {},
214
+ cancelHandshakeReader : func () {},
215
+
216
+ replayProtectionWindow : uint (replayProtectionWindow ),
217
+
218
+ state : State {
219
+ isClient : isClient ,
220
+ },
221
+ }
222
+
223
+ c .setRemoteEpoch (0 )
224
+ c .setLocalEpoch (0 )
225
+ return c , nil
226
+ }
227
+
228
+ // Handshake runs the client or server DTLS handshake
229
+ // protocol if it has not yet been run.
230
+ //
231
+ // Most uses of this package need not call Handshake explicitly: the
232
+ // first [Conn.Read] or [Conn.Write] will call it automatically.
233
+ //
234
+ // For control over canceling or setting a timeout on a handshake, use
235
+ // [Conn.HandshakeContext].
236
+ func (c * Conn ) Handshake () error {
237
+ return c .HandshakeContext (context .Background ())
238
+ }
239
+
240
+ // HandshakeContext runs the client or server DTLS handshake
241
+ // protocol if it has not yet been run.
242
+ //
243
+ // The provided Context must be non-nil. If the context is canceled before
244
+ // the handshake is complete, the handshake is interrupted and an error is returned.
245
+ // Once the handshake has completed, cancellation of the context will not affect the
246
+ // connection.
247
+ //
248
+ // Most uses of this package need not call HandshakeContext explicitly: the
249
+ // first [Conn.Read] or [Conn.Write] will call it automatically.
250
+ func (c * Conn ) HandshakeContext (ctx context.Context ) error {
251
+ c .handshakeMutex .Lock ()
252
+ defer c .handshakeMutex .Unlock ()
253
+
254
+ if c .isHandshakeCompletedSuccessfully () {
255
+ return nil
225
256
}
226
257
227
258
// rfc5246#section-7.4.3
228
259
// In addition, the hash and signature algorithms MUST be compatible
229
260
// with the key in the server's end-entity certificate.
230
- if ! isClient {
231
- cert , err := hsCfg .getCertificate (& ClientHelloInfo {})
261
+ if ! c . state . isClient {
262
+ cert , err := c . handshakeConfig .getCertificate (& ClientHelloInfo {})
232
263
if err != nil && ! errors .Is (err , errNoCertificates ) {
233
- return nil , err
264
+ return err
234
265
}
235
- hsCfg . localCipherSuites = filterCipherSuitesForCertificate (cert , cipherSuites )
266
+ c . handshakeConfig . localCipherSuites = filterCipherSuitesForCertificate (cert , c . handshakeConfig . localCipherSuites )
236
267
}
237
268
238
269
var initialFlight flightVal
239
270
var initialFSMState handshakeState
240
271
241
- if initialState != nil {
242
- if conn .state .isClient {
272
+ if c . handshakeConfig . resumeState != nil {
273
+ if c .state .isClient {
243
274
initialFlight = flight5
244
275
} else {
245
276
initialFlight = flight6
246
277
}
247
278
initialFSMState = handshakeFinished
248
279
249
- conn .state = * initialState
280
+ c .state = * c . handshakeConfig . resumeState
250
281
} else {
251
- if conn .state .isClient {
282
+ if c .state .isClient {
252
283
initialFlight = flight1
253
284
} else {
254
285
initialFlight = flight0
255
286
}
256
287
initialFSMState = handshakePreparing
257
288
}
258
289
// Do handshake
259
- if err := conn .handshake (ctx , hsCfg , initialFlight , initialFSMState ); err != nil {
260
- return nil , err
290
+ if err := c .handshake (ctx , c . handshakeConfig , initialFlight , initialFSMState ); err != nil {
291
+ return err
261
292
}
262
293
263
- conn .log .Trace ("Handshake Completed" )
294
+ c .log .Trace ("Handshake Completed" )
264
295
265
- return conn , nil
296
+ return nil
266
297
}
267
298
268
299
// Dial connects to the given network address and establishes a DTLS connection on top.
269
- // Connection handshake will timeout using ConnectContextMaker in the Config.
270
- // If you want to specify the timeout duration, use DialWithContext() instead.
271
300
func Dial (network string , rAddr * net.UDPAddr , config * Config ) (* Conn , error ) {
272
- ctx , cancel := config .connectContextMaker ()
273
- defer cancel ()
274
-
275
- return DialWithContext (ctx , network , rAddr , config )
276
- }
277
-
278
- // Client establishes a DTLS connection over an existing connection.
279
- // Connection handshake will timeout using ConnectContextMaker in the Config.
280
- // If you want to specify the timeout duration, use ClientWithContext() instead.
281
- func Client (conn net.PacketConn , rAddr net.Addr , config * Config ) (* Conn , error ) {
282
- ctx , cancel := config .connectContextMaker ()
283
- defer cancel ()
284
-
285
- return ClientWithContext (ctx , conn , rAddr , config )
286
- }
287
-
288
- // Server listens for incoming DTLS connections.
289
- // Connection handshake will timeout using ConnectContextMaker in the Config.
290
- // If you want to specify the timeout duration, use ServerWithContext() instead.
291
- func Server (conn net.PacketConn , rAddr net.Addr , config * Config ) (* Conn , error ) {
292
- ctx , cancel := config .connectContextMaker ()
293
- defer cancel ()
294
-
295
- return ServerWithContext (ctx , conn , rAddr , config )
296
- }
297
-
298
- // DialWithContext connects to the given network address and establishes a DTLS
299
- // connection on top.
300
- func DialWithContext (ctx context.Context , network string , rAddr * net.UDPAddr , config * Config ) (* Conn , error ) {
301
301
// net.ListenUDP is used rather than net.DialUDP as the latter prevents the
302
302
// use of net.PacketConn.WriteTo.
303
303
// https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115
@@ -306,28 +306,23 @@ func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, co
306
306
return nil , err
307
307
}
308
308
309
- return ClientWithContext ( ctx , pConn , rAddr , config )
309
+ return Client ( pConn , rAddr , config )
310
310
}
311
311
312
- // ClientWithContext establishes a DTLS connection over an existing connection.
313
- func ClientWithContext ( ctx context. Context , conn net.PacketConn , rAddr net.Addr , config * Config ) (* Conn , error ) {
312
+ // Client establishes a DTLS connection over an existing connection.
313
+ func Client ( conn net.PacketConn , rAddr net.Addr , config * Config ) (* Conn , error ) {
314
314
switch {
315
315
case config == nil :
316
316
return nil , errNoConfigProvided
317
317
case config .PSK != nil && config .PSKIdentityHint == nil :
318
318
return nil , errPSKAndIdentityMustBeSetForClient
319
319
}
320
320
321
- dconn , err := createConn (conn , rAddr , config , true )
322
- if err != nil {
323
- return nil , err
324
- }
325
-
326
- return handshakeConn (ctx , dconn , config , true , nil )
321
+ return createConn (conn , rAddr , config , true , nil )
327
322
}
328
323
329
- // ServerWithContext listens for incoming DTLS connections.
330
- func ServerWithContext ( ctx context. Context , conn net.PacketConn , rAddr net.Addr , config * Config ) (* Conn , error ) {
324
+ // Server listens for incoming DTLS connections.
325
+ func Server ( conn net.PacketConn , rAddr net.Addr , config * Config ) (* Conn , error ) {
331
326
if config == nil {
332
327
return nil , errNoConfigProvided
333
328
}
@@ -336,17 +331,13 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
336
331
return nil , err
337
332
}
338
333
}
339
- dconn , err := createConn (conn , rAddr , config , false )
340
- if err != nil {
341
- return nil , err
342
- }
343
- return handshakeConn (ctx , dconn , config , false , nil )
334
+ return createConn (conn , rAddr , config , false , nil )
344
335
}
345
336
346
337
// Read reads data from the connection.
347
338
func (c * Conn ) Read (p []byte ) (n int , err error ) {
348
- if ! c . isHandshakeCompletedSuccessfully () {
349
- return 0 , errHandshakeInProgress
339
+ if err := c . Handshake (); err != nil {
340
+ return 0 , err
350
341
}
351
342
352
343
select {
@@ -389,8 +380,8 @@ func (c *Conn) Write(p []byte) (int, error) {
389
380
default :
390
381
}
391
382
392
- if ! c . isHandshakeCompletedSuccessfully () {
393
- return 0 , errHandshakeInProgress
383
+ if err := c . Handshake (); err != nil {
384
+ return 0 , err
394
385
}
395
386
396
387
return len (p ), c .writePackets (c .writeDeadline , []* packet {
0 commit comments