// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "crypto/tls" "errors" "sync" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v3/test" ) const nonZeroRetransmitInterval = 100 * time.Millisecond // Test that writes to the key log are in the correct format and only applies // when a key log writer is given. func TestWriteKeyLog(t *testing.T) { var buf bytes.Buffer cfg := handshakeConfig{ keyLogWriter: &buf, } cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) // Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret> // https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format want := "LABEL aabbcc ddeeff\n" if buf.String() != want { t.Fatalf("Got %s want %s", buf.String(), want) } // no key log writer = no writes cfg = handshakeConfig{} cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) } func TestHandshaker(t *testing.T) { //nolint:gocyclo,cyclop,maintidx // Check for leaking routines report := test.CheckRoutines(t) defer report() loggerFactory := logging.NewDefaultLoggerFactory() logger := loggerFactory.NewLogger("dtls") cipherSuites, err := parseCipherSuites(nil, nil, true, false) if err != nil { t.Fatal(err) } clientCert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } genFilters := map[string]func() (TestEndpoint, TestEndpoint, func(t *testing.T)){ "PassThrough": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { return TestEndpoint{}, TestEndpoint{}, nil }, "HelloVerifyRequestLost": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { var ( cntHelloVerifyRequest = 0 cntClientHelloNoCookie = 0 ) const helloVerifyDrop = 5 clientEndpoint := TestEndpoint{ Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if hmch, ok := h.Message.(*handshake.MessageClientHello); ok { if len(hmch.Cookie) == 0 { cntClientHelloNoCookie++ } } return true }, } serverEndpoint := TestEndpoint{ Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok { cntHelloVerifyRequest++ return cntHelloVerifyRequest > helloVerifyDrop } return true }, } report := func(t *testing.T) { t.Helper() if cntHelloVerifyRequest != helloVerifyDrop+1 { t.Errorf( "Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest, ) } if cntClientHelloNoCookie != cntHelloVerifyRequest { ///nolint:lll t.Errorf( "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", cntHelloVerifyRequest, cntClientHelloNoCookie, ) } } return clientEndpoint, serverEndpoint, report }, "NoLatencyTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { var ( cntClientFinished = 0 cntServerFinished = 0 ) clientEndpoint := TestEndpoint{ Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if _, ok := h.Message.(*handshake.MessageFinished); ok { cntClientFinished++ } return true }, } serverEndpoint := TestEndpoint{ Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if _, ok := h.Message.(*handshake.MessageFinished); ok { cntServerFinished++ } return true }, } report := func(t *testing.T) { t.Helper() if cntClientFinished != 1 { t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished) } if cntServerFinished != 1 { t.Errorf("Number of server finished is wrong, expected: %d times, got: %d times", 1, cntServerFinished) } } return clientEndpoint, serverEndpoint, report }, "SlowServerTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) { var ( cntClientFinished = 0 isClientFinished = false cntClientFinishedLastRetransmit = 0 cntServerFinished = 0 isServerFinished = false cntServerFinishedLastRetransmit = 0 ) clientEndpoint := TestEndpoint{ Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if _, ok := h.Message.(*handshake.MessageFinished); ok { if isClientFinished { cntClientFinishedLastRetransmit++ } else { cntClientFinished++ } } return true }, Delay: 0, OnFinished: func() { isClientFinished = true }, FinishWait: 2000 * time.Millisecond, } serverEndpoint := TestEndpoint{ Filter: func(p *packet) bool { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return true } if _, ok := h.Message.(*handshake.MessageFinished); ok { if isServerFinished { cntServerFinishedLastRetransmit++ } else { cntServerFinished++ } } return true }, Delay: 1000 * time.Millisecond, OnFinished: func() { isServerFinished = true }, FinishWait: 2000 * time.Millisecond, } report := func(t *testing.T) { t.Helper() // with one second server delay and 100 ms retransmit (+ exponential backoff), // there should be close to 4 `Finished` from client // using a range of 3 - 5 for checking. if cntClientFinished < 3 || cntClientFinished > 5 { t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 3, 5, cntClientFinished) } if !isClientFinished { t.Errorf("Client is not finished") } // there should be no `Finished` last retransmit from client if cntClientFinishedLastRetransmit != 0 { t.Errorf( "Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit, ) } if cntServerFinished < 1 { t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished) } if !isServerFinished { t.Errorf("Server is not finished") } // there should be `Finished` last retransmit from server. // Because of slow server, client would have sent several `Finished`. if cntServerFinishedLastRetransmit < 1 { t.Errorf( "Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit, ) } } return clientEndpoint, serverEndpoint, report }, } for name, filters := range genFilters { clientEndpoint, serverEndpoint, report := filters() t.Run(name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() if report != nil { defer report(t) } ca, cb := flightTestPipe(ctx, clientEndpoint, serverEndpoint) ca.state.isClient = true var wg sync.WaitGroup wg.Add(2) ctxCliFinished, cancelCli := context.WithCancel(ctx) ctxSrvFinished, cancelSrv := context.WithCancel(ctx) go func() { defer wg.Done() cfg := &handshakeConfig{ localCipherSuites: cipherSuites, localCertificates: []tls.Certificate{clientCert}, ellipticCurves: defaultCurves, localSignatureSchemes: signaturehash.Algorithms(), insecureSkipVerify: true, log: logger, onFlightState: func(_ flightVal, s handshakeState) { if s == handshakeFinished { if clientEndpoint.OnFinished != nil { clientEndpoint.OnFinished() } time.AfterFunc(clientEndpoint.FinishWait, func() { cancelCli() }) } }, initialRetransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) err := fsm.Run(ctx, ca, handshakePreparing) switch { case errors.Is(err, context.Canceled): case errors.Is(err, context.DeadlineExceeded): t.Error("Timeout") default: t.Error(err) } }() go func() { defer wg.Done() cfg := &handshakeConfig{ localCipherSuites: cipherSuites, localCertificates: []tls.Certificate{clientCert}, ellipticCurves: defaultCurves, localSignatureSchemes: signaturehash.Algorithms(), insecureSkipVerify: true, log: logger, onFlightState: func(_ flightVal, s handshakeState) { if s == handshakeFinished { if serverEndpoint.OnFinished != nil { serverEndpoint.OnFinished() } time.AfterFunc(serverEndpoint.FinishWait, func() { cancelSrv() }) } }, initialRetransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0) err := fsm.Run(ctx, cb, handshakePreparing) switch { case errors.Is(err, context.Canceled): case errors.Is(err, context.DeadlineExceeded): t.Error("Timeout") default: t.Error(err) } }() <-ctxCliFinished.Done() <-ctxSrvFinished.Done() cancel() wg.Wait() }) } } type packetFilter func(p *packet) bool type TestEndpoint struct { Filter packetFilter Delay time.Duration OnFinished func() FinishWait time.Duration } func flightTestPipe( ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint, ) (*flightTestConn, *flightTestConn) { ca := newHandshakeCache() cb := newHandshakeCache() chA := make(chan recvHandshakeState) chB := make(chan recvHandshakeState) return &flightTestConn{ handshakeCache: ca, otherEndCache: cb, recv: chA, otherEndRecv: chB, done: ctx.Done(), filter: clientEndpoint.Filter, delay: clientEndpoint.Delay, }, &flightTestConn{ handshakeCache: cb, otherEndCache: ca, recv: chB, otherEndRecv: chA, done: ctx.Done(), filter: serverEndpoint.Filter, delay: serverEndpoint.Delay, } } type flightTestConn struct { state State handshakeCache *handshakeCache recv chan recvHandshakeState done <-chan struct{} epoch uint16 filter packetFilter delay time.Duration otherEndCache *handshakeCache otherEndRecv chan recvHandshakeState } func (c *flightTestConn) recvHandshake() <-chan recvHandshakeState { return c.recv } func (c *flightTestConn) setLocalEpoch(epoch uint16) { c.epoch = epoch } func (c *flightTestConn) notify(context.Context, alert.Level, alert.Description) error { return nil } func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error { time.Sleep(c.delay) for _, pkt := range pkts { if c.filter != nil && !c.filter(pkt) { continue } if handshake, ok := pkt.record.Content.(*handshake.Handshake); ok { handshakeRaw, err := pkt.record.Marshal() if err != nil { return err } c.handshakeCache.push( handshakeRaw[recordlayer.FixedHeaderSize:], pkt.record.Header.Epoch, handshake.Header.MessageSequence, handshake.Header.Type, c.state.isClient, ) content, err := handshake.Message.Marshal() if err != nil { return err } handshake.Header.Length = uint32(len(content)) //nolint:gosec // G115 handshake.Header.FragmentLength = uint32(len(content)) //nolint:gosec // G115 hdr, err := handshake.Header.Marshal() if err != nil { return err } c.otherEndCache.push( append(hdr, content...), pkt.record.Header.Epoch, handshake.Header.MessageSequence, handshake.Header.Type, c.state.isClient, ) } } go func() { select { case c.otherEndRecv <- recvHandshakeState{done: make(chan struct{})}: case <-c.done: } }() // Avoid deadlock on JS/WASM environment due to context switch problem. time.Sleep(10 * time.Millisecond) return nil } func (c *flightTestConn) handleQueuedPackets(context.Context) error { return nil } func (c *flightTestConn) sessionKey() []byte { return nil }