diff --git a/.github/workflows/generate-authors.yml b/.github/workflows/api.yaml similarity index 73% rename from .github/workflows/generate-authors.yml rename to .github/workflows/api.yaml index ec7446c84..1032179e3 100644 --- a/.github/workflows/generate-authors.yml +++ b/.github/workflows/api.yaml @@ -11,13 +11,10 @@ # SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> # SPDX-License-Identifier: MIT -name: Generate Authors - +name: API on: pull_request: jobs: - generate: - uses: pion/.goassets/.github/workflows/generate-authors.reusable.yml@master - secrets: - token: ${{ secrets.PIONBOT_PRIVATE_KEY }} + check: + uses: pion/.goassets/.github/workflows/api.reusable.yml@master diff --git a/.github/workflows/e2e.yaml b/.github/workflows/e2e.yaml index 51809cbf3..52a90f173 100644 --- a/.github/workflows/e2e.yaml +++ b/.github/workflows/e2e.yaml @@ -14,9 +14,10 @@ jobs: e2e-test: name: Test runs-on: ubuntu-latest + timeout-minutes: 10 steps: - name: checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: test run: | docker build -t pion-dtls-e2e -f e2e/Dockerfile . diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 01227e2a5..0e72ea4d3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -21,4 +21,4 @@ jobs: release: uses: pion/.goassets/.github/workflows/release.reusable.yml@master with: - go-version: '1.20' # auto-update/latest-go-version + go-version: "1.22" # auto-update/latest-go-version diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 31aada4af..b02428931 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -23,16 +23,17 @@ jobs: uses: pion/.goassets/.github/workflows/test.reusable.yml@master strategy: matrix: - go: ['1.20', '1.19'] # auto-update/supported-go-version-list + go: ["1.23", "1.22"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} + secrets: inherit test-i386: uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master strategy: matrix: - go: ['1.20', '1.19'] # auto-update/supported-go-version-list + go: ["1.23", "1.22"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} @@ -40,4 +41,5 @@ jobs: test-wasm: uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master with: - go-version: '1.20' # auto-update/latest-go-version + go-version: "1.23" # auto-update/latest-go-version + secrets: inherit diff --git a/.github/workflows/tidy-check.yaml b/.github/workflows/tidy-check.yaml index 4d346d4fd..417e730a5 100644 --- a/.github/workflows/tidy-check.yaml +++ b/.github/workflows/tidy-check.yaml @@ -22,4 +22,4 @@ jobs: tidy: uses: pion/.goassets/.github/workflows/tidy-check.reusable.yml@master with: - go-version: '1.20' # auto-update/latest-go-version + go-version: "1.22" # auto-update/latest-go-version diff --git a/.golangci.yml b/.golangci.yml index 4e3eddf42..88cb4fbf9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,9 +1,13 @@ # SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> # SPDX-License-Identifier: MIT +run: + timeout: 5m + linters-settings: govet: - check-shadowing: true + enable: + - shadow misspell: locale: US exhaustive: @@ -21,18 +25,32 @@ linters-settings: - ^os.Exit$ - ^panic$ - ^print(ln)?$ + varnamelen: + max-distance: 12 + min-name-length: 2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-chan-recv-ok: true + ignore-decls: + - i int + - n int + - w io.Writer + - r io.Reader + - b []byte linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully + - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - - depguard # Go linter that checks if package imports are in a list of acceptable packages - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together + - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. @@ -43,18 +61,17 @@ linters: - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - - goerr113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code @@ -62,10 +79,15 @@ linters: - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length + - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes @@ -73,29 +95,22 @@ linters: - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types + - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - - containedctx # containedctx is a linter that detects struct contained context.Context field - - cyclop # checks function and package cyclomatic complexity - - exhaustivestruct # Checks if all struct's fields are initialized + - depguard # Go linter that checks if package imports are in a list of acceptable packages - funlen # Tool for detection of long functions - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - gomnd # An analyzer to detect magic numbers. - - ifshort # Checks that your code uses short syntax for if-statements whenever possible + - gochecknoinits # Checks that no init functions are present in Go code + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - - lll # Reports long lines - - maintidx # maintidx measures the maintainability index of each function. - - makezero # Finds slice declarations with non-zero initial length - - maligned # Tool to detect Go structs that would take less memory if their fields were sorted - - nestif # Reports deeply nested if statements - - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated @@ -103,35 +118,21 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - - varnamelen # checks that the length of a variable's name matches its scope + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! issues: exclude-use-default: false + exclude-dirs-use-default: false exclude-rules: - # Allow complex tests, better to be self contained - - path: _test\.go + # Allow complex tests and examples, better to be self contained + - path: (examples|main\.go|_test\.go) linters: - - gocognit - forbidigo - - # Allow complex main function in examples - - path: examples - text: "of func `main` is high" - linters: - gocognit - - # Allow forbidden identifiers in examples - - path: examples - linters: - - forbidigo # Allow forbidden identifiers in CLI commands - path: cmd linters: - forbidigo - -run: - skip-dirs-use-default: false diff --git a/.reuse/dep5 b/.reuse/dep5 index c8b3dfa09..4ce056940 100644 --- a/.reuse/dep5 +++ b/.reuse/dep5 @@ -2,10 +2,10 @@ Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: Pion Source: https://github.com/pion/ -Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum .eslintrc.json package.json examples/examples.json +Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum **/go.mod **/go.sum .eslintrc.json package.json examples.json sfu-ws/flutter/.gitignore sfu-ws/flutter/pubspec.yaml c-data-channels/webrtc.h examples/examples.json Copyright: 2023 The Pion community <https://pion.ly> License: MIT -Files: testdata/fuzz/* **/testdata/fuzz/* api/*.txt +Files: testdata/seed/* testdata/fuzz/* **/testdata/fuzz/* api/*.txt Copyright: 2023 The Pion community <https://pion.ly> License: CC0-1.0 diff --git a/AUTHORS.txt b/AUTHORS.txt deleted file mode 100644 index e14fae4c0..000000000 --- a/AUTHORS.txt +++ /dev/null @@ -1,57 +0,0 @@ -# Thank you to everyone that made Pion possible. If you are interested in contributing -# we would love to have you https://github.com/pion/webrtc/wiki/Contributing -# -# This file is auto generated, using git to list all individuals contributors. -# see https://github.com/pion/.goassets/blob/master/scripts/generate-authors.sh for the scripting -Aleksandr Razumov <ar@gortc.io> -alvarowolfx <alvarowolfx@gmail.com> -Arlo Breault <arlolra@gmail.com> -Atsushi Watanabe <atsushi.w@ieee.org> -backkem <mail@backkem.me> -bjdgyc <bjdgyc@163.com> -boks1971 <raja.gobi@tutanota.com> -Bragadeesh <bragboy@gmail.com> -Carson Hoffman <c@rsonhoffman.com> -Cecylia Bocovich <cohosh@torproject.org> -Chris Hiszpanski <thinkski@users.noreply.github.com> -cnderrauber <zengjie9004@gmail.com> -Daniele Sluijters <daenney@users.noreply.github.com> -folbrich <frank.olbricht@gmail.com> -Hayden James <hayden.james@gmail.com> -Hugo Arregui <hugo.arregui@gmail.com> -Hugo Arregui <hugo@decentraland.org> -igolaizola <11333576+igolaizola@users.noreply.github.com> -Jeffrey Stoke <me@arhat.dev> -Jeroen de Bruijn <vidavidorra+jdbruijn@gmail.com> -Jeroen de Bruijn <vidavidorra@gmail.com> -Jim Wert <jimwert@gmail.com> -jinleileiking <jinleileiking@gmail.com> -Jozef Kralik <jojo.lwin@gmail.com> -Julien Salleyron <julien.salleyron@gmail.com> -Juliusz Chroboczek <jch@irif.fr> -Kegan Dougal <kegan@matrix.org> -Kevin Wang <kevmo314@gmail.com> -Lander Noterman <lander.noterman@basalte.be> -Len <len@hpcnt.com> -Lukas Lihotzki <lukas@lihotzki.de> -ManuelBk <26275612+ManuelBk@users.noreply.github.com> -Michael Zabka <zabka.michael@gmail.com> -Michiel De Backker <mail@backkem.me> -Rachel Chen <rachel@chens.email> -Robert Eperjesi <eperjesi@uber.com> -Ryan Gordon <ryan.gordon@getcruise.com> -Sam Lancia <sam.lancia@motorolasolutions.com> -Sean DuBois <duboisea@justin.tv> -Sean DuBois <seaduboi@amazon.com> -Sean DuBois <sean@siobud.com> -Shelikhoo <xiaokangwang@outlook.com> -Stefan Tatschner <stefan@rumpelsepp.org> -Steffen Vogel <post@steffenvogel.de> -Vadim <fffilimonov@yandex.ru> -Vadim Filimonov <fffilimonov@yandex.ru> -wmiao <wu.miao@viasat.com> -ZHENK <chengzhenyang@gmail.com> -吕海涛 <hi@taoshu.in> - -# List of contributors not appearing in Git history - diff --git a/README.md b/README.md index 0c0659593..fa00c95d8 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,9 @@ <a href="https://pion.ly/slack"><img src="https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=brightgreen" alt="Slack Widget"></a> <br> <img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/pion/dtls/test.yaml"> - <a href="https://pkg.go.dev/github.com/pion/dtls/v2"><img src="https://pkg.go.dev/badge/github.com/pion/dtls/v2.svg" alt="Go Reference"></a> + <a href="https://pkg.go.dev/github.com/pion/dtls/v3"><img src="https://pkg.go.dev/badge/github.com/pion/dtls/v3.svg" alt="Go Reference"></a> <a href="https://codecov.io/gh/pion/dtls"><img src="https://codecov.io/gh/pion/dtls/branch/master/graph/badge.svg" alt="Coverage Status"></a> - <a href="https://goreportcard.com/report/github.com/pion/dtls/v2"><img src="https://goreportcard.com/badge/github.com/pion/dtls/v2" alt="Go Report Card"></a> + <a href="https://goreportcard.com/report/github.com/pion/dtls/v3"><img src="https://goreportcard.com/badge/github.com/pion/dtls/v3" alt="Go Report Card"></a> <a href="LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg" alt="License: MIT"></a> </p> <br> @@ -145,7 +145,7 @@ We are always looking to support **your projects**. Please reach out if you have If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing -Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible: [AUTHORS.txt](./AUTHORS.txt) +Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible ### License MIT License - see [LICENSE](LICENSE) for full text diff --git a/bench_test.go b/bench_test.go index abec5a5d7..8d90786cb 100644 --- a/bench_test.go +++ b/bench_test.go @@ -10,10 +10,11 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/logging" - "github.com/pion/transport/v2/dpipe" - "github.com/pion/transport/v2/test" + "github.com/pion/transport/v3/dpipe" + "github.com/pion/transport/v3/test" ) func TestSimpleReadWrite(t *testing.T) { @@ -30,16 +31,17 @@ func TestSimpleReadWrite(t *testing.T) { gotHello := make(chan struct{}) go func() { - server, sErr := testServer(ctx, cb, &Config{ + server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) if sErr != nil { t.Error(sErr) + return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} @@ -48,7 +50,7 @@ func TestSimpleReadWrite(t *testing.T) { } }() - client, err := testClient(ctx, ca, &Config{ + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) @@ -70,19 +72,22 @@ func TestSimpleReadWrite(t *testing.T) { } } -func benchmarkConn(b *testing.B, n int64) { - b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { +func benchmarkConn(b *testing.B, payloadSize int64) { + b.Helper() + + b.Run(fmt.Sprintf("%d", payloadSize), func(b *testing.B) { ctx := context.Background() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() server := make(chan *Conn) go func() { - s, sErr := testServer(ctx, cb, &Config{ + s, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, }, false) if err != nil { b.Error(sErr) + return } server <- s @@ -90,11 +95,13 @@ func benchmarkConn(b *testing.B, n int64) { if err != nil { b.Fatal(err) } - hw := make([]byte, n) + hw := make([]byte, payloadSize) b.ReportAllocs() b.SetBytes(int64(len(hw))) go func() { - client, cErr := testClient(ctx, ca, &Config{InsecureSkipVerify: true}, false) + client, cErr := testClient( + ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false, + ) if cErr != nil { b.Error(err) } diff --git a/certificate.go b/certificate.go index 519fc875f..524b8e063 100644 --- a/certificate.go +++ b/certificate.go @@ -9,6 +9,8 @@ import ( "crypto/x509" "fmt" "strings" + + "github.com/pion/dtls/v3/pkg/protocol/handshake" ) // ClientHelloInfo contains information from a ClientHello message in order to @@ -22,6 +24,9 @@ type ClientHelloInfo struct { // CipherSuites lists the CipherSuites supported by the client (e.g. // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). CipherSuites []CipherSuiteID + + // RandomBytes stores the client hello random bytes + RandomBytes [handshake.RandomBytesLength]byte } // CertificateRequestInfo contains information from a server's @@ -38,7 +43,8 @@ type CertificateRequestInfo struct { // SupportsCertificate returns nil if the provided certificate is supported by // the server that sent the CertificateRequest. Otherwise, it returns an error // describing the reason for the incompatibility. -// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 +// NOTE: original src: +// https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error { if len(cri.AcceptableCAs) == 0 { return nil @@ -61,6 +67,7 @@ func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error } } } + return errNotAcceptableCertificateChain } @@ -86,6 +93,7 @@ func (c *handshakeConfig) setNameToCertificateLocked() { c.nameToCertificate = nameToCertificate } +//nolint:cyclop func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() @@ -136,7 +144,8 @@ func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls return &c.localCertificates[0], nil } -// NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 +// NOTE: original src: +// https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() @@ -149,6 +158,7 @@ func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tl if err := cri.SupportsCertificate(&chain); err != nil { continue } + return &chain, nil } diff --git a/certificate_test.go b/certificate_test.go index 5f2e87bb4..37598a639 100644 --- a/certificate_test.go +++ b/certificate_test.go @@ -8,7 +8,7 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) func TestGetCertificate(t *testing.T) { @@ -77,7 +77,7 @@ func TestGetCertificate(t *testing.T) { }, { desc: "Get certificate from callback", - getCertificate: func(info *ClientHelloInfo) (*tls.Certificate, error) { + getCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &certificateTest, nil }, expectedCertificate: certificateTest, diff --git a/cipher_suite.go b/cipher_suite.go index 7a5bb4a58..2b29bf238 100644 --- a/cipher_suite.go +++ b/cipher_suite.go @@ -4,6 +4,7 @@ package dtls import ( + "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -11,54 +12,68 @@ import ( "fmt" "hash" - "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// CipherSuiteID is an ID for our supported CipherSuites +// CipherSuiteID is an ID for our supported CipherSuites. type CipherSuiteID = ciphersuite.ID -// Supported Cipher Suites +// Supported Cipher Suites. const ( // AES-128-CCM - TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:revive,stylecheck - TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 // AES-128-GCM-SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 // AES-256-CBC-SHA - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck - - TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:revive,stylecheck - TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:revive,stylecheck - TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:revive,stylecheck - TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck - TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck - - TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck + //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA + + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 + //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 + + //nolint:revive,stylecheck + TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ) -// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite +// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite. type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType -// AuthenticationType Enums +// AuthenticationType Enums. const ( CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous ) -// CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite +// CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite. type CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithm -// CipherSuiteKeyExchangeAlgorithm Bitmask +// CipherSuiteKeyExchangeAlgorithm Bitmask. const ( CipherSuiteKeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmPsk @@ -67,7 +82,7 @@ const ( var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14 -// CipherSuite is an interface that all DTLS CipherSuites must satisfy +// CipherSuite is an interface that all DTLS CipherSuites must satisfy. type CipherSuite interface { // String of CipherSuite, only used for logging String() string @@ -95,7 +110,7 @@ type CipherSuite interface { Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error IsInitialized() bool Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) - Decrypt(in []byte) ([]byte, error) + Decrypt(h recordlayer.Header, in []byte) ([]byte, error) } // CipherSuiteName provides the same functionality as tls.CipherSuiteName @@ -108,13 +123,14 @@ func CipherSuiteName(id CipherSuiteID) string { if suite != nil { return suite.String() } + return fmt.Sprintf("0x%04X", uint16(id)) } // Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml // A cipherSuite is a specific combination of key agreement, cipher and MAC // function. -func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { +func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { //nolint:cyclop switch id { //nolint:exhaustive case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm() @@ -157,7 +173,7 @@ func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) Ciph return nil } -// CipherSuites we support in order of preference +// CipherSuites we support in order of preference. func defaultCipherSuites() []CipherSuite { return []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, @@ -191,10 +207,16 @@ func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 { for _, c := range cipherSuites { rtrn = append(rtrn, uint16(c.ID())) } + return rtrn } -func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool) ([]CipherSuite, error) { +//nolint:cyclop +func parseCipherSuites( + userSelectedSuites []CipherSuiteID, + customCipherSuites func() []CipherSuite, + includeCertificateSuites, includePSKSuites bool, +) ([]CipherSuite, error) { cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) { cipherSuites := []CipherSuite{} for _, id := range ids { @@ -204,6 +226,7 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites fu } cipherSuites = append(cipherSuites, c) } + return cipherSuites, nil } @@ -258,11 +281,16 @@ func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []Ciph if cert == nil || cert.PrivateKey == nil { return cipherSuites } + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return cipherSuites + } + var certType clientcertificate.Type - switch cert.PrivateKey.(type) { - case ed25519.PrivateKey, *ecdsa.PrivateKey: + switch signer.Public().(type) { + case ed25519.PublicKey, *ecdsa.PublicKey: certType = clientcertificate.ECDSASign - case *rsa.PrivateKey: + case *rsa.PublicKey: certType = clientcertificate.RSASign } @@ -272,5 +300,6 @@ func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []Ciph filtered = append(filtered, c) } } + return filtered } diff --git a/cipher_suite_go114.go b/cipher_suite_go114.go index fd46d7bd9..e7a324147 100644 --- a/cipher_suite_go114.go +++ b/cipher_suite_go114.go @@ -11,10 +11,10 @@ import ( ) // VersionDTLS12 is the DTLS version in the same style as -// VersionTLSXX from crypto/tls +// VersionTLSXX from crypto/tls. const VersionDTLS12 = 0xfefd -// Convert from our cipherSuite interface to a tls.CipherSuite struct +// Convert from our cipherSuite interface to a tls.CipherSuite struct. func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite { return &tls.CipherSuite{ ID: uint16(c.ID()), @@ -33,6 +33,7 @@ func CipherSuites() []*tls.CipherSuite { for i, c := range suites { res[i] = toTLSCipherSuite(c) } + return res } @@ -40,5 +41,6 @@ func CipherSuites() []*tls.CipherSuite { // this package and which have security issues. func InsecureCipherSuites() []*tls.CipherSuite { var res []*tls.CipherSuite + return res } diff --git a/cipher_suite_go114_test.go b/cipher_suite_go114_test.go index 35c4b1ef6..e93b760c5 100644 --- a/cipher_suite_go114_test.go +++ b/cipher_suite_go114_test.go @@ -30,25 +30,25 @@ func TestCipherSuites(t *testing.T) { i := i s := s t.Run(s.String(), func(t *testing.T) { - c := theirs[i] - if c.ID != uint16(s.ID()) { - t.Fatalf("Expected ID: 0x%04X, got 0x%04X", s.ID(), c.ID) + cipher := theirs[i] + if cipher.ID != uint16(s.ID()) { + t.Fatalf("Expected ID: 0x%04X, got 0x%04X", s.ID(), cipher.ID) } - if c.Name != s.String() { - t.Fatalf("Expected Name: %s, got %s", s.String(), c.Name) + if cipher.Name != s.String() { + t.Fatalf("Expected Name: %s, got %s", s.String(), cipher.Name) } - if len(c.SupportedVersions) != 1 { - t.Fatalf("Expected %d SupportedVersion, got %d", 1, len(c.SupportedVersions)) + if len(cipher.SupportedVersions) != 1 { + t.Fatalf("Expected %d SupportedVersion, got %d", 1, len(cipher.SupportedVersions)) } - if c.SupportedVersions[0] != VersionDTLS12 { - t.Fatalf("Expected SupportedVersions 0x%04X, got 0x%04X", VersionDTLS12, c.SupportedVersions[0]) + if cipher.SupportedVersions[0] != VersionDTLS12 { + t.Fatalf("Expected SupportedVersions 0x%04X, got 0x%04X", VersionDTLS12, cipher.SupportedVersions[0]) } - if c.Insecure { - t.Fatalf("Expected Insecure %t, got %t", false, c.Insecure) + if cipher.Insecure { + t.Fatalf("Expected Insecure %t, got %t", false, cipher.Insecure) } }) } diff --git a/cipher_suite_test.go b/cipher_suite_test.go index 655fe6717..c4fd4840a 100644 --- a/cipher_suite_test.go +++ b/cipher_suite_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/transport/v2/dpipe" - "github.com/pion/transport/v2/test" + "github.com/pion/dtls/v3/internal/ciphersuite" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/transport/v3/dpipe" + "github.com/pion/transport/v3/test" ) func TestCipherSuiteName(t *testing.T) { @@ -37,7 +38,7 @@ func TestAllCipherSuites(t *testing.T) { } } -// CustomCipher that is just used to assert Custom IDs work +// CustomCipher that is just used to assert Custom IDs work. type testCustomCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 authenticationType CipherSuiteAuthenticationType @@ -51,7 +52,7 @@ func (t *testCustomCipherSuite) AuthenticationType() CipherSuiteAuthenticationTy return t.authenticationType } -// Assert that two connections that pass in a CipherSuite with a CustomID works +// Assert that two connections that pass in a CipherSuite with a CustomID works. func TestCustomCipherSuite(t *testing.T) { type result struct { c *Conn @@ -67,22 +68,22 @@ func TestCustomCipherSuite(t *testing.T) { defer cancel() ca, cb := dpipe.Pipe() - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{ + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) - c <- result{client, err} + resultCh <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{ + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) - clientResult := <-c + clientResult := <-resultCh if err != nil { t.Error(err) @@ -97,13 +98,13 @@ func TestCustomCipherSuite(t *testing.T) { } } - t.Run("Custom ID", func(t *testing.T) { + t.Run("Custom ID", func(*testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeCertificate}} }) }) - t.Run("Anonymous Cipher", func(t *testing.T) { + t.Run("Anonymous Cipher", func(*testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeAnonymous}} }) diff --git a/compression_method.go b/compression_method.go index 7e44de009..49b11c718 100644 --- a/compression_method.go +++ b/compression_method.go @@ -3,7 +3,7 @@ package dtls -import "github.com/pion/dtls/v2/pkg/protocol" +import "github.com/pion/dtls/v3/pkg/protocol" func defaultCompressionMethods() []*protocol.CompressionMethod { return []*protocol.CompressionMethod{ diff --git a/config.go b/config.go index fbc3ee247..722335139 100644 --- a/config.go +++ b/config.go @@ -4,16 +4,18 @@ package dtls import ( - "context" + "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" "io" + "net" "time" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" ) @@ -44,6 +46,10 @@ type Config struct { // Servers will assert that clients send one of these profiles and will respond as needed SRTPProtectionProfiles []SRTPProtectionProfile + // SRTPMasterKeyIdentifier value (if any) is sent via the use_srtp + // extension for Clients and Servers + SRTPMasterKeyIdentifier []byte + // ClientAuth determines the server's policy for // TLS Client Authentication. The default is NoClientCert. ClientAuth ClientAuthType @@ -56,6 +62,10 @@ type Config struct { // defaults to time.Second FlightInterval time.Duration + // DisableRetransmitBackoff can be used to the disable the backoff feature + // when sending outbound messages as specified in RFC 4347 4.2.4.1 + DisableRetransmitBackoff bool + // PSK sets the pre-shared key used by this DTLS connection // If PSK is non-nil only PSK CipherSuites will be used PSK PSKCallback @@ -112,15 +122,6 @@ type Config struct { LoggerFactory logging.LoggerFactory - // ConnectContextMaker is a function to make a context used in Dial(), - // Client(), Server(), and Accept(). If nil, the default ConnectContextMaker - // is used. It can be implemented as following. - // - // func ConnectContextMaker() (context.Context, func()) { - // return context.WithTimeout(context.Background(), 30*time.Second) - // } - ConnectContextMaker func() (context.Context, func()) - // MTU is the length at which handshake messages will be fragmented to // fit within the maximum transmission unit (default is 1200 bytes) MTU int @@ -176,17 +177,53 @@ type Config struct { // skip hello verify phase and receive ServerHello after initial ClientHello. // This have implication on DoS attack resistance. InsecureSkipVerifyHello bool -} -func defaultConnectContextMaker() (context.Context, func()) { - return context.WithTimeout(context.Background(), 30*time.Second) -} - -func (c *Config) connectContextMaker() (context.Context, func()) { - if c.ConnectContextMaker == nil { - return defaultConnectContextMaker() - } - return c.ConnectContextMaker() + // ConnectionIDGenerator generates connection identifiers that should be + // sent by the remote party if it supports the DTLS Connection Identifier + // extension, as determined during the handshake. Generated connection + // identifiers must always have the same length. Returning a zero-length + // connection identifier indicates that the local party supports sending + // connection identifiers but does not require the remote party to send + // them. A nil ConnectionIDGenerator indicates that connection identifiers + // are not supported. + // https://datatracker.ietf.org/doc/html/rfc9146 + ConnectionIDGenerator func() []byte + + // PaddingLengthGenerator generates the number of padding bytes used to + // inflate ciphertext size in order to obscure content size from observers. + // The length of the content is passed to the generator such that both + // deterministic and random padding schemes can be applied while not + // exceeding maximum record size. + // If no PaddingLengthGenerator is specified, padding will not be applied. + // https://datatracker.ietf.org/doc/html/rfc9146#section-4 + PaddingLengthGenerator func(uint) uint + + // HelloRandomBytesGenerator generates custom client hello random bytes. + HelloRandomBytesGenerator func() [handshake.RandomBytesLength]byte + + // Handshake hooks: hooks can be used for testing invalid messages, + // mimicking other implementations or randomizing fields, which is valuable + // for applications that need censorship-resistance by making + // fingerprinting more difficult. + + // ClientHelloMessageHook, if not nil, is called when a Client Hello message is sent + // from a client. The returned handshake message replaces the original message. + ClientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + + // ServerHelloMessageHook, if not nil, is called when a Server Hello message is sent + // from a server. The returned handshake message replaces the original message. + ServerHelloMessageHook func(handshake.MessageServerHello) handshake.Message + + // CertificateRequestMessageHook, if not nil, is called when a Certificate Request + // message is sent from a server. The returned handshake message replaces the original message. + CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + // OnConnectionAttempt is fired Whenever a connection attempt is made, + // the server or application can call this callback function. + // The callback function can then implement logic to handle the connection attempt, such as logging the attempt, + // checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks. + // If the callback function returns an error, the connection attempt will be aborted. + OnConnectionAttempt func(net.Addr) error } func (c *Config) includeCertificateSuites() bool { @@ -198,14 +235,14 @@ const defaultMTU = 1200 // bytes var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals // PSKCallback is called once we have the remote's PSKIdentityHint. -// If the remote provided none it will be nil +// If the remote provided none it will be nil. type PSKCallback func([]byte) ([]byte, error) // ClientAuthType declares the policy the server will follow for // TLS Client Authentication. type ClientAuthType int -// ClientAuthType enums +// ClientAuthType enums. const ( NoClientCert ClientAuthType = iota RequestClientCert @@ -215,17 +252,17 @@ const ( ) // ExtendedMasterSecretType declares the policy the client and server -// will follow for the Extended Master Secret extension +// will follow for the Extended Master Secret extension. type ExtendedMasterSecretType int -// ExtendedMasterSecretType enums +// ExtendedMasterSecretType enums. const ( RequestExtendedMasterSecret ExtendedMasterSecretType = iota RequireExtendedMasterSecret DisableExtendedMasterSecret ) -func validateConfig(config *Config) error { +func validateConfig(config *Config) error { //nolint:cyclop switch { case config == nil: return errNoConfigProvided @@ -238,16 +275,23 @@ func validateConfig(config *Config) error { return errInvalidCertificate } if cert.PrivateKey != nil { - switch cert.PrivateKey.(type) { - case ed25519.PrivateKey: - case *ecdsa.PrivateKey: - case *rsa.PrivateKey: + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return errInvalidPrivateKey + } + switch signer.Public().(type) { + case ed25519.PublicKey: + case *ecdsa.PublicKey: + case *rsa.PublicKey: default: return errInvalidPrivateKey } } } - _, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + _, err := parseCipherSuites( + config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil, + ) + return err } diff --git a/config_test.go b/config_test.go index 811427a0c..b01de1442 100644 --- a/config_test.go +++ b/config_test.go @@ -11,29 +11,33 @@ import ( "errors" "testing" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) -func TestValidateConfig(t *testing.T) { +func TestValidateConfig(t *testing.T) { //nolint:cyclop cert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), self signed certificate not generated", err) + return } dsaPrivateKey := &dsa.PrivateKey{} err = dsa.GenerateParameters(&dsaPrivateKey.Parameters, rand.Reader, dsa.L1024N160) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), DSA parameters not generated", err) + return } err = dsa.GenerateKey(dsaPrivateKey, rand.Reader) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), DSA private key not generated", err) + return } rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), RSA private key not generated", err) + return } cases := map[string]struct { @@ -47,7 +51,7 @@ func TestValidateConfig(t *testing.T) { "PSK and Certificate, valid cipher suites": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - PSK: func(hint []byte) ([]byte, error) { + PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, @@ -56,7 +60,7 @@ func TestValidateConfig(t *testing.T) { "PSK and Certificate, no PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - PSK: func(hint []byte) ([]byte, error) { + PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, @@ -66,7 +70,7 @@ func TestValidateConfig(t *testing.T) { "PSK and Certificate, no non-PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, - PSK: func(hint []byte) ([]byte, error) { + PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, @@ -108,7 +112,7 @@ func TestValidateConfig(t *testing.T) { "Valid config with get certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { + GetCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, @@ -116,7 +120,7 @@ func TestValidateConfig(t *testing.T) { "Valid config with get client certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, - GetClientCertificate: func(cri *CertificateRequestInfo) (*tls.Certificate, error) { + GetClientCertificate: func(*CertificateRequestInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, diff --git a/conn.go b/conn.go index 2b7585108..c7cd9a3cb 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ package dtls import ( + "bytes" "context" "errors" "fmt" @@ -13,17 +14,17 @@ import ( "sync/atomic" "time" - "github.com/pion/dtls/v2/internal/closer" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/internal/closer" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" - "github.com/pion/transport/v2/connctx" - "github.com/pion/transport/v2/deadline" - "github.com/pion/transport/v2/replaydetector" + "github.com/pion/transport/v3/deadline" + "github.com/pion/transport/v3/netctx" + "github.com/pion/transport/v3/replaydetector" ) const ( @@ -32,8 +33,11 @@ const ( sessionLength = 32 defaultNamedCurve = elliptic.X25519 inboundBufferSize = 8192 - // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 + // Default replay protection window is specified by RFC 6347 Section 4.1.2.6. defaultReplayProtectionWindow = 64 + // maxAppDataPacketQueueSize is the maximum number of app data packets we will. + // enqueue before the handshake is completed. + maxAppDataPacketQueueSize = 100 ) func invalidKeyingLabels() map[string]bool { @@ -45,26 +49,38 @@ func invalidKeyingLabels() map[string]bool { } } -// Conn represents a DTLS connection -type Conn struct { - lock sync.RWMutex // Internal lock (must not be public) - nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from - fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling - handshakeCache *handshakeCache // caching of handshake messages for verifyData generation - decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read` +type addrPkt struct { + rAddr net.Addr + data []byte +} - state State // Internal state +type recvHandshakeState struct { + done chan struct{} + isRetransmit bool +} + +// Conn represents a DTLS connection. +type Conn struct { + lock sync.RWMutex // Internal lock (must not be public) + nextConn netctx.PacketConn // Embedded Conn, typically a udpconn we read/write from + fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling + handshakeCache *handshakeCache // caching of handshake messages for verifyData generation + decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read` + rAddr net.Addr + state State // Internal state maximumTransmissionUnit int + paddingLengthGenerator func(uint) uint handshakeCompletedSuccessfully atomic.Value + handshakeMutex sync.Mutex + handshakeDone chan struct{} - encryptedPackets [][]byte + encryptedPackets []addrPkt connectionClosedByUser bool closeLock sync.Mutex closed *closer.Closer - handshakeLoopsFinished sync.WaitGroup readDeadline *deadline.Deadline writeDeadline *deadline.Deadline @@ -72,18 +88,26 @@ type Conn struct { log logging.LeveledLogger reading chan struct{} - handshakeRecv chan chan struct{} + handshakeRecv chan recvHandshakeState cancelHandshaker func() cancelHandshakeReader func() fsm *handshakeFSM replayProtectionWindow uint + + handshakeConfig *handshakeConfig } -func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { - err := validateConfig(config) - if err != nil { +//nolint:cyclop +func createConn( + nextConn net.PacketConn, + rAddr net.Addr, + config *Config, + isClient bool, + resumeState *State, +) (*Conn, error) { + if err := validateConfig(config); err != nil { return nil, err } @@ -91,7 +115,34 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient return nil, errNilNextConn } - cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + logger := loggerFactory.NewLogger("dtls") + + mtu := config.MTU + if mtu <= 0 { + mtu = defaultMTU + } + + replayProtectionWindow := config.ReplayProtectionWindow + if replayProtectionWindow <= 0 { + replayProtectionWindow = defaultReplayProtectionWindow + } + + paddingLengthGenerator := config.PaddingLengthGenerator + if paddingLengthGenerator == nil { + paddingLengthGenerator = func(uint) uint { return 0 } + } + + cipherSuites, err := parseCipherSuites( + config.CipherSuites, + config.CustomCipherSuites, + config.includeCertificateSuites(), + config.PSK != nil, + ) if err != nil { return nil, err } @@ -106,28 +157,62 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient workerInterval = config.FlightInterval } - loggerFactory := config.LoggerFactory - if loggerFactory == nil { - loggerFactory = logging.NewDefaultLoggerFactory() - } - - logger := loggerFactory.NewLogger("dtls") - - mtu := config.MTU - if mtu <= 0 { - mtu = defaultMTU + serverName := config.ServerName + // Do not allow the use of an IP address literal as an SNI value. + // See RFC 6066, Section 3. + if net.ParseIP(serverName) != nil { + serverName = "" } - replayProtectionWindow := config.ReplayProtectionWindow - if replayProtectionWindow <= 0 { - replayProtectionWindow = defaultReplayProtectionWindow + curves := config.EllipticCurves + if len(curves) == 0 { + curves = defaultCurves } - c := &Conn{ - nextConn: connctx.New(nextConn), + handshakeConfig := &handshakeConfig{ + localPSKCallback: config.PSK, + localPSKIdentityHint: config.PSKIdentityHint, + localCipherSuites: cipherSuites, + localSignatureSchemes: signatureSchemes, + extendedMasterSecret: config.ExtendedMasterSecret, + localSRTPProtectionProfiles: config.SRTPProtectionProfiles, + localSRTPMasterKeyIdentifier: config.SRTPMasterKeyIdentifier, + serverName: serverName, + supportedProtocols: config.SupportedProtocols, + clientAuth: config.ClientAuth, + localCertificates: config.Certificates, + insecureSkipVerify: config.InsecureSkipVerify, + verifyPeerCertificate: config.VerifyPeerCertificate, + verifyConnection: config.VerifyConnection, + rootCAs: config.RootCAs, + clientCAs: config.ClientCAs, + customCipherSuites: config.CustomCipherSuites, + initialRetransmitInterval: workerInterval, + disableRetransmitBackoff: config.DisableRetransmitBackoff, + log: logger, + initialEpoch: 0, + keyLogWriter: config.KeyLogWriter, + sessionStore: config.SessionStore, + ellipticCurves: curves, + localGetCertificate: config.GetCertificate, + localGetClientCertificate: config.GetClientCertificate, + insecureSkipHelloVerify: config.InsecureSkipVerifyHello, + connectionIDGenerator: config.ConnectionIDGenerator, + helloRandomBytesGenerator: config.HelloRandomBytesGenerator, + clientHelloMessageHook: config.ClientHelloMessageHook, + serverHelloMessageHook: config.ServerHelloMessageHook, + certificateRequestMessageHook: config.CertificateRequestMessageHook, + resumeState: resumeState, + } + + conn := &Conn{ + rAddr: rAddr, + nextConn: netctx.NewPacketConn(nextConn), + handshakeConfig: handshakeConfig, fragmentBuffer: newFragmentBuffer(), handshakeCache: newHandshakeCache(), maximumTransmissionUnit: mtu, + paddingLengthGenerator: paddingLengthGenerator, decrypted: make(chan interface{}, 1), log: logger, @@ -135,76 +220,76 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient readDeadline: deadline.New(), writeDeadline: deadline.New(), - reading: make(chan struct{}, 1), - handshakeRecv: make(chan chan struct{}), - closed: closer.NewCloser(), - cancelHandshaker: func() {}, + reading: make(chan struct{}, 1), + handshakeRecv: make(chan recvHandshakeState), + closed: closer.NewCloser(), + cancelHandshaker: func() {}, + cancelHandshakeReader: func() {}, - replayProtectionWindow: uint(replayProtectionWindow), + replayProtectionWindow: uint(replayProtectionWindow), //nolint:gosec // G115 state: State{ isClient: isClient, }, } - c.setRemoteEpoch(0) - c.setLocalEpoch(0) + conn.setRemoteEpoch(0) + conn.setLocalEpoch(0) - serverName := config.ServerName - // Do not allow the use of an IP address literal as an SNI value. - // See RFC 6066, Section 3. - if net.ParseIP(serverName) != nil { - serverName = "" - } + return conn, nil +} - curves := config.EllipticCurves - if len(curves) == 0 { - curves = defaultCurves - } +// Handshake runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// Most uses of this package need not call Handshake explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +// +// For control over canceling or setting a timeout on a handshake, use +// [Conn.HandshakeContext]. +func (c *Conn) Handshake() error { + return c.HandshakeContext(context.Background()) +} - hsCfg := &handshakeConfig{ - localPSKCallback: config.PSK, - localPSKIdentityHint: config.PSKIdentityHint, - localCipherSuites: cipherSuites, - localSignatureSchemes: signatureSchemes, - extendedMasterSecret: config.ExtendedMasterSecret, - localSRTPProtectionProfiles: config.SRTPProtectionProfiles, - serverName: serverName, - supportedProtocols: config.SupportedProtocols, - clientAuth: config.ClientAuth, - localCertificates: config.Certificates, - insecureSkipVerify: config.InsecureSkipVerify, - verifyPeerCertificate: config.VerifyPeerCertificate, - verifyConnection: config.VerifyConnection, - rootCAs: config.RootCAs, - clientCAs: config.ClientCAs, - customCipherSuites: config.CustomCipherSuites, - retransmitInterval: workerInterval, - log: logger, - initialEpoch: 0, - keyLogWriter: config.KeyLogWriter, - sessionStore: config.SessionStore, - ellipticCurves: curves, - localGetCertificate: config.GetCertificate, - localGetClientCertificate: config.GetClientCertificate, - insecureSkipHelloVerify: config.InsecureSkipVerifyHello, +// HandshakeContext runs the client or server DTLS handshake +// protocol if it has not yet been run. +// +// The provided Context must be non-nil. If the context is canceled before +// the handshake is complete, the handshake is interrupted and an error is returned. +// Once the handshake has completed, cancellation of the context will not affect the +// connection. +// +// Most uses of this package need not call HandshakeContext explicitly: the +// first [Conn.Read] or [Conn.Write] will call it automatically. +func (c *Conn) HandshakeContext(ctx context.Context) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.isHandshakeCompletedSuccessfully() { + return nil } + handshakeDone := make(chan struct{}) + defer close(handshakeDone) + c.closeLock.Lock() + c.handshakeDone = handshakeDone + c.closeLock.Unlock() + // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. - if !isClient { - cert, err := hsCfg.getCertificate(&ClientHelloInfo{}) + if !c.state.isClient { + cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{}) if err != nil && !errors.Is(err, errNoCertificates) { - return nil, err + return err } - hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites) + c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) } var initialFlight flightVal var initialFSMState handshakeState - if initialState != nil { + if c.handshakeConfig.resumeState != nil { //nolint:nestif if c.state.isClient { initialFlight = flight5 } else { @@ -212,7 +297,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient } initialFSMState = handshakeFinished - c.state = *initialState + c.state = *c.handshakeConfig.resumeState } else { if c.state.isClient { initialFlight = flight1 @@ -222,56 +307,30 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient initialFSMState = handshakePreparing } // Do handshake - if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { - return nil, err + if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { + return err } c.log.Trace("Handshake Completed") - return c, nil + return nil } // Dial connects to the given network address and establishes a DTLS connection on top. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use DialWithContext() instead. -func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return DialWithContext(ctx, network, raddr, config) -} - -// Client establishes a DTLS connection over an existing connection. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use ClientWithContext() instead. -func Client(conn net.Conn, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return ClientWithContext(ctx, conn, config) -} - -// Server listens for incoming DTLS connections. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, use ServerWithContext() instead. -func Server(conn net.Conn, config *Config) (*Conn, error) { - ctx, cancel := config.connectContextMaker() - defer cancel() - - return ServerWithContext(ctx, conn, config) -} - -// DialWithContext connects to the given network address and establishes a DTLS connection on top. -func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { - pConn, err := net.DialUDP(network, nil, raddr) +func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { + // net.ListenUDP is used rather than net.DialUDP as the latter prevents the + // use of net.PacketConn.WriteTo. + // https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115 + pConn, err := net.ListenUDP(network, nil) if err != nil { return nil, err } - return ClientWithContext(ctx, pConn, config) + + return Client(pConn, rAddr, config) } -// ClientWithContext establishes a DTLS connection over an existing connection. -func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { +// Client establishes a DTLS connection over an existing connection. +func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { switch { case config == nil: return nil, errNoConfigProvided @@ -279,22 +338,27 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con return nil, errPSKAndIdentityMustBeSetForClient } - return createConn(ctx, conn, config, true, nil) + return createConn(conn, rAddr, config, true, nil) } -// ServerWithContext listens for incoming DTLS connections. -func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { +// Server listens for incoming DTLS connections. +func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } + if config.OnConnectionAttempt != nil { + if err := config.OnConnectionAttempt(rAddr); err != nil { + return nil, err + } + } - return createConn(ctx, conn, config, false, nil) + return createConn(conn, rAddr, config, false, nil) } // Read reads data from the connection. -func (c *Conn) Read(p []byte) (n int, err error) { - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress +func (c *Conn) Read(buff []byte) (n int, err error) { //nolint:cyclop + if err := c.Handshake(); err != nil { + return 0, err } select { @@ -313,10 +377,11 @@ func (c *Conn) Read(p []byte) (n int, err error) { } switch val := out.(type) { case ([]byte): - if len(p) < len(val) { + if len(buff) < len(val) { return 0, errBufferTooSmall } - copy(p, val) + copy(buff, val) + return len(val), nil case (error): return 0, val @@ -325,8 +390,8 @@ func (c *Conn) Read(p []byte) (n int, err error) { } } -// Write writes len(p) bytes from p to the DTLS connection -func (c *Conn) Write(p []byte) (int, error) { +// Write writes len(payload) bytes from payload to the DTLS connection. +func (c *Conn) Write(payload []byte) (int, error) { if c.isConnectionClosed() { return 0, ErrConnClosed } @@ -337,11 +402,11 @@ func (c *Conn) Write(p []byte) (int, error) { default: } - if !c.isHandshakeCompletedSuccessfully() { - return 0, errHandshakeInProgress + if err := c.Handshake(); err != nil { + return 0, err } - return len(p), c.writePackets(c.writeDeadline, []*packet{ + return len(payload), c.writePackets(c.writeDeadline, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ @@ -349,9 +414,10 @@ func (c *Conn) Write(p []byte) (int, error) { Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ - Data: p, + Data: payload, }, }, + shouldWrapCID: len(c.state.remoteConnectionID) > 0, shouldEncrypt: true, }, }) @@ -360,28 +426,46 @@ func (c *Conn) Write(p []byte) (int, error) { // Close closes the connection. func (c *Conn) Close() error { err := c.close(true) //nolint:contextcheck - c.handshakeLoopsFinished.Wait() + c.closeLock.Lock() + handshakeDone := c.handshakeDone + c.closeLock.Unlock() + if handshakeDone != nil { + <-handshakeDone + } + return err } // ConnectionState returns basic DTLS details about the connection. // Note that this replaced the `Export` function of v1. -func (c *Conn) ConnectionState() State { +func (c *Conn) ConnectionState() (State, bool) { c.lock.RLock() defer c.lock.RUnlock() - return *c.state.clone() + stateClone, err := c.state.clone() + if err != nil { + return State{}, false + } + + return *stateClone, true } -// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile +// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { - c.lock.RLock() - defer c.lock.RUnlock() - - if c.state.srtpProtectionProfile == 0 { + profile := c.state.getSRTPProtectionProfile() + if profile == 0 { return 0, false } - return c.state.srtpProtectionProfile, true + return profile, true +} + +// RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp. +func (c *Conn) RemoteSRTPMasterKeyIdentifier() ([]byte, bool) { + if profile := c.state.getSRTPProtectionProfile(); profile == 0 { + return nil, false + } + + return c.state.remoteSRTPMasterKeyIdentifier, true } func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { @@ -390,25 +474,32 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { var rawPackets [][]byte - for _, p := range pkts { - if h, ok := p.record.Content.(*handshake.Handshake); ok { - handshakeRaw, err := p.record.Marshal() + for _, pkt := range pkts { + if dtlsHandshake, ok := pkt.record.Content.(*handshake.Handshake); ok { + handshakeRaw, err := pkt.record.Marshal() if err != nil { return err } c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", - srvCliStr(c.state.isClient), h.Header.Type.String(), - p.record.Header.Epoch, h.Header.MessageSequence) - c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) + srvCliStr(c.state.isClient), dtlsHandshake.Header.Type.String(), + pkt.record.Header.Epoch, dtlsHandshake.Header.MessageSequence) + + c.handshakeCache.push( + handshakeRaw[recordlayer.FixedHeaderSize:], + pkt.record.Header.Epoch, + dtlsHandshake.Header.MessageSequence, + dtlsHandshake.Header.Type, + c.state.isClient, + ) - rawHandshakePackets, err := c.processHandshakePacket(p, h) + rawHandshakePackets, err := c.processHandshakePacket(pkt, dtlsHandshake) if err != nil { return err } rawPackets = append(rawPackets, rawHandshakePackets...) } else { - rawPacket, err := c.processPacket(p) + rawPacket, err := c.processPacket(pkt) if err != nil { return err } @@ -421,7 +512,7 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { compactedRawPackets := c.compactRawPackets(rawPackets) for _, compactedRawPackets := range compactedRawPackets { - if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil { + if _, err := c.nextConn.WriteToContext(ctx, compactedRawPackets, c.rAddr); err != nil { return netError(err) } } @@ -451,8 +542,8 @@ func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { return combinedRawPackets } -func (c *Conn) processPacket(p *packet) ([]byte, error) { - epoch := p.record.Header.Epoch +func (c *Conn) processPacket(pkt *packet) ([]byte, error) { //nolint:cyclop + epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } @@ -463,16 +554,51 @@ func (c *Conn) processPacket(p *packet) ([]byte, error) { // prior to allowing the sequence number to wrap. return nil, errSequenceNumberOverflow } - p.record.Header.SequenceNumber = seq + pkt.record.Header.SequenceNumber = seq - rawPacket, err := p.record.Marshal() - if err != nil { - return nil, err + var rawPacket []byte + if pkt.shouldWrapCID { //nolint:nestif + // Record must be marshaled to populate fields used in inner plaintext. + if _, err := pkt.record.Marshal(); err != nil { + return nil, err + } + content, err := pkt.record.Content.Marshal() + if err != nil { + return nil, err + } + inner := &recordlayer.InnerPlaintext{ + Content: content, + RealType: pkt.record.Header.ContentType, + } + rawInner, err := inner.Marshal() //nolint:govet + if err != nil { + return nil, err + } + cidHeader := &recordlayer.Header{ + Version: pkt.record.Header.Version, + ContentType: protocol.ContentTypeConnectionID, + Epoch: pkt.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 + ConnectionID: c.state.remoteConnectionID, + SequenceNumber: pkt.record.Header.SequenceNumber, + } + rawPacket, err = cidHeader.Marshal() + if err != nil { + return nil, err + } + pkt.record.Header = *cidHeader + rawPacket = append(rawPacket, rawInner...) + } else { + var err error + rawPacket, err = pkt.record.Marshal() + if err != nil { + return nil, err + } } - if p.shouldEncrypt { + if pkt.shouldEncrypt { var err error - rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) + rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } @@ -481,14 +607,15 @@ func (c *Conn) processPacket(p *packet) ([]byte, error) { return rawPacket, nil } -func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) { +//nolint:cyclop +func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Handshake) ([][]byte, error) { rawPackets := make([][]byte, 0) - handshakeFragments, err := c.fragmentHandshake(h) + handshakeFragments, err := c.fragmentHandshake(dtlsHandshake) if err != nil { return nil, err } - epoch := p.record.Header.Epoch + epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } @@ -499,25 +626,52 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by return nil, errSequenceNumberOverflow } - recordlayerHeader := &recordlayer.Header{ - Version: p.record.Header.Version, - ContentType: p.record.Header.ContentType, - ContentLen: uint16(len(handshakeFragment)), - Epoch: p.record.Header.Epoch, - SequenceNumber: seq, - } + var rawPacket []byte + if pkt.shouldWrapCID { + inner := &recordlayer.InnerPlaintext{ + Content: handshakeFragment, + RealType: protocol.ContentTypeHandshake, + Zeros: c.paddingLengthGenerator(uint(len(handshakeFragment))), + } + rawInner, err := inner.Marshal() //nolint:govet + if err != nil { + return nil, err + } + cidHeader := &recordlayer.Header{ + Version: pkt.record.Header.Version, + ContentType: protocol.ContentTypeConnectionID, + Epoch: pkt.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 + ConnectionID: c.state.remoteConnectionID, + SequenceNumber: pkt.record.Header.SequenceNumber, + } + rawPacket, err = cidHeader.Marshal() + if err != nil { + return nil, err + } + pkt.record.Header = *cidHeader + rawPacket = append(rawPacket, rawInner...) + } else { + recordlayerHeader := &recordlayer.Header{ + Version: pkt.record.Header.Version, + ContentType: pkt.record.Header.ContentType, + ContentLen: uint16(len(handshakeFragment)), //nolint:gosec // G115 + Epoch: pkt.record.Header.Epoch, + SequenceNumber: seq, + } - rawPacket, err := recordlayerHeader.Marshal() - if err != nil { - return nil, err - } + rawPacket, err = recordlayerHeader.Marshal() + if err != nil { + return nil, err + } - p.record.Header = *recordlayerHeader + pkt.record.Header = *recordlayerHeader + rawPacket = append(rawPacket, handshakeFragment...) + } - rawPacket = append(rawPacket, handshakeFragment...) - if p.shouldEncrypt { + if pkt.shouldEncrypt { var err error - rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) + rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } @@ -529,8 +683,8 @@ func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]by return rawPackets, nil } -func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { - content, err := h.Message.Marshal() +func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte, error) { + content, err := dtlsHandshake.Message.Marshal() if err != nil { return nil, err } @@ -549,11 +703,11 @@ func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { contentFragmentLen := len(contentFragment) headerFragment := &handshake.Header{ - Type: h.Header.Type, - Length: h.Header.Length, - MessageSequence: h.Header.MessageSequence, - FragmentOffset: uint32(offset), - FragmentLength: uint32(contentFragmentLen), + Type: dtlsHandshake.Header.Type, + Length: dtlsHandshake.Header.Length, + MessageSequence: dtlsHandshake.Header.MessageSequence, + FragmentOffset: uint32(offset), //nolint:gosec // G115 + FragmentLength: uint32(contentFragmentLen), //nolint:gosec // G115 } offset += contentFragmentLen @@ -573,11 +727,12 @@ func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { b := make([]byte, inboundBufferSize) + return &b }, } -func (c *Conn) readAndBuffer(ctx context.Context) error { +func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop bufptr, ok := poolReadBuffer.Get().(*[]byte) if !ok { return errFailedToAccessPoolReadBuffer @@ -585,19 +740,19 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { defer poolReadBuffer.Put(bufptr) b := *bufptr - i, err := c.nextConn.ReadContext(ctx, b) + i, rAddr, err := c.nextConn.ReadFromContext(ctx, b) if err != nil { return netError(err) } - pkts, err := recordlayer.UnpackDatagram(b[:i]) + pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID())) if err != nil { return err } - var hasHandshake bool + var hasHandshake, isRetransmit bool for _, p := range pkts { - hs, alert, err := c.handleIncomingPacket(ctx, p, true) + hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { @@ -605,29 +760,35 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { } } } - if hs { - hasHandshake = true - } var e *alertError - if errors.As(err, &e) { - if e.IsFatalOrCloseNotify() { - return e - } - } else if err != nil { + if errors.As(err, &e) && e.IsFatalOrCloseNotify() { return e } + if err != nil { + return err + } + if hs { + hasHandshake = true + } + if rtx { + isRetransmit = true + } } if hasHandshake { - done := make(chan struct{}) + s := recvHandshakeState{ + done: make(chan struct{}), + isRetransmit: isRetransmit, + } select { - case c.handshakeRecv <- done: + case c.handshakeRecv <- s: // If the other party may retransmit the flight, // we should respond even if it not a new message. - <-done + <-s.done case <-c.fsm.Done(): } } + return nil } @@ -636,7 +797,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { c.encryptedPackets = nil for _, p := range pkts { - _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue + _, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { @@ -645,99 +806,179 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { } } var e *alertError - if errors.As(err, &e) { - if e.IsFatalOrCloseNotify() { - return e - } - } else if err != nil { + if errors.As(err, &e) && e.IsFatalOrCloseNotify() { return e } + if err != nil { + return err + } } + return nil } -func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit - h := &recordlayer.Header{} - if err := h.Unmarshal(buf); err != nil { +func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { + if len(c.encryptedPackets) < maxAppDataPacketQueueSize { + c.encryptedPackets = append(c.encryptedPackets, packet) + + return true + } + + return false +} + +//nolint:gocognit,gocyclo,cyclop,maintidx +func (c *Conn) handleIncomingPacket( + ctx context.Context, + buf []byte, + rAddr net.Addr, + enqueue bool, +) (bool, bool, *alert.Alert, error) { + header := &recordlayer.Header{} + // Set connection ID size so that records of content type tls12_cid will + // be parsed correctly. + if len(c.state.getLocalConnectionID()) > 0 { + header.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) + } + if err := header.Unmarshal(buf); err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("discarded broken packet: %v", err) - return false, nil, nil - } + return false, false, nil, nil + } // Validate epoch remoteEpoch := c.state.getRemoteEpoch() - if h.Epoch > remoteEpoch { - if h.Epoch > remoteEpoch+1 { + if header.Epoch > remoteEpoch { + if header.Epoch > remoteEpoch+1 { c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", - h.Epoch, h.SequenceNumber, + header.Epoch, header.SequenceNumber, ) - return false, nil, nil + + return false, false, nil, nil } if enqueue { - c.log.Debug("received packet of next epoch, queuing packet") - c.encryptedPackets = append(c.encryptedPackets, buf) + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debug("received packet of next epoch, queuing packet") + } } - return false, nil, nil + + return false, false, nil, nil } // Anti-replay protection - for len(c.state.replayDetector) <= int(h.Epoch) { + for len(c.state.replayDetector) <= int(header.Epoch) { c.state.replayDetector = append(c.state.replayDetector, replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), ) } - markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber) + markPacketAsValid, ok := c.state.replayDetector[int(header.Epoch)].Check(header.SequenceNumber) if !ok { c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", - h.Epoch, h.SequenceNumber, + header.Epoch, header.SequenceNumber, ) - return false, nil, nil + + return false, false, nil, nil } + // originalCID indicates whether the original record had content type + // Connection ID. + originalCID := false + // Decrypt - if h.Epoch != 0 { + if header.Epoch != 0 { //nolint:nestif if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { - c.encryptedPackets = append(c.encryptedPackets, buf) - c.log.Debug("handshake not finished, queuing packet") + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debug("handshake not finished, queuing packet") + } } - return false, nil, nil + + return false, false, nil, nil + } + + // If a connection identifier had been negotiated and encryption is + // enabled, the connection identifier MUST be sent. + if len(c.state.getLocalConnectionID()) > 0 && header.ContentType != protocol.ContentTypeConnectionID { + c.log.Debug("discarded packet missing connection ID after value negotiated") + + return false, false, nil, nil } var err error - buf, err = c.state.cipherSuite.Decrypt(buf) + var hdr recordlayer.Header + if header.ContentType == protocol.ContentTypeConnectionID { + hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) + } + buf, err = c.state.cipherSuite.Decrypt(hdr, buf) if err != nil { c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) - return false, nil, nil + + return false, false, nil, nil + } + // If this is a connection ID record, make it look like a normal record for + // further processing. + if header.ContentType == protocol.ContentTypeConnectionID { + originalCID = true + ip := &recordlayer.InnerPlaintext{} + if err := ip.Unmarshal(buf[header.Size():]); err != nil { //nolint:govet + c.log.Debugf("unpacking inner plaintext failed: %s", err) + + return false, false, nil, nil + } + unpacked := &recordlayer.Header{ + ContentType: ip.RealType, + ContentLen: uint16(len(ip.Content)), //nolint:gosec // G115 + Version: header.Version, + Epoch: header.Epoch, + SequenceNumber: header.SequenceNumber, + } + buf, err = unpacked.Marshal() + if err != nil { + c.log.Debugf("converting CID record to inner plaintext failed: %s", err) + + return false, false, nil, nil + } + buf = append(buf, ip.Content...) + } + + // If connection ID does not match discard the packet. + if !bytes.Equal(c.state.getLocalConnectionID(), header.ConnectionID) { + c.log.Debug("unexpected connection ID") + + return false, false, nil, nil } } - isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...)) + isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...)) if err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("defragment failed: %s", err) - return false, nil, nil + + return false, false, nil, nil } else if isHandshake { markPacketAsValid() + for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { header := &handshake.Header{} if err := header.Unmarshal(out); err != nil { c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) + continue } c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) } - return true, nil, nil + return true, isRetransmit, nil, nil } r := &recordlayer.RecordLayer{} if err := r.Unmarshal(buf); err != nil { - return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err + return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err } + isLatestSeqNum := false switch content := r.Content.(type) { case *alert.Alert: c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String()) @@ -746,30 +987,35 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo // Respond with a close_notify [RFC5246 Section 7.2.1] a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} } - markPacketAsValid() - return false, a, &alertError{content} + _ = markPacketAsValid() + + return false, false, a, &alertError{content} case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { - c.encryptedPackets = append(c.encryptedPackets, buf) - c.log.Debugf("CipherSuite not initialized, queuing packet") + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debugf("CipherSuite not initialized, queuing packet") + } } - return false, nil, nil + + return false, false, nil, nil } - newRemoteEpoch := h.Epoch + 1 + newRemoteEpoch := header.Epoch + 1 c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) if c.state.getRemoteEpoch()+1 == newRemoteEpoch { c.setRemoteEpoch(newRemoteEpoch) - markPacketAsValid() + isLatestSeqNum = markPacketAsValid() } case *protocol.ApplicationData: - if h.Epoch == 0 { - return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero + if header.Epoch == 0 { + return false, false, &alert.Alert{ + Level: alert.Fatal, Description: alert.UnexpectedMessage, + }, errApplicationDataEpochZero } - markPacketAsValid() + isLatestSeqNum = markPacketAsValid() select { case c.decrypted <- content.Data: @@ -778,12 +1024,26 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo } default: - return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) + return false, false, &alert.Alert{ + Level: alert.Fatal, Description: alert.UnexpectedMessage, + }, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) + } + + // Any valid connection ID record is a candidate for updating the remote + // address if it is the latest record received. + // https://datatracker.ietf.org/doc/html/rfc9146#peer-address-update + if originalCID && isLatestSeqNum { + if rAddr != c.RemoteAddr() { + c.lock.Lock() + c.rAddr = rAddr + c.lock.Unlock() + } } - return false, nil, nil + + return false, false, nil, nil } -func (c *Conn) recvHandshake() <-chan chan struct{} { +func (c *Conn) recvHandshake() <-chan recvHandshakeState { return c.handshakeRecv } @@ -798,6 +1058,7 @@ func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Descrip } } } + return c.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ @@ -810,6 +1071,7 @@ func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Descrip Description: desc, }, }, + shouldWrapCID: len(c.state.remoteConnectionID) > 0, shouldEncrypt: c.isHandshakeCompletedSuccessfully(), }, }) @@ -821,16 +1083,22 @@ func (c *Conn) setHandshakeCompletedSuccessfully() { func (c *Conn) isHandshakeCompletedSuccessfully() bool { boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool }) + return boolean.bool } -func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit +//nolint:cyclop,gocognit,contextcheck +func (c *Conn) handshake( + ctx context.Context, + cfg *handshakeConfig, + initialFlight flightVal, + initialState handshakeState, +) error { c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) done := make(chan struct{}) ctxRead, cancelRead := context.WithCancel(context.Background()) - c.cancelHandshakeReader = cancelRead - cfg.onFlightState = func(f flightVal, s handshakeState) { + cfg.onFlightState = func(_ flightVal, s handshakeState) { if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() { c.setHandshakeCompletedSuccessfully() close(done) @@ -838,16 +1106,21 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh } ctxHs, cancel := context.WithCancel(context.Background()) + + c.closeLock.Lock() c.cancelHandshaker = cancel + c.cancelHandshakeReader = cancelRead + c.closeLock.Unlock() firstErr := make(chan error, 1) - c.handshakeLoopsFinished.Add(2) + var handshakeLoopsFinished sync.WaitGroup + handshakeLoopsFinished.Add(2) // Handshake routine should be live until close. // The other party may request retransmission of the last flight to cope with packet drop. go func() { - defer c.handshakeLoopsFinished.Done() + defer handshakeLoopsFinished.Done() err := c.fsm.Run(ctxHs, c, initialState) if !errors.Is(err, context.Canceled) { select { @@ -858,19 +1131,21 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh }() go func() { defer func() { - // Escaping read loop. - // It's safe to close decrypted channnel now. - close(c.decrypted) + if c.isHandshakeCompletedSuccessfully() { + // Escaping read loop. + // It's safe to close decrypted channnel now. + close(c.decrypted) + } // Force stop handshaker when the underlying connection is closed. cancel() }() - defer c.handshakeLoopsFinished.Done() + defer handshakeLoopsFinished.Done() for { - if err := c.readAndBuffer(ctxRead); err != nil { - var e *alertError - if errors.As(err, &e) { - if !e.IsFatalOrCloseNotify() { + if err := c.readAndBuffer(ctxRead); err != nil { //nolint:nestif + var alertErr *alertError + if errors.As(err, &alertErr) { + if !alertErr.IsFatalOrCloseNotify() { if c.isHandshakeCompletedSuccessfully() { // Pass the error to Read() select { @@ -879,11 +1154,19 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh case <-ctxRead.Done(): } } + continue // non-fatal alert must not stop read loop } } else { switch { - case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF): + case errors.Is(err, context.DeadlineExceeded), + errors.Is(err, context.Canceled), + errors.Is(err, io.EOF), + errors.Is(err, net.ErrClosed): + case errors.Is(err, recordlayer.ErrInvalidPacketLength): + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + continue default: if c.isHandshakeCompletedSuccessfully() { // Keep read loop and pass the read error to Read() @@ -892,6 +1175,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh case <-c.closed.Done(): case <-ctxRead.Done(): } + continue // non-fatal alert must not stop read loop } } @@ -902,8 +1186,8 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh default: } - if e != nil { - if e.IsFatalOrCloseNotify() { + if alertErr != nil { + if alertErr.IsFatalOrCloseNotify() { _ = c.close(false) //nolint:contextcheck } } @@ -911,6 +1195,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh c.log.Trace("handshake timeouts - closing underline connection") _ = c.close(false) //nolint:contextcheck } + return } } @@ -920,12 +1205,14 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh case err := <-firstErr: cancelRead() cancel() - c.handshakeLoopsFinished.Wait() + handshakeLoopsFinished.Wait() + return c.translateHandshakeCtxError(err) case <-ctx.Done(): cancelRead() cancel() - c.handshakeLoopsFinished.Wait() + handshakeLoopsFinished.Wait() + return c.translateHandshakeCtxError(ctx.Err()) case <-done: return nil @@ -939,12 +1226,18 @@ func (c *Conn) translateHandshakeCtxError(err error) error { if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { return nil } + return &HandshakeError{Err: err} } func (c *Conn) close(byUser bool) error { - c.cancelHandshaker() - c.cancelHandshakeReader() + c.closeLock.Lock() + cancelHandshaker := c.cancelHandshaker + cancelHandshakeReader := c.cancelHandshakeReader + c.closeLock.Unlock() + + cancelHandshaker() + cancelHandshakeReader() if c.isHandshakeCompletedSuccessfully() && byUser { // Discard error from notify() to return non-error on the first user call of Close() @@ -990,14 +1283,17 @@ func (c *Conn) setRemoteEpoch(epoch uint16) { c.state.remoteEpoch.Store(epoch) } -// LocalAddr implements net.Conn.LocalAddr +// LocalAddr implements net.Conn.LocalAddr. func (c *Conn) LocalAddr() net.Addr { return c.nextConn.LocalAddr() } -// RemoteAddr implements net.Conn.RemoteAddr +// RemoteAddr implements net.Conn.RemoteAddr. func (c *Conn) RemoteAddr() net.Addr { - return c.nextConn.RemoteAddr() + c.lock.RLock() + defer c.lock.RUnlock() + + return c.rAddr } func (c *Conn) sessionKey() []byte { @@ -1005,18 +1301,20 @@ func (c *Conn) sessionKey() []byte { // As ServerName can be like 0.example.com, it's better to add // delimiter character which is not allowed to be in // neither address or domain name. - return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName) + return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) } + return c.state.SessionID } -// SetDeadline implements net.Conn.SetDeadline +// SetDeadline implements net.Conn.SetDeadline. func (c *Conn) SetDeadline(t time.Time) error { c.readDeadline.Set(t) + return c.SetWriteDeadline(t) } -// SetReadDeadline implements net.Conn.SetReadDeadline +// SetReadDeadline implements net.Conn.SetReadDeadline. func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) // Read deadline is fully managed by this layer. @@ -1024,7 +1322,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error { return nil } -// SetWriteDeadline implements net.Conn.SetWriteDeadline +// SetWriteDeadline implements net.Conn.SetWriteDeadline. func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline.Set(t) // Write deadline is also fully managed by this layer. diff --git a/conn_go_test.go b/conn_go_test.go index 99e6f74c4..b22e7c71e 100644 --- a/conn_go_test.go +++ b/conn_go_test.go @@ -15,12 +15,13 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" - "github.com/pion/transport/v2/dpipe" - "github.com/pion/transport/v2/test" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/transport/v3/dpipe" + "github.com/pion/transport/v3/test" ) -func TestContextConfig(t *testing.T) { +func TestContextConfig(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -51,9 +52,6 @@ func TestContextConfig(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } config := &Config{ - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(context.Background(), 40*time.Millisecond) - }, Certificates: []tls.Certificate{cert}, } @@ -63,71 +61,58 @@ func TestContextConfig(t *testing.T) { }{ "Dial": { f: func() (func() (net.Conn, error), func()) { + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + return func() (net.Conn, error) { - return Dial("udp", addr, config) - }, func() { - } - }, - order: []byte{0, 1, 2}, - }, - "DialWithContext": { - f: func() (func() (net.Conn, error), func()) { - ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) - return func() (net.Conn, error) { - return DialWithContext(ctx, "udp", addr, config) + conn, err := Dial("udp", addr, config) + if err != nil { + return nil, err + } + + return conn, conn.HandshakeContext(ctx) }, func() { cancel() } }, - order: []byte{0, 2, 1}, + order: []byte{0, 1, 2}, }, "Client": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + return func() (net.Conn, error) { - return Client(ca, config) + conn, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() - } - }, - order: []byte{0, 1, 2}, - }, - "ClientWithContext": { - f: func() (func() (net.Conn, error), func()) { - ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) - ca, _ := dpipe.Pipe() - return func() (net.Conn, error) { - return ClientWithContext(ctx, ca, config) - }, func() { cancel() - _ = ca.Close() } }, - order: []byte{0, 2, 1}, + order: []byte{0, 1, 2}, }, "Server": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + return func() (net.Conn, error) { - return Server(ca, config) + conn, err := Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) + if err != nil { + return nil, err + } + + return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() - } - }, - order: []byte{0, 1, 2}, - }, - "ServerWithContext": { - f: func() (func() (net.Conn, error), func()) { - ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) - ca, _ := dpipe.Pipe() - return func() (net.Conn, error) { - return ServerWithContext(ctx, ca, config) - }, func() { cancel() - _ = ca.Close() } }, - order: []byte{0, 2, 1}, + order: []byte{0, 1, 2}, }, } @@ -144,6 +129,7 @@ func TestContextConfig(t *testing.T) { if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck t.Errorf("Client error exp(Temporary network error) failed(%v)", err) close(done) + return } done <- struct{}{} diff --git a/conn_test.go b/conn_test.go index 946e4ab43..ed821a57c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -24,20 +24,21 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" - "github.com/pion/dtls/v2/pkg/crypto/signature" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/extension" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" - "github.com/pion/transport/v2/dpipe" - "github.com/pion/transport/v2/test" + "github.com/pion/transport/v3/dpipe" + "github.com/pion/transport/v3/test" ) var ( @@ -62,6 +63,8 @@ func TestStressDuplex(t *testing.T) { } func stressDuplex(t *testing.T) { + t.Helper() + ca, cb, err := pipeMemory() if err != nil { t.Fatal(err) @@ -116,7 +119,7 @@ func TestRoutineLeakOnClose(t *testing.T) { // inboundLoop routine should not be leaked. } -func TestReadWriteDeadline(t *testing.T) { +func TestReadWriteDeadline(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() @@ -125,7 +128,7 @@ func TestReadWriteDeadline(t *testing.T) { report := test.CheckRoutines(t) defer report() - var e net.Error + var netErr net.Error ca, cb, err := pipeMemory() if err != nil { @@ -136,22 +139,22 @@ func TestReadWriteDeadline(t *testing.T) { t.Fatal(err) } _, werr := ca.Write(make([]byte, 100)) - if errors.As(werr, &e) { - if !e.Timeout() { + if errors.As(werr, &netErr) { + if !netErr.Timeout() { t.Error("Deadline exceeded Write must return Timeout error") } - if !e.Temporary() { //nolint:staticcheck + if !netErr.Temporary() { //nolint:staticcheck t.Error("Deadline exceeded Write must return Temporary error") } } else { t.Error("Write must return net.Error error") } _, rerr := ca.Read(make([]byte, 100)) - if errors.As(rerr, &e) { - if !e.Timeout() { + if errors.As(rerr, &netErr) { + if !netErr.Timeout() { t.Error("Deadline exceeded Read must return Timeout error") } - if !e.Temporary() { //nolint:staticcheck + if !netErr.Temporary() { //nolint:staticcheck t.Error("Deadline exceeded Read must return Temporary error") } } else { @@ -250,6 +253,7 @@ func TestSequenceNumberOverflow(t *testing.T) { func pipeMemory() (*Conn, *Conn, error) { // In memory pipe ca, cb := dpipe.Pipe() + return pipeConn(ca, cb) } @@ -259,33 +263,44 @@ func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { err error } - c := make(chan result) + resultCh := make(chan result) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Setup client go func() { - client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + }, true) + resultCh <- result{client, err} }() // Setup server - server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + }, true) if err != nil { return nil, nil, err } // Receive client - res := <-c + res := <-resultCh if res.err != nil { _ = server.Close() + return nil, nil, res.err } return res.c, server, nil } -func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { +func testClient( + ctx context.Context, + pktConn net.PacketConn, + rAddr net.Addr, + cfg *Config, + generateCertificate bool, +) (*Conn, error) { if generateCertificate { clientCert, err := selfsign.GenerateSelfSigned() if err != nil { @@ -294,10 +309,21 @@ func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificat cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true - return ClientWithContext(ctx, c, cfg) + conn, err := Client(pktConn, rAddr, cfg) + if err != nil { + return nil, err + } + + return conn, conn.HandshakeContext(ctx) } -func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { +func testServer( + ctx context.Context, + c net.PacketConn, + rAddr net.Addr, + cfg *Config, + generateCertificate bool, +) (*Conn, error) { if generateCertificate { serverCert, err := selfsign.GenerateSelfSigned() if err != nil { @@ -305,7 +331,12 @@ func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificat } cfg.Certificates = []tls.Certificate{serverCert} } - return ServerWithContext(ctx, c, cfg) + conn, err := Server(c, rAddr, cfg) + if err != nil { + return nil, err + } + + return conn, conn.HandshakeContext(ctx) } func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error { @@ -316,7 +347,7 @@ func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensio }, Content: &handshake.Handshake{ Header: handshake.Header{ - MessageSequence: uint16(sequenceNumber), + MessageSequence: uint16(sequenceNumber), //nolint:gosec // G115 }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, @@ -334,6 +365,7 @@ func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensio if _, err = ca.Write(packet); err != nil { return err } + return nil } @@ -384,11 +416,11 @@ func TestHandshakeWithAlert(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - _, err := testClient(ctx, ca, testCase.configClient, true) + _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), testCase.configClient, true) clientErr <- err }() - _, errServer := testServer(ctx, cb, testCase.configServer, true) + _, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), testCase.configServer, true) if !errors.Is(errServer, testCase.errServer) { t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer) } @@ -401,7 +433,76 @@ func TestHandshakeWithAlert(t *testing.T) { } } -func TestExportKeyingMaterial(t *testing.T) { +func TestHandshakeWithInvalidRecord(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + type result struct { + c *Conn + err error + } + clientErr := make(chan result, 1) + ca, cb := dpipe.Pipe() + caWithInvalidRecord := &connWithCallback{Conn: ca} + + var msgSeq atomic.Int32 + // Send invalid record after first message + caWithInvalidRecord.onWrite = func([]byte) { + if msgSeq.Add(1) == 2 { + if _, err := ca.Write([]byte{0x01, 0x02}); err != nil { + t.Fatal(err) + } + } + } + go func() { + client, err := testClient( + ctx, + dtlsnet.PacketConnFromConn(caWithInvalidRecord), + caWithInvalidRecord.RemoteAddr(), + &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}, + true, + ) + clientErr <- result{client, err} + }() + + server, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + }, true) + + errClient := <-clientErr + + defer func() { + if server != nil { + if err := server.Close(); err != nil { + t.Fatal(err) + } + } + + if errClient.c != nil { + if err := errClient.c.Close(); err != nil { + t.Fatal(err) + } + } + }() + + if errServer != nil { + t.Fatalf("Server failed(%v)", errServer) + } + + if errClient.err != nil { + t.Fatalf("Client failed(%v)", errClient.err) + } +} + +func TestExportKeyingMaterial(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -412,7 +513,7 @@ func TestExportKeyingMaterial(t *testing.T) { expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b} expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77} - c := &Conn{ + conn := &Conn{ state: State{ localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}, remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand}, @@ -420,31 +521,43 @@ func TestExportKeyingMaterial(t *testing.T) { cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, }, } - c.setLocalEpoch(0) - c.setRemoteEpoch(0) + conn.setLocalEpoch(0) + conn.setRemoteEpoch(0) - state := c.ConnectionState() + state, ok := conn.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } _, err := state.ExportKeyingMaterial(exportLabel, nil, 0) if !errors.Is(err, errHandshakeInProgress) { t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err) } - c.setLocalEpoch(1) - state = c.ConnectionState() + conn.setLocalEpoch(1) + state, ok = conn.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0) if !errors.Is(err, errContextUnsupported) { t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err) } for k := range invalidKeyingLabels() { - state = c.ConnectionState() + state, ok = conn.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } _, err = state.ExportKeyingMaterial(k, nil, 0) if !errors.Is(err, errReservedExportKeyingMaterial) { t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err) } } - state = c.ConnectionState() + state, ok = conn.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10) if err != nil { t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) @@ -452,8 +565,11 @@ func TestExportKeyingMaterial(t *testing.T) { t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial) } - c.state.isClient = true - state = c.ConnectionState() + conn.state.isClient = true + state, ok = conn.ConnectionState() + if !ok { + t.Fatal("ConnectionState failed") + } keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10) if err != nil { t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) @@ -462,7 +578,7 @@ func TestExportKeyingMaterial(t *testing.T) { } } -func TestPSK(t *testing.T) { +func TestPSK(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -473,6 +589,7 @@ func TestPSK(t *testing.T) { for _, test := range []struct { Name string + ClientIdentity []byte ServerIdentity []byte CipherSuites []CipherSuiteID ClientVerifyConnection func(*State) error @@ -484,13 +601,15 @@ func TestPSK(t *testing.T) { { Name: "Server identity specified", ServerIdentity: []byte("Test Identity"), + ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "Server identity specified - Server verify connection fails", ServerIdentity: []byte("Test Identity"), + ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, - ServerVerifyConnection: func(s *State) error { + ServerVerifyConnection: func(*State) error { return errExample }, WantFail: true, @@ -500,8 +619,9 @@ func TestPSK(t *testing.T) { { Name: "Server identity specified - Client verify connection fails", ServerIdentity: []byte("Test Identity"), + ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, - ClientVerifyConnection: func(s *State) error { + ClientVerifyConnection: func(*State) error { return errExample }, WantFail: true, @@ -511,25 +631,33 @@ func TestPSK(t *testing.T) { { Name: "Server identity nil", ServerIdentity: nil, + ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "TLS_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, + ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256}, }, { Name: "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, + ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256}, }, + { + Name: "Client identity empty", + ServerIdentity: nil, + ClientIdentity: []byte{}, + CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, + }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - clientIdentity := []byte("Client Identity") type result struct { c *Conn err error @@ -541,25 +669,30 @@ func TestPSK(t *testing.T) { conf := &Config{ PSK: func(hint []byte) ([]byte, error) { if !bytes.Equal(test.ServerIdentity, hint) { - return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113 + return nil, fmt.Errorf( //nolint:goerr113 + "TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", + test.ServerIdentity, hint, + ) } return []byte{0xAB, 0xC1, 0x23}, nil }, - PSKIdentityHint: clientIdentity, + PSKIdentityHint: test.ClientIdentity, CipherSuites: test.CipherSuites, VerifyConnection: test.ClientVerifyConnection, } - c, err := testClient(ctx, ca, conf, false) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientRes <- result{c, err} }() config := &Config{ PSK: func(hint []byte) ([]byte, error) { - if !bytes.Equal(clientIdentity, hint) { - return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint) + fmt.Println(hint) + if !bytes.Equal(test.ClientIdentity, hint) { + return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, test.ClientIdentity, hint) } + return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: test.ServerIdentity, @@ -567,7 +700,7 @@ func TestPSK(t *testing.T) { VerifyConnection: test.ServerVerifyConnection, } - server, err := testServer(ctx, cb, config, false) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) if test.WantFail { res := <-clientRes if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) { @@ -576,15 +709,23 @@ func TestPSK(t *testing.T) { if res.err == nil || !strings.Contains(res.err.Error(), test.ExpectedClientErr) { t.Fatalf("TestPSK: Client expected(%v) actual(%v)", test.ExpectedClientErr, res.err) } + return } if err != nil { t.Fatalf("TestPSK: Server failed(%v)", err) } - actualPSKIdentityHint := server.ConnectionState().IdentityHint - if !bytes.Equal(actualPSKIdentityHint, clientIdentity) { - t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, clientIdentity, actualPSKIdentityHint) + state, ok := server.ConnectionState() + if !ok { + t.Fatalf("TestPSK: Server ConnectionState failed") + } + actualPSKIdentityHint := state.IdentityHint + if !bytes.Equal(actualPSKIdentityHint, test.ClientIdentity) { + t.Errorf( + "TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ClientIdentity, actualPSKIdentityHint, + ) } defer func() { @@ -619,26 +760,28 @@ func TestPSKHintFail(t *testing.T) { ca, cb := dpipe.Pipe() go func() { conf := &Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func([]byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - _, err := testClient(ctx, ca, conf, false) + _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientErr <- err }() config := &Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func([]byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } - if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) { + if _, err := testServer( + ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false, + ); !errors.Is(err, serverAlertError) { t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err) } @@ -647,6 +790,104 @@ func TestPSKHintFail(t *testing.T) { } } +// Assert that ServerKeyExchange is only sent if Identity is set on server side. +func TestPSKServerKeyExchange(t *testing.T) { //nolint:cyclop + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + for _, test := range []struct { + Name string + SetIdentity bool + }{ + { + Name: "Server Identity Set", + SetIdentity: true, + }, + { + Name: "Server Not Identity Set", + SetIdentity: false, + }, + } { + test := test + t.Run(test.Name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + gotServerKeyExchange := false + + clientErr := make(chan error, 1) + ca, cb := dpipe.Pipe() + cbAnalyzer := &connWithCallback{Conn: cb} + cbAnalyzer.onWrite = func(in []byte) { + messages, err := recordlayer.UnpackDatagram(in) + if err != nil { + t.Fatal(err) + } + + for i := range messages { + h := &handshake.Handshake{} + _ = h.Unmarshal(messages[i][recordlayer.FixedHeaderSize:]) + + if h.Header.Type == handshake.TypeServerKeyExchange { + gotServerKeyExchange = true + } + } + } + + go func() { + conf := &Config{ + PSK: func([]byte) ([]byte, error) { + return []byte{0xAB, 0xC1, 0x23}, nil + }, + PSKIdentityHint: []byte{0xAB, 0xC1, 0x23}, + CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, + } + + if client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false); err != nil { + clientErr <- err + } else { + clientErr <- client.Close() //nolint + } + }() + + config := &Config{ + PSK: func([]byte) ([]byte, error) { + return []byte{0xAB, 0xC1, 0x23}, nil + }, + CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, + } + if test.SetIdentity { + config.PSKIdentityHint = []byte{0xAB, 0xC1, 0x23} + } + + if server, err := testServer( + ctx, dtlsnet.PacketConnFromConn(cbAnalyzer), cbAnalyzer.RemoteAddr(), config, false, + ); err != nil { + t.Fatalf("TestPSK: Server error %v", err) + } else { + if err = server.Close(); err != nil { + t.Fatal(err) + } + } + + if err := <-clientErr; err != nil { + t.Fatalf("TestPSK: Client error %v", err) + } + + if gotServerKeyExchange != test.SetIdentity { + t.Fatalf( + "Mismatch between setting Identity and getting a ServerKeyExchange exp(%t) actual(%t)", + test.SetIdentity, gotServerKeyExchange, + ) + } + }) + } +} + func TestClientTimeout(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) @@ -665,7 +906,7 @@ func TestClientTimeout(t *testing.T) { go func() { conf := &Config{} - c, err := testClient(ctx, ca, conf, true) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, true) if err == nil { _ = c.Close() //nolint:contextcheck } @@ -680,18 +921,20 @@ func TestClientTimeout(t *testing.T) { } } -func TestSRTPConfiguration(t *testing.T) { +func TestSRTPConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { - Name string - ClientSRTP []SRTPProtectionProfile - ServerSRTP []SRTPProtectionProfile - ExpectedProfile SRTPProtectionProfile - WantClientError error - WantServerError error + Name string + ClientSRTP []SRTPProtectionProfile + ServerSRTP []SRTPProtectionProfile + ClientSRTPMasterKeyIdentifier []byte + ServerSRTPMasterKeyIdentifier []byte + ExpectedProfile SRTPProtectionProfile + WantClientError error + WantServerError error }{ { Name: "No SRTP in use", @@ -702,12 +945,14 @@ func TestSRTPConfiguration(t *testing.T) { WantServerError: nil, }, { - Name: "SRTP both ends", - ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, - WantClientError: nil, - WantServerError: nil, + Name: "SRTP both ends", + ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, + ClientSRTPMasterKeyIdentifier: []byte("ClientSRTPMKI"), + ServerSRTPMasterKeyIdentifier: []byte("ServerSRTPMKI"), + WantClientError: nil, + WantServerError: nil, }, { Name: "SRTP client only", @@ -750,16 +995,23 @@ func TestSRTPConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + SRTPProtectionProfiles: test.ClientSRTP, SRTPMasterKeyIdentifier: test.ServerSRTPMasterKeyIdentifier, + }, true) + resultCh <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier, + }, true) if !errors.Is(err, test.WantServerError) { - t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + t.Errorf( + "TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantServerError, err, + ) } if err == nil { defer func() { @@ -767,14 +1019,17 @@ func TestSRTPConfiguration(t *testing.T) { }() } - res := <-c + res := <-resultCh if res.err == nil { defer func() { _ = res.c.Close() }() } if !errors.Is(res.err, test.WantClientError) { - t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) + t.Fatalf( + "TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantClientError, res.err, + ) } if res.c == nil { return @@ -782,17 +1037,39 @@ func TestSRTPConfiguration(t *testing.T) { actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() if actualClientSRTP != test.ExpectedProfile { - t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP) + t.Errorf( + "TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ExpectedProfile, actualClientSRTP, + ) } actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() if actualServerSRTP != test.ExpectedProfile { - t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP) + t.Errorf( + "TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ExpectedProfile, actualServerSRTP, + ) + } + + actualServerMKI, _ := server.RemoteSRTPMasterKeyIdentifier() + if !bytes.Equal(actualServerMKI, test.ServerSRTPMasterKeyIdentifier) { + t.Errorf( + "TestSRTPConfiguration: Server SRTPMKI Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ServerSRTPMasterKeyIdentifier, actualServerMKI, + ) + } + + actualClientMKI, _ := res.c.RemoteSRTPMasterKeyIdentifier() + if !bytes.Equal(actualClientMKI, test.ClientSRTPMasterKeyIdentifier) { + t.Errorf( + "TestSRTPConfiguration: Client SRTPMKI Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.ClientSRTPMasterKeyIdentifier, actualClientMKI, + ) } } } -func TestClientCertificate(t *testing.T) { +func TestClientCertificate(t *testing.T) { //nolint:gocyclo,cyclop,maintidx // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -839,14 +1116,14 @@ func TestClientCertificate(t *testing.T) { Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, - VerifyConnection: func(s *State) error { + VerifyConnection: func(*State) error { return errExample }, }, wantErr: true, }, "NoClientCert_ClientVerifyConnectionFails": { - clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(s *State) error { + clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(*State) error { return errExample }}, serverCfg: &Config{ @@ -863,6 +1140,14 @@ func TestClientCertificate(t *testing.T) { ClientAuth: RequireAnyClientCert, }, }, + "RequestClientCert_cert_sigscheme": { // specify signature algorithm + clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, + serverCfg: &Config{ + SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512}, + Certificates: []tls.Certificate{srvCert}, + ClientAuth: RequestClientCert, + }, + }, "RequestClientCert_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ @@ -918,12 +1203,17 @@ func TestClientCertificate(t *testing.T) { wantErr: true, }, "RequireAndVerifyClientCert": { - clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error { - if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { - return errExample - } - return nil - }}, + clientCfg: &Config{ + RootCAs: srvCAPool, + Certificates: []tls.Certificate{cert}, + VerifyConnection: func(s *State) error { + if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { + return errExample + } + + return nil + }, + }, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, @@ -932,6 +1222,7 @@ func TestClientCertificate(t *testing.T) { if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok { return errExample } + return nil }, }, @@ -940,10 +1231,10 @@ func TestClientCertificate(t *testing.T) { clientCfg: &Config{ RootCAs: srvCAPool, // Certificates: []tls.Certificate{cert}, - GetClientCertificate: func(cri *CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil }, + GetClientCertificate: func(*CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil }, }, serverCfg: &Config{ - GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { return &srvCert, nil }, + GetCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &srvCert, nil }, // Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, ClientCAs: caPool, @@ -955,17 +1246,18 @@ func TestClientCertificate(t *testing.T) { t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } c := make(chan result) go func() { - client, err := Client(ca, tt.clientCfg) - c <- result{client, err} + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) + c <- result{client, err, client.Handshake()} }() - server, err := Server(cb, tt.serverCfg) + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) + hserr := server.Handshake() res := <-c defer func() { if err == nil { @@ -977,7 +1269,7 @@ func TestClientCertificate(t *testing.T) { }() if tt.wantErr { - if err != nil { + if err != nil || hserr != nil { // Error expected, test succeeded return } @@ -991,8 +1283,14 @@ func TestClientCertificate(t *testing.T) { t.Errorf("Client failed(%v)", res.err) } - actualClientCert := server.ConnectionState().PeerCertificates - if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { + state, ok := server.ConnectionState() + if !ok { + t.Error("Server connection state not available") + } + actualClientCert := state.PeerCertificates + //nolint:nestif + if tt.serverCfg.ClientAuth == RequireAnyClientCert || + tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { if actualClientCert == nil { t.Errorf("Client did not provide a certificate") } @@ -1018,7 +1316,11 @@ func TestClientCertificate(t *testing.T) { } } - actualServerCert := res.c.ConnectionState().PeerCertificates + clientState, ok := res.c.ConnectionState() + if !ok { + t.Error("Client connection state not available") + } + actualServerCert := clientState.PeerCertificates if actualServerCert == nil { t.Errorf("Server did not provide a certificate") } @@ -1041,6 +1343,132 @@ func TestClientCertificate(t *testing.T) { }) } +func TestConnectionID(t *testing.T) { + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + clientCID := []byte{5, 77, 33, 24, 93, 27, 45, 81} + serverCID := []byte{64, 24, 73, 2, 17, 96, 38, 59} + cidEcho := func(echo []byte) func() []byte { + return func() []byte { + return echo + } + } + tests := map[string]struct { + clientCfg *Config + serverCfg *Config + clientConnectionID []byte + serverConnectionID []byte + }{ + "BidirectionalConnectionIDs": { + clientCfg: &Config{ + ConnectionIDGenerator: cidEcho(clientCID), + }, + serverCfg: &Config{ + ConnectionIDGenerator: cidEcho(serverCID), + }, + clientConnectionID: clientCID, + serverConnectionID: serverCID, + }, + "BothSupportOnlyClientSends": { + clientCfg: &Config{ + ConnectionIDGenerator: cidEcho(nil), + }, + serverCfg: &Config{ + ConnectionIDGenerator: cidEcho(serverCID), + }, + serverConnectionID: serverCID, + }, + "BothSupportOnlyServerSends": { + clientCfg: &Config{ + ConnectionIDGenerator: cidEcho(clientCID), + }, + serverCfg: &Config{ + ConnectionIDGenerator: cidEcho(nil), + }, + clientConnectionID: clientCID, + }, + "ClientDoesNotSupport": { + clientCfg: &Config{}, + serverCfg: &Config{ + ConnectionIDGenerator: cidEcho(serverCID), + }, + }, + "ServerDoesNotSupport": { + clientCfg: &Config{ + ConnectionIDGenerator: cidEcho(clientCID), + }, + serverCfg: &Config{}, + }, + "NeitherSupport": { + clientCfg: &Config{}, + serverCfg: &Config{}, + }, + } + for name, tt := range tests { + tt := tt + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ca, cb := dpipe.Pipe() + type result struct { + c *Conn + err error + } + c := make(chan result) + + go func() { + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) + c <- result{client, err} + }() + + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) + if err != nil { + t.Fatalf("Unexpected server error: %v", err) + } + res := <-c + if res.err != nil { + t.Fatalf("Unexpected client error: %v", res.err) + } + defer func() { + if err == nil { + _ = server.Close() + } + if res.err == nil { + _ = res.c.Close() + } + }() + + if !bytes.Equal(res.c.state.getLocalConnectionID(), tt.clientConnectionID) { + t.Errorf( + "Unexpected client local connection ID\nwant: %v\ngot:%v", + tt.clientConnectionID, res.c.state.localConnectionID, + ) + } + if !bytes.Equal(res.c.state.remoteConnectionID, tt.serverConnectionID) { + t.Errorf( + "Unexpected client remote connection ID\nwant: %v\ngot:%v", + tt.serverConnectionID, res.c.state.remoteConnectionID, + ) + } + if !bytes.Equal(server.state.getLocalConnectionID(), tt.serverConnectionID) { + t.Errorf( + "Unexpected server local connection ID\nwant: %v\ngot:%v", + tt.serverConnectionID, server.state.localConnectionID, + ) + } + if !bytes.Equal(server.state.remoteConnectionID, tt.clientConnectionID) { + t.Errorf( + "Unexpected server remote connection ID\nwant: %v\ngot:%v", + tt.clientConnectionID, server.state.remoteConnectionID, + ) + } + }) + } +} + func TestExtendedMasterSecret(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) @@ -1157,11 +1585,11 @@ func TestExtendedMasterSecret(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ctx, ca, tt.clientCfg, true) + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() - server, err := testServer(ctx, cb, tt.serverCfg, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) res := <-c defer func() { if err == nil { @@ -1183,7 +1611,7 @@ func TestExtendedMasterSecret(t *testing.T) { } } -func TestServerCertificate(t *testing.T) { +func TestServerCertificate(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -1220,21 +1648,32 @@ func TestServerCertificate(t *testing.T) { }, "good_ca_skip_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, - serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error { - if len(chain) != 0 { - return errNotExpectedChain - } - return nil - }}, + serverCfg: &Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: RequireAnyClientCert, + VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { + if len(chain) != 0 { + return errNotExpectedChain + } + + return nil + }, + }, }, "good_ca_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, - serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error { - if len(chain) == 0 { - return errExpecedChain - } - return nil - }}, + serverCfg: &Config{ + ClientCAs: caPool, + Certificates: []tls.Certificate{cert}, + ClientAuth: RequireAndVerifyClientCert, + VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { + if len(chain) == 0 { + return errExpecedChain + } + + return nil + }, + }, }, "good_ca_custom_verify_peer": { clientCfg: &Config{ @@ -1262,23 +1701,24 @@ func TestServerCertificate(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { - c *Conn - err error + c *Conn + err, hserr error } srvCh := make(chan result) go func() { - s, err := Server(cb, tt.serverCfg) - srvCh <- result{s, err} + s, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) + srvCh <- result{s, err, s.Handshake()} }() - cli, err := Client(ca, tt.clientCfg) + cli, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) + hserr := cli.Handshake() if err == nil { _ = cli.Close() } - if !tt.wantErr && err != nil { - t.Errorf("Client failed(%v)", err) + if !tt.wantErr && (err != nil || hserr != nil) { + t.Errorf("Client failed(%v, %v)", err, hserr) } - if tt.wantErr && err == nil { + if tt.wantErr && err == nil && hserr == nil { t.Fatal("Error expected") } @@ -1350,8 +1790,10 @@ func TestCipherSuiteConfiguration(t *testing.T) { WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, }, { - Name: "Server supports subset of client suites", - ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, + Name: "Server supports subset of client suites", + ClientCipherSuites: []CipherSuiteID{ + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + }, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: nil, WantServerError: nil, @@ -1368,33 +1810,46 @@ func TestCipherSuiteConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{CipherSuites: test.ClientCipherSuites}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + CipherSuites: test.ClientCipherSuites, + }, true) + resultCh <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{CipherSuites: test.ServerCipherSuites}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + CipherSuites: test.ServerCipherSuites, + }, true) if err == nil { defer func() { _ = server.Close() }() } if !errors.Is(err, test.WantServerError) { - t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + t.Errorf( + "TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantServerError, err, + ) } - res := <-c + res := <-resultCh if res.err == nil { _ = server.Close() _ = res.c.Close() } if !errors.Is(res.err, test.WantClientError) { - t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) + t.Errorf( + "TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantClientError, res.err, + ) } if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite { - t.Errorf("TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID()) + t.Errorf( + "TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", + test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID(), + ) } }) } @@ -1428,7 +1883,7 @@ func TestCertificateAndPSKServer(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}} @@ -1440,8 +1895,8 @@ func TestCertificateAndPSKServer(t *testing.T) { config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256} } - client, err := testClient(ctx, ca, config, false) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) + resultCh <- result{client, err} }() config := &Config{ @@ -1451,7 +1906,7 @@ func TestCertificateAndPSKServer(t *testing.T) { }, } - server, err := testServer(ctx, cb, config, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) if err == nil { defer func() { _ = server.Close() @@ -1460,18 +1915,21 @@ func TestCertificateAndPSKServer(t *testing.T) { t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err) } - res := <-c + res := <-resultCh if res.err == nil { _ = server.Close() _ = res.c.Close() } else { - t.Errorf("TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, res.err) + t.Errorf( + "TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, nil, res.err, + ) } }) } } -func TestPSKConfiguration(t *testing.T) { +func TestPSKConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -1540,30 +1998,50 @@ func TestPSKConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) - c <- result{client, err} + client, err := testClient( + ctx, + dtlsnet.PacketConnFromConn(ca), + ca.RemoteAddr(), + &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, + test.ClientHasCertificate, + ) + resultCh <- result{client, err} }() - _, err := testServer(ctx, cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) + _, err := testServer( + ctx, + dtlsnet.PacketConnFromConn(cb), + cb.RemoteAddr(), + &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, + test.ServerHasCertificate, + ) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { - t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + t.Fatalf( + "TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, test.WantServerError, err, + ) } } - res := <-c + res := <-resultCh if res.err != nil || test.WantClientError != nil { if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) { - t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) + t.Fatalf( + "TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", + test.Name, + test.WantClientError, + res.err, + ) } } } } -func TestServerTimeout(t *testing.T) { +func TestServerTimeout(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -1677,7 +2155,7 @@ func TestServerTimeout(t *testing.T) { FlightInterval: 100 * time.Millisecond, } - _, serverErr := testServer(ctx, cb, config, true) + _, serverErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) var netErr net.Error if !errors.As(serverErr, &netErr) || !netErr.Timeout() { t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr) @@ -1692,7 +2170,7 @@ func TestServerTimeout(t *testing.T) { } } -func TestProtocolVersionValidation(t *testing.T) { +func TestProtocolVersionValidation(t *testing.T) { //nolint:cyclop,maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -1773,8 +2251,8 @@ func TestProtocolVersionValidation(t *testing.T) { }, }, } - for name, c := range serverCases { - c := c + for name, serverCase := range serverCases { + serverCase := serverCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { @@ -1792,7 +2270,13 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testServer( + ctx, + dtlsnet.PacketConnFromConn(cb), + cb.RemoteAddr(), + config, + true, + ); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() @@ -1800,7 +2284,7 @@ func TestProtocolVersionValidation(t *testing.T) { time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) - for _, record := range c.records { + for _, record := range serverCase.records { packet, err := record.Marshal() if err != nil { t.Fatal(err) @@ -1853,9 +2337,13 @@ func TestProtocolVersionValidation(t *testing.T) { MessageSequence: 1, }, Message: &handshake.MessageServerHello{ - Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade - Random: random, - CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); return &id }(), + Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade + Random: random, + CipherSuiteID: func() *uint16 { + id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + + return &id + }(), CompressionMethod: defaultCompressionMethods()[0], }, }, @@ -1863,8 +2351,8 @@ func TestProtocolVersionValidation(t *testing.T) { }, }, } - for name, c := range clientCases { - c := c + for name, clientCase := range clientCases { + clientCase := clientCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { @@ -1882,14 +2370,16 @@ func TestProtocolVersionValidation(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { + if _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); !errors.Is( + err, errUnsupportedProtocolVersion, + ) { t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() time.Sleep(50 * time.Millisecond) - for _, record := range c.records { + for _, record := range clientCase.records { if _, err := ca.Read(make([]byte, 1024)); err != nil { t.Fatal(err) } @@ -1921,7 +2411,7 @@ func TestProtocolVersionValidation(t *testing.T) { }) } -func TestMultipleHelloVerifyRequest(t *testing.T) { +func TestMultipleHelloVerifyRequest(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -1949,7 +2439,7 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { }, Content: &handshake.Handshake{ Header: handshake.Header{ - MessageSequence: uint16(i), + MessageSequence: uint16(i), //nolint:gosec // G115 }, Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, @@ -1980,7 +2470,7 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { defer wg.Wait() go func() { defer wg.Done() - _, _ = testClient(ctx, ca, &Config{}, false) + _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{}, false) }() for i, cookie := range cookies { @@ -2014,8 +2504,8 @@ func TestMultipleHelloVerifyRequest(t *testing.T) { } // Assert that a DTLS Server always responds with RenegotiationInfo if -// a ClientHello contained that extension or not -func TestRenegotationInfo(t *testing.T) { +// a ClientHello contained that extension or not. +func TestRenegotationInfo(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(10 * time.Second) defer lim.Stop() @@ -2052,7 +2542,13 @@ func TestRenegotationInfo(t *testing.T) { defer cancel() go func() { - if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer( + ctx, + dtlsnet.PacketConnFromConn(cb), + cb.RemoteAddr(), + &Config{}, + true, + ); !errors.Is(err, context.Canceled) { t.Error(err) } }() @@ -2073,12 +2569,12 @@ func TestRenegotationInfo(t *testing.T) { if err != nil { t.Fatal(err) } - r := &recordlayer.RecordLayer{} - if err = r.Unmarshal(resp[:n]); err != nil { + record := &recordlayer.RecordLayer{} + if err = record.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } - helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) + helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) if !ok { t.Fatal("Failed to cast MessageHelloVerifyRequest") } @@ -2096,11 +2592,11 @@ func TestRenegotationInfo(t *testing.T) { t.Fatal(err) } - if err := r.Unmarshal(messages[0]); err != nil { + if err := record.Unmarshal(messages[0]); err != nil { t.Fatal(err) } - serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) + serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast MessageServerHello") } @@ -2164,7 +2660,7 @@ func TestServerNameIndicationExtension(t *testing.T) { ServerName: test.ServerName, } - _, _ = testClient(ctx, ca, conf, false) + _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello @@ -2208,7 +2704,7 @@ func TestServerNameIndicationExtension(t *testing.T) { } } -func TestALPNExtension(t *testing.T) { +func TestALPNExtension(t *testing.T) { //nolint:cyclop,maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2282,7 +2778,7 @@ func TestALPNExtension(t *testing.T) { conf := &Config{ SupportedProtocols: test.ClientProtocolNameList, } - _, _ = testClient(ctx, ca, conf, false) + _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello @@ -2300,7 +2796,9 @@ func TestALPNExtension(t *testing.T) { conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } - if _, err2 := testServer(ctx2, cb2, conf, true); !errors.Is(err2, context.Canceled) { + if _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true); !errors.Is( + err2, context.Canceled, + ) { if test.ExpectAlertFromServer { //nolint // Assert the error type? } else { @@ -2352,13 +2850,13 @@ func TestALPNExtension(t *testing.T) { t.Fatal(err) } - r := &recordlayer.RecordLayer{} - if err := r.Unmarshal(messages[0]); err != nil { + record := &recordlayer.RecordLayer{} + if err := record.Unmarshal(messages[0]); err != nil { t.Fatal(err) } - if test.ExpectAlertFromServer { - a, ok := r.Content.(*alert.Alert) + if test.ExpectAlertFromServer { //nolint:nestif + a, ok := record.Content.(*alert.Alert) if !ok { t.Fatal("Failed to cast alert.Alert") } @@ -2367,7 +2865,7 @@ func TestALPNExtension(t *testing.T) { t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) } } else { - serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) + serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast handshake.MessageServerHello") } @@ -2393,7 +2891,7 @@ func TestALPNExtension(t *testing.T) { t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.ExpectedProtocol, negotiatedProtocol) } - s, err := r.Marshal() + s, err := record.Marshal() if err != nil { t.Fatal(err) } @@ -2431,8 +2929,8 @@ func TestALPNExtension(t *testing.T) { } } -// Make sure the supported_groups extension is not included in the ServerHello -func TestSupportedGroupsExtension(t *testing.T) { +// Make sure the supported_groups extension is not included in the ServerHello. +func TestSupportedGroupsExtension(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2447,7 +2945,9 @@ func TestSupportedGroupsExtension(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true); !errors.Is( + err, context.Canceled, + ) { t.Error(err) } }() @@ -2473,12 +2973,12 @@ func TestSupportedGroupsExtension(t *testing.T) { if err != nil { t.Fatal(err) } - r := &recordlayer.RecordLayer{} - if err = r.Unmarshal(resp[:n]); err != nil { + record := &recordlayer.RecordLayer{} + if err = record.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } - helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) + helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) if !ok { t.Fatal("Failed to cast MessageHelloVerifyRequest") } @@ -2496,11 +2996,11 @@ func TestSupportedGroupsExtension(t *testing.T) { t.Fatal(err) } - if err := r.Unmarshal(messages[0]); err != nil { + if err := record.Unmarshal(messages[0]); err != nil { t.Fatal(err) } - serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) + serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast MessageServerHello") } @@ -2518,7 +3018,7 @@ func TestSupportedGroupsExtension(t *testing.T) { }) } -func TestSessionResume(t *testing.T) { +func TestSessionResume(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2540,7 +3040,9 @@ func TestSessionResume(t *testing.T) { ss := &memSessStore{} id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") - secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") + secret, _ := hex.DecodeString( + "2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7", + ) s := Session{ID: id, Secret: secret} @@ -2556,7 +3058,7 @@ func TestSessionResume(t *testing.T) { SessionStore: ss, MTU: 100, } - c, err := testClient(ctx, ca, config, false) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() @@ -2566,13 +3068,17 @@ func TestSessionResume(t *testing.T) { SessionStore: ss, MTU: 100, } - server, err := testServer(ctx, cb, config, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) if err != nil { t.Fatalf("TestSessionResume: Server failed(%v)", err) } - actualSessionID := server.ConnectionState().SessionID - actualMasterSecret := server.ConnectionState().masterSecret + state, ok := server.ConnectionState() + if !ok { + t.Fatal("TestSessionResume: ConnectionState failed") + } + actualSessionID := state.SessionID + actualMasterSecret := state.masterSecret if !bytes.Equal(actualSessionID, id) { t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID) } @@ -2610,20 +3116,24 @@ func TestSessionResume(t *testing.T) { ServerName: "example.com", SessionStore: s1, } - c, err := testClient(ctx, ca, config, false) + c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() config := &Config{ SessionStore: s2, } - server, err := testServer(ctx, cb, config, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) if err != nil { t.Fatalf("TestSessionResumetion: Server failed(%v)", err) } - actualSessionID := server.ConnectionState().SessionID - actualMasterSecret := server.ConnectionState().masterSecret + state, ok := server.ConnectionState() + if !ok { + t.Fatal("TestSessionResumetion: ConnectionState failed") + } + actualSessionID := state.SessionID + actualMasterSecret := state.masterSecret ss, _ := s2.Get(actualSessionID) if !bytes.Equal(actualMasterSecret, ss.Secret) { t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) @@ -2681,7 +3191,8 @@ func (ms *memSessStore) Del(key []byte) error { // Assert that the server only uses CipherSuites with a hash+signature that matches // the certificate. As specified in rfc5246#section-7.4.3 -func TestCipherSuiteMatchesCertificateType(t *testing.T) { +// . +func TestCipherSuiteMatchesCertificateType(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -2715,32 +3226,34 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), ca, &Config{CipherSuites: test.cipherList}, false) + c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + CipherSuites: test.cipherList, + }, false) clientErr <- err client <- c }() var ( - priv crypto.PrivateKey - err error + signer crypto.Signer + err error ) if test.generateRSA { - if priv, err = rsa.GenerateKey(rand.Reader, 2048); err != nil { + if signer, err = rsa.GenerateKey(rand.Reader, 2048); err != nil { t.Fatal(err) } } else { - if priv, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil { + if signer, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil { t.Fatal(err) } } - serverCert, err := selfsign.SelfSign(priv) + serverCert, err := selfsign.SelfSign(signer) if err != nil { t.Fatal(err) } - if s, err := testServer(context.TODO(), cb, &Config{ + if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.cipherList, Certificates: []tls.Certificate{serverCert}, }, false); err != nil { @@ -2753,15 +3266,15 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) { t.Fatal(err) } else if err := c.Close(); err != nil { t.Fatal(err) - } else if c.ConnectionState().cipherSuite.ID() != test.expectedCipher { - t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, c.ConnectionState().cipherSuite.ID()) + } else if state, ok := c.ConnectionState(); !ok || state.cipherSuite.ID() != test.expectedCipher { + t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, state.cipherSuite.ID()) } }) } } -// Test that we return the proper certificate if we are serving multiple ServerNames on a single Server -func TestMultipleServerCertificates(t *testing.T) { +// Test that we return the proper certificate if we are serving multiple ServerNames on a single Server. +func TestMultipleServerCertificates(t *testing.T) { //nolint:cyclop fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo") if err != nil { t.Fatal(err) @@ -2805,10 +3318,10 @@ func TestMultipleServerCertificates(t *testing.T) { ca, cb := dpipe.Pipe() go func() { - c, err := testClient(context.TODO(), ca, &Config{ + clientConn, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ RootCAs: caPool, ServerName: test.RequestServerName, - VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { certificate, err := x509.ParseCertificate(rawCerts[0]) if err != nil { return err @@ -2822,10 +3335,12 @@ func TestMultipleServerCertificates(t *testing.T) { }, }, false) clientErr <- err - client <- c + client <- clientConn }() - if s, err := testServer(context.TODO(), cb, &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { + if s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{fooCert, barCert}, + }, false); err != nil { t.Fatal(err) } else if err = s.Close(); err != nil { t.Fatal(err) @@ -2840,7 +3355,7 @@ func TestMultipleServerCertificates(t *testing.T) { } } -func TestEllipticCurveConfiguration(t *testing.T) { +func TestEllipticCurveConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -2848,22 +3363,22 @@ func TestEllipticCurveConfiguration(t *testing.T) { for _, test := range []struct { Name string ConfigCurves []elliptic.Curve - HadnshakeCurves []elliptic.Curve + HandshakeCurves []elliptic.Curve }{ { Name: "Curve defaulting", ConfigCurves: nil, - HadnshakeCurves: defaultCurves, + HandshakeCurves: defaultCurves, }, { Name: "Single curve", ConfigCurves: []elliptic.Curve{elliptic.X25519}, - HadnshakeCurves: []elliptic.Curve{elliptic.X25519}, + HandshakeCurves: []elliptic.Curve{elliptic.X25519}, }, { Name: "Multiple curves", ConfigCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, - HadnshakeCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, + HandshakeCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -2874,25 +3389,39 @@ func TestEllipticCurveConfiguration(t *testing.T) { c *Conn err error } - c := make(chan result) + resultCh := make(chan result) go func() { - client, err := testClient(ctx, ca, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) - c <- result{client, err} + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + EllipticCurves: test.ConfigCurves, + }, true) + resultCh <- result{client, err} }() - server, err := testServer(ctx, cb, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + EllipticCurves: test.ConfigCurves, + }, true) if err != nil { t.Fatalf("Server error: %v", err) } - if len(test.ConfigCurves) == 0 && len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) { - t.Fatalf("Failed to default Elliptic curves, expected %d, got: %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves)) + if len(test.ConfigCurves) == 0 && len(test.HandshakeCurves) != len(server.fsm.cfg.ellipticCurves) { + t.Fatalf( + "Failed to default Elliptic curves, expected %d, got: %d", + len(test.HandshakeCurves), + len(server.fsm.cfg.ellipticCurves), + ) } if len(test.ConfigCurves) != 0 { - if len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) { - t.Fatalf("Failed to configure Elliptic curves, expect %d, got %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves)) + if len(test.HandshakeCurves) != len(server.fsm.cfg.ellipticCurves) { + t.Fatalf( + "Failed to configure Elliptic curves, expect %d, got %d", + len(test.HandshakeCurves), + len(server.fsm.cfg.ellipticCurves), + ) } for i, c := range test.ConfigCurves { if c != server.fsm.cfg.ellipticCurves[i] { @@ -2901,7 +3430,7 @@ func TestEllipticCurveConfiguration(t *testing.T) { } } - res := <-c + res := <-resultCh if res.err != nil { t.Fatalf("Client error; %v", err) } @@ -2933,17 +3462,18 @@ func TestSkipHelloVerify(t *testing.T) { gotHello := make(chan struct{}) go func() { - server, sErr := testServer(ctx, cb, &Config{ + server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerifyHello: true, }, false) if sErr != nil { t.Error(sErr) + return } buf := make([]byte, 1024) - if _, sErr = server.Read(buf); sErr != nil { + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck t.Error(sErr) } gotHello <- struct{}{} @@ -2952,7 +3482,7 @@ func TestSkipHelloVerify(t *testing.T) { } }() - client, err := testClient(ctx, ca, &Config{ + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) @@ -2973,3 +3503,399 @@ func TestSkipHelloVerify(t *testing.T) { t.Error(err) } } + +type connWithCallback struct { + net.Conn + onWrite func([]byte) +} + +func (c *connWithCallback) Write(b []byte) (int, error) { + if c.onWrite != nil { + c.onWrite(b) + } + + return c.Conn.Write(b) +} + +func TestApplicationDataQueueLimited(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ca, cb := dpipe.Pipe() + defer ca.Close() //nolint:errcheck + defer cb.Close() //nolint:errcheck + + done := make(chan struct{}) + go func() { + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Error(err) + + return + } + cfg := &Config{} + cfg.Certificates = []tls.Certificate{serverCert} + + dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) + if err != nil { + t.Error(err) + + return + } + go func() { + for i := 0; i < 5; i++ { + dconn.lock.RLock() + qlen := len(dconn.encryptedPackets) + dconn.lock.RUnlock() + if qlen > maxAppDataPacketQueueSize { + t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets)) + } + time.Sleep(1 * time.Second) + } + }() + if err := dconn.HandshakeContext(ctx); err == nil { + t.Error("expected handshake to fail") + } + close(done) + }() + extensions := []extension.Extension{} + + time.Sleep(50 * time.Millisecond) + + err := sendClientHello([]byte{}, ca, 0, extensions) + if err != nil { + t.Fatal(err) + } + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 1000; i++ { + // Send an application data packet + packet, err := (&recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + SequenceNumber: uint64(3), + Epoch: 1, // use an epoch greater than 0 + }, + Content: &protocol.ApplicationData{ + Data: []byte{1, 2, 3, 4}, + }, + }).Marshal() + if err != nil { + t.Fatal(err) + } + ca.Write(packet) // nolint + if i%100 == 0 { + time.Sleep(10 * time.Millisecond) + } + } + time.Sleep(1 * time.Second) + ca.Close() // nolint + <-done +} + +func TestHelloRandom(t *testing.T) { //nolint:cyclop + report := test.CheckRoutines(t) + defer report() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ca, cb := dpipe.Pipe() + certificate, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + gotHello := make(chan struct{}) + + chRandom := [handshake.RandomBytesLength]byte{} + _, err = rand.Read(chRandom[:]) + if err != nil { + t.Fatal(err) + } + + go func() { + server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { + if len(chi.CipherSuites) == 0 { + return &certificate, nil + } + + if !bytes.Equal(chi.RandomBytes[:], chRandom[:]) { + t.Error("client hello random differs") + } + + return &certificate, nil + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + }, false) + if sErr != nil { + t.Error(sErr) + + return + } + buf := make([]byte, 1024) + if _, sErr = server.Read(buf); sErr != nil { //nolint:contextcheck + t.Error(sErr) + } + gotHello <- struct{}{} + if sErr = server.Close(); sErr != nil { //nolint:contextcheck + t.Error(sErr) + } + }() + + client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + LoggerFactory: logging.NewDefaultLoggerFactory(), + HelloRandomBytesGenerator: func() [handshake.RandomBytesLength]byte { + return chRandom + }, + InsecureSkipVerify: true, + }, false) + if err != nil { + t.Fatal(err) + } + if _, err = client.Write([]byte("hello")); err != nil { + t.Error(err) + } + select { + case <-gotHello: + // OK + case <-time.After(time.Second * 5): + t.Error("timeout") + } + + if err = client.Close(); err != nil { + t.Error(err) + } +} + +func TestOnConnectionAttempt(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*20) + defer cancel() + + var clientOnConnectionAttempt, serverOnConnectionAttempt atomic.Int32 + + ca, cb := dpipe.Pipe() + clientErr := make(chan error, 1) + go func() { + _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + OnConnectionAttempt: func(in net.Addr) error { + clientOnConnectionAttempt.Store(1) + if in == nil { + t.Fatal("net.Addr is nil") //nolint: govet + } + + return nil + }, + }, true) + clientErr <- err + }() + + expectedErr := &FatalError{} + if _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + OnConnectionAttempt: func(in net.Addr) error { + serverOnConnectionAttempt.Store(1) + if in == nil { + t.Fatal("net.Addr is nil") //nolint: govet + } + + return expectedErr + }, + }, true); !errors.Is(err, expectedErr) { + t.Fatal(err) + } + + if err := <-clientErr; err == nil { + t.Fatal(err) + } + + if v := serverOnConnectionAttempt.Load(); v != 1 { + t.Fatal("OnConnectionAttempt did not fire for server") + } + + if v := clientOnConnectionAttempt.Load(); v != 0 { + t.Fatal("OnConnectionAttempt fired for client") + } +} + +func TestFragmentBuffer_Retransmission(t *testing.T) { + fragmentBuffer := newFragmentBuffer() + frag := []byte{ + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, 0x03, 0x00, + 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, + } + + if _, isRetransmission, err := fragmentBuffer.push(frag); err != nil { + t.Fatal(err) + } else if isRetransmission { + t.Fatal("fragment should not be retransmission") + } + + if v, _ := fragmentBuffer.pop(); v == nil { + t.Fatal("Failed to pop fragment") + } + + if _, isRetransmission, err := fragmentBuffer.push(frag); err != nil { + t.Fatal(err) + } else if !isRetransmission { + t.Fatal("fragment should be retransmission") + } +} + +func TestConnectionState(t *testing.T) { + ca, cb := dpipe.Pipe() + + // Setup client + clientCfg := &Config{} + clientCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + clientCfg.Certificates = []tls.Certificate{clientCert} + clientCfg.InsecureSkipVerify = true + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientCfg) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = client.Close() + }() + + _, ok := client.ConnectionState() + if ok { + t.Fatal("ConnectionState should be nil") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + errorChannel := make(chan error) + go func() { + errC := client.HandshakeContext(ctx) + errorChannel <- errC + }() + + // Setup server + server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = server.Close() + }() + + err = <-errorChannel + if err != nil { + t.Fatal(err) + } + + _, ok = client.ConnectionState() + if !ok { + t.Fatal("ConnectionState should not be nil") + } +} + +func TestMultiHandshake(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + ca, cb := dpipe.Pipe() + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{serverCert}, + }) + if err != nil { + t.Fatal(err) + } + + go func() { + _ = server.Handshake() + }() + + clientCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{clientCert}, + }) + if err != nil { + t.Fatal(err) + } + + if err = client.Handshake(); err == nil { + t.Fatal(err) + } + + if err = client.Handshake(); err == nil { + t.Fatal(err) + } + + if err = server.Close(); err != nil { + t.Fatal(err) + } + + if err = client.Close(); err != nil { + t.Fatal(err) + } +} + +func TestCloseDuringHandshake(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 100; i++ { + _, cb := dpipe.Pipe() + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{serverCert}, + }) + if err != nil { + t.Fatal(err) + } + + waitChan := make(chan struct{}) + go func() { + close(waitChan) + _ = server.Handshake() + }() + + <-waitChan + if err = server.Close(); err != nil { + t.Fatal(err) + } + } +} + +func TestCloseWithoutHandshake(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 10).Stop() + + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + _, cb := dpipe.Pipe() + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{serverCert}, + }) + if err != nil { + t.Fatal(err) + } + if err = server.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/connection_id.go b/connection_id.go new file mode 100644 index 000000000..c590499b4 --- /dev/null +++ b/connection_id.go @@ -0,0 +1,105 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto/rand" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +// RandomCIDGenerator is a random Connection ID generator where CID is the +// specified size. Specifying a size of 0 will indicate to peers that sending a +// Connection ID is not necessary. +func RandomCIDGenerator(size int) func() []byte { + return func() []byte { + cid := make([]byte, size) + if _, err := rand.Read(cid); err != nil { + panic(err) //nolint -- nonrecoverable + } + + return cid + } +} + +// OnlySendCIDGenerator enables sending Connection IDs negotiated with a peer, +// but indicates to the peer that sending Connection IDs in return is not +// necessary. +func OnlySendCIDGenerator() func() []byte { + return func() []byte { + return nil + } +} + +// cidDatagramRouter extracts connection IDs from incoming datagram payloads and +// uses them to route to the proper connection. +// NOTE: properly routing datagrams based on connection IDs requires using +// constant size connection IDs. +func cidDatagramRouter(size int) func([]byte) (string, bool) { + return func(packet []byte) (string, bool) { + pkts, err := recordlayer.ContentAwareUnpackDatagram(packet, size) + if err != nil || len(pkts) < 1 { + return "", false + } + for _, pkt := range pkts { + h := &recordlayer.Header{ + ConnectionID: make([]byte, size), + } + if err := h.Unmarshal(pkt); err != nil { + continue + } + if h.ContentType != protocol.ContentTypeConnectionID { + continue + } + + return string(h.ConnectionID), true + } + + return "", false + } +} + +// cidConnIdentifier extracts connection IDs from outgoing ServerHello records +// and associates them with the associated connection. +// NOTE: a ServerHello should always be the first record in a datagram if +// multiple are present, so we avoid iterating through all packets if the first +// is not a ServerHello. +func cidConnIdentifier() func([]byte) (string, bool) { //nolint:cyclop + return func(packet []byte) (string, bool) { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return "", false + } + var h recordlayer.Header + if hErr := h.Unmarshal(pkts[0]); hErr != nil { + return "", false + } + if h.ContentType != protocol.ContentTypeHandshake { + return "", false + } + var hh handshake.Header + var sh handshake.MessageServerHello + for _, pkt := range pkts { + if hhErr := hh.Unmarshal(pkt[recordlayer.FixedHeaderSize:]); hhErr != nil { + continue + } + if err = sh.Unmarshal(pkt[recordlayer.FixedHeaderSize+handshake.HeaderLength:]); err == nil { + break + } + } + if err != nil { + return "", false + } + for _, ext := range sh.Extensions { + if e, ok := ext.(*extension.ConnectionID); ok { + return string(e.CID), true + } + } + + return "", false + } +} diff --git a/connection_id_test.go b/connection_id_test.go new file mode 100644 index 000000000..aba5f72d0 --- /dev/null +++ b/connection_id_test.go @@ -0,0 +1,290 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "testing" + "time" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func TestRandomConnectionIDGenerator(t *testing.T) { + cases := map[string]struct { + reason string + size int + }{ + "LengthMatch": { + reason: "Zero size should match length of generated CID.", + size: 0, + }, + "LengthMatchSome": { + reason: "Non-zero size should match length of generated CID with non-zero.", + size: 8, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + if cidLen := len(RandomCIDGenerator(tc.size)()); cidLen != tc.size { + t.Errorf("%s\nRandomCIDGenerator: expected CID length %d, but got %d.", tc.reason, tc.size, cidLen) + } + }) + } +} + +func TestOnlySendCIDGenerator(t *testing.T) { + cases := map[string]struct { + reason string + }{ + "LengthMatch": { + reason: "CID length should always be zero.", + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + if cidLen := len(OnlySendCIDGenerator()()); cidLen != 0 { + t.Errorf("%s\nOnlySendCIDGenerator: expected CID length %d, but got %d.", tc.reason, 0, cidLen) + } + }) + } +} + +func TestCIDDatagramRouter(t *testing.T) { + cid := []byte("abcd1234") + cidLen := 8 + appRecord, err := (&recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Epoch: 1, + Version: protocol.Version1_2, + }, + Content: &protocol.ApplicationData{ + Data: []byte("application data"), + }, + }).Marshal() + if err != nil { + t.Fatal(err) + } + appData, err := (&protocol.ApplicationData{ + Data: []byte("some data"), + }).Marshal() + if err != nil { + t.Fatal(err) + } + inner, err := (&recordlayer.InnerPlaintext{ + Content: appData, + RealType: protocol.ContentTypeApplicationData, + }).Marshal() + if err != nil { + t.Fatal(err) + } + cidHeader, err := (&recordlayer.Header{ + Epoch: 1, + Version: protocol.Version1_2, + ContentType: protocol.ContentTypeConnectionID, + ContentLen: uint16(len(inner)), //nolint:gosec // G115 + ConnectionID: cid, + SequenceNumber: 1, + }).Marshal() + if err != nil { + t.Fatal(err) + } + cases := map[string]struct { + reason string + size int + datagram []byte + ok bool + want string + }{ + "EmptyDatagram": { + reason: "If datagram is empty, we cannot extract an identifier", + size: cidLen, + datagram: []byte{}, + ok: false, + want: "", + }, + "NotADTLSRecord": { + reason: "If datagram is not a DTLS record, we cannot extract an identifier", + size: cidLen, + datagram: []byte("not a DTLS record"), + ok: false, + want: "", + }, + "NotAConnectionIDDatagram": { + reason: "If datagram does not contain any Connection ID records, we cannot extract an identifier", + size: cidLen, + datagram: appRecord, + ok: false, + want: "", + }, + "OneRecordConnectionID": { + reason: "If datagram contains one Connection ID record, we should be able to extract it.", + size: cidLen, + datagram: append(cidHeader, inner...), + ok: true, + want: string(cid), + }, + "OneRecordConnectionIDAltLength": { + //nolint:lll + reason: "If datagram contains one Connection ID record, but it has the wrong length we should not be able to extract it.", + size: cidLen, + datagram: func() []byte { + altCIDHeader, err := (&recordlayer.Header{ + Epoch: 1, + Version: protocol.Version1_2, + ContentType: protocol.ContentTypeConnectionID, + ContentLen: uint16(len(inner)), //nolint:gosec // G115 + ConnectionID: []byte("abcd"), + SequenceNumber: 1, + }).Marshal() + if err != nil { + t.Fatal(err) + } + + return append(altCIDHeader, inner...) + }(), + ok: false, + want: "", + }, + "MultipleRecordOneConnectionID": { + //nolint:lll + reason: "If datagram contains multiple records and one is a Connection ID record, we should be able to extract it.", + size: 8, + datagram: append(append(appRecord, cidHeader...), inner...), + ok: true, + want: string(cid), + }, + "MultipleRecordMultipleConnectionID": { + //nolint:lll + reason: "If datagram contains multiple records and multiple are Connection ID records, we should extract the first one.", + size: 8, + datagram: append(append(append(appRecord, func() []byte { + altCIDHeader, err := (&recordlayer.Header{ + Epoch: 1, + Version: protocol.Version1_2, + ContentType: protocol.ContentTypeConnectionID, + ContentLen: uint16(len(inner)), //nolint:gosec // G115 + ConnectionID: []byte("1234abcd"), + SequenceNumber: 1, + }).Marshal() + if err != nil { + t.Fatal(err) + } + + return append(altCIDHeader, inner...) + }()...), cidHeader...), inner...), + ok: true, + want: "1234abcd", + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + cid, ok := cidDatagramRouter(tc.size)(tc.datagram) + if ok != tc.ok { + t.Errorf("%s\ncidDatagramRouter: expected ok %t, but got %t.", tc.reason, tc.ok, ok) + } + if cid != tc.want { + t.Errorf("%s\ncidDatagramRouter: expected CID %s, but got %s.", tc.reason, tc.want, cid) + } + }) + } +} + +func TestCIDConnIdentifier(t *testing.T) { + cid := []byte("abcd1234") + cs := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + sh, err := (&recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Epoch: 0, + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: [28]byte{}}, + SessionID: []byte("hello"), + CipherSuiteID: &cs, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: []extension.Extension{ + &extension.ConnectionID{ + CID: cid, + }, + }, + }, + }, + }).Marshal() + if err != nil { + t.Fatal(err) + } + appRecord, err := (&recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Epoch: 1, + Version: protocol.Version1_2, + }, + Content: &protocol.ApplicationData{ + Data: []byte("application data"), + }, + }).Marshal() + if err != nil { + t.Fatal(err) + } + cases := map[string]struct { + reason string + datagram []byte + ok bool + want string + }{ + "EmptyDatagram": { + reason: "If datagram is empty, we cannot extract an identifier", + datagram: []byte{}, + ok: false, + want: "", + }, + "NotADTLSRecord": { + reason: "If datagram is not a DTLS record, we cannot extract an identifier", + datagram: []byte("not a DTLS record"), + ok: false, + want: "", + }, + "NotAServerhelloDatagram": { + reason: "If datagram does not contain any ServerHello record, we cannot extract an identifier", + datagram: appRecord, + ok: false, + want: "", + }, + "OneRecordServerHello": { + reason: "If datagram contains one ServerHello record, we should be able to extract an identifier.", + datagram: sh, + ok: true, + want: string(cid), + }, + "MultipleRecordFirstServerHello": { + //nolint:lll + reason: "If datagram contains multiple records and the first is a ServerHello record, we should be able to extract an identifier.", + datagram: append(sh, appRecord...), + ok: true, + want: string(cid), + }, + "MultipleRecordNotFirstServerHello": { + //nolint:lll + reason: "If datagram contains multiple records and the first is not a ServerHello record, we should not be able to extract an identifier.", + datagram: append(appRecord, sh...), + ok: false, + want: "", + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + cid, ok := cidConnIdentifier()(tc.datagram) + if ok != tc.ok { + t.Errorf("%s\ncidConnIdentifier: expected ok %t, but got %t.", tc.reason, tc.ok, ok) + } + if cid != tc.want { + t.Errorf("%s\ncidConnIdentifier: expected CID %s, but got %s.", tc.reason, tc.want, cid) + } + }) + } +} diff --git a/crypto.go b/crypto.go index 968910c7e..dae47731c 100644 --- a/crypto.go +++ b/crypto.go @@ -9,15 +9,14 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/rsa" - "crypto/sha256" "crypto/x509" "encoding/asn1" "encoding/binary" "math/big" "time" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/hash" ) type ecdsaSignature struct { @@ -44,24 +43,36 @@ func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve el // hash/signature algorithm pair that appears in that extension // // https://tools.ietf.org/html/rfc5246#section-7.4.2 -func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { +func generateKeySignature( + clientRandom, serverRandom, publicKey []byte, + namedCurve elliptic.Curve, + signer crypto.Signer, + hashAlgorithm hash.Algorithm, +) ([]byte, error) { msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve) - switch p := privateKey.(type) { - case ed25519.PrivateKey: + switch signer.Public().(type) { + case ed25519.PublicKey: // https://crypto.stackexchange.com/a/55483 - return p.Sign(rand.Reader, msg, crypto.Hash(0)) - case *ecdsa.PrivateKey: + return signer.Sign(rand.Reader, msg, crypto.Hash(0)) + case *ecdsa.PublicKey: hashed := hashAlgorithm.Digest(msg) - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) - case *rsa.PrivateKey: + + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + case *rsa.PublicKey: hashed := hashAlgorithm.Digest(msg) - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errKeySignatureGenerateUnimplemented } -func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.Algorithm, rawCertificates [][]byte) error { //nolint:dupl +//nolint:dupl,cyclop +func verifyKeySignature( + message, remoteKeySignature []byte, + hashAlgorithm hash.Algorithm, + rawCertificates [][]byte, +) error { if len(rawCertificates) == 0 { return errLengthMismatch } @@ -70,11 +81,12 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.A return err } - switch p := certificate.PublicKey.(type) { + switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: - if ok := ed25519.Verify(p, message, remoteKeySignature); !ok { + if ok := ed25519.Verify(pubKey, message, remoteKeySignature); !ok { return errKeySignatureMismatch } + return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} @@ -85,18 +97,18 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.A return errInvalidECDSASignature } hashed := hashAlgorithm.Digest(message) - if !ecdsa.Verify(p, hashed, ecdsaSig.R, ecdsaSig.S) { + if !ecdsa.Verify(pubKey, hashed, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } + return nil case *rsa.PublicKey: - switch certificate.SignatureAlgorithm { - case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: - hashed := hashAlgorithm.Digest(message) - return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) - default: - return errKeySignatureVerifyUnimplemented + hashed := hashAlgorithm.Digest(message) + if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) != nil { + return errKeySignatureMismatch } + + return nil } return errKeySignatureVerifyUnimplemented @@ -110,31 +122,37 @@ func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.A // CertificateVerify message is sent to explicitly verify possession of // the private key in the certificate. // https://tools.ietf.org/html/rfc5246#section-7.3 -func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { - if p, ok := privateKey.(ed25519.PrivateKey); ok { +func generateCertificateVerify( + handshakeBodies []byte, + signer crypto.Signer, + hashAlgorithm hash.Algorithm, +) ([]byte, error) { + if _, ok := signer.Public().(ed25519.PublicKey); ok { // https://pkg.go.dev/crypto/ed25519#PrivateKey.Sign // Sign signs the given message with priv. Ed25519 performs two passes over // messages to be signed and therefore cannot handle pre-hashed messages. - return p.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) + return signer.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) } - h := sha256.New() - if _, err := h.Write(handshakeBodies); err != nil { - return nil, err - } - hashed := h.Sum(nil) + hashed := hashAlgorithm.Digest(handshakeBodies) - switch p := privateKey.(type) { - case *ecdsa.PrivateKey: - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) - case *rsa.PrivateKey: - return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + switch signer.Public().(type) { + case *ecdsa.PublicKey: + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) + case *rsa.PublicKey: + return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errInvalidSignatureAlgorithm } -func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl +//nolint:dupl,cyclop +func verifyCertificateVerify( + handshakeBodies []byte, + hashAlgorithm hash.Algorithm, + remoteKeySignature []byte, + rawCertificates [][]byte, +) error { if len(rawCertificates) == 0 { return errLengthMismatch } @@ -143,11 +161,12 @@ func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorith return err } - switch p := certificate.PublicKey.(type) { + switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: - if ok := ed25519.Verify(p, handshakeBodies, remoteKeySignature); !ok { + if ok := ed25519.Verify(pubKey, handshakeBodies, remoteKeySignature); !ok { return errKeySignatureMismatch } + return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} @@ -158,18 +177,18 @@ func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorith return errInvalidECDSASignature } hash := hashAlgorithm.Digest(handshakeBodies) - if !ecdsa.Verify(p, hash, ecdsaSig.R, ecdsaSig.S) { + if !ecdsa.Verify(pubKey, hash, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } + return nil case *rsa.PublicKey: - switch certificate.SignatureAlgorithm { - case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: - hash := hashAlgorithm.Digest(handshakeBodies) - return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) - default: - return errKeySignatureVerifyUnimplemented + hash := hashAlgorithm.Digest(handshakeBodies) + if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) != nil { + return errKeySignatureMismatch } + + return nil } return errKeySignatureVerifyUnimplemented @@ -188,6 +207,7 @@ func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) { } certs = append(certs, cert) } + return certs, nil } @@ -206,10 +226,15 @@ func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [] Intermediates: intermediateCAPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, } + return certificate[0].Verify(opts) } -func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, err error) { +func verifyServerCert( + rawCertificates [][]byte, + roots *x509.CertPool, + serverName string, +) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err @@ -224,5 +249,6 @@ func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName DNSName: serverName, Intermediates: intermediateCAPool, } + return certificate[0].Verify(opts) } diff --git a/crypto_test.go b/crypto_test.go index 771ea3afa..249ca2cdc 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -9,10 +9,11 @@ import ( "encoding/pem" "testing" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/hash" ) +// nolint: gosec const rawPrivateKey = ` -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAxIA2BrrnR2sIlATsp7aRBD/3krwZ7vt9dNeoDQAee0s6SuYP @@ -50,21 +51,34 @@ func TestGenerateKeySignature(t *testing.T) { t.Error(err) } - clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} - serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f} - publicKey := []byte{0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15} + clientRandom := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + } + serverRandom := []byte{ + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + } + publicKey := []byte{ + 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, + 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, + } expectedSignature := []byte{ - 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, - 0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, - 0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0, - 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, - 0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30, - 0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, - 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, - 0x73, 0x39, 0x43, 0xd5, 0xbb, 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, 0xea, 0x36, - 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, - 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, - 0x87, 0x5e, 0x5c, 0x36, 0x75, 0x86, + 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, + 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, 0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, + 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, 0xcd, 0x83, 0x00, 0x57, + 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, + 0xa2, 0x81, 0xd0, 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, + 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, 0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, + 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30, 0x74, + 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, + 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, + 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, 0x73, 0x39, 0x43, 0xd5, 0xbb, + 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, + 0xea, 0x36, 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, + 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, + 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, 0x87, 0x5e, + 0x5c, 0x36, 0x75, 0x86, } signature, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256) diff --git a/e2e/Dockerfile b/e2e/Dockerfile index 68440e526..a6f9eedce 100644 --- a/e2e/Dockerfile +++ b/e2e/Dockerfile @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> # SPDX-License-Identifier: MIT -FROM docker.io/library/golang:1.18-bullseye +FROM docker.io/library/golang:1.24-bullseye COPY . /go/src/github.com/pion/dtls WORKDIR /go/src/github.com/pion/dtls/e2e diff --git a/e2e/e2e_lossy_test.go b/e2e/e2e_lossy_test.go index 2789ec3e9..f49cb78f4 100644 --- a/e2e/e2e_lossy_test.go +++ b/e2e/e2e_lossy_test.go @@ -10,9 +10,10 @@ import ( "testing" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" - transportTest "github.com/pion/transport/v2/test" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + transportTest "github.com/pion/transport/v3/test" ) const ( @@ -20,10 +21,9 @@ const ( lossyTestTimeout = 30 * time.Second ) -/* -DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments -*/ -func TestPionE2ELossy(t *testing.T) { +// DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments + +func TestPionE2ELossy(t *testing.T) { //nolint:cyclop // Check for leaking routines report := transportTest.CheckRoutines(t) defer report() @@ -44,10 +44,11 @@ func TestPionE2ELossy(t *testing.T) { } for _, test := range []struct { - LossChanceRange int - DoClientAuth bool - CipherSuites []dtls.CipherSuiteID - MTU int + LossChanceRange int + DoClientAuth bool + CipherSuites []dtls.CipherSuiteID + MTU int + DisableServerFlightInterval bool }{ { LossChanceRange: 0, @@ -108,6 +109,20 @@ func TestPionE2ELossy(t *testing.T) { MTU: 100, DoClientAuth: true, }, + // Incoming retransmitted handshakes should cause us to retransmit. Disabling the FlightInterval on one side + // means that a incoming re-transmissions causes the retransmission to be fired + { + LossChanceRange: 10, + DisableServerFlightInterval: true, + }, + { + LossChanceRange: 20, + DisableServerFlightInterval: true, + }, + { + LossChanceRange: 50, + DisableServerFlightInterval: true, + }, } { name := fmt.Sprintf("Loss%d_MTU%d", test.LossChanceRange, test.MTU) if test.DoClientAuth { @@ -116,13 +131,16 @@ func TestPionE2ELossy(t *testing.T) { for _, ciph := range test.CipherSuites { name += "_With" + ciph.String() } + if test.DisableServerFlightInterval { + name += "_WithNoServerFlightInterval" + } + test := test t.Run(name, func(t *testing.T) { // Limit runtime in case of deadlocks lim := transportTest.TimeOut(lossyTestTimeout + time.Second) defer lim.Stop() - rand.Seed(time.Now().UTC().UnixNano()) chosenLoss := rand.Intn(9) + test.LossChanceRange //nolint:gosec serverDone := make(chan runResult) clientDone := make(chan runResult) @@ -134,32 +152,38 @@ func TestPionE2ELossy(t *testing.T) { go func() { cfg := &dtls.Config{ - FlightInterval: flightInterval, - CipherSuites: test.CipherSuites, - InsecureSkipVerify: true, - MTU: test.MTU, + FlightInterval: flightInterval, + CipherSuites: test.CipherSuites, + InsecureSkipVerify: true, + MTU: test.MTU, + DisableRetransmitBackoff: true, } if test.DoClientAuth { cfg.Certificates = []tls.Certificate{clientCert} } - client, startupErr := dtls.Client(br.GetConn0(), cfg) + client, startupErr := dtls.Client(dtlsnet.PacketConnFromConn(br.GetConn0()), br.GetConn0().RemoteAddr(), cfg) clientDone <- runResult{client, startupErr} }() go func() { cfg := &dtls.Config{ - Certificates: []tls.Certificate{serverCert}, - FlightInterval: flightInterval, - MTU: test.MTU, + Certificates: []tls.Certificate{serverCert}, + FlightInterval: flightInterval, + MTU: test.MTU, + DisableRetransmitBackoff: true, } if test.DoClientAuth { cfg.ClientAuth = dtls.RequireAnyClientCert } - server, startupErr := dtls.Server(br.GetConn1(), cfg) + if test.DisableServerFlightInterval { + cfg.FlightInterval = time.Hour + } + + server, startupErr := dtls.Server(dtlsnet.PacketConnFromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), cfg) serverDone <- runResult{server, startupErr} }() @@ -187,20 +211,32 @@ func TestPionE2ELossy(t *testing.T) { select { case serverResult := <-serverDone: if serverResult.err != nil { - t.Errorf("Fail, serverError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, serverResult.err) + t.Errorf( + "Fail, serverError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", + clientConn != nil, serverConn != nil, chosenLoss, serverResult.err, + ) + return } serverConn = serverResult.dtlsConn case clientResult := <-clientDone: if clientResult.err != nil { - t.Errorf("Fail, clientError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, clientResult.err) + t.Errorf( + "Fail, clientError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", + clientConn != nil, serverConn != nil, chosenLoss, clientResult.err, + ) + return } clientConn = clientResult.dtlsConn case <-testTimer.C: - t.Errorf("Test expired: clientComplete(%t) serverComplete(%t) LossChance(%d)", clientConn != nil, serverConn != nil, chosenLoss) + t.Errorf( + "Test expired: clientComplete(%t) serverComplete(%t) LossChance(%d)", + clientConn != nil, serverConn != nil, chosenLoss, + ) + return case <-time.After(10 * time.Millisecond): } diff --git a/e2e/e2e_openssl_test.go b/e2e/e2e_openssl_test.go index 25bffb35a..97da68e88 100644 --- a/e2e/e2e_openssl_test.go +++ b/e2e/e2e_openssl_test.go @@ -20,7 +20,7 @@ import ( "testing" "time" - "github.com/pion/dtls/v2" + "github.com/pion/dtls/v3" ) func serverOpenSSL(c *comm) { @@ -78,7 +78,7 @@ func serverOpenSSL(c *comm) { // launch command // #nosec G204 - cmd := exec.CommandContext(c.ctx, "openssl", args...) + cmd := exec.Command("openssl", args...) var inner net.Conn inner, c.serverConn = net.Pipe() cmd.Stdin = inner @@ -95,6 +95,8 @@ func serverOpenSSL(c *comm) { c.serverReady <- struct{}{} simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) + c.serverDone <- cmd.Process.Kill() + close(c.serverDone) }() } @@ -155,7 +157,7 @@ func clientOpenSSL(c *comm) { // launch command // #nosec G204 - cmd := exec.CommandContext(c.ctx, "openssl", args...) + cmd := exec.Command("openssl", args...) var inner net.Conn inner, c.clientConn = net.Pipe() cmd.Stdin = inner @@ -168,6 +170,8 @@ func clientOpenSSL(c *comm) { } simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) + c.clientDone <- cmd.Process.Kill() + close(c.clientDone) } func ciphersOpenSSL(cfg *dtls.Config) string { diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 25514eff8..f02d1507f 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -22,9 +22,11 @@ import ( "testing" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" - "github.com/pion/transport/v2/test" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/transport/v3/test" ) const ( @@ -33,13 +35,17 @@ const ( messageRetry = 200 * time.Millisecond ) -var errServerTimeout = errors.New("waiting on serverReady err: timeout") +var ( + errServerTimeout = errors.New("waiting on serverReady err: timeout") + errHookCiphersFailed = errors.New("hook failed to modify cipherlist") + errHookAPLNFailed = errors.New("hook failed to modify APLN extension") +) -func randomPort(t testing.TB) int { - t.Helper() +func randomPort(tb testing.TB) int { + tb.Helper() conn, err := net.ListenPacket("udp4", "127.0.0.1:0") if err != nil { - t.Fatalf("failed to pickPort: %v", err) + tb.Fatalf("failed to pickPort: %v", err) } defer func() { _ = conn.Close() @@ -48,7 +54,8 @@ func randomPort(t testing.TB) int { case *net.UDPAddr: return addr.Port default: - t.Fatalf("unknown addr type %T", addr) + tb.Fatalf("unknown addr type %T", addr) + return 0 } } @@ -59,6 +66,7 @@ func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter n, err := conn.Read(buffer) if err != nil { errChan <- err + return } @@ -71,6 +79,7 @@ func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter break } else if _, err := conn.Write([]byte(testMessage)); err != nil { errChan <- err + break } @@ -79,16 +88,18 @@ func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter } type comm struct { - ctx context.Context + ctx context.Context //nolint:containedctx clientConfig, serverConfig *dtls.Config serverPort int messageRecvCount *uint64 // Counter to make sure both sides got a message clientMutex *sync.Mutex clientConn net.Conn + clientDone chan error serverMutex *sync.Mutex serverConn net.Conn serverListener net.Listener serverReady chan struct{} + serverDone chan error errChan chan error clientChan chan string serverChan chan string @@ -96,9 +107,15 @@ type comm struct { server func(*comm) } -func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm { +func newComm( + ctx context.Context, + clientConfig, serverConfig *dtls.Config, + serverPort int, + server, client func(*comm), +) *comm { messageRecvCount := uint64(0) - c := &comm{ + + com := &comm{ ctx: ctx, clientConfig: clientConfig, serverConfig: serverConfig, @@ -107,16 +124,21 @@ func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serve clientMutex: &sync.Mutex{}, serverMutex: &sync.Mutex{}, serverReady: make(chan struct{}), + serverDone: make(chan error), + clientDone: make(chan error), errChan: make(chan error), clientChan: make(chan string), serverChan: make(chan string), server: server, client: client, } - return c + + return com } -func (c *comm) assert(t *testing.T) { +func (c *comm) assert(t *testing.T) { //nolint:cyclop + t.Helper() + // DTLS Client go c.client(c) @@ -172,7 +194,35 @@ func (c *comm) assert(t *testing.T) { }() } -func clientPion(c *comm) { +func (c *comm) cleanup(t *testing.T) { //nolint:cyclop + t.Helper() + + clientDone, serverDone := false, false + for { + select { + case err := <-c.clientDone: + if err != nil { + t.Fatal(err) + } + clientDone = true + if clientDone && serverDone { + return + } + case err := <-c.serverDone: + if err != nil { + t.Fatal(err) + } + serverDone = true + if clientDone && serverDone { + return + } + case <-time.After(testTimeLimit): + t.Fatalf("Test timeout waiting for server shutdown") + } + } +} + +func clientPion(c *comm) { //nolint:varnamelen select { case <-c.serverReady: // OK @@ -183,20 +233,30 @@ func clientPion(c *comm) { c.clientMutex.Lock() defer c.clientMutex.Unlock() - var err error - c.clientConn, err = dtls.DialWithContext(c.ctx, "udp", + conn, err := dtls.Dial("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.clientConfig, ) if err != nil { c.errChan <- err + + return + } + + if err := conn.HandshakeContext(c.ctx); err != nil { + c.errChan <- err + return } + c.clientConn = conn + simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) + c.clientDone <- nil + close(c.clientDone) } -func serverPion(c *comm) { +func serverPion(c *comm) { //nolint:varnamelen c.serverMutex.Lock() defer c.serverMutex.Unlock() @@ -207,25 +267,45 @@ func serverPion(c *comm) { ) if err != nil { c.errChan <- err + return } c.serverReady <- struct{}{} c.serverConn, err = c.serverListener.Accept() if err != nil { c.errChan <- err + return } + dtlsConn, ok := c.serverConn.(*dtls.Conn) + if ok { + if err := dtlsConn.HandshakeContext(c.ctx); err != nil { + c.errChan <- err + + return + } + } + simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) + c.serverDone <- nil + close(c.serverDone) +} + +type dtlsConfOpts func(*dtls.Config) + +func withConnectionIDGenerator(g func() []byte) dtlsConfOpts { + return func(c *dtls.Config) { + c.ConnectionIDGenerator = g + } } -/* - Simple DTLS Client/Server can communicate - - Assert that you can send messages both ways - - Assert that Close() on both ends work - - Assert that no Goroutines are leaked -*/ -func testPionE2ESimple(t *testing.T, server, client func(*comm)) { +// Simple DTLS Client/Server can communicate +// - Assert that you can send messages both ways +// - Assert that Close() on both ends work +// - Assert that no Goroutines are leaked +func testPionE2ESimple(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -252,14 +332,20 @@ func testPionE2ESimple(t *testing.T, server, client func(*comm)) { CipherSuites: []dtls.CipherSuiteID{cipherSuite}, InsecureSkipVerify: true, } + for _, o := range opts { + o(cfg) + } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) + defer comm.cleanup(t) comm.assert(t) }) } } -func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) { +func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -279,20 +365,26 @@ func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) { defer cancel() cfg := &dtls.Config{ - PSK: func(hint []byte) ([]byte, error) { + PSK: func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, CipherSuites: []dtls.CipherSuiteID{cipherSuite}, } + for _, o := range opts { + o(cfg) + } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) + defer comm.cleanup(t) comm.assert(t) }) } } -func testPionE2EMTUs(t *testing.T, server, client func(*comm)) { +func testPionE2EMTUs(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -320,14 +412,20 @@ func testPionE2EMTUs(t *testing.T, server, client func(*comm)) { InsecureSkipVerify: true, MTU: mtu, } + for _, o := range opts { + o(cfg) + } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) + defer comm.cleanup(t) comm.assert(t) }) } } -func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) { +func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -360,14 +458,20 @@ func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) { CipherSuites: []dtls.CipherSuiteID{cipherSuite}, InsecureSkipVerify: true, } + for _, o := range opts { + o(cfg) + } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) + defer comm.cleanup(t) comm.assert(t) }) } } -func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm)) { +func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -405,12 +509,19 @@ func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm) CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } + for _, o := range opts { + o(scfg) + o(ccfg) + } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) + defer comm.cleanup(t) comm.assert(t) } -func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm)) { +func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -448,12 +559,19 @@ func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm)) CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } + for _, o := range opts { + o(scfg) + o(ccfg) + } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) + defer comm.cleanup(t) comm.assert(t) } -func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm)) { +func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -491,11 +609,134 @@ func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm)) { CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } + for _, o := range opts { + o(scfg) + o(ccfg) + } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) + defer comm.cleanup(t) comm.assert(t) } +func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + t.Run("ClientHello hook", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") + if err != nil { + t.Fatal(err) + } + + modifiedCipher := dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + supportedList := []dtls.CipherSuiteID{ + dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, + modifiedCipher, + } + + ccfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + VerifyConnection: func(s *dtls.State) error { + if s.CipherSuiteID != modifiedCipher { + return errHookCiphersFailed + } + + return nil + }, + CipherSuites: supportedList, + ClientHelloMessageHook: func(ch handshake.MessageClientHello) handshake.Message { + ch.CipherSuiteIDs = []uint16{uint16(modifiedCipher)} + + return &ch + }, + InsecureSkipVerify: true, + } + + scfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + CipherSuites: supportedList, + InsecureSkipVerify: true, + } + + for _, o := range opts { + o(ccfg) + o(scfg) + } + serverPort := randomPort(t) + comm := newComm(ctx, ccfg, scfg, serverPort, server, client) + defer comm.cleanup(t) + comm.assert(t) + }) +} + +func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { + t.Helper() + + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + t.Run("ServerHello hook", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") + if err != nil { + t.Fatal(err) + } + + supportedList := []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM} + + apln := "APLN" + + ccfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + VerifyConnection: func(s *dtls.State) error { + if s.NegotiatedProtocol != apln { + return errHookAPLNFailed + } + + return nil + }, + CipherSuites: supportedList, + InsecureSkipVerify: true, + } + + scfg := &dtls.Config{ + Certificates: []tls.Certificate{cert}, + CipherSuites: supportedList, + ServerHelloMessageHook: func(sh handshake.MessageServerHello) handshake.Message { + sh.Extensions = append(sh.Extensions, &extension.ALPN{ + ProtocolNameList: []string{apln}, + }) + + return &sh + }, + InsecureSkipVerify: true, + } + + for _, o := range opts { + o(ccfg) + o(scfg) + } + serverPort := randomPort(t) + comm := newComm(ctx, ccfg, scfg, serverPort, server, client) + defer comm.cleanup(t) + comm.assert(t) + }) +} + func TestPionE2ESimple(t *testing.T) { testPionE2ESimple(t, serverPion, clientPion) } @@ -523,3 +764,39 @@ func TestPionE2ESimpleECDSAClientCert(t *testing.T) { func TestPionE2ESimpleRSAClientCert(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientPion) } + +func TestPionE2ESimpleCID(t *testing.T) { + testPionE2ESimple(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) +} + +func TestPionE2ESimplePSKCID(t *testing.T) { + testPionE2ESimplePSK(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) +} + +func TestPionE2EMTUsCID(t *testing.T) { + testPionE2EMTUs(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) +} + +func TestPionE2ESimpleED25519CID(t *testing.T) { + testPionE2ESimpleED25519(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) +} + +func TestPionE2ESimpleED25519ClientCertCID(t *testing.T) { + testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) +} + +func TestPionE2ESimpleECDSAClientCertCID(t *testing.T) { + testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) +} + +func TestPionE2ESimpleRSAClientCertCID(t *testing.T) { + testPionE2ESimpleRSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) +} + +func TestPionE2ESimpleClientHelloHook(t *testing.T) { + testPionE2ESimpleClientHelloHook(t, serverPion, clientPion) +} + +func TestPionE2ESimpleServerHelloHook(t *testing.T) { + testPionE2ESimpleServerHelloHook(t, serverPion, clientPion) +} diff --git a/errors.go b/errors.go index 025d8645e..b7f93b7b4 100644 --- a/errors.go +++ b/errors.go @@ -11,69 +11,137 @@ import ( "net" "os" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" ) -// Typed errors +// Typed errors. var ( ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:goerr113 errDeadlineExceeded = &TimeoutError{Err: fmt.Errorf("read/write timeout: %w", context.DeadlineExceeded)} errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 - errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113 - errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} //nolint:goerr113 - errReservedExportKeyingMaterial = &TemporaryError{Err: errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113 - errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} //nolint:goerr113 - errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} //nolint:goerr113 - - errCertificateVerifyNoCertificate = &FatalError{Err: errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113 - errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113 - errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} //nolint:goerr113 - errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} //nolint:goerr113 - errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113 - errClientRequiredButNoServerEMS = &FatalError{Err: errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113 - errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} //nolint:goerr113 - errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113 - errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} //nolint:goerr113 - errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} //nolint:goerr113 - errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113 - errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} //nolint:goerr113 - errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113 - errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} //nolint:goerr113 - errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113 - errNoAvailableCipherSuites = &FatalError{Err: errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113 - errNoAvailablePSKCipherSuite = &FatalError{Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite")} //nolint:goerr113 - errNoAvailableCertificateCipherSuite = &FatalError{Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite")} //nolint:goerr113 - errNoAvailableSignatureSchemes = &FatalError{Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113 - errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} //nolint:goerr113 - errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} //nolint:goerr113 - errNoSupportedEllipticCurves = &FatalError{Err: errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113 - errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 - errPSKAndIdentityMustBeSetForClient = &FatalError{Err: errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113 - errRequestedButNoSRTPExtension = &FatalError{Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113 - errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113 - errServerRequiredButNoClientEMS = &FatalError{Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113 - errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:goerr113 - errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} //nolint:goerr113 - - errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:goerr113 - errKeySignatureGenerateUnimplemented = &InternalError{Err: errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113 - errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113 - errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 - errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 - errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113 - errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} //nolint:goerr113 - errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} //nolint:goerr113 + //nolint:goerr113 + errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} + //nolint:goerr113 + errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} + //nolint:goerr113 + errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} + //nolint:goerr113 + errReservedExportKeyingMaterial = &TemporaryError{ + Err: errors.New("ExportKeyingMaterial can not be used with a reserved label"), + } + //nolint:goerr113 + errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} + //nolint:goerr113 + errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} + + //nolint:goerr113 + errCertificateVerifyNoCertificate = &FatalError{ + Err: errors.New("client sent certificate verify but we have no certificate to verify"), + } + //nolint:goerr113 + errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} + //nolint:goerr113 + errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} + //nolint:goerr113 + errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} + //nolint:goerr113 + errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} + //nolint:goerr113 + errClientRequiredButNoServerEMS = &FatalError{ + Err: errors.New("client required Extended Master Secret extension, but server does not support it"), + } + //nolint:goerr113 + errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} + //nolint:goerr113 + errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} + //nolint:goerr113 + errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} + //nolint:goerr113 + errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} + //nolint:goerr113 + errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} + //nolint:goerr113 + errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} + //nolint:goerr113 + errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} + //nolint:goerr113 + errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} + //nolint:goerr113 + errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} + //nolint:goerr113 + errNoAvailableCipherSuites = &FatalError{ + Err: errors.New("connection can not be created, no CipherSuites satisfy this Config"), + } + //nolint:goerr113 + errNoAvailablePSKCipherSuite = &FatalError{ + Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite"), + } + //nolint:goerr113 + errNoAvailableCertificateCipherSuite = &FatalError{ + Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite"), + } + //nolint:goerr113 + errNoAvailableSignatureSchemes = &FatalError{ + Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config"), + } + //nolint:goerr113 + errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} + //nolint:goerr113 + errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} + //nolint:goerr113 + errNoSupportedEllipticCurves = &FatalError{ + Err: errors.New("client requested zero or more elliptic curves that are not supported by the server"), + } + //nolint:goerr113 + errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} + //nolint:goerr113 + errPSKAndIdentityMustBeSetForClient = &FatalError{ + Err: errors.New("PSK and PSK Identity Hint must both be set for client"), + } + //nolint:goerr113 + errRequestedButNoSRTPExtension = &FatalError{ + Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension"), + } + //nolint:goerr113 + errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} + //nolint:goerr113 + errServerRequiredButNoClientEMS = &FatalError{ + Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it"), + } + //nolint:goerr113 + errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} + //nolint:goerr113 + errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} + + //nolint:goerr113 + errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} + //nolint:goerr113 + errKeySignatureGenerateUnimplemented = &InternalError{ + Err: errors.New("unable to generate key signature, unimplemented"), + } + //nolint:goerr113 + errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} + //nolint:goerr113 + errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} + //nolint:goerr113 + errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} + //nolint:goerr113 + errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} + //nolint:goerr113 + errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} + //nolint:goerr113 + errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} ) // FatalError indicates that the DTLS connection is no longer available. // It is mainly caused by wrong configuration of server or client. type FatalError = protocol.FatalError -// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available. +// InternalError indicates and internal error caused by the implementation, +// and the DTLS connection is no longer available. // It is mainly caused by bugs or tried to use unimplemented features. type InternalError = protocol.InternalError @@ -100,10 +168,11 @@ func (e *invalidCipherSuiteError) Is(err error) bool { if errors.As(err, &other) { return e.id == other.id } + return false } -// errAlert wraps DTLS alert notification as an error +// errAlert wraps DTLS alert notification as an error. type alertError struct { *alert.Alert } @@ -121,6 +190,7 @@ func (e *alertError) Is(err error) bool { if errors.As(err, &other) { return e.Level == other.Level && e.Description == other.Description } + return false } @@ -138,7 +208,7 @@ func netError(err error) error { se *os.SyscallError ) - if errors.As(err, &opError) { + if errors.As(err, &opError) { //nolint:nestif if errors.As(opError, &se) { if se.Timeout() { return &TimeoutError{Err: err} diff --git a/errors_errno_test.go b/errors_errno_test.go index 5b4b209d1..e2957a691 100644 --- a/errors_errno_test.go +++ b/errors_errno_test.go @@ -16,18 +16,32 @@ import ( ) func TestErrorsTemporary(t *testing.T) { - addrListen, errListen := net.ResolveUDPAddr("udp", "localhost:0") - if errListen != nil { - t.Fatalf("Unexpected error: %v", errListen) + // Allocate a UDP port no one is listening on. + addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") + if err != nil { + t.Fatalf("Unexpected failure to resolve: %v", err) } + listener, err := net.ListenUDP("udp", addrListen) + if err != nil { + t.Fatalf("Unexpected failure to listen: %v", err) + } + raddr, ok := listener.LocalAddr().(*net.UDPAddr) + if !ok { + t.Fatal("Unexpedted type assertion error") + } + err = listener.Close() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // Server is not listening. - conn, errDial := net.DialUDP("udp", nil, addrListen) + conn, errDial := net.DialUDP("udp", nil, raddr) if errDial != nil { t.Fatalf("Unexpected error: %v", errDial) } _, _ = conn.Write([]byte{0x00}) // trigger - _, err := conn.Read(make([]byte, 10)) + _, err = conn.Read(make([]byte, 10)) _ = conn.Close() if err == nil { diff --git a/errors_test.go b/errors_test.go index 05c2c2745..db3bffc59 100644 --- a/errors_test.go +++ b/errors_test.go @@ -65,21 +65,21 @@ func TestErrorNetError(t *testing.T) { {&HandshakeError{Err: errExample}, "handshake error: an example error", false, false}, {&HandshakeError{Err: &TimeoutError{Err: errExample}}, "handshake error: dtls timeout: an example error", true, true}, } - for _, c := range cases { - c := c - t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) { + for _, testCase := range cases { + testCase := testCase + t.Run(fmt.Sprintf("%T", testCase.err), func(t *testing.T) { var ne net.Error - if !errors.As(c.err, &ne) { - t.Fatalf("%T doesn't implement net.Error", c.err) + if !errors.As(testCase.err, &ne) { + t.Fatalf("%T doesn't implement net.Error", testCase.err) } - if ne.Timeout() != c.timeout { - t.Errorf("%T.Timeout() should be %v", c.err, c.timeout) + if ne.Timeout() != testCase.timeout { + t.Errorf("%T.Timeout() should be %v", testCase.err, testCase.timeout) } - if ne.Temporary() != c.temporary { //nolint:staticcheck - t.Errorf("%T.Temporary() should be %v", c.err, c.temporary) + if ne.Temporary() != testCase.temporary { //nolint:staticcheck + t.Errorf("%T.Temporary() should be %v", testCase.err, testCase.temporary) } - if ne.Error() != c.str { - t.Errorf("%T.Error() should be %v", c.err, c.str) + if ne.Error() != testCase.str { + t.Errorf("%T.Error() should be %v", testCase.err, testCase.str) } }) } diff --git a/examples/dial/cid/main.go b/examples/dial/cid/main.go new file mode 100644 index 000000000..15f316137 --- /dev/null +++ b/examples/dial/cid/main.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package main implements an example DTLS client using a pre-shared key. +package main + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" +) + +func main() { + // Prepare the IP to connect to + addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} + + // + // Everything below is the pion-DTLS API! Thanks for using it ❤️. + // + + // Prepare the configuration of the DTLS connection + config := &dtls.Config{ + PSK: func(hint []byte) ([]byte, error) { + fmt.Printf("Server's hint: %s \n", hint) + + return []byte{0xAB, 0xC1, 0x23}, nil + }, + PSKIdentityHint: []byte("Pion DTLS Client"), + CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + ConnectionIDGenerator: dtls.OnlySendCIDGenerator(), + } + + // Connect to a DTLS server + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + dtlsConn, err := dtls.Dial("udp", addr, config) + util.Check(err) + defer func() { + util.Check(dtlsConn.Close()) + }() + + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + + return + } + + fmt.Println("Connected; type 'exit' to shutdown gracefully") + + // Simulate a chat session + util.Chat(dtlsConn) +} diff --git a/examples/dial/psk/main.go b/examples/dial/psk/main.go index 1f4440d41..2847d40c7 100644 --- a/examples/dial/psk/main.go +++ b/examples/dial/psk/main.go @@ -10,8 +10,8 @@ import ( "net" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/examples/util" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" ) func main() { @@ -26,9 +26,10 @@ func main() { config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil }, - PSKIdentityHint: []byte("Pion DTLS Server"), + PSKIdentityHint: []byte{}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, } @@ -36,12 +37,18 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/selfsign/main.go b/examples/dial/selfsign/main.go index 5fa25a923..66ff80a70 100644 --- a/examples/dial/selfsign/main.go +++ b/examples/dial/selfsign/main.go @@ -11,9 +11,9 @@ import ( "net" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/examples/util" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) func main() { @@ -38,12 +38,18 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/dial/verify/main.go b/examples/dial/verify/main.go index 07501954d..32992a9d1 100644 --- a/examples/dial/verify/main.go +++ b/examples/dial/verify/main.go @@ -12,8 +12,8 @@ import ( "net" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/examples/util" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" ) func main() { @@ -45,12 +45,18 @@ func main() { // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) + dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() + if err := dtlsConn.HandshakeContext(ctx); err != nil { + fmt.Printf("Failed to handshake with server: %v\n", err) + + return + } + fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session diff --git a/examples/listen/cid/main.go b/examples/listen/cid/main.go new file mode 100644 index 000000000..2c0b41ee6 --- /dev/null +++ b/examples/listen/cid/main.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package main implements a DTLS server using a pre-shared key. +package main + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" +) + +func main() { + // Prepare the IP to connect to + addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} + + // + // Everything below is the pion-DTLS API! Thanks for using it ❤️. + // + + // Prepare the configuration of the DTLS connection + config := &dtls.Config{ + PSK: func(hint []byte) ([]byte, error) { + fmt.Printf("Client's hint: %s \n", hint) + + return []byte{0xAB, 0xC1, 0x23}, nil + }, + PSKIdentityHint: []byte("Pion DTLS Server"), + CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + ConnectionIDGenerator: dtls.RandomCIDGenerator(8), + } + + // Connect to a DTLS server + listener, err := dtls.Listen("udp", addr, config) + util.Check(err) + defer func() { + util.Check(listener.Close()) + }() + + fmt.Println("Listening") + + // Simulate a chat session + hub := util.NewHub() + + go func() { + for { + // Wait for a connection. + conn, err := listener.Accept() + util.Check(err) + // defer conn.Close() // TODO: graceful shutdown + + // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` + // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose + // functions like `ConnectionState` etc. + + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + + // Register the connection with the chat hub + if err == nil { + hub.Register(conn) + } + } + }() + + // Start chatting + hub.Chat() +} diff --git a/examples/listen/psk/main.go b/examples/listen/psk/main.go index ff45b7d86..4098e0dfe 100644 --- a/examples/listen/psk/main.go +++ b/examples/listen/psk/main.go @@ -10,18 +10,14 @@ import ( "net" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/examples/util" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -30,15 +26,12 @@ func main() { config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) + return []byte{0xAB, 0xC1, 0x23}, nil }, - PSKIdentityHint: []byte("Pion DTLS Client"), + PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server @@ -64,6 +57,14 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub if err == nil { hub.Register(conn) diff --git a/examples/listen/selfsign/main.go b/examples/listen/selfsign/main.go index 025b667e4..2243d722e 100644 --- a/examples/listen/selfsign/main.go +++ b/examples/listen/selfsign/main.go @@ -11,9 +11,9 @@ import ( "net" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/examples/util" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) func main() { @@ -24,10 +24,6 @@ func main() { certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -36,10 +32,6 @@ func main() { config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server @@ -65,6 +57,14 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub if err == nil { hub.Register(conn) diff --git a/examples/listen/verify-brute-force-protection/main.go b/examples/listen/verify-brute-force-protection/main.go new file mode 100644 index 000000000..606179b34 --- /dev/null +++ b/examples/listen/verify-brute-force-protection/main.go @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package main implements an example DTLS server which verifies client certificates. +// It also implements a basic Brute Force Attack protection. +package main + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "sync" + "time" + + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" +) + +func main() { + // Prepare the IP to connect to + addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} + + // + // Everything below is the pion-DTLS API! Thanks for using it ❤️. + // + + // ************ Variables used to implement a basic Brute Force Attack protection ************* + var ( + attempts = make(map[string]int) // Map of attempts for each IP address. + attemptsMutex sync.Mutex // Mutex for the map of attempts. + attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes. + ) + + certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", + "examples/certificates/server.pub.pem") + util.Check(err) + + rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") + util.Check(err) + certPool := x509.NewCertPool() + cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) + util.Check(err) + certPool.AddCert(cert) + + // Prepare the configuration of the DTLS connection + config := &dtls.Config{ + Certificates: []tls.Certificate{certificate}, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + ClientAuth: dtls.RequireAndVerifyClientCert, + ClientCAs: certPool, + // This function will be called on each connection attempt. + OnConnectionAttempt: func(addr net.Addr) error { + // *************** Brute Force Attack protection *************** + // Check if the IP address is in the map, and if the IP address has exceeded the limit + attemptsMutex.Lock() + defer attemptsMutex.Unlock() + // Here I implement a time cleaner for the map of attempts, every 5 minutes I will + // decrement by 1 the number of attempts for each IP address. + if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) { + attemptsCleaner = time.Now() + for k, v := range attempts { + if v > 0 { + attempts[k]-- + } + if attempts[k] == 0 { + delete(attempts, k) + } + } + } + // Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection) + attemptIP := addr.(*net.UDPAddr).IP.String() //nolint + if attempts[attemptIP] > 10 { + return fmt.Errorf("too many attempts from this IP address") //nolint + } + // Here I increment the number of attempts for this IP address (Brute Force Attack protection) + attempts[attemptIP]++ + // *************** END Brute Force Attack protection END *************** + return nil + }, + } + + // Connect to a DTLS server + listener, err := dtls.Listen("udp", addr, config) + util.Check(err) + defer func() { + util.Check(listener.Close()) + }() + + fmt.Println("Listening") + + // Simulate a chat session + hub := util.NewHub() + + go func() { + for { + // Wait for a connection. + conn, err := listener.Accept() + util.Check(err) + // defer conn.Close() // TODO: graceful shutdown + + // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` + // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose + // functions like `ConnectionState` etc. + + // *************** Brute Force Attack protection *************** + // Here I decrease the number of attempts for this IP address + attemptsMutex.Lock() + attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String() //nolint + attempts[attemptIP]-- + // If the number of attempts for this IP address is 0, I delete the IP address from the map + if attempts[attemptIP] == 0 { + delete(attempts, attemptIP) + } + attemptsMutex.Unlock() + // *************** END Brute Force Attack protection END *************** + + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + + // Register the connection with the chat hub + hub.Register(conn) + } + }() + + // Start chatting + hub.Chat() +} diff --git a/examples/listen/verify/main.go b/examples/listen/verify/main.go index a02211e15..7d7067a27 100644 --- a/examples/listen/verify/main.go +++ b/examples/listen/verify/main.go @@ -12,18 +12,14 @@ import ( "net" "time" - "github.com/pion/dtls/v2" - "github.com/pion/dtls/v2/examples/util" + "github.com/pion/dtls/v3" + "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} - // Create parent context to cleanup handshaking connections on exit. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // @@ -45,10 +41,6 @@ func main() { ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, - // Create timeout context for accepted connection. - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // Connect to a DTLS server @@ -74,6 +66,14 @@ func main() { // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. + // Perform the handshake with a 30-second timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + dtlsConn, ok := conn.(*dtls.Conn) + if ok { + util.Check(dtlsConn.HandshakeContext(ctx)) + } + cancel() + // Register the connection with the chat hub hub.Register(conn) } diff --git a/examples/util/hub.go b/examples/util/hub.go index 19e33db0c..4e656a21f 100644 --- a/examples/util/hub.go +++ b/examples/util/hub.go @@ -12,18 +12,18 @@ import ( "sync" ) -// Hub is a helper to handle one to many chat +// Hub is a helper to handle one to many chat. type Hub struct { conns map[string]net.Conn lock sync.RWMutex } -// NewHub builds a new hub +// NewHub builds a new hub. func NewHub() *Hub { return &Hub{conns: make(map[string]net.Conn)} } -// Register adds a new conn to the Hub +// Register adds a new conn to the Hub. func (h *Hub) Register(conn net.Conn) { fmt.Printf("Connected to %s\n", conn.RemoteAddr()) h.lock.Lock() @@ -40,6 +40,7 @@ func (h *Hub) readLoop(conn net.Conn) { n, err := conn.Read(b) if err != nil { h.unregister(conn) + return } fmt.Printf("Got message: %s\n", string(b[:n])) @@ -69,7 +70,7 @@ func (h *Hub) broadcast(msg []byte) { } } -// Chat starts the stdin readloop to dispatch messages to the hub +// Chat starts the stdin readloop to dispatch messages to the hub. func (h *Hub) Chat() { reader := bufio.NewReader(os.Stdin) for { diff --git a/examples/util/util.go b/examples/util/util.go index f6a9642c0..e3a85cb0e 100644 --- a/examples/util/util.go +++ b/examples/util/util.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "os" "path/filepath" @@ -25,7 +24,7 @@ var ( errNoCertificateFound = errors.New("no certificate found, unable to load certificates") ) -// Chat simulates a simple text chat session over the connection +// Chat simulates a simple text chat session over the connection. func Chat(conn io.ReadWriter) { go func() { b := make([]byte, bufSize) @@ -52,7 +51,7 @@ func Chat(conn io.ReadWriter) { } } -// Check is a helper to throw errors in the examples +// Check is a helper to throw errors in the examples. func Check(err error) { var netError net.Error if errors.As(err, &netError) && netError.Temporary() { //nolint:staticcheck @@ -63,14 +62,14 @@ func Check(err error) { } } -// LoadKeyAndCertificate reads certificates or key from file +// LoadKeyAndCertificate reads certificates or key from file. func LoadKeyAndCertificate(keyPath string, certificatePath string) (tls.Certificate, error) { return tls.LoadX509KeyPair(certificatePath, keyPath) } -// LoadCertificate Load/read certificate(s) from file +// LoadCertificate Load/read certificate(s) from file. func LoadCertificate(path string) (*tls.Certificate, error) { - rawData, err := ioutil.ReadFile(filepath.Clean(path)) + rawData, err := os.ReadFile(filepath.Clean(path)) if err != nil { return nil, err } diff --git a/flight.go b/flight.go index cfa58c574..7ecc9489d 100644 --- a/flight.go +++ b/flight.go @@ -70,7 +70,7 @@ const ( flight6 ) -func (f flightVal) String() string { +func (f flightVal) String() string { //nolint:cyclop switch f { case flight0: return "Flight 0" diff --git a/flight0handler.go b/flight0handler.go index ec766ddff..ce6ad2031 100644 --- a/flight0handler.go +++ b/flight0handler.go @@ -7,14 +7,21 @@ import ( "context" "crypto/rand" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/extension" - "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" ) -func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +//nolint:cyclop +func flight0Parse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) @@ -22,6 +29,12 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak // No valid message received. Keep reading return 0, nil, nil } + + // Connection Identifiers must be negotiated afresh on session resumption. + // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension + state.setLocalConnectionID(nil) + state.remoteConnectionID = nil + state.handshakeRecvSequence = seq var clientHello *handshake.MessageClientHello @@ -49,29 +62,42 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak } for _, val := range clientHello.Extensions { - switch e := val.(type) { + switch ext := val.(type) { case *extension.SupportedEllipticCurves: - if len(e.EllipticCurves) == 0 { + if len(ext.EllipticCurves) == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves } - state.namedCurve = e.EllipticCurves[0] + state.namedCurve = ext.EllipticCurves[0] case *extension.UseSRTP: - profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) + profile, ok := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile } - state.srtpProtectionProfile = profile + state.setSRTPProtectionProfile(profile) + state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ServerName: - state.serverName = e.ServerName // remote server name + state.serverName = ext.ServerName // remote server name case *extension.ALPN: - state.peerSupportedProtocols = e.ProtocolNameList + state.peerSupportedProtocols = ext.ProtocolNameList + case *extension.ConnectionID: + // Only set connection ID to be sent if server supports connection + // IDs. + if cfg.connectionIDGenerator != nil { + state.remoteConnectionID = ext.CID + } } } + // If the client doesn't support connection IDs, the server should not + // expect one to be sent. + if state.remoteConnectionID == nil { + state.setLocalConnectionID(nil) + } + if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS } @@ -93,7 +119,12 @@ func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshak return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight) } -func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) { +func handleHelloResume( + sessionID []byte, + state *State, + cfg *handshakeConfig, + next flightVal, +) (flightVal, *alert.Alert, error) { if len(sessionID) > 0 && cfg.sessionStore != nil { if s, err := cfg.sessionStore.Get(sessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err @@ -113,10 +144,16 @@ func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, nex return flight4b, nil, nil } } + return next, nil, nil } -func flight0Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight0Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { // Initialize if !cfg.insecureSkipHelloVerify { state.cookie = make([]byte, cookieLength) diff --git a/flight1handler.go b/flight1handler.go index 94fdc222d..6c55a6430 100644 --- a/flight1handler.go +++ b/flight1handler.go @@ -6,15 +6,21 @@ package dtls import ( "context" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/extension" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight1Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { // HelloVerifyRequest can be skipped by the server, // so allow ServerHello during flight1 also seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, @@ -29,7 +35,7 @@ func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handsh if _, ok := msgs[handshake.TypeServerHello]; ok { // Flight1 and flight2 were skipped. // Parse as flight3. - return flight3Parse(ctx, c, state, cache, cfg) + return flight3Parse(ctx, conn, state, cache, cfg) } if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok { @@ -40,13 +46,20 @@ func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handsh } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq + return flight3, nil, nil } return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } -func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +//nolint:cyclop +func flight1Generate( + conn flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) @@ -57,6 +70,10 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha return nil, nil, err } + if cfg.helloRandomBytesGenerator != nil { + state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator() + } + extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, @@ -70,6 +87,7 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha for _, c := range cfg.localCipherSuites { if c.ECC() { setEllipticCurveCryptographyClientHelloExtensions = true + break } } @@ -87,7 +105,8 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: cfg.localSRTPProtectionProfiles, + ProtectionProfiles: cfg.localSRTPProtectionProfiles, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } @@ -108,7 +127,7 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha if cfg.sessionStore != nil { cfg.log.Tracef("[handshake] try to resume session") - if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil { + if s, err := cfg.sessionStore.Get(conn.sessionKey()); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] get saved session: %x", s.ID) @@ -118,23 +137,46 @@ func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handsha } } + // If we have a connection ID generator, use it. The CID may be zero length, + // in which case we are just requesting that the server send us a CID to + // use. + if cfg.connectionIDGenerator != nil { + state.setLocalConnectionID(cfg.connectionIDGenerator()) + // The presence of a generator indicates support for connection IDs. We + // use the presence of a non-nil local CID in flight 3 to determine + // whether we send a CID in the second ClientHello, so we convert any + // nil CID returned by a generator to []byte{}. + if state.getLocalConnectionID() == nil { + state.setLocalConnectionID([]byte{}) + } + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) + } + + clientHello := &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.clientHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} + } else { + content = handshake.Handshake{Message: clientHello} + } + return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageClientHello{ - Version: protocol.Version1_2, - SessionID: state.SessionID, - Cookie: state.cookie, - Random: state.localRandom, - CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), - CompressionMethods: defaultCompressionMethods(), - Extensions: extensions, - }, - }, + Content: &content, }, }, }, nil, nil diff --git a/flight1handler_test.go b/flight1handler_test.go new file mode 100644 index 000000000..457ee413b --- /dev/null +++ b/flight1handler_test.go @@ -0,0 +1,285 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "testing" + "time" + + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/logging" + "github.com/pion/transport/v3/test" +) + +type flight1TestMockFlightConn struct{} + +func (f *flight1TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { + return nil +} +func (f *flight1TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } +func (f *flight1TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } +func (f *flight1TestMockFlightConn) setLocalEpoch(uint16) {} +func (f *flight1TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } +func (f *flight1TestMockFlightConn) sessionKey() []byte { return nil } + +type flight1TestMockCipherSuite struct { + ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 + + t *testing.T +} + +func (f *flight1TestMockCipherSuite) IsInitialized() bool { + f.t.Fatal("IsInitialized called with Certificate but not CertificateVerify") + + return true +} + +// When "server hello" arrives later than "certificate", +// "server key exchange", "certificate request", "server hello done", +// is it normal for the flight1Parse method to handle it. +func TestFlight1_Process_ServerHelloLateArrival(t *testing.T) { //nolint:maintidx + // Limit runtime in case of deadlocks + lim := test.TimeOut(5 * time.Second) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + mockConn := &flight1TestMockFlightConn{} + state := &State{ + cipherSuite: &flight1TestMockCipherSuite{t: t}, + } + cache := newHandshakeCache() + cfg := &handshakeConfig{ + localSRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AEAD_AES_128_GCM}, + localCipherSuites: []CipherSuite{}, + } + cfg.localCipherSuites = []CipherSuite{cipherSuiteForID(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, nil)} + cfg.log = logging.NewDefaultLoggerFactory().NewLogger("dtls") + + serverHello := []byte{ + 0x02, 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x62, 0xfe, 0xfd, 0x07, 0x46, 0xb7, 0xbf, 0xde, 0x78, + 0xab, 0x38, 0x69, 0x36, 0x74, 0x10, 0xa6, 0x50, 0x67, 0x7b, + 0x4b, 0x85, 0xdf, 0x71, 0x71, 0x62, 0x3a, 0xb1, 0xd7, 0xa4, + 0x79, 0x6a, 0x38, 0x13, 0x5e, 0xa1, 0x20, 0xbd, 0x64, 0xaf, + 0xb3, 0x36, 0x77, 0x73, 0x8a, 0x62, 0x75, 0xb2, 0x64, 0xbe, + 0xf6, 0x2a, 0xb1, 0x6e, 0x7b, 0xf6, 0x00, 0xd6, 0x24, 0xd5, + 0xb1, 0x1e, 0x54, 0xa3, 0x76, 0xb3, 0xac, 0x76, 0x8f, 0xc0, + 0x2f, 0x00, 0x00, 0x1a, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, + 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, 0x00, 0x0e, 0x00, + 0x05, 0x00, 0x02, 0x00, 0x07, 0x00, 0x00, 0x17, 0x00, 0x00, + } + certificate1 := []byte{ + 0x0b, 0x00, 0x05, 0x5b, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x04, 0xe4, 0x00, 0x05, 0x58, 0x00, 0x05, 0x55, 0x30, 0x82, + 0x05, 0x51, 0x30, 0x82, 0x04, 0x39, 0xa0, 0x03, 0x02, 0x01, + 0x02, 0x02, 0x0c, 0x56, 0x8b, 0xb4, 0x68, 0xed, 0x70, 0xce, + 0xb6, 0x8d, 0x44, 0x65, 0x4b, 0x30, 0x0d, 0x06, 0x09, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, + 0x30, 0x66, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, + 0x06, 0x13, 0x02, 0x42, 0x45, 0x31, 0x19, 0x30, 0x17, 0x06, + 0x03, 0x55, 0x04, 0x0a, 0x13, 0x10, 0x47, 0x6c, 0x6f, 0x62, + 0x61, 0x6c, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x6e, 0x76, 0x2d, + 0x73, 0x61, 0x31, 0x3c, 0x30, 0x3a, 0x06, 0x03, 0x55, 0x04, + 0x03, 0x13, 0x33, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53, + 0x69, 0x67, 0x6e, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x56, 0x61, 0x6c, + 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x43, 0x41, + 0x20, 0x2d, 0x20, 0x53, 0x48, 0x41, 0x32, 0x35, 0x36, 0x20, + 0x2d, 0x20, 0x47, 0x32, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x37, + 0x30, 0x34, 0x32, 0x30, 0x31, 0x31, 0x31, 0x39, 0x35, 0x39, + 0x5a, 0x17, 0x0d, 0x31, 0x38, 0x30, 0x34, 0x32, 0x31, 0x31, + 0x31, 0x31, 0x39, 0x35, 0x39, 0x5a, 0x30, 0x81, 0x84, 0x31, + 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, + 0x43, 0x4e, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, + 0x08, 0x13, 0x09, 0x67, 0x75, 0x61, 0x6e, 0x67, 0x64, 0x6f, + 0x6e, 0x67, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x04, + 0x07, 0x13, 0x08, 0x73, 0x68, 0x65, 0x6e, 0x7a, 0x68, 0x65, + 0x6e, 0x31, 0x36, 0x30, 0x34, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x13, 0x2d, 0x54, 0x65, 0x6e, 0x63, 0x65, 0x6e, 0x74, 0x20, + 0x54, 0x65, 0x63, 0x68, 0x6e, 0x6f, 0x6c, 0x6f, 0x67, 0x79, + 0x20, 0x28, 0x53, 0x68, 0x65, 0x6e, 0x7a, 0x68, 0x65, 0x6e, + 0x29, 0x20, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x6e, 0x79, 0x20, + 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x16, 0x30, + 0x14, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x0d, 0x77, 0x65, + 0x62, 0x72, 0x74, 0x63, 0x2e, 0x71, 0x71, 0x2e, 0x63, 0x6f, + 0x6d, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, + 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, + 0x82, 0x01, 0x01, 0x00, 0xb6, 0x00, 0xa7, 0x09, 0x0a, 0xc4, + 0x96, 0x24, 0x72, 0xa0, 0x09, 0xda, 0xac, 0x63, 0xe4, 0x9a, + 0xfe, 0x8b, 0x9b, 0x99, 0x8c, 0xe3, 0xab, 0x4b, 0x7c, 0xbd, + 0x4f, 0x31, 0x1e, 0x2f, 0xff, 0x34, 0x54, 0xb5, 0xb0, 0x99, + 0xcd, 0x00, 0x7c, 0x5b, 0x12, 0x96, 0xfa, 0x9b, 0x6b, 0x79, + 0xc7, 0xfb, 0x00, 0x53, 0xaf, 0xb6, 0x00, 0x45, 0x46, 0x20, + 0x7d, 0x95, 0xca, 0x86, 0xcc, 0x4b, 0xe8, 0x25, 0x52, 0x5b, + 0x9c, 0xe7, 0x58, 0xcd, 0xd0, 0x8f, 0x4a, 0xd8, 0x77, 0x7d, + 0x45, 0xa0, 0x70, 0xe8, 0x16, 0x45, 0x23, 0xfb, 0xbc, 0x43, + 0x36, 0xdd, 0x5b, 0x8f, 0x01, 0xc3, 0xc0, 0xa2, 0xab, 0x80, + 0xf1, 0x97, 0x72, 0x38, 0xab, 0x6f, 0xa1, 0x28, 0x09, 0xdd, + 0x31, 0x7e, 0x50, 0xc8, 0x51, 0xde, 0x8d, 0x05, 0xbc, 0x72, + 0x79, 0x94, 0x6e, 0xd4, 0xb7, 0xf0, 0x97, 0xd0, 0x76, 0x9c, + 0x9d, 0xb4, 0x34, 0xf1, 0x8a, 0x82, 0x20, 0x9b, 0x24, 0x4b, + 0x38, 0xc9, 0x63, 0xe6, 0x02, 0xf5, 0xb2, 0x9b, 0x70, 0xa4, + 0x97, 0x9f, 0xaa, 0x1f, 0x36, 0x9c, 0xfd, 0x81, 0x93, 0x81, + 0xd7, 0x4e, 0xca, 0xd2, 0xa7, 0x7c, 0x29, 0x9d, 0x28, 0xf2, + 0x3e, 0x3b, 0xea, 0xe6, 0x22, 0x51, 0x8f, 0x0b, 0xe7, 0x65, + 0xa1, 0x28, 0xdd, 0x55, 0x6a, 0x59, 0x53, 0x67, 0xb6, 0xb3, + 0xd2, 0x4c, 0x90, 0x69, 0xd1, 0x1e, 0x62, 0xab, 0x33, 0x47, + 0x29, 0x45, 0x18, 0x1f, 0xeb, 0x6d, 0x13, 0xb4, 0x61, 0xf5, + 0x15, 0x03, 0xf7, 0x4f, 0x9c, 0x4c, 0x2c, 0xae, 0x5e, 0xde, + 0xd2, 0x11, 0x32, 0xb5, 0x17, 0xb5, 0xe8, 0xa3, 0xb2, 0x1f, + 0xc3, 0x9f, 0x78, 0xa1, 0xf5, 0x80, 0xb4, 0x96, 0x90, 0x6b, + 0x77, 0x9e, 0xe9, 0x39, 0x61, 0x2c, 0x18, 0xf5, 0x7b, 0xab, + 0x1e, 0x09, 0x88, 0x7d, 0xc3, 0x75, 0x5e, 0x4d, 0xcf, 0xf3, + 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x82, 0x01, 0xde, 0x30, + 0x82, 0x01, 0xda, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, + 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, + 0x81, 0xa0, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, + 0x01, 0x01, 0x04, 0x81, 0x93, 0x30, 0x81, 0x90, 0x30, 0x4d, + 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, + 0x86, 0x41, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73, + 0x65, 0x63, 0x75, 0x72, 0x65, 0x2e, 0x67, 0x6c, 0x6f, 0x62, + 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, + 0x2f, 0x63, 0x61, 0x63, 0x65, 0x72, 0x74, 0x2f, 0x67, 0x73, + 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, 0x67, + 0x32, 0x72, 0x31, 0x2e, 0x63, 0x72, 0x74, 0x30, 0x3f, 0x06, + 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, + 0x33, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, + 0x73, 0x70, 0x32, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, + 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, + 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, + 0x67, 0x32, 0x30, 0x56, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04, + 0x4f, 0x30, 0x4d, 0x30, 0x41, 0x06, 0x09, 0x2b, 0x06, 0x01, + 0x04, 0x01, 0xa0, 0x32, 0x01, 0x14, 0x30, 0x34, 0x30, 0x32, + 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, + 0x16, 0x26, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, + 0x77, 0x77, 0x77, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, + 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, + 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, + 0x30, 0x08, 0x06, 0x06, 0x67, 0x81, 0x0c, 0x01, 0x02, 0x02, + 0x30, 0x09, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, + 0x00, 0x30, 0x49, 0x06, 0x03, 0x55, 0x1d, 0x1f, 0x04, 0x42, + 0x30, 0x40, 0x30, 0x3e, 0xa0, 0x3c, 0xa0, 0x3a, 0x86, 0x38, + 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, + 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, + 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x73, 0x2f, 0x67, + 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, + 0x67, 0x32, 0x2e, 0x63, 0x72, 0x6c, 0x30, 0x18, 0x06, 0x03, + 0x55, 0x1d, 0x11, 0x04, 0x11, 0x30, 0x0f, 0x82, 0x0d, 0x77, + 0x65, 0x62, 0x72, 0x74, 0x63, 0x2e, 0x71, 0x71, 0x2e, 0x63, + 0x6f, 0x6d, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, + 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, + 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, + 0x07, 0x03, 0x02, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, + 0x04, 0x16, 0x04, 0x14, 0x28, 0xff, 0xe2, 0x97, 0xf3, 0x6f, + 0x2a, 0xef, 0x0f, 0xbc, 0x4c, 0x61, 0x9b, 0xd9, 0x23, 0x7b, + 0x3a, 0xef, 0xc2, 0xe7, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, + 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x96, 0xde, 0x61, + 0xf1, 0xbd, 0x1c, 0x16, 0x29, 0x53, 0x1c, 0xc0, 0xcc, 0x7d, + 0x3b, 0x83, 0x00, 0x40, 0xe6, 0x1a, 0x7c, 0x30, 0x0d, 0x06, + 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, + 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x30, 0xc1, 0xcc, + 0xd6, 0x97, 0xf7, 0xf5, 0xa7, 0x93, 0xa5, 0x78, 0xc8, 0xcb, + 0x81, 0x44, 0xd4, 0x1f, 0x2a, 0xa6, 0xc1, 0x48, 0xa8, 0x1a, + 0xbd, 0x17, 0x10, 0x0e, 0xdf, 0x21, 0xea, 0x02, 0x3e, 0xb3, + 0xbd, 0x45, 0x1e, 0x64, 0x85, 0x3f, 0x04, 0x9a, 0xc0, 0x78, + 0xf4, 0x81, 0x2e, 0x38, 0x39, 0x3a, 0x04, 0x2d, 0x5f, 0xec, + 0xc4, 0x10, 0x57, 0xfb, 0x1b, 0x32, 0xe0, 0x8e, 0xfc, 0xe3, + 0x6d, 0x4b, 0xc6, 0xf0, 0x07, 0xb7, 0xc6, 0x19, 0xd7, 0x99, + 0x93, 0xbd, 0x60, 0x58, 0xad, 0xbb, 0x94, 0xcf, 0xd8, 0x05, + 0x5c, 0x14, 0x70, 0xec, 0x2e, 0xb7, 0x60, 0x52, 0x3c, 0xd3, + 0x03, 0xf8, 0xcd, 0xe5, 0x4e, 0x84, 0xcf, 0xef, 0x2f, 0x12, + 0xdd, 0x74, 0xfd, 0x95, 0x9d, 0x03, 0xa9, 0x81, 0x18, 0x3a, + 0x6e, 0xe6, 0xc2, 0xdd, 0x07, 0x1e, 0xea, 0x8c, 0xe6, 0xd9, + 0x31, 0x72, 0x63, 0x25, 0xcd, 0xf2, 0x19, 0xf2, 0x4e, 0x3c, + 0x18, 0xfb, 0xb2, 0x74, + } + certificate2 := []byte{ + 0x0b, 0x00, 0x05, 0x5b, 0x00, 0x01, 0x00, 0x04, 0xe4, 0x00, + 0x00, 0x77, 0xc1, 0x6b, 0x67, 0xec, 0x34, 0x05, 0xe8, 0x63, + 0xfc, 0x74, 0x4b, 0x11, 0x3f, 0x3a, 0xe4, 0x4e, 0x06, 0x89, + 0x96, 0x24, 0x3c, 0x15, 0x83, 0xc5, 0x1d, 0xeb, 0xc0, 0x19, + 0x71, 0x35, 0x6c, 0xfa, 0xf1, 0x51, 0x06, 0x0e, 0x8e, 0xfb, + 0x9b, 0x4e, 0xaa, 0x50, 0x24, 0x77, 0xac, 0x86, 0x14, 0x50, + 0x52, 0x35, 0x68, 0x15, 0x9b, 0xdd, 0x8b, 0xdb, 0x83, 0x1d, + 0xed, 0x45, 0x05, 0x78, 0x53, 0xd6, 0xc4, 0x21, 0xaf, 0x68, + 0x45, 0x91, 0xe7, 0x30, 0x36, 0x4c, 0xb1, 0xfb, 0xf1, 0x65, + 0x9a, 0xe4, 0x49, 0x90, 0x1c, 0x0c, 0xa8, 0x63, 0xe9, 0x04, + 0xe3, 0x17, 0x61, 0x8d, 0x20, 0x29, 0xca, 0x41, 0xa6, 0x8b, + 0x32, 0x53, 0xa5, 0x84, 0x29, 0x5a, 0x62, 0xe7, 0x84, 0x38, + 0x32, 0x56, 0xbb, 0x8b, 0xbc, 0x25, 0xc7, 0xa3, 0x28, 0x3b, + 0x35, + } + serverKeyExchange := []byte{ + 0x0c, 0x00, 0x01, 0x28, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x59, 0xa2, 0x0f, 0xc4, + 0x7b, 0xd8, 0x03, 0xf6, 0xb0, 0xcf, 0x5d, 0xf0, 0x45, 0x7f, + 0x7e, 0xf2, 0x98, 0xab, 0xc0, 0x24, 0xf1, 0xdf, 0xba, 0x63, + 0x3e, 0xfb, 0xe5, 0x02, 0x31, 0xcf, 0xd1, 0x05, 0x04, 0x01, + 0x01, 0x00, 0x7b, 0x52, 0x9c, 0xe7, 0x54, 0x8b, 0xb0, 0xc9, + 0xfd, 0xaf, 0xe2, 0x91, 0x19, 0x9d, 0x6c, 0xb8, 0xbe, 0xa5, + 0xe1, 0x48, 0xa0, 0xfd, 0xc5, 0x76, 0x62, 0x47, 0xf2, 0xd1, + 0x35, 0x76, 0x4e, 0x33, 0xf4, 0xa1, 0xf1, 0x58, 0xdc, 0xd5, + 0x45, 0x3f, 0x76, 0x64, 0x40, 0xba, 0x32, 0xe3, 0x07, 0xb7, + 0x4b, 0xbe, 0xe2, 0x77, 0x99, 0xad, 0x11, 0x73, 0x54, 0xe6, + 0xbb, 0xfb, 0xd4, 0xb1, 0x83, 0x9f, 0xc6, 0x50, 0xc6, 0xd8, + 0xbb, 0x92, 0x0d, 0x93, 0xf9, 0x63, 0x29, 0xf9, 0xc3, 0xce, + 0x24, 0x40, 0x29, 0x95, 0x43, 0xf0, 0x32, 0x00, 0x21, 0xde, + 0xdf, 0x64, 0xfe, 0xb6, 0x11, 0xa0, 0x11, 0x44, 0x12, 0x2a, + 0x1c, 0x96, 0x44, 0x4b, 0x79, 0x31, 0x23, 0x46, 0x4e, 0xe8, + 0x16, 0x5b, 0xf5, 0x9a, 0x5f, 0x51, 0x10, 0x5b, 0x11, 0xa3, + 0xb8, 0x1f, 0xb7, 0xf1, 0x11, 0xad, 0x05, 0x82, 0x2b, 0xc3, + 0x65, 0x8c, 0x41, 0xb4, 0x8e, 0x60, 0x42, 0x89, 0x92, 0xd1, + 0x83, 0x73, 0xe7, 0x35, 0xb4, 0xc9, 0xd1, 0xbc, 0x5c, 0x84, + 0x5b, 0xdb, 0x44, 0x34, 0xea, 0xd8, 0x06, 0xe4, 0xfb, 0xbd, + 0x40, 0x35, 0x18, 0x60, 0x33, 0xb6, 0xed, 0xbc, 0x9b, 0x3a, + 0xff, 0x2f, 0xa1, 0xe8, 0x5d, 0x5c, 0xbb, 0xe8, 0xe1, 0xa6, + 0xbb, 0x84, 0x0f, 0x50, 0x51, 0x0d, 0xa5, 0x8f, 0x96, 0xb6, + 0x35, 0x37, 0x7b, 0x58, 0xaf, 0x4f, 0x77, 0x9d, 0x5d, 0xb2, + 0xff, 0x5f, 0xd6, 0xb8, 0x82, 0x64, 0x5f, 0x79, 0xd0, 0x06, + 0x44, 0x6d, 0x3a, 0x82, 0x25, 0x21, 0xca, 0xbb, 0xa0, 0x79, + 0xdd, 0x6e, 0x15, 0xb6, 0x57, 0x9b, 0x04, 0x84, 0x63, 0x88, + 0x1d, 0x41, 0xff, 0xe1, 0x20, 0x61, 0xd5, 0x3f, 0xc7, 0xca, + 0x0c, 0xd9, 0xe0, 0x74, 0x86, 0x78, 0xed, 0x60, 0x18, 0x2d, + 0x9e, 0x69, 0x66, 0x77, 0xf7, 0xd0, 0xe9, 0x9c, + } + certificateRequest := []byte{ + 0x0d, 0x00, 0x00, 0x26, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x26, 0x03, 0x01, 0x02, 0x40, 0x00, 0x1e, 0x06, 0x01, + 0x06, 0x02, 0x06, 0x03, 0x05, 0x01, 0x05, 0x02, 0x05, 0x03, + 0x04, 0x01, 0x04, 0x02, 0x04, 0x03, 0x03, 0x01, 0x03, 0x02, + 0x03, 0x03, 0x02, 0x01, 0x02, 0x02, 0x02, 0x03, 0x00, 0x00, + } + serverHelloDone := []byte{ + 0x0e, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + } + cache.push(certificate2, 0, 2, handshake.TypeCertificate, false) + cache.push(serverKeyExchange, 0, 3, handshake.TypeServerKeyExchange, false) + cache.push(certificateRequest, 0, 4, handshake.TypeCertificateRequest, false) + cache.push(serverHelloDone, 0, 5, handshake.TypeServerHelloDone, false) + + if _, alt, err := flight1Parse(context.TODO(), mockConn, state, cache, cfg); err != nil { + t.Fatal(err) + } else if alt != nil { + t.Fatal(alt.String()) + } + + cache.push(serverHello, 0, 0, handshake.TypeServerHello, false) + cache.push(certificate1, 0, 1, handshake.TypeCertificate, false) + if _, alt, err := flight1Parse(context.TODO(), mockConn, state, cache, cfg); err != nil { + t.Fatal(err) + } else if alt != nil { + t.Fatal(alt.String()) + } +} diff --git a/flight2handler.go b/flight2handler.go index 26e57d2f2..8d50befba 100644 --- a/flight2handler.go +++ b/flight2handler.go @@ -7,13 +7,19 @@ import ( "bytes" "context" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight2Parse( + ctx context.Context, + c flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) @@ -41,11 +47,18 @@ func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handsh if !bytes.Equal(state.cookie, clientHello.Cookie) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch } + return flight4, nil, nil } -func flight2Generate(_ flightConn, state *State, _ *handshakeCache, _ *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight2Generate( + _ flightConn, + state *State, + _ *handshakeCache, + _ *handshakeConfig, +) ([]*packet, *alert.Alert, error) { state.handshakeSendSequence = 0 + return []*packet{ { record: &recordlayer.RecordLayer{ diff --git a/flight3handler.go b/flight3handler.go index 5a763dc08..7301e34b6 100644 --- a/flight3handler.go +++ b/flight3handler.go @@ -7,17 +7,24 @@ import ( "bytes" "context" - "github.com/pion/dtls/v2/internal/ciphersuite/types" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/extension" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,gocyclo,maintidx,cyclop +func flight3Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { // Clients may receive multiple HelloVerifyRequest messages with different cookies. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 @@ -33,6 +40,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq + return flight3, nil, nil } } @@ -45,37 +53,53 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, nil, nil } - if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { - if !h.Version.Equal(protocol.Version1_2) { + if serverHelloMsg, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { //nolint:nestif + if !serverHelloMsg.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } - for _, v := range h.Extensions { - switch e := v.(type) { + for _, v := range serverHelloMsg.Extensions { + switch ext := v.(type) { case *extension.UseSRTP: - profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) + profile, found := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile } - state.srtpProtectionProfile = profile + state.setSRTPProtectionProfile(profile) + state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ALPN: - if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling - return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error? + if len(ext.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling + return 0, &alert.Alert{ + Level: alert.Fatal, + Description: alert.InternalError, + }, extension.ErrALPNInvalidFormat // Meh, internal error? + } + state.NegotiatedProtocol = ext.ProtocolNameList[0] + case *extension.ConnectionID: + // Only set connection ID to be sent if client supports connection + // IDs. + if cfg.connectionIDGenerator != nil { + state.remoteConnectionID = ext.CID } - state.NegotiatedProtocol = e.ProtocolNameList[0] } } + // If the server doesn't support connection IDs, the client should not + // expect one to be sent. + if state.remoteConnectionID == nil { + state.setLocalConnectionID(nil) + } + if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS } - if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 { + if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension } - remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites) + remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*serverHelloMsg.CipherSuiteID), cfg.customCipherSuites) if remoteCipherSuite == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } @@ -86,11 +110,11 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh } state.cipherSuite = selectedCipherSuite - state.remoteRandom = h.Random + state.remoteRandom = serverHelloMsg.Random cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) - if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) { - return handleResumption(ctx, c, state, cache, cfg) + if len(serverHelloMsg.SessionID) > 0 && bytes.Equal(state.SessionID, serverHelloMsg.SessionID) { + return handleResumption(ctx, conn, state, cache, cfg) } if len(state.SessionID) > 0 { @@ -103,7 +127,7 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh if cfg.sessionStore == nil { state.SessionID = []byte{} } else { - state.SessionID = h.SessionID + state.SessionID = serverHelloMsg.SessionID } state.masterSecret = []byte{} @@ -135,20 +159,27 @@ func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handsh } if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { - alertPtr, err := handleServerKeyExchange(c, state, cfg, h) + alertPtr, err := handleServerKeyExchange(conn, state, cfg, h) if err != nil { return 0, alertPtr, err } } - if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { + if creq, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { + state.remoteCertRequestAlgs = creq.SignatureHashAlgorithms state.remoteRequestedCertificate = true } return flight5, nil, nil } -func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func handleResumption( + ctx context.Context, + c flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -189,25 +220,36 @@ func handleResumption(ctx context.Context, c flightConn, state *State, cache *ha return flight5b, nil, nil } -func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) { +//nolint:cyclop +func handleServerKeyExchange( + _ flightConn, + state *State, + cfg *handshakeConfig, + keyExchangeMessage *handshake.MessageServerKeyExchange, +) (*alert.Alert, error) { var err error if state.cipherSuite == nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } - if cfg.localPSKCallback != nil { + if cfg.localPSKCallback != nil { //nolint:nestif var psk []byte - if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil { + if psk, err = cfg.localPSKCallback(keyExchangeMessage.IdentityHint); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } - state.IdentityHint = h.IdentityHint + state.IdentityHint = keyExchangeMessage.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case types.KeyExchangeAlgorithmPsk: state.preMasterSecret = prf.PSKPreMasterSecret(psk) case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): - if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { + if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } - state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) + state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret( + psk, + keyExchangeMessage.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -215,11 +257,15 @@ func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } } else { - if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { + if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } - if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { + if state.preMasterSecret, err = prf.PreMasterSecret( + keyExchangeMessage.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } @@ -227,7 +273,12 @@ func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h return nil, nil //nolint:nilnil } -func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight3Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, @@ -236,10 +287,11 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha RenegotiatedConnection: 0, }, } + if state.namedCurve != 0 { extensions = append(extensions, []extension.Extension{ &extension.SupportedEllipticCurves{ - EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, + EllipticCurves: cfg.ellipticCurves, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, @@ -268,23 +320,37 @@ func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) } + // If we sent a connection ID on the first ClientHello, send it on the + // second. + if state.getLocalConnectionID() != nil { + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) + } + + clientHello := &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.clientHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} + } else { + content = handshake.Handshake{Message: clientHello} + } + return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageClientHello{ - Version: protocol.Version1_2, - SessionID: state.SessionID, - Cookie: state.cookie, - Random: state.localRandom, - CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), - CompressionMethods: defaultCompressionMethods(), - Extensions: extensions, - }, - }, + Content: &content, }, }, }, nil, nil diff --git a/flight3handler_test.go b/flight3handler_test.go new file mode 100644 index 000000000..af7374d6d --- /dev/null +++ b/flight3handler_test.go @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "github.com/pion/transport/v3/dpipe" + "github.com/pion/transport/v3/test" +) + +// Assert that SupportedEllipticCurves is only sent when a ECC CipherSuite is available. +func TestSupportedEllipticCurves(t *testing.T) { //nolint:cyclop + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + expectedCurves := defaultCurves + var actualCurves []elliptic.Curve + + rand.Shuffle(len(expectedCurves), func(i, j int) { + expectedCurves[i], expectedCurves[j] = expectedCurves[j], expectedCurves[i] + }) + + clientErr := make(chan error, 1) + ca, cb := dpipe.Pipe() + caAnalyzer := &connWithCallback{Conn: ca} + caAnalyzer.onWrite = func(in []byte) { + messages, err := recordlayer.UnpackDatagram(in) + if err != nil { + t.Fatal(err) + } + + for i := range messages { + h := &handshake.Handshake{} + _ = h.Unmarshal(messages[i][recordlayer.FixedHeaderSize:]) + + if h.Header.Type == handshake.TypeClientHello { //nolint:nestif + clientHello := &handshake.MessageClientHello{} + msg, err := h.Message.Marshal() + + if err != nil { + t.Fatal(err) + } else if err = clientHello.Unmarshal(msg); err != nil { + t.Fatal(err) + } + + for _, e := range clientHello.Extensions { + if e.TypeValue() == extension.SupportedEllipticCurvesTypeValue { + if c, ok := e.(*extension.SupportedEllipticCurves); ok { + actualCurves = c.EllipticCurves + } + } + } + } + } + } + + go func() { + conf := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + EllipticCurves: expectedCurves, + } + + if client, err := testClient( + ctx, + dtlsnet.PacketConnFromConn(caAnalyzer), + caAnalyzer.RemoteAddr(), + conf, + false, + ); err != nil { + clientErr <- err + } else { + clientErr <- client.Close() //nolint + } + }() + + config := &Config{ + CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + } + + if server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true); err != nil { + t.Fatalf("Server error %v", err) + } else { + if err = server.Close(); err != nil { + t.Fatal(err) + } + } + + if err := <-clientErr; err != nil { + t.Fatalf("Client error %v", err) + } + + for i := range expectedCurves { + if expectedCurves[i] != actualCurves[i] { + t.Fatal("List of curves in SupportedEllipticCurves does not match config") + } + } +} diff --git a/flight4bhandler.go b/flight4bhandler.go index 6bbbc5972..681533b0e 100644 --- a/flight4bhandler.go +++ b/flight4bhandler.go @@ -7,15 +7,21 @@ import ( "bytes" "context" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/extension" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight4bParse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) @@ -47,7 +53,13 @@ func flight4bParse(_ context.Context, _ flightConn, state *State, cache *handsha return flight4b, nil, nil } -func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +//nolint:cyclop +func flight4bGenerate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { var pkts []*packet extensions := []extension.Extension{&extension.RenegotiationInfo{ @@ -59,9 +71,10 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha Supported: true, }) } - if state.srtpProtectionProfile != 0 { + if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile}, + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } @@ -77,18 +90,24 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha } cipherSuiteID := uint16(state.cipherSuite.ID()) - serverHello := &handshake.Handshake{ - Message: &handshake.MessageServerHello{ - Version: protocol.Version1_2, - Random: state.localRandom, - SessionID: state.SessionID, - CipherSuiteID: &cipherSuiteID, - CompressionMethod: defaultCompressionMethods()[0], - Extensions: extensions, - }, + var serverHello handshake.Handshake + + serverHelloMessage := &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: state.localRandom, + SessionID: state.SessionID, + CipherSuiteID: &cipherSuiteID, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: extensions, + } + + if cfg.serverHelloMessageHook != nil { + serverHello = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHelloMessage)} + } else { + serverHello = handshake.Handshake{Message: serverHelloMessage} } - serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) + serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) //nolint:gosec // G115 if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( @@ -112,7 +131,7 @@ func flight4bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *ha Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: serverHello, + Content: &serverHello, }, }, &packet{ diff --git a/flight4handler.go b/flight4handler.go index 67a486461..f75373fc2 100644 --- a/flight4handler.go +++ b/flight4handler.go @@ -5,22 +5,30 @@ package dtls import ( "context" + "crypto" "crypto/rand" "crypto/x509" - "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/extension" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,gocyclo,lll,cyclop,maintidx +func flight4Parse( + ctx context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, @@ -47,7 +55,8 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh state.SessionID = nil } - if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify { + //nolint:nestif + if verify, hasVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasVerify { if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate } @@ -66,8 +75,9 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { - if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { + if ss.Hash == verify.HashAlgorithm && ss.Signature == verify.SignatureAlgorithm { validSignatureScheme = true + break } } @@ -75,7 +85,12 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } - if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil { + if err := verifyCertificateVerify( + plainText, + verify.HashAlgorithm, + verify.Signature, + state.PeerCertificates, + ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate @@ -99,7 +114,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, nil, nil } - if !state.cipherSuite.IsInitialized() { + if !state.cipherSuite.IsInitialized() { //nolint:nestif serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() @@ -115,14 +130,23 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh case CipherSuiteKeyExchangeAlgorithmPsk: preMasterSecret = prf.PSKPreMasterSecret(psk) case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe): - if preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { + if preMasterSecret, err = prf.EcdhePSKPreMasterSecret( + psk, + clientKeyExchange.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite } } else { - preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) + preMasterSecret, err = prf.PreMasterSecret( + clientKeyExchange.PublicKey, + state.localKeypair.PrivateKey, + state.localKeypair.Curve, + ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } @@ -140,7 +164,12 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } else { - state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) + state.masterSecret, err = prf.MasterSecret( + preMasterSecret, + clientRandom[:], + serverRandom[:], + state.cipherSuite.HashFunc(), + ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -164,7 +193,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh } // Now, encrypted packets can be handled - if err := c.handleQueuedPackets(ctx); err != nil { + if err := conn.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -181,12 +210,17 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } - if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { //nolint:nestif if cfg.verifyConnection != nil { - if err := cfg.verifyConnection(state.clone()); err != nil { + stateClone, err := state.clone() + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } + return flight6, nil, nil } @@ -210,7 +244,11 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh // go to flight6 } if cfg.verifyConnection != nil { - if err := cfg.verifyConnection(state.clone()); err != nil { + stateClone, err := state.clone() + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } @@ -218,7 +256,13 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh return flight6, nil, nil } -func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +//nolint:gocognit,cyclop,maintidx +func flight4Generate( + _ flightConn, + state *State, + _ *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{&extension.RenegotiationInfo{ RenegotiatedConnection: 0, }} @@ -228,9 +272,10 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha Supported: true, }) } - if state.srtpProtectionProfile != 0 { + if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile}, + ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, + MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { @@ -250,6 +295,15 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha state.NegotiatedProtocol = selectedProto } + // If we have a connection ID generator, we are willing to use connection + // IDs. We already know whether the client supports connection IDs from + // parsing the ClientHello, so avoid setting local connection ID if the + // client won't send it. + if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil { + state.setLocalConnectionID(cfg.connectionIDGenerator()) + extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) + } + var pkts []*packet cipherSuiteID := uint16(state.cipherSuite.ID()) @@ -260,21 +314,29 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha } } + serverHello := &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: state.localRandom, + SessionID: state.SessionID, + CipherSuiteID: &cipherSuiteID, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: extensions, + } + + var content handshake.Handshake + + if cfg.serverHelloMessageHook != nil { + content = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHello)} + } else { + content = handshake.Handshake{Message: serverHello} + } + pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageServerHello{ - Version: protocol.Version1_2, - Random: state.localRandom, - SessionID: state.SessionID, - CipherSuiteID: &cipherSuiteID, - CompressionMethod: defaultCompressionMethods()[0], - Extensions: extensions, - }, - }, + Content: &content, }, }) @@ -283,6 +345,7 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha certificate, err := cfg.getCertificate(&ClientHelloInfo{ ServerName: state.serverName, CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()}, + RandomBytes: state.remoteRandom.RandomBytes, }) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err @@ -304,13 +367,25 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() + signer, ok := certificate.PrivateKey.(crypto.Signer) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidPrivateKey + } + // Find compatible signature scheme - signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey) + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } - signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash) + signature, err := generateKeySignature( + clientRandom[:], + serverRandom[:], + state.localKeypair.PublicKey, + state.namedCurve, + signer, + signatureHashAlgo.Hash, + ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -342,25 +417,37 @@ func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handsha // an appropriate certificate to give to us. var certificateAuthorities [][]byte if cfg.clientCAs != nil { - // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool and it's ok if certificate authorities is empty. + // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR + // because cert does not come from SystemCertPool and it's ok if certificate + // authorities is empty. certificateAuthorities = cfg.clientCAs.Subjects() } + + certReq := &handshake.MessageCertificateRequest{ + CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, + SignatureHashAlgorithms: cfg.localSignatureSchemes, + CertificateAuthoritiesNames: certificateAuthorities, + } + + var content handshake.Handshake + + if cfg.certificateRequestMessageHook != nil { + content = handshake.Handshake{Message: cfg.certificateRequestMessageHook(*certReq)} + } else { + content = handshake.Handshake{Message: certReq} + } + pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, - Content: &handshake.Handshake{ - Message: &handshake.MessageCertificateRequest{ - CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, - SignatureHashAlgorithms: cfg.localSignatureSchemes, - CertificateAuthoritiesNames: certificateAuthorities, - }, - }, + Content: &content, }, }) } - case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): + case cfg.localPSKIdentityHint != nil || + state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): // To help the client in selecting which identity to use, the server // can provide a "PSK identity hint" in the ServerKeyExchange message. // If no hint is provided and cipher suite doesn't use elliptic curve, diff --git a/flight4handler_test.go b/flight4handler_test.go index 318a05826..458292b69 100644 --- a/flight4handler_test.go +++ b/flight4handler_test.go @@ -5,22 +5,29 @@ package dtls import ( "context" + "crypto/tls" + "errors" "testing" "time" - "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/transport/v2/test" + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/transport/v3/test" ) type flight4TestMockFlightConn struct{} +var errHookCertReqFailed = errors.New("hook failed to modify SignatureHashAlgorithms") + func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { return nil } func (f *flight4TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } -func (f *flight4TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil } +func (f *flight4TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } func (f *flight4TestMockFlightConn) setLocalEpoch(uint16) {} func (f *flight4TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight4TestMockFlightConn) sessionKey() []byte { return nil } @@ -33,13 +40,14 @@ type flight4TestMockCipherSuite struct { func (f *flight4TestMockCipherSuite) IsInitialized() bool { f.t.Fatal("IsInitialized called with Certificate but not CertificateVerify") + return true } // Assert that if a Client sends a certificate they // must also send a CertificateVerify message. // The flight4handler must not interact with the CipherSuite -// if the CertificateVerify is missing +// if the CertificateVerify is missing. func TestFlight4_Process_CertificateVerify(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) @@ -117,3 +125,65 @@ func TestFlight4_Process_CertificateVerify(t *testing.T) { t.Fatal(err) } } + +func TestFlight4_CertificateRequestHook(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(5 * time.Second) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + localKeypair, err := elliptic.GenerateKeypair(elliptic.P256) + if err != nil { + t.Fatal(err) + } + + mockConn := &flight4TestMockFlightConn{} + state := &State{ + cipherSuite: &flight4TestMockCipherSuite{t: t}, + localKeypair: localKeypair, + } + + cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") + if err != nil { + t.Fatal(err) + } + + cfg := &handshakeConfig{ + localCertificates: []tls.Certificate{cert}, + localSignatureSchemes: signaturehash.Algorithms(), + clientAuth: 1, + certificateRequestMessageHook: func(mcr handshake.MessageCertificateRequest) handshake.Message { + mcr.SignatureHashAlgorithms = []signaturehash.Algorithm{} + + return &mcr + }, + } + + pkts, _, err := flight4Generate(mockConn, state, nil, cfg) + if err != nil { + t.Fatal(err) + } + + for _, p := range pkts { + if h, ok := p.record.Content.(*handshake.Handshake); ok { //nolint:nestif + if h.Message.Type() == handshake.TypeCertificateRequest { + mcr := &handshake.MessageCertificateRequest{} + msg, err := h.Message.Marshal() + if err != nil { + t.Fatal(err) + } + err = mcr.Unmarshal(msg) + if err != nil { + t.Fatal(err) + } + if len(mcr.SignatureHashAlgorithms) == 0 { + return + } + } + } + } + t.Fatal(errHookCertReqFailed) +} diff --git a/flight5bhandler.go b/flight5bhandler.go index ddd37324c..db6de367c 100644 --- a/flight5bhandler.go +++ b/flight5bhandler.go @@ -6,14 +6,20 @@ package dtls import ( "context" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight5bParse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight5bParse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) @@ -30,7 +36,12 @@ func flight5bParse(_ context.Context, _ flightConn, state *State, cache *handsha return flight5b, nil, nil } -func flight5bGenerate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit +func flight5bGenerate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { //nolint:gocognit var pkts []*packet pkts = append(pkts, diff --git a/flight5handler.go b/flight5handler.go index e8adf4f36..1b85b06b9 100644 --- a/flight5handler.go +++ b/flight5handler.go @@ -9,15 +9,21 @@ import ( "crypto" "crypto/x509" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight5Parse( + _ context.Context, + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) @@ -57,7 +63,7 @@ func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshak Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) - if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil { + if err := cfg.sessionStore.Set(conn.sessionKey(), s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } @@ -65,17 +71,23 @@ func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshak return flight5, nil, nil } -func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit - var privateKey crypto.PrivateKey +//nolint:gocognit,cyclop,maintidx +func flight5Generate( + conn flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { + var signer crypto.Signer var pkts []*packet - if state.remoteRequestedCertificate { + if state.remoteRequestedCertificate { //nolint:nestif _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired } reqInfo := CertificateRequestInfo{} - if r, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { + if r, ok2 := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok2 { reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames } else { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired @@ -88,7 +100,10 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain } if certificate.Certificate != nil { - privateKey = certificate.PrivateKey + signer, ok = certificate.PrivateKey.(crypto.Signer) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errInvalidPrivateKey + } } pkts = append(pkts, &packet{ @@ -135,7 +150,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han // handshakeMessageServerKeyExchange is optional for PSK if len(serverKeyExchangeData) == 0 { - alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{}) + alertPtr, err := handleServerKeyExchange(conn, state, cfg, &handshake.MessageServerKeyExchange{}) if err != nil { return nil, alertPtr, err } @@ -158,7 +173,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han // Append not-yet-sent packets merged := []byte{} - seqPred := uint16(state.handshakeSendSequence) + seqPred := uint16(state.handshakeSendSequence) //nolint:gosec // G115 for _, p := range pkts { h, ok := p.record.Content.(*handshake.Handshake) if !ok { @@ -173,14 +188,14 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han merged = append(merged, raw...) } - if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil { + if alertPtr, err := initializeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil { return nil, alertPtr, err } // If the client has sent a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of the // private key in the certificate. - if state.remoteRequestedCertificate && privateKey != nil { + if state.remoteRequestedCertificate && signer != nil { plainText := append(cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, @@ -193,18 +208,19 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han ), merged...) // Find compatible signature scheme - signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey) + + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(state.remoteCertRequestAlgs, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } - certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash) + certVerify, err := generateCertificateVerify(plainText, signer, signatureHashAlgo.Hash) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.localCertificatesVerify = certVerify - p := &packet{ + pkt := &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, @@ -218,9 +234,9 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han }, }, } - pkts = append(pkts, p) + pkts = append(pkts, pkt) - h, ok := p.record.Content.(*handshake.Handshake) + h, ok := pkt.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } @@ -258,7 +274,11 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han ) var err error - state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc()) + state.localVerifyData, err = prf.VerifyDataClient( + state.masterSecret, + append(plainText, merged...), + state.cipherSuite.HashFunc(), + ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } @@ -277,6 +297,7 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han }, }, }, + shouldWrapCID: len(state.remoteConnectionID) > 0, shouldEncrypt: true, resetLocalSequenceNumber: true, }) @@ -284,7 +305,14 @@ func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *han return pkts, nil, nil } -func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit +//nolint:gocognit,cyclop +func initializeCipherSuite( + state *State, + cache *handshakeCache, + cfg *handshakeConfig, + handshakeKeyExchange *handshake.MessageServerKeyExchange, + sendingPlainText []byte, +) (*alert.Alert, error) { if state.cipherSuite.IsInitialized() { return nil, nil //nolint } @@ -306,18 +334,24 @@ func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCon return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } else { - state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) + state.masterSecret, err = prf.MasterSecret( + state.preMasterSecret, + clientRandom[:], + serverRandom[:], + state.cipherSuite.HashFunc(), + ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } - if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { //nolint:nestif // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { - if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { + if ss.Hash == handshakeKeyExchange.HashAlgorithm && ss.Signature == handshakeKeyExchange.SignatureAlgorithm { validSignatureScheme = true + break } } @@ -325,8 +359,19 @@ func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCon return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } - expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve) - if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil { + expectedMsg := valueKeyMessage( + clientRandom[:], + serverRandom[:], + handshakeKeyExchange.PublicKey, + handshakeKeyExchange.NamedCurve, + ) + if err = verifyKeySignature( + expectedMsg, + handshakeKeyExchange. + Signature, + handshakeKeyExchange.HashAlgorithm, + state.PeerCertificates, + ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate @@ -342,8 +387,12 @@ func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeCon } } if cfg.verifyConnection != nil { - if err = cfg.verifyConnection(state.clone()); err != nil { - return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + stateClone, errC := state.clone() + if errC != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errC + } + if errC = cfg.verifyConnection(stateClone); errC != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errC } } diff --git a/flight6handler.go b/flight6handler.go index 57ac14360..d7828e749 100644 --- a/flight6handler.go +++ b/flight6handler.go @@ -6,14 +6,20 @@ package dtls import ( "context" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -func flight6Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { +func flight6Parse( + _ context.Context, + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) @@ -30,7 +36,12 @@ func flight6Parse(_ context.Context, _ flightConn, state *State, cache *handshak return flight6, nil, nil } -func flight6Generate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { +func flight6Generate( + _ flightConn, + state *State, + cache *handshakeCache, + cfg *handshakeConfig, +) ([]*packet, *alert.Alert, error) { var pkts []*packet pkts = append(pkts, @@ -77,9 +88,11 @@ func flight6Generate(_ flightConn, state *State, cache *handshakeCache, cfg *han }, }, }, + shouldWrapCID: len(state.remoteConnectionID) > 0, shouldEncrypt: true, resetLocalSequenceNumber: true, }, ) + return pkts, nil, nil } diff --git a/flighthandler.go b/flighthandler.go index ceb4a992b..b90cebd3b 100644 --- a/flighthandler.go +++ b/flighthandler.go @@ -6,16 +6,22 @@ package dtls import ( "context" - "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/alert" ) -// Parse received handshakes and return next flightVal -type flightParser func(context.Context, flightConn, *State, *handshakeCache, *handshakeConfig) (flightVal, *alert.Alert, error) +// Parse received handshakes and return next flightVal. +type flightParser func( + context.Context, + flightConn, + *State, + *handshakeCache, + *handshakeConfig, +) (flightVal, *alert.Alert, error) -// Generate flights +// Generate flights. type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error) -func (f flightVal) getFlightParser() (flightParser, error) { +func (f flightVal) getFlightParser() (flightParser, error) { //nolint:cyclop switch f { case flight0: return flight0Parse, nil @@ -40,7 +46,7 @@ func (f flightVal) getFlightParser() (flightParser, error) { } } -func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { +func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { //nolint:cyclop switch f { case flight0: return flight0Generate, true, nil diff --git a/fragment_buffer.go b/fragment_buffer.go index f20033758..497d97107 100644 --- a/fragment_buffer.go +++ b/fragment_buffer.go @@ -4,12 +4,12 @@ package dtls import ( - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// 2 megabytes +// 2 megabytes. const fragmentBufferMaxSize = 2000000 type fragment struct { @@ -29,7 +29,7 @@ func newFragmentBuffer() *fragmentBuffer { return &fragmentBuffer{cache: map[uint16][]*fragment{}} } -// current total size of buffer +// current total size of buffer. func (f *fragmentBuffer) size() int { size := 0 for i := range f.cache { @@ -37,32 +37,37 @@ func (f *fragmentBuffer) size() int { size += len(f.cache[i][j].data) } } + return size } // Attempts to push a DTLS packet to the fragmentBuffer // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled -// when an error returns it is fatal, and the DTLS connection should be stopped -func (f *fragmentBuffer) push(buf []byte) (bool, error) { +// when an error returns it is fatal, and the DTLS connection should be stopped. +func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) { if f.size()+len(buf) >= fragmentBufferMaxSize { - return false, errFragmentBufferOverflow + return false, false, errFragmentBufferOverflow } frag := new(fragment) if err := frag.recordLayerHeader.Unmarshal(buf); err != nil { - return false, err + return false, false, err } // fragment isn't a handshake, we don't need to handle it if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake { - return false, nil + return false, false, nil } - for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) { + for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) { if err := frag.handshakeHeader.Unmarshal(buf); err != nil { - return false, err + return false, false, err } + // Fragment is a retransmission. We have already assembled it before successfully + isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && + frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber + if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{} } @@ -80,7 +85,7 @@ func (f *fragmentBuffer) push(buf []byte) (bool, error) { buf = buf[end:] } - return true, nil + return true, isRetransmit, nil } func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { @@ -104,9 +109,11 @@ func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { } rawMessage = append(f.data, rawMessage...) + return true } } + return false } @@ -128,5 +135,6 @@ func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { delete(f.cache, f.currentMessageSequenceNumber) f.currentMessageSequenceNumber++ + return append(rawHeader, rawMessage...), messageEpoch } diff --git a/fragment_buffer_test.go b/fragment_buffer_test.go index ad8834e71..9e842b0a2 100644 --- a/fragment_buffer_test.go +++ b/fragment_buffer_test.go @@ -19,7 +19,10 @@ func TestFragmentBuffer(t *testing.T) { { Name: "Single Fragment", In: [][]byte{ - {0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, + { + 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, + }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, @@ -29,7 +32,10 @@ func TestFragmentBuffer(t *testing.T) { { Name: "Single Fragment Epoch 3", In: [][]byte{ - {0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, + { + 0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, + }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, @@ -39,24 +45,48 @@ func TestFragmentBuffer(t *testing.T) { { Name: "Multiple Fragments", In: [][]byte{ - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E}, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + }, }, Expected: [][]byte{ - {0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}, + { + 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, + }, }, Epoch: 0, }, { Name: "Multiple Unordered Fragments", In: [][]byte{ - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E}, - {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09}, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, + }, + { + 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, + 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, + }, }, Expected: [][]byte{ - {0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}, + { + 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, + 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, + }, }, Epoch: 0, }, @@ -94,7 +124,7 @@ func TestFragmentBuffer(t *testing.T) { } { fragmentBuffer := newFragmentBuffer() for _, frag := range test.In { - status, err := fragmentBuffer.push(frag) + status, _, err := fragmentBuffer.push(frag) if err != nil { t.Error(err) } else if !status { @@ -122,13 +152,16 @@ func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits - if _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil { + if _, _, err := fragmentBuffer.push([]byte{ + 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, + 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, + }); err != nil { t.Fatal(err) } // Allocate a buffer that exceeds cache size largeBuffer := make([]byte, fragmentBufferMaxSize) - if _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { + if _, _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { t.Fatalf("Pushing a large buffer returned (%s) expected(%s)", err, errFragmentBufferOverflow) } } diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 000000000..c91ec2d5e --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT +package dtls + +import ( + "os" + "testing" +) + +func FuzzUnmarshalBinary(f *testing.F) { + TestResumeClient, err := os.ReadFile("testdata/seed/TestResumeClient.raw") + if err != nil { + return + } + f.Add(TestResumeClient) + + TestResumeServer, err := os.ReadFile("testdata/seed/TestResumeServer.raw") + if err != nil { + return + } + f.Add(TestResumeServer) + + f.Fuzz(func(_ *testing.T, data []byte) { + deserialized := &State{} + _ = deserialized.UnmarshalBinary(data) + }) +} diff --git a/go.mod b/go.mod index 9d0d65615..6c9fd66ee 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,12 @@ -module github.com/pion/dtls/v2 +module github.com/pion/dtls/v3 require ( - github.com/pion/logging v0.2.2 - github.com/pion/transport/v2 v2.2.1 - golang.org/x/crypto v0.8.0 - golang.org/x/net v0.9.0 + github.com/pion/logging v0.2.3 + github.com/pion/transport/v3 v3.0.7 + golang.org/x/crypto v0.36.0 + golang.org/x/net v0.37.0 ) -go 1.13 +go 1.23.0 + +toolchain go1.24.1 diff --git a/go.sum b/go.sum index 8144e06dd..dfb9901d2 100644 --- a/go.sum +++ b/go.sum @@ -1,58 +1,8 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= -github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/transport/v2 v2.2.1 h1:7qYnCBlpgSJNYMbLCKuSY9KbQdBFoETvPNETv0y4N7c= -github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= -golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI= +github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= diff --git a/handshake_cache.go b/handshake_cache.go index 8d5960568..95f20953f 100644 --- a/handshake_cache.go +++ b/handshake_cache.go @@ -6,8 +6,8 @@ package dtls import ( "sync" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/handshake" ) type handshakeCacheItem struct { @@ -49,7 +49,7 @@ func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ ha // returns a list handshakes that match the requested rules // the list will contain null entries for rules that can't be satisfied -// multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies) +// multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies). func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem { h.mu.Lock() defer h.mu.Unlock() @@ -72,15 +72,21 @@ func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCache } // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map. -func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) { +// +//nolint:cyclop +func (h *handshakeCache) fullPullMap( + startSeq int, + cipherSuite CipherSuite, + rules ...handshakeCachePullRule, +) (int, map[handshake.Type]handshake.Message, bool) { h.mu.Lock() defer h.mu.Unlock() ci := make(map[handshake.Type]*handshakeCacheItem) - for _, r := range rules { + for _, rule := range rules { var item *handshakeCacheItem for _, c := range h.cache { - if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { + if c.typ == rule.typ && c.isClient == rule.isClient && c.epoch == rule.epoch { switch { case item == nil: item = c @@ -89,17 +95,18 @@ func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rule } } } - if !r.optional && item == nil { + if !rule.optional && item == nil { // Missing mandatory message. return startSeq, nil, false } - ci[r.typ] = item + ci[rule.typ] = item } out := make(map[handshake.Type]handshake.Message) seq := startSeq + ok := false for _, r := range rules { - t := r.typ - i := ci[t] + typ := r.typ + i := ci[typ] if i == nil { continue } @@ -113,17 +120,22 @@ func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rule if err := rawHandshake.Unmarshal(i.data); err != nil { return startSeq, nil, false } - if uint16(seq) != rawHandshake.Header.MessageSequence { + if uint16(seq) != rawHandshake.Header.MessageSequence { //nolint:gosec // G115 // There is a gap. Some messages are not arrived. return startSeq, nil, false } seq++ - out[t] = rawHandshake.Message + ok = true + out[typ] = rawHandshake.Message + } + if !ok { + return seq, nil, false } + return seq, out, true } -// pullAndMerge calls pull and then merges the results, ignoring any null entries +// pullAndMerge calls pull and then merges the results, ignoring any null entries. func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { merged := []byte{} @@ -132,6 +144,7 @@ func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { merged = append(merged, p.data...) } } + return merged } diff --git a/handshake_cache_test.go b/handshake_cache_test.go index 44a15b587..b655ac166 100644 --- a/handshake_cache_test.go +++ b/handshake_cache_test.go @@ -7,8 +7,8 @@ import ( "bytes" "testing" - "github.com/pion/dtls/v2/internal/ciphersuite" - "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v3/internal/ciphersuite" + "github.com/pion/dtls/v3/pkg/protocol/handshake" ) func TestHandshakeCacheSinglePush(t *testing.T) { @@ -144,7 +144,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeServerHelloDone, false, 0, 4, []byte{0x04}}, {handshake.TypeClientKeyExchange, true, 0, 5, []byte{0x05}}, }, - Expected: []byte{0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, 0x27, 0xcd, 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, 0x2d, 0x51, 0x1f, 0x43}, + Expected: []byte{ + 0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, 0x27, 0xcd, + 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, 0x2d, 0x51, 0x1f, 0x43, + }, }, { Name: "Handshake With Client Cert Request", @@ -157,7 +160,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, }, - Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, + Expected: []byte{ + 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, + 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, + }, }, { Name: "Handshake Ignores after ClientKeyExchange", @@ -173,7 +179,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, }, - Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, + Expected: []byte{ + 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, + 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, + }, }, { Name: "Handshake Ignores wrong epoch", @@ -193,7 +202,10 @@ func TestHandshakeCacheSessionHash(t *testing.T) { {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, }, - Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, + Expected: []byte{ + 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, + 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, + }, }, } { h := newHandshakeCache() diff --git a/handshake_test.go b/handshake_test.go index 5bba7f812..8c97d20b2 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/extension" - "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol/handshake" ) func TestHandshakeMessage(t *testing.T) { @@ -30,7 +30,10 @@ func TestHandshakeMessage(t *testing.T) { Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: handshake.Random{ GMTUnixTime: time.Unix(3056586332, 0), - RandomBytes: [28]byte{0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b}, + RandomBytes: [28]byte{ + 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, + 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, + }, }, SessionID: []byte{}, Cookie: []byte{}, diff --git a/handshaker.go b/handshaker.go index 1c6d58fe9..e8b09b6c4 100644 --- a/handshaker.go +++ b/handshaker.go @@ -12,10 +12,10 @@ import ( "sync" "time" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" ) @@ -82,37 +82,42 @@ func (s handshakeState) String() string { } type handshakeFSM struct { - currentFlight flightVal - flights []*packet - retransmit bool - state *State - cache *handshakeCache - cfg *handshakeConfig - closed chan struct{} + currentFlight flightVal + flights []*packet + retransmit bool + retransmitInterval time.Duration + state *State + cache *handshakeCache + cfg *handshakeConfig + closed chan struct{} } type handshakeConfig struct { - localPSKCallback PSKCallback - localPSKIdentityHint []byte - localCipherSuites []CipherSuite // Available CipherSuites - localSignatureSchemes []signaturehash.Algorithm // Available signature schemes - extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension - localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support - serverName string - supportedProtocols []string - clientAuth ClientAuthType // If we are a client should we request a client certificate - localCertificates []tls.Certificate - nameToCertificate map[string]*tls.Certificate - insecureSkipVerify bool - verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - verifyConnection func(*State) error - sessionStore SessionStore - rootCAs *x509.CertPool - clientCAs *x509.CertPool - retransmitInterval time.Duration - customCipherSuites func() []CipherSuite - ellipticCurves []elliptic.Curve - insecureSkipHelloVerify bool + localPSKCallback PSKCallback + localPSKIdentityHint []byte + localCipherSuites []CipherSuite // Available CipherSuites + localSignatureSchemes []signaturehash.Algorithm // Available signature schemes + extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension + localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support + localSRTPMasterKeyIdentifier []byte + serverName string + supportedProtocols []string + clientAuth ClientAuthType // If we are a client should we request a client certificate + localCertificates []tls.Certificate + nameToCertificate map[string]*tls.Certificate + insecureSkipVerify bool + verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + verifyConnection func(*State) error + sessionStore SessionStore + rootCAs *x509.CertPool + clientCAs *x509.CertPool + initialRetransmitInterval time.Duration + disableRetransmitBackoff bool + customCipherSuites func() []CipherSuite + ellipticCurves []elliptic.Curve + insecureSkipHelloVerify bool + connectionIDGenerator func() []byte + helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte onFlightState func(flightVal, handshakeState) log logging.LeveledLogger @@ -124,12 +129,18 @@ type handshakeConfig struct { initialEpoch uint16 mu sync.Mutex + + clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message + serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message + certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message + + resumeState *State } type flightConn interface { notify(ctx context.Context, level alert.Level, desc alert.Description) error writePackets(context.Context, []*packet) error - recvHandshake() <-chan chan struct{} + recvHandshake() <-chan recvHandshakeState setLocalEpoch(epoch uint16) handleQueuedPackets(context.Context) error sessionKey() []byte @@ -151,6 +162,7 @@ func srvCliStr(isClient bool) string { if isClient { return "client" } + return "server" } @@ -159,15 +171,16 @@ func newHandshakeFSM( initialFlight flightVal, ) *handshakeFSM { return &handshakeFSM{ - currentFlight: initialFlight, - state: s, - cache: cache, - cfg: cfg, - closed: make(chan struct{}), + currentFlight: initialFlight, + state: s, + cache: cache, + cfg: cfg, + retransmitInterval: cfg.initialRetransmitInterval, + closed: make(chan struct{}), } } -func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error { +func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { state := initialState defer func() { close(s.closed) @@ -180,13 +193,13 @@ func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState hands var err error switch state { case handshakePreparing: - state, err = s.prepare(ctx, c) + state, err = s.prepare(ctx, conn) case handshakeSending: - state, err = s.send(ctx, c) + state, err = s.send(ctx, conn) case handshakeWaiting: - state, err = s.wait(ctx, c) + state, err = s.wait(ctx, conn) case handshakeFinished: - state, err = s.finish(ctx, c) + state, err = s.finish(ctx, conn) default: return errInvalidFSMTransition } @@ -200,24 +213,24 @@ func (s *handshakeFSM) Done() <-chan struct{} { return s.closed } -func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) { +func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { s.flights = nil // Prepare flights var ( - a *alert.Alert - err error - pkts []*packet + dtlsAlert *alert.Alert + err error + pkts []*packet ) gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() if errFlight != nil { err = errFlight - a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} + dtlsAlert = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} } else { - pkts, a, err = gen(c, s.state, s.cache, s.cfg) + pkts, dtlsAlert, err = gen(conn, s.state, s.cache, s.cfg) s.retransmit = retransmit } - if a != nil { - if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil { + if dtlsAlert != nil { + if alertErr := conn.notify(ctx, dtlsAlert.Level, dtlsAlert.Description); alertErr != nil { if err != nil { err = alertErr } @@ -236,14 +249,15 @@ func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeStat nextEpoch = p.record.Header.Epoch } if h, ok := p.record.Content.(*handshake.Handshake); ok { - h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) + h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) //nolint:gosec // G115 s.state.handshakeSendSequence++ } } if epoch != nextEpoch { s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) - c.setLocalEpoch(nextEpoch) + conn.setLocalEpoch(nextEpoch) } + return handshakeSending, nil } @@ -256,28 +270,35 @@ func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, if s.currentFlight.isLastSendFlight() { return handshakeFinished, nil } + return handshakeWaiting, nil } -func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit +func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { - if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { - if errFlight != nil { - return handshakeErrored, alertErr - } + if alertErr := conn.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { + return handshakeErrored, alertErr } + return handshakeErrored, errFlight } - retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) + retransmitTimer := time.NewTimer(s.retransmitInterval) for { select { - case done := <-c.recvHandshake(): - nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) - close(done) + case state := <-conn.recvHandshake(): + if state.isRetransmit { + close(state.done) + + return handshakeSending, nil + } + + nextFlight, alert, err := parse(ctx, conn, s.state, s.cache, s.cfg) + s.retransmitInterval = s.cfg.initialRetransmitInterval + close(state.done) if alert != nil { - if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if alertErr := conn.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { err = alertErr } @@ -289,62 +310,53 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, if nextFlight == 0 { break } - s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String()) + s.cfg.log.Tracef( + "[handshake:%s] %s -> %s", + srvCliStr(s.state.isClient), + s.currentFlight.String(), + nextFlight.String(), + ) if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { return handshakeFinished, nil } s.currentFlight = nextFlight + return handshakePreparing, nil case <-retransmitTimer.C: if !s.retransmit { return handshakeWaiting, nil } + + // RFC 4347 4.2.4.1: + // Implementations SHOULD use an initial timer value of 1 second (the minimum defined in RFC 2988 [RFC2988]) + // and double the value at each retransmission, up to no less than the RFC 2988 maximum of 60 seconds. + if !s.cfg.disableRetransmitBackoff { + s.retransmitInterval *= 2 + } + if s.retransmitInterval > time.Second*60 { + s.retransmitInterval = time.Second * 60 + } + return handshakeSending, nil case <-ctx.Done(): + s.retransmitInterval = s.cfg.initialRetransmitInterval + return handshakeErrored, ctx.Err() } } } func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { - parse, errFlight := s.currentFlight.getFlightParser() - if errFlight != nil { - if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { - if errFlight != nil { - return handshakeErrored, alertErr - } - } - return handshakeErrored, errFlight - } - - retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) select { - case done := <-c.recvHandshake(): - nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) - close(done) - if alert != nil { - if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { - if err != nil { - err = alertErr - } - } - } - if err != nil { - return handshakeErrored, err - } - if nextFlight == 0 { - break - } - if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { + case state := <-c.recvHandshake(): + close(state.done) + if s.state.isClient { return handshakeFinished, nil + } else { + return handshakeSending, nil } - <-retransmitTimer.C - // Retransmit last flight - return handshakeSending, nil - case <-ctx.Done(): return handshakeErrored, ctx.Err() } - return handshakeFinished, nil } diff --git a/handshaker_test.go b/handshaker_test.go index 6cf7cd3cf..88e69ee0c 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -12,13 +12,13 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" - "github.com/pion/transport/v2/test" + "github.com/pion/transport/v3/test" ) const nonZeroRetransmitInterval = 100 * time.Millisecond @@ -44,7 +44,7 @@ func TestWriteKeyLog(t *testing.T) { cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) } -func TestHandshaker(t *testing.T) { +func TestHandshaker(t *testing.T) { //nolint:gocyclo,cyclop,maintidx // Check for leaking routines report := test.CheckRoutines(t) defer report() @@ -84,6 +84,7 @@ func TestHandshaker(t *testing.T) { cntClientHelloNoCookie++ } } + return true }, } @@ -96,17 +97,26 @@ func TestHandshaker(t *testing.T) { } if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok { cntHelloVerifyRequest++ + return cntHelloVerifyRequest > helloVerifyDrop } + return true }, } report := func(t *testing.T) { + t.Helper() + if cntHelloVerifyRequest != helloVerifyDrop+1 { - t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest) + t.Errorf( + "Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", + helloVerifyDrop+1, + cntHelloVerifyRequest, + ) } if cntClientHelloNoCookie != cntHelloVerifyRequest { + ///nolint:lll t.Errorf( "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", cntHelloVerifyRequest, cntClientHelloNoCookie, @@ -132,6 +142,7 @@ func TestHandshaker(t *testing.T) { if _, ok := h.Message.(*handshake.MessageFinished); ok { cntClientFinished++ } + return true }, } @@ -145,11 +156,14 @@ func TestHandshaker(t *testing.T) { if _, ok := h.Message.(*handshake.MessageFinished); ok { cntServerFinished++ } + return true }, } report := func(t *testing.T) { + t.Helper() + if cntClientFinished != 1 { t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished) } @@ -184,6 +198,7 @@ func TestHandshaker(t *testing.T) { cntClientFinished++ } } + return true }, Delay: 0, @@ -206,6 +221,7 @@ func TestHandshaker(t *testing.T) { cntServerFinished++ } } + return true }, Delay: 1000 * time.Millisecond, @@ -216,17 +232,24 @@ func TestHandshaker(t *testing.T) { } report := func(t *testing.T) { - // with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client - // using a range of 9 - 11 for checking - if cntClientFinished < 8 || cntClientFinished > 11 { - t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished) + t.Helper() + + // with one second server delay and 100 ms retransmit (+ exponential backoff), + // there should be close to 4 `Finished` from client + // using a range of 3 - 5 for checking. + if cntClientFinished < 3 || cntClientFinished > 5 { + t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 3, 5, cntClientFinished) } if !isClientFinished { t.Errorf("Client is not finished") } // there should be no `Finished` last retransmit from client if cntClientFinishedLastRetransmit != 0 { - t.Errorf("Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit) + t.Errorf( + "Number of client finished last retransmit is wrong, expected: %d times, got: %d times", + 0, + cntClientFinishedLastRetransmit, + ) } if cntServerFinished < 1 { t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished) @@ -234,9 +257,14 @@ func TestHandshaker(t *testing.T) { if !isServerFinished { t.Errorf("Server is not finished") } - // there should be `Finished` last retransmit from server. Because of slow server, client would have sent several `Finished`. + // there should be `Finished` last retransmit from server. + // Because of slow server, client would have sent several `Finished`. if cntServerFinishedLastRetransmit < 1 { - t.Errorf("Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit) + t.Errorf( + "Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", + 1, + cntServerFinishedLastRetransmit, + ) } } @@ -271,7 +299,7 @@ func TestHandshaker(t *testing.T) { localSignatureSchemes: signaturehash.Algorithms(), insecureSkipVerify: true, log: logger, - onFlightState: func(f flightVal, s handshakeState) { + onFlightState: func(_ flightVal, s handshakeState) { if s == handshakeFinished { if clientEndpoint.OnFinished != nil { clientEndpoint.OnFinished() @@ -281,7 +309,7 @@ func TestHandshaker(t *testing.T) { }) } }, - retransmitInterval: nonZeroRetransmitInterval, + initialRetransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) @@ -304,7 +332,7 @@ func TestHandshaker(t *testing.T) { localSignatureSchemes: signaturehash.Algorithms(), insecureSkipVerify: true, log: logger, - onFlightState: func(f flightVal, s handshakeState) { + onFlightState: func(_ flightVal, s handshakeState) { if s == handshakeFinished { if serverEndpoint.OnFinished != nil { serverEndpoint.OnFinished() @@ -314,7 +342,7 @@ func TestHandshaker(t *testing.T) { }) } }, - retransmitInterval: nonZeroRetransmitInterval, + initialRetransmitInterval: nonZeroRetransmitInterval, } fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0) @@ -346,11 +374,16 @@ type TestEndpoint struct { FinishWait time.Duration } -func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) { +func flightTestPipe( + ctx context.Context, + clientEndpoint TestEndpoint, + serverEndpoint TestEndpoint, +) (*flightTestConn, *flightTestConn) { ca := newHandshakeCache() cb := newHandshakeCache() - chA := make(chan chan struct{}) - chB := make(chan chan struct{}) + chA := make(chan recvHandshakeState) + chB := make(chan recvHandshakeState) + return &flightTestConn{ handshakeCache: ca, otherEndCache: cb, @@ -373,7 +406,7 @@ func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndp type flightTestConn struct { state State handshakeCache *handshakeCache - recv chan chan struct{} + recv chan recvHandshakeState done <-chan struct{} epoch uint16 @@ -382,10 +415,10 @@ type flightTestConn struct { delay time.Duration otherEndCache *handshakeCache - otherEndRecv chan chan struct{} + otherEndRecv chan recvHandshakeState } -func (c *flightTestConn) recvHandshake() <-chan chan struct{} { +func (c *flightTestConn) recvHandshake() <-chan recvHandshakeState { return c.recv } @@ -399,35 +432,46 @@ func (c *flightTestConn) notify(context.Context, alert.Level, alert.Description) func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error { time.Sleep(c.delay) - for _, p := range pkts { - if c.filter != nil && !c.filter(p) { + for _, pkt := range pkts { + if c.filter != nil && !c.filter(pkt) { continue } - if h, ok := p.record.Content.(*handshake.Handshake); ok { - handshakeRaw, err := p.record.Marshal() + if handshake, ok := pkt.record.Content.(*handshake.Handshake); ok { + handshakeRaw, err := pkt.record.Marshal() if err != nil { return err } - c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) + c.handshakeCache.push( + handshakeRaw[recordlayer.FixedHeaderSize:], + pkt.record.Header.Epoch, + handshake.Header.MessageSequence, + handshake.Header.Type, + c.state.isClient, + ) - content, err := h.Message.Marshal() + content, err := handshake.Message.Marshal() if err != nil { return err } - h.Header.Length = uint32(len(content)) - h.Header.FragmentLength = uint32(len(content)) - hdr, err := h.Header.Marshal() + handshake.Header.Length = uint32(len(content)) //nolint:gosec // G115 + handshake.Header.FragmentLength = uint32(len(content)) //nolint:gosec // G115 + hdr, err := handshake.Header.Marshal() if err != nil { return err } c.otherEndCache.push( - append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) + append(hdr, content...), + pkt.record.Header.Epoch, + handshake.Header.MessageSequence, + handshake.Header.Type, + c.state.isClient, + ) } } go func() { select { - case c.otherEndRecv <- make(chan struct{}): + case c.otherEndRecv <- recvHandshakeState{done: make(chan struct{})}: case <-c.done: } }() diff --git a/internal/ciphersuite/aes_128_ccm.go b/internal/ciphersuite/aes_128_ccm.go index f78b6dc2c..9805f36e8 100644 --- a/internal/ciphersuite/aes_128_ccm.go +++ b/internal/ciphersuite/aes_128_ccm.go @@ -4,16 +4,23 @@ package ciphersuite import ( - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// Aes128Ccm is a base class used by multiple AES-CCM Ciphers +// Aes128Ccm is a base class used by multiple AES-CCM Ciphers. type Aes128Ccm struct { AesCcm } -func newAes128Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, cryptoCCMTagLen ciphersuite.CCMTagLen, keyExchangeAlgorithm KeyExchangeAlgorithm, ecc bool) *Aes128Ccm { +func newAes128Ccm( + clientCertificateType clientcertificate.Type, + id ID, + psk bool, + cryptoCCMTagLen ciphersuite.CCMTagLen, + keyExchangeAlgorithm KeyExchangeAlgorithm, + ecc bool, +) *Aes128Ccm { return &Aes128Ccm{ AesCcm: AesCcm{ clientCertificateType: clientCertificateType, @@ -26,8 +33,9 @@ func newAes128Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, } } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *Aes128Ccm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const prfKeyLen = 16 + return c.AesCcm.Init(masterSecret, clientRandom, serverRandom, isClient, prfKeyLen) } diff --git a/internal/ciphersuite/aes_256_ccm.go b/internal/ciphersuite/aes_256_ccm.go index bb8128627..58d5e0cee 100644 --- a/internal/ciphersuite/aes_256_ccm.go +++ b/internal/ciphersuite/aes_256_ccm.go @@ -4,16 +4,23 @@ package ciphersuite import ( - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// Aes256Ccm is a base class used by multiple AES-CCM Ciphers +// Aes256Ccm is a base class used by multiple AES-CCM Ciphers. type Aes256Ccm struct { AesCcm } -func newAes256Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, cryptoCCMTagLen ciphersuite.CCMTagLen, keyExchangeAlgorithm KeyExchangeAlgorithm, ecc bool) *Aes256Ccm { +func newAes256Ccm( + clientCertificateType clientcertificate.Type, + id ID, + psk bool, + cryptoCCMTagLen ciphersuite.CCMTagLen, + keyExchangeAlgorithm KeyExchangeAlgorithm, + ecc bool, +) *Aes256Ccm { return &Aes256Ccm{ AesCcm: AesCcm{ clientCertificateType: clientCertificateType, @@ -26,8 +33,9 @@ func newAes256Ccm(clientCertificateType clientcertificate.Type, id ID, psk bool, } } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *Aes256Ccm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const prfKeyLen = 32 + return c.AesCcm.Init(masterSecret, clientRandom, serverRandom, isClient, prfKeyLen) } diff --git a/internal/ciphersuite/aes_ccm.go b/internal/ciphersuite/aes_ccm.go index dc5119823..ddda1c8e7 100644 --- a/internal/ciphersuite/aes_ccm.go +++ b/internal/ciphersuite/aes_ccm.go @@ -9,13 +9,13 @@ import ( "hash" "sync/atomic" - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// AesCcm is a base class used by multiple AES-CCM Ciphers +// AesCcm is a base class used by multiple AES-CCM Ciphers. type AesCcm struct { ccm atomic.Value // *cryptoCCM clientCertificateType clientcertificate.Type @@ -26,12 +26,12 @@ type AesCcm struct { ecc bool } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *AesCcm) CertificateType() clientcertificate.Type { return c.clientCertificateType } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *AesCcm) ID() ID { return c.id } @@ -40,59 +40,66 @@ func (c *AesCcm) String() string { return c.id.String() } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *AesCcm) ECC() bool { return c.ecc } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *AesCcm) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return c.keyExchangeAlgorithm } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *AesCcm) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *AesCcm) AuthenticationType() AuthenticationType { if c.psk { return AuthenticationTypePreSharedKey } + return AuthenticationTypeCertificate } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *AesCcm) IsInitialized() bool { return c.ccm.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *AesCcm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool, prfKeyLen int) error { const ( prfMacLen = 0 prfIvLen = 4 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } var ccm *ciphersuite.CCM if isClient { - ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV) + ccm, err = ciphersuite.NewCCM( + c.cryptoCCMTagLen, keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV, + ) } else { - ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV) + ccm, err = ciphersuite.NewCCM( + c.cryptoCCMTagLen, keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV, + ) } c.ccm.Store(ccm) return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) if !ok { @@ -102,12 +109,12 @@ func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, erro return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer -func (c *AesCcm) Decrypt(raw []byte) ([]byte, error) { +// Decrypt decrypts a single TLS RecordLayer. +func (c *AesCcm) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) if !ok { return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) } - return cipherSuite.Decrypt(raw) + return cipherSuite.Decrypt(h, raw) } diff --git a/internal/ciphersuite/ciphersuite.go b/internal/ciphersuite/ciphersuite.go index f44f29fd3..27b7d57ce 100644 --- a/internal/ciphersuite/ciphersuite.go +++ b/internal/ciphersuite/ciphersuite.go @@ -1,23 +1,25 @@ // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> // SPDX-License-Identifier: MIT -// Package ciphersuite provides TLS Ciphers as registered with the IANA https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-4 +// Package ciphersuite provides TLS Ciphers as registered with the IANA +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-4 package ciphersuite import ( "errors" "fmt" - "github.com/pion/dtls/v2/internal/ciphersuite/types" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/pkg/protocol" ) -var errCipherSuiteNotInit = &protocol.TemporaryError{Err: errors.New("CipherSuite has not been initialized")} //nolint:goerr113 +//nolint:goerr113 +var errCipherSuiteNotInit = &protocol.TemporaryError{Err: errors.New("CipherSuite has not been initialized")} -// ID is an ID for our supported CipherSuites +// ID is an ID for our supported CipherSuites. type ID uint16 -func (i ID) String() string { +func (i ID) String() string { //nolint:cyclop switch i { case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: return "TLS_ECDHE_ECDSA_WITH_AES_128_CCM" @@ -52,19 +54,19 @@ func (i ID) String() string { } } -// Supported Cipher Suites +// Supported Cipher Suites. const ( - // AES-128-CCM + // AES-128-CCM. TLS_ECDHE_ECDSA_WITH_AES_128_CCM ID = 0xc0ac //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ID = 0xc0ae //nolint:revive,stylecheck - // AES-128-GCM-SHA256 + // AES-128-GCM-SHA256. TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ID = 0xc02b //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ID = 0xc02f //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ID = 0xc02c //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ID = 0xc030 //nolint:revive,stylecheck - // AES-256-CBC-SHA + // AES-256-CBC-SHA. TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ID = 0xc00a //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ID = 0xc014 //nolint:revive,stylecheck @@ -77,10 +79,10 @@ const ( TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ID = 0xC037 //nolint:revive,stylecheck ) -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. type AuthenticationType = types.AuthenticationType -// AuthenticationType Enums +// AuthenticationType Enums. const ( AuthenticationTypeCertificate AuthenticationType = types.AuthenticationTypeCertificate AuthenticationTypePreSharedKey AuthenticationType = types.AuthenticationTypePreSharedKey @@ -90,7 +92,7 @@ const ( // KeyExchangeAlgorithm controls what exchange algorithm was chosen. type KeyExchangeAlgorithm = types.KeyExchangeAlgorithm -// KeyExchangeAlgorithm Bitmask +// KeyExchangeAlgorithm Bitmask. const ( KeyExchangeAlgorithmNone KeyExchangeAlgorithm = types.KeyExchangeAlgorithmNone KeyExchangeAlgorithmPsk KeyExchangeAlgorithm = types.KeyExchangeAlgorithmPsk diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go index 8367b2c6d..04a6ca40d 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm.go @@ -4,11 +4,18 @@ package ciphersuite import ( - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSEcdheEcdsaWithAes128Ccm constructs a TLS_ECDHE_ECDSA_WITH_AES_128_CCM Cipher +// NewTLSEcdheEcdsaWithAes128Ccm constructs a TLS_ECDHE_ECDSA_WITH_AES_128_CCM Cipher. func NewTLSEcdheEcdsaWithAes128Ccm() *Aes128Ccm { - return newAes128Ccm(clientcertificate.ECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM, false, ciphersuite.CCMTagLength, KeyExchangeAlgorithmEcdhe, true) + return newAes128Ccm( + clientcertificate.ECDSASign, + TLS_ECDHE_ECDSA_WITH_AES_128_CCM, + false, + ciphersuite.CCMTagLength, + KeyExchangeAlgorithmEcdhe, + true, + ) } diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go index 11b687327..38a166fad 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_ccm8.go @@ -4,11 +4,18 @@ package ciphersuite import ( - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSEcdheEcdsaWithAes128Ccm8 creates a new TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuite +// NewTLSEcdheEcdsaWithAes128Ccm8 creates a new TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuite. func NewTLSEcdheEcdsaWithAes128Ccm8() *Aes128Ccm { - return newAes128Ccm(clientcertificate.ECDSASign, TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, false, ciphersuite.CCMTagLength8, KeyExchangeAlgorithmEcdhe, true) + return newAes128Ccm( + clientcertificate.ECDSASign, + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, + false, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmEcdhe, + true, + ) } diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go index 0c919fe47..f47a67497 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go @@ -9,33 +9,33 @@ import ( "hash" "sync/atomic" - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSEcdheEcdsaWithAes128GcmSha256 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSEcdheEcdsaWithAes128GcmSha256 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSEcdheEcdsaWithAes128GcmSha256 struct { gcm atomic.Value // *cryptoGCM } -// CertificateType returns what type of certficate this CipherSuite exchanges +// CertificateType returns what type of certficate this CipherSuite exchanges. func (c *TLSEcdheEcdsaWithAes128GcmSha256) CertificateType() clientcertificate.Type { return clientcertificate.ECDSASign } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSEcdheEcdsaWithAes128GcmSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmEcdhe } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSEcdheEcdsaWithAes128GcmSha256) ECC() bool { return true } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheEcdsaWithAes128GcmSha256) ID() ID { return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 } @@ -44,24 +44,31 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) String() string { return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdheEcdsaWithAes128GcmSha256) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSEcdheEcdsaWithAes128GcmSha256) AuthenticationType() AuthenticationType { return AuthenticationTypeCertificate } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSEcdheEcdsaWithAes128GcmSha256) IsInitialized() bool { return c.gcm.Load() != nil } -func (c *TLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool, prfMacLen, prfKeyLen, prfIvLen int, hashFunc func() hash.Hash) error { - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, hashFunc) +func (c *TLSEcdheEcdsaWithAes128GcmSha256) init( + masterSecret, clientRandom, serverRandom []byte, + isClient bool, + prfMacLen, prfKeyLen, prfIvLen int, + hashFunc func() hash.Hash, +) error { + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, hashFunc, + ) if err != nil { return err } @@ -73,10 +80,11 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serv gcm, err = ciphersuite.NewGCM(keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV) } c.gcm.Store(gcm) + return err } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdheEcdsaWithAes128GcmSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 0 @@ -87,7 +95,7 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Init(masterSecret, clientRandom, serv return c.init(masterSecret, clientRandom, serverRandom, isClient, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) if !ok { @@ -97,12 +105,12 @@ func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer, return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer -func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(raw []byte) ([]byte, error) { +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) if !ok { return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) } - return cipherSuite.Decrypt(raw) + return cipherSuite.Decrypt(h, raw) } diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go index 577192c89..6eeb91811 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go @@ -10,33 +10,33 @@ import ( "hash" "sync/atomic" - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSEcdheEcdsaWithAes256CbcSha represents a TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuite +// TLSEcdheEcdsaWithAes256CbcSha represents a TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuite. type TLSEcdheEcdsaWithAes256CbcSha struct { cbc atomic.Value // *cryptoCBC } -// CertificateType returns what type of certficate this CipherSuite exchanges +// CertificateType returns what type of certficate this CipherSuite exchanges. func (c *TLSEcdheEcdsaWithAes256CbcSha) CertificateType() clientcertificate.Type { return clientcertificate.ECDSASign } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSEcdheEcdsaWithAes256CbcSha) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmEcdhe } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSEcdheEcdsaWithAes256CbcSha) ECC() bool { return true } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheEcdsaWithAes256CbcSha) ID() ID { return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA } @@ -45,23 +45,23 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) String() string { return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdheEcdsaWithAes256CbcSha) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSEcdheEcdsaWithAes256CbcSha) AuthenticationType() AuthenticationType { return AuthenticationTypeCertificate } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSEcdheEcdsaWithAes256CbcSha) IsInitialized() bool { return c.cbc.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 20 @@ -69,7 +69,9 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverR prfIvLen = 16 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } @@ -93,7 +95,7 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverR return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { @@ -103,12 +105,12 @@ func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, ra return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer -func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(raw []byte) ([]byte, error) { +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) } - return cipherSuite.Decrypt(raw) + return cipherSuite.Decrypt(h, raw) } diff --git a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go index 2a3cfa4f5..bf6f6c444 100644 --- a/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go +++ b/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_gcm_sha384.go @@ -8,12 +8,12 @@ import ( "hash" ) -// TLSEcdheEcdsaWithAes256GcmSha384 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSEcdheEcdsaWithAes256GcmSha384 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSEcdheEcdsaWithAes256GcmSha384 struct { TLSEcdheEcdsaWithAes128GcmSha256 } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheEcdsaWithAes256GcmSha384) ID() ID { return TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 } @@ -22,12 +22,12 @@ func (c *TLSEcdheEcdsaWithAes256GcmSha384) String() string { return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdheEcdsaWithAes256GcmSha384) HashFunc() func() hash.Hash { return sha512.New384 } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdheEcdsaWithAes256GcmSha384) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 0 diff --git a/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go b/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go index 75a25633a..24f51e1eb 100644 --- a/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go +++ b/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go @@ -9,13 +9,13 @@ import ( "hash" "sync/atomic" - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSEcdhePskWithAes128CbcSha256 implements the TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuite +// TLSEcdhePskWithAes128CbcSha256 implements the TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuite. type TLSEcdhePskWithAes128CbcSha256 struct { cbc atomic.Value // *cryptoCBC } @@ -25,22 +25,22 @@ func NewTLSEcdhePskWithAes128CbcSha256() *TLSEcdhePskWithAes128CbcSha256 { return &TLSEcdhePskWithAes128CbcSha256{} } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdhePskWithAes128CbcSha256) CertificateType() clientcertificate.Type { return clientcertificate.Type(0) } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSEcdhePskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return (KeyExchangeAlgorithmPsk | KeyExchangeAlgorithmEcdhe) } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSEcdhePskWithAes128CbcSha256) ECC() bool { return true } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdhePskWithAes128CbcSha256) ID() ID { return TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 } @@ -49,23 +49,23 @@ func (c *TLSEcdhePskWithAes128CbcSha256) String() string { return "TLS-ECDHE-PSK-WITH-AES-128-CBC-SHA256" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSEcdhePskWithAes128CbcSha256) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSEcdhePskWithAes128CbcSha256) AuthenticationType() AuthenticationType { return AuthenticationTypePreSharedKey } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSEcdhePskWithAes128CbcSha256) IsInitialized() bool { return c.cbc.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 32 @@ -73,7 +73,9 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, server prfIvLen = 16 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } @@ -97,7 +99,7 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, server return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { // !c.isInitialized() @@ -107,12 +109,12 @@ func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, r return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer -func (c *TLSEcdhePskWithAes128CbcSha256) Decrypt(raw []byte) ([]byte, error) { +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSEcdhePskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { // !c.isInitialized() return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) } - return cipherSuite.Decrypt(raw) + return cipherSuite.Decrypt(h, raw) } diff --git a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go index 478a2e0dc..b78969111 100644 --- a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go +++ b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_128_gcm_sha256.go @@ -3,19 +3,19 @@ package ciphersuite -import "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSEcdheRsaWithAes128GcmSha256 implements the TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSEcdheRsaWithAes128GcmSha256 implements the TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSEcdheRsaWithAes128GcmSha256 struct { TLSEcdheEcdsaWithAes128GcmSha256 } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdheRsaWithAes128GcmSha256) CertificateType() clientcertificate.Type { return clientcertificate.RSASign } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheRsaWithAes128GcmSha256) ID() ID { return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 } diff --git a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go index 8e88ee639..deb20dd94 100644 --- a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go +++ b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_cbc_sha.go @@ -3,19 +3,19 @@ package ciphersuite -import "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSEcdheRsaWithAes256CbcSha implements the TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuite +// TLSEcdheRsaWithAes256CbcSha implements the TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuite. type TLSEcdheRsaWithAes256CbcSha struct { TLSEcdheEcdsaWithAes256CbcSha } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdheRsaWithAes256CbcSha) CertificateType() clientcertificate.Type { return clientcertificate.RSASign } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheRsaWithAes256CbcSha) ID() ID { return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA } diff --git a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go index 752fb529c..f7d7049a8 100644 --- a/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go +++ b/internal/ciphersuite/tls_ecdhe_rsa_with_aes_256_gcm_sha384.go @@ -3,19 +3,19 @@ package ciphersuite -import "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSEcdheRsaWithAes256GcmSha384 implements the TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuite +// TLSEcdheRsaWithAes256GcmSha384 implements the TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuite. type TLSEcdheRsaWithAes256GcmSha384 struct { TLSEcdheEcdsaWithAes256GcmSha384 } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSEcdheRsaWithAes256GcmSha384) CertificateType() clientcertificate.Type { return clientcertificate.RSASign } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSEcdheRsaWithAes256GcmSha384) ID() ID { return TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 } diff --git a/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go b/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go index 7336ad946..32507cdc3 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go @@ -9,33 +9,33 @@ import ( "hash" "sync/atomic" - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// TLSPskWithAes128CbcSha256 implements the TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuite +// TLSPskWithAes128CbcSha256 implements the TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuite. type TLSPskWithAes128CbcSha256 struct { cbc atomic.Value // *cryptoCBC } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSPskWithAes128CbcSha256) CertificateType() clientcertificate.Type { return clientcertificate.Type(0) } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSPskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmPsk } -// ECC uses Elliptic Curve Cryptography +// ECC uses Elliptic Curve Cryptography. func (c *TLSPskWithAes128CbcSha256) ECC() bool { return false } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSPskWithAes128CbcSha256) ID() ID { return TLS_PSK_WITH_AES_128_CBC_SHA256 } @@ -44,23 +44,23 @@ func (c *TLSPskWithAes128CbcSha256) String() string { return "TLS_PSK_WITH_AES_128_CBC_SHA256" } -// HashFunc returns the hashing func for this CipherSuite +// HashFunc returns the hashing func for this CipherSuite. func (c *TLSPskWithAes128CbcSha256) HashFunc() func() hash.Hash { return sha256.New } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSPskWithAes128CbcSha256) AuthenticationType() AuthenticationType { return AuthenticationTypePreSharedKey } // IsInitialized returns if the CipherSuite has keying material and can -// encrypt/decrypt packets +// encrypt/decrypt packets. func (c *TLSPskWithAes128CbcSha256) IsInitialized() bool { return c.cbc.Load() != nil } -// Init initializes the internal Cipher with keying material +// Init initializes the internal Cipher with keying material. func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 32 @@ -68,7 +68,9 @@ func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRando prfIvLen = 16 ) - keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + keys, err := prf.GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc(), + ) if err != nil { return err } @@ -92,7 +94,7 @@ func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRando return err } -// Encrypt encrypts a single TLS RecordLayer +// Encrypt encrypts a single TLS RecordLayer. func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { @@ -102,12 +104,12 @@ func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw [] return cipherSuite.Encrypt(pkt, raw) } -// Decrypt decrypts a single TLS RecordLayer -func (c *TLSPskWithAes128CbcSha256) Decrypt(raw []byte) ([]byte, error) { +// Decrypt decrypts a single TLS RecordLayer. +func (c *TLSPskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) if !ok { return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) } - return cipherSuite.Decrypt(raw) + return cipherSuite.Decrypt(h, raw) } diff --git a/internal/ciphersuite/tls_psk_with_aes_128_ccm.go b/internal/ciphersuite/tls_psk_with_aes_128_ccm.go index 1ded09b88..0b802fc8c 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_ccm.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_ccm.go @@ -4,11 +4,18 @@ package ciphersuite import ( - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSPskWithAes128Ccm returns the TLS_PSK_WITH_AES_128_CCM CipherSuite +// NewTLSPskWithAes128Ccm returns the TLS_PSK_WITH_AES_128_CCM CipherSuite. func NewTLSPskWithAes128Ccm() *Aes128Ccm { - return newAes128Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_128_CCM, true, ciphersuite.CCMTagLength, KeyExchangeAlgorithmPsk, false) + return newAes128Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_128_CCM, + true, + ciphersuite.CCMTagLength, + KeyExchangeAlgorithmPsk, + false, + ) } diff --git a/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go b/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go index 478197074..c6bf6dc59 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_ccm8.go @@ -4,11 +4,18 @@ package ciphersuite import ( - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSPskWithAes128Ccm8 returns the TLS_PSK_WITH_AES_128_CCM_8 CipherSuite +// NewTLSPskWithAes128Ccm8 returns the TLS_PSK_WITH_AES_128_CCM_8 CipherSuite. func NewTLSPskWithAes128Ccm8() *Aes128Ccm { - return newAes128Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_128_CCM_8, true, ciphersuite.CCMTagLength8, KeyExchangeAlgorithmPsk, false) + return newAes128Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_128_CCM_8, + true, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmPsk, + false, + ) } diff --git a/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go b/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go index 8ab5b89a8..bc50d562b 100644 --- a/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go +++ b/internal/ciphersuite/tls_psk_with_aes_128_gcm_sha256.go @@ -3,24 +3,24 @@ package ciphersuite -import "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" +import "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" -// TLSPskWithAes128GcmSha256 implements the TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuite +// TLSPskWithAes128GcmSha256 implements the TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuite. type TLSPskWithAes128GcmSha256 struct { TLSEcdheEcdsaWithAes128GcmSha256 } -// CertificateType returns what type of certificate this CipherSuite exchanges +// CertificateType returns what type of certificate this CipherSuite exchanges. func (c *TLSPskWithAes128GcmSha256) CertificateType() clientcertificate.Type { return clientcertificate.Type(0) } -// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake. func (c *TLSPskWithAes128GcmSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { return KeyExchangeAlgorithmPsk } -// ID returns the ID of the CipherSuite +// ID returns the ID of the CipherSuite. func (c *TLSPskWithAes128GcmSha256) ID() ID { return TLS_PSK_WITH_AES_128_GCM_SHA256 } @@ -29,7 +29,7 @@ func (c *TLSPskWithAes128GcmSha256) String() string { return "TLS_PSK_WITH_AES_128_GCM_SHA256" } -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. func (c *TLSPskWithAes128GcmSha256) AuthenticationType() AuthenticationType { return AuthenticationTypePreSharedKey } diff --git a/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go b/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go index 32d503018..771a1d42e 100644 --- a/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go +++ b/internal/ciphersuite/tls_psk_with_aes_256_ccm8.go @@ -4,11 +4,18 @@ package ciphersuite import ( - "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" ) -// NewTLSPskWithAes256Ccm8 returns the TLS_PSK_WITH_AES_256_CCM_8 CipherSuite +// NewTLSPskWithAes256Ccm8 returns the TLS_PSK_WITH_AES_256_CCM_8 CipherSuite. func NewTLSPskWithAes256Ccm8() *Aes256Ccm { - return newAes256Ccm(clientcertificate.Type(0), TLS_PSK_WITH_AES_256_CCM_8, true, ciphersuite.CCMTagLength8, KeyExchangeAlgorithmPsk, false) + return newAes256Ccm( + clientcertificate.Type(0), + TLS_PSK_WITH_AES_256_CCM_8, + true, + ciphersuite.CCMTagLength8, + KeyExchangeAlgorithmPsk, + false, + ) } diff --git a/internal/ciphersuite/types/authentication_type.go b/internal/ciphersuite/types/authentication_type.go index 2da21e642..09681cec5 100644 --- a/internal/ciphersuite/types/authentication_type.go +++ b/internal/ciphersuite/types/authentication_type.go @@ -3,10 +3,10 @@ package types -// AuthenticationType controls what authentication method is using during the handshake +// AuthenticationType controls what authentication method is using during the handshake. type AuthenticationType int -// AuthenticationType Enums +// AuthenticationType Enums. const ( AuthenticationTypeCertificate AuthenticationType = iota + 1 AuthenticationTypePreSharedKey diff --git a/internal/ciphersuite/types/key_exchange_algorithm.go b/internal/ciphersuite/types/key_exchange_algorithm.go index c2c39113a..5b59f2410 100644 --- a/internal/ciphersuite/types/key_exchange_algorithm.go +++ b/internal/ciphersuite/types/key_exchange_algorithm.go @@ -7,7 +7,7 @@ package types // KeyExchangeAlgorithm controls what exchange algorithm was chosen. type KeyExchangeAlgorithm int -// KeyExchangeAlgorithm Bitmask +// KeyExchangeAlgorithm Bitmask. const ( KeyExchangeAlgorithmNone KeyExchangeAlgorithm = 0 KeyExchangeAlgorithmPsk KeyExchangeAlgorithm = iota << 1 diff --git a/internal/closer/closer.go b/internal/closer/closer.go index bfa171cda..a1c25f379 100644 --- a/internal/closer/closer.go +++ b/internal/closer/closer.go @@ -8,41 +8,43 @@ import ( "context" ) -// Closer allows for each signaling a channel for shutdown +// Closer allows for each signaling a channel for shutdown. type Closer struct { - ctx context.Context + ctx context.Context //nolint:containedctx closeFunc func() } -// NewCloser creates a new instance of Closer +// NewCloser creates a new instance of Closer. func NewCloser() *Closer { ctx, closeFunc := context.WithCancel(context.Background()) + return &Closer{ ctx: ctx, closeFunc: closeFunc, } } -// NewCloserWithParent creates a new instance of Closer with a parent context +// NewCloserWithParent creates a new instance of Closer with a parent context. func NewCloserWithParent(ctx context.Context) *Closer { ctx, closeFunc := context.WithCancel(ctx) + return &Closer{ ctx: ctx, closeFunc: closeFunc, } } -// Done returns a channel signaling when it is done +// Done returns a channel signaling when it is done. func (c *Closer) Done() <-chan struct{} { return c.ctx.Done() } -// Err returns an error of the context +// Err returns an error of the context. func (c *Closer) Err() error { return c.ctx.Err() } -// Close sends a signal to trigger the ctx done channel +// Close sends a signal to trigger the ctx done channel. func (c *Closer) Close() { c.closeFunc() } diff --git a/internal/net/buffer.go b/internal/net/buffer.go new file mode 100644 index 000000000..c763f15e4 --- /dev/null +++ b/internal/net/buffer.go @@ -0,0 +1,242 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package net implements DTLS specific networking primitives. +// NOTE: this package is an adaption of pion/transport/packetio that allows for +// storing a remote address alongside each packet in the buffer and implements +// relevant methods of net.PacketConn. If possible, the updates made in this +// repository will be reflected back upstream. If not, it is likely that this +// will be moved to a public package in this repository. +// +// This package was migrated from pion/transport/packetio at +// https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb +package net + +import ( + "bytes" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/pion/transport/v3/deadline" +) + +// ErrTimeout indicates that deadline was reached before operation could be +// completed. +var ErrTimeout = errors.New("buffer: i/o timeout") + +// AddrPacket is a packet payload and the associated remote address from which +// it was received. +type AddrPacket struct { + addr net.Addr + data bytes.Buffer +} + +// PacketBuffer is a circular buffer for network packets. Each slot in the +// buffer contains the remote address from which the packet was received, as +// well as the packet data. +type PacketBuffer struct { + mutex sync.Mutex + + packets []AddrPacket + write, read int + + // full indicates whether the buffer is full, which is needed to distinguish + // when the write pointer and read pointer are at the same index. + full bool + + notify chan struct{} + closed bool + + readDeadline *deadline.Deadline +} + +// NewPacketBuffer creates a new PacketBuffer. +func NewPacketBuffer() *PacketBuffer { + return &PacketBuffer{ + readDeadline: deadline.New(), + // In the narrow context in which this package is currently used, there + // will always be at least one packet written to the buffer. Therefore, + // we opt to allocate with size of 1 during construction, rather than + // waiting until that first packet is written. + packets: make([]AddrPacket, 1), + full: false, + } +} + +// WriteTo writes a single packet to the buffer. The supplied address will +// remain associated with the packet. +func (b *PacketBuffer) WriteTo(pkt []byte, addr net.Addr) (int, error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + + return 0, io.ErrClosedPipe + } + + var notify chan struct{} + if b.notify != nil { + notify = b.notify + b.notify = nil + } + + // Check to see if we are full. + if b.full { + // If so, grow AddrPacket buffer. + var newSize int + if len(b.packets) < 128 { + // Double the number of packets. + newSize = len(b.packets) * 2 + } else { + // Increase the number of packets by 25%. + newSize = 5 * len(b.packets) / 4 + } + newBuf := make([]AddrPacket, newSize) + var n int + if b.read < b.write { + n = copy(newBuf, b.packets[b.read:b.write]) + } else { + n = copy(newBuf, b.packets[b.read:]) + n += copy(newBuf[n:], b.packets[:b.write]) + } + + b.packets = newBuf + + // Update write pointer to point to new location and mark buffer as not + // full. + b.write = n + b.full = false + } + + // Store the packet at the write pointer. + packet := &b.packets[b.write] + packet.data.Reset() + n, err := packet.data.Write(pkt) + if err != nil { + b.mutex.Unlock() + + return n, err + } + packet.addr = addr + + // Increment write pointer. + b.write++ + + // If the write pointer is equal to the length of the buffer, wrap around. + if len(b.packets) == b.write { + b.write = 0 + } + + // If a write resulted in making write and read pointers equivalent, then we + // are full. + if b.write == b.read { + b.full = true + } + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return n, nil +} + +// ReadFrom reads a single packet from the buffer, or blocks until one is +// available. +func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) { //nolint:cyclop + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + default: + } + + for { + b.mutex.Lock() + + if b.read != b.write || b.full { + ap := b.packets[b.read] + if len(packet) < ap.data.Len() { + b.mutex.Unlock() + + return 0, nil, io.ErrShortBuffer + } + + // Copy packet data from buffer. + n, err := ap.data.Read(packet) + if err != nil { + b.mutex.Unlock() + + return n, nil, err + } + + // Advance read pointer. + b.read++ + if len(b.packets) == b.read { + b.read = 0 + } + + // If we were full before reading and have successfully read, we are + // no longer full. + if b.full { + b.full = false + } + + b.mutex.Unlock() + + return n, ap.addr, nil + } + + if b.closed { + b.mutex.Unlock() + + return 0, nil, io.EOF + } + + if b.notify == nil { + b.notify = make(chan struct{}) + } + notify := b.notify + b.mutex.Unlock() + + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + case <-notify: + } + } +} + +// Close closes the buffer, allowing unread packets to be read, but erroring on +// any new writes. +func (b *PacketBuffer) Close() (err error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + + return nil + } + + notify := b.notify + b.notify = nil + b.closed = true + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return nil +} + +// SetReadDeadline sets the read deadline for the buffer. +func (b *PacketBuffer) SetReadDeadline(t time.Time) error { + b.readDeadline.Set(t) + + return nil +} diff --git a/internal/net/buffer_test.go b/internal/net/buffer_test.go new file mode 100644 index 000000000..87a2a6ce1 --- /dev/null +++ b/internal/net/buffer_test.go @@ -0,0 +1,423 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package net implements DTLS specific networking primitives. +package net + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "testing" + "time" +) + +func equalInt(t *testing.T, expected, actual int) { + t.Helper() + + if expected != actual { + t.Errorf("Expected %d got %d", expected, actual) + } +} + +func equalUDPAddr(t *testing.T, expected, actual net.Addr) { + t.Helper() + + if expected == nil && actual == nil { + return + } + if expected.String() != actual.String() { + t.Errorf("Expected %v got %v", expected, actual) + } +} + +func equalBytes(t *testing.T, expected, actual []byte) { + t.Helper() + + if !bytes.Equal(expected, actual) { + t.Errorf("Expected %v got %v", expected, actual) + } +} + +func TestBuffer(t *testing.T) { //nolint:cyclop + buffer := NewPacketBuffer() + packet := make([]byte, 4) + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + t.Fatal(err) + } + + // Write once. + n, err := buffer.WriteTo([]byte{0, 1}, addr) + if err != nil { + t.Fatal(err) + } + equalInt(t, 2, n) + + // Read once. + var raddr net.Addr + if n, raddr, err = buffer.ReadFrom(packet); err != nil { + t.Fatal(err) + } + equalInt(t, 2, n) + equalBytes(t, []byte{0, 1}, packet[:n]) + equalUDPAddr(t, addr, raddr) + + // Read deadline. + if err = buffer.SetReadDeadline(time.Unix(0, 1)); err != nil { + t.Fatal(err) + } + n, raddr, err = buffer.ReadFrom(packet) + if !errors.Is(ErrTimeout, err) { + t.Fatalf("Unexpected err %v wanted ErrTimeout", err) + } + equalInt(t, 0, n) + equalUDPAddr(t, nil, raddr) + + // Reset deadline. + if err = buffer.SetReadDeadline(time.Time{}); err != nil { + t.Fatal(err) + } + + // Write twice. + if n, err = buffer.WriteTo([]byte{2, 3, 4}, addr); err != nil { + t.Fatal(err) + } + equalInt(t, 3, n) + + if n, err = buffer.WriteTo([]byte{5, 6, 7}, addr); err != nil { + t.Fatal(err) + } + equalInt(t, 3, n) + + // Read twice. + if n, raddr, err = buffer.ReadFrom(packet); err != nil { + t.Fatal(err) + } + equalInt(t, 3, n) + equalBytes(t, []byte{2, 3, 4}, packet[:n]) + equalUDPAddr(t, addr, raddr) + + if n, raddr, err = buffer.ReadFrom(packet); err != nil { + t.Fatal(err) + } + equalInt(t, 3, n) + equalBytes(t, []byte{5, 6, 7}, packet[:n]) + equalUDPAddr(t, addr, raddr) + + // Write once prior to close. + if _, err = buffer.WriteTo([]byte{3}, addr); err != nil { + t.Fatal(err) + } + + // Close. + if err = buffer.Close(); err != nil { + t.Fatal(err) + } + + // Future writes will error. + if _, err = buffer.WriteTo([]byte{4}, addr); err == nil { + t.Fatal("Expected error") + } + + // But we can read the remaining data. + if n, raddr, err = buffer.ReadFrom(packet); err != nil { + t.Fatal(err) + } + equalInt(t, 1, n) + equalBytes(t, []byte{3}, packet[:n]) + equalUDPAddr(t, addr, raddr) + + // Until EOF. + if _, _, err = buffer.ReadFrom(packet); !errors.Is(err, io.EOF) { + t.Fatalf("Unexpected err %v wanted io.EOF", err) + } +} + +func TestShortBuffer(t *testing.T) { + buffer := NewPacketBuffer() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + t.Fatal(err) + } + + // Write once. + n, err := buffer.WriteTo([]byte{0, 1, 2, 3}, addr) + if err != nil { + t.Fatal(err) + } + equalInt(t, 4, n) + + // Try to read with a short buffer. + packet := make([]byte, 3) + var raddr net.Addr + n, raddr, err = buffer.ReadFrom(packet) + if !errors.Is(err, io.ErrShortBuffer) { + t.Fatalf("Unexpected err %v wanted io.ErrShortBuffer", err) + } + equalUDPAddr(t, nil, raddr) + equalInt(t, 0, n) + + // Close. + if err = buffer.Close(); err != nil { + t.Fatal(err) + } + + // Make sure you can Close twice. + if err = buffer.Close(); err != nil { + t.Fatal(err) + } +} + +func TestWraparound(t *testing.T) { + buffer := NewPacketBuffer() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + t.Fatal(err) + } + + // Write multiple. + n, err := buffer.WriteTo([]byte{0, 1, 2, 3}, addr) + if err != nil { + t.Fatal(err) + } + equalInt(t, 4, n) + + if n, err = buffer.WriteTo([]byte{4, 5}, addr); err != nil { + t.Fatal(err) + } + equalInt(t, 2, n) + + if n, err = buffer.WriteTo([]byte{6, 7, 8}, addr); err != nil { + t.Fatal(err) + } + equalInt(t, 3, n) + + // Verify underlying buffer length. + // Packet 1: buffer does not grow. + // Packet 2: buffer doubles from 1 to 2. + // Packet 3: buffer doubles from 2 to 4. + equalInt(t, 4, len(buffer.packets)) + + // Read once. + packet := make([]byte, 4) + var raddr net.Addr + if n, raddr, err = buffer.ReadFrom(packet); err != nil { + t.Fatal(err) + } + equalInt(t, 4, n) + equalBytes(t, []byte{0, 1, 2, 3}, packet[:n]) + equalUDPAddr(t, addr, raddr) + + // Write again. + if n, err = buffer.WriteTo([]byte{9, 10, 11}, addr); err != nil { + t.Fatal(err) + } + equalInt(t, 3, n) + + // Verify underlying buffer length. + // No change in buffer size. + equalInt(t, 4, len(buffer.packets)) + + // Write again and verify buffer grew. + if n, err = buffer.WriteTo([]byte{12, 13, 14, 15, 16, 17, 18, 19}, addr); err != nil { + t.Fatal(err) + } + equalInt(t, 8, n) + equalInt(t, 4, len(buffer.packets)) + + // Close. + if err = buffer.Close(); err != nil { + t.Fatal(err) + } +} + +func TestBufferAsync(t *testing.T) { + buffer := NewPacketBuffer() + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + t.Fatal(err) + } + + // Start up a goroutine to start a blocking read. + done := make(chan string) + go func() { + packet := make([]byte, 4) + + n, raddr, rErr := buffer.ReadFrom(packet) + if rErr != nil { + done <- rErr.Error() + + return + } + + equalInt(t, 2, n) + equalBytes(t, []byte{0, 1}, packet[:n]) + equalUDPAddr(t, addr, raddr) + + _, _, readErr := buffer.ReadFrom(packet) + if !errors.Is(readErr, io.EOF) { + done <- fmt.Sprintf("Unexpected err %v wanted io.EOF", readErr) + } else { + close(done) + } + }() + + // Wait for the reader to start reading. + time.Sleep(time.Millisecond) + + // Write once + n, err := buffer.WriteTo([]byte{0, 1}, addr) + if err != nil { + t.Fatal(err) + } + equalInt(t, 2, n) + + // Wait for the reader to start reading again. + time.Sleep(time.Millisecond) + + // Close will unblock the reader. + if err = buffer.Close(); err != nil { + t.Fatal(err) + } + + if routineFail, ok := <-done; ok { + t.Fatal(routineFail) + } +} + +func benchmarkBufferWR(b *testing.B, size int64, write bool, grow int) { // nolint:unparam,cyclop + b.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + b.Fatalf("net.ResolveUDPAddr: %v", err) + } + buffer := NewPacketBuffer() + packet := make([]byte, size) + + // Grow the buffer first + pad := make([]byte, 1022) + for len(buffer.packets) < grow { + if _, err := buffer.WriteTo(pad, addr); err != nil { + b.Fatalf("Write: %v", err) + } + } + for buffer.read != buffer.write { + if _, _, err := buffer.ReadFrom(pad); err != nil { + b.Fatalf("ReadFrom: %v", err) + } + } + + if write { + if _, err := buffer.WriteTo(packet, addr); err != nil { + b.Fatalf("Write: %v", err) + } + } + + b.SetBytes(size) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := buffer.WriteTo(packet, addr); err != nil { + b.Fatalf("Write: %v", err) + } + if _, _, err := buffer.ReadFrom(packet); err != nil { + b.Fatalf("Write: %v", err) + } + } +} + +// In this benchmark, the buffer is often empty, which is hopefully +// typical of real usage. +func BenchmarkBufferWR14(b *testing.B) { + benchmarkBufferWR(b, 14, false, 128) +} + +func BenchmarkBufferWR140(b *testing.B) { + benchmarkBufferWR(b, 140, false, 128) +} + +func BenchmarkBufferWR1400(b *testing.B) { + benchmarkBufferWR(b, 1400, false, 128) +} + +// Here, the buffer never becomes empty, which forces wraparound. +func BenchmarkBufferWWR14(b *testing.B) { + benchmarkBufferWR(b, 14, true, 128) +} + +func BenchmarkBufferWWR140(b *testing.B) { + benchmarkBufferWR(b, 140, true, 128) +} + +func BenchmarkBufferWWR1400(b *testing.B) { + benchmarkBufferWR(b, 1400, true, 128) +} + +func benchmarkBuffer(b *testing.B, size int64) { + b.Helper() + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5684") + if err != nil { + b.Fatalf("net.ResolveUDPAddr: %v", err) + } + buffer := NewPacketBuffer() + b.SetBytes(size) + + done := make(chan struct{}) + go func() { + packet := make([]byte, size) + + for { + _, _, err := buffer.ReadFrom(packet) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + b.Error(err) + + break + } + } + + close(done) + }() + + packet := make([]byte, size) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var err error + for { + _, err = buffer.WriteTo(packet, addr) + if !errors.Is(err, bytes.ErrTooLarge) { + break + } + time.Sleep(time.Microsecond) + } + if err != nil { + b.Fatal(err) + } + } + + if err := buffer.Close(); err != nil { + b.Fatal(err) + } + + <-done +} + +func BenchmarkBuffer14(b *testing.B) { + benchmarkBuffer(b, 14) +} + +func BenchmarkBuffer140(b *testing.B) { + benchmarkBuffer(b, 140) +} + +func BenchmarkBuffer1400(b *testing.B) { + benchmarkBuffer(b, 1400) +} diff --git a/internal/net/udp/packet_conn.go b/internal/net/udp/packet_conn.go new file mode 100644 index 000000000..e3e214ce9 --- /dev/null +++ b/internal/net/udp/packet_conn.go @@ -0,0 +1,413 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package udp implements DTLS specific UDP networking primitives. +// NOTE: this package is an adaption of pion/transport/udp that allows for +// routing datagrams based on identifiers other than the remote address. The +// primary use case for this functionality is routing based on DTLS connection +// IDs. In order to allow for consumers of this package to treat connections as +// generic net.PackageConn, routing and identitier establishment is based on +// custom introspecion of datagrams, rather than direct intervention by +// consumers. If possible, the updates made in this repository will be reflected +// back upstream. If not, it is likely that this will be moved to a public +// package in this repository. +// +// This package was migrated from pion/transport/udp at +// https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb +package udp + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + idtlsnet "github.com/pion/dtls/v3/internal/net" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/transport/v3/deadline" +) + +const ( + receiveMTU = 8192 + defaultListenBacklog = 128 // same as Linux default +) + +// Typed errors. +var ( + ErrClosedListener = errors.New("udp: listener closed") + ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") +) + +// listener augments a connection-oriented Listener over a UDP PacketConn. +type listener struct { + pConn *net.UDPConn + + accepting atomic.Value // bool + acceptCh chan *PacketConn + doneCh chan struct{} + doneOnce sync.Once + acceptFilter func([]byte) bool + datagramRouter func([]byte) (string, bool) + connIdentifier func([]byte) (string, bool) + + connLock sync.Mutex + conns map[string]*PacketConn + connWG sync.WaitGroup + + readWG sync.WaitGroup + errClose atomic.Value // error + + readDoneCh chan struct{} + errRead atomic.Value // error +} + +// Accept waits for and returns the next connection to the listener. +func (l *listener) Accept() (net.PacketConn, net.Addr, error) { + select { + case c := <-l.acceptCh: + l.connWG.Add(1) + + return c, c.raddr, nil + + case <-l.readDoneCh: + err, _ := l.errRead.Load().(error) + + return nil, nil, err + + case <-l.doneCh: + return nil, nil, ErrClosedListener + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *listener) Close() error { + var err error + l.doneOnce.Do(func() { + l.accepting.Store(false) + close(l.doneCh) + + l.connLock.Lock() + // Close unaccepted connections + lclose: + for { + select { + case c := <-l.acceptCh: + close(c.doneCh) + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(l.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it + // from the connection map. + if c.rmraddr.Load() == nil { + delete(l.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + default: + break lclose + } + } + nConns := len(l.conns) + l.connLock.Unlock() + + l.connWG.Done() + + if nConns == 0 { + // Wait if this is the final connection. + l.readWG.Wait() + if errClose, ok := l.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + }) + + return err +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.pConn.LocalAddr() +} + +// ListenConfig stores options for listening to an address. +type ListenConfig struct { + // Backlog defines the maximum length of the queue of pending + // connections. It is equivalent of the backlog argument of + // POSIX listen function. + // If a connection request arrives when the queue is full, + // the request will be silently discarded, unlike TCP. + // Set zero to use default value 128 which is same as Linux default. + Backlog int + + // AcceptFilter determines whether the new conn should be made for + // the incoming packet. If not set, any packet creates new conn. + AcceptFilter func([]byte) bool + + // DatagramRouter routes an incoming datagram to a connection by extracting + // an identifier from the its paylod + DatagramRouter func([]byte) (string, bool) + + // ConnectionIdentifier extracts an identifier from an outgoing packet. If + // the identifier is not already associated with the connection, it will be + // added. + ConnectionIdentifier func([]byte) (string, bool) +} + +// Listen creates a new listener based on the ListenConfig. +func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + if lc.Backlog == 0 { + lc.Backlog = defaultListenBacklog + } + + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + + packetListener := &listener{ + pConn: conn, + acceptCh: make(chan *PacketConn, lc.Backlog), + conns: make(map[string]*PacketConn), + doneCh: make(chan struct{}), + acceptFilter: lc.AcceptFilter, + datagramRouter: lc.DatagramRouter, + connIdentifier: lc.ConnectionIdentifier, + readDoneCh: make(chan struct{}), + } + + packetListener.accepting.Store(true) + packetListener.connWG.Add(1) + packetListener.readWG.Add(2) // wait readLoop and Close execution routine + + go packetListener.readLoop() + go func() { + packetListener.connWG.Wait() + if err := packetListener.pConn.Close(); err != nil { + packetListener.errClose.Store(err) + } + packetListener.readWG.Done() + }() + + return packetListener, nil +} + +// Listen creates a new listener using default ListenConfig. +func Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + return (&ListenConfig{}).Listen(network, laddr) +} + +// readLoop dispatches packets to the proper connection, creating a new one if +// necessary, until all connections are closed. +func (l *listener) readLoop() { + defer l.readWG.Done() + defer close(l.readDoneCh) + + buf := make([]byte, receiveMTU) + + for { + n, raddr, err := l.pConn.ReadFrom(buf) + if err != nil { + l.errRead.Store(err) + + return + } + conn, ok, err := l.getConn(raddr, buf[:n]) + if err != nil { + continue + } + if ok { + _, _ = conn.buffer.WriteTo(buf[:n], raddr) + } + } +} + +// getConn gets an existing connection or creates a new one. +func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error) { //nolint:cyclop + l.connLock.Lock() + defer l.connLock.Unlock() + // If we have a custom resolver, use it. + if l.datagramRouter != nil { + if id, ok := l.datagramRouter(buf); ok { + if conn, ok := l.conns[id]; ok { + return conn, true, nil + } + } + } + + // If we don't have a custom resolver, or we were unable to find an + // associated connection, fall back to remote address. + conn, ok := l.conns[raddr.String()] + if !ok { + if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok { + return nil, false, ErrClosedListener + } + if l.acceptFilter != nil { + if !l.acceptFilter(buf) { + return nil, false, nil + } + } + conn = l.newPacketConn(raddr) + select { + case l.acceptCh <- conn: + l.conns[raddr.String()] = conn + default: + return nil, false, ErrListenQueueExceeded + } + } + + return conn, true, nil +} + +// PacketConn is a net.PacketConn implementation that is able to dictate its +// routing ID via an alternate identifier from its remote address. Internal +// buffering is performed for reads, and writes are passed through to the +// underlying net.PacketConn. +type PacketConn struct { + listener *listener + + raddr net.Addr + rmraddr atomic.Value // bool + id atomic.Value // string + + buffer *idtlsnet.PacketBuffer + + doneCh chan struct{} + doneOnce sync.Once + + writeDeadline *deadline.Deadline +} + +// newPacketConn constructs a new PacketConn. +func (l *listener) newPacketConn(raddr net.Addr) *PacketConn { + return &PacketConn{ + listener: l, + raddr: raddr, + buffer: idtlsnet.NewPacketBuffer(), + doneCh: make(chan struct{}), + writeDeadline: deadline.New(), + } +} + +// ReadFrom reads a single packet payload and its associated remote address from +// the underlying buffer. +func (c *PacketConn) ReadFrom(buff []byte) (int, net.Addr, error) { + return c.buffer.ReadFrom(buff) +} + +// WriteTo writes len(payload) bytes from payload to the specified address. +func (c *PacketConn) WriteTo(payload []byte, addr net.Addr) (n int, err error) { + // If we have a connection identifier, check to see if the outgoing packet + // sets it. + if c.listener.connIdentifier != nil { + id := c.id.Load() + // Only update establish identifier if we haven't already done so. + if id == nil { + candidate, ok := c.listener.connIdentifier(payload) + // If we have an identifier, add entry to connection map. + if ok { + c.listener.connLock.Lock() + c.listener.conns[candidate] = c + c.listener.connLock.Unlock() + c.id.Store(candidate) + } + } + // If we are writing to a remote address that differs from the initial, + // we have an alternate identifier established, and we haven't already + // freed the remote address, free the remote address to be used by + // another connection. + // Note: this strategy results in holding onto a remote address after it + // is potentially no longer in use by the client. However, releasing + // earlier means that we could miss some packets that should have been + // routed to this connection. Ideally, we would drop the connection + // entry for the remote address as soon as the client starts sending + // using an alternate identifier, but in practice this proves + // challenging because any client could spoof a connection identifier, + // resulting in the remote address entry being dropped prior to the + // "real" client transitioning to sending using the alternate + // identifier. + if id != nil && c.rmraddr.Load() == nil && addr.String() != c.raddr.String() { + c.listener.connLock.Lock() + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + c.listener.connLock.Unlock() + } + } + + select { + case <-c.writeDeadline.Done(): + return 0, context.DeadlineExceeded + default: + } + + return c.listener.pConn.WriteTo(payload, addr) +} + +// Close closes the conn and releases any Read calls. +func (c *PacketConn) Close() error { + var err error + c.doneOnce.Do(func() { + c.listener.connWG.Done() + close(c.doneCh) + c.listener.connLock.Lock() + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(c.listener.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it from the + // connection map. + if c.rmraddr.Load() == nil { + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + nConns := len(c.listener.conns) + c.listener.connLock.Unlock() + + if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok { + // Wait if this is the final connection + c.listener.readWG.Wait() + if errClose, ok := c.listener.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + + if errBuf := c.buffer.Close(); errBuf != nil && err == nil { + err = errBuf + } + }) + + return err +} + +// LocalAddr implements net.PacketConn.LocalAddr. +func (c *PacketConn) LocalAddr() net.Addr { + return c.listener.pConn.LocalAddr() +} + +// SetDeadline implements net.PacketConn.SetDeadline. +func (c *PacketConn) SetDeadline(t time.Time) error { + c.writeDeadline.Set(t) + + return c.SetReadDeadline(t) +} + +// SetReadDeadline implements net.PacketConn.SetReadDeadline. +func (c *PacketConn) SetReadDeadline(t time.Time) error { + return c.buffer.SetReadDeadline(t) +} + +// SetWriteDeadline implements net.PacketConn.SetWriteDeadline. +func (c *PacketConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + // Write deadline of underlying connection should not be changed + // since the connection can be shared. + return nil +} diff --git a/internal/net/udp/packet_conn_test.go b/internal/net/udp/packet_conn_test.go new file mode 100644 index 000000000..53b3c06d5 --- /dev/null +++ b/internal/net/udp/packet_conn_test.go @@ -0,0 +1,742 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +// Package udp implements DTLS specific UDP networking primitives. +package udp + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/transport/v3/test" +) + +var errHandshakeFailed = errors.New("handshake failed") + +func TestStressDuplex(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + // Run the test + stressDuplex(t) +} + +type rw struct { + p net.PacketConn + raddr net.Addr +} + +func fromPC(p net.PacketConn, raddr net.Addr) *rw { + return &rw{ + p: p, + raddr: raddr, + } +} + +func (r *rw) Read(p []byte) (int, error) { + n, _, err := r.p.ReadFrom(p) + + return n, err +} + +func (r *rw) Write(p []byte) (int, error) { + return r.p.WriteTo(p, r.raddr) +} + +func stressDuplex(t *testing.T) { + t.Helper() + + listener, ca, cb, err := pipe() + if err != nil { + t.Fatal(err) + } + + defer func() { + if ca.Close() != nil { + t.Fatal(err) + } + if cb.Close() != nil { + t.Fatal(err) + } + if listener.Close() != nil { + t.Fatal(err) + } + }() + + opt := test.Options{ + MsgSize: 2048, + MsgCount: 1, // Can't rely on UDP message order in CI + } + + if err := test.StressDuplex(fromPC(ca, cb.LocalAddr()), cb, opt); err != nil { + t.Fatal(err) + } +} + +func TestListenerCloseTimeout(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + listener, ca, _, err := pipe() + if err != nil { + t.Fatal(err) + } + + err = listener.Close() + if err != nil { + t.Fatal(err) + } + + // Close client after server closes to cleanup + err = ca.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestListenerCloseUnaccepted(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + const backlog = 2 + + network, addr := getConfig() + listener, err := (&ListenConfig{ + Backlog: backlog, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < backlog; i++ { + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + + continue + } + if _, wErr := conn.Write([]byte{byte(i)}); wErr != nil { + t.Error(wErr) + } + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + } + + time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop + + // Unaccepted connections must be closed by listener.Close() + if err = listener.Close(); err != nil { + t.Fatal(err) + } +} + +func TestListenerAcceptFilter(t *testing.T) { //nolint:cyclop + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + testCases := map[string]struct { + packet []byte + accept bool + }{ + "CreateConn": { + packet: []byte{0xAA}, + accept: true, + }, + "Discarded": { + packet: []byte{0x00}, + accept: false, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + network, addr := getConfig() + listener, err := (&ListenConfig{ + AcceptFilter: func(pkt []byte) bool { + return pkt[0] == 0xAA + }, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + var wgAcceptLoop sync.WaitGroup + wgAcceptLoop.Add(1) + defer func() { + if lErr := listener.Close(); lErr != nil { + t.Fatal(lErr) + } + wgAcceptLoop.Wait() + }() + + conn, err := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if err != nil { + t.Fatal(err) + } + if _, err := conn.Write(testCase.packet); err != nil { + t.Fatal(err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + chAccepted := make(chan struct{}) + go func() { + defer wgAcceptLoop.Done() + + conn, _, aArr := listener.Accept() + if aArr != nil { + if !errors.Is(aArr, ErrClosedListener) { + t.Error(aArr) + } + + return + } + close(chAccepted) + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + + var accepted bool + select { + case <-chAccepted: + accepted = true + case <-time.After(10 * time.Millisecond): + } + + if accepted != testCase.accept { + if testCase.accept { + t.Error("Packet should create new conn") + } else { + t.Error("Packet should not create new conn") + } + } + }) + } +} + +func TestListenerConcurrent(t *testing.T) { //nolint:gocyclo,cyclop + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + const backlog = 2 + + network, addr := getConfig() + listener, err := (&ListenConfig{ + Backlog: backlog, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < backlog+1; i++ { + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + + continue + } + if _, wErr := conn.Write([]byte{byte(i)}); wErr != nil { + t.Error(wErr) + } + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + } + + time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop + + for i := 0; i < backlog; i++ { + conn, _, lErr := listener.Accept() + if lErr != nil { + t.Error(lErr) + + continue + } + b := make([]byte, 1) + n, _, lErr := conn.ReadFrom(b) + if lErr != nil { + t.Error(lErr) + } else if !bytes.Equal([]byte{byte(i)}, b[:n]) { + t.Errorf("Packet from connection %d is wrong, expected: [%d], got: %v", i, i, b[:n]) + } + if lErr = conn.Close(); lErr != nil { + t.Error(lErr) + } + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if conn, _, lErr := listener.Accept(); !errors.Is(lErr, ErrClosedListener) { + t.Errorf("Connection exceeding backlog limit must be discarded: %v", lErr) + if lErr == nil { + _ = conn.Close() + } + } + }() + + time.Sleep(100 * time.Millisecond) // Last Accept should be discarded + err = listener.Close() + if err != nil { + t.Fatal(err) + } + + wg.Wait() +} + +func pipe() (dtlsnet.PacketListener, net.PacketConn, *net.UDPConn, error) { + // Start listening + network, addr := getConfig() + listener, err := Listen(network, addr) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to listen: %w", err) + } + + // Open a connection + var dConn *net.UDPConn + dConn, err = net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to dial: %w", err) + } + + // Write to the connection to initiate it + handshake := "hello" + _, err = dConn.Write([]byte(handshake)) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to write to dialed Conn: %w", err) + } + + // Accept the connection + var lConn net.PacketConn + lConn, _, err = listener.Accept() + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to accept Conn: %w", err) + } + + var n int + buf := make([]byte, len(handshake)) + if n, _, err = lConn.ReadFrom(buf); err != nil { + return nil, nil, nil, fmt.Errorf("failed to read handshake: %w", err) + } + + result := string(buf[:n]) + if handshake != result { + return nil, nil, nil, fmt.Errorf("%w: %s != %s", errHandshakeFailed, handshake, result) + } + + return listener, lConn, dConn, nil +} + +func getConfig() (string, *net.UDPAddr) { + return "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} +} + +func TestConnClose(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + t.Run("Close", func(t *testing.T) { + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + udpListener, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + if err := ca.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := cb.Close(); err != nil { + t.Errorf("Failed to close B side: %v", err) + } + if err := udpListener.Close(); err != nil { + t.Errorf("Failed to close listener: %v", err) + } + }) + t.Run("CloseError1", func(t *testing.T) { + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + udpListener, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + // Close l.pConn to inject error. + if err := udpListener.(*listener).pConn.Close(); err != nil { //nolint:forcetypeassert + t.Error(err) + } + + if err := cb.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := ca.Close(); err != nil { + t.Errorf("Failed to close B side: %v", err) + } + if err := udpListener.Close(); err == nil { + t.Errorf("Error is not propagated to Listener.Close") + } + }) + t.Run("CloseError2", func(t *testing.T) { + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + l, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + // Close l.pConn to inject error. + if err := l.(*listener).pConn.Close(); err != nil { //nolint:forcetypeassert + t.Error(err) + } + + if err := cb.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := l.Close(); err != nil { + t.Errorf("Failed to close listener: %v", err) + } + if err := ca.Close(); err == nil { + t.Errorf("Error is not propagated to Conn.Close") + } + }) + t.Run("CancelRead", func(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + listener, ca, cb, errPipe := pipe() + if errPipe != nil { + t.Fatal(errPipe) + } + + errC := make(chan error, 1) + go func() { + buf := make([]byte, 1024) + // This read will block because we don't write on the other side. + // Calling Close must unblock the call. + _, _, err := ca.ReadFrom(buf) + errC <- err + }() + + if err := ca.Close(); err != nil { // Trigger Read cancellation. + t.Errorf("Failed to close B side: %v", err) + } + + // Main test condition, Read should return + // after ca.Close() by closing the buffer. + if err := <-errC; !errors.Is(err, io.EOF) { + t.Errorf("expected err to be io.EOF but got %v", err) + } + + if err := cb.Close(); err != nil { + t.Errorf("Failed to close A side: %v", err) + } + if err := listener.Close(); err != nil { + t.Errorf("Failed to close listener: %v", err) + } + }) +} + +func TestListenerCustomConnIDs(t *testing.T) { //nolint:gocyclo,cyclop,maintidx + const helloPayload, setPayload = "hello", "set" + const serverCount, clientCount = 5, 20 + // Limit runtime in case of deadlocks. + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines. + report := test.CheckRoutines(t) + defer report() + + type pkt struct { + ID int + Payload string + } + network, addr := getConfig() + listener, err := (&ListenConfig{ + // For all datagrams other than the initial "hello" packet, use the ID + // to route. + DatagramRouter: func(buf []byte) (string, bool) { + var p pkt + if err := json.Unmarshal(buf, &p); err != nil { + return "", false + } + if p.Payload == helloPayload { + return "", false + } + + return fmt.Sprint(p.ID), true + }, + // Use the outgoing "set" payload to add an identifier for a connection. + ConnectionIdentifier: func(buf []byte) (string, bool) { + var p pkt + if err := json.Unmarshal(buf, &p); err != nil { + return "", false + } + if p.Payload == setPayload { + return fmt.Sprint(p.ID), true + } + + return "", false + }, + }).Listen(network, addr) + if err != nil { + t.Fatal(err) + } + + var clientWg sync.WaitGroup + var phaseOne [5]chan struct{} + for i := range phaseOne { + phaseOne[i] = make(chan struct{}) + } + var serverWg sync.WaitGroup + clientMap := map[string]struct{}{} + var clientMapMu sync.Mutex + // Start servers. + for i := 0; i < serverCount; i++ { + serverWg.Add(1) + go func() { + defer serverWg.Done() + // The first payload from the accepted connection should inform + // which connection this server is. + conn, _, err := listener.Accept() + if err != nil { + t.Error(err) + + return + } + buf := make([]byte, 100) + n, raddr, rErr := conn.ReadFrom(buf) + if rErr != nil { + t.Error(err) + + return + } + var udpPkt pkt + if uErr := json.Unmarshal(buf[:n], &udpPkt); uErr != nil { + t.Error(err) + + return + } + // First message should be a hello and custom connection + // ID function will use remote address as identifier. + if udpPkt.Payload != helloPayload { + t.Error("Expected hello message") + + return + } + connID := udpPkt.ID + + // Send set message to associate ID with this connection. + buf, err = json.Marshal(&pkt{ + ID: connID, + Payload: "set", + }) + if err != nil { + t.Error(err) + + return + } + if _, wErr := conn.WriteTo(buf, raddr); wErr != nil { + t.Error(wErr) + + return + } + // Signal to the corresponding clients that connection ID has been + // set. + close(phaseOne[connID]) + // Receive packets, ensuring that each one came from a different + // client remote address and has a unique payload. + for j := 0; j < clientCount/serverCount; j++ { + buf := make([]byte, 100) + n, _, err := conn.ReadFrom(buf) + if err != nil { + t.Error(err) + + return + } + var udpPkt pkt + if err := json.Unmarshal(buf[:n], &udpPkt); err != nil { + t.Error(err) + + return + } + if udpPkt.ID != connID { + t.Errorf("Expected connection ID %d, but got %d", connID, udpPkt.ID) + + return + } + // Ensure we only ever receive one message from + // a given client. + clientMapMu.Lock() + if _, ok := clientMap[udpPkt.Payload]; ok { + t.Errorf("Multiple messages from single client %s", udpPkt.Payload) + + return + } + clientMap[udpPkt.Payload] = struct{}{} + clientMapMu.Unlock() + } + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + } + + // Start a client per server to send initial "hello" message and receive a + // "set" message. + for i := 0; i < serverCount; i++ { + clientWg.Add(1) + go func(connID int) { + defer clientWg.Done() + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + + return + } + hbuf, err := json.Marshal(&pkt{ + ID: connID, + Payload: helloPayload, + }) + if err != nil { + t.Error(err) + + return + } + if _, wErr := conn.Write(hbuf); wErr != nil { + t.Error(wErr) + + return + } + + var udpPacket pkt + buf := make([]byte, 100) + n, err := conn.Read(buf) + if err != nil { + t.Error(err) + + return + } + if err := json.Unmarshal(buf[:n], &udpPacket); err != nil { + t.Error(err) + + return + } + // Second message should be a set and custom connection identifier + // function will update the connection ID from remote address to the + // supplied ID. + if udpPacket.Payload != "set" { + t.Error("Expected set message") + + return + } + // Ensure the connection ID matches what the "hello" message + // indicated. + if udpPacket.ID != connID { + t.Errorf("Expected connection ID %d, but got %d", connID, udpPacket.ID) + + return + } + // Close connection. We will reconnect from a different remote + // address using the same connection ID. + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + }(i) + } + + // Spawn clients sending to server connections. + for i := 1; i <= clientCount; i++ { + clientWg.Add(1) + go func(connID int) { + defer clientWg.Done() + // Ensure that we are using a connection ID for packet + // routing prior to sending any messages. + <-phaseOne[connID] + conn, dErr := net.DialUDP(network, nil, listener.Addr().(*net.UDPAddr)) + if dErr != nil { + t.Error(dErr) + + return + } + // Send a packet with a connection ID and this client's local + // address. The latter is used to identify this client as unique. + buf, err := json.Marshal(&pkt{ + ID: connID, + Payload: conn.LocalAddr().String(), + }) + if err != nil { + t.Error(err) + + return + } + if _, wErr := conn.Write(buf); wErr != nil { + t.Error(wErr) + + return + } + if cErr := conn.Close(); cErr != nil { + t.Error(cErr) + } + }(i % serverCount) + } + + // Wait for clients to exit. + clientWg.Wait() + // Wait for servers to exit. + serverWg.Wait() + if err := listener.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/internal/util/util.go b/internal/util/util.go index 685910fc2..8ebbcd44f 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -6,9 +6,11 @@ package util import ( "encoding/binary" + + "golang.org/x/crypto/cryptobyte" ) -// BigEndianUint24 returns the value of a big endian uint24 +// BigEndianUint24 returns the value of a big endian uint24. func BigEndianUint24(raw []byte) uint32 { if len(raw) < 3 { return 0 @@ -16,27 +18,36 @@ func BigEndianUint24(raw []byte) uint32 { rawCopy := make([]byte, 4) copy(rawCopy[1:], raw) + return binary.BigEndian.Uint32(rawCopy) } -// PutBigEndianUint24 encodes a uint24 and places into out +// PutBigEndianUint24 encodes a uint24 and places into out. func PutBigEndianUint24(out []byte, in uint32) { tmp := make([]byte, 4) binary.BigEndian.PutUint32(tmp, in) copy(out, tmp[1:]) } -// PutBigEndianUint48 encodes a uint64 and places into out +// PutBigEndianUint48 encodes a uint64 and places into out. func PutBigEndianUint48(out []byte, in uint64) { tmp := make([]byte, 8) binary.BigEndian.PutUint64(tmp, in) copy(out, tmp[2:]) } -// Max returns the larger value +// Max returns the larger value. func Max(a, b int) int { if a > b { return a } + return b } + +// AddUint48 appends a big-endian, 48-bit value to the byte string. +// Remove if / when https://github.com/golang/crypto/pull/265 is merged +// upstream. +func AddUint48(b *cryptobyte.Builder, v uint64) { + b.AddBytes([]byte{byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}) +} diff --git a/internal/util/util_test.go b/internal/util/util_test.go new file mode 100644 index 000000000..41127c0b8 --- /dev/null +++ b/internal/util/util_test.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package util + +import ( + "bytes" + "testing" + + "golang.org/x/crypto/cryptobyte" +) + +func TestAddUint48(t *testing.T) { + cases := map[string]struct { + reason string + builder *cryptobyte.Builder + postAdd func(*cryptobyte.Builder) + in uint64 + want []byte + }{ + "OnlyUint48": { + reason: "Adding only a 48-bit unsigned integer should yield expected result.", + builder: &cryptobyte.Builder{}, + in: 0xfefcff3cfdfc, + want: []byte{254, 252, 255, 60, 253, 252}, + }, + "ExistingAddUint48": { + reason: "Adding a 48-bit unsigned integer to a builder with existing bytes should yield expected result.", + builder: func() *cryptobyte.Builder { + var b cryptobyte.Builder + b.AddUint64(0xffffffffffffffff) + + return &b + }(), + in: 0xfefcff3cfdfc, + want: []byte{255, 255, 255, 255, 255, 255, 255, 255, 254, 252, 255, 60, 253, 252}, + }, + "ExistingAddUint48AndMore": { + //nolint:lll + reason: "Adding a 48-bit unsigned integer to a builder with existing bytes, then adding more bytes, should yield expected result.", + builder: func() *cryptobyte.Builder { + var b cryptobyte.Builder + b.AddUint64(0xffffffffffffffff) + + return &b + }(), + postAdd: func(b *cryptobyte.Builder) { + b.AddUint32(0xffffffff) + }, + in: 0xfefcff3cfdfc, + want: []byte{255, 255, 255, 255, 255, 255, 255, 255, 254, 252, 255, 60, 253, 252, 255, 255, 255, 255}, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + AddUint48(tc.builder, tc.in) + if tc.postAdd != nil { + tc.postAdd(tc.builder) + } + got := tc.builder.BytesOrPanic() + if !bytes.Equal(got, tc.want) { + t.Errorf("Bytes() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/listener.go b/listener.go index 190d236c7..3583d0308 100644 --- a/listener.go +++ b/listener.go @@ -6,12 +6,13 @@ package dtls import ( "net" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" - "github.com/pion/transport/v2/udp" + "github.com/pion/dtls/v3/internal/net/udp" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// Listen creates a DTLS listener +// Listen creates a DTLS listener. func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) { if err := validateConfig(config); err != nil { return nil, err @@ -27,13 +28,21 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e if err := h.Unmarshal(pkts[0]); err != nil { return false } + return h.ContentType == protocol.ContentTypeHandshake }, } + // If connection ID support is enabled, then they must be supported in + // routing. + if config.ConnectionIDGenerator != nil { + lc.DatagramRouter = cidDatagramRouter(len(config.ConnectionIDGenerator())) + lc.ConnectionIdentifier = cidConnIdentifier() + } parent, err := lc.Listen(network, laddr) if err != nil { return nil, err } + return &listener{ config: config, parent: parent, @@ -41,7 +50,7 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e } // NewListener creates a DTLS listener which accepts connections from an inner Listener. -func NewListener(inner net.Listener, config *Config) (net.Listener, error) { +func NewListener(inner dtlsnet.PacketListener, config *Config) (net.Listener, error) { if err := validateConfig(config); err != nil { return nil, err } @@ -52,22 +61,21 @@ func NewListener(inner net.Listener, config *Config) (net.Listener, error) { }, nil } -// listener represents a DTLS listener +// listener represents a DTLS listener. type listener struct { config *Config - parent net.Listener + parent dtlsnet.PacketListener } // Accept waits for and returns the next connection to the listener. // You have to either close or read on all connection that are created. -// Connection handshake will timeout using ConnectContextMaker in the Config. -// If you want to specify the timeout duration, set ConnectContextMaker. func (l *listener) Accept() (net.Conn, error) { - c, err := l.parent.Accept() + c, raddr, err := l.parent.Accept() if err != nil { return nil, err } - return Server(c, l.config) + + return Server(c, raddr, l.config) } // Close closes the listener. diff --git a/nettest_test.go b/nettest_test.go index e3cb4eb35..3a712434f 100644 --- a/nettest_test.go +++ b/nettest_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/pion/transport/v2/test" + "github.com/pion/transport/v3/test" "golang.org/x/net/nettest" ) @@ -28,6 +28,7 @@ func TestNetTest(t *testing.T) { _ = c1.Close() _ = c2.Close() } + return }) } diff --git a/packet.go b/packet.go index 55d6272ee..c224eff1b 100644 --- a/packet.go +++ b/packet.go @@ -3,10 +3,13 @@ package dtls -import "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +import ( + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) type packet struct { record *recordlayer.RecordLayer shouldEncrypt bool + shouldWrapCID bool resetLocalSequenceNumber bool } diff --git a/pkg/crypto/ccm/ccm.go b/pkg/crypto/ccm/ccm.go index d6e6fc479..bc268deeb 100644 --- a/pkg/crypto/ccm/ccm.go +++ b/pkg/crypto/ccm/ccm.go @@ -65,7 +65,8 @@ func NewCCM(b cipher.Block, tagsize, noncesize int) (CCM, error) { if lensize < 2 || lensize > 8 { return nil, errInvalidNonceSize } - c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)} + c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)} //nolint:gosec // G114 + return c, nil } @@ -74,14 +75,15 @@ func (c *ccm) Overhead() int { return int(c.M) } func (c *ccm) MaxLength() int { return maxlen(c.L, c.Overhead()) } func maxlen(l uint8, tagsize int) int { - max := (uint64(1) << (8 * l)) - 1 - if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || max > m64 { - max = m64 // The maximum lentgh on a 64bit arch + mLen := (uint64(1) << (8 * l)) - 1 + if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || mLen > m64 { //nolint:gosec // G114 + mLen = m64 // The maximum lentgh on a 64bit arch } - if max != uint64(int(max)) { + if mLen != uint64(int(mLen)) { //nolint:gosec // G114 return math.MaxInt32 - tagsize // We have only 32bit int's } - return int(max) + + return int(mLen) //nolint:gosec // G114 } // MaxNonceLength returns the maximum nonce length for a given plaintext length. @@ -90,10 +92,11 @@ func maxlen(l uint8, tagsize int) int { func MaxNonceLength(pdatalen int) int { const tagsize = 16 for L := 2; L <= 8; L++ { - if maxlen(uint8(L), tagsize) >= pdatalen { + if maxlen(uint8(L), tagsize) >= pdatalen { //nolint:gosec // G115 return 15 - L } } + return 0 } @@ -137,20 +140,20 @@ func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) { c.b.Encrypt(mac[:], mac[:]) var block [ccmBlockSize]byte - if n := uint64(len(adata)); n > 0 { + if adataLength := uint64(len(adata)); adataLength > 0 { //nolint:nestif // First adata block includes adata length i := 2 - if n <= 0xfeff { - binary.BigEndian.PutUint16(block[:i], uint16(n)) + if adataLength <= 0xfeff { + binary.BigEndian.PutUint16(block[:i], uint16(adataLength)) } else { block[0] = 0xfe block[1] = 0xff - if n < uint64(1<<32) { + if adataLength < uint64(1<<32) { i = 2 + 4 - binary.BigEndian.PutUint32(block[2:i], uint32(n)) + binary.BigEndian.PutUint32(block[2:i], uint32(adataLength)) //nolint:gosec // G115 } else { i = 2 + 8 - binary.BigEndian.PutUint64(block[2:i], n) + binary.BigEndian.PutUint64(block[2:i], adataLength) } } i = copy(block[i:], adata) @@ -170,6 +173,7 @@ func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) { // second slice that aliases into it and contains only the extra bytes. If the // original slice has sufficient capacity then no allocation is performed. // From crypto/cipher/gcm.go +// . func sliceForAppend(in []byte, n int) (head, tail []byte) { if total := len(in) + n; cap(in) >= total { head = in[:total] @@ -178,6 +182,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) { copy(head, in) } tail = head[len(in):] + return } @@ -207,6 +212,7 @@ func (c *ccm) Seal(dst, nonce, plaintext, adata []byte) []byte { ret, out := sliceForAppend(dst, len(plaintext)+int(c.M)) stream.XORKeyStream(out, plaintext) copy(out[len(plaintext):], tag) + return ret } @@ -250,5 +256,6 @@ func (c *ccm) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) { if subtle.ConstantTimeCompare(tag, expectedTag) != 1 { return nil, errOpen } + return append(dst, plaintext...), nil } diff --git a/pkg/crypto/ccm/ccm_test.go b/pkg/crypto/ccm/ccm_test.go index da88f2a05..c8056c721 100644 --- a/pkg/crypto/ccm/ccm_test.go +++ b/pkg/crypto/ccm/ccm_test.go @@ -19,6 +19,7 @@ func mustHexDecode(s string) []byte { if err != nil { panic(err) } + return r } @@ -32,7 +33,7 @@ var ( // ClearHeaderOctets: Input with X cleartext header octets // Data: Input with X cleartext header octets // M: length(CBC-MAC) -// Nonce: Nonce +// Nonce: Nonce. type vector struct { AESKey []byte CipherText []byte @@ -42,7 +43,7 @@ type vector struct { Nonce []byte } -func TestRFC3610Vectors(t *testing.T) { +func TestRFC3610Vectors(t *testing.T) { //nolint:maintidx cases := []vector{ // Vectors 1-12 { @@ -62,8 +63,10 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("00000004030201a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060751b1e5f44a197d1da46b0f8e2d282ae871e838bb64da8596574adaa76fbd9fb0c5"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060751b1e5f44a197d1da46b0f8e2d282ae871e838bb64da8596574adaa76fbd9fb0c5", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 8, @@ -86,56 +89,70 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("00000007060504a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0b6fc1b011f006568b5171a42d953d469b2570a4bd87405a0443ac91cb94"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0b6fc1b011f006568b5171a42d953d469b2570a4bd87405a0443ac91cb94", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 8, Nonce: mustHexDecode("00000008070605a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("00010203040506070135d1b2c95f41d5d1d4fec185d166b8094e999dfed96c048c56602c97acbb7490"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "00010203040506070135d1b2c95f41d5d1d4fec185d166b8094e999dfed96c048c56602c97acbb7490", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), M: 10, Nonce: mustHexDecode("00000009080706a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("00010203040506077b75399ac0831dd2f0bbd75879a2fd8f6cae6b6cd9b7db24c17b4433f434963f34b4"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "00010203040506077b75399ac0831dd2f0bbd75879a2fd8f6cae6b6cd9b7db24c17b4433f434963f34b4", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), M: 10, Nonce: mustHexDecode("0000000a090807a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060782531a60cc24945a4b8279181ab5c84df21ce7f9b73f42e197ea9c07e56b5eb17e5f4e"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060782531a60cc24945a4b8279181ab5c84df21ce7f9b73f42e197ea9c07e56b5eb17e5f4e", + ), ClearHeaderOctets: 8, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 10, Nonce: mustHexDecode("0000000b0a0908a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0b07342594157785152b074098330abb141b947b566aa9406b4d999988dd"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0b07342594157785152b074098330abb141b947b566aa9406b4d999988dd", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), M: 10, Nonce: mustHexDecode("0000000c0b0a09a0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0b676bb20380b0e301e8ab79590a396da78b834934f53aa2e9107a8b6c022c"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0b676bb20380b0e301e8ab79590a396da78b834934f53aa2e9107a8b6c022c", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), M: 10, Nonce: mustHexDecode("0000000d0c0b0aa0a1a2a3a4a5"), }, { - AESKey: aesKey1to12, - CipherText: mustHexDecode("000102030405060708090a0bc0ffa0d6f05bdb67f24d43a4338d2aa4bed7b20e43cd1aa31662e7ad65d6db"), + AESKey: aesKey1to12, + CipherText: mustHexDecode( + "000102030405060708090a0bc0ffa0d6f05bdb67f24d43a4338d2aa4bed7b20e43cd1aa31662e7ad65d6db", + ), ClearHeaderOctets: 12, Data: mustHexDecode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), M: 10, @@ -159,8 +176,10 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("0033568ef7b2633c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("aa6cfa36cae86b40b1d23a2220ddc0ac900d9aa03c61fcf4a559a4417767089708a776796edb723506"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "aa6cfa36cae86b40b1d23a2220ddc0ac900d9aa03c61fcf4a559a4417767089708a776796edb723506", + ), ClearHeaderOctets: 8, Data: mustHexDecode("aa6cfa36cae86b40b916e0eacc1c00d7dcec68ec0b3bbb1a02de8a2d1aa346132e"), M: 8, @@ -183,56 +202,70 @@ func TestRFC3610Vectors(t *testing.T) { Nonce: mustHexDecode("00f8b678094e3b3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("cd9044d2b71fdb8120ea60c0009769ecabdf48625594c59251e6035722675e04c847099e5ae0704551"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "cd9044d2b71fdb8120ea60c0009769ecabdf48625594c59251e6035722675e04c847099e5ae0704551", + ), ClearHeaderOctets: 12, Data: mustHexDecode("cd9044d2b71fdb8120ea60c06435acbafb11a82e2f071d7ca4a5ebd93a803ba87f"), M: 8, Nonce: mustHexDecode("00d560912d3f703c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("d85bc7e69f944fb8bc218daa947427b6db386a99ac1aef23ade0b52939cb6a637cf9bec2408897c6ba"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "d85bc7e69f944fb8bc218daa947427b6db386a99ac1aef23ade0b52939cb6a637cf9bec2408897c6ba", + ), ClearHeaderOctets: 8, Data: mustHexDecode("d85bc7e69f944fb88a19b950bcf71a018e5e6701c91787659809d67dbedd18"), M: 10, Nonce: mustHexDecode("0042fff8f1951c3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("74a0ebc9069f5b375810e6fd25874022e80361a478e3e9cf484ab04f447efff6f0a477cc2fc9bf548944"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "74a0ebc9069f5b375810e6fd25874022e80361a478e3e9cf484ab04f447efff6f0a477cc2fc9bf548944", + ), ClearHeaderOctets: 8, Data: mustHexDecode("74a0ebc9069f5b371761433c37c5a35fc1f39f406302eb907c6163be38c98437"), M: 10, Nonce: mustHexDecode("00920f40e56cdc3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("44a3aa3aae6475caf2beed7bc5098e83feb5b31608f8e29c38819a89c8e776f1544d4151a4ed3a8b87b9ce"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "44a3aa3aae6475caf2beed7bc5098e83feb5b31608f8e29c38819a89c8e776f1544d4151a4ed3a8b87b9ce", + ), ClearHeaderOctets: 8, Data: mustHexDecode("44a3aa3aae6475caa434a8e58500c6e41530538862d686ea9e81301b5ae4226bfa"), M: 10, Nonce: mustHexDecode("0027ca0c7120bc3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("ec46bb63b02520c33c49fd7031d750a09da3ed7fddd49a2032aabf17ec8ebf7d22c8088c666be5c197"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "ec46bb63b02520c33c49fd7031d750a09da3ed7fddd49a2032aabf17ec8ebf7d22c8088c666be5c197", + ), ClearHeaderOctets: 12, Data: mustHexDecode("ec46bb63b02520c33c49fd70b96b49e21d621741632875db7f6c9243d2d7c2"), M: 10, Nonce: mustHexDecode("005b8ccbcd9af83c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("47a65ac78b3d594227e85e71e882f1dbd38ce3eda7c23f04dd65071eb41342acdf7e00dccec7ae52987d"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "47a65ac78b3d594227e85e71e882f1dbd38ce3eda7c23f04dd65071eb41342acdf7e00dccec7ae52987d", + ), ClearHeaderOctets: 12, Data: mustHexDecode("47a65ac78b3d594227e85e71e2fcfbb880442c731bf95167c8ffd7895e337076"), M: 10, Nonce: mustHexDecode("003ebe94044b9a3c9696766cfa"), }, { - AESKey: aesKey13to24, - CipherText: mustHexDecode("6e37a6ef546d955d34ab6059f32905b88a641b04b9c9ffb58cc390900f3da12ab16dce9e82efa16da62059"), + AESKey: aesKey13to24, + CipherText: mustHexDecode( + "6e37a6ef546d955d34ab6059f32905b88a641b04b9c9ffb58cc390900f3da12ab16dce9e82efa16da62059", + ), ClearHeaderOctets: 12, Data: mustHexDecode("6e37a6ef546d955d34ab6059abf21c0b02feb88f856df4a37381bce3cc128517d4"), M: 10, @@ -245,37 +278,47 @@ func TestRFC3610Vectors(t *testing.T) { t.FailNow() //nolint:revive } - for idx, c := range cases { - c := c + for idx, testCase := range cases { + testCase := testCase t.Run(fmt.Sprintf("packet vector #%d", idx+1), func(t *testing.T) { - blk, err := aes.NewCipher(c.AESKey) + blk, err := aes.NewCipher(testCase.AESKey) if err != nil { t.Fatalf("could not initialize AES block cipher from key: %v", err) } - lccm, err := NewCCM(blk, c.M, len(c.Nonce)) + lccm, err := NewCCM(blk, testCase.M, len(testCase.Nonce)) if err != nil { t.Fatalf("could not create CCM: %v", err) } t.Run("seal", func(t *testing.T) { var dst []byte - dst = lccm.Seal(dst, c.Nonce, c.Data[c.ClearHeaderOctets:], c.Data[:c.ClearHeaderOctets]) - if !bytes.Equal(c.CipherText[c.ClearHeaderOctets:], dst) { + dst = lccm.Seal( + dst, + testCase.Nonce, + testCase.Data[testCase.ClearHeaderOctets:], + testCase.Data[:testCase.ClearHeaderOctets], + ) + if !bytes.Equal(testCase.CipherText[testCase.ClearHeaderOctets:], dst) { t.Fatalf("ciphertext does not match, wanted %v, got %v", - c.CipherText[c.ClearHeaderOctets:], dst) + testCase.CipherText[testCase.ClearHeaderOctets:], dst) } }) t.Run("open", func(t *testing.T) { var dst []byte - dst, err = lccm.Open(dst, c.Nonce, c.CipherText[c.ClearHeaderOctets:], c.CipherText[:c.ClearHeaderOctets]) + dst, err = lccm.Open( + dst, + testCase.Nonce, + testCase.CipherText[testCase.ClearHeaderOctets:], + testCase.CipherText[:testCase.ClearHeaderOctets], + ) if err != nil { t.Fatalf("failed to unseal: %v", err) } - if !bytes.Equal(c.Data[c.ClearHeaderOctets:], dst) { + if !bytes.Equal(testCase.Data[testCase.ClearHeaderOctets:], dst) { t.Fatalf("plaintext does not match, wanted %v, got %v", - c.Data[c.ClearHeaderOctets:], dst) + testCase.Data[testCase.ClearHeaderOctets:], dst) } }) }) @@ -363,21 +406,26 @@ func TestSealError(t *testing.T) { t.Fatalf("could not create CCM: %v", err) } - for name, c := range cases { - c := c + for name, testCase := range cases { + testCase := testCase t.Run(name, func(t *testing.T) { defer func() { err, ok := recover().(error) if !ok { - t.Errorf("expected panic '%v', got '%v'", c.err, err) + t.Errorf("expected panic '%v', got '%v'", testCase.err, err) } - if !errors.Is(err, c.err) { - t.Errorf("expected panic '%v', got '%v'", c.err, err) + if !errors.Is(err, testCase.err) { + t.Errorf("expected panic '%v', got '%v'", testCase.err, err) } }() var dst []byte - _ = lccm.Seal(dst, c.Nonce, c.Data[c.ClearHeaderOctets:], c.Data[:c.ClearHeaderOctets]) + _ = lccm.Seal( + dst, + testCase.Nonce, + testCase.Data[testCase.ClearHeaderOctets:], + testCase.Data[:testCase.ClearHeaderOctets], + ) }) } } diff --git a/pkg/crypto/ciphersuite/cbc.go b/pkg/crypto/ciphersuite/cbc.go index 460fb1437..ab2588f9f 100644 --- a/pkg/crypto/ciphersuite/cbc.go +++ b/pkg/crypto/ciphersuite/cbc.go @@ -11,10 +11,11 @@ import ( //nolint:gci "encoding/binary" "hash" - "github.com/pion/dtls/v2/internal/util" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "golang.org/x/crypto/cryptobyte" ) // block ciphers using cipher block chaining. @@ -23,15 +24,18 @@ type cbcMode interface { SetIV([]byte) } -// CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +// CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. type CBC struct { writeCBC, readCBC cbcMode writeMac, readMac []byte h prf.HashFunc } -// NewCBC creates a DTLS CBC Cipher -func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, h prf.HashFunc) (*CBC, error) { +// NewCBC creates a DTLS CBC Cipher. +func NewCBC( + localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, + hashFunc prf.HashFunc, +) (*CBC, error) { writeBlock, err := aes.NewCipher(localKey) if err != nil { return nil, err @@ -58,24 +62,30 @@ func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMa readCBC: readCBC, readMac: remoteMac, - h: h, + h: hashFunc, }, nil } -// Encrypt encrypt a DTLS RecordLayer message +// Encrypt encrypt a DTLS RecordLayer message. func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { - payload := raw[recordlayer.HeaderSize:] - raw = raw[:recordlayer.HeaderSize] + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] blockSize := c.writeCBC.BlockSize() // Generate + Append MAC h := pkt.Header - MAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, payload, c.writeMac, c.h) + var err error + var mac []byte + if h.ContentType == protocol.ContentTypeConnectionID { + mac, err = c.hmacCID(h.Epoch, h.SequenceNumber, h.Version, payload, c.writeMac, c.h, h.ConnectionID) + } else { + mac, err = c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, payload, c.writeMac, c.h) + } if err != nil { return nil, err } - payload = append(payload, MAC...) + payload = append(payload, mac...) // Generate + Append padding padding := make([]byte, blockSize-len(payload)%blockSize) @@ -94,29 +104,29 @@ func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) // Set IV + Encrypt + Prepend IV c.writeCBC.SetIV(iv) c.writeCBC.CryptBlocks(payload, payload) - payload = append(iv, payload...) + payload = append(iv, payload...) //nolint:makezero // todo: FIX - // Prepend unencrypte header with encrypted payload + // Prepend unencrypted header with encrypted payload raw = append(raw, payload...) // Update recordLayer size to include IV+MAC+Padding - binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize)) + binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) //nolint:gosec //G115 return raw, nil } -// Decrypt decrypts a DTLS RecordLayer message -func (c *CBC) Decrypt(in []byte) ([]byte, error) { - body := in[recordlayer.HeaderSize:] +// Decrypt decrypts a DTLS RecordLayer message. +func (c *CBC) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { blockSize := c.readCBC.BlockSize() mac := c.h() - var h recordlayer.Header - err := h.Unmarshal(in) - switch { - case err != nil: + if err := header.Unmarshal(in); err != nil { return nil, err - case h.ContentType == protocol.ContentTypeChangeCipherSpec: + } + body := in[header.Size():] + + switch { + case header.ContentType == protocol.ContentTypeChangeCipherSpec: // Nothing to encrypt with ChangeCipherSpec return in, nil case len(body)%blockSize != 0 || len(body) < blockSize+util.Max(mac.Size()+1, blockSize): @@ -145,18 +155,35 @@ func (c *CBC) Decrypt(in []byte) ([]byte, error) { dataEnd := len(body) - macSize - paddingLen expectedMAC := body[dataEnd : dataEnd+macSize] - actualMAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, body[:dataEnd], c.readMac, c.h) - + var err error + var actualMAC []byte + if header.ContentType == protocol.ContentTypeConnectionID { + actualMAC, err = c.hmacCID( + header.Epoch, header.SequenceNumber, header.Version, body[:dataEnd], c.readMac, c.h, header.ConnectionID, + ) + } else { + actualMAC, err = c.hmac( + header.Epoch, header.SequenceNumber, header.ContentType, header.Version, body[:dataEnd], c.readMac, c.h, + ) + } // Compute Local MAC and compare if err != nil || !hmac.Equal(actualMAC, expectedMAC) { return nil, errInvalidMAC } - return append(in[:recordlayer.HeaderSize], body[:dataEnd]...), nil + return append(in[:header.Size()], body[:dataEnd]...), nil } -func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.ContentType, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash) ([]byte, error) { - h := hmac.New(hf, key) +func (c *CBC) hmac( + epoch uint16, + sequenceNumber uint64, + contentType protocol.ContentType, + protocolVersion protocol.Version, + payload []byte, + key []byte, + hf func() hash.Hash, +) ([]byte, error) { + hmacHash := hmac.New(hf, key) msg := make([]byte, 13) @@ -165,13 +192,59 @@ func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.Con msg[8] = byte(contentType) msg[9] = protocolVersion.Major msg[10] = protocolVersion.Minor - binary.BigEndian.PutUint16(msg[11:], uint16(len(payload))) + binary.BigEndian.PutUint16(msg[11:], uint16(len(payload))) //nolint:gosec //G115 + + if _, err := hmacHash.Write(msg); err != nil { + return nil, err + } + if _, err := hmacHash.Write(payload); err != nil { + return nil, err + } + + return hmacHash.Sum(nil), nil +} - if _, err := h.Write(msg); err != nil { +// hmacCID calculates a MAC according to +// https://datatracker.ietf.org/doc/html/rfc9146#section-5.1 +func (c *CBC) hmacCID( + epoch uint16, + sequenceNumber uint64, + protocolVersion protocol.Version, + payload []byte, + key []byte, + hf func() hash.Hash, + cid []byte, +) ([]byte, error) { + // Must unmarshal inner plaintext in orde to perform MAC. + ip := &recordlayer.InnerPlaintext{} + if err := ip.Unmarshal(payload); err != nil { return nil, err - } else if _, err := h.Write(payload); err != nil { + } + + hmacHash := hmac.New(hf, key) + + var msg cryptobyte.Builder + + msg.AddUint64(seqNumPlaceholder) + msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) + msg.AddUint8(uint8(len(cid))) //nolint:gosec //G115 + msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) + msg.AddUint8(protocolVersion.Major) + msg.AddUint8(protocolVersion.Minor) + msg.AddUint16(epoch) + util.AddUint48(&msg, sequenceNumber) + msg.AddBytes(cid) + msg.AddUint16(uint16(len(payload))) //nolint:gosec //G115 + msg.AddBytes(ip.Content) + msg.AddUint8(uint8(ip.RealType)) + msg.AddBytes(make([]byte, ip.Zeros)) + + if _, err := hmacHash.Write(msg.BytesOrPanic()); err != nil { + return nil, err + } + if _, err := hmacHash.Write(payload); err != nil { return nil, err } - return h.Sum(nil), nil + return hmacHash.Sum(nil), nil } diff --git a/pkg/crypto/ciphersuite/ccm.go b/pkg/crypto/ciphersuite/ccm.go index 24050dc92..9a40cae8f 100644 --- a/pkg/crypto/ciphersuite/ccm.go +++ b/pkg/crypto/ciphersuite/ccm.go @@ -9,29 +9,29 @@ import ( "encoding/binary" "fmt" - "github.com/pion/dtls/v2/pkg/crypto/ccm" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/crypto/ccm" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) -// CCMTagLen is the length of Authentication Tag +// CCMTagLen is the length of Authentication Tag. type CCMTagLen int -// CCM Enums +// CCM Enums. const ( CCMTagLength8 CCMTagLen = 8 CCMTagLength CCMTagLen = 16 ccmNonceLength = 12 ) -// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. type CCM struct { localCCM, remoteCCM ccm.CCM localWriteIV, remoteWriteIV []byte tagLen CCMTagLen } -// NewCCM creates a DTLS GCM Cipher +// NewCCM creates a DTLS GCM Cipher. func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*CCM, error) { localBlock, err := aes.NewCipher(localKey) if err != nil { @@ -60,48 +60,60 @@ func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV [ }, nil } -// Encrypt encrypt a DTLS RecordLayer message +// Encrypt encrypt a DTLS RecordLayer message. func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { - payload := raw[recordlayer.HeaderSize:] - raw = raw[:recordlayer.HeaderSize] + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] nonce := append(append([]byte{}, c.localWriteIV[:4]...), make([]byte, 8)...) if _, err := rand.Read(nonce[4:]); err != nil { return nil, err } - additionalData := generateAEADAdditionalData(&pkt.Header, len(payload)) + var additionalData []byte + if pkt.Header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&pkt.Header, len(payload)) + } else { + additionalData = generateAEADAdditionalData(&pkt.Header, len(payload)) + } encryptedPayload := c.localCCM.Seal(nil, nonce, payload, additionalData) encryptedPayload = append(nonce[4:], encryptedPayload...) raw = append(raw, encryptedPayload...) // Update recordLayer size to include explicit nonce - binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize)) + binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) //nolint:gosec //G115 + return raw, nil } -// Decrypt decrypts a DTLS RecordLayer message -func (c *CCM) Decrypt(in []byte) ([]byte, error) { - var h recordlayer.Header - err := h.Unmarshal(in) - switch { - case err != nil: +// Decrypt decrypts a DTLS RecordLayer message. +func (c *CCM) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + if err := header.Unmarshal(in); err != nil { return nil, err - case h.ContentType == protocol.ContentTypeChangeCipherSpec: + } + switch { + case header.ContentType == protocol.ContentTypeChangeCipherSpec: // Nothing to encrypt with ChangeCipherSpec return in, nil - case len(in) <= (8 + recordlayer.HeaderSize): + case len(in) <= (8 + header.Size()): return nil, errNotEnoughRoomForNonce } - nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...) - out := in[recordlayer.HeaderSize+8:] + nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[header.Size():header.Size()+8]...) + out := in[header.Size()+8:] - additionalData := generateAEADAdditionalData(&h, len(out)-int(c.tagLen)) + var additionalData []byte + if header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&header, len(out)-int(c.tagLen)) + } else { + additionalData = generateAEADAdditionalData(&header, len(out)-int(c.tagLen)) + } + var err error out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData) if err != nil { return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint } - return append(in[:recordlayer.HeaderSize], out...), nil + + return append(in[:header.Size()], out...), nil } diff --git a/pkg/crypto/ciphersuite/ciphersuite.go b/pkg/crypto/ciphersuite/ciphersuite.go index 9d9fb7418..5c01de580 100644 --- a/pkg/crypto/ciphersuite/ciphersuite.go +++ b/pkg/crypto/ciphersuite/ciphersuite.go @@ -8,19 +8,32 @@ import ( "encoding/binary" "errors" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "golang.org/x/crypto/cryptobyte" +) + +const ( + // 8 bytes of 0xff. + // https://datatracker.ietf.org/doc/html/rfc9146#name-record-payload-protection + seqNumPlaceholder = 0xffffffffffffffff ) var ( - errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} //nolint:goerr113 - errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} //nolint:goerr113 - errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} //nolint:goerr113 - errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} //nolint:goerr113 + //nolint:goerr113 + errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} + //nolint:goerr113 + errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} + //nolint:goerr113 + errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} + //nolint:goerr113 + errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} ) func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte { var additionalData [13]byte + // SequenceNumber MUST be set first // we only want uint48, clobbering an extra 2 (using uint64, Golang doesn't have uint48) binary.BigEndian.PutUint64(additionalData[:], h.SequenceNumber) @@ -28,11 +41,31 @@ func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte { additionalData[8] = byte(h.ContentType) additionalData[9] = h.Version.Major additionalData[10] = h.Version.Minor + //nolint:gosec //G115 binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen)) return additionalData[:] } +// generateAEADAdditionalDataCID generates additional data for AEAD ciphers +// according to https://datatracker.ietf.org/doc/html/rfc9146#name-aead-ciphers +func generateAEADAdditionalDataCID(h *recordlayer.Header, payloadLen int) []byte { + var builder cryptobyte.Builder + + builder.AddUint64(seqNumPlaceholder) + builder.AddUint8(uint8(protocol.ContentTypeConnectionID)) + builder.AddUint8(uint8(len(h.ConnectionID))) //nolint:gosec //G115 + builder.AddUint8(uint8(protocol.ContentTypeConnectionID)) + builder.AddUint8(h.Version.Major) + builder.AddUint8(h.Version.Minor) + builder.AddUint16(h.Epoch) + util.AddUint48(&builder, h.SequenceNumber) + builder.AddBytes(h.ConnectionID) + builder.AddUint16(uint16(payloadLen)) //nolint:gosec //G115 + + return builder.BytesOrPanic() +} + // examinePadding returns, in constant time, the length of the padding to remove // from the end of payload. It also returns a byte which is equal to 255 if the // padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. @@ -44,9 +77,9 @@ func examinePadding(payload []byte) (toRemove int, good byte) { } paddingLen := payload[len(payload)-1] - t := uint(len(payload)-1) - uint(paddingLen) + t := uint(len(payload)-1) - uint(paddingLen) //nolint:gosec //G115 // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good = byte(int32(^t) >> 31) + good = byte(int32(^t) >> 31) //nolint:gosec //G115 // The maximum possible padding length plus the actual length field toCheck := 256 @@ -56,9 +89,9 @@ func examinePadding(payload []byte) (toRemove int, good byte) { } for i := 0; i < toCheck; i++ { - t := uint(paddingLen) - uint(i) + t := uint(paddingLen) - uint(i) //nolint:gosec //G115 // if i <= paddingLen then the MSB of t is zero - mask := byte(int32(^t) >> 31) + mask := byte(int32(^t) >> 31) //nolint:gosec //G115 b := payload[len(payload)-1-i] good &^= mask&paddingLen ^ mask&b } @@ -68,7 +101,7 @@ func examinePadding(payload []byte) (toRemove int, good byte) { good &= good << 4 good &= good << 2 good &= good << 1 - good = uint8(int8(good) >> 7) + good = uint8(int8(good) >> 7) //nolint:gosec //G115 toRemove = int(paddingLen) + 1 diff --git a/pkg/crypto/ciphersuite/ciphersuite_test.go b/pkg/crypto/ciphersuite/ciphersuite_test.go new file mode 100644 index 000000000..3767d9661 --- /dev/null +++ b/pkg/crypto/ciphersuite/ciphersuite_test.go @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package ciphersuite provides the crypto operations needed for a DTLS CipherSuite +package ciphersuite + +import ( + "bytes" + "testing" + + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" +) + +func TestGenerateAEADAdditionalDataCID(t *testing.T) { + cases := map[string]struct { + reason string + header *recordlayer.Header + payloadLen int + expected []byte + }{ + "WithConnectionID": { + reason: "Should successfully generate additional data with valid header", + header: &recordlayer.Header{ + ContentType: protocol.ContentTypeConnectionID, + ConnectionID: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + Version: protocol.Version1_2, + Epoch: 2, + SequenceNumber: 277, + }, + payloadLen: 1784, + expected: []byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 25, 8, 25, 254, 253, + 0, 2, 0, 0, 0, 0, 1, 21, 1, 2, 3, 4, 5, 6, 7, 8, 6, 248, + }, + }, + "IgnoreContentType": { + reason: "Should use Connection ID content type regardless of header content type.", + header: &recordlayer.Header{ + ContentType: protocol.ContentTypeAlert, + ConnectionID: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + Version: protocol.Version1_2, + Epoch: 2, + SequenceNumber: 277, + }, + payloadLen: 1784, + expected: []byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 25, 8, 25, 254, 253, + 0, 2, 0, 0, 0, 0, 1, 21, 1, 2, 3, 4, 5, 6, 7, 8, 6, 248, + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + data := generateAEADAdditionalDataCID(tc.header, tc.payloadLen) + if !bytes.Equal(data, tc.expected) { + t.Errorf("%s\nUnexpected additional data\nwant: %v\ngot: %v", tc.reason, tc.expected, data) + } + }) + } +} diff --git a/pkg/crypto/ciphersuite/gcm.go b/pkg/crypto/ciphersuite/gcm.go index c0fd1f76f..1c50dd967 100644 --- a/pkg/crypto/ciphersuite/gcm.go +++ b/pkg/crypto/ciphersuite/gcm.go @@ -10,8 +10,8 @@ import ( "encoding/binary" "fmt" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) const ( @@ -19,13 +19,13 @@ const ( gcmNonceLength = 12 ) -// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets. type GCM struct { localGCM, remoteGCM cipher.AEAD localWriteIV, remoteWriteIV []byte } -// NewGCM creates a DTLS GCM Cipher +// NewGCM creates a DTLS GCM Cipher. func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, error) { localBlock, err := aes.NewCipher(localKey) if err != nil { @@ -53,10 +53,10 @@ func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, erro }, nil } -// Encrypt encrypt a DTLS RecordLayer message +// Encrypt encrypt a DTLS RecordLayer message. func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { - payload := raw[recordlayer.HeaderSize:] - raw = raw[:recordlayer.HeaderSize] + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] nonce := make([]byte, gcmNonceLength) copy(nonce, g.localWriteIV[:4]) @@ -64,7 +64,12 @@ func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) return nil, err } - additionalData := generateAEADAdditionalData(&pkt.Header, len(payload)) + var additionalData []byte + if pkt.Header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&pkt.Header, len(payload)) + } else { + additionalData = generateAEADAdditionalData(&pkt.Header, len(payload)) + } encryptedPayload := g.localGCM.Seal(nil, nonce, payload, additionalData) r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload)) copy(r, raw) @@ -72,32 +77,38 @@ func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) copy(r[len(raw)+len(nonce[4:]):], encryptedPayload) // Update recordLayer size to include explicit nonce - binary.BigEndian.PutUint16(r[recordlayer.HeaderSize-2:], uint16(len(r)-recordlayer.HeaderSize)) + binary.BigEndian.PutUint16(r[pkt.Header.Size()-2:], uint16(len(r)-pkt.Header.Size())) //nolint:gosec //G115 + return r, nil } -// Decrypt decrypts a DTLS RecordLayer message -func (g *GCM) Decrypt(in []byte) ([]byte, error) { - var h recordlayer.Header - err := h.Unmarshal(in) +// Decrypt decrypts a DTLS RecordLayer message. +func (g *GCM) Decrypt(header recordlayer.Header, in []byte) ([]byte, error) { + err := header.Unmarshal(in) switch { case err != nil: return nil, err - case h.ContentType == protocol.ContentTypeChangeCipherSpec: + case header.ContentType == protocol.ContentTypeChangeCipherSpec: // Nothing to encrypt with ChangeCipherSpec return in, nil - case len(in) <= (8 + recordlayer.HeaderSize): + case len(in) <= (8 + header.Size()): return nil, errNotEnoughRoomForNonce } nonce := make([]byte, 0, gcmNonceLength) - nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...) - out := in[recordlayer.HeaderSize+8:] + nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[header.Size():header.Size()+8]...) + out := in[header.Size()+8:] - additionalData := generateAEADAdditionalData(&h, len(out)-gcmTagLength) + var additionalData []byte + if header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&header, len(out)-gcmTagLength) + } else { + additionalData = generateAEADAdditionalData(&header, len(out)-gcmTagLength) + } out, err = g.remoteGCM.Open(out[:0], nonce, out, additionalData) if err != nil { return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint } - return append(in[:recordlayer.HeaderSize], out...), nil + + return append(in[:header.Size()], out...), nil } diff --git a/pkg/crypto/clientcertificate/client_certificate.go b/pkg/crypto/clientcertificate/client_certificate.go index ddfa39ebe..0a510d4d4 100644 --- a/pkg/crypto/clientcertificate/client_certificate.go +++ b/pkg/crypto/clientcertificate/client_certificate.go @@ -10,13 +10,13 @@ package clientcertificate // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-2 type Type byte -// ClientCertificateType enums +// ClientCertificateType enums. const ( RSASign Type = 1 ECDSASign Type = 64 ) -// Types returns all valid ClientCertificate Types +// Types returns all valid ClientCertificate Types. func Types() map[Type]bool { return map[Type]bool{ RSASign: true, diff --git a/pkg/crypto/elliptic/elliptic.go b/pkg/crypto/elliptic/elliptic.go index 126523872..b98fb4f9c 100644 --- a/pkg/crypto/elliptic/elliptic.go +++ b/pkg/crypto/elliptic/elliptic.go @@ -20,12 +20,12 @@ var errInvalidNamedCurve = errors.New("invalid named curve") // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 type CurvePointFormat byte -// CurvePointFormat enums +// CurvePointFormat enums. const ( CurvePointFormatUncompressed CurvePointFormat = 0 ) -// Keypair is a Curve with a Private/Public Keypair +// Keypair is a Curve with a Private/Public Keypair. type Keypair struct { Curve Curve PublicKey []byte @@ -37,12 +37,12 @@ type Keypair struct { // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-10 type CurveType byte -// CurveType enums +// CurveType enums. const ( CurveTypeNamedCurve CurveType = 0x03 ) -// CurveTypes returns all known curves +// CurveTypes returns all known curves. func CurveTypes() map[CurveType]struct{} { return map[CurveType]struct{}{ CurveTypeNamedCurve: {}, @@ -54,7 +54,7 @@ func CurveTypes() map[CurveType]struct{} { // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 type Curve uint16 -// Curve enums +// Curve enums. const ( P256 Curve = 0x0017 P384 Curve = 0x0018 @@ -70,10 +70,11 @@ func (c Curve) String() string { case X25519: return "X25519" } + return fmt.Sprintf("%#x", uint16(c)) } -// Curves returns all curves we implement +// Curves returns all curves we implement. func Curves() map[Curve]bool { return map[Curve]bool{ X25519: true, @@ -82,7 +83,7 @@ func Curves() map[Curve]bool { } } -// GenerateKeypair generates a keypair for the given Curve +// GenerateKeypair generates a keypair for the given Curve. func GenerateKeypair(c Curve) (*Keypair, error) { switch c { //nolint:revive case X25519: @@ -95,6 +96,7 @@ func GenerateKeypair(c Curve) (*Keypair, error) { copy(private[:], tmp) curve25519.ScalarBaseMult(&public, &private) + return &Keypair{X25519, public[:], private[:]}, nil case P256: return ellipticCurveKeypair(P256, elliptic.P256(), elliptic.P256()) diff --git a/pkg/crypto/fingerprint/fingerprint.go b/pkg/crypto/fingerprint/fingerprint.go index 7c66265c7..43d4a4000 100644 --- a/pkg/crypto/fingerprint/fingerprint.go +++ b/pkg/crypto/fingerprint/fingerprint.go @@ -16,7 +16,7 @@ var ( errInvalidFingerprintLength = errors.New("fingerprint: invalid fingerprint length") ) -// Fingerprint creates a fingerprint for a certificate using the specified hash algorithm +// Fingerprint creates a fingerprint for a certificate using the specified hash algorithm. func Fingerprint(cert *x509.Certificate, algo crypto.Hash) (string, error) { if !algo.Available() { return "", errHashUnavailable diff --git a/pkg/crypto/fingerprint/fingerprint_test.go b/pkg/crypto/fingerprint/fingerprint_test.go index 3266d1153..0a22d8b6c 100644 --- a/pkg/crypto/fingerprint/fingerprint_test.go +++ b/pkg/crypto/fingerprint/fingerprint_test.go @@ -12,23 +12,29 @@ import ( func TestFingerprint(t *testing.T) { rawCertificate := []byte{ - 0x30, 0x82, 0x01, 0x98, 0x30, 0x82, 0x01, 0x3d, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x11, 0x00, 0xa9, 0x91, 0x76, 0x0a, 0xcd, 0x97, 0x4c, 0x36, 0xba, - 0xc9, 0xc2, 0x66, 0x91, 0x47, 0x6c, 0xac, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x2b, 0x31, 0x29, 0x30, 0x27, - 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, - 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x31, 0x31, 0x30, 0x30, - 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x32, 0x31, 0x30, 0x30, 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x30, 0x2b, 0x31, 0x29, - 0x30, 0x27, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, - 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, - 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0x9c, 0x12, 0x8e, 0xb5, 0x21, 0x23, 0x9f, - 0x35, 0x5d, 0x39, 0x64, 0xc3, 0x75, 0x81, 0xa4, 0xc8, 0xc8, 0x08, 0x8a, 0xa8, 0x42, 0x30, 0x30, 0x65, 0xb8, 0xb1, 0x3e, 0x4a, 0x51, 0x86, 0xeb, 0xad, - 0x03, 0x02, 0x35, 0x83, 0xc4, 0x19, 0x3a, 0x5b, 0x79, 0x83, 0xec, 0x59, 0x0e, 0x4f, 0x99, 0xb1, 0xd2, 0xf0, 0x50, 0xfa, 0xb8, 0x5f, 0xfc, 0x88, 0xf3, - 0x15, 0xed, 0xb8, 0x14, 0xf0, 0xba, 0xcd, 0xa3, 0x42, 0x30, 0x40, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, - 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, - 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, - 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x49, 0x00, 0x30, 0x46, 0x02, 0x21, 0x00, 0xcd, 0x44, 0xb1, 0xf2, 0x09, - 0xe5, 0xf1, 0xf4, 0xc9, 0x26, 0x95, 0x9a, 0x2d, 0x6d, 0xf3, 0x0c, 0xb8, 0xeb, 0x27, 0x2d, 0x81, 0x19, 0xe9, 0x51, 0xf7, 0xad, 0x64, 0x7d, 0x42, 0x32, - 0x9e, 0xf8, 0x02, 0x21, 0x00, 0xee, 0xad, 0x96, 0x41, 0xf1, 0x12, 0xd0, 0x6b, 0xcd, 0x09, 0xf0, 0x3c, 0x67, 0xb3, 0xdd, 0xed, 0x0a, 0xf1, 0xd8, 0x41, - 0x4f, 0x61, 0xfd, 0x53, 0x1d, 0xf5, 0x27, 0xbe, 0x6d, 0x0b, 0xe2, 0x0d, + 0x30, 0x82, 0x01, 0x98, 0x30, 0x82, 0x01, 0x3d, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x11, 0x00, 0xa9, 0x91, + 0x76, 0x0a, 0xcd, 0x97, 0x4c, 0x36, 0xba, 0xc9, 0xc2, 0x66, 0x91, 0x47, 0x6c, 0xac, 0x30, 0x0a, 0x06, 0x08, + 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x2b, 0x31, 0x29, 0x30, 0x27, 0x06, 0x03, 0x55, 0x04, + 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x1e, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x31, 0x31, 0x30, 0x30, 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x17, 0x0d, + 0x31, 0x39, 0x31, 0x32, 0x31, 0x30, 0x30, 0x39, 0x30, 0x34, 0x32, 0x33, 0x5a, 0x30, 0x2b, 0x31, 0x29, 0x30, + 0x27, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, + 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0x9c, 0x12, 0x8e, 0xb5, 0x21, + 0x23, 0x9f, 0x35, 0x5d, 0x39, 0x64, 0xc3, 0x75, 0x81, 0xa4, 0xc8, 0xc8, 0x08, 0x8a, 0xa8, 0x42, 0x30, 0x30, + 0x65, 0xb8, 0xb1, 0x3e, 0x4a, 0x51, 0x86, 0xeb, 0xad, 0x03, 0x02, 0x35, 0x83, 0xc4, 0x19, 0x3a, 0x5b, 0x79, + 0x83, 0xec, 0x59, 0x0e, 0x4f, 0x99, 0xb1, 0xd2, 0xf0, 0x50, 0xfa, 0xb8, 0x5f, 0xfc, 0x88, 0xf3, 0x15, 0xed, + 0xb8, 0x14, 0xf0, 0xba, 0xcd, 0xa3, 0x42, 0x30, 0x40, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, + 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, + 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, + 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, + 0xff, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x49, 0x00, 0x30, 0x46, + 0x02, 0x21, 0x00, 0xcd, 0x44, 0xb1, 0xf2, 0x09, 0xe5, 0xf1, 0xf4, 0xc9, 0x26, 0x95, 0x9a, 0x2d, 0x6d, 0xf3, + 0x0c, 0xb8, 0xeb, 0x27, 0x2d, 0x81, 0x19, 0xe9, 0x51, 0xf7, 0xad, 0x64, 0x7d, 0x42, 0x32, 0x9e, 0xf8, 0x02, + 0x21, 0x00, 0xee, 0xad, 0x96, 0x41, 0xf1, 0x12, 0xd0, 0x6b, 0xcd, 0x09, 0xf0, 0x3c, 0x67, 0xb3, 0xdd, 0xed, + 0x0a, 0xf1, 0xd8, 0x41, 0x4f, 0x61, 0xfd, 0x53, 0x1d, 0xf5, 0x27, 0xbe, 0x6d, 0x0b, 0xe2, 0x0d, } cert, err := x509.ParseCertificate(rawCertificate) @@ -36,6 +42,7 @@ func TestFingerprint(t *testing.T) { t.Fatal(err) } + //nolint:lll const expectedSHA256 = "60:ef:f5:79:ad:8d:3e:d7:e8:4d:5a:5a:d6:1e:71:2d:47:52:a5:cb:df:34:37:87:10:a5:4e:d7:2a:2c:37:34" actualSHA256, err := Fingerprint(cert, crypto.SHA256) if err != nil { diff --git a/pkg/crypto/fingerprint/hash.go b/pkg/crypto/fingerprint/hash.go index 3f988ffb7..8aacd673d 100644 --- a/pkg/crypto/fingerprint/hash.go +++ b/pkg/crypto/fingerprint/hash.go @@ -22,11 +22,12 @@ func nameToHash() map[string]crypto.Hash { } } -// HashFromString allows looking up a hash algorithm by it's string representation +// HashFromString allows looking up a hash algorithm by it's string representation. func HashFromString(s string) (crypto.Hash, error) { if h, ok := nameToHash()[strings.ToLower(s)]; ok { return h, nil } + return 0, errInvalidHashAlgorithm } @@ -37,5 +38,6 @@ func StringFromHash(hash crypto.Hash) (string, error) { return s, nil } } + return "", errInvalidHashAlgorithm } diff --git a/pkg/crypto/fingerprint/hash_test.go b/pkg/crypto/fingerprint/hash_test.go index b71a7a363..51a29bfce 100644 --- a/pkg/crypto/fingerprint/hash_test.go +++ b/pkg/crypto/fingerprint/hash_test.go @@ -22,7 +22,7 @@ func TestHashFromString(t *testing.T) { t.Fatalf("Unexpected error for valid hash name, got '%v'", err) } if h != crypto.SHA512 { - t.Errorf("Expected hash ID of %d, got %d", int(crypto.SHA512), int(h)) + t.Errorf("Expected hash ID of %d, got %d", int(crypto.SHA512), int(h)) //nolint:gosec //G115 } }) t.Run("ValidCaseInsensitiveHashAlgorithm", func(t *testing.T) { @@ -31,6 +31,7 @@ func TestHashFromString(t *testing.T) { t.Fatalf("Unexpected error for valid hash name, got '%v'", err) } if h != crypto.SHA512 { + //nolint:gosec // G115 t.Errorf("Expected hash ID of %d, got %d", int(crypto.SHA512), int(h)) } }) diff --git a/pkg/crypto/hash/hash.go b/pkg/crypto/hash/hash.go index 9966626e3..a390170fe 100644 --- a/pkg/crypto/hash/hash.go +++ b/pkg/crypto/hash/hash.go @@ -16,7 +16,7 @@ import ( //nolint:gci // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18 type Algorithm uint16 -// Supported hash algorithms +// Supported hash algorithms. const ( None Algorithm = 0 // Blacklisted MD5 Algorithm = 1 // Blacklisted @@ -28,7 +28,7 @@ const ( Ed25519 Algorithm = 8 ) -// String makes hashAlgorithm printable +// String makes hashAlgorithm printable. func (a Algorithm) String() string { switch a { case None: @@ -52,28 +52,34 @@ func (a Algorithm) String() string { } } -// Digest performs a digest on the passed value +// Digest performs a digest on the passed value. func (a Algorithm) Digest(b []byte) []byte { switch a { case None: return nil case MD5: hash := md5.Sum(b) // #nosec + return hash[:] case SHA1: hash := sha1.Sum(b) // #nosec + return hash[:] case SHA224: hash := sha256.Sum224(b) + return hash[:] case SHA256: hash := sha256.Sum256(b) + return hash[:] case SHA384: hash := sha512.Sum384(b) + return hash[:] case SHA512: hash := sha512.Sum512(b) + return hash[:] default: return nil @@ -81,6 +87,7 @@ func (a Algorithm) Digest(b []byte) []byte { } // Insecure returns if the given HashAlgorithm is considered secure in DTLS 1.2 +// . func (a Algorithm) Insecure() bool { switch a { case None, MD5, SHA1: @@ -90,7 +97,7 @@ func (a Algorithm) Insecure() bool { } } -// CryptoHash returns the crypto.Hash implementation for the given HashAlgorithm +// CryptoHash returns the crypto.Hash implementation for the given HashAlgorithm. func (a Algorithm) CryptoHash() crypto.Hash { switch a { case None: @@ -114,7 +121,7 @@ func (a Algorithm) CryptoHash() crypto.Hash { } } -// Algorithms returns all the supported Hash Algorithms +// Algorithms returns all the supported Hash Algorithms. func Algorithms() map[Algorithm]struct{} { return map[Algorithm]struct{}{ None: {}, diff --git a/pkg/crypto/hash/hash_test.go b/pkg/crypto/hash/hash_test.go index e6711c69e..c6ba906ae 100644 --- a/pkg/crypto/hash/hash_test.go +++ b/pkg/crypto/hash/hash_test.go @@ -6,7 +6,7 @@ package hash import ( "testing" - "github.com/pion/dtls/v2/pkg/crypto/fingerprint" + "github.com/pion/dtls/v3/pkg/crypto/fingerprint" ) func TestHashAlgorithm_StringRoundtrip(t *testing.T) { @@ -22,7 +22,10 @@ func TestHashAlgorithm_StringRoundtrip(t *testing.T) { t.Fatalf("fingerprint.HashFromString failed: %v", err) } if hash1 != hash2 { - t.Errorf("Hash algorithm mismatch, input: %d, after roundtrip: %d", int(hash1), int(hash2)) + t.Errorf( + "Hash algorithm mismatch, input: %d, after roundtrip: %d", + int(hash1), int(hash2), //nolint:gosec // G115 + ) } } } diff --git a/pkg/crypto/prf/prf.go b/pkg/crypto/prf/prf.go index 6e7b3ecba..9eace83b7 100644 --- a/pkg/crypto/prf/prf.go +++ b/pkg/crypto/prf/prf.go @@ -13,8 +13,8 @@ import ( //nolint:gci "hash" "math" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" "golang.org/x/crypto/curve25519" ) @@ -26,10 +26,10 @@ const ( verifyDataServerLabel = "server finished" ) -// HashFunc allows callers to decide what hash is used in PRF +// HashFunc allows callers to decide what hash is used in PRF. type HashFunc func() hash.Hash -// EncryptionKeys is all the state needed for a TLS CipherSuite +// EncryptionKeys is all the state needed for a TLS CipherSuite. type EncryptionKeys struct { MasterSecret []byte ClientMACKey []byte @@ -68,7 +68,7 @@ func (e *EncryptionKeys) String() string { // // https://tools.ietf.org/html/rfc4279#section-2 func PSKPreMasterSecret(psk []byte) []byte { - pskLen := uint16(len(psk)) + pskLen := uint16(len(psk)) //nolint:gosec // G115 out := append(make([]byte, 2+pskLen+2), psk...) binary.BigEndian.PutUint16(out, pskLen) @@ -89,7 +89,7 @@ func EcdhePSKPreMasterSecret(psk, publicKey, privateKey []byte, curve elliptic.C // write preMasterSecret length offset := 0 - binary.BigEndian.PutUint16(out[offset:], uint16(len(preMasterSecret))) + binary.BigEndian.PutUint16(out[offset:], uint16(len(preMasterSecret))) //nolint:gosec // G115 offset += 2 // write preMasterSecret @@ -97,15 +97,16 @@ func EcdhePSKPreMasterSecret(psk, publicKey, privateKey []byte, curve elliptic.C offset += len(preMasterSecret) // write psk length - binary.BigEndian.PutUint16(out[offset:], uint16(len(psk))) + binary.BigEndian.PutUint16(out[offset:], uint16(len(psk))) //nolint:gosec // G115 offset += 2 // write psk copy(out[offset:], psk) + return out, nil } -// PreMasterSecret implements TLS 1.2 Premaster Secret generation given a keypair and a curve +// PreMasterSecret implements TLS 1.2 Premaster Secret generation given a keypair and a curve. func PreMasterSecret(publicKey, privateKey []byte, curve elliptic.Curve) ([]byte, error) { switch curve { case elliptic.X25519: @@ -129,6 +130,7 @@ func ellipticCurvePreMasterSecret(publicKey, privateKey []byte, c1, c2 ellipticS preMasterSecret := make([]byte, (c2.Params().BitSize+7)>>3) resultBytes := result.Bytes() copy(preMasterSecret[len(preMasterSecret)-len(resultBytes):], resultBytes) + return preMasterSecret, nil } @@ -155,12 +157,13 @@ func ellipticCurvePreMasterSecret(publicKey, privateKey []byte, c1, c2 ellipticS // output data. // // https://tools.ietf.org/html/rfc4346w -func PHash(secret, seed []byte, requestedLength int, h HashFunc) ([]byte, error) { +func PHash(secret, seed []byte, requestedLength int, hashFunc HashFunc) ([]byte, error) { hmacSHA256 := func(key, data []byte) ([]byte, error) { - mac := hmac.New(h, key) + mac := hmac.New(hashFunc, key) if _, err := mac.Write(data); err != nil { return nil, err } + return mac.Sum(nil), nil } @@ -168,7 +171,7 @@ func PHash(secret, seed []byte, requestedLength int, h HashFunc) ([]byte, error) lastRound := seed out := []byte{} - iterations := int(math.Ceil(float64(requestedLength) / float64(h().Size()))) + iterations := int(math.Ceil(float64(requestedLength) / float64(hashFunc().Size()))) for i := 0; i < iterations; i++ { lastRound, err = hmacSHA256(secret, lastRound) if err != nil { @@ -188,18 +191,24 @@ func PHash(secret, seed []byte, requestedLength int, h HashFunc) ([]byte, error) // https://tools.ietf.org/html/rfc7627 func ExtendedMasterSecret(preMasterSecret, sessionHash []byte, h HashFunc) ([]byte, error) { seed := append([]byte(extendedMasterSecretLabel), sessionHash...) + return PHash(preMasterSecret, seed, 48, h) } -// MasterSecret generates a TLS 1.2 MasterSecret +// MasterSecret generates a TLS 1.2 MasterSecret. func MasterSecret(preMasterSecret, clientRandom, serverRandom []byte, h HashFunc) ([]byte, error) { seed := append(append([]byte(masterSecretLabel), clientRandom...), serverRandom...) + return PHash(preMasterSecret, seed, 48, h) } // GenerateEncryptionKeys is the final step TLS 1.2 PRF. Given all state generated so far generates -// the final keys need for encryption -func GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int, h HashFunc) (*EncryptionKeys, error) { +// the final keys need for encryption. +func GenerateEncryptionKeys( + masterSecret, clientRandom, serverRandom []byte, + macLen, keyLen, ivLen int, + h HashFunc, +) (*EncryptionKeys, error) { seed := append(append([]byte(keyExpansionLabel), serverRandom...), clientRandom...) keyMaterial, err := PHash(masterSecret, seed, (2*macLen)+(2*keyLen)+(2*ivLen), h) if err != nil { @@ -241,15 +250,16 @@ func prfVerifyData(masterSecret, handshakeBodies []byte, label string, hashFunc } seed := append([]byte(label), h.Sum(nil)...) + return PHash(masterSecret, seed, 12, hashFunc) } -// VerifyDataClient is caled on the Client Side to either verify or generate the VerifyData message +// VerifyDataClient is caled on the Client Side to either verify or generate the VerifyData message. func VerifyDataClient(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) { return prfVerifyData(masterSecret, handshakeBodies, verifyDataClientLabel, h) } -// VerifyDataServer is caled on the Server Side to either verify or generate the VerifyData message +// VerifyDataServer is caled on the Server Side to either verify or generate the VerifyData message. func VerifyDataServer(masterSecret, handshakeBodies []byte, h HashFunc) ([]byte, error) { return prfVerifyData(masterSecret, handshakeBodies, verifyDataServerLabel, h) } diff --git a/pkg/crypto/prf/prf_test.go b/pkg/crypto/prf/prf_test.go index bd375b9aa..9d2f90451 100644 --- a/pkg/crypto/prf/prf_test.go +++ b/pkg/crypto/prf/prf_test.go @@ -9,13 +9,22 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" ) func TestPreMasterSecret(t *testing.T) { - privateKey := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f} - publicKey := []byte{0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15} - expectedPreMasterSecret := []byte{0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24} + privateKey := []byte{ + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + } + publicKey := []byte{ + 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, + 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, + } + expectedPreMasterSecret := []byte{ + 0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, + 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24, + } preMasterSecret, err := PreMasterSecret(publicKey, privateKey, elliptic.X25519) if err != nil { @@ -26,10 +35,23 @@ func TestPreMasterSecret(t *testing.T) { } func TestMasterSecret(t *testing.T) { - preMasterSecret := []byte{0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24} - clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} - serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f} - expectedMasterSecret := []byte{0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c} + preMasterSecret := []byte{ + 0xdf, 0x4a, 0x29, 0x1b, 0xaa, 0x1e, 0xb7, 0xcf, 0xa6, 0x93, 0x4b, 0x29, 0xb4, 0x74, 0xba, 0xad, + 0x26, 0x97, 0xe2, 0x9f, 0x1f, 0x92, 0x0d, 0xcc, 0x77, 0xc8, 0xa0, 0xa0, 0x88, 0x44, 0x76, 0x24, + } + clientRandom := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + } + serverRandom := []byte{ + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + } + expectedMasterSecret := []byte{ + 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, + 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, + 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c, + } masterSecret, err := MasterSecret(preMasterSecret, clientRandom, serverRandom, sha256.New) if err != nil { @@ -40,18 +62,32 @@ func TestMasterSecret(t *testing.T) { } func TestEncryptionKeys(t *testing.T) { - clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} - serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f} - masterSecret := []byte{0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c} + clientRandom := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + } + serverRandom := []byte{ + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + } + masterSecret := []byte{ + 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, + 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, + 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c, + } expectedEncryptionKeys := &EncryptionKeys{ - MasterSecret: masterSecret, - ClientMACKey: []byte{}, - ServerMACKey: []byte{}, - ClientWriteKey: []byte{0x1b, 0x7d, 0x11, 0x7c, 0x7d, 0x5f, 0x69, 0x0b, 0xc2, 0x63, 0xca, 0xe8, 0xef, 0x60, 0xaf, 0x0f}, - ServerWriteKey: []byte{0x18, 0x78, 0xac, 0xc2, 0x2a, 0xd8, 0xbd, 0xd8, 0xc6, 0x01, 0xa6, 0x17, 0x12, 0x6f, 0x63, 0x54}, - ClientWriteIV: []byte{0x0e, 0xb2, 0x09, 0x06}, - ServerWriteIV: []byte{0xf7, 0x81, 0xfa, 0xd2}, + MasterSecret: masterSecret, + ClientMACKey: []byte{}, + ServerMACKey: []byte{}, + ClientWriteKey: []byte{ + 0x1b, 0x7d, 0x11, 0x7c, 0x7d, 0x5f, 0x69, 0x0b, 0xc2, 0x63, 0xca, 0xe8, 0xef, 0x60, 0xaf, 0x0f, + }, + ServerWriteKey: []byte{ + 0x18, 0x78, 0xac, 0xc2, 0x2a, 0xd8, 0xbd, 0xd8, 0xc6, 0x01, 0xa6, 0x17, 0x12, 0x6f, 0x63, 0x54, + }, + ClientWriteIV: []byte{0x0e, 0xb2, 0x09, 0x06}, + ServerWriteIV: []byte{0xf7, 0x81, 0xfa, 0xd2}, } keys, err := GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, 0, 16, 4, sha256.New) @@ -63,15 +99,123 @@ func TestEncryptionKeys(t *testing.T) { } func TestVerifyData(t *testing.T) { - clientHello := []byte{0x01, 0x00, 0x00, 0xa1, 0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x00, 0x00, 0x20, 0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30, 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d, 0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, 0x01, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0d, 0x00, 0x12, 0x00, 0x10, 0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03, 0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x12, 0x00, 0x00} - serverHello := []byte{0x02, 0x00, 0x00, 0x2d, 0x03, 0x03, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x00, 0xc0, 0x13, 0x00, 0x00, 0x05, 0xff, 0x01, 0x00, 0x01, 0x00} - serverCertificate := []byte{0x0b, 0x00, 0x03, 0x2b, 0x00, 0x03, 0x28, 0x00, 0x03, 0x25, 0x30, 0x82, 0x03, 0x21, 0x30, 0x82, 0x02, 0x09, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x08, 0x15, 0x5a, 0x92, 0xad, 0xc2, 0x04, 0x8f, 0x90, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x30, 0x22, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x30, 0x2b, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x1c, 0x30, 0x1a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xc4, 0x80, 0x36, 0x06, 0xba, 0xe7, 0x47, 0x6b, 0x08, 0x94, 0x04, 0xec, 0xa7, 0xb6, 0x91, 0x04, 0x3f, 0xf7, 0x92, 0xbc, 0x19, 0xee, 0xfb, 0x7d, 0x74, 0xd7, 0xa8, 0x0d, 0x00, 0x1e, 0x7b, 0x4b, 0x3a, 0x4a, 0xe6, 0x0f, 0xe8, 0xc0, 0x71, 0xfc, 0x73, 0xe7, 0x02, 0x4c, 0x0d, 0xbc, 0xf4, 0xbd, 0xd1, 0x1d, 0x39, 0x6b, 0xba, 0x70, 0x46, 0x4a, 0x13, 0xe9, 0x4a, 0xf8, 0x3d, 0xf3, 0xe1, 0x09, 0x59, 0x54, 0x7b, 0xc9, 0x55, 0xfb, 0x41, 0x2d, 0xa3, 0x76, 0x52, 0x11, 0xe1, 0xf3, 0xdc, 0x77, 0x6c, 0xaa, 0x53, 0x37, 0x6e, 0xca, 0x3a, 0xec, 0xbe, 0xc3, 0xaa, 0xb7, 0x3b, 0x31, 0xd5, 0x6c, 0xb6, 0x52, 0x9c, 0x80, 0x98, 0xbc, 0xc9, 0xe0, 0x28, 0x18, 0xe2, 0x0b, 0xf7, 0xf8, 0xa0, 0x3a, 0xfd, 0x17, 0x04, 0x50, 0x9e, 0xce, 0x79, 0xbd, 0x9f, 0x39, 0xf1, 0xea, 0x69, 0xec, 0x47, 0x97, 0x2e, 0x83, 0x0f, 0xb5, 0xca, 0x95, 0xde, 0x95, 0xa1, 0xe6, 0x04, 0x22, 0xd5, 0xee, 0xbe, 0x52, 0x79, 0x54, 0xa1, 0xe7, 0xbf, 0x8a, 0x86, 0xf6, 0x46, 0x6d, 0x0d, 0x9f, 0x16, 0x95, 0x1a, 0x4c, 0xf7, 0xa0, 0x46, 0x92, 0x59, 0x5c, 0x13, 0x52, 0xf2, 0x54, 0x9e, 0x5a, 0xfb, 0x4e, 0xbf, 0xd7, 0x7a, 0x37, 0x95, 0x01, 0x44, 0xe4, 0xc0, 0x26, 0x87, 0x4c, 0x65, 0x3e, 0x40, 0x7d, 0x7d, 0x23, 0x07, 0x44, 0x01, 0xf4, 0x84, 0xff, 0xd0, 0x8f, 0x7a, 0x1f, 0xa0, 0x52, 0x10, 0xd1, 0xf4, 0xf0, 0xd5, 0xce, 0x79, 0x70, 0x29, 0x32, 0xe2, 0xca, 0xbe, 0x70, 0x1f, 0xdf, 0xad, 0x6b, 0x4b, 0xb7, 0x11, 0x01, 0xf4, 0x4b, 0xad, 0x66, 0x6a, 0x11, 0x13, 0x0f, 0xe2, 0xee, 0x82, 0x9e, 0x4d, 0x02, 0x9d, 0xc9, 0x1c, 0xdd, 0x67, 0x16, 0xdb, 0xb9, 0x06, 0x18, 0x86, 0xed, 0xc1, 0xba, 0x94, 0x21, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x52, 0x30, 0x50, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x89, 0x4f, 0xde, 0x5b, 0xcc, 0x69, 0xe2, 0x52, 0xcf, 0x3e, 0xa3, 0x00, 0xdf, 0xb1, 0x97, 0xb8, 0x1d, 0xe1, 0xc1, 0x46, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x59, 0x16, 0x45, 0xa6, 0x9a, 0x2e, 0x37, 0x79, 0xe4, 0xf6, 0xdd, 0x27, 0x1a, 0xba, 0x1c, 0x0b, 0xfd, 0x6c, 0xd7, 0x55, 0x99, 0xb5, 0xe7, 0xc3, 0x6e, 0x53, 0x3e, 0xff, 0x36, 0x59, 0x08, 0x43, 0x24, 0xc9, 0xe7, 0xa5, 0x04, 0x07, 0x9d, 0x39, 0xe0, 0xd4, 0x29, 0x87, 0xff, 0xe3, 0xeb, 0xdd, 0x09, 0xc1, 0xcf, 0x1d, 0x91, 0x44, 0x55, 0x87, 0x0b, 0x57, 0x1d, 0xd1, 0x9b, 0xdf, 0x1d, 0x24, 0xf8, 0xbb, 0x9a, 0x11, 0xfe, 0x80, 0xfd, 0x59, 0x2b, 0xa0, 0x39, 0x8c, 0xde, 0x11, 0xe2, 0x65, 0x1e, 0x61, 0x8c, 0xe5, 0x98, 0xfa, 0x96, 0xe5, 0x37, 0x2e, 0xef, 0x3d, 0x24, 0x8a, 0xfd, 0xe1, 0x74, 0x63, 0xeb, 0xbf, 0xab, 0xb8, 0xe4, 0xd1, 0xab, 0x50, 0x2a, 0x54, 0xec, 0x00, 0x64, 0xe9, 0x2f, 0x78, 0x19, 0x66, 0x0d, 0x3f, 0x27, 0xcf, 0x20, 0x9e, 0x66, 0x7f, 0xce, 0x5a, 0xe2, 0xe4, 0xac, 0x99, 0xc7, 0xc9, 0x38, 0x18, 0xf8, 0xb2, 0x51, 0x07, 0x22, 0xdf, 0xed, 0x97, 0xf3, 0x2e, 0x3e, 0x93, 0x49, 0xd4, 0xc6, 0x6c, 0x9e, 0xa6, 0x39, 0x6d, 0x74, 0x44, 0x62, 0xa0, 0x6b, 0x42, 0xc6, 0xd5, 0xba, 0x68, 0x8e, 0xac, 0x3a, 0x01, 0x7b, 0xdd, 0xfc, 0x8e, 0x2c, 0xfc, 0xad, 0x27, 0xcb, 0x69, 0xd3, 0xcc, 0xdc, 0xa2, 0x80, 0x41, 0x44, 0x65, 0xd3, 0xae, 0x34, 0x8c, 0xe0, 0xf3, 0x4a, 0xb2, 0xfb, 0x9c, 0x61, 0x83, 0x71, 0x31, 0x2b, 0x19, 0x10, 0x41, 0x64, 0x1c, 0x23, 0x7f, 0x11, 0xa5, 0xd6, 0x5c, 0x84, 0x4f, 0x04, 0x04, 0x84, 0x99, 0x38, 0x71, 0x2b, 0x95, 0x9e, 0xd6, 0x85, 0xbc, 0x5c, 0x5d, 0xd6, 0x45, 0xed, 0x19, 0x90, 0x94, 0x73, 0x40, 0x29, 0x26, 0xdc, 0xb4, 0x0e, 0x34, 0x69, 0xa1, 0x59, 0x41, 0xe8, 0xe2, 0xcc, 0xa8, 0x4b, 0xb6, 0x08, 0x46, 0x36, 0xa0} - serverKeyExchange := []byte{0x0c, 0x00, 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, 0x04, 0x01, 0x01, 0x00, 0x04, 0x02, 0xb6, 0x61, 0xf7, 0xc1, 0x91, 0xee, 0x59, 0xbe, 0x45, 0x37, 0x66, 0x39, 0xbd, 0xc3, 0xd4, 0xbb, 0x81, 0xe1, 0x15, 0xca, 0x73, 0xc8, 0x34, 0x8b, 0x52, 0x5b, 0x0d, 0x23, 0x38, 0xaa, 0x14, 0x46, 0x67, 0xed, 0x94, 0x31, 0x02, 0x14, 0x12, 0xcd, 0x9b, 0x84, 0x4c, 0xba, 0x29, 0x93, 0x4a, 0xaa, 0xcc, 0xe8, 0x73, 0x41, 0x4e, 0xc1, 0x1c, 0xb0, 0x2e, 0x27, 0x2d, 0x0a, 0xd8, 0x1f, 0x76, 0x7d, 0x33, 0x07, 0x67, 0x21, 0xf1, 0x3b, 0xf3, 0x60, 0x20, 0xcf, 0x0b, 0x1f, 0xd0, 0xec, 0xb0, 0x78, 0xde, 0x11, 0x28, 0xbe, 0xba, 0x09, 0x49, 0xeb, 0xec, 0xe1, 0xa1, 0xf9, 0x6e, 0x20, 0x9d, 0xc3, 0x6e, 0x4f, 0xff, 0xd3, 0x6b, 0x67, 0x3a, 0x7d, 0xdc, 0x15, 0x97, 0xad, 0x44, 0x08, 0xe4, 0x85, 0xc4, 0xad, 0xb2, 0xc8, 0x73, 0x84, 0x12, 0x49, 0x37, 0x25, 0x23, 0x80, 0x9e, 0x43, 0x12, 0xd0, 0xc7, 0xb3, 0x52, 0x2e, 0xf9, 0x83, 0xca, 0xc1, 0xe0, 0x39, 0x35, 0xff, 0x13, 0xa8, 0xe9, 0x6b, 0xa6, 0x81, 0xa6, 0x2e, 0x40, 0xd3, 0xe7, 0x0a, 0x7f, 0xf3, 0x58, 0x66, 0xd3, 0xd9, 0x99, 0x3f, 0x9e, 0x26, 0xa6, 0x34, 0xc8, 0x1b, 0x4e, 0x71, 0x38, 0x0f, 0xcd, 0xd6, 0xf4, 0xe8, 0x35, 0xf7, 0x5a, 0x64, 0x09, 0xc7, 0xdc, 0x2c, 0x07, 0x41, 0x0e, 0x6f, 0x87, 0x85, 0x8c, 0x7b, 0x94, 0xc0, 0x1c, 0x2e, 0x32, 0xf2, 0x91, 0x76, 0x9e, 0xac, 0xca, 0x71, 0x64, 0x3b, 0x8b, 0x98, 0xa9, 0x63, 0xdf, 0x0a, 0x32, 0x9b, 0xea, 0x4e, 0xd6, 0x39, 0x7e, 0x8c, 0xd0, 0x1a, 0x11, 0x0a, 0xb3, 0x61, 0xac, 0x5b, 0xad, 0x1c, 0xcd, 0x84, 0x0a, 0x6c, 0x8a, 0x6e, 0xaa, 0x00, 0x1a, 0x9d, 0x7d, 0x87, 0xdc, 0x33, 0x18, 0x64, 0x35, 0x71, 0x22, 0x6c, 0x4d, 0xd2, 0xc2, 0xac, 0x41, 0xfb} + clientHello := []byte{ + 0x01, 0x00, 0x00, 0xa1, 0x03, 0x03, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, + 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x00, 0x00, 0x20, 0xcc, 0xa8, 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, + 0x30, 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, 0xc0, 0x09, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, + 0x9d, 0x00, 0x2f, 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, 0x01, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, + 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, + 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0d, 0x00, 0x12, 0x00, 0x10, 0x04, 0x01, 0x04, 0x03, + 0x05, 0x01, 0x05, 0x03, 0x06, 0x01, 0x06, 0x03, 0x02, 0x01, 0x02, 0x03, 0xff, 0x01, 0x00, 0x01, + 0x00, 0x00, 0x12, 0x00, 0x00, + } + serverHello := []byte{ + 0x02, 0x00, 0x00, 0x2d, 0x03, 0x03, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, + 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, + 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x00, 0xc0, 0x13, 0x00, 0x00, 0x05, 0xff, 0x01, 0x00, 0x01, + 0x00, + } + serverCertificate := []byte{ + 0x0b, 0x00, 0x03, 0x2b, 0x00, 0x03, 0x28, 0x00, 0x03, 0x25, 0x30, 0x82, 0x03, 0x21, 0x30, 0x82, + 0x02, 0x09, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x08, 0x15, 0x5a, 0x92, 0xad, 0xc2, 0x04, 0x8f, + 0x90, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, + 0x30, 0x22, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, + 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, + 0x33, 0x38, 0x31, 0x37, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, + 0x38, 0x31, 0x37, 0x5a, 0x30, 0x2b, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x55, 0x53, 0x31, 0x1c, 0x30, 0x1a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x13, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, + 0x74, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, + 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, + 0x01, 0x00, 0xc4, 0x80, 0x36, 0x06, 0xba, 0xe7, 0x47, 0x6b, 0x08, 0x94, 0x04, 0xec, 0xa7, 0xb6, + 0x91, 0x04, 0x3f, 0xf7, 0x92, 0xbc, 0x19, 0xee, 0xfb, 0x7d, 0x74, 0xd7, 0xa8, 0x0d, 0x00, 0x1e, + 0x7b, 0x4b, 0x3a, 0x4a, 0xe6, 0x0f, 0xe8, 0xc0, 0x71, 0xfc, 0x73, 0xe7, 0x02, 0x4c, 0x0d, 0xbc, + 0xf4, 0xbd, 0xd1, 0x1d, 0x39, 0x6b, 0xba, 0x70, 0x46, 0x4a, 0x13, 0xe9, 0x4a, 0xf8, 0x3d, 0xf3, + 0xe1, 0x09, 0x59, 0x54, 0x7b, 0xc9, 0x55, 0xfb, 0x41, 0x2d, 0xa3, 0x76, 0x52, 0x11, 0xe1, 0xf3, + 0xdc, 0x77, 0x6c, 0xaa, 0x53, 0x37, 0x6e, 0xca, 0x3a, 0xec, 0xbe, 0xc3, 0xaa, 0xb7, 0x3b, 0x31, + 0xd5, 0x6c, 0xb6, 0x52, 0x9c, 0x80, 0x98, 0xbc, 0xc9, 0xe0, 0x28, 0x18, 0xe2, 0x0b, 0xf7, 0xf8, + 0xa0, 0x3a, 0xfd, 0x17, 0x04, 0x50, 0x9e, 0xce, 0x79, 0xbd, 0x9f, 0x39, 0xf1, 0xea, 0x69, 0xec, + 0x47, 0x97, 0x2e, 0x83, 0x0f, 0xb5, 0xca, 0x95, 0xde, 0x95, 0xa1, 0xe6, 0x04, 0x22, 0xd5, 0xee, + 0xbe, 0x52, 0x79, 0x54, 0xa1, 0xe7, 0xbf, 0x8a, 0x86, 0xf6, 0x46, 0x6d, 0x0d, 0x9f, 0x16, 0x95, + 0x1a, 0x4c, 0xf7, 0xa0, 0x46, 0x92, 0x59, 0x5c, 0x13, 0x52, 0xf2, 0x54, 0x9e, 0x5a, 0xfb, 0x4e, + 0xbf, 0xd7, 0x7a, 0x37, 0x95, 0x01, 0x44, 0xe4, 0xc0, 0x26, 0x87, 0x4c, 0x65, 0x3e, 0x40, 0x7d, + 0x7d, 0x23, 0x07, 0x44, 0x01, 0xf4, 0x84, 0xff, 0xd0, 0x8f, 0x7a, 0x1f, 0xa0, 0x52, 0x10, 0xd1, + 0xf4, 0xf0, 0xd5, 0xce, 0x79, 0x70, 0x29, 0x32, 0xe2, 0xca, 0xbe, 0x70, 0x1f, 0xdf, 0xad, 0x6b, + 0x4b, 0xb7, 0x11, 0x01, 0xf4, 0x4b, 0xad, 0x66, 0x6a, 0x11, 0x13, 0x0f, 0xe2, 0xee, 0x82, 0x9e, + 0x4d, 0x02, 0x9d, 0xc9, 0x1c, 0xdd, 0x67, 0x16, 0xdb, 0xb9, 0x06, 0x18, 0x86, 0xed, 0xc1, 0xba, + 0x94, 0x21, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x52, 0x30, 0x50, 0x30, 0x0e, 0x06, 0x03, 0x55, + 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, + 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, + 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, + 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x89, 0x4f, 0xde, 0x5b, 0xcc, 0x69, 0xe2, 0x52, 0xcf, + 0x3e, 0xa3, 0x00, 0xdf, 0xb1, 0x97, 0xb8, 0x1d, 0xe1, 0xc1, 0x46, 0x30, 0x0d, 0x06, 0x09, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x59, + 0x16, 0x45, 0xa6, 0x9a, 0x2e, 0x37, 0x79, 0xe4, 0xf6, 0xdd, 0x27, 0x1a, 0xba, 0x1c, 0x0b, 0xfd, + 0x6c, 0xd7, 0x55, 0x99, 0xb5, 0xe7, 0xc3, 0x6e, 0x53, 0x3e, 0xff, 0x36, 0x59, 0x08, 0x43, 0x24, + 0xc9, 0xe7, 0xa5, 0x04, 0x07, 0x9d, 0x39, 0xe0, 0xd4, 0x29, 0x87, 0xff, 0xe3, 0xeb, 0xdd, 0x09, + 0xc1, 0xcf, 0x1d, 0x91, 0x44, 0x55, 0x87, 0x0b, 0x57, 0x1d, 0xd1, 0x9b, 0xdf, 0x1d, 0x24, 0xf8, + 0xbb, 0x9a, 0x11, 0xfe, 0x80, 0xfd, 0x59, 0x2b, 0xa0, 0x39, 0x8c, 0xde, 0x11, 0xe2, 0x65, 0x1e, + 0x61, 0x8c, 0xe5, 0x98, 0xfa, 0x96, 0xe5, 0x37, 0x2e, 0xef, 0x3d, 0x24, 0x8a, 0xfd, 0xe1, 0x74, + 0x63, 0xeb, 0xbf, 0xab, 0xb8, 0xe4, 0xd1, 0xab, 0x50, 0x2a, 0x54, 0xec, 0x00, 0x64, 0xe9, 0x2f, + 0x78, 0x19, 0x66, 0x0d, 0x3f, 0x27, 0xcf, 0x20, 0x9e, 0x66, 0x7f, 0xce, 0x5a, 0xe2, 0xe4, 0xac, + 0x99, 0xc7, 0xc9, 0x38, 0x18, 0xf8, 0xb2, 0x51, 0x07, 0x22, 0xdf, 0xed, 0x97, 0xf3, 0x2e, 0x3e, + 0x93, 0x49, 0xd4, 0xc6, 0x6c, 0x9e, 0xa6, 0x39, 0x6d, 0x74, 0x44, 0x62, 0xa0, 0x6b, 0x42, 0xc6, + 0xd5, 0xba, 0x68, 0x8e, 0xac, 0x3a, 0x01, 0x7b, 0xdd, 0xfc, 0x8e, 0x2c, 0xfc, 0xad, 0x27, 0xcb, + 0x69, 0xd3, 0xcc, 0xdc, 0xa2, 0x80, 0x41, 0x44, 0x65, 0xd3, 0xae, 0x34, 0x8c, 0xe0, 0xf3, 0x4a, + 0xb2, 0xfb, 0x9c, 0x61, 0x83, 0x71, 0x31, 0x2b, 0x19, 0x10, 0x41, 0x64, 0x1c, 0x23, 0x7f, 0x11, + 0xa5, 0xd6, 0x5c, 0x84, 0x4f, 0x04, 0x04, 0x84, 0x99, 0x38, 0x71, 0x2b, 0x95, 0x9e, 0xd6, 0x85, + 0xbc, 0x5c, 0x5d, 0xd6, 0x45, 0xed, 0x19, 0x90, 0x94, 0x73, 0x40, 0x29, 0x26, 0xdc, 0xb4, 0x0e, + 0x34, 0x69, 0xa1, 0x59, 0x41, 0xe8, 0xe2, 0xcc, 0xa8, 0x4b, 0xb6, 0x08, 0x46, 0x36, 0xa0, + } + serverKeyExchange := []byte{ + 0x0c, 0x00, 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, + 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, + 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, 0x04, 0x01, 0x01, 0x00, 0x04, 0x02, 0xb6, 0x61, + 0xf7, 0xc1, 0x91, 0xee, 0x59, 0xbe, 0x45, 0x37, 0x66, 0x39, 0xbd, 0xc3, 0xd4, 0xbb, 0x81, 0xe1, + 0x15, 0xca, 0x73, 0xc8, 0x34, 0x8b, 0x52, 0x5b, 0x0d, 0x23, 0x38, 0xaa, 0x14, 0x46, 0x67, 0xed, + 0x94, 0x31, 0x02, 0x14, 0x12, 0xcd, 0x9b, 0x84, 0x4c, 0xba, 0x29, 0x93, 0x4a, 0xaa, 0xcc, 0xe8, + 0x73, 0x41, 0x4e, 0xc1, 0x1c, 0xb0, 0x2e, 0x27, 0x2d, 0x0a, 0xd8, 0x1f, 0x76, 0x7d, 0x33, 0x07, + 0x67, 0x21, 0xf1, 0x3b, 0xf3, 0x60, 0x20, 0xcf, 0x0b, 0x1f, 0xd0, 0xec, 0xb0, 0x78, 0xde, 0x11, + 0x28, 0xbe, 0xba, 0x09, 0x49, 0xeb, 0xec, 0xe1, 0xa1, 0xf9, 0x6e, 0x20, 0x9d, 0xc3, 0x6e, 0x4f, + 0xff, 0xd3, 0x6b, 0x67, 0x3a, 0x7d, 0xdc, 0x15, 0x97, 0xad, 0x44, 0x08, 0xe4, 0x85, 0xc4, 0xad, + 0xb2, 0xc8, 0x73, 0x84, 0x12, 0x49, 0x37, 0x25, 0x23, 0x80, 0x9e, 0x43, 0x12, 0xd0, 0xc7, 0xb3, + 0x52, 0x2e, 0xf9, 0x83, 0xca, 0xc1, 0xe0, 0x39, 0x35, 0xff, 0x13, 0xa8, 0xe9, 0x6b, 0xa6, 0x81, + 0xa6, 0x2e, 0x40, 0xd3, 0xe7, 0x0a, 0x7f, 0xf3, 0x58, 0x66, 0xd3, 0xd9, 0x99, 0x3f, 0x9e, 0x26, + 0xa6, 0x34, 0xc8, 0x1b, 0x4e, 0x71, 0x38, 0x0f, 0xcd, 0xd6, 0xf4, 0xe8, 0x35, 0xf7, 0x5a, 0x64, + 0x09, 0xc7, 0xdc, 0x2c, 0x07, 0x41, 0x0e, 0x6f, 0x87, 0x85, 0x8c, 0x7b, 0x94, 0xc0, 0x1c, 0x2e, + 0x32, 0xf2, 0x91, 0x76, 0x9e, 0xac, 0xca, 0x71, 0x64, 0x3b, 0x8b, 0x98, 0xa9, 0x63, 0xdf, 0x0a, + 0x32, 0x9b, 0xea, 0x4e, 0xd6, 0x39, 0x7e, 0x8c, 0xd0, 0x1a, 0x11, 0x0a, 0xb3, 0x61, 0xac, 0x5b, + 0xad, 0x1c, 0xcd, 0x84, 0x0a, 0x6c, 0x8a, 0x6e, 0xaa, 0x00, 0x1a, 0x9d, 0x7d, 0x87, 0xdc, 0x33, + 0x18, 0x64, 0x35, 0x71, 0x22, 0x6c, 0x4d, 0xd2, 0xc2, 0xac, 0x41, 0xfb, + } serverHelloDone := []byte{0x0e, 0x00, 0x00, 0x00} - clientKeyExchange := []byte{0x10, 0x00, 0x00, 0x21, 0x20, 0x35, 0x80, 0x72, 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54} + clientKeyExchange := []byte{ + 0x10, 0x00, 0x00, 0x21, 0x20, 0x35, 0x80, 0x72, 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, + 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38, 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, + 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62, 0x54, + } - finalMsg := append(append(append(append(append(clientHello, serverHello...), serverCertificate...), serverKeyExchange...), serverHelloDone...), clientKeyExchange...) - masterSecret := []byte{0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, 0x37, 0x34, 0x4c} + finalMsg := append( + append( + append( + append( + append( + clientHello, serverHello..., + ), serverCertificate..., + ), serverKeyExchange..., + ), serverHelloDone..., + ), clientKeyExchange..., + ) + masterSecret := []byte{ + 0x91, 0x6a, 0xbf, 0x9d, 0xa5, 0x59, 0x73, 0xe1, 0x36, 0x14, 0xae, 0x0a, 0x3f, 0x5d, 0x3f, + 0x37, 0xb0, 0x23, 0xba, 0x12, 0x9a, 0xee, 0x02, 0xcc, 0x91, 0x34, 0x33, 0x81, 0x27, 0xcd, + 0x70, 0x49, 0x78, 0x1c, 0x8e, 0x19, 0xfc, 0x1e, 0xb2, 0xa7, 0x38, 0x7a, 0xc0, 0x6a, 0xe2, + 0x37, 0x34, 0x4c, + } expectedVerifyData := []byte{0xcf, 0x91, 0x96, 0x26, 0xf1, 0x36, 0x0c, 0x53, 0x6a, 0xaa, 0xd7, 0x3a} verifyData, err := VerifyDataClient(masterSecret, finalMsg, sha256.New) diff --git a/pkg/crypto/selfsign/selfsign.go b/pkg/crypto/selfsign/selfsign.go index 6ef016724..b3bf850a4 100644 --- a/pkg/crypto/selfsign/selfsign.go +++ b/pkg/crypto/selfsign/selfsign.go @@ -21,7 +21,7 @@ import ( var errInvalidPrivateKey = errors.New("selfsign: invalid private key type") -// GenerateSelfSigned creates a self-signed certificate +// GenerateSelfSigned creates a self-signed certificate. func GenerateSelfSigned() (tls.Certificate, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -31,7 +31,7 @@ func GenerateSelfSigned() (tls.Certificate, error) { return SelfSign(priv) } -// GenerateSelfSignedWithDNS creates a self-signed certificate +// GenerateSelfSignedWithDNS creates a self-signed certificate. func GenerateSelfSignedWithDNS(cn string, sans ...string) (tls.Certificate, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -41,25 +41,30 @@ func GenerateSelfSignedWithDNS(cn string, sans ...string) (tls.Certificate, erro return WithDNS(priv, cn, sans...) } -// SelfSign creates a self-signed certificate from a elliptic curve key +// SelfSign creates a self-signed certificate from a elliptic curve key. func SelfSign(key crypto.PrivateKey) (tls.Certificate, error) { return WithDNS(key, "self-signed cert") } -// WithDNS creates a self-signed certificate from a elliptic curve key +// WithDNS creates a self-signed certificate from a elliptic curve key. func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, error) { var ( pubKey crypto.PublicKey maxBigInt = new(big.Int) // Max random value, a 130-bits integer, i.e 2^130 - 1 ) - switch k := key.(type) { - case ed25519.PrivateKey: - pubKey = k.Public() - case *ecdsa.PrivateKey: - pubKey = k.Public() - case *rsa.PrivateKey: - pubKey = k.Public() + signer, ok := key.(crypto.Signer) + if !ok { + return tls.Certificate{}, errInvalidPrivateKey + } + + switch k := signer.Public().(type) { + case ed25519.PublicKey: + pubKey = k + case *ecdsa.PublicKey: + pubKey = k + case *rsa.PublicKey: + pubKey = k default: return tls.Certificate{}, errInvalidPrivateKey } @@ -76,7 +81,7 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, names = append(names, sans...) keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign - if _, isRSA := key.(*rsa.PrivateKey); isRSA { + if _, isRSA := signer.Public().(*rsa.PublicKey); isRSA { keyUsage |= x509.KeyUsageKeyEncipherment } @@ -98,7 +103,7 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, }, } - raw, err := x509.CreateCertificate(rand.Reader, &template, &template, pubKey, key) + raw, err := x509.CreateCertificate(rand.Reader, &template, &template, pubKey, signer) if err != nil { return tls.Certificate{}, err } @@ -110,7 +115,7 @@ func WithDNS(key crypto.PrivateKey, cn string, sans ...string) (tls.Certificate, return tls.Certificate{ Certificate: [][]byte{raw}, - PrivateKey: key, + PrivateKey: signer, Leaf: leaf, }, nil } diff --git a/pkg/crypto/signature/signature.go b/pkg/crypto/signature/signature.go index fec7fba3b..53fb7c952 100644 --- a/pkg/crypto/signature/signature.go +++ b/pkg/crypto/signature/signature.go @@ -8,7 +8,7 @@ package signature // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16 type Algorithm uint16 -// SignatureAlgorithm enums +// SignatureAlgorithm enums. const ( Anonymous Algorithm = 0 RSA Algorithm = 1 @@ -16,7 +16,7 @@ const ( Ed25519 Algorithm = 7 ) -// Algorithms returns all implemented Signature Algorithms +// Algorithms returns all implemented Signature Algorithms. func Algorithms() map[Algorithm]struct{} { return map[Algorithm]struct{}{ Anonymous: {}, diff --git a/pkg/crypto/signaturehash/errors.go b/pkg/crypto/signaturehash/errors.go index 4aeb3e40a..2e2b72bb8 100644 --- a/pkg/crypto/signaturehash/errors.go +++ b/pkg/crypto/signaturehash/errors.go @@ -9,4 +9,5 @@ var ( errNoAvailableSignatureSchemes = errors.New("connection can not be created, no SignatureScheme satisfy this Config") errInvalidSignatureAlgorithm = errors.New("invalid signature algorithm") errInvalidHashAlgorithm = errors.New("invalid hash algorithm") + errInvalidPrivateKey = errors.New("invalid private key type") ) diff --git a/pkg/crypto/signaturehash/signaturehash.go b/pkg/crypto/signaturehash/signaturehash.go index 2561accd1..2c72f8ca8 100644 --- a/pkg/crypto/signaturehash/signaturehash.go +++ b/pkg/crypto/signaturehash/signaturehash.go @@ -12,8 +12,8 @@ import ( "crypto/tls" "fmt" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" ) // Algorithm is a signature/hash algorithm pairs which may be used in @@ -25,7 +25,7 @@ type Algorithm struct { Signature signature.Algorithm } -// Algorithms are all the know SignatureHash Algorithms +// Algorithms are all the know SignatureHash Algorithms. func Algorithms() []Algorithm { return []Algorithm{ {hash.SHA256, signature.ECDSA}, @@ -40,22 +40,27 @@ func Algorithms() []Algorithm { // SelectSignatureScheme returns most preferred and compatible scheme. func SelectSignatureScheme(sigs []Algorithm, privateKey crypto.PrivateKey) (Algorithm, error) { + signer, ok := privateKey.(crypto.Signer) + if !ok { + return Algorithm{}, errInvalidPrivateKey + } for _, ss := range sigs { - if ss.isCompatible(privateKey) { + if ss.isCompatible(signer) { return ss, nil } } + return Algorithm{}, errNoAvailableSignatureSchemes } // isCompatible checks that given private key is compatible with the signature scheme. -func (a *Algorithm) isCompatible(privateKey crypto.PrivateKey) bool { - switch privateKey.(type) { - case ed25519.PrivateKey: +func (a *Algorithm) isCompatible(signer crypto.Signer) bool { + switch signer.Public().(type) { + case ed25519.PublicKey: return a.Signature == signature.Ed25519 - case *ecdsa.PrivateKey: + case *ecdsa.PublicKey: return a.Signature == signature.ECDSA - case *rsa.PrivateKey: + case *rsa.PublicKey: return a.Signature == signature.RSA default: return false diff --git a/pkg/crypto/signaturehash/signaturehash_test.go b/pkg/crypto/signaturehash/signaturehash_test.go index 6df4dbcab..54afce324 100644 --- a/pkg/crypto/signaturehash/signaturehash_test.go +++ b/pkg/crypto/signaturehash/signaturehash_test.go @@ -9,8 +9,8 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" ) func TestParseSignatureSchemes(t *testing.T) { diff --git a/pkg/net/net.go b/pkg/net/net.go new file mode 100644 index 000000000..3db604777 --- /dev/null +++ b/pkg/net/net.go @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +// Package net defines packet-oriented primitives that are compatible with net +// in the standard library. +package net + +import ( + "net" + "time" +) + +// A PacketListener is the same as net.Listener but returns a net.PacketConn on +// Accept() rather than a net.Conn. +// +// Multiple goroutines may invoke methods on a PacketListener simultaneously. +type PacketListener interface { + // Accept waits for and returns the next connection to the listener. + Accept() (net.PacketConn, net.Addr, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error + + // Addr returns the listener's network address. + Addr() net.Addr +} + +// PacketListenerFromListener converts a net.Listener into a +// dtlsnet.PacketListener. +func PacketListenerFromListener(l net.Listener) PacketListener { + return &packetListenerWrapper{ + l: l, + } +} + +// packetListenerWrapper wraps a net.Listener and implements +// dtlsnet.PacketListener. +type packetListenerWrapper struct { + l net.Listener +} + +// Accept calls Accept on the underlying net.Listener and converts the returned +// net.Conn into a net.PacketConn. +func (p *packetListenerWrapper) Accept() (net.PacketConn, net.Addr, error) { + c, err := p.l.Accept() + if err != nil { + return PacketConnFromConn(c), nil, err + } + + return PacketConnFromConn(c), c.RemoteAddr(), nil +} + +// Close closes the underlying net.Listener. +func (p *packetListenerWrapper) Close() error { + return p.l.Close() +} + +// Addr returns the address of the underlying net.Listener. +func (p *packetListenerWrapper) Addr() net.Addr { + return p.l.Addr() +} + +// PacketConnFromConn converts a net.Conn into a net.PacketConn. +func PacketConnFromConn(conn net.Conn) net.PacketConn { + return &packetConnWrapper{conn} +} + +// packetConnWrapper wraps a net.Conn and implements net.PacketConn. +type packetConnWrapper struct { + conn net.Conn +} + +// ReadFrom reads from the underlying net.Conn and returns its remote address. +func (p *packetConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := p.conn.Read(b) + + return n, p.conn.RemoteAddr(), err +} + +// WriteTo writes to the underlying net.Conn. +func (p *packetConnWrapper) WriteTo(b []byte, _ net.Addr) (int, error) { + n, err := p.conn.Write(b) + + return n, err +} + +// Close closes the underlying net.Conn. +func (p *packetConnWrapper) Close() error { + return p.conn.Close() +} + +// LocalAddr returns the local address of the underlying net.Conn. +func (p *packetConnWrapper) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// SetDeadline sets the deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} diff --git a/pkg/protocol/alert/alert.go b/pkg/protocol/alert/alert.go index 91e9f4d60..8fac65962 100644 --- a/pkg/protocol/alert/alert.go +++ b/pkg/protocol/alert/alert.go @@ -8,15 +8,15 @@ import ( "errors" "fmt" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol" ) var errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 -// Level is the level of the TLS Alert +// Level is the level of the TLS Alert. type Level byte -// Level enums +// Level enums. const ( Warning Level = 1 Fatal Level = 2 @@ -33,10 +33,10 @@ func (l Level) String() string { } } -// Description is the extended info of the TLS Alert +// Description is the extended info of the TLS Alert. type Description byte -// Description enums +// Description enums. const ( CloseNotify Description = 0 UnexpectedMessage Description = 10 @@ -66,7 +66,7 @@ const ( NoApplicationProtocol Description = 120 ) -func (d Description) String() string { +func (d Description) String() string { //nolint:cyclop switch d { case CloseNotify: return "CloseNotify" @@ -140,17 +140,17 @@ type Alert struct { Description Description } -// ContentType returns the ContentType of this Content +// ContentType returns the ContentType of this Content. func (a Alert) ContentType() protocol.ContentType { return protocol.ContentTypeAlert } -// Marshal returns the encoded alert +// Marshal returns the encoded alert. func (a *Alert) Marshal() ([]byte, error) { return []byte{byte(a.Level), byte(a.Description)}, nil } -// Unmarshal populates the alert from binary data +// Unmarshal populates the alert from binary data. func (a *Alert) Unmarshal(data []byte) error { if len(data) != 2 { return errBufferTooSmall @@ -158,6 +158,7 @@ func (a *Alert) Unmarshal(data []byte) error { a.Level = Level(data[0]) a.Description = Description(data[1]) + return nil } diff --git a/pkg/protocol/application_data.go b/pkg/protocol/application_data.go index f42211511..f478c4231 100644 --- a/pkg/protocol/application_data.go +++ b/pkg/protocol/application_data.go @@ -12,18 +12,19 @@ type ApplicationData struct { Data []byte } -// ContentType returns the ContentType of this content +// ContentType returns the ContentType of this content. func (a ApplicationData) ContentType() ContentType { return ContentTypeApplicationData } -// Marshal encodes the ApplicationData to binary +// Marshal encodes the ApplicationData to binary. func (a *ApplicationData) Marshal() ([]byte, error) { return append([]byte{}, a.Data...), nil } -// Unmarshal populates the ApplicationData from binary +// Unmarshal populates the ApplicationData from binary. func (a *ApplicationData) Unmarshal(data []byte) error { a.Data = append([]byte{}, data...) + return nil } diff --git a/pkg/protocol/change_cipher_spec.go b/pkg/protocol/change_cipher_spec.go index 87f28bc37..4813cd564 100644 --- a/pkg/protocol/change_cipher_spec.go +++ b/pkg/protocol/change_cipher_spec.go @@ -10,17 +10,17 @@ package protocol // https://tools.ietf.org/html/rfc5246#section-7.1 type ChangeCipherSpec struct{} -// ContentType returns the ContentType of this content +// ContentType returns the ContentType of this content. func (c ChangeCipherSpec) ContentType() ContentType { return ContentTypeChangeCipherSpec } -// Marshal encodes the ChangeCipherSpec to binary +// Marshal encodes the ChangeCipherSpec to binary. func (c *ChangeCipherSpec) Marshal() ([]byte, error) { return []byte{0x01}, nil } -// Unmarshal populates the ChangeCipherSpec from binary +// Unmarshal populates the ChangeCipherSpec from binary. func (c *ChangeCipherSpec) Unmarshal(data []byte) error { if len(data) == 1 && data[0] == 0x01 { return nil diff --git a/pkg/protocol/compression_method.go b/pkg/protocol/compression_method.go index 3478ee38c..0fb99a51b 100644 --- a/pkg/protocol/compression_method.go +++ b/pkg/protocol/compression_method.go @@ -3,26 +3,26 @@ package protocol -// CompressionMethodID is the ID for a CompressionMethod +// CompressionMethodID is the ID for a CompressionMethod. type CompressionMethodID byte const ( compressionMethodNull CompressionMethodID = 0 ) -// CompressionMethod represents a TLS Compression Method +// CompressionMethod represents a TLS Compression Method. type CompressionMethod struct { ID CompressionMethodID } -// CompressionMethods returns all supported CompressionMethods +// CompressionMethods returns all supported CompressionMethods. func CompressionMethods() map[CompressionMethodID]*CompressionMethod { return map[CompressionMethodID]*CompressionMethod{ compressionMethodNull: {ID: compressionMethodNull}, } } -// DecodeCompressionMethods the given compression methods +// DecodeCompressionMethods the given compression methods. func DecodeCompressionMethods(buf []byte) ([]*CompressionMethod, error) { if len(buf) < 1 { return nil, errBufferTooSmall @@ -38,14 +38,16 @@ func DecodeCompressionMethods(buf []byte) ([]*CompressionMethod, error) { c = append(c, compressionMethod) } } + return c, nil } -// EncodeCompressionMethods the given compression methods +// EncodeCompressionMethods the given compression methods. func EncodeCompressionMethods(c []*CompressionMethod) []byte { out := []byte{byte(len(c))} for i := len(c); i > 0; i-- { out = append(out, byte(c[i-1].ID)) } + return out } diff --git a/pkg/protocol/content.go b/pkg/protocol/content.go index 92c9db2bf..9b6daa51f 100644 --- a/pkg/protocol/content.go +++ b/pkg/protocol/content.go @@ -8,15 +8,16 @@ package protocol // https://tools.ietf.org/html/rfc4346#section-6.2.1 type ContentType uint8 -// ContentType enums +// ContentType enums. const ( ContentTypeChangeCipherSpec ContentType = 20 ContentTypeAlert ContentType = 21 ContentTypeHandshake ContentType = 22 ContentTypeApplicationData ContentType = 23 + ContentTypeConnectionID ContentType = 25 ) -// Content is the top level distinguisher for a DTLS Datagram +// Content is the top level distinguisher for a DTLS Datagram. type Content interface { ContentType() ContentType Marshal() ([]byte, error) diff --git a/pkg/protocol/errors.go b/pkg/protocol/errors.go index d87aff7fb..dc091bff3 100644 --- a/pkg/protocol/errors.go +++ b/pkg/protocol/errors.go @@ -20,7 +20,8 @@ type FatalError struct { Err error } -// InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available. +// InternalError indicates and internal error caused by the implementation, +// and the DTLS connection is no longer available. // It is mainly caused by bugs or tried to use unimplemented features. type InternalError struct { Err error @@ -41,10 +42,10 @@ type HandshakeError struct { Err error } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*FatalError) Timeout() bool { return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*FatalError) Temporary() bool { return false } // Unwrap implements Go1.13 error unwrapper. @@ -52,10 +53,10 @@ func (e *FatalError) Unwrap() error { return e.Err } func (e *FatalError) Error() string { return fmt.Sprintf("dtls fatal: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*InternalError) Timeout() bool { return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*InternalError) Temporary() bool { return false } // Unwrap implements Go1.13 error unwrapper. @@ -63,10 +64,10 @@ func (e *InternalError) Unwrap() error { return e.Err } func (e *InternalError) Error() string { return fmt.Sprintf("dtls internal: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*TemporaryError) Timeout() bool { return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*TemporaryError) Temporary() bool { return true } // Unwrap implements Go1.13 error unwrapper. @@ -74,10 +75,10 @@ func (e *TemporaryError) Unwrap() error { return e.Err } func (e *TemporaryError) Error() string { return fmt.Sprintf("dtls temporary: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (*TimeoutError) Timeout() bool { return true } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (*TimeoutError) Temporary() bool { return true } // Unwrap implements Go1.13 error unwrapper. @@ -85,21 +86,23 @@ func (e *TimeoutError) Unwrap() error { return e.Err } func (e *TimeoutError) Error() string { return fmt.Sprintf("dtls timeout: %v", e.Err) } -// Timeout implements net.Error.Timeout() +// Timeout implements net.Error.Timeout(). func (e *HandshakeError) Timeout() bool { var netErr net.Error if errors.As(e.Err, &netErr) { return netErr.Timeout() } + return false } -// Temporary implements net.Error.Temporary() +// Temporary implements net.Error.Temporary(). func (e *HandshakeError) Temporary() bool { var netErr net.Error if errors.As(e.Err, &netErr) { return netErr.Temporary() //nolint } + return false } diff --git a/pkg/protocol/extension/alpn.go b/pkg/protocol/extension/alpn.go index e780dc9e1..719428601 100644 --- a/pkg/protocol/extension/alpn.go +++ b/pkg/protocol/extension/alpn.go @@ -15,16 +15,16 @@ type ALPN struct { ProtocolNameList []string } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (a ALPN) TypeValue() TypeValue { return ALPNTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (a *ALPN) Marshal() ([]byte, error) { - var b cryptobyte.Builder - b.AddUint16(uint16(a.TypeValue())) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + var builder cryptobyte.Builder + builder.AddUint16(uint16(a.TypeValue())) + builder.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { for _, proto := range a.ProtocolNameList { p := proto // Satisfy range scope lint @@ -34,10 +34,11 @@ func (a *ALPN) Marshal() ([]byte, error) { } }) }) - return b.Bytes() + + return builder.Bytes() } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (a *ALPN) Unmarshal(data []byte) error { val := cryptobyte.String(data) @@ -61,10 +62,11 @@ func (a *ALPN) Unmarshal(data []byte) error { } a.ProtocolNameList = append(a.ProtocolNameList, string(proto)) } + return nil } -// ALPNProtocolSelection negotiates a shared protocol according to #3.2 of rfc7301 +// ALPNProtocolSelection negotiates a shared protocol according to #3.2 of rfc7301. func ALPNProtocolSelection(supportedProtocols, peerSupportedProtocols []string) (string, error) { if len(supportedProtocols) == 0 || len(peerSupportedProtocols) == 0 { return "", nil @@ -76,5 +78,6 @@ func ALPNProtocolSelection(supportedProtocols, peerSupportedProtocols []string) } } } + return "", errALPNNoAppProto } diff --git a/pkg/protocol/extension/alpn_test.go b/pkg/protocol/extension/alpn_test.go index 6b12af0f7..11468fac0 100644 --- a/pkg/protocol/extension/alpn_test.go +++ b/pkg/protocol/extension/alpn_test.go @@ -31,22 +31,22 @@ func TestALPN(t *testing.T) { } func TestALPNProtocolSelection(t *testing.T) { - s, err := ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{"spd/1"}) + selectedProtocol, err := ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{"spd/1"}) if err != nil { t.Fatal(err) } - if s != "spd/1" { - t.Errorf("expected: spd/1, got: %v", s) + if selectedProtocol != "spd/1" { + t.Errorf("expected: spd/1, got: %v", selectedProtocol) } _, err = ALPNProtocolSelection([]string{"http/1.1"}, []string{"spd/1"}) if !errors.Is(err, errALPNNoAppProto) { t.Fatal("expected to fail negotiating an application protocol") } - s, err = ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{}) + selectedProtocol, err = ALPNProtocolSelection([]string{"http/1.1", "spd/1"}, []string{}) if err != nil { t.Fatal(err) } - if s != "" { - t.Errorf("expected not to negotiate a protocol, got: %v", s) + if selectedProtocol != "" { + t.Errorf("expected not to negotiate a protocol, got: %v", selectedProtocol) } } diff --git a/pkg/protocol/extension/connection_id.go b/pkg/protocol/extension/connection_id.go new file mode 100644 index 000000000..6c8a7f566 --- /dev/null +++ b/pkg/protocol/extension/connection_id.go @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package extension + +import ( + "golang.org/x/crypto/cryptobyte" +) + +// ConnectionID is a DTLS extension that provides an alternative to IP address +// and port for session association. +// +// https://tools.ietf.org/html/rfc9146 +type ConnectionID struct { + // A zero-length connection ID indicates for a client or server that + // negotiated connection IDs from the peer will be sent but there is no need + // to respond with one + CID []byte // variable length +} + +// TypeValue returns the extension TypeValue. +func (c ConnectionID) TypeValue() TypeValue { + return ConnectionIDTypeValue +} + +// Marshal encodes the extension. +func (c *ConnectionID) Marshal() ([]byte, error) { + var b cryptobyte.Builder + b.AddUint16(uint16(c.TypeValue())) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(c.CID) + }) + }) + + return b.Bytes() +} + +// Unmarshal populates the extension from encoded data. +func (c *ConnectionID) Unmarshal(data []byte) error { + val := cryptobyte.String(data) + var extension uint16 + val.ReadUint16(&extension) + if TypeValue(extension) != c.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + val.ReadUint16LengthPrefixed(&extData) + + var cid cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&cid) { + return errInvalidCIDFormat + } + c.CID = make([]byte, len(cid)) + if !cid.CopyBytes(c.CID) { + return errInvalidCIDFormat + } + + return nil +} diff --git a/pkg/protocol/extension/connection_id_test.go b/pkg/protocol/extension/connection_id_test.go new file mode 100644 index 000000000..e9736172d --- /dev/null +++ b/pkg/protocol/extension/connection_id_test.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package extension + +import ( + "reflect" + "testing" +) + +func TestExtensionConnectionID(t *testing.T) { + rawExtensionConnectionID := []byte{1, 6, 8, 3, 88, 12, 2, 47} + parsedExtensionConnectionID := &ConnectionID{ + CID: rawExtensionConnectionID, + } + + raw, err := parsedExtensionConnectionID.Marshal() + if err != nil { + t.Fatal(err) + } + + roundtrip := &ConnectionID{} + if err := roundtrip.Unmarshal(raw); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(roundtrip, parsedExtensionConnectionID) { + t.Errorf("parsedExtensionConnectionID unmarshal: got %#v, want %#v", roundtrip, parsedExtensionConnectionID) + } +} diff --git a/pkg/protocol/extension/errors.go b/pkg/protocol/extension/errors.go index c5e954ce5..424ae5b1a 100644 --- a/pkg/protocol/extension/errors.go +++ b/pkg/protocol/extension/errors.go @@ -6,15 +6,33 @@ package extension import ( "errors" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol" ) var ( - // ErrALPNInvalidFormat is raised when the ALPN format is invalid - ErrALPNInvalidFormat = &protocol.FatalError{Err: errors.New("invalid alpn format")} //nolint:goerr113 - errALPNNoAppProto = &protocol.FatalError{Err: errors.New("no application protocol")} //nolint:goerr113 - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113 - errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113 - errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 + // ErrALPNInvalidFormat is raised when the ALPN format is invalid. + ErrALPNInvalidFormat = &protocol.FatalError{ + Err: errors.New("invalid alpn format"), //nolint:goerr113 + } + errALPNNoAppProto = &protocol.FatalError{ + Err: errors.New("no application protocol"), //nolint:goerr113 + } + errBufferTooSmall = &protocol.TemporaryError{ + Err: errors.New("buffer is too small"), //nolint:goerr113 + } + errInvalidExtensionType = &protocol.FatalError{ + Err: errors.New("invalid extension type"), //nolint:goerr113 + } + errInvalidSNIFormat = &protocol.FatalError{ + Err: errors.New("invalid server name format"), //nolint:goerr113 + } + errInvalidCIDFormat = &protocol.FatalError{ + Err: errors.New("invalid connection ID format"), //nolint:goerr113 + } + errLengthMismatch = &protocol.InternalError{ + Err: errors.New("data length and declared length do not match"), //nolint:goerr113 + } + errMasterKeyIdentifierTooLarge = &protocol.FatalError{ + Err: errors.New("master key identifier is over 255 bytes"), //nolint:goerr113 + } ) diff --git a/pkg/protocol/extension/extension.go b/pkg/protocol/extension/extension.go index 5173a5863..ba82beea3 100644 --- a/pkg/protocol/extension/extension.go +++ b/pkg/protocol/extension/extension.go @@ -11,7 +11,7 @@ import "encoding/binary" // https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml type TypeValue uint16 -// TypeValue constants +// TypeValue constants. const ( ServerNameTypeValue TypeValue = 0 SupportedEllipticCurvesTypeValue TypeValue = 10 @@ -20,18 +20,19 @@ const ( UseSRTPTypeValue TypeValue = 14 ALPNTypeValue TypeValue = 16 UseExtendedMasterSecretTypeValue TypeValue = 23 + ConnectionIDTypeValue TypeValue = 54 RenegotiationInfoTypeValue TypeValue = 65281 ) -// Extension represents a single TLS extension +// Extension represents a single TLS extension. type Extension interface { Marshal() ([]byte, error) Unmarshal(data []byte) error TypeValue() TypeValue } -// Unmarshal many extensions at once -func Unmarshal(buf []byte) ([]Extension, error) { +// Unmarshal many extensions at once. +func Unmarshal(buf []byte) ([]Extension, error) { //nolint:cyclop switch { case len(buf) == 0: return []Extension{}, nil @@ -51,6 +52,7 @@ func Unmarshal(buf []byte) ([]Extension, error) { return err } extensions = append(extensions, e) + return nil } @@ -64,6 +66,10 @@ func Unmarshal(buf []byte) ([]Extension, error) { err = unmarshalAndAppend(buf[offset:], &ServerName{}) case SupportedEllipticCurvesTypeValue: err = unmarshalAndAppend(buf[offset:], &SupportedEllipticCurves{}) + case SupportedPointFormatsTypeValue: + err = unmarshalAndAppend(buf[offset:], &SupportedPointFormats{}) + case SupportedSignatureAlgorithmsTypeValue: + err = unmarshalAndAppend(buf[offset:], &SupportedSignatureAlgorithms{}) case UseSRTPTypeValue: err = unmarshalAndAppend(buf[offset:], &UseSRTP{}) case ALPNTypeValue: @@ -72,6 +78,8 @@ func Unmarshal(buf []byte) ([]Extension, error) { err = unmarshalAndAppend(buf[offset:], &UseExtendedMasterSecret{}) case RenegotiationInfoTypeValue: err = unmarshalAndAppend(buf[offset:], &RenegotiationInfo{}) + case ConnectionIDTypeValue: + err = unmarshalAndAppend(buf[offset:], &ConnectionID{}) default: } if err != nil { @@ -83,10 +91,11 @@ func Unmarshal(buf []byte) ([]Extension, error) { extensionLength := binary.BigEndian.Uint16(buf[offset+2:]) offset += (4 + int(extensionLength)) } + return extensions, nil } -// Marshal many extensions at once +// Marshal many extensions at once. func Marshal(e []Extension) ([]byte, error) { extensions := []byte{} for _, e := range e { @@ -97,6 +106,7 @@ func Marshal(e []Extension) ([]byte, error) { extensions = append(extensions, raw...) } out := []byte{0x00, 0x00} - binary.BigEndian.PutUint16(out, uint16(len(extensions))) + binary.BigEndian.PutUint16(out, uint16(len(extensions))) //nolint:gosec // G115 + return append(out, extensions...), nil } diff --git a/pkg/protocol/extension/fuzz_test.go b/pkg/protocol/extension/fuzz_test.go new file mode 100644 index 000000000..9266a8031 --- /dev/null +++ b/pkg/protocol/extension/fuzz_test.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package extension + +import "testing" + +func FuzzUnmarshal(f *testing.F) { + f.Add([]byte{0x00}) + f.Add([]byte{1, 6, 8, 3, 88, 12, 2, 47}) + f.Add([]byte{0x0, 0xa, 0x0, 0x4, 0x0, 0x2, 0x0, 0x1d}) + f.Add([]byte{0x00, 0x0b, 0x00, 0x02, 0x01, 0x00}) + f.Add([]byte{0x00, 0x0d, 0x00, 0x08, 0x00, 0x06, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03}) + f.Add([]byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00}) + + f.Fuzz(func(_ *testing.T, data []byte) { + _, _ = Unmarshal(data) + }) +} diff --git a/pkg/protocol/extension/renegotiation_info.go b/pkg/protocol/extension/renegotiation_info.go index c5092a7db..57432fd0b 100644 --- a/pkg/protocol/extension/renegotiation_info.go +++ b/pkg/protocol/extension/renegotiation_info.go @@ -17,22 +17,23 @@ type RenegotiationInfo struct { RenegotiatedConnection uint8 } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (r RenegotiationInfo) TypeValue() TypeValue { return RenegotiationInfoTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (r *RenegotiationInfo) Marshal() ([]byte, error) { out := make([]byte, renegotiationInfoHeaderSize) binary.BigEndian.PutUint16(out, uint16(r.TypeValue())) binary.BigEndian.PutUint16(out[2:], uint16(1)) // length out[4] = r.RenegotiatedConnection + return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (r *RenegotiationInfo) Unmarshal(data []byte) error { if len(data) < renegotiationInfoHeaderSize { return errBufferTooSmall diff --git a/pkg/protocol/extension/renegotiation_info_test.go b/pkg/protocol/extension/renegotiation_info_test.go index 63b9609d1..252144332 100644 --- a/pkg/protocol/extension/renegotiation_info_test.go +++ b/pkg/protocol/extension/renegotiation_info_test.go @@ -20,6 +20,9 @@ func TestRenegotiationInfo(t *testing.T) { } if newExtension.RenegotiatedConnection != extension.RenegotiatedConnection { - t.Errorf("extensionRenegotiationInfo marshal: got %d expected %d", newExtension.RenegotiatedConnection, extension.RenegotiatedConnection) + t.Errorf( + "extensionRenegotiationInfo marshal: got %d expected %d", + newExtension.RenegotiatedConnection, extension.RenegotiatedConnection, + ) } } diff --git a/pkg/protocol/extension/server_name.go b/pkg/protocol/extension/server_name.go index 183e08e6e..31e6327d3 100644 --- a/pkg/protocol/extension/server_name.go +++ b/pkg/protocol/extension/server_name.go @@ -20,12 +20,12 @@ type ServerName struct { ServerName string } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s ServerName) TypeValue() TypeValue { return ServerNameTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *ServerName) Marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(uint16(s.TypeValue())) @@ -37,11 +37,12 @@ func (s *ServerName) Marshal() ([]byte, error) { }) }) }) + return b.Bytes() } -// Unmarshal populates the extension from encoded data -func (s *ServerName) Unmarshal(data []byte) error { +// Unmarshal populates the extension from encoded data. +func (s *ServerName) Unmarshal(data []byte) error { //nolint:cyclop val := cryptobyte.String(data) var extension uint16 val.ReadUint16(&extension) @@ -77,5 +78,6 @@ func (s *ServerName) Unmarshal(data []byte) error { return errInvalidSNIFormat } } + return nil } diff --git a/pkg/protocol/extension/srtp_protection_profile.go b/pkg/protocol/extension/srtp_protection_profile.go index 2966966dd..75a8c2ee3 100644 --- a/pkg/protocol/extension/srtp_protection_profile.go +++ b/pkg/protocol/extension/srtp_protection_profile.go @@ -10,6 +10,10 @@ type SRTPProtectionProfile uint16 const ( SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = 0x0001 // nolint SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = 0x0002 // nolint + SRTP_AES256_CM_SHA1_80 SRTPProtectionProfile = 0x0003 // nolint + SRTP_AES256_CM_SHA1_32 SRTPProtectionProfile = 0x0004 // nolint + SRTP_NULL_HMAC_SHA1_80 SRTPProtectionProfile = 0x0005 // nolint + SRTP_NULL_HMAC_SHA1_32 SRTPProtectionProfile = 0x0006 // nolint SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = 0x0007 // nolint SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = 0x0008 // nolint ) @@ -18,6 +22,10 @@ func srtpProtectionProfiles() map[SRTPProtectionProfile]bool { return map[SRTPProtectionProfile]bool{ SRTP_AES128_CM_HMAC_SHA1_80: true, SRTP_AES128_CM_HMAC_SHA1_32: true, + SRTP_AES256_CM_SHA1_80: true, + SRTP_AES256_CM_SHA1_32: true, + SRTP_NULL_HMAC_SHA1_80: true, + SRTP_NULL_HMAC_SHA1_32: true, SRTP_AEAD_AES_128_GCM: true, SRTP_AEAD_AES_256_GCM: true, } diff --git a/pkg/protocol/extension/supported_elliptic_curves.go b/pkg/protocol/extension/supported_elliptic_curves.go index dd9b54f0d..e3e87634b 100644 --- a/pkg/protocol/extension/supported_elliptic_curves.go +++ b/pkg/protocol/extension/supported_elliptic_curves.go @@ -6,7 +6,7 @@ package extension import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" ) const ( @@ -21,28 +21,28 @@ type SupportedEllipticCurves struct { EllipticCurves []elliptic.Curve } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s SupportedEllipticCurves) TypeValue() TypeValue { return SupportedEllipticCurvesTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *SupportedEllipticCurves) Marshal() ([]byte, error) { out := make([]byte, supportedGroupsHeaderSize) binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.EllipticCurves)*2))) - binary.BigEndian.PutUint16(out[4:], uint16(len(s.EllipticCurves)*2)) + binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.EllipticCurves)*2))) //nolint:gosec // G115 + binary.BigEndian.PutUint16(out[4:], uint16(len(s.EllipticCurves)*2)) //nolint:gosec // G115 for _, v := range s.EllipticCurves { - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v)) } return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (s *SupportedEllipticCurves) Unmarshal(data []byte) error { if len(data) <= supportedGroupsHeaderSize { return errBufferTooSmall @@ -61,5 +61,6 @@ func (s *SupportedEllipticCurves) Unmarshal(data []byte) error { s.EllipticCurves = append(s.EllipticCurves, supportedGroupID) } } + return nil } diff --git a/pkg/protocol/extension/supported_elliptic_curves_test.go b/pkg/protocol/extension/supported_elliptic_curves_test.go index c00554be1..620246a81 100644 --- a/pkg/protocol/extension/supported_elliptic_curves_test.go +++ b/pkg/protocol/extension/supported_elliptic_curves_test.go @@ -7,7 +7,7 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" ) func TestExtensionSupportedGroups(t *testing.T) { @@ -18,8 +18,15 @@ func TestExtensionSupportedGroups(t *testing.T) { raw, err := parsedSupportedGroups.Marshal() if err != nil { - t.Error(err) + t.Fatal(err) } else if !reflect.DeepEqual(raw, rawSupportedGroups) { - t.Errorf("extensionSupportedGroups marshal: got %#v, want %#v", raw, rawSupportedGroups) + t.Fatalf("extensionSupportedGroups marshal: got %#v, want %#v", raw, rawSupportedGroups) + } + + roundtrip := &SupportedEllipticCurves{} + if err := roundtrip.Unmarshal(raw); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(roundtrip, parsedSupportedGroups) { + t.Errorf("extensionSupportedGroups unmarshal: got %#v, want %#v", roundtrip, parsedSupportedGroups) } } diff --git a/pkg/protocol/extension/supported_point_formats.go b/pkg/protocol/extension/supported_point_formats.go index 9c2543e6e..77dc4fd50 100644 --- a/pkg/protocol/extension/supported_point_formats.go +++ b/pkg/protocol/extension/supported_point_formats.go @@ -6,7 +6,7 @@ package extension import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" ) const ( @@ -21,35 +21,38 @@ type SupportedPointFormats struct { PointFormats []elliptic.CurvePointFormat } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s SupportedPointFormats) TypeValue() TypeValue { return SupportedPointFormatsTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *SupportedPointFormats) Marshal() ([]byte, error) { out := make([]byte, supportedPointFormatsSize) binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(1+(len(s.PointFormats)))) + binary.BigEndian.PutUint16(out[2:], uint16(1+(len(s.PointFormats)))) //nolint:gosec // G115 out[4] = byte(len(s.PointFormats)) for _, v := range s.PointFormats { - out = append(out, byte(v)) + out = append(out, byte(v)) //nolint:makezero // todo: fix } + return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (s *SupportedPointFormats) Unmarshal(data []byte) error { if len(data) <= supportedPointFormatsSize { return errBufferTooSmall - } else if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() { + } + + if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() { return errInvalidExtensionType } - pointFormatCount := int(binary.BigEndian.Uint16(data[4:])) - if supportedGroupsHeaderSize+(pointFormatCount) > len(data) { + pointFormatCount := int(data[4]) + if supportedPointFormatsSize+pointFormatCount > len(data) { return errLengthMismatch } @@ -61,5 +64,6 @@ func (s *SupportedPointFormats) Unmarshal(data []byte) error { default: } } + return nil } diff --git a/pkg/protocol/extension/supported_point_formats_test.go b/pkg/protocol/extension/supported_point_formats_test.go index 7db3f0135..f57bdd0ca 100644 --- a/pkg/protocol/extension/supported_point_formats_test.go +++ b/pkg/protocol/extension/supported_point_formats_test.go @@ -7,7 +7,7 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" ) func TestExtensionSupportedPointFormats(t *testing.T) { @@ -18,8 +18,18 @@ func TestExtensionSupportedPointFormats(t *testing.T) { raw, err := parsedExtensionSupportedPointFormats.Marshal() if err != nil { - t.Error(err) + t.Fatal(err) } else if !reflect.DeepEqual(raw, rawExtensionSupportedPointFormats) { - t.Errorf("extensionSupportedPointFormats marshal: got %#v, want %#v", raw, rawExtensionSupportedPointFormats) + t.Fatalf("extensionSupportedPointFormats marshal: got %#v, want %#v", raw, rawExtensionSupportedPointFormats) + } + + roundtrip := &SupportedPointFormats{} + if err := roundtrip.Unmarshal(raw); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(roundtrip, parsedExtensionSupportedPointFormats) { + t.Errorf( + "extensionSupportedPointFormats unmarshal: got %#v, want %#v", + roundtrip, parsedExtensionSupportedPointFormats, + ) } } diff --git a/pkg/protocol/extension/supported_signature_algorithms.go b/pkg/protocol/extension/supported_signature_algorithms.go index 2ff4b90b6..e7ad0d422 100644 --- a/pkg/protocol/extension/supported_signature_algorithms.go +++ b/pkg/protocol/extension/supported_signature_algorithms.go @@ -6,9 +6,9 @@ package extension import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" ) const ( @@ -23,20 +23,20 @@ type SupportedSignatureAlgorithms struct { SignatureHashAlgorithms []signaturehash.Algorithm } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (s SupportedSignatureAlgorithms) TypeValue() TypeValue { return SupportedSignatureAlgorithmsTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (s *SupportedSignatureAlgorithms) Marshal() ([]byte, error) { out := make([]byte, supportedSignatureAlgorithmsHeaderSize) binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.SignatureHashAlgorithms)*2))) - binary.BigEndian.PutUint16(out[4:], uint16(len(s.SignatureHashAlgorithms)*2)) + binary.BigEndian.PutUint16(out[2:], uint16(2+(len(s.SignatureHashAlgorithms)*2))) //nolint:gosec // G115 + binary.BigEndian.PutUint16(out[4:], uint16(len(s.SignatureHashAlgorithms)*2)) //nolint:gosec // G115 for _, v := range s.SignatureHashAlgorithms { - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix out[len(out)-2] = byte(v.Hash) out[len(out)-1] = byte(v.Signature) } @@ -44,7 +44,7 @@ func (s *SupportedSignatureAlgorithms) Marshal() ([]byte, error) { return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (s *SupportedSignatureAlgorithms) Unmarshal(data []byte) error { if len(data) <= supportedSignatureAlgorithmsHeaderSize { return errBufferTooSmall diff --git a/pkg/protocol/extension/supported_signature_algorithms_test.go b/pkg/protocol/extension/supported_signature_algorithms_test.go index dbea9c126..653d4b558 100644 --- a/pkg/protocol/extension/supported_signature_algorithms_test.go +++ b/pkg/protocol/extension/supported_signature_algorithms_test.go @@ -7,9 +7,9 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" ) func TestExtensionSupportedSignatureAlgorithms(t *testing.T) { @@ -31,8 +31,21 @@ func TestExtensionSupportedSignatureAlgorithms(t *testing.T) { raw, err := parsedExtensionSupportedSignatureAlgorithms.Marshal() if err != nil { - t.Error(err) + t.Fatal(err) } else if !reflect.DeepEqual(raw, rawExtensionSupportedSignatureAlgorithms) { - t.Errorf("extensionSupportedSignatureAlgorithms marshal: got %#v, want %#v", raw, rawExtensionSupportedSignatureAlgorithms) + t.Fatalf( + "extensionSupportedSignatureAlgorithms marshal: got %#v, want %#v", + raw, rawExtensionSupportedSignatureAlgorithms, + ) + } + + roundtrip := &SupportedSignatureAlgorithms{} + if err := roundtrip.Unmarshal(raw); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(roundtrip, parsedExtensionSupportedSignatureAlgorithms) { + t.Errorf( + "extensionSupportedSignatureAlgorithms unmarshal: got %#v, want %#v", + roundtrip, parsedExtensionSupportedSignatureAlgorithms, + ) } } diff --git a/pkg/protocol/extension/use_master_secret.go b/pkg/protocol/extension/use_master_secret.go index d0b70cafb..fcf5dd289 100644 --- a/pkg/protocol/extension/use_master_secret.go +++ b/pkg/protocol/extension/use_master_secret.go @@ -16,12 +16,12 @@ type UseExtendedMasterSecret struct { Supported bool } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (u UseExtendedMasterSecret) TypeValue() TypeValue { return UseExtendedMasterSecretTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (u *UseExtendedMasterSecret) Marshal() ([]byte, error) { if !u.Supported { return []byte{}, nil @@ -31,10 +31,11 @@ func (u *UseExtendedMasterSecret) Marshal() ([]byte, error) { binary.BigEndian.PutUint16(out, uint16(u.TypeValue())) binary.BigEndian.PutUint16(out[2:], uint16(0)) // length + return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (u *UseExtendedMasterSecret) Unmarshal(data []byte) error { if len(data) < useExtendedMasterSecretHeaderSize { return errBufferTooSmall diff --git a/pkg/protocol/extension/use_srtp.go b/pkg/protocol/extension/use_srtp.go index ea9f10872..4e0410cae 100644 --- a/pkg/protocol/extension/use_srtp.go +++ b/pkg/protocol/extension/use_srtp.go @@ -3,7 +3,9 @@ package extension -import "encoding/binary" +import ( + "encoding/binary" +) const ( useSRTPHeaderSize = 6 @@ -14,32 +16,42 @@ const ( // // https://tools.ietf.org/html/rfc8422 type UseSRTP struct { - ProtectionProfiles []SRTPProtectionProfile + ProtectionProfiles []SRTPProtectionProfile + MasterKeyIdentifier []byte } -// TypeValue returns the extension TypeValue +// TypeValue returns the extension TypeValue. func (u UseSRTP) TypeValue() TypeValue { return UseSRTPTypeValue } -// Marshal encodes the extension +// Marshal encodes the extension. func (u *UseSRTP) Marshal() ([]byte, error) { out := make([]byte, useSRTPHeaderSize) binary.BigEndian.PutUint16(out, uint16(u.TypeValue())) - binary.BigEndian.PutUint16(out[2:], uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1)) - binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2)) + //nolint:gosec // G115 + binary.BigEndian.PutUint16( + out[2:], + uint16(2+(len(u.ProtectionProfiles)*2)+ /* MKI Length */ 1+len(u.MasterKeyIdentifier)), + ) + binary.BigEndian.PutUint16(out[4:], uint16(len(u.ProtectionProfiles)*2)) //nolint:gosec // G115 for _, v := range u.ProtectionProfiles { - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix binary.BigEndian.PutUint16(out[len(out)-2:], uint16(v)) } + if len(u.MasterKeyIdentifier) > 255 { + return nil, errMasterKeyIdentifierTooLarge + } + + out = append(out, byte(len(u.MasterKeyIdentifier))) //nolint:makezero // todo: fix + out = append(out, u.MasterKeyIdentifier...) //nolint:makezero // todo: fix - out = append(out, 0x00) /* MKI Length */ return out, nil } -// Unmarshal populates the extension from encoded data +// Unmarshal populates the extension from encoded data. func (u *UseSRTP) Unmarshal(data []byte) error { if len(data) <= useSRTPHeaderSize { return errBufferTooSmall @@ -48,7 +60,8 @@ func (u *UseSRTP) Unmarshal(data []byte) error { } profileCount := int(binary.BigEndian.Uint16(data[4:]) / 2) - if supportedGroupsHeaderSize+(profileCount*2) > len(data) { + masterKeyIdentifierIndex := supportedGroupsHeaderSize + (profileCount * 2) + if masterKeyIdentifierIndex+1 > len(data) { return errLengthMismatch } @@ -58,5 +71,16 @@ func (u *UseSRTP) Unmarshal(data []byte) error { u.ProtectionProfiles = append(u.ProtectionProfiles, supportedProfile) } } + + masterKeyIdentifierLen := int(data[masterKeyIdentifierIndex]) + if masterKeyIdentifierIndex+masterKeyIdentifierLen >= len(data) { + return errLengthMismatch + } + + u.MasterKeyIdentifier = append( + []byte{}, + data[masterKeyIdentifierIndex+1:masterKeyIdentifierIndex+1+masterKeyIdentifierLen]..., + ) + return nil } diff --git a/pkg/protocol/extension/use_srtp_test.go b/pkg/protocol/extension/use_srtp_test.go index 25b7b9e12..c88c61c15 100644 --- a/pkg/protocol/extension/use_srtp_test.go +++ b/pkg/protocol/extension/use_srtp_test.go @@ -4,20 +4,76 @@ package extension import ( + "errors" "reflect" "testing" ) -func TestExtensionUseSRTP(t *testing.T) { - rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00} - parsedUseSRTP := &UseSRTP{ - ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - } - - raw, err := parsedUseSRTP.Marshal() - if err != nil { - t.Error(err) - } else if !reflect.DeepEqual(raw, rawUseSRTP) { - t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", raw, rawUseSRTP) - } +func TestExtensionUseSRTP(t *testing.T) { //nolint:cyclop + t.Run("No MasterKeyIdentifier", func(t *testing.T) { + rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00} + parsedUseSRTP := &UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + MasterKeyIdentifier: []byte{}, + } + + marshaled, err := parsedUseSRTP.Marshal() + if err != nil { + t.Error(err) + } else if !reflect.DeepEqual(marshaled, rawUseSRTP) { + t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", marshaled, rawUseSRTP) + } + + unmarshaled := &UseSRTP{} + if err := unmarshaled.Unmarshal(rawUseSRTP); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(unmarshaled, parsedUseSRTP) { + t.Errorf("extensionUseSRTP unmarshal: got %#v, want %#v", unmarshaled, parsedUseSRTP) + } + }) + + t.Run("With MasterKeyIdentifier", func(t *testing.T) { + rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE} + parsedUseSRTP := &UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + MasterKeyIdentifier: []byte{0xA, 0xB, 0xC, 0xD, 0xE}, + } + + marshaled, err := parsedUseSRTP.Marshal() + if err != nil { + t.Error(err) + } else if !reflect.DeepEqual(marshaled, rawUseSRTP) { + t.Errorf("extensionUseSRTP marshal: got %#v, want %#v", marshaled, rawUseSRTP) + } + + unmarshaled := &UseSRTP{} + if err := unmarshaled.Unmarshal(rawUseSRTP); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(unmarshaled, parsedUseSRTP) { + t.Errorf("extensionUseSRTP unmarshal: got %#v, want %#v", unmarshaled, parsedUseSRTP) + } + }) + + t.Run("Invalid Lengths", func(t *testing.T) { + unmarshaled := &UseSRTP{} + + if err := unmarshaled.Unmarshal( + []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x04, 0x00, 0x01, 0x00}, + ); !errors.Is(errLengthMismatch, err) { + t.Error(err) + } + + if err := unmarshaled.Unmarshal( + []byte{0x00, 0x0e, 0x00, 0x0a, 0x00, 0x02, 0x00, 0x01, 0x01}, + ); !errors.Is(errLengthMismatch, err) { + t.Error(err) + } + + if _, err := (&UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + MasterKeyIdentifier: make([]byte, 500), + }).Marshal(); !errors.Is(errMasterKeyIdentifierTooLarge, err) { + panic(err) + } + }) } diff --git a/pkg/protocol/handshake/cipher_suite.go b/pkg/protocol/handshake/cipher_suite.go index b29629717..49d2b7407 100644 --- a/pkg/protocol/handshake/cipher_suite.go +++ b/pkg/protocol/handshake/cipher_suite.go @@ -18,15 +18,17 @@ func decodeCipherSuiteIDs(buf []byte) ([]uint16, error) { rtrn[i] = binary.BigEndian.Uint16(buf[(i*2)+2:]) } + return rtrn, nil } func encodeCipherSuiteIDs(cipherSuiteIDs []uint16) []byte { out := []byte{0x00, 0x00} - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuiteIDs)*2)) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(cipherSuiteIDs)*2)) //nolint:gosec // G115 for _, id := range cipherSuiteIDs { out = append(out, []byte{0x00, 0x00}...) binary.BigEndian.PutUint16(out[len(out)-2:], id) } + return out } diff --git a/pkg/protocol/handshake/errors.go b/pkg/protocol/handshake/errors.go index 1354300c4..20794f78c 100644 --- a/pkg/protocol/handshake/errors.go +++ b/pkg/protocol/handshake/errors.go @@ -6,23 +6,51 @@ package handshake import ( "errors" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol" ) -// Typed errors +// Typed errors. var ( - errUnableToMarshalFragmented = &protocol.InternalError{Err: errors.New("unable to marshal fragmented handshakes")} //nolint:goerr113 - errHandshakeMessageUnset = &protocol.InternalError{Err: errors.New("handshake message unset, unable to marshal")} //nolint:goerr113 - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 - errInvalidClientKeyExchange = &protocol.FatalError{Err: errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity")} //nolint:goerr113 - errInvalidHashAlgorithm = &protocol.FatalError{Err: errors.New("invalid hash algorithm")} //nolint:goerr113 - errInvalidSignatureAlgorithm = &protocol.FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113 - errCookieTooLong = &protocol.FatalError{Err: errors.New("cookie must not be longer then 255 bytes")} //nolint:goerr113 - errInvalidEllipticCurveType = &protocol.FatalError{Err: errors.New("invalid or unknown elliptic curve type")} //nolint:goerr113 - errInvalidNamedCurve = &protocol.FatalError{Err: errors.New("invalid named curve")} //nolint:goerr113 - errCipherSuiteUnset = &protocol.FatalError{Err: errors.New("server hello can not be created without a cipher suite")} //nolint:goerr113 - errCompressionMethodUnset = &protocol.FatalError{Err: errors.New("server hello can not be created without a compression method")} //nolint:goerr113 - errInvalidCompressionMethod = &protocol.FatalError{Err: errors.New("invalid or unknown compression method")} //nolint:goerr113 - errNotImplemented = &protocol.InternalError{Err: errors.New("feature has not been implemented yet")} //nolint:goerr113 + errUnableToMarshalFragmented = &protocol.InternalError{ + Err: errors.New("unable to marshal fragmented handshakes"), //nolint:err113 + } + errHandshakeMessageUnset = &protocol.InternalError{ + Err: errors.New("handshake message unset, unable to marshal"), //nolint:err113 + } + errBufferTooSmall = &protocol.TemporaryError{ + Err: errors.New("buffer is too small"), //nolint:err113 + } + errLengthMismatch = &protocol.InternalError{ + Err: errors.New("data length and declared length do not match"), //nolint:err113 + } + errInvalidClientKeyExchange = &protocol.FatalError{ + Err: errors.New("unable to determine if ClientKeyExchange is a public key or PSK Identity"), //nolint:err113 + } + errInvalidHashAlgorithm = &protocol.FatalError{ + Err: errors.New("invalid hash algorithm"), //nolint:err113 + } + errInvalidSignatureAlgorithm = &protocol.FatalError{ + Err: errors.New("invalid signature algorithm"), //nolint:err113 + } + errCookieTooLong = &protocol.FatalError{ + Err: errors.New("cookie must not be longer then 255 bytes"), //nolint:err113 + } + errInvalidEllipticCurveType = &protocol.FatalError{ + Err: errors.New("invalid or unknown elliptic curve type"), //nolint:err113 + } + errInvalidNamedCurve = &protocol.FatalError{ + Err: errors.New("invalid named curve"), //nolint:err113 + } + errCipherSuiteUnset = &protocol.FatalError{ + Err: errors.New("server hello can not be created without a cipher suite"), //nolint:err113 + } + errCompressionMethodUnset = &protocol.FatalError{ + Err: errors.New("server hello can not be created without a compression method"), //nolint:err113 + } + errInvalidCompressionMethod = &protocol.FatalError{ + Err: errors.New("invalid or unknown compression method"), //nolint:err113 + } + errNotImplemented = &protocol.InternalError{ + Err: errors.New("feature has not been implemented yet"), //nolint:err113 + } ) diff --git a/pkg/protocol/handshake/fuzz_test.go b/pkg/protocol/handshake/fuzz_test.go index 837459f85..bce66ff3d 100644 --- a/pkg/protocol/handshake/fuzz_test.go +++ b/pkg/protocol/handshake/fuzz_test.go @@ -8,17 +8,64 @@ import ( ) func FuzzDtlsHandshake(f *testing.F) { - f.Fuzz(func(t *testing.T, data []byte) { + rawCertificateRequest := []byte{ + 0x02, 0x01, 0x40, 0x00, 0x0C, 0x04, 0x03, 0x04, 0x01, 0x05, 0x03, + 0x05, 0x01, 0x06, 0x01, 0x02, 0x01, 0x00, 0x06, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, + } + rawCertificateVerify := []byte{ + 0x04, 0x03, 0x00, 0x47, 0x30, 0x45, 0x02, 0x20, 0x6b, 0x63, 0x17, 0xad, 0xbe, 0xb7, 0x7b, 0x0f, + 0x86, 0x73, 0x39, 0x1e, 0xba, 0xb3, 0x50, 0x9c, 0xce, 0x9c, 0xe4, 0x8b, 0xe5, 0x13, 0x07, 0x59, + 0x18, 0x1f, 0xe5, 0xa0, 0x2b, 0xca, 0xa6, 0xad, 0x02, 0x21, 0x00, 0xd3, 0xb5, 0x01, 0xbe, 0x87, + 0x6c, 0x04, 0xa1, 0xdc, 0x28, 0xaa, 0x5f, 0xf7, 0x1e, 0x9c, 0xc0, 0x1e, 0x00, 0x2c, 0xe5, 0x94, + 0xbb, 0x03, 0x0e, 0xf1, 0xcb, 0x28, 0x22, 0x33, 0x23, 0x88, 0xad, + } + rawClientHello := []byte{ + 0xfe, 0xfd, 0xb6, 0x2f, 0xce, 0x5c, 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, + 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, + 0xd8, 0x3d, 0xdc, 0x4b, 0x00, 0x14, 0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, + 0xd6, 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8, 0x00, 0x04, 0xc0, 0x2b, + 0xc0, 0x0a, 0x01, 0x00, 0x00, 0x08, 0x00, 0x0a, 0x00, 0x04, 0x00, 0x02, 0x00, 0x1d, + } + rawClientKeyExchange := []byte{ + 0x20, 0x26, 0x78, 0x4a, 0x78, 0x70, 0xc1, 0xf9, 0x71, 0xea, 0x50, 0x4a, 0xb5, 0xbb, 0x00, 0x76, + 0x02, 0x05, 0xda, 0xf7, 0xd0, 0x3f, 0xe3, 0xf7, 0x4e, 0x8a, 0x14, 0x6f, 0xb7, 0xe0, 0xc0, 0xff, + 0x54, + } + rawFinished := []byte{ + 0x01, 0x01, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + } + rawHelloVerifyRequest := []byte{ + 0xfe, 0xff, 0x14, 0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, + 0xeb, 0xad, 0xe2, 0xef, 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd, + } + rawServerHello := []byte{ + 0xfe, 0xfd, 0x21, 0x63, 0x32, 0x21, 0x81, 0x0e, 0x98, 0x6c, + 0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20, + 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, + 0xcf, 0x63, 0x84, 0x28, 0x00, 0xc0, 0x2b, 0x00, 0x00, 0x00, + } + rawServerKeyExchange := []byte{ + 0x03, 0x00, 0x1d, 0x41, 0x04, 0x0c, 0xb9, 0xa3, 0xb9, 0x90, 0x71, 0x35, 0x4a, 0x08, 0x66, 0xaf, + 0xd6, 0x88, 0x58, 0x29, 0x69, 0x98, 0xf1, 0x87, 0x0f, 0xb5, 0xa8, 0xcd, 0x92, 0xf6, 0x2b, 0x08, + 0x0c, 0xd4, 0x16, 0x5b, 0xcc, 0x81, 0xf2, 0x58, 0x91, 0x8e, 0x62, 0xdf, 0xc1, 0xec, 0x72, 0xe8, + 0x47, 0x24, 0x42, 0x96, 0xb8, 0x7b, 0xee, 0xe7, 0x0d, 0xdc, 0x44, 0xec, 0xf3, 0x97, 0x6b, 0x1b, + 0x45, 0x28, 0xac, 0x3f, 0x35, 0x02, 0x03, 0x00, 0x47, 0x30, 0x45, 0x02, 0x21, 0x00, 0xb2, 0x0b, + 0x22, 0x95, 0x3d, 0x56, 0x57, 0x6a, 0x3f, 0x85, 0x30, 0x6f, 0x55, 0xc3, 0xf4, 0x24, 0x1b, 0x21, + 0x07, 0xe5, 0xdf, 0xba, 0x24, 0x02, 0x68, 0x95, 0x1f, 0x6e, 0x13, 0xbd, 0x9f, 0xaa, 0x02, 0x20, + 0x49, 0x9c, 0x9d, 0xdf, 0x84, 0x60, 0x33, 0x27, 0x96, 0x9e, 0x58, 0x6d, 0x72, 0x13, 0xe7, 0x3a, + 0xe8, 0xdf, 0x43, 0x75, 0xc7, 0xb9, 0x37, 0x6e, 0x90, 0xe5, 0x3b, 0x81, 0xd4, 0xda, 0x68, 0xcd, + } + f.Add(rawCertificateRequest) + f.Add(rawCertificateVerify) + f.Add(rawClientHello) + f.Add(rawClientKeyExchange) + f.Add(rawFinished) + f.Add(rawHelloVerifyRequest) + f.Add(rawServerHello) + f.Add(rawServerKeyExchange) + + f.Fuzz(func(_ *testing.T, data []byte) { h := &Handshake{} - if err := h.Unmarshal(data); err != nil { - return - } - buf, err := h.Marshal() - if err != nil { - t.Fatal(err) - } - if len(buf) == 0 { - t.Fatal("Zero buff") - } + _ = h.Unmarshal(data) }) } diff --git a/pkg/protocol/handshake/handshake.go b/pkg/protocol/handshake/handshake.go index b1f682bf5..9c6187711 100644 --- a/pkg/protocol/handshake/handshake.go +++ b/pkg/protocol/handshake/handshake.go @@ -5,16 +5,16 @@ package handshake import ( - "github.com/pion/dtls/v2/internal/ciphersuite/types" - "github.com/pion/dtls/v2/internal/util" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/protocol" ) // Type is the unique identifier for each handshake message // https://tools.ietf.org/html/rfc5246#section-7.4 type Type uint8 -// Types of DTLS Handshake messages we know about +// Types of DTLS Handshake messages we know about. const ( TypeHelloRequest Type = 0 TypeClientHello Type = 1 @@ -29,8 +29,8 @@ const ( TypeFinished Type = 20 ) -// String returns the string representation of this type -func (t Type) String() string { +// String returns the string representation of this type. +func (t Type) String() string { //nolint:cyclop switch t { case TypeHelloRequest: return "HelloRequest" @@ -55,10 +55,11 @@ func (t Type) String() string { case TypeFinished: return "Finished" } + return "" } -// Message is the body of a Handshake datagram +// Message is the body of a Handshake datagram. type Message interface { Marshal() ([]byte, error) Unmarshal(data []byte) error @@ -78,12 +79,12 @@ type Handshake struct { KeyExchangeAlgorithm types.KeyExchangeAlgorithm } -// ContentType returns what kind of content this message is carying +// ContentType returns what kind of content this message is carying. func (h Handshake) ContentType() protocol.ContentType { return protocol.ContentTypeHandshake } -// Marshal encodes a handshake into a binary message +// Marshal encodes a handshake into a binary message. func (h *Handshake) Marshal() ([]byte, error) { if h.Message == nil { return nil, errHandshakeMessageUnset @@ -96,7 +97,7 @@ func (h *Handshake) Marshal() ([]byte, error) { return nil, err } - h.Header.Length = uint32(len(msg)) + h.Header.Length = uint32(len(msg)) //nolint:gosec // G115 h.Header.FragmentLength = h.Header.Length h.Header.Type = h.Message.Type() header, err := h.Header.Marshal() @@ -107,14 +108,14 @@ func (h *Handshake) Marshal() ([]byte, error) { return append(header, msg...), nil } -// Unmarshal decodes a handshake from a binary message -func (h *Handshake) Unmarshal(data []byte) error { +// Unmarshal decodes a handshake from a binary message. +func (h *Handshake) Unmarshal(data []byte) error { //nolint:cyclop if err := h.Header.Unmarshal(data); err != nil { return err } reportedLen := util.BigEndianUint24(data[1:]) - if uint32(len(data)-HeaderLength) != reportedLen { + if uint32(len(data)-HeaderLength) != reportedLen { //nolint:gosec // G115 return errLengthMismatch } else if reportedLen != h.Header.FragmentLength { return errLengthMismatch @@ -146,5 +147,6 @@ func (h *Handshake) Unmarshal(data []byte) error { default: return errNotImplemented } + return h.Message.Unmarshal(data[HeaderLength:]) } diff --git a/pkg/protocol/handshake/header.go b/pkg/protocol/handshake/header.go index 4f9a96287..4e909de54 100644 --- a/pkg/protocol/handshake/header.go +++ b/pkg/protocol/handshake/header.go @@ -6,11 +6,11 @@ package handshake import ( "encoding/binary" - "github.com/pion/dtls/v2/internal/util" + "github.com/pion/dtls/v3/internal/util" ) // HeaderLength msg_len for Handshake messages assumes an extra -// 12 bytes for sequence, fragment and version information vs TLS +// 12 bytes for sequence, fragment and version information vs TLS. const HeaderLength = 12 // Header is the static first 12 bytes of each RecordLayer @@ -26,7 +26,7 @@ type Header struct { FragmentLength uint32 // uint24 in spec } -// Marshal encodes the Header +// Marshal encodes the Header. func (h *Header) Marshal() ([]byte, error) { out := make([]byte, HeaderLength) @@ -35,10 +35,11 @@ func (h *Header) Marshal() ([]byte, error) { binary.BigEndian.PutUint16(out[4:], h.MessageSequence) util.PutBigEndianUint24(out[6:], h.FragmentOffset) util.PutBigEndianUint24(out[9:], h.FragmentLength) + return out, nil } -// Unmarshal populates the header from encoded data +// Unmarshal populates the header from encoded data. func (h *Header) Unmarshal(data []byte) error { if len(data) < HeaderLength { return errBufferTooSmall @@ -49,5 +50,6 @@ func (h *Header) Unmarshal(data []byte) error { h.MessageSequence = binary.BigEndian.Uint16(data[4:]) h.FragmentOffset = util.BigEndianUint24(data[6:]) h.FragmentLength = util.BigEndianUint24(data[9:]) + return nil } diff --git a/pkg/protocol/handshake/message_certificate.go b/pkg/protocol/handshake/message_certificate.go index d5c861d90..27d2ea99a 100644 --- a/pkg/protocol/handshake/message_certificate.go +++ b/pkg/protocol/handshake/message_certificate.go @@ -4,7 +4,7 @@ package handshake import ( - "github.com/pion/dtls/v2/internal/util" + "github.com/pion/dtls/v3/internal/util" ) // MessageCertificate is a DTLS Handshake Message @@ -15,7 +15,7 @@ type MessageCertificate struct { Certificate [][]byte } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageCertificate) Type() Type { return TypeCertificate } @@ -24,31 +24,36 @@ const ( handshakeMessageCertificateLengthFieldSize = 3 ) -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageCertificate) Marshal() ([]byte, error) { out := make([]byte, handshakeMessageCertificateLengthFieldSize) for _, r := range m.Certificate { // Certificate Length + //nolint:makezero // todo: fix out = append(out, make([]byte, handshakeMessageCertificateLengthFieldSize)...) + //nolint:gosec // G115 util.PutBigEndianUint24(out[len(out)-handshakeMessageCertificateLengthFieldSize:], uint32(len(r))) // Certificate body - out = append(out, append([]byte{}, r...)...) + out = append(out, append([]byte{}, r...)...) //nolint:makezero // todo: fix } // Total Payload Size - util.PutBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:]))) + util.PutBigEndianUint24(out[0:], uint32(len(out[handshakeMessageCertificateLengthFieldSize:]))) //nolint:gosec //G115 + return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageCertificate) Unmarshal(data []byte) error { if len(data) < handshakeMessageCertificateLengthFieldSize { return errBufferTooSmall } - if certificateBodyLen := int(util.BigEndianUint24(data)); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) { + if certificateBodyLen := int(util.BigEndianUint24( + data, + )); certificateBodyLen+handshakeMessageCertificateLengthFieldSize != len(data) { return errLengthMismatch } diff --git a/pkg/protocol/handshake/message_certificate_request.go b/pkg/protocol/handshake/message_certificate_request.go index 11a44d440..28dabf35a 100644 --- a/pkg/protocol/handshake/message_certificate_request.go +++ b/pkg/protocol/handshake/message_certificate_request.go @@ -6,10 +6,10 @@ package handshake import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" ) /* @@ -31,12 +31,12 @@ const ( messageCertificateRequestMinLength = 5 ) -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageCertificateRequest) Type() Type { return TypeCertificateRequest } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageCertificateRequest) Marshal() ([]byte, error) { out := []byte{byte(len(m.CertificateTypes))} for _, v := range m.CertificateTypes { @@ -44,7 +44,7 @@ func (m *MessageCertificateRequest) Marshal() ([]byte, error) { } out = append(out, []byte{0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.SignatureHashAlgorithms)*2)) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.SignatureHashAlgorithms)*2)) //nolint:gosec //G115 for _, v := range m.SignatureHashAlgorithms { out = append(out, byte(v.Hash)) out = append(out, byte(v.Signature)) @@ -56,19 +56,20 @@ func (m *MessageCertificateRequest) Marshal() ([]byte, error) { casLength += len(ca) + 2 } out = append(out, []byte{0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(casLength)) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(casLength)) //nolint:gosec //G115 if casLength > 0 { for _, ca := range m.CertificateAuthoritiesNames { out = append(out, []byte{0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(ca))) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(ca))) //nolint:gosec //G115 out = append(out, ca...) } } + return out, nil } -// Unmarshal populates the message from encoded data -func (m *MessageCertificateRequest) Unmarshal(data []byte) error { +// Unmarshal populates the message from encoded data. +func (m *MessageCertificateRequest) Unmarshal(data []byte) error { //nolint:cyclop if len(data) < messageCertificateRequestMinLength { return errBufferTooSmall } diff --git a/pkg/protocol/handshake/message_certificate_request_test.go b/pkg/protocol/handshake/message_certificate_request_test.go index 99b360f80..78bcb5b41 100644 --- a/pkg/protocol/handshake/message_certificate_request_test.go +++ b/pkg/protocol/handshake/message_certificate_request_test.go @@ -8,10 +8,10 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" - "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" ) func TestHandshakeMessageCertificateRequest(t *testing.T) { diff --git a/pkg/protocol/handshake/message_certificate_test.go b/pkg/protocol/handshake/message_certificate_test.go index 760d5efdf..8abe42578 100644 --- a/pkg/protocol/handshake/message_certificate_test.go +++ b/pkg/protocol/handshake/message_certificate_test.go @@ -60,21 +60,21 @@ func TestHandshakeMessageCertificate(t *testing.T) { Version: 1, } - c := &MessageCertificate{} - if err := c.Unmarshal(rawCertificate); err != nil { + certMessage := &MessageCertificate{} + if err := certMessage.Unmarshal(rawCertificate); err != nil { t.Error(err) } else { - certificate, err := x509.ParseCertificate(c.Certificate[0]) + certificate, err := x509.ParseCertificate(certMessage.Certificate[0]) if err != nil { t.Error(err) } copyCertificatePrivateMembers(certificate, parsedCertificate) if !reflect.DeepEqual(certificate, parsedCertificate) { - t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", c, parsedCertificate) + t.Errorf("handshakeMessageCertificate unmarshal: got %#v, want %#v", certMessage, parsedCertificate) } } - raw, err := c.Marshal() + raw, err := certMessage.Marshal() if err != nil { t.Error(err) } else if !reflect.DeepEqual(raw, rawCertificate) { diff --git a/pkg/protocol/handshake/message_certificate_verify.go b/pkg/protocol/handshake/message_certificate_verify.go index 9e02a9c11..d10ffa035 100644 --- a/pkg/protocol/handshake/message_certificate_verify.go +++ b/pkg/protocol/handshake/message_certificate_verify.go @@ -6,8 +6,8 @@ package handshake import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" ) // MessageCertificateVerify provide explicit verification of a @@ -22,23 +22,24 @@ type MessageCertificateVerify struct { const handshakeMessageCertificateVerifyMinLength = 4 -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageCertificateVerify) Type() Type { return TypeCertificateVerify } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageCertificateVerify) Marshal() ([]byte, error) { out := make([]byte, 1+1+2+len(m.Signature)) out[0] = byte(m.HashAlgorithm) out[1] = byte(m.SignatureAlgorithm) - binary.BigEndian.PutUint16(out[2:], uint16(len(m.Signature))) + binary.BigEndian.PutUint16(out[2:], uint16(len(m.Signature))) //nolint:gosec // G115 copy(out[4:], m.Signature) + return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageCertificateVerify) Unmarshal(data []byte) error { if len(data) < handshakeMessageCertificateVerifyMinLength { return errBufferTooSmall @@ -60,5 +61,6 @@ func (m *MessageCertificateVerify) Unmarshal(data []byte) error { } m.Signature = append([]byte{}, data[4:]...) + return nil } diff --git a/pkg/protocol/handshake/message_certificate_verify_test.go b/pkg/protocol/handshake/message_certificate_verify_test.go index ea9cdce7c..3144e5c78 100644 --- a/pkg/protocol/handshake/message_certificate_verify_test.go +++ b/pkg/protocol/handshake/message_certificate_verify_test.go @@ -7,8 +7,8 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" ) func TestHandshakeMessageCertificateVerify(t *testing.T) { diff --git a/pkg/protocol/handshake/message_client_hello.go b/pkg/protocol/handshake/message_client_hello.go index bea6dd969..e7aa5e397 100644 --- a/pkg/protocol/handshake/message_client_hello.go +++ b/pkg/protocol/handshake/message_client_hello.go @@ -6,8 +6,8 @@ package handshake import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" ) /* @@ -31,12 +31,12 @@ type MessageClientHello struct { const handshakeMessageClientHelloVariableWidthStart = 34 -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageClientHello) Type() Type { return TypeClientHello } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageClientHello) Marshal() ([]byte, error) { if len(m.Cookie) > 255 { return nil, errCookieTooLong @@ -49,24 +49,24 @@ func (m *MessageClientHello) Marshal() ([]byte, error) { rand := m.Random.MarshalFixed() copy(out[2:], rand[:]) - out = append(out, byte(len(m.SessionID))) - out = append(out, m.SessionID...) + out = append(out, byte(len(m.SessionID))) //nolint:makezero // todo: fix + out = append(out, m.SessionID...) //nolint:makezero // todo: fix - out = append(out, byte(len(m.Cookie))) - out = append(out, m.Cookie...) - out = append(out, encodeCipherSuiteIDs(m.CipherSuiteIDs)...) - out = append(out, protocol.EncodeCompressionMethods(m.CompressionMethods)...) + out = append(out, byte(len(m.Cookie))) //nolint:makezero // todo: fix + out = append(out, m.Cookie...) //nolint:makezero // todo: fix + out = append(out, encodeCipherSuiteIDs(m.CipherSuiteIDs)...) //nolint:makezero // todo: fix + out = append(out, protocol.EncodeCompressionMethods(m.CompressionMethods)...) //nolint:makezero // todo: fix extensions, err := extension.Marshal(m.Extensions) if err != nil { return nil, err } - return append(out, extensions...), nil + return append(out, extensions...), nil //nolint:makezero // todo: fix } -// Unmarshal populates the message from encoded data -func (m *MessageClientHello) Unmarshal(data []byte) error { +// Unmarshal populates the message from encoded data. +func (m *MessageClientHello) Unmarshal(data []byte) error { //nolint:cyclop if len(data) < 2+RandomLength { return errBufferTooSmall } @@ -137,5 +137,6 @@ func (m *MessageClientHello) Unmarshal(data []byte) error { return err } m.Extensions = extensions + return nil } diff --git a/pkg/protocol/handshake/message_client_hello_test.go b/pkg/protocol/handshake/message_client_hello_test.go index eeb73b81e..f32fe1dff 100644 --- a/pkg/protocol/handshake/message_client_hello_test.go +++ b/pkg/protocol/handshake/message_client_hello_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" ) func TestHandshakeMessageClientHello(t *testing.T) { @@ -26,10 +26,16 @@ func TestHandshakeMessageClientHello(t *testing.T) { Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: Random{ GMTUnixTime: time.Unix(3056586332, 0), - RandomBytes: [28]byte{0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b}, + RandomBytes: [28]byte{ + 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, + 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, + }, }, SessionID: []byte{}, - Cookie: []byte{0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, 0xd6, 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8}, + Cookie: []byte{ + 0xe6, 0x14, 0x3a, 0x1b, 0x04, 0xea, 0x9e, 0x7a, 0x14, 0xd6, + 0x6c, 0x57, 0xd0, 0x0e, 0x32, 0x85, 0x76, 0x18, 0xde, 0xd8, + }, CipherSuiteIDs: []uint16{ 0xc02b, 0xc00a, diff --git a/pkg/protocol/handshake/message_client_key_exchange.go b/pkg/protocol/handshake/message_client_key_exchange.go index 2abcd5bf7..60361a94a 100644 --- a/pkg/protocol/handshake/message_client_key_exchange.go +++ b/pkg/protocol/handshake/message_client_key_exchange.go @@ -6,7 +6,7 @@ package handshake import ( "encoding/binary" - "github.com/pion/dtls/v2/internal/ciphersuite/types" + "github.com/pion/dtls/v3/internal/ciphersuite/types" ) // MessageClientKeyExchange is a DTLS Handshake Message @@ -24,12 +24,12 @@ type MessageClientKeyExchange struct { KeyExchangeAlgorithm types.KeyExchangeAlgorithm } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageClientKeyExchange) Type() Type { return TypeClientKeyExchange } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageClientKeyExchange) Marshal() (out []byte, err error) { if m.IdentityHint == nil && m.PublicKey == nil { return nil, errInvalidClientKeyExchange @@ -37,7 +37,7 @@ func (m *MessageClientKeyExchange) Marshal() (out []byte, err error) { if m.IdentityHint != nil { out = append([]byte{0x00, 0x00}, m.IdentityHint...) - binary.BigEndian.PutUint16(out, uint16(len(out)-2)) + binary.BigEndian.PutUint16(out, uint16(len(out)-2)) //nolint:gosec // G115 } if m.PublicKey != nil { @@ -48,7 +48,7 @@ func (m *MessageClientKeyExchange) Marshal() (out []byte, err error) { return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageClientKeyExchange) Unmarshal(data []byte) error { switch { case len(data) < 2: diff --git a/pkg/protocol/handshake/message_client_key_exchange_test.go b/pkg/protocol/handshake/message_client_key_exchange_test.go index 88f31af8b..c7d569542 100644 --- a/pkg/protocol/handshake/message_client_key_exchange_test.go +++ b/pkg/protocol/handshake/message_client_key_exchange_test.go @@ -7,7 +7,7 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/internal/ciphersuite/types" + "github.com/pion/dtls/v3/internal/ciphersuite/types" ) func TestHandshakeMessageClientKeyExchange(t *testing.T) { diff --git a/pkg/protocol/handshake/message_finished.go b/pkg/protocol/handshake/message_finished.go index 255aedd7e..f7187d88a 100644 --- a/pkg/protocol/handshake/message_finished.go +++ b/pkg/protocol/handshake/message_finished.go @@ -13,18 +13,19 @@ type MessageFinished struct { VerifyData []byte } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageFinished) Type() Type { return TypeFinished } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageFinished) Marshal() ([]byte, error) { return append([]byte{}, m.VerifyData...), nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageFinished) Unmarshal(data []byte) error { m.VerifyData = append([]byte{}, data...) + return nil } diff --git a/pkg/protocol/handshake/message_hello_verify_request.go b/pkg/protocol/handshake/message_hello_verify_request.go index 398e59cc3..7f5bc95aa 100644 --- a/pkg/protocol/handshake/message_hello_verify_request.go +++ b/pkg/protocol/handshake/message_hello_verify_request.go @@ -4,7 +4,7 @@ package handshake import ( - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol" ) // MessageHelloVerifyRequest is as follows: @@ -27,12 +27,12 @@ type MessageHelloVerifyRequest struct { Cookie []byte } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageHelloVerifyRequest) Type() Type { return TypeHelloVerifyRequest } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageHelloVerifyRequest) Marshal() ([]byte, error) { if len(m.Cookie) > 255 { return nil, errCookieTooLong @@ -47,7 +47,7 @@ func (m *MessageHelloVerifyRequest) Marshal() ([]byte, error) { return out, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageHelloVerifyRequest) Unmarshal(data []byte) error { if len(data) < 3 { return errBufferTooSmall @@ -61,5 +61,6 @@ func (m *MessageHelloVerifyRequest) Unmarshal(data []byte) error { m.Cookie = make([]byte, cookieLength) copy(m.Cookie, data[3:3+cookieLength]) + return nil } diff --git a/pkg/protocol/handshake/message_hello_verify_request_test.go b/pkg/protocol/handshake/message_hello_verify_request_test.go index 3513fc8ee..0cfd24b70 100644 --- a/pkg/protocol/handshake/message_hello_verify_request_test.go +++ b/pkg/protocol/handshake/message_hello_verify_request_test.go @@ -7,7 +7,7 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol" ) func TestHandshakeMessageHelloVerifyRequest(t *testing.T) { @@ -17,7 +17,10 @@ func TestHandshakeMessageHelloVerifyRequest(t *testing.T) { } parsedHelloVerifyRequest := &MessageHelloVerifyRequest{ Version: protocol.Version{Major: 0xFE, Minor: 0xFF}, - Cookie: []byte{0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, 0xeb, 0xad, 0xe2, 0xef, 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd}, + Cookie: []byte{ + 0x25, 0xfb, 0xee, 0xb3, 0x7c, 0x95, 0xcf, 0x00, 0xeb, 0xad, + 0xe2, 0xef, 0xc7, 0xfd, 0xbb, 0xed, 0xf7, 0x1f, 0x6c, 0xcd, + }, } h := &MessageHelloVerifyRequest{} diff --git a/pkg/protocol/handshake/message_server_hello.go b/pkg/protocol/handshake/message_server_hello.go index caf186da8..e2f19c1d9 100644 --- a/pkg/protocol/handshake/message_server_hello.go +++ b/pkg/protocol/handshake/message_server_hello.go @@ -6,8 +6,8 @@ package handshake import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" ) // MessageServerHello is sent in response to a ClientHello @@ -29,12 +29,12 @@ type MessageServerHello struct { const messageServerHelloVariableWidthStart = 2 + RandomLength -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageServerHello) Type() Type { return TypeServerHello } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageServerHello) Marshal() ([]byte, error) { if m.CipherSuiteID == nil { return nil, errCipherSuiteUnset @@ -49,23 +49,23 @@ func (m *MessageServerHello) Marshal() ([]byte, error) { rand := m.Random.MarshalFixed() copy(out[2:], rand[:]) - out = append(out, byte(len(m.SessionID))) - out = append(out, m.SessionID...) + out = append(out, byte(len(m.SessionID))) //nolint:makezero // todo: fix + out = append(out, m.SessionID...) //nolint:makezero // todo: fix - out = append(out, []byte{0x00, 0x00}...) + out = append(out, []byte{0x00, 0x00}...) //nolint:makezero // todo: fix binary.BigEndian.PutUint16(out[len(out)-2:], *m.CipherSuiteID) - out = append(out, byte(m.CompressionMethod.ID)) + out = append(out, byte(m.CompressionMethod.ID)) //nolint:makezero // todo: fix extensions, err := extension.Marshal(m.Extensions) if err != nil { return nil, err } - return append(out, extensions...), nil + return append(out, extensions...), nil //nolint:makezero // todo: fix } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageServerHello) Unmarshal(data []byte) error { if len(data) < 2+RandomLength { return errBufferTooSmall @@ -110,6 +110,7 @@ func (m *MessageServerHello) Unmarshal(data []byte) error { if len(data) <= currOffset { m.Extensions = []extension.Extension{} + return nil } @@ -118,5 +119,6 @@ func (m *MessageServerHello) Unmarshal(data []byte) error { return err } m.Extensions = extensions + return nil } diff --git a/pkg/protocol/handshake/message_server_hello_done.go b/pkg/protocol/handshake/message_server_hello_done.go index b187dd417..49a830c2a 100644 --- a/pkg/protocol/handshake/message_server_hello_done.go +++ b/pkg/protocol/handshake/message_server_hello_done.go @@ -5,20 +5,20 @@ package handshake // MessageServerHelloDone is final non-encrypted message from server // this communicates server has sent all its handshake messages and next -// should be MessageFinished +// should be MessageFinished. type MessageServerHelloDone struct{} -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageServerHelloDone) Type() Type { return TypeServerHelloDone } -// Marshal encodes the Handshake +// Marshal encodes the Handshake. func (m *MessageServerHelloDone) Marshal() ([]byte, error) { return []byte{}, nil } -// Unmarshal populates the message from encoded data +// Unmarshal populates the message from encoded data. func (m *MessageServerHelloDone) Unmarshal([]byte) error { return nil } diff --git a/pkg/protocol/handshake/message_server_hello_test.go b/pkg/protocol/handshake/message_server_hello_test.go index 14b713fcd..810dcd348 100644 --- a/pkg/protocol/handshake/message_server_hello_test.go +++ b/pkg/protocol/handshake/message_server_hello_test.go @@ -9,8 +9,8 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/extension" ) func TestHandshakeMessageServerHello(t *testing.T) { @@ -27,7 +27,10 @@ func TestHandshakeMessageServerHello(t *testing.T) { Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: Random{ GMTUnixTime: time.Unix(560149025, 0), - RandomBytes: [28]byte{0x81, 0x0e, 0x98, 0x6c, 0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20, 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, 0xcf, 0x63, 0x84, 0x28}, + RandomBytes: [28]byte{ + 0x81, 0x0e, 0x98, 0x6c, 0x85, 0x3d, 0xa4, 0x39, 0xaf, 0x5f, 0xd6, 0x5c, 0xcc, 0x20, + 0x7f, 0x7c, 0x78, 0xf1, 0x5f, 0x7e, 0x1c, 0xb7, 0xa1, 0x1e, 0xcf, 0x63, 0x84, 0x28, + }, }, SessionID: []byte{}, CipherSuiteID: &cipherSuiteID, diff --git a/pkg/protocol/handshake/message_server_key_exchange.go b/pkg/protocol/handshake/message_server_key_exchange.go index 82abbe0d4..59a5392f3 100644 --- a/pkg/protocol/handshake/message_server_key_exchange.go +++ b/pkg/protocol/handshake/message_server_key_exchange.go @@ -6,13 +6,13 @@ package handshake import ( "encoding/binary" - "github.com/pion/dtls/v2/internal/ciphersuite/types" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" ) -// MessageServerKeyExchange supports ECDH and PSK +// MessageServerKeyExchange supports ECDH and PSK. type MessageServerKeyExchange struct { IdentityHint []byte @@ -27,17 +27,17 @@ type MessageServerKeyExchange struct { KeyExchangeAlgorithm types.KeyExchangeAlgorithm } -// Type returns the Handshake Type +// Type returns the Handshake Type. func (m MessageServerKeyExchange) Type() Type { return TypeServerKeyExchange } -// Marshal encodes the Handshake -func (m *MessageServerKeyExchange) Marshal() ([]byte, error) { +// Marshal encodes the Handshake. +func (m *MessageServerKeyExchange) Marshal() ([]byte, error) { //nolint:cyclop var out []byte if m.IdentityHint != nil { out = append([]byte{0x00, 0x00}, m.IdentityHint...) - binary.BigEndian.PutUint16(out, uint16(len(out)-2)) + binary.BigEndian.PutUint16(out, uint16(len(out)-2)) //nolint:gosec //G115 } if m.EllipticCurveType == 0 || len(m.PublicKey) == 0 { @@ -60,14 +60,14 @@ func (m *MessageServerKeyExchange) Marshal() ([]byte, error) { } out = append(out, []byte{byte(m.HashAlgorithm), byte(m.SignatureAlgorithm), 0x00, 0x00}...) - binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.Signature))) + binary.BigEndian.PutUint16(out[len(out)-2:], uint16(len(m.Signature))) //nolint:gosec // G115 out = append(out, m.Signature...) return out, nil } -// Unmarshal populates the message from encoded data -func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { +// Unmarshal populates the message from encoded data. +func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { //nolint:cyclop switch { case len(data) < 2: return errBufferTooSmall @@ -84,6 +84,7 @@ func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { if len(data) == 0 { return nil } + return errLengthMismatch } @@ -144,5 +145,6 @@ func (m *MessageServerKeyExchange) Unmarshal(data []byte) error { return errBufferTooSmall } m.Signature = append([]byte{}, data[offset:offset+signatureLength]...) + return nil } diff --git a/pkg/protocol/handshake/message_server_key_exchange_test.go b/pkg/protocol/handshake/message_server_key_exchange_test.go index acfc17964..fc4573cf3 100644 --- a/pkg/protocol/handshake/message_server_key_exchange_test.go +++ b/pkg/protocol/handshake/message_server_key_exchange_test.go @@ -7,10 +7,10 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/internal/ciphersuite/types" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/hash" - "github.com/pion/dtls/v2/pkg/crypto/signature" + "github.com/pion/dtls/v3/internal/ciphersuite/types" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/hash" + "github.com/pion/dtls/v3/pkg/crypto/signature" ) func TestHandshakeMessageServerKeyExchange(t *testing.T) { @@ -32,7 +32,7 @@ func TestHandshakeMessageServerKeyExchange(t *testing.T) { } } - t.Run("Hash+Signature", func(t *testing.T) { + t.Run("Hash+Signature", func(*testing.T) { rawServerKeyExchange := []byte{ 0x03, 0x00, 0x1d, 0x41, 0x04, 0x0c, 0xb9, 0xa3, 0xb9, 0x90, 0x71, 0x35, 0x4a, 0x08, 0x66, 0xaf, 0xd6, 0x88, 0x58, 0x29, 0x69, 0x98, 0xf1, 0x87, 0x0f, 0xb5, 0xa8, 0xcd, 0x92, 0xf6, 0x2b, 0x08, @@ -57,7 +57,7 @@ func TestHandshakeMessageServerKeyExchange(t *testing.T) { test(rawServerKeyExchange, parsedServerKeyExchange) }) - t.Run("Anonymous", func(t *testing.T) { + t.Run("Anonymous", func(*testing.T) { rawServerKeyExchange := []byte{ 0x03, 0x00, 0x1d, 0x41, 0x04, 0x0c, 0xb9, 0xa3, 0xb9, 0x90, 0x71, 0x35, 0x4a, 0x08, 0x66, 0xaf, 0xd6, 0x88, 0x58, 0x29, 0x69, 0x98, 0xf1, 0x87, 0x0f, 0xb5, 0xa8, 0xcd, 0x92, 0xf6, 0x2b, 0x08, diff --git a/pkg/protocol/handshake/random.go b/pkg/protocol/handshake/random.go index 56f37569b..6eb2815f4 100644 --- a/pkg/protocol/handshake/random.go +++ b/pkg/protocol/handshake/random.go @@ -9,7 +9,7 @@ import ( "time" ) -// Consts for Random in Handshake +// Consts for Random in Handshake. const ( RandomBytesLength = 28 RandomLength = RandomBytesLength + 4 @@ -23,24 +23,24 @@ type Random struct { RandomBytes [RandomBytesLength]byte } -// MarshalFixed encodes the Handshake +// MarshalFixed encodes the Handshake. func (r *Random) MarshalFixed() [RandomLength]byte { var out [RandomLength]byte - binary.BigEndian.PutUint32(out[0:], uint32(r.GMTUnixTime.Unix())) + binary.BigEndian.PutUint32(out[0:], uint32(r.GMTUnixTime.Unix())) //nolint:gosec // G115 copy(out[4:], r.RandomBytes[:]) return out } -// UnmarshalFixed populates the message from encoded data +// UnmarshalFixed populates the message from encoded data. func (r *Random) UnmarshalFixed(data [RandomLength]byte) { r.GMTUnixTime = time.Unix(int64(binary.BigEndian.Uint32(data[0:])), 0) copy(r.RandomBytes[:], data[4:]) } // Populate fills the handshakeRandom with random values -// may be called multiple times +// may be called multiple times. func (r *Random) Populate() error { r.GMTUnixTime = time.Now() diff --git a/pkg/protocol/recordlayer/errors.go b/pkg/protocol/recordlayer/errors.go index cd4cb60a5..09599249b 100644 --- a/pkg/protocol/recordlayer/errors.go +++ b/pkg/protocol/recordlayer/errors.go @@ -7,13 +7,18 @@ package recordlayer import ( "errors" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol" ) var ( - errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 - errInvalidPacketLength = &protocol.TemporaryError{Err: errors.New("packet length and declared length do not match")} //nolint:goerr113 - errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 - errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 - errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 + // ErrInvalidPacketLength is returned when the packet length too small + // or declared length do not match. + ErrInvalidPacketLength = &protocol.TemporaryError{ + Err: errors.New("packet length and declared length do not match"), //nolint:goerr113 + } + + errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 + errSequenceNumberOverflow = &protocol.InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 + errUnsupportedProtocolVersion = &protocol.FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 + errInvalidContentType = &protocol.TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 ) diff --git a/pkg/protocol/recordlayer/fuzz_test.go b/pkg/protocol/recordlayer/fuzz_test.go index 6379d9d17..24918d2ee 100644 --- a/pkg/protocol/recordlayer/fuzz_test.go +++ b/pkg/protocol/recordlayer/fuzz_test.go @@ -7,35 +7,30 @@ import ( "testing" ) -func partialHeaderMismatch(a, b Header) bool { - // Ignoring content length for now. - a.ContentLen = b.ContentLen - return a != b -} - func FuzzRecordLayer(f *testing.F) { - f.Fuzz(func(t *testing.T, data []byte) { - var r RecordLayer - if err := r.Unmarshal(data); err != nil { - return - } - - buf, err := r.Marshal() - if err != nil { - return - } + Data := []byte{ + 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01, + } + f.Add(Data) - if len(buf) == 0 { - t.Fatal("Zero buff") - } - - var nr RecordLayer - if err = nr.Unmarshal(data); err != nil { - t.Fatal(err) - } + f.Fuzz(func(_ *testing.T, data []byte) { + var r RecordLayer + _ = r.Unmarshal(data) + }) +} - if partialHeaderMismatch(nr.Header, r.Header) { - t.Fatalf("Header mismatch: %+v != %+v", nr.Header, r.Header) - } +func FuzzUnpackDatagram(f *testing.F) { + Datasingle := []byte{ + 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01, + } + Datamulti := []byte{ + 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x01, 0x01, + 0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x01, 0x01, + } + f.Add(Datasingle) + f.Add(Datamulti) + + f.Fuzz(func(_ *testing.T, data []byte) { + _, _ = UnpackDatagram(data) }) } diff --git a/pkg/protocol/recordlayer/header.go b/pkg/protocol/recordlayer/header.go index 92252502b..47af855df 100644 --- a/pkg/protocol/recordlayer/header.go +++ b/pkg/protocol/recordlayer/header.go @@ -6,47 +6,64 @@ package recordlayer import ( "encoding/binary" - "github.com/pion/dtls/v2/internal/util" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/internal/util" + "github.com/pion/dtls/v3/pkg/protocol" ) -// Header implements a TLS RecordLayer header +// Header implements a TLS RecordLayer header. type Header struct { ContentType protocol.ContentType ContentLen uint16 Version protocol.Version Epoch uint16 SequenceNumber uint64 // uint48 in spec + + // Optional Fields + ConnectionID []byte } -// RecordLayer enums +// RecordLayer enums. const ( - HeaderSize = 13 + // FixedHeaderSize is the size of a DTLS record header when connection IDs + // are not in use. + FixedHeaderSize = 13 MaxSequenceNumber = 0x0000FFFFFFFFFFFF ) -// Marshal encodes a TLS RecordLayer Header to binary +// Marshal encodes a TLS RecordLayer Header to binary. func (h *Header) Marshal() ([]byte, error) { if h.SequenceNumber > MaxSequenceNumber { return nil, errSequenceNumberOverflow } - out := make([]byte, HeaderSize) + hs := FixedHeaderSize + len(h.ConnectionID) + + out := make([]byte, hs) out[0] = byte(h.ContentType) out[1] = h.Version.Major out[2] = h.Version.Minor binary.BigEndian.PutUint16(out[3:], h.Epoch) util.PutBigEndianUint48(out[5:], h.SequenceNumber) - binary.BigEndian.PutUint16(out[HeaderSize-2:], h.ContentLen) + copy(out[11:11+len(h.ConnectionID)], h.ConnectionID) + binary.BigEndian.PutUint16(out[hs-2:], h.ContentLen) + return out, nil } -// Unmarshal populates a TLS RecordLayer Header from binary +// Unmarshal populates a TLS RecordLayer Header from binary. func (h *Header) Unmarshal(data []byte) error { - if len(data) < HeaderSize { + if len(data) < FixedHeaderSize { return errBufferTooSmall } h.ContentType = protocol.ContentType(data[0]) + if h.ContentType == protocol.ContentTypeConnectionID { + // If a CID was expected the ConnectionID should have been initialized. + if len(data) < FixedHeaderSize+len(h.ConnectionID) { + return errBufferTooSmall + } + h.ConnectionID = data[11 : 11+len(h.ConnectionID)] + } + h.Version.Major = data[1] h.Version.Minor = data[2] h.Epoch = binary.BigEndian.Uint16(data[3:]) @@ -62,3 +79,8 @@ func (h *Header) Unmarshal(data []byte) error { return nil } + +// Size returns the total size of the header. +func (h *Header) Size() int { + return FixedHeaderSize + len(h.ConnectionID) +} diff --git a/pkg/protocol/recordlayer/inner_plaintext.go b/pkg/protocol/recordlayer/inner_plaintext.go new file mode 100644 index 000000000..2c67c86f2 --- /dev/null +++ b/pkg/protocol/recordlayer/inner_plaintext.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> +// SPDX-License-Identifier: MIT + +package recordlayer + +import ( + "github.com/pion/dtls/v3/pkg/protocol" + "golang.org/x/crypto/cryptobyte" +) + +// InnerPlaintext implements DTLSInnerPlaintext +// +// https://datatracker.ietf.org/doc/html/rfc9146#name-record-layer-extensions +type InnerPlaintext struct { + Content []byte + RealType protocol.ContentType + Zeros uint +} + +// Marshal encodes a DTLS InnerPlaintext to binary. +func (p *InnerPlaintext) Marshal() ([]byte, error) { + var out cryptobyte.Builder + out.AddBytes(p.Content) + out.AddUint8(uint8(p.RealType)) + out.AddBytes(make([]byte, p.Zeros)) + + return out.Bytes() +} + +// Unmarshal populates a DTLS InnerPlaintext from binary. +func (p *InnerPlaintext) Unmarshal(data []byte) error { + // Process in reverse + i := len(data) - 1 + for i >= 0 { + if data[i] != 0 { + p.Zeros = uint(len(data) - 1 - i) //nolint:gosec // G115 + + break + } + i-- + } + if i == 0 { + return errBufferTooSmall + } + p.RealType = protocol.ContentType(data[i]) + p.Content = append([]byte{}, data[:i]...) + + return nil +} diff --git a/pkg/protocol/recordlayer/recordlayer.go b/pkg/protocol/recordlayer/recordlayer.go index 02325fd2d..95113da4f 100644 --- a/pkg/protocol/recordlayer/recordlayer.go +++ b/pkg/protocol/recordlayer/recordlayer.go @@ -6,11 +6,28 @@ package recordlayer import ( "encoding/binary" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/alert" - "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/alert" + "github.com/pion/dtls/v3/pkg/protocol/handshake" ) +// DTLS fixed size record layer header when Connection IDs are not in-use. + +// --------------------------------- +// | Type | Version | Epoch | +// --------------------------------- +// | Epoch | Sequence Number | +// --------------------------------- +// | Sequence Number | Length | +// --------------------------------- +// | Length | Fragment... | +// --------------------------------- + +// fixedHeaderLenIdx is the index at which the record layer content length is +// specified in a fixed length header (i.e. one that does not include a +// Connection ID). +const fixedHeaderLenIdx = 11 + // RecordLayer which handles all data transport. // The record layer is assumed to sit directly on top of some // reliable transport such as TCP. The record layer can carry four types of content: @@ -31,14 +48,14 @@ type RecordLayer struct { Content protocol.Content } -// Marshal encodes the RecordLayer to binary +// Marshal encodes the RecordLayer to binary. func (r *RecordLayer) Marshal() ([]byte, error) { contentRaw, err := r.Content.Marshal() if err != nil { return nil, err } - r.Header.ContentLen = uint16(len(contentRaw)) + r.Header.ContentLen = uint16(len(contentRaw)) //nolint:gosec // G115 r.Header.ContentType = r.Content.ContentType() headerRaw, err := r.Header.Marshal() @@ -49,16 +66,13 @@ func (r *RecordLayer) Marshal() ([]byte, error) { return append(headerRaw, contentRaw...), nil } -// Unmarshal populates the RecordLayer from binary +// Unmarshal populates the RecordLayer from binary. func (r *RecordLayer) Unmarshal(data []byte) error { - if len(data) < HeaderSize { - return errBufferTooSmall - } if err := r.Header.Unmarshal(data); err != nil { return err } - switch protocol.ContentType(data[0]) { + switch r.Header.ContentType { case protocol.ContentTypeChangeCipherSpec: r.Content = &protocol.ChangeCipherSpec{} case protocol.ContentTypeAlert: @@ -71,7 +85,7 @@ func (r *RecordLayer) Unmarshal(data []byte) error { return errInvalidContentType } - return r.Content.Unmarshal(data[HeaderSize:]) + return r.Content.Unmarshal(data[r.Header.Size()+len(r.Header.ConnectionID):]) } // UnpackDatagram extracts all RecordLayer messages from a single datagram. @@ -85,13 +99,42 @@ func UnpackDatagram(buf []byte) ([][]byte, error) { out := [][]byte{} for offset := 0; len(buf) != offset; { - if len(buf)-offset <= HeaderSize { - return nil, errInvalidPacketLength + if len(buf)-offset <= FixedHeaderSize { + return nil, ErrInvalidPacketLength + } + + pktLen := (FixedHeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:]))) + if offset+pktLen > len(buf) { + return nil, ErrInvalidPacketLength + } + + out = append(out, buf[offset:offset+pktLen]) + offset += pktLen + } + + return out, nil +} + +// ContentAwareUnpackDatagram is the same as UnpackDatagram but considers the +// presence of a connection identifier if the record is of content type +// tls12_cid. +func ContentAwareUnpackDatagram(buf []byte, cidLength int) ([][]byte, error) { + out := [][]byte{} + + for offset := 0; len(buf) != offset; { + headerSize := FixedHeaderSize + lenIdx := fixedHeaderLenIdx + if protocol.ContentType(buf[offset]) == protocol.ContentTypeConnectionID { + headerSize += cidLength + lenIdx += cidLength + } + if len(buf)-offset <= headerSize { + return nil, ErrInvalidPacketLength } - pktLen := (HeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:]))) + pktLen := (headerSize + int(binary.BigEndian.Uint16(buf[offset+lenIdx:]))) if offset+pktLen > len(buf) { - return nil, errInvalidPacketLength + return nil, ErrInvalidPacketLength } out = append(out, buf[offset:offset+pktLen]) diff --git a/pkg/protocol/recordlayer/recordlayer_test.go b/pkg/protocol/recordlayer/recordlayer_test.go index 2e16c0104..c568c9147 100644 --- a/pkg/protocol/recordlayer/recordlayer_test.go +++ b/pkg/protocol/recordlayer/recordlayer_test.go @@ -8,7 +8,7 @@ import ( "reflect" "testing" - "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol" ) func TestUDPDecode(t *testing.T) { @@ -39,12 +39,12 @@ func TestUDPDecode(t *testing.T) { { Name: "Invalid packet length", Data: []byte{0x14, 0xfe}, - WantError: errInvalidPacketLength, + WantError: ErrInvalidPacketLength, }, { Name: "Packet declared invalid length", Data: []byte{0x14, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0xFF, 0x01}, - WantError: errInvalidPacketLength, + WantError: ErrInvalidPacketLength, }, } { dtlsPkts, err := UnpackDatagram(test.Data) diff --git a/pkg/protocol/version.go b/pkg/protocol/version.go index c4d94ac3a..3943c1504 100644 --- a/pkg/protocol/version.go +++ b/pkg/protocol/version.go @@ -4,7 +4,7 @@ // Package protocol provides the DTLS wire format package protocol -// Version enums +// Version enums. var ( Version1_0 = Version{Major: 0xfe, Minor: 0xff} //nolint:gochecknoglobals Version1_2 = Version{Major: 0xfe, Minor: 0xfd} //nolint:gochecknoglobals @@ -18,7 +18,7 @@ type Version struct { Major, Minor uint8 } -// Equal determines if two protocol versions are equal +// Equal determines if two protocol versions are equal. func (v Version) Equal(x Version) bool { return v.Major == x.Major && v.Minor == x.Minor } diff --git a/replayprotection_test.go b/replayprotection_test.go index 3f3a8b4f0..e0984ef62 100644 --- a/replayprotection_test.go +++ b/replayprotection_test.go @@ -12,11 +12,11 @@ import ( "testing" "time" - "github.com/pion/transport/v2/dpipe" - "github.com/pion/transport/v2/test" + "github.com/pion/transport/v3/dpipe" + "github.com/pion/transport/v3/test" ) -func TestReplayProtection(t *testing.T) { +func TestReplayProtection(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() @@ -52,6 +52,7 @@ func TestReplayProtection(t *testing.T) { } if _, werr := cb.Write(b[:n]); werr != nil { t.Error(werr) + return } @@ -109,10 +110,12 @@ func TestReplayProtection(t *testing.T) { sent = append(sent, data) if _, werr := ca.Write(data); werr != nil { t.Error(werr) + return } if _, werr := cb.Write(data); werr != nil { t.Error(werr) + return } } diff --git a/resume.go b/resume.go index c470d856b..954907dd0 100644 --- a/resume.go +++ b/resume.go @@ -4,19 +4,14 @@ package dtls import ( - "context" "net" ) -// Resume imports an already established dtls connection using a specific dtls state -func Resume(state *State, conn net.Conn, config *Config) (*Conn, error) { +// Resume imports an already established dtls connection using a specific dtls state. +func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if err := state.initCipherSuite(); err != nil { return nil, err } - c, err := createConn(context.Background(), conn, config, state.isClient, state) - if err != nil { - return nil, err - } - return c, nil + return createConn(conn, rAddr, config, state.isClient, state) } diff --git a/resume_test.go b/resume_test.go index c8c231b86..a3ce26cec 100644 --- a/resume_test.go +++ b/resume_test.go @@ -13,11 +13,15 @@ import ( "testing" "time" - "github.com/pion/dtls/v2/pkg/crypto/selfsign" - "github.com/pion/transport/v2/test" + "github.com/pion/dtls/v3/pkg/crypto/selfsign" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/transport/v3/test" ) -var errMessageMissmatch = errors.New("messages missmatch") +var ( + errMessageMissmatch = errors.New("messages missmatch") + errInvalidConnectionState = errors.New("failed to get connection state") +) func TestResumeClient(t *testing.T) { DoTestResume(t, Client, Server) @@ -28,11 +32,20 @@ func TestResumeServer(t *testing.T) { } func fatal(t *testing.T, errChan chan error, err error) { + t.Helper() + close(errChan) t.Fatal(err) } -func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Conn, error)) { +//nolint:cyclop +func DoTestResume( + t *testing.T, + newLocal, + newRemote func(net.PacketConn, net.Addr, *Config) (*Conn, error), +) { + t.Helper() + // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() @@ -67,7 +80,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Co go func() { var remote *Conn var errR error - remote, errR = newRemote(remoteConn, config) + remote, errR = newRemote(dtlsnet.PacketConnFromConn(remoteConn), remoteConn.RemoteAddr(), config) if errR != nil { errChan <- errR } @@ -89,7 +102,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Co }() var local *Conn - local, err = newLocal(localConn1, config) + local, err = newLocal(dtlsnet.PacketConnFromConn(localConn1), localConn1.RemoteAddr(), config) if err != nil { fatal(t, errChan, err) } @@ -119,7 +132,10 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Co } // Serialize and deserialize state - state := local.ConnectionState() + state, ok := local.ConnectionState() + if !ok { + fatal(t, errChan, errInvalidConnectionState) + } var b []byte b, err = state.MarshalBinary() if err != nil { @@ -132,7 +148,7 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Co // Resume dtls connection var resumed net.Conn - resumed, err = Resume(deserialized, localConn2, config) + resumed, err = Resume(deserialized, dtlsnet.PacketConnFromConn(localConn2), localConn2.RemoteAddr(), config) if err != nil { fatal(t, errChan, err) } @@ -169,8 +185,10 @@ func (b *backupConn) Read(data []byte) (n int, err error) { b.curr = b.next b.next = nil b.mux.Unlock() + return b.Read(data) } + return n, err } @@ -181,8 +199,10 @@ func (b *backupConn) Write(data []byte) (n int, err error) { b.curr = b.next b.next = nil b.mux.Unlock() + return b.Write(data) } + return n, err } diff --git a/session.go b/session.go index 99bf5a499..912a5997a 100644 --- a/session.go +++ b/session.go @@ -3,7 +3,7 @@ package dtls -// Session store data needed in resumption +// Session store data needed in resumption. type Session struct { // ID store session id ID []byte diff --git a/srtp_protection_profile.go b/srtp_protection_profile.go index e306e9e6a..bc242095b 100644 --- a/srtp_protection_profile.go +++ b/srtp_protection_profile.go @@ -3,7 +3,7 @@ package dtls -import "github.com/pion/dtls/v2/pkg/protocol/extension" +import "github.com/pion/dtls/v3/pkg/protocol/extension" // SRTPProtectionProfile defines the parameters and options that are in effect for the SRTP processing // https://tools.ietf.org/html/rfc5764#section-4.1.2 @@ -12,6 +12,10 @@ type SRTPProtectionProfile = extension.SRTPProtectionProfile const ( SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_80 // nolint:revive,stylecheck SRTP_AES128_CM_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES128_CM_HMAC_SHA1_32 // nolint:revive,stylecheck + SRTP_AES256_CM_SHA1_80 SRTPProtectionProfile = extension.SRTP_AES256_CM_SHA1_80 // nolint:revive,stylecheck + SRTP_AES256_CM_SHA1_32 SRTPProtectionProfile = extension.SRTP_AES256_CM_SHA1_32 // nolint:revive,stylecheck + SRTP_NULL_HMAC_SHA1_80 SRTPProtectionProfile = extension.SRTP_NULL_HMAC_SHA1_80 // nolint:revive,stylecheck + SRTP_NULL_HMAC_SHA1_32 SRTPProtectionProfile = extension.SRTP_NULL_HMAC_SHA1_32 // nolint:revive,stylecheck SRTP_AEAD_AES_128_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_128_GCM // nolint:revive,stylecheck SRTP_AEAD_AES_256_GCM SRTPProtectionProfile = extension.SRTP_AEAD_AES_256_GCM // nolint:revive,stylecheck ) diff --git a/state.go b/state.go index e9f86a80b..364f7dcc1 100644 --- a/state.go +++ b/state.go @@ -6,26 +6,46 @@ package dtls import ( "bytes" "encoding/gob" + "errors" "sync/atomic" - "github.com/pion/dtls/v2/pkg/crypto/elliptic" - "github.com/pion/dtls/v2/pkg/crypto/prf" - "github.com/pion/dtls/v2/pkg/protocol/handshake" - "github.com/pion/transport/v2/replaydetector" + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/dtls/v3/pkg/crypto/prf" + "github.com/pion/dtls/v3/pkg/crypto/signaturehash" + "github.com/pion/dtls/v3/pkg/protocol/handshake" + "github.com/pion/transport/v3/replaydetector" ) -// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// State holds the dtls connection state and implements both encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. type State struct { localEpoch, remoteEpoch atomic.Value localSequenceNumber []uint64 // uint48 localRandom, remoteRandom handshake.Random masterSecret []byte cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen + CipherSuiteID CipherSuiteID - srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile - PeerCertificates [][]byte - IdentityHint []byte - SessionID []byte + srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile + remoteSRTPMasterKeyIdentifier []byte + + PeerCertificates [][]byte + IdentityHint []byte + SessionID []byte + + // Connection Identifiers must be negotiated afresh on session resumption. + // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension + + // localConnectionID is the locally generated connection ID that is expected + // to be received from the remote endpoint. + // For a server, this is the connection ID sent in ServerHello. + // For a client, this is the connection ID sent in the ClientHello. + localConnectionID atomic.Value + // remoteConnectionID is the connection ID that the remote endpoint + // specifies should be sent. + // For a server, this is the connection ID received in the ClientHello. + // For a client, this is the connection ID received in the ServerHello. + remoteConnectionID []byte isClient bool @@ -38,6 +58,7 @@ type State struct { handshakeSendSequence int handshakeRecvSequence int serverName string + remoteCertRequestAlgs []signaturehash.Algorithm remoteRequestedCertificate bool // Did we get a CertificateRequest localCertificatesVerify []byte // cache CertificateVerify localVerifyData []byte // cached VerifyData @@ -62,37 +83,54 @@ type serializedState struct { PeerCertificates [][]byte IdentityHint []byte SessionID []byte + LocalConnectionID []byte + RemoteConnectionID []byte IsClient bool + NegotiatedProtocol string } -func (s *State) clone() *State { - serialized := s.serialize() +var errCipherSuiteNotSet = &InternalError{Err: errors.New("cipher suite not set")} //nolint:goerr113 + +func (s *State) clone() (*State, error) { + serialized, err := s.serialize() + if err != nil { + return nil, err + } state := &State{} state.deserialize(*serialized) - return state + return state, err } -func (s *State) serialize() *serializedState { +func (s *State) serialize() (*serializedState, error) { + if s.cipherSuite == nil { + return nil, errCipherSuiteNotSet + } + cipherSuiteID := uint16(s.cipherSuite.ID()) + // Marshal random values localRnd := s.localRandom.MarshalFixed() remoteRnd := s.remoteRandom.MarshalFixed() epoch := s.getLocalEpoch() + return &serializedState{ LocalEpoch: s.getLocalEpoch(), RemoteEpoch: s.getRemoteEpoch(), - CipherSuiteID: uint16(s.cipherSuite.ID()), + CipherSuiteID: cipherSuiteID, MasterSecret: s.masterSecret, SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]), LocalRandom: localRnd, RemoteRandom: remoteRnd, - SRTPProtectionProfile: uint16(s.srtpProtectionProfile), + SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()), PeerCertificates: s.PeerCertificates, IdentityHint: s.IdentityHint, SessionID: s.SessionID, + LocalConnectionID: s.getLocalConnectionID(), + RemoteConnectionID: s.remoteConnectionID, IsClient: s.isClient, - } + NegotiatedProtocol: s.NegotiatedProtocol, + }, nil } func (s *State) deserialize(serialized serializedState) { @@ -120,15 +158,24 @@ func (s *State) deserialize(serialized serializedState) { s.masterSecret = serialized.MasterSecret // Set cipher suite - s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil) + s.CipherSuiteID = CipherSuiteID(serialized.CipherSuiteID) + s.cipherSuite = cipherSuiteForID(s.CipherSuiteID, nil) atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber) - s.srtpProtectionProfile = SRTPProtectionProfile(serialized.SRTPProtectionProfile) + s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile)) // Set remote certificate s.PeerCertificates = serialized.PeerCertificates + s.IdentityHint = serialized.IdentityHint + + // Set local and remote connection IDs + s.setLocalConnectionID(serialized.LocalConnectionID) + s.remoteConnectionID = serialized.RemoteConnectionID + s.SessionID = serialized.SessionID + + s.NegotiatedProtocol = serialized.NegotiatedProtocol } func (s *State) initCipherSuite() error { @@ -148,22 +195,27 @@ func (s *State) initCipherSuite() error { if err != nil { return err } + return nil } -// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation +// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation. func (s *State) MarshalBinary() ([]byte, error) { - serialized := s.serialize() + serialized, err := s.serialize() + if err != nil { + return nil, err + } var buf bytes.Buffer enc := gob.NewEncoder(&buf) if err := enc.Encode(*serialized); err != nil { return nil, err } + return buf.Bytes(), nil } -// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation +// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation. func (s *State) UnmarshalBinary(data []byte) error { enc := gob.NewDecoder(bytes.NewBuffer(data)) var serialized serializedState @@ -179,7 +231,7 @@ func (s *State) UnmarshalBinary(data []byte) error { // ExportKeyingMaterial returns length bytes of exported key material in a new // slice as defined in RFC 5705. // This allows protocols to use DTLS for key establishment, but -// then use some of the keying material for their own purposes +// then use some of the keying material for their own purposes. func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { if s.getLocalEpoch() == 0 { return nil, errHandshakeInProgress @@ -198,6 +250,7 @@ func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ( } else { seed = append(append(seed, remoteRandom[:]...), localRandom[:]...) } + return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc()) } @@ -205,6 +258,7 @@ func (s *State) getRemoteEpoch() uint16 { if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok { return remoteEpoch } + return 0 } @@ -212,5 +266,35 @@ func (s *State) getLocalEpoch() uint16 { if localEpoch, ok := s.localEpoch.Load().(uint16); ok { return localEpoch } + return 0 } + +func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) { + s.srtpProtectionProfile.Store(profile) +} + +func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile { + if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok { + return val + } + + return 0 +} + +func (s *State) getLocalConnectionID() []byte { + if val, ok := s.localConnectionID.Load().([]byte); ok { + return val + } + + return nil +} + +func (s *State) setLocalConnectionID(v []byte) { + s.localConnectionID.Store(v) +} + +// RemoteRandomBytes returns the remote client hello random bytes. +func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte { + return s.remoteRandom.RandomBytes +} diff --git a/testdata/seed/TestResumeClient.raw b/testdata/seed/TestResumeClient.raw new file mode 100644 index 000000000..c2f9cb04e Binary files /dev/null and b/testdata/seed/TestResumeClient.raw differ diff --git a/testdata/seed/TestResumeServer.raw b/testdata/seed/TestResumeServer.raw new file mode 100644 index 000000000..29b6f4d00 --- /dev/null +++ b/testdata/seed/TestResumeServer.raw @@ -0,0 +1 @@ +���serializedState��LocalEpochRemoteEpochLocalRandom��RemoteRandom��CipherSuiteIDMasterSecretSequenceNumberSRTPProtectionProfilePeerCertificates��IdentityHintSessionIDLocalConnectionIDRemoteConnectionIDIsClient��[32]uint8��@��[][]uint8������ e&M����K��'����/��kZO6g��MP��I����������} e&M����������@��Cs.=��5Z��{����rg:[����Le����+0y���*��[�\B�GEv�5�ٕ=�u�CwP9�r�48r�y��;k.E \ No newline at end of file diff --git a/util.go b/util.go index 663c4437c..3d9b0bc85 100644 --- a/util.go +++ b/util.go @@ -11,6 +11,7 @@ func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfil } } } + return 0, false } @@ -22,6 +23,7 @@ func findMatchingCipherSuite(a, b []CipherSuite) (CipherSuite, bool) { } } } + return nil, false }