@@ -54,6 +54,11 @@ type addrPkt struct {
54
54
data []byte
55
55
}
56
56
57
+ type recvHandshakeState struct {
58
+ done chan struct {}
59
+ isRetransmit bool
60
+ }
61
+
57
62
// Conn represents a DTLS connection
58
63
type Conn struct {
59
64
lock sync.RWMutex // Internal lock (must not be public)
@@ -82,7 +87,7 @@ type Conn struct {
82
87
log logging.LeveledLogger
83
88
84
89
reading chan struct {}
85
- handshakeRecv chan chan struct {}
90
+ handshakeRecv chan recvHandshakeState
86
91
cancelHandshaker func ()
87
92
cancelHandshakeReader func ()
88
93
@@ -137,7 +142,7 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien
137
142
writeDeadline : deadline .New (),
138
143
139
144
reading : make (chan struct {}, 1 ),
140
- handshakeRecv : make (chan chan struct {} ),
145
+ handshakeRecv : make (chan recvHandshakeState ),
141
146
closed : closer .NewCloser (),
142
147
cancelHandshaker : func () {},
143
148
@@ -704,9 +709,9 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
704
709
return err
705
710
}
706
711
707
- var hasHandshake bool
712
+ var hasHandshake , isRetransmit bool
708
713
for _ , p := range pkts {
709
- hs , alert , err := c .handleIncomingPacket (ctx , p , rAddr , true )
714
+ hs , rtx , alert , err := c .handleIncomingPacket (ctx , p , rAddr , true )
710
715
if alert != nil {
711
716
if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
712
717
if err == nil {
@@ -725,14 +730,20 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {
725
730
if hs {
726
731
hasHandshake = true
727
732
}
733
+ if rtx {
734
+ isRetransmit = true
735
+ }
728
736
}
729
737
if hasHandshake {
730
- done := make (chan struct {})
738
+ s := recvHandshakeState {
739
+ done : make (chan struct {}),
740
+ isRetransmit : isRetransmit ,
741
+ }
731
742
select {
732
- case c .handshakeRecv <- done :
743
+ case c .handshakeRecv <- s :
733
744
// If the other party may retransmit the flight,
734
745
// we should respond even if it not a new message.
735
- <- done
746
+ <- s . done
736
747
case <- c .fsm .Done ():
737
748
}
738
749
}
@@ -744,7 +755,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
744
755
c .encryptedPackets = nil
745
756
746
757
for _ , p := range pkts {
747
- _ , alert , err := c .handleIncomingPacket (ctx , p .data , p .rAddr , false ) // don't re-enqueue
758
+ _ , _ , alert , err := c .handleIncomingPacket (ctx , p .data , p .rAddr , false ) // don't re-enqueue
748
759
if alert != nil {
749
760
if alertErr := c .notify (ctx , alert .Level , alert .Description ); alertErr != nil {
750
761
if err == nil {
@@ -771,7 +782,7 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
771
782
return false
772
783
}
773
784
774
- func (c * Conn ) handleIncomingPacket (ctx context.Context , buf []byte , rAddr net.Addr , enqueue bool ) (bool , * alert.Alert , error ) { //nolint:gocognit
785
+ func (c * Conn ) handleIncomingPacket (ctx context.Context , buf []byte , rAddr net.Addr , enqueue bool ) (bool , bool , * alert.Alert , error ) { //nolint:gocognit
775
786
h := & recordlayer.Header {}
776
787
// Set connection ID size so that records of content type tls12_cid will
777
788
// be parsed correctly.
@@ -782,7 +793,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
782
793
// Decode error must be silently discarded
783
794
// [RFC6347 Section-4.1.2.7]
784
795
c .log .Debugf ("discarded broken packet: %v" , err )
785
- return false , nil , nil
796
+ return false , false , nil , nil
786
797
}
787
798
// Validate epoch
788
799
remoteEpoch := c .state .getRemoteEpoch ()
@@ -791,14 +802,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
791
802
c .log .Debugf ("discarded future packet (epoch: %d, seq: %d)" ,
792
803
h .Epoch , h .SequenceNumber ,
793
804
)
794
- return false , nil , nil
805
+ return false , false , nil , nil
795
806
}
796
807
if enqueue {
797
808
if ok := c .enqueueEncryptedPackets (addrPkt {rAddr , buf }); ok {
798
809
c .log .Debug ("received packet of next epoch, queuing packet" )
799
810
}
800
811
}
801
- return false , nil , nil
812
+ return false , false , nil , nil
802
813
}
803
814
804
815
// Anti-replay protection
@@ -812,7 +823,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
812
823
c .log .Debugf ("discarded duplicated packet (epoch: %d, seq: %d)" ,
813
824
h .Epoch , h .SequenceNumber ,
814
825
)
815
- return false , nil , nil
826
+ return false , false , nil , nil
816
827
}
817
828
818
829
// originalCID indicates whether the original record had content type
@@ -827,14 +838,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
827
838
c .log .Debug ("handshake not finished, queuing packet" )
828
839
}
829
840
}
830
- return false , nil , nil
841
+ return false , false , nil , nil
831
842
}
832
843
833
844
// If a connection identifier had been negotiated and encryption is
834
845
// enabled, the connection identifier MUST be sent.
835
846
if len (c .state .getLocalConnectionID ()) > 0 && h .ContentType != protocol .ContentTypeConnectionID {
836
847
c .log .Debug ("discarded packet missing connection ID after value negotiated" )
837
- return false , nil , nil
848
+ return false , false , nil , nil
838
849
}
839
850
840
851
var err error
@@ -845,7 +856,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
845
856
buf , err = c .state .cipherSuite .Decrypt (hdr , buf )
846
857
if err != nil {
847
858
c .log .Debugf ("%s: decrypt failed: %s" , srvCliStr (c .state .isClient ), err )
848
- return false , nil , nil
859
+ return false , false , nil , nil
849
860
}
850
861
// If this is a connection ID record, make it look like a normal record for
851
862
// further processing.
@@ -854,7 +865,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
854
865
ip := & recordlayer.InnerPlaintext {}
855
866
if err := ip .Unmarshal (buf [h .Size ():]); err != nil { //nolint:govet
856
867
c .log .Debugf ("unpacking inner plaintext failed: %s" , err )
857
- return false , nil , nil
868
+ return false , false , nil , nil
858
869
}
859
870
unpacked := & recordlayer.Header {
860
871
ContentType : ip .RealType ,
@@ -866,26 +877,27 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
866
877
buf , err = unpacked .Marshal ()
867
878
if err != nil {
868
879
c .log .Debugf ("converting CID record to inner plaintext failed: %s" , err )
869
- return false , nil , nil
880
+ return false , false , nil , nil
870
881
}
871
882
buf = append (buf , ip .Content ... )
872
883
}
873
884
874
885
// If connection ID does not match discard the packet.
875
886
if ! bytes .Equal (c .state .getLocalConnectionID (), h .ConnectionID ) {
876
887
c .log .Debug ("unexpected connection ID" )
877
- return false , nil , nil
888
+ return false , false , nil , nil
878
889
}
879
890
}
880
891
881
- isHandshake , err := c .fragmentBuffer .push (append ([]byte {}, buf ... ))
892
+ isHandshake , isRetransmit , err := c .fragmentBuffer .push (append ([]byte {}, buf ... ))
882
893
if err != nil {
883
894
// Decode error must be silently discarded
884
895
// [RFC6347 Section-4.1.2.7]
885
896
c .log .Debugf ("defragment failed: %s" , err )
886
- return false , nil , nil
897
+ return false , false , nil , nil
887
898
} else if isHandshake {
888
899
markPacketAsValid ()
900
+
889
901
for out , epoch := c .fragmentBuffer .pop (); out != nil ; out , epoch = c .fragmentBuffer .pop () {
890
902
header := & handshake.Header {}
891
903
if err := header .Unmarshal (out ); err != nil {
@@ -895,12 +907,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
895
907
c .handshakeCache .push (out , epoch , header .MessageSequence , header .Type , ! c .state .isClient )
896
908
}
897
909
898
- return true , nil , nil
910
+ return true , isRetransmit , nil , nil
899
911
}
900
912
901
913
r := & recordlayer.RecordLayer {}
902
914
if err := r .Unmarshal (buf ); err != nil {
903
- return false , & alert.Alert {Level : alert .Fatal , Description : alert .DecodeError }, err
915
+ return false , false , & alert.Alert {Level : alert .Fatal , Description : alert .DecodeError }, err
904
916
}
905
917
906
918
isLatestSeqNum := false
@@ -913,15 +925,15 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
913
925
a = & alert.Alert {Level : alert .Warning , Description : alert .CloseNotify }
914
926
}
915
927
_ = markPacketAsValid ()
916
- return false , a , & alertError {content }
928
+ return false , false , a , & alertError {content }
917
929
case * protocol.ChangeCipherSpec :
918
930
if c .state .cipherSuite == nil || ! c .state .cipherSuite .IsInitialized () {
919
931
if enqueue {
920
932
if ok := c .enqueueEncryptedPackets (addrPkt {rAddr , buf }); ok {
921
933
c .log .Debugf ("CipherSuite not initialized, queuing packet" )
922
934
}
923
935
}
924
- return false , nil , nil
936
+ return false , false , nil , nil
925
937
}
926
938
927
939
newRemoteEpoch := h .Epoch + 1
@@ -933,7 +945,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
933
945
}
934
946
case * protocol.ApplicationData :
935
947
if h .Epoch == 0 {
936
- return false , & alert.Alert {Level : alert .Fatal , Description : alert .UnexpectedMessage }, errApplicationDataEpochZero
948
+ return false , false , & alert.Alert {Level : alert .Fatal , Description : alert .UnexpectedMessage }, errApplicationDataEpochZero
937
949
}
938
950
939
951
isLatestSeqNum = markPacketAsValid ()
@@ -945,7 +957,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
945
957
}
946
958
947
959
default :
948
- return false , & alert.Alert {Level : alert .Fatal , Description : alert .UnexpectedMessage }, fmt .Errorf ("%w: %d" , errUnhandledContextType , content .ContentType ())
960
+ return false , false , & alert.Alert {Level : alert .Fatal , Description : alert .UnexpectedMessage }, fmt .Errorf ("%w: %d" , errUnhandledContextType , content .ContentType ())
949
961
}
950
962
951
963
// Any valid connection ID record is a candidate for updating the remote
@@ -959,10 +971,10 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
959
971
}
960
972
}
961
973
962
- return false , nil , nil
974
+ return false , false , nil , nil
963
975
}
964
976
965
- func (c * Conn ) recvHandshake () <- chan chan struct {} {
977
+ func (c * Conn ) recvHandshake () <- chan recvHandshakeState {
966
978
return c .handshakeRecv
967
979
}
968
980
0 commit comments