From 23a403e52f880e4e4390fe105764ea9169ca26ed Mon Sep 17 00:00:00 2001 From: Ingo Oppermann Date: Wed, 28 Aug 2024 16:27:08 +0200 Subject: [PATCH] Add congestion control extension --- conn_request.go | 12 +++++++ packet/packet.go | 75 ++++++++++++++++++++++++++++++++++++++----- packet/packet_test.go | 8 +++-- 3 files changed, 85 insertions(+), 10 deletions(-) diff --git a/conn_request.go b/conn_request.go index 8057695..bc5bcd5 100644 --- a/conn_request.go +++ b/conn_request.go @@ -206,6 +206,18 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest { return nil } + + // We only support live congestion control + if cif.HasCongestionCtl && cif.CongestionCtl != "live" { + cif.HandshakeType = packet.HandshakeType(REJ_CONGESTION) + ln.log("handshake:recv:error", func() string { return "only live congestion control is supported" }) + p.MarshalCIF(cif) + ln.log("handshake:send:dump", func() string { return p.Dump() }) + ln.log("handshake:send:cif", func() string { return cif.String() }) + ln.send(p) + + return nil + } } else { cif.HandshakeType = packet.HandshakeType(REJ_ROGUE) ln.log("handshake:recv:error", func() string { return fmt.Sprintf("only HSv4 and HSv5 are supported (got HSv%d)", cif.Version) }) diff --git a/packet/packet.go b/packet/packet.go index 1ecfea4..6a72b50 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -140,7 +140,7 @@ const ( EXTTYPE_KMREQ CtrlSubType = 3 EXTTYPE_KMRSP CtrlSubType = 4 EXTTYPE_SID CtrlSubType = 5 - EXTTYPE_CONGESTION CtrlSubType = 6 // unimplemented + EXTTYPE_CONGESTION CtrlSubType = 6 EXTTYPE_FILTER CtrlSubType = 7 // unimplemented EXTTYPE_GROUP CtrlSubType = 8 // unimplemented ) @@ -484,7 +484,7 @@ type CIFHandshake struct { Version uint32 // A base protocol version number. Currently used values are 4 and 5. Values greater than 5 are reserved for future use. EncryptionField uint16 // Block cipher family and key size. The values of this field are described in Table 2. The default value is AES-128. - ExtensionField uint16 // This field is message specific extension related to Handshake Type field. The value MUST be set to 0 except for the following cases. (1) If the handshake control packet is the INDUCTION message, this field is sent back by the Listener. (2) In the case of a CONCLUSION message, this field value should contain a combination of Extension Type values. For more details, see Section 4.3.1. + ExtensionField uint16 // This field is a message specific extension related to Handshake Type field. The value MUST be set to 0 except for the following cases. (1) If the handshake control packet is the INDUCTION message, this field is sent back by the Listener. (2) In the case of a CONCLUSION message, this field value should contain a combination of Extension Type values. For more details, see Section 4.3.1. InitialPacketSequenceNumber circular.Number // The sequence number of the very first data packet to be sent. MaxTransmissionUnitSize uint32 // This value is typically set to 1500, which is the default Maximum Transmission Unit (MTU) size for Ethernet, but can be less. MaxFlowWindowSize uint32 // The value of this field is the maximum number of data packets allowed to be "in flight" (i.e. the number of sent packets for which an ACK control packet has not yet been received). @@ -493,9 +493,10 @@ type CIFHandshake struct { SynCookie uint32 // Randomized value for processing a handshake. The value of this field is specified by the handshake message type. See Section 4.3. PeerIP srtnet.IP // IPv4 or IPv6 address of the packet's sender. The value consists of four 32-bit fields. In the case of IPv4 addresses, fields 2, 3 and 4 are filled with zeroes. - HasHS bool - HasKM bool - HasSID bool + HasHS bool + HasKM bool + HasSID bool + HasCongestionCtl bool // 3.2.1.1. Handshake Extension Message SRTHS *CIFHandshakeExtension @@ -505,6 +506,9 @@ type CIFHandshake struct { // 3.2.1.3. Stream ID Extension Message StreamId string + + // ??? Congestion Control Extension message (handshake.md #### Congestion controller) + CongestionCtl string } func (c CIFHandshake) String() string { @@ -537,6 +541,12 @@ func (c CIFHandshake) String() string { fmt.Fprintf(&b, " streamId : %s\n", c.StreamId) fmt.Fprintf(&b, "--- /SIDExt ---\n") } + + if c.HasCongestionCtl { + fmt.Fprintf(&b, "--- CongestionExt ---\n") + fmt.Fprintf(&b, " congestion : %s\n", c.CongestionCtl) + fmt.Fprintf(&b, "--- /CongestionExt ---\n") + } } fmt.Fprintf(&b, "--- /handshake ---") @@ -599,7 +609,7 @@ func (c *CIFHandshake) Unmarshal(data []byte) error { if extensionType == EXTTYPE_HSREQ || extensionType == EXTTYPE_HSRSP { // 3.2.1.1. Handshake Extension Message if extensionLength != 12 || len(pivot) < extensionLength { - return fmt.Errorf("invalid extension length") + return fmt.Errorf("invalid extension length of %d bytes (%s)", extensionLength, extensionType.String()) } c.HasHS = true @@ -612,7 +622,7 @@ func (c *CIFHandshake) Unmarshal(data []byte) error { } else if extensionType == EXTTYPE_KMREQ || extensionType == EXTTYPE_KMRSP { // 3.2.1.2. Key Material Extension Message if len(pivot) < extensionLength { - return fmt.Errorf("invalid extension length") + return fmt.Errorf("invalid extension length of %d bytes (%s)", extensionLength, extensionType.String()) } c.HasKM = true @@ -638,7 +648,7 @@ func (c *CIFHandshake) Unmarshal(data []byte) error { } else if extensionType == EXTTYPE_SID { // 3.2.1.3. Stream ID Extension Message if extensionLength > 512 || len(pivot) < extensionLength { - return fmt.Errorf("invalid extension length") + return fmt.Errorf("invalid extension length of %d bytes (%s)", extensionLength, extensionType.String()) } c.HasSID = true @@ -653,6 +663,24 @@ func (c *CIFHandshake) Unmarshal(data []byte) error { } c.StreamId = strings.TrimRight(b.String(), "\x00") + } else if extensionType == EXTTYPE_CONGESTION { + // ??? Congestion Control Extension message (handshake.md #### Congestion controller) + if extensionLength > 4 || len(pivot) < extensionLength { + return fmt.Errorf("invalid extension length of %d bytes (%s)", extensionLength, extensionType.String()) + } + + c.HasCongestionCtl = true + + var b strings.Builder + + for i := 0; i < extensionLength; i += 4 { + b.WriteByte(pivot[i+3]) + b.WriteByte(pivot[i+2]) + b.WriteByte(pivot[i+1]) + b.WriteByte(pivot[i+0]) + } + + c.CongestionCtl = strings.TrimRight(b.String(), "\x00") } else { return fmt.Errorf("unimplemented extension (%d)", extensionType) } @@ -695,6 +723,10 @@ func (c *CIFHandshake) Marshal(w io.Writer) { if c.HasSID { c.ExtensionField = c.ExtensionField | 4 } + + if c.HasCongestionCtl { + c.ExtensionField = c.ExtensionField | 4 + } } else { c.EncryptionField = 0 c.ExtensionField = 2 @@ -773,6 +805,33 @@ func (c *CIFHandshake) Marshal(w io.Writer) { w.Write(buffer[:4]) } } + + if c.HasCongestionCtl && c.CongestionCtl != "live" { + congestion := bytes.NewBufferString(c.CongestionCtl) + + missing := (4 - congestion.Len()%4) + if missing < 4 { + for i := 0; i < missing; i++ { + congestion.WriteByte(0) + } + } + + binary.BigEndian.PutUint16(buffer[0:], EXTTYPE_CONGESTION.Value()) + binary.BigEndian.PutUint16(buffer[2:], uint16(congestion.Len()/4)) + + w.Write(buffer[:4]) + + b := congestion.Bytes() + + for i := 0; i < len(b); i += 4 { + buffer[0] = b[i+3] + buffer[1] = b[i+2] + buffer[2] = b[i+1] + buffer[3] = b[i+0] + + w.Write(buffer[:4]) + } + } } // 3.2.1.1.1. Handshake Extension Message Flags diff --git a/packet/packet_test.go b/packet/packet_test.go index 0db64ed..9d494d2 100644 --- a/packet/packet_test.go +++ b/packet/packet_test.go @@ -137,6 +137,7 @@ func TestHandshakeV4(t *testing.T) { HasHS: false, HasKM: false, HasSID: false, + HasCongestionCtl: false, } var buf bytes.Buffer @@ -174,6 +175,7 @@ func TestHandshakeV5(t *testing.T) { HasHS: true, HasKM: true, HasSID: true, + HasCongestionCtl: true, SRTHS: &CIFHandshakeExtension{ SRTVersion: 0x010402, SRTFlags: CIFHandshakeExtensionFlags{ @@ -207,7 +209,8 @@ func TestHandshakeV5(t *testing.T) { Salt: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10}, Wrap: []byte{0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}, }, - StreamId: "/live/stream.foobar", + StreamId: "/live/stream.foobar", + CongestionCtl: "foob", } var buf bytes.Buffer @@ -216,7 +219,7 @@ func TestHandshakeV5(t *testing.T) { data := hex.EncodeToString(buf.Bytes()) - require.Equal(t, "00000005000200070000002a000005dc00000064ffffffff00274921001234560100007f00000000000000000000000000020003000104020000003f006400640004000e122029010000000002000200000004040102030405060708090a0b0c0d0e0f10f0f1f2f3f4f5f6f71112131415161718191a1b1c1d1e1f200005000576696c2f74732f656d6165726f6f662e00726162", data) + require.Equal(t, "00000005000200070000002a000005dc00000064ffffffff00274921001234560100007f00000000000000000000000000020003000104020000003f006400640004000e122029010000000002000200000004040102030405060708090a0b0c0d0e0f10f0f1f2f3f4f5f6f71112131415161718191a1b1c1d1e1f200005000576696c2f74732f656d6165726f6f662e0072616200060001626f6f66", data) cif2 := &CIFHandshake{} @@ -245,6 +248,7 @@ func TestHandshakeString(t *testing.T) { HasHS: true, HasKM: false, HasSID: true, + HasCongestionCtl: false, SRTHS: &CIFHandshakeExtension{ SRTVersion: 0x010402, SRTFlags: CIFHandshakeExtensionFlags{