// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> // SPDX-License-Identifier: MIT package dtls import ( "testing" "time" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func TestRandomConnectionIDGenerator(t *testing.T) { cases := map[string]struct { reason string size int }{ "LengthMatch": { reason: "Zero size should match length of generated CID.", size: 0, }, "LengthMatchSome": { reason: "Non-zero size should match length of generated CID with non-zero.", size: 8, }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { if cidLen := len(RandomCIDGenerator(tc.size)()); cidLen != tc.size { t.Errorf("%s\nRandomCIDGenerator: expected CID length %d, but got %d.", tc.reason, tc.size, cidLen) } }) } } func TestOnlySendCIDGenerator(t *testing.T) { cases := map[string]struct { reason string }{ "LengthMatch": { reason: "CID length should always be zero.", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { if cidLen := len(OnlySendCIDGenerator()()); cidLen != 0 { t.Errorf("%s\nOnlySendCIDGenerator: expected CID length %d, but got %d.", tc.reason, 0, cidLen) } }) } } func TestCIDDatagramRouter(t *testing.T) { cid := []byte("abcd1234") cidLen := 8 appRecord, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: []byte("application data"), }, }).Marshal() if err != nil { t.Fatal(err) } appData, err := (&protocol.ApplicationData{ Data: []byte("some data"), }).Marshal() if err != nil { t.Fatal(err) } inner, err := (&recordlayer.InnerPlaintext{ Content: appData, RealType: protocol.ContentTypeApplicationData, }).Marshal() if err != nil { t.Fatal(err) } cidHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: cid, SequenceNumber: 1, }).Marshal() if err != nil { t.Fatal(err) } cases := map[string]struct { reason string size int datagram []byte ok bool want string }{ "EmptyDatagram": { reason: "If datagram is empty, we cannot extract an identifier", size: cidLen, datagram: []byte{}, ok: false, want: "", }, "NotADTLSRecord": { reason: "If datagram is not a DTLS record, we cannot extract an identifier", size: cidLen, datagram: []byte("not a DTLS record"), ok: false, want: "", }, "NotAConnectionIDDatagram": { reason: "If datagram does not contain any Connection ID records, we cannot extract an identifier", size: cidLen, datagram: appRecord, ok: false, want: "", }, "OneRecordConnectionID": { reason: "If datagram contains one Connection ID record, we should be able to extract it.", size: cidLen, datagram: append(cidHeader, inner...), ok: true, want: string(cid), }, "OneRecordConnectionIDAltLength": { //nolint:lll reason: "If datagram contains one Connection ID record, but it has the wrong length we should not be able to extract it.", size: cidLen, datagram: func() []byte { altCIDHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("abcd"), SequenceNumber: 1, }).Marshal() if err != nil { t.Fatal(err) } return append(altCIDHeader, inner...) }(), ok: false, want: "", }, "MultipleRecordOneConnectionID": { //nolint:lll reason: "If datagram contains multiple records and one is a Connection ID record, we should be able to extract it.", size: 8, datagram: append(append(appRecord, cidHeader...), inner...), ok: true, want: string(cid), }, "MultipleRecordMultipleConnectionID": { //nolint:lll reason: "If datagram contains multiple records and multiple are Connection ID records, we should extract the first one.", size: 8, datagram: append(append(append(appRecord, func() []byte { altCIDHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("1234abcd"), SequenceNumber: 1, }).Marshal() if err != nil { t.Fatal(err) } return append(altCIDHeader, inner...) }()...), cidHeader...), inner...), ok: true, want: "1234abcd", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidDatagramRouter(tc.size)(tc.datagram) if ok != tc.ok { t.Errorf("%s\ncidDatagramRouter: expected ok %t, but got %t.", tc.reason, tc.ok, ok) } if cid != tc.want { t.Errorf("%s\ncidDatagramRouter: expected CID %s, but got %s.", tc.reason, tc.want, cid) } }) } } func TestCIDConnIdentifier(t *testing.T) { cid := []byte("abcd1234") cs := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) sh, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 0, Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: [28]byte{}}, SessionID: []byte("hello"), CipherSuiteID: &cs, CompressionMethod: defaultCompressionMethods()[0], Extensions: []extension.Extension{ &extension.ConnectionID{ CID: cid, }, }, }, }, }).Marshal() if err != nil { t.Fatal(err) } appRecord, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: []byte("application data"), }, }).Marshal() if err != nil { t.Fatal(err) } cases := map[string]struct { reason string datagram []byte ok bool want string }{ "EmptyDatagram": { reason: "If datagram is empty, we cannot extract an identifier", datagram: []byte{}, ok: false, want: "", }, "NotADTLSRecord": { reason: "If datagram is not a DTLS record, we cannot extract an identifier", datagram: []byte("not a DTLS record"), ok: false, want: "", }, "NotAServerhelloDatagram": { reason: "If datagram does not contain any ServerHello record, we cannot extract an identifier", datagram: appRecord, ok: false, want: "", }, "OneRecordServerHello": { reason: "If datagram contains one ServerHello record, we should be able to extract an identifier.", datagram: sh, ok: true, want: string(cid), }, "MultipleRecordFirstServerHello": { //nolint:lll reason: "If datagram contains multiple records and the first is a ServerHello record, we should be able to extract an identifier.", datagram: append(sh, appRecord...), ok: true, want: string(cid), }, "MultipleRecordNotFirstServerHello": { //nolint:lll reason: "If datagram contains multiple records and the first is not a ServerHello record, we should not be able to extract an identifier.", datagram: append(appRecord, sh...), ok: false, want: "", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidConnIdentifier()(tc.datagram) if ok != tc.ok { t.Errorf("%s\ncidConnIdentifier: expected ok %t, but got %t.", tc.reason, tc.ok, ok) } if cid != tc.want { t.Errorf("%s\ncidConnIdentifier: expected CID %s, but got %s.", tc.reason, tc.want, cid) } }) } }