Skip to content

Commit d5affd5

Browse files
authoredSep 6, 2022
Merge pull request #1088 from kelvich/sni_support
2 parents d65e6ae + 957fc0b commit d5affd5

File tree

3 files changed

+153
-1
lines changed

3 files changed

+153
-1
lines changed
 

‎conn.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ func isDriverSetting(key string) bool {
11271127
return true
11281128
case "password":
11291129
return true
1130-
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline":
1130+
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni":
11311131
return true
11321132
case "fallback_application_name":
11331133
return true
@@ -2020,6 +2020,8 @@ func parseEnviron(env []string) (out map[string]string) {
20202020
accrue("sslkey")
20212021
case "PGSSLROOTCERT":
20222022
accrue("sslrootcert")
2023+
case "PGSSLSNI":
2024+
accrue("sslsni")
20232025
case "PGREQUIRESSL", "PGSSLCRL":
20242026
unsupported()
20252027
case "PGREQUIREPEER":

‎ssl.go

+11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"os/user"
1010
"path/filepath"
11+
"strings"
1112
)
1213

1314
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
@@ -50,6 +51,16 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
5051
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
5152
}
5253

54+
// Set Server Name Indication (SNI), if enabled by connection parameters.
55+
// By default SNI is on, any value which is not starting with "1" disables
56+
// SNI -- that is the same check vanilla libpq uses.
57+
if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") {
58+
// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4
59+
// or IPv6). This check is coded already crypto.tls.hostnameInSNI, so
60+
// just always set ServerName here and let crypto/tls do the filtering.
61+
tlsConf.ServerName = o["host"]
62+
}
63+
5364
err := sslClientCertificates(&tlsConf, o)
5465
if err != nil {
5566
return nil, err

‎ssl_test.go

+139
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@ package pq
33
// This file contains SSL tests
44

55
import (
6+
"bytes"
67
_ "crypto/sha256"
8+
"crypto/tls"
79
"crypto/x509"
810
"database/sql"
11+
"fmt"
12+
"io"
13+
"net"
914
"os"
1015
"path/filepath"
16+
"strings"
1117
"testing"
18+
"time"
1219
)
1320

1421
func maybeSkipSSLTests(t *testing.T) {
@@ -280,3 +287,135 @@ func TestSSLClientCertificates(t *testing.T) {
280287
}
281288
}
282289
}
290+
291+
// Check that clint sends SNI data when `sslsni` is not disabled
292+
func TestSNISupport(t *testing.T) {
293+
t.Parallel()
294+
tests := []struct {
295+
name string
296+
conn_param string
297+
hostname string
298+
expected_sni string
299+
}{
300+
{
301+
name: "SNI is set by default",
302+
conn_param: "",
303+
hostname: "localhost",
304+
expected_sni: "localhost",
305+
},
306+
{
307+
name: "SNI is passed when asked for",
308+
conn_param: "sslsni=1",
309+
hostname: "localhost",
310+
expected_sni: "localhost",
311+
},
312+
{
313+
name: "SNI is not passed when disabled",
314+
conn_param: "sslsni=0",
315+
hostname: "localhost",
316+
expected_sni: "",
317+
},
318+
{
319+
name: "SNI is not set for IPv4",
320+
conn_param: "",
321+
hostname: "127.0.0.1",
322+
expected_sni: "",
323+
},
324+
}
325+
for _, tt := range tests {
326+
tt := tt
327+
t.Run(tt.name, func(t *testing.T) {
328+
t.Parallel()
329+
330+
// Start mock postgres server on OS-provided port
331+
listener, err := net.Listen("tcp", "127.0.0.1:")
332+
if err != nil {
333+
t.Fatal(err)
334+
}
335+
serverErrChan := make(chan error, 1)
336+
serverSNINameChan := make(chan string, 1)
337+
go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)
338+
339+
defer listener.Close()
340+
defer close(serverErrChan)
341+
defer close(serverSNINameChan)
342+
343+
// Try to establish a connection with the mock server. Connection will error out after TLS
344+
// clientHello, but it is enough to catch SNI data on the server side
345+
port := strings.Split(listener.Addr().String(), ":")[1]
346+
connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param)
347+
348+
// We are okay to skip this error as we are polling serverErrChan and we'll get an error
349+
// or timeout from the server side in case of problems here.
350+
db, _ := sql.Open("postgres", connStr)
351+
_, _ = db.Exec("SELECT 1")
352+
353+
// Check SNI data
354+
select {
355+
case sniHost := <-serverSNINameChan:
356+
if sniHost != tt.expected_sni {
357+
t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost)
358+
}
359+
case err = <-serverErrChan:
360+
t.Fatalf("mock server failed with error: %+v", err)
361+
case <-time.After(time.Second):
362+
t.Fatal("exceeded connection timeout without erroring out")
363+
}
364+
})
365+
}
366+
}
367+
368+
// Make a postgres mock server to test TLS SNI
369+
//
370+
// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
371+
// While reading clientHello catch passed SNI data and report it to nameChan.
372+
func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
373+
var sniHost string
374+
375+
conn, err := listener.Accept()
376+
if err != nil {
377+
errChan <- err
378+
return
379+
}
380+
defer conn.Close()
381+
382+
err = conn.SetDeadline(time.Now().Add(time.Second))
383+
if err != nil {
384+
errChan <- err
385+
return
386+
}
387+
388+
// Receive StartupMessage with SSL Request
389+
startupMessage := make([]byte, 8)
390+
if _, err := io.ReadFull(conn, startupMessage); err != nil {
391+
errChan <- err
392+
return
393+
}
394+
// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
395+
if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
396+
errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
397+
return
398+
}
399+
400+
// Respond with SSLOk
401+
_, err = conn.Write([]byte("S"))
402+
if err != nil {
403+
errChan <- err
404+
return
405+
}
406+
407+
// Set up TLS context to catch clientHello. It will always error out during handshake
408+
// as no certificate is set.
409+
srv := tls.Server(conn, &tls.Config{
410+
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
411+
sniHost = argHello.ServerName
412+
return nil, nil
413+
},
414+
})
415+
defer srv.Close()
416+
417+
// Do the TLS handshake ignoring errors
418+
_ = srv.Handshake()
419+
420+
nameChan <- sniHost
421+
}

0 commit comments

Comments
 (0)
Please sign in to comment.