@@ -3,12 +3,19 @@ package pq
3
3
// This file contains SSL tests
4
4
5
5
import (
6
+ "bytes"
6
7
_ "crypto/sha256"
8
+ "crypto/tls"
7
9
"crypto/x509"
8
10
"database/sql"
11
+ "fmt"
12
+ "io"
13
+ "net"
9
14
"os"
10
15
"path/filepath"
16
+ "strings"
11
17
"testing"
18
+ "time"
12
19
)
13
20
14
21
func maybeSkipSSLTests (t * testing.T ) {
@@ -280,3 +287,135 @@ func TestSSLClientCertificates(t *testing.T) {
280
287
}
281
288
}
282
289
}
290
+
291
+ // Check that clint sends SNI data when `sslsni` is not disabled
292
+ func TestSNISupport (t * testing.T ) {
293
+ t .Parallel ()
294
+ tests := []struct {
295
+ name string
296
+ conn_param string
297
+ hostname string
298
+ expected_sni string
299
+ }{
300
+ {
301
+ name : "SNI is set by default" ,
302
+ conn_param : "" ,
303
+ hostname : "localhost" ,
304
+ expected_sni : "localhost" ,
305
+ },
306
+ {
307
+ name : "SNI is passed when asked for" ,
308
+ conn_param : "sslsni=1" ,
309
+ hostname : "localhost" ,
310
+ expected_sni : "localhost" ,
311
+ },
312
+ {
313
+ name : "SNI is not passed when disabled" ,
314
+ conn_param : "sslsni=0" ,
315
+ hostname : "localhost" ,
316
+ expected_sni : "" ,
317
+ },
318
+ {
319
+ name : "SNI is not set for IPv4" ,
320
+ conn_param : "" ,
321
+ hostname : "127.0.0.1" ,
322
+ expected_sni : "" ,
323
+ },
324
+ }
325
+ for _ , tt := range tests {
326
+ tt := tt
327
+ t .Run (tt .name , func (t * testing.T ) {
328
+ t .Parallel ()
329
+
330
+ // Start mock postgres server on OS-provided port
331
+ listener , err := net .Listen ("tcp" , "127.0.0.1:" )
332
+ if err != nil {
333
+ t .Fatal (err )
334
+ }
335
+ serverErrChan := make (chan error , 1 )
336
+ serverSNINameChan := make (chan string , 1 )
337
+ go mockPostgresSSL (listener , serverErrChan , serverSNINameChan )
338
+
339
+ defer listener .Close ()
340
+ defer close (serverErrChan )
341
+ defer close (serverSNINameChan )
342
+
343
+ // Try to establish a connection with the mock server. Connection will error out after TLS
344
+ // clientHello, but it is enough to catch SNI data on the server side
345
+ port := strings .Split (listener .Addr ().String (), ":" )[1 ]
346
+ connStr := fmt .Sprintf ("sslmode=require host=%s port=%s %s" , tt .hostname , port , tt .conn_param )
347
+
348
+ // We are okay to skip this error as we are polling serverErrChan and we'll get an error
349
+ // or timeout from the server side in case of problems here.
350
+ db , _ := sql .Open ("postgres" , connStr )
351
+ _ , _ = db .Exec ("SELECT 1" )
352
+
353
+ // Check SNI data
354
+ select {
355
+ case sniHost := <- serverSNINameChan :
356
+ if sniHost != tt .expected_sni {
357
+ t .Fatalf ("Expected SNI to be 'localhost', got '%+v' instead" , sniHost )
358
+ }
359
+ case err = <- serverErrChan :
360
+ t .Fatalf ("mock server failed with error: %+v" , err )
361
+ case <- time .After (time .Second ):
362
+ t .Fatal ("exceeded connection timeout without erroring out" )
363
+ }
364
+ })
365
+ }
366
+ }
367
+
368
+ // Make a postgres mock server to test TLS SNI
369
+ //
370
+ // Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
371
+ // While reading clientHello catch passed SNI data and report it to nameChan.
372
+ func mockPostgresSSL (listener net.Listener , errChan chan error , nameChan chan string ) {
373
+ var sniHost string
374
+
375
+ conn , err := listener .Accept ()
376
+ if err != nil {
377
+ errChan <- err
378
+ return
379
+ }
380
+ defer conn .Close ()
381
+
382
+ err = conn .SetDeadline (time .Now ().Add (time .Second ))
383
+ if err != nil {
384
+ errChan <- err
385
+ return
386
+ }
387
+
388
+ // Receive StartupMessage with SSL Request
389
+ startupMessage := make ([]byte , 8 )
390
+ if _ , err := io .ReadFull (conn , startupMessage ); err != nil {
391
+ errChan <- err
392
+ return
393
+ }
394
+ // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
395
+ if ! bytes .Equal (startupMessage , []byte {0 , 0 , 0 , 0x8 , 0x4 , 0xd2 , 0x16 , 0x2f }) {
396
+ errChan <- fmt .Errorf ("unexpected startup message: %#v" , startupMessage )
397
+ return
398
+ }
399
+
400
+ // Respond with SSLOk
401
+ _ , err = conn .Write ([]byte ("S" ))
402
+ if err != nil {
403
+ errChan <- err
404
+ return
405
+ }
406
+
407
+ // Set up TLS context to catch clientHello. It will always error out during handshake
408
+ // as no certificate is set.
409
+ srv := tls .Server (conn , & tls.Config {
410
+ GetConfigForClient : func (argHello * tls.ClientHelloInfo ) (* tls.Config , error ) {
411
+ sniHost = argHello .ServerName
412
+ return nil , nil
413
+ },
414
+ })
415
+ defer srv .Close ()
416
+
417
+ // Do the TLS handshake ignoring errors
418
+ _ = srv .Handshake ()
419
+
420
+ nameChan <- sniHost
421
+ }
0 commit comments