Skip to content

Commit

Permalink
use pool for unmarshal packets
Browse files Browse the repository at this point in the history
  • Loading branch information
tyohan committed Sep 19, 2024
1 parent 26bcda5 commit e5596f0
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 21 deletions.
138 changes: 117 additions & 21 deletions packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

package rtcp

import (
"bytes"
"sync"
)

// Packet represents an RTCP packet, a protocol used for out-of-band statistics and control information for an RTP session
type Packet interface {
// DestinationSSRC returns an array of SSRC values that this packet refers to.
Expand All @@ -11,16 +16,38 @@ type Packet interface {
Marshal() ([]byte, error)
Unmarshal(rawPacket []byte) error
MarshalSize() int

// Release returns the packet to its pool
Release()
}

var (
senderReportPool = sync.Pool{New: func() interface{} { return new(SenderReport) }}
receiverReportPool = sync.Pool{New: func() interface{} { return new(ReceiverReport) }}
sourceDescriptionPool = sync.Pool{New: func() interface{} { return new(SourceDescription) }}
goodbyePool = sync.Pool{New: func() interface{} { return new(Goodbye) }}
transportLayerNackPool = sync.Pool{New: func() interface{} { return new(TransportLayerNack) }}
rapidResynchronizationRequestPool = sync.Pool{New: func() interface{} { return new(RapidResynchronizationRequest) }}
transportLayerCCPool = sync.Pool{New: func() interface{} { return new(TransportLayerCC) }}
ccFeedbackReportPool = sync.Pool{New: func() interface{} { return new(CCFeedbackReport) }}
pictureLossIndicationPool = sync.Pool{New: func() interface{} { return new(PictureLossIndication) }}
sliceLossIndicationPool = sync.Pool{New: func() interface{} { return new(SliceLossIndication) }}
receiverEstimatedMaximumBitratePool = sync.Pool{New: func() interface{} { return new(ReceiverEstimatedMaximumBitrate) }}
fullIntraRequestPool = sync.Pool{New: func() interface{} { return new(FullIntraRequest) }}
extendedReportPool = sync.Pool{New: func() interface{} { return new(ExtendedReport) }}
applicationDefinedPool = sync.Pool{New: func() interface{} { return new(ApplicationDefined) }}
rawPacketPool = sync.Pool{New: func() interface{} { return new(RawPacket) }}
)

// Unmarshal takes an entire udp datagram (which may consist of multiple RTCP packets) and
// returns the unmarshaled packets it contains.
//
// If this is a reduced-size RTCP packet a feedback packet (Goodbye, SliceLossIndication, etc)
// will be returned. Otherwise, the underlying type of the returned packet will be
// CompoundPacket.
func Unmarshal(rawData []byte) ([]Packet, error) {
var packets []Packet
estimatedPackets := len(rawData) / 100 // Estimate based on average packet size
packets := make([]Packet, 0, estimatedPackets)
for len(rawData) != 0 {
p, processed, err := unmarshal(rawData)
if err != nil {
Expand All @@ -43,15 +70,16 @@ func Unmarshal(rawData []byte) ([]Packet, error) {

// Marshal takes an array of Packets and serializes them to a single buffer
func Marshal(packets []Packet) ([]byte, error) {
out := make([]byte, 0)
var buf bytes.Buffer
for _, p := range packets {
data, err := p.Marshal()
if err != nil {
return nil, err
}
out = append(out, data...)
buf.Write(data)
p.Release()
}
return out, nil
return buf.Bytes(), nil
}

// unmarshal is a factory which pulls the first RTCP packet from a bytestream,
Expand All @@ -72,56 +100,124 @@ func unmarshal(rawData []byte) (packet Packet, bytesprocessed int, err error) {

switch h.Type {
case TypeSenderReport:
packet = new(SenderReport)
packet = senderReportPool.Get().(*SenderReport)

case TypeReceiverReport:
packet = new(ReceiverReport)
packet = receiverReportPool.Get().(*ReceiverReport)

case TypeSourceDescription:
packet = new(SourceDescription)
packet = sourceDescriptionPool.Get().(*SourceDescription)

case TypeGoodbye:
packet = new(Goodbye)
packet = goodbyePool.Get().(*Goodbye)

case TypeTransportSpecificFeedback:
switch h.Count {
case FormatTLN:
packet = new(TransportLayerNack)
packet = transportLayerNackPool.Get().(*TransportLayerNack)
case FormatRRR:
packet = new(RapidResynchronizationRequest)
packet = rapidResynchronizationRequestPool.Get().(*RapidResynchronizationRequest)
case FormatTCC:
packet = new(TransportLayerCC)
packet = transportLayerCCPool.Get().(*TransportLayerCC)
case FormatCCFB:
packet = new(CCFeedbackReport)
packet = ccFeedbackReportPool.Get().(*CCFeedbackReport)
default:
packet = new(RawPacket)
packet = rawPacketPool.Get().(*RawPacket)
}

case TypePayloadSpecificFeedback:
switch h.Count {
case FormatPLI:
packet = new(PictureLossIndication)
packet = pictureLossIndicationPool.Get().(*PictureLossIndication)
case FormatSLI:
packet = new(SliceLossIndication)
packet = sliceLossIndicationPool.Get().(*SliceLossIndication)
case FormatREMB:
packet = new(ReceiverEstimatedMaximumBitrate)
packet = receiverEstimatedMaximumBitratePool.Get().(*ReceiverEstimatedMaximumBitrate)
case FormatFIR:
packet = new(FullIntraRequest)
packet = fullIntraRequestPool.Get().(*FullIntraRequest)
default:
packet = new(RawPacket)
packet = rawPacketPool.Get().(*RawPacket)
}

case TypeExtendedReport:
packet = new(ExtendedReport)
packet = extendedReportPool.Get().(*ExtendedReport)

case TypeApplicationDefined:
packet = new(ApplicationDefined)
packet = applicationDefinedPool.Get().(*ApplicationDefined)

default:
packet = new(RawPacket)
packet = rawPacketPool.Get().(*RawPacket)
}

err = packet.Unmarshal(inPacket)

return packet, bytesprocessed, err
}

// Implement the Release method for each concrete packet type
func (p *SenderReport) Release() {
senderReportPool.Put(p)
}

func (p *ReceiverReport) Release() {
receiverReportPool.Put(p)
}

func (p *SourceDescription) Release() {
sourceDescriptionPool.Put(p)
}

func (p *Goodbye) Release() {
goodbyePool.Put(p)
}

func (p *TransportLayerNack) Release() {
transportLayerNackPool.Put(p)
}

func (p *RapidResynchronizationRequest) Release() {
rapidResynchronizationRequestPool.Put(p)
}

func (p *TransportLayerCC) Release() {
transportLayerCCPool.Put(p)
}

func (p *CCFeedbackReport) Release() {
ccFeedbackReportPool.Put(p)
}

func (p *PictureLossIndication) Release() {
pictureLossIndicationPool.Put(p)
}

func (p *SliceLossIndication) Release() {
sliceLossIndicationPool.Put(p)
}

func (p *ReceiverEstimatedMaximumBitrate) Release() {
receiverEstimatedMaximumBitratePool.Put(p)
}

func (p *FullIntraRequest) Release() {
fullIntraRequestPool.Put(p)
}

func (p *ExtendedReport) Release() {
extendedReportPool.Put(p)
}

func (p *ApplicationDefined) Release() {
applicationDefinedPool.Put(p)
}

func (p *CompoundPacket) Release() {
// CompoundPacket is a slice of pointers, so we need to release each one
for _, packet := range *p {
packet.Release()
}
}

func (p *RawPacket) Release() {
rawPacketPool.Put(p)
}
10 changes: 10 additions & 0 deletions packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ func realPacket() []byte {
}
}

func BenchmarkUnmarshal(b *testing.B) {
packetData := realPacket()
for i := 0; i < b.N; i++ {
_, err := Unmarshal(packetData)
if err != nil {
b.Fatalf("Error unmarshalling packets: %s", err)
}
}
}

func TestUnmarshal(t *testing.T) {
packet, err := Unmarshal(realPacket())
if err != nil {
Expand Down

0 comments on commit e5596f0

Please sign in to comment.