// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> // SPDX-License-Identifier: MIT package dtls import ( "crypto/rand" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) // RandomCIDGenerator is a random Connection ID generator where CID is the // specified size. Specifying a size of 0 will indicate to peers that sending a // Connection ID is not necessary. func RandomCIDGenerator(size int) func() []byte { return func() []byte { cid := make([]byte, size) if _, err := rand.Read(cid); err != nil { panic(err) //nolint -- nonrecoverable } return cid } } // OnlySendCIDGenerator enables sending Connection IDs negotiated with a peer, // but indicates to the peer that sending Connection IDs in return is not // necessary. func OnlySendCIDGenerator() func() []byte { return func() []byte { return nil } } // cidDatagramRouter extracts connection IDs from incoming datagram payloads and // uses them to route to the proper connection. // NOTE: properly routing datagrams based on connection IDs requires using // constant size connection IDs. func cidDatagramRouter(size int) func([]byte) (string, bool) { return func(packet []byte) (string, bool) { pkts, err := recordlayer.ContentAwareUnpackDatagram(packet, size) if err != nil || len(pkts) < 1 { return "", false } for _, pkt := range pkts { h := &recordlayer.Header{ ConnectionID: make([]byte, size), } if err := h.Unmarshal(pkt); err != nil { continue } if h.ContentType != protocol.ContentTypeConnectionID { continue } return string(h.ConnectionID), true } return "", false } } // cidConnIdentifier extracts connection IDs from outgoing ServerHello records // and associates them with the associated connection. // NOTE: a ServerHello should always be the first record in a datagram if // multiple are present, so we avoid iterating through all packets if the first // is not a ServerHello. func cidConnIdentifier() func([]byte) (string, bool) { //nolint:cyclop return func(packet []byte) (string, bool) { pkts, err := recordlayer.UnpackDatagram(packet) if err != nil || len(pkts) < 1 { return "", false } var h recordlayer.Header if hErr := h.Unmarshal(pkts[0]); hErr != nil { return "", false } if h.ContentType != protocol.ContentTypeHandshake { return "", false } var hh handshake.Header var sh handshake.MessageServerHello for _, pkt := range pkts { if hhErr := hh.Unmarshal(pkt[recordlayer.FixedHeaderSize:]); hhErr != nil { continue } if err = sh.Unmarshal(pkt[recordlayer.FixedHeaderSize+handshake.HeaderLength:]); err == nil { break } } if err != nil { return "", false } for _, ext := range sh.Extensions { if e, ok := ext.(*extension.ConnectionID); ok { return string(e.CID), true } } return "", false } }