diff --git a/client.go b/client.go index 1bcfec31..9818db6a 100644 --- a/client.go +++ b/client.go @@ -21,27 +21,42 @@ type Client struct { clientConnectionSignature []byte sessionKey []byte sequenceIDIn *Counter - sequenceIDOut *Counter + sequenceIDOutManager *SequenceIDManager pid uint32 stationURLs []*StationURL connectionID uint32 pingCheckTimer *time.Timer pingKickTimer *time.Timer connected bool + incomingPacketManager *PacketManager + outgoingResendManager *PacketResendManager } // Reset resets the Client to default values func (client *Client) Reset() error { + server := client.Server() + client.sequenceIDIn = NewCounter(0) - client.sequenceIDOut = NewCounter(0) + client.sequenceIDOutManager = NewSequenceIDManager() // TODO - Pass the server into here to get data for multiple substreams and the unreliable starting ID + client.incomingPacketManager = NewPacketManager() + + if client.outgoingResendManager != nil { + // * PacketResendManager makes use of time.Ticker structs. + // * These create new channels and goroutines which won't + // * close even if the objects are deleted. To free up + // * resources, time.Ticker MUST be stopped before reassigning + client.outgoingResendManager.Clear() + } + + client.outgoingResendManager = NewPacketResendManager(server.resendTimeout, server.resendTimeoutIncrement, server.resendMaxIterations) - client.UpdateAccessKey(client.Server().AccessKey()) + client.UpdateAccessKey(server.AccessKey()) err := client.UpdateRC4Key([]byte("CD&ML")) if err != nil { return fmt.Errorf("Failed to update client RC4 key. %s", err.Error()) } - if client.Server().PRUDPVersion() == 0 { + if server.PRUDPVersion() == 0 { client.SetServerConnectionSignature(make([]byte, 4)) client.SetClientConnectionSignature(make([]byte, 4)) } else { @@ -149,9 +164,9 @@ func (client *Client) ClientConnectionSignature() []byte { return client.clientConnectionSignature } -// SequenceIDCounterOut returns the clients packet SequenceID counter for out-going packets -func (client *Client) SequenceIDCounterOut() *Counter { - return client.sequenceIDOut +// SequenceIDOutManager returns the clients packet SequenceID manager for out-going packets +func (client *Client) SequenceIDOutManager() *SequenceIDManager { + return client.sequenceIDOutManager } // SequenceIDCounterIn returns the clients packet SequenceID counter for incoming packets @@ -233,6 +248,18 @@ func (client *Client) StartTimeoutTimer() { }) } +// StopTimeoutTimer stops the packet timeout timer +func (client *Client) StopTimeoutTimer() { + //Stop the kick timer + if client.pingKickTimer != nil { + client.pingKickTimer.Stop() + } + //and the check timer + if client.pingCheckTimer != nil { + client.pingCheckTimer.Stop() + } +} + // NewClient returns a new PRUDP client func NewClient(address *net.UDPAddr, server *Server) *Client { client := &Client{ diff --git a/go.mod b/go.mod index 46b78fb2..911601da 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/PretendoNetwork/plogger-go v1.0.4 github.com/superwhiskers/crunch/v3 v3.5.7 + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/mod v0.12.0 ) @@ -13,6 +14,6 @@ require ( github.com/jwalton/go-supportscolor v1.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - golang.org/x/sys v0.11.0 // indirect + golang.org/x/sys v0.12.0 // indirect golang.org/x/term v0.11.0 // indirect ) diff --git a/go.sum b/go.sum index 2c9138aa..ef5e34c6 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,7 @@ github.com/PretendoNetwork/plogger-go v1.0.4 h1:PF7xHw9eDRHH+RsAP9tmAE7fG0N0p6H4 github.com/PretendoNetwork/plogger-go v1.0.4/go.mod h1:7kD6M4vPq1JL4LTuPg6kuB1OvUBOwQOtAvTaUwMbwvU= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/jwalton/go-supportscolor v1.2.0 h1:g6Ha4u7Vm3LIsQ5wmeBpS4gazu0UP1DRDE8y6bre4H8= github.com/jwalton/go-supportscolor v1.2.0/go.mod h1:hFVUAZV2cWg+WFFC4v8pT2X/S2qUUBYMioBD9AINXGs= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -12,14 +12,16 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/superwhiskers/crunch/v3 v3.5.7 h1:N9RLxaR65C36i26BUIpzPXGy2f6pQ7wisu2bawbKNqg= github.com/superwhiskers/crunch/v3 v3.5.7/go.mod h1:4ub2EKgF1MAhTjoOCTU4b9uLMsAweHEa89aRrfAypXA= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= diff --git a/mutex_map.go b/mutex_map.go new file mode 100644 index 00000000..4a91d345 --- /dev/null +++ b/mutex_map.go @@ -0,0 +1,76 @@ +package nex + +import "sync" + +// MutexMap implements a map type with go routine safe accessors through mutex locks. Embeds sync.RWMutex +type MutexMap[K comparable, V any] struct { + *sync.RWMutex + real map[K]V +} + +// Set sets a key to a given value +func (m *MutexMap[K, V]) Set(key K, value V) { + m.Lock() + defer m.Unlock() + + m.real[key] = value +} + +// Get returns the given key value and a bool if found +func (m *MutexMap[K, V]) Get(key K) (V, bool) { + m.RLock() + defer m.RUnlock() + + value, ok := m.real[key] + + return value, ok +} + +// Delete removes a key from the internal map +func (m *MutexMap[K, V]) Delete(key K) { + m.Lock() + defer m.Unlock() + + delete(m.real, key) +} + +// Size returns the length of the internal map +func (m *MutexMap[K, V]) Size() int { + m.RLock() + defer m.RUnlock() + + return len(m.real) +} + +// Each runs a callback function for every item in the map +// The map should not be modified inside the callback function +func (m *MutexMap[K, V]) Each(callback func(key K, value V)) { + m.RLock() + defer m.RUnlock() + + for key, value := range m.real { + callback(key, value) + } +} + +// Clear removes all items from the `real` map +// Accepts an optional callback function ran for every item before it is deleted +func (m *MutexMap[K, V]) Clear(callback func(key K, value V)) { + m.Lock() + defer m.Unlock() + + for key, value := range m.real { + if callback != nil { + callback(key, value) + } + delete(m.real, key) + } +} + +// NewMutexMap returns a new instance of MutexMap with the provided key/value types +func NewMutexMap[K comparable, V any]() *MutexMap[K, V] { + return &MutexMap[K, V]{ + RWMutex: &sync.RWMutex{}, + real: make(map[K]V), + } +} diff --git a/packet_interface.go b/packet_interface.go index 1f522fe6..71aaf5b9 100644 --- a/packet_interface.go +++ b/packet_interface.go @@ -2,6 +2,7 @@ package nex // PacketInterface implements all Packet methods type PacketInterface interface { + Data() []byte Sender() *Client SetVersion(version uint8) Version() uint8 @@ -28,6 +29,7 @@ type PacketInterface interface { FragmentID() uint8 SetPayload(payload []byte) Payload() []byte + DecryptPayload() error RMCRequest() RMCRequest Bytes() []byte } diff --git a/packet_manager.go b/packet_manager.go new file mode 100644 index 00000000..ab49806c --- /dev/null +++ b/packet_manager.go @@ -0,0 +1,43 @@ +package nex + +// PacketManager implements an API for pushing/popping packets in the correct order +type PacketManager struct { + currentSequenceID *Counter + packets []PacketInterface +} + +// Next gets the next packet in the sequence. Returns nil if the next packet has not been sent yet +func (p *PacketManager) Next() PacketInterface { + var packet PacketInterface + + for i := 0; i < len(p.packets); i++ { + if p.currentSequenceID.Value() == uint32(p.packets[i].SequenceID()) { + packet = p.packets[i] + p.RemoveByIndex(i) + p.currentSequenceID.Increment() + break + } + } + + return packet +} + +// Push adds a packet to the pool to choose from in Next +func (p *PacketManager) Push(packet PacketInterface) { + p.packets = append(p.packets, packet) +} + +// RemoveByIndex removes a packet from the pool using it's index in the slice +func (p *PacketManager) RemoveByIndex(i int) { + // * https://stackoverflow.com/a/37335777 + p.packets[i] = p.packets[len(p.packets)-1] + p.packets = p.packets[:len(p.packets)-1] +} + +// NewPacketManager returns a new PacketManager +func NewPacketManager() *PacketManager { + return &PacketManager{ + currentSequenceID: NewCounter(0), + packets: make([]PacketInterface, 0), + } +} diff --git a/packet_resend_manager.go b/packet_resend_manager.go new file mode 100644 index 00000000..47b7028c --- /dev/null +++ b/packet_resend_manager.go @@ -0,0 +1,114 @@ +package nex + +import ( + "time" +) + +// PendingPacket represents a packet which the server has sent but not received an ACK for +// it handles it's own retransmission on a per-packet timer +type PendingPacket struct { + ticking bool + ticker *time.Ticker + quit chan struct{} + packet PacketInterface + iterations *Counter + timeout time.Duration + timeoutInc time.Duration + maxIterations int +} + +// BeginTimeoutTimer starts the pending packets timeout timer until it is either stopped or maxIterations is hit +func (p *PendingPacket) BeginTimeoutTimer() { + go func() { + for { + select { + case <-p.quit: + return + case <-p.ticker.C: + client := p.packet.Sender() + server := client.Server() + + if int(p.iterations.Increment()) > p.maxIterations { + // * Max iterations hit. Assume client is dead + server.TimeoutKick(client) + p.StopTimeoutTimer() + return + } else { + if p.timeoutInc != 0 { + p.timeout += p.timeoutInc + p.ticker.Reset(p.timeout) + } + + // * Resend the packet + server.SendRaw(client.Address(), p.packet.Bytes()) + } + } + } + }() +} + +// StopTimeoutTimer stops the packet retransmission timer +func (p *PendingPacket) StopTimeoutTimer() { + if p.ticking { + close(p.quit) + p.ticker.Stop() + p.ticking = false + } +} + +// NewPendingPacket returns a new PendingPacket +func NewPendingPacket(packet PacketInterface, timeoutTime time.Duration, timeoutIncrement time.Duration, maxIterations int) *PendingPacket { + p := &PendingPacket{ + ticking: true, + ticker: time.NewTicker(timeoutTime), + quit: make(chan struct{}), + packet: packet, + iterations: NewCounter(0), + timeout: timeoutTime, + timeoutInc: timeoutIncrement, + maxIterations: maxIterations, + } + + return p +} + +// PacketResendManager manages all the pending packets sent the client waiting to be ACKed +type PacketResendManager struct { + pending *MutexMap[uint16, *PendingPacket] + timeoutTime time.Duration + timeoutInc time.Duration + maxIterations int +} + +// Add creates a PendingPacket, adds it to the pool, and begins it's timeout timer +func (p *PacketResendManager) Add(packet PacketInterface) { + cached := NewPendingPacket(packet, p.timeoutTime, p.timeoutInc, p.maxIterations) + p.pending.Set(packet.SequenceID(), cached) + + cached.BeginTimeoutTimer() +} + +// Remove removes a packet from pool and stops it's timer +func (p *PacketResendManager) Remove(sequenceID uint16) { + if cached, ok := p.pending.Get(sequenceID); ok { + cached.StopTimeoutTimer() + p.pending.Delete(sequenceID) + } +} + +// Clear removes all packets from pool and stops their timers +func (p *PacketResendManager) Clear() { + p.pending.Clear(func(key uint16, value *PendingPacket) { + value.StopTimeoutTimer() + }) +} + +// NewPacketResendManager returns a new PacketResendManager +func NewPacketResendManager(timeoutTime time.Duration, timeoutIncrement time.Duration, maxIterations int) *PacketResendManager { + return &PacketResendManager{ + pending: NewMutexMap[uint16, *PendingPacket](), + timeoutTime: timeoutTime, + timeoutInc: timeoutIncrement, + maxIterations: maxIterations, + } +} diff --git a/packet_v0.go b/packet_v0.go index 351b7fde..3161d459 100644 --- a/packet_v0.go +++ b/packet_v0.go @@ -117,19 +117,6 @@ func (packet *PacketV0) Decode() error { payloadCrypted := stream.ReadBytesNext(int64(payloadSize)) packet.SetPayload(payloadCrypted) - - if packet.Type() == DataPacket { - ciphered := make([]byte, payloadSize) - packet.Sender().Decipher().XORKeyStream(ciphered, payloadCrypted) - - request := NewRMCRequest() - err := request.FromBytes(ciphered) - if err != nil { - return errors.New("[PRUDPv0] Error parsing RMC request: " + err.Error()) - } - - packet.rmcRequest = request - } } if len(packet.Data()[stream.ByteOffset():]) < int(checksumSize) { @@ -154,30 +141,27 @@ func (packet *PacketV0) Decode() error { return nil } -// Bytes encodes the packet and returns a byte array -func (packet *PacketV0) Bytes() []byte { - if packet.Type() == DataPacket { +// DecryptPayload decrypts the packets payload and sets the RMC request data +func (packet *PacketV0) DecryptPayload() error { + if packet.Type() == DataPacket && !packet.HasFlag(FlagAck) { + ciphered := make([]byte, len(packet.Payload())) - if packet.HasFlag(FlagAck) { - packet.SetPayload([]byte{}) - } else { - payload := packet.Payload() + packet.Sender().Decipher().XORKeyStream(ciphered, packet.Payload()) - if payload != nil || len(payload) > 0 { - payloadSize := len(payload) - - encrypted := make([]byte, payloadSize) - packet.Sender().Cipher().XORKeyStream(encrypted, payload) - - packet.SetPayload(encrypted) - } + request := NewRMCRequest() + err := request.FromBytes(ciphered) + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 RMC request. %s", err.Error()) } - if !packet.HasFlag(FlagHasSize) { - packet.AddFlag(FlagHasSize) - } + packet.rmcRequest = request } + return nil +} + +// Bytes encodes the packet and returns a byte array +func (packet *PacketV0) Bytes() []byte { var typeFlags uint16 = packet.Type() | packet.Flags()<<4 stream := NewStreamOut(packet.Sender().Server()) diff --git a/packet_v1.go b/packet_v1.go index df099962..4fd505b6 100644 --- a/packet_v1.go +++ b/packet_v1.go @@ -199,20 +199,6 @@ func (packet *PacketV1) Decode() error { payloadCrypted := stream.ReadBytesNext(int64(payloadSize)) packet.SetPayload(payloadCrypted) - - if packet.Type() == DataPacket && !packet.HasFlag(FlagMultiAck) { - ciphered := make([]byte, payloadSize) - - packet.Sender().Decipher().XORKeyStream(ciphered, payloadCrypted) - - request := NewRMCRequest() - err := request.FromBytes(ciphered) - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 RMC request. %s", err.Error()) - } - - packet.rmcRequest = request - } } calculatedSignature := packet.calculateSignature(packet.Data()[2:14], packet.Sender().ServerConnectionSignature(), options, packet.Payload()) @@ -224,27 +210,27 @@ func (packet *PacketV1) Decode() error { return nil } -// Bytes encodes the packet and returns a byte array -func (packet *PacketV1) Bytes() []byte { - if packet.Type() == DataPacket { - if !packet.HasFlag(FlagMultiAck) { - payload := packet.Payload() - - if payload != nil || len(payload) > 0 { - payloadSize := len(payload) +// DecryptPayload decrypts the packets payload and sets the RMC request data +func (packet *PacketV1) DecryptPayload() error { + if packet.Type() == DataPacket && !packet.HasFlag(FlagMultiAck) { + ciphered := make([]byte, len(packet.Payload())) - encrypted := make([]byte, payloadSize) - packet.Sender().Cipher().XORKeyStream(encrypted, payload) + packet.Sender().Decipher().XORKeyStream(ciphered, packet.Payload()) - packet.SetPayload(encrypted) - } + request := NewRMCRequest() + err := request.FromBytes(ciphered) + if err != nil { + return fmt.Errorf("Failed to read PRUDPv1 RMC request. %s", err.Error()) } - if !packet.HasFlag(FlagHasSize) { - packet.AddFlag(FlagHasSize) - } + packet.rmcRequest = request } + return nil +} + +// Bytes encodes the packet and returns a byte array +func (packet *PacketV1) Bytes() []byte { var typeFlags uint16 = packet.Type() | packet.Flags()<<4 stream := NewStreamOut(packet.Sender().Server()) diff --git a/rmc.go b/rmc.go index 2d0ce7cd..fb752517 100644 --- a/rmc.go +++ b/rmc.go @@ -5,6 +5,8 @@ import ( "fmt" ) +// TODO - We should probably combine RMCRequest and RMCResponse in a single RMCMessage for simpler packet payload setting/reading that supports both request and response payloads + // RMCRequest represets a RMC request type RMCRequest struct { protocolID uint8 diff --git a/sequence_id_manager.go b/sequence_id_manager.go new file mode 100644 index 00000000..ca80bde6 --- /dev/null +++ b/sequence_id_manager.go @@ -0,0 +1,29 @@ +package nex + +// SequenceIDManager implements an API for managing the sequence IDs of different packet streams on a client +type SequenceIDManager struct { + reliableCounter *Counter // TODO - NEX only uses one reliable stream, but Rendezvous supports many. This needs to be a slice! + pingCounter *Counter + // TODO - Unreliable packets for Rendezvous +} + +// Next gets the next sequence ID for the packet. Returns 0 for an unsupported packet +func (s *SequenceIDManager) Next(packet PacketInterface) uint32 { + if packet.HasFlag(FlagReliable) { + return s.reliableCounter.Increment() + } + + if packet.Type() == PingPacket { + return s.pingCounter.Increment() + } + + return 0 +} + +// NewSequenceIDManager returns a new SequenceIDManager +func NewSequenceIDManager() *SequenceIDManager { + return &SequenceIDManager{ + reliableCounter: NewCounter(0), + pingCounter: NewCounter(0), + } +} diff --git a/server.go b/server.go index 5a635b45..3d3e138e 100644 --- a/server.go +++ b/server.go @@ -11,19 +11,20 @@ package nex import ( "crypto/rand" "fmt" + mrand "math/rand" "net" "net/http" "runtime" "strconv" - "sync" "time" + + "golang.org/x/exp/slices" ) // Server represents a PRUDP server type Server struct { socket *net.UDPConn - clients map[string]*Client - clientMutex *sync.RWMutex + clients *MutexMap[string, *Client] genericEventHandles map[string][]func(PacketInterface) prudpV0EventHandles map[string][]func(*PacketV0) prudpV1EventHandles map[string][]func(*PacketV1) @@ -35,7 +36,9 @@ type Server struct { prudpProtocolMinorVersion int supportedFunctions int fragmentSize int16 - resendTimeout float32 + resendTimeout time.Duration + resendTimeoutIncrement time.Duration + resendMaxIterations int pingTimeout int kerberosPassword string kerberosKeySize int @@ -50,6 +53,8 @@ type Server struct { messagingProtocolVersion *NEXVersion utilityProtocolVersion *NEXVersion natTraversalProtocolVersion *NEXVersion + emuSendPacketDropPercent int + emuRecvPacketDropPercent int } // Listen starts a NEX server on a given address @@ -103,18 +108,19 @@ func (server *Server) handleSocketMessage() error { return err } + if server.shouldDropPacket(true) { + // Emulate packet drop for debugging + return nil + } + discriminator := addr.String() - server.clientMutex.RLock() - client, ok := server.clients[discriminator] - server.clientMutex.RUnlock() + client, ok := server.clients.Get(discriminator) if !ok { client = NewClient(addr, server) - server.clientMutex.Lock() - server.clients[discriminator] = client - server.clientMutex.Unlock() + server.clients.Set(discriminator, client) } data := buffer[0:length] @@ -136,6 +142,8 @@ func (server *Server) handleSocketMessage() error { client.IncreasePingTimeoutTime(server.PingTimeout()) if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { + // TODO - Should this return an error? + server.handleAcknowledgement(packet) return nil } @@ -150,6 +158,49 @@ func (server *Server) handleSocketMessage() error { } } + switch packet.Type() { + case PingPacket: + err := server.processPacket(packet) + if err != nil { + // TODO - Should this return the error too? + logger.Error(err.Error()) + return nil + } + default: + // TODO - Make a better API in client to access incomingPacketManager? + client.incomingPacketManager.Push(packet) + + // TODO - Make this API smarter. Only track missing packets and not all packets? + // * Keep processing packets so long as the next one is in the pool, + // * this way if several packets came in out of order they all get + // * processed at once the moment the correct next packet comes in + for next := client.incomingPacketManager.Next(); next != nil; { + err := server.processPacket(next) + if err != nil { + // TODO - Should this return the error too? + logger.Error(err.Error()) + return nil + } + + next = client.incomingPacketManager.Next() + } + } + + return nil +} + +func (server *Server) processPacket(packet PacketInterface) error { + err := packet.DecryptPayload() + if err != nil { + return err + } + + client := packet.Sender() + + if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { + return nil + } + switch packet.Type() { case SynPacket: // * PID should always be 0 when a fresh connection is made @@ -166,6 +217,10 @@ func (server *Server) handleSocketMessage() error { client.SetConnected(true) client.StartTimeoutTimer() + // TODO - Don't make this part suck ass? + // * Manually incrementing because the original manager gets destroyed in the reset + // * but we need to still track the SYN packet was sent + client.incomingPacketManager.currentSequenceID.Increment() server.Emit("Syn", packet) case ConnectPacket: packet.Sender().SetClientConnectionSignature(packet.ConnectionSignature()) @@ -186,6 +241,50 @@ func (server *Server) handleSocketMessage() error { return nil } +func (server *Server) handleAcknowledgement(packet PacketInterface) { + if packet.Version() == 0 || (packet.HasFlag(FlagAck) && !packet.HasFlag(FlagMultiAck)) { + packet.Sender().outgoingResendManager.Remove(packet.SequenceID()) + } else { + // TODO - Validate the aggregate packet is valid and can be processed + sequenceIDs := make([]uint16, 0) + stream := NewStreamIn(packet.Payload(), server) + var baseSequenceID uint16 + + // TODO - We should probably handle these errors lol + if server.PRUDPProtocolMinorVersion() >= 2 { + _, _ = stream.ReadUInt8() // * Substream ID. NEX always uses 0 + additionalIDsCount, _ := stream.ReadUInt8() + baseSequenceID, _ = stream.ReadUInt16LE() + + for i := 0; i < int(additionalIDsCount); i++ { + additionalID, _ := stream.ReadUInt16LE() + sequenceIDs = append(sequenceIDs, additionalID) + } + } else { + baseSequenceID = packet.SequenceID() + + for remaining := stream.Remaining(); remaining != 0; { + additionalID, _ := stream.ReadUInt16LE() + sequenceIDs = append(sequenceIDs, additionalID) + remaining = stream.Remaining() + } + } + + // * MutexMap.Each locks the mutex, can't remove while reading + // * Have to just loop again + packet.Sender().outgoingResendManager.pending.Each(func(sequenceID uint16, pending *PendingPacket) { + if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) { + sequenceIDs = append(sequenceIDs, sequenceID) + } + }) + + // * Actually remove the packets from the pool + for _, sequenceID := range sequenceIDs { + packet.Sender().outgoingResendManager.Remove(sequenceID) + } + } +} + // HPPListen starts a NEX HPP server on a given address func (server *Server) HPPListen(address string) { hppHandler := func(w http.ResponseWriter, req *http.Request) { @@ -364,9 +463,7 @@ func (server *Server) Emit(event string, packet interface{}) { func (server *Server) ClientConnected(client *Client) bool { discriminator := client.Address().String() - server.clientMutex.RLock() - _, connected := server.clients[discriminator] - server.clientMutex.RUnlock() + _, connected := server.clients.Get(discriminator) return connected } @@ -400,9 +497,8 @@ func (server *Server) TimeoutKick(client *Client) { client.SetConnected(false) discriminator := client.Address().String() - server.clientMutex.Lock() - delete(server.clients, discriminator) - server.clientMutex.Unlock() + client.outgoingResendManager.Clear() + server.clients.Delete(discriminator) } // GracefulKick removes an active client from the server @@ -434,20 +530,21 @@ func (server *Server) GracefulKick(client *Client) { server.Emit("Kick", packet) client.SetConnected(false) + client.StopTimeoutTimer() discriminator := client.Address().String() - server.clientMutex.Lock() - delete(server.clients, discriminator) - server.clientMutex.Unlock() + client.outgoingResendManager.Clear() + server.clients.Delete(discriminator) } // GracefulKickAll removes all clients from the server func (server *Server) GracefulKickAll() { // * https://stackoverflow.com/a/40456170 - server.clientMutex.RLock() - defer server.clientMutex.RUnlock() - for _, client := range server.clients { - server.clientMutex.RUnlock() + server.clients.RLock() + defer server.clients.RUnlock() + // TODO - MAKE A BETTER API FOR RANGING OVER THIS DATA INSIDE MutexMap! + for _, client := range server.clients.real { + server.clients.RUnlock() var packet PacketInterface var err error @@ -462,7 +559,7 @@ func (server *Server) GracefulKickAll() { if err != nil { // TODO - Should this return the error too? logger.Error(err.Error()) - server.clientMutex.RLock() + server.clients.RLock() continue } @@ -478,11 +575,10 @@ func (server *Server) GracefulKickAll() { client.SetConnected(false) discriminator := client.Address().String() - server.clientMutex.Lock() - delete(server.clients, discriminator) - server.clientMutex.Unlock() + client.outgoingResendManager.Clear() + server.clients.Delete(discriminator) - server.clientMutex.RLock() + server.clients.RLock() } } @@ -785,6 +881,21 @@ func (server *Server) SetFragmentSize(fragmentSize int16) { server.fragmentSize = fragmentSize } +// SetResendTimeout sets the time that a packet should wait before resending to the client +func (server *Server) SetResendTimeout(resendTimeout time.Duration) { + server.resendTimeout = resendTimeout +} + +// SetResendTimeoutIncrement sets how much to increment the resendTimeout every time a packet is resent to the client +func (server *Server) SetResendTimeoutIncrement(resendTimeoutIncrement time.Duration) { + server.resendTimeoutIncrement = resendTimeoutIncrement +} + +// SetResendMaxIterations sets the max number of times a packet can try to resend before assuming the client is dead +func (server *Server) SetResendMaxIterations(resendMaxIterations int) { + server.resendMaxIterations = resendMaxIterations +} + // ConnectionIDCounter gets the server connection ID counter func (server *Server) ConnectionIDCounter() *Counter { return server.connectionIDCounter @@ -793,16 +904,17 @@ func (server *Server) ConnectionIDCounter() *Counter { // FindClientFromPID finds a client by their PID func (server *Server) FindClientFromPID(pid uint32) *Client { // * https://stackoverflow.com/a/40456170 - server.clientMutex.RLock() - for _, client := range server.clients { - server.clientMutex.RUnlock() + // TODO - MAKE A BETTER API FOR RANGING OVER THIS DATA INSIDE MutexMap! + server.clients.RLock() + for _, client := range server.clients.real { + server.clients.RUnlock() if client.pid == pid { return client } - server.clientMutex.RLock() + server.clients.RLock() } - server.clientMutex.RUnlock() + server.clients.RUnlock() return nil } @@ -810,16 +922,17 @@ func (server *Server) FindClientFromPID(pid uint32) *Client { // FindClientFromConnectionID finds a client by their Connection ID func (server *Server) FindClientFromConnectionID(rvcid uint32) *Client { // * https://stackoverflow.com/a/40456170 - server.clientMutex.RLock() - for _, client := range server.clients { - server.clientMutex.RUnlock() + // TODO - MAKE A BETTER API FOR RANGING OVER THIS DATA INSIDE MutexMap! + server.clients.RLock() + for _, client := range server.clients.real { + server.clients.RUnlock() if client.connectionID == rvcid { return client } - server.clientMutex.RLock() + server.clients.RLock() } - server.clientMutex.RUnlock() + server.clients.RUnlock() return nil } @@ -865,20 +978,51 @@ func (server *Server) Send(packet PacketInterface) { // SendFragment sends a packet fragment to the client func (server *Server) SendFragment(packet PacketInterface, fragmentID uint8) { - data := packet.Payload() client := packet.Sender() + payload := packet.Payload() + + if packet.Type() == DataPacket { + if packet.Version() == 0 && packet.HasFlag(FlagAck) { + // * v0 ACK payloads empty, ensure this + payload = []byte{} + } else if !packet.HasFlag(FlagMultiAck) { + if payload != nil || len(payload) > 0 { + payloadSize := len(payload) + + encrypted := make([]byte, payloadSize) + packet.Sender().Cipher().XORKeyStream(encrypted, payload) + + payload = encrypted + } + } + + // * Only add the HAS_SIZE flag if the payload exists + if !packet.HasFlag(FlagHasSize) && len(payload) > 0 { + packet.AddFlag(FlagHasSize) + } + } packet.SetFragmentID(fragmentID) - packet.SetPayload(data) - packet.SetSequenceID(uint16(client.SequenceIDCounterOut().Increment())) + + packet.SetPayload(payload) + packet.SetSequenceID(uint16(client.SequenceIDOutManager().Next(packet))) encodedPacket := packet.Bytes() server.SendRaw(client.Address(), encodedPacket) + + if (packet.HasFlag(FlagReliable) || packet.Type() == SynPacket) && packet.HasFlag(FlagNeedsAck) { + packet.Sender().outgoingResendManager.Add(packet) + } } // SendRaw writes raw packet data to the client socket func (server *Server) SendRaw(conn *net.UDPAddr, data []byte) { + if server.shouldDropPacket(false) { + // Emulate packet drop for debugging + return + } + _, err := server.Socket().WriteToUDP(data, conn) if err != nil { // TODO - Should this return the error too? @@ -886,23 +1030,43 @@ func (server *Server) SendRaw(conn *net.UDPAddr, data []byte) { } } +func (server *Server) shouldDropPacket(isRecv bool) bool { + if isRecv { + return server.emuRecvPacketDropPercent != 0 && mrand.Intn(100) < server.emuRecvPacketDropPercent + } else { + return server.emuSendPacketDropPercent != 0 && mrand.Intn(100) < server.emuSendPacketDropPercent + } +} + +// SetEmulatedPacketDropPercent sets the percentage of emulated sent and received dropped packets +func (server *Server) SetEmulatedPacketDropPercent(forRecv bool, percent int) { + if forRecv { + server.emuRecvPacketDropPercent = percent + } else { + server.emuSendPacketDropPercent = percent + } +} + // NewServer returns a new NEX server func NewServer() *Server { server := &Server{ - genericEventHandles: make(map[string][]func(PacketInterface)), - prudpV0EventHandles: make(map[string][]func(*PacketV0)), - prudpV1EventHandles: make(map[string][]func(*PacketV1)), - hppEventHandles: make(map[string][]func(*HPPPacket)), - hppClientResponses: make(map[*Client](chan []byte)), - clients: make(map[string]*Client), - clientMutex: &sync.RWMutex{}, - prudpVersion: 1, - fragmentSize: 1300, - resendTimeout: 1.5, - pingTimeout: 5, - kerberosKeySize: 32, - kerberosKeyDerivation: 0, - connectionIDCounter: NewCounter(10), + genericEventHandles: make(map[string][]func(PacketInterface)), + prudpV0EventHandles: make(map[string][]func(*PacketV0)), + prudpV1EventHandles: make(map[string][]func(*PacketV1)), + hppEventHandles: make(map[string][]func(*HPPPacket)), + hppClientResponses: make(map[*Client](chan []byte)), + clients: NewMutexMap[string, *Client](), + prudpVersion: 1, + fragmentSize: 1300, + resendTimeout: time.Second, + resendTimeoutIncrement: 0, + resendMaxIterations: 5, + pingTimeout: 5, + kerberosKeySize: 32, + kerberosKeyDerivation: 0, + connectionIDCounter: NewCounter(10), + emuSendPacketDropPercent: 0, + emuRecvPacketDropPercent: 0, } server.SetDefaultNEXVersion(NewNEXVersion(0, 0, 0)) diff --git a/stream_in.go b/stream_in.go index 046bdafe..795d0e87 100644 --- a/stream_in.go +++ b/stream_in.go @@ -15,9 +15,20 @@ type StreamIn struct { Server *Server } +// Remaining returns the amount of data left to be read in the buffer +func (stream *StreamIn) Remaining() int { + return len(stream.Bytes()[stream.ByteOffset():]) +} + +// ReadRemaining reads all the data left to be read in the buffer +func (stream *StreamIn) ReadRemaining() []byte { + // TODO - Should we do a bounds check here? Or just allow empty slices? + return stream.ReadBytesNext(int64(stream.Remaining())) +} + // ReadBool reads a bool func (stream *StreamIn) ReadBool() (bool, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 1 { + if stream.Remaining() < 1 { return false, errors.New("Not enough data to read bool") } @@ -26,7 +37,7 @@ func (stream *StreamIn) ReadBool() (bool, error) { // ReadUInt8 reads a uint8 func (stream *StreamIn) ReadUInt8() (uint8, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 1 { + if stream.Remaining() < 1 { return 0, errors.New("Not enough data to read uint8") } @@ -35,7 +46,7 @@ func (stream *StreamIn) ReadUInt8() (uint8, error) { // ReadInt8 reads a uint8 func (stream *StreamIn) ReadInt8() (int8, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 1 { + if stream.Remaining() < 1 { return 0, errors.New("Not enough data to read int8") } @@ -44,7 +55,7 @@ func (stream *StreamIn) ReadInt8() (int8, error) { // ReadUInt16LE reads a Little-Endian encoded uint16 func (stream *StreamIn) ReadUInt16LE() (uint16, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 2 { + if stream.Remaining() < 2 { return 0, errors.New("Not enough data to read uint16") } @@ -53,7 +64,7 @@ func (stream *StreamIn) ReadUInt16LE() (uint16, error) { // ReadUInt16BE reads a Big-Endian encoded uint16 func (stream *StreamIn) ReadUInt16BE() (uint16, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 2 { + if stream.Remaining() < 2 { return 0, errors.New("Not enough data to read uint16") } @@ -62,7 +73,7 @@ func (stream *StreamIn) ReadUInt16BE() (uint16, error) { // ReadInt16LE reads a Little-Endian encoded int16 func (stream *StreamIn) ReadInt16LE() (int16, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 2 { + if stream.Remaining() < 2 { return 0, errors.New("Not enough data to read int16") } @@ -71,7 +82,7 @@ func (stream *StreamIn) ReadInt16LE() (int16, error) { // ReadInt16BE reads a Big-Endian encoded int16 func (stream *StreamIn) ReadInt16BE() (int16, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 2 { + if stream.Remaining() < 2 { return 0, errors.New("Not enough data to read int16") } @@ -80,7 +91,7 @@ func (stream *StreamIn) ReadInt16BE() (int16, error) { // ReadUInt32LE reads a Little-Endian encoded uint32 func (stream *StreamIn) ReadUInt32LE() (uint32, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 4 { + if stream.Remaining() < 4 { return 0, errors.New("Not enough data to read uint32") } @@ -89,7 +100,7 @@ func (stream *StreamIn) ReadUInt32LE() (uint32, error) { // ReadUInt32BE reads a Big-Endian encoded uint32 func (stream *StreamIn) ReadUInt32BE() (uint32, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 4 { + if stream.Remaining() < 4 { return 0, errors.New("Not enough data to read uint32") } @@ -98,7 +109,7 @@ func (stream *StreamIn) ReadUInt32BE() (uint32, error) { // ReadInt32LE reads a Little-Endian encoded int32 func (stream *StreamIn) ReadInt32LE() (int32, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 4 { + if stream.Remaining() < 4 { return 0, errors.New("Not enough data to read int32") } @@ -107,7 +118,7 @@ func (stream *StreamIn) ReadInt32LE() (int32, error) { // ReadInt32BE reads a Big-Endian encoded int32 func (stream *StreamIn) ReadInt32BE() (int32, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 4 { + if stream.Remaining() < 4 { return 0, errors.New("Not enough data to read int32") } @@ -116,7 +127,7 @@ func (stream *StreamIn) ReadInt32BE() (int32, error) { // ReadUInt64LE reads a Little-Endian encoded uint64 func (stream *StreamIn) ReadUInt64LE() (uint64, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 8 { + if stream.Remaining() < 8 { return 0, errors.New("Not enough data to read uint64") } @@ -125,7 +136,7 @@ func (stream *StreamIn) ReadUInt64LE() (uint64, error) { // ReadUInt64BE reads a Big-Endian encoded uint64 func (stream *StreamIn) ReadUInt64BE() (uint64, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 8 { + if stream.Remaining() < 8 { return 0, errors.New("Not enough data to read uint64") } @@ -134,7 +145,7 @@ func (stream *StreamIn) ReadUInt64BE() (uint64, error) { // ReadInt64LE reads a Little-Endian encoded int64 func (stream *StreamIn) ReadInt64LE() (int64, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 8 { + if stream.Remaining() < 8 { return 0, errors.New("Not enough data to read int64") } @@ -143,7 +154,7 @@ func (stream *StreamIn) ReadInt64LE() (int64, error) { // ReadInt64BE reads a Big-Endian encoded int64 func (stream *StreamIn) ReadInt64BE() (int64, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 8 { + if stream.Remaining() < 8 { return 0, errors.New("Not enough data to read int64") } @@ -152,7 +163,7 @@ func (stream *StreamIn) ReadInt64BE() (int64, error) { // ReadFloat32LE reads a Little-Endian encoded float32 func (stream *StreamIn) ReadFloat32LE() (float32, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 4 { + if stream.Remaining() < 4 { return 0, errors.New("Not enough data to read float32") } @@ -161,7 +172,7 @@ func (stream *StreamIn) ReadFloat32LE() (float32, error) { // ReadFloat32BE reads a Big-Endian encoded float32 func (stream *StreamIn) ReadFloat32BE() (float32, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 4 { + if stream.Remaining() < 4 { return 0, errors.New("Not enough data to read float32") } @@ -170,7 +181,7 @@ func (stream *StreamIn) ReadFloat32BE() (float32, error) { // ReadFloat64LE reads a Little-Endian encoded float64 func (stream *StreamIn) ReadFloat64LE() (float64, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 8 { + if stream.Remaining() < 8 { return 0, errors.New("Not enough data to read float64") } @@ -179,7 +190,7 @@ func (stream *StreamIn) ReadFloat64LE() (float64, error) { // ReadFloat64BE reads a Big-Endian encoded float64 func (stream *StreamIn) ReadFloat64BE() (float64, error) { - if len(stream.Bytes()[stream.ByteOffset():]) < 8 { + if stream.Remaining() < 8 { return 0, errors.New("Not enough data to read float64") } @@ -193,7 +204,7 @@ func (stream *StreamIn) ReadString() (string, error) { return "", fmt.Errorf("Failed to read NEX string length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length) { + if stream.Remaining() < int(length) { return "", errors.New("NEX string length longer than data size") } @@ -210,7 +221,7 @@ func (stream *StreamIn) ReadBuffer() ([]byte, error) { return []byte{}, fmt.Errorf("Failed to read NEX buffer length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length) { + if stream.Remaining() < int(length) { return []byte{}, errors.New("NEX buffer length longer than data size") } @@ -226,7 +237,7 @@ func (stream *StreamIn) ReadQBuffer() ([]byte, error) { return []byte{}, fmt.Errorf("Failed to read NEX qBuffer length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length) { + if stream.Remaining() < int(length) { return []byte{}, errors.New("NEX qBuffer length longer than data size") } @@ -257,7 +268,7 @@ func (stream *StreamIn) ReadStructure(structure StructureInterface) (StructureIn return nil, fmt.Errorf("Failed to read NEX Structure content length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(structureLength) { + if stream.Remaining() < int(structureLength) { return nil, errors.New("NEX Structure content length longer than data size") } @@ -373,7 +384,7 @@ func (stream *StreamIn) ReadListUInt8() ([]uint8, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length) { + if stream.Remaining() < int(length) { return nil, errors.New("NEX List length longer than data size") } @@ -398,7 +409,7 @@ func (stream *StreamIn) ReadListInt8() ([]int8, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length) { + if stream.Remaining() < int(length) { return nil, errors.New("NEX List length longer than data size") } @@ -423,7 +434,7 @@ func (stream *StreamIn) ReadListUInt16LE() ([]uint16, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*2) { + if stream.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } @@ -448,7 +459,7 @@ func (stream *StreamIn) ReadListUInt16BE() ([]uint16, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*2) { + if stream.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } @@ -473,7 +484,7 @@ func (stream *StreamIn) ReadListInt16LE() ([]int16, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*2) { + if stream.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } @@ -498,7 +509,7 @@ func (stream *StreamIn) ReadListInt16BE() ([]int16, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*2) { + if stream.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } @@ -523,7 +534,7 @@ func (stream *StreamIn) ReadListUInt32LE() ([]uint32, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } @@ -548,7 +559,7 @@ func (stream *StreamIn) ReadListUInt32BE() ([]uint32, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } @@ -573,7 +584,7 @@ func (stream *StreamIn) ReadListInt32LE() ([]int32, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } @@ -598,7 +609,7 @@ func (stream *StreamIn) ReadListInt32BE() ([]int32, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } @@ -623,7 +634,7 @@ func (stream *StreamIn) ReadListUInt64LE() ([]uint64, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*8) { + if stream.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } @@ -648,7 +659,7 @@ func (stream *StreamIn) ReadListUInt64BE() ([]uint64, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*8) { + if stream.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } @@ -673,7 +684,7 @@ func (stream *StreamIn) ReadListInt64LE() ([]int64, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*8) { + if stream.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } @@ -698,7 +709,7 @@ func (stream *StreamIn) ReadListInt64BE() ([]int64, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*8) { + if stream.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } @@ -723,7 +734,7 @@ func (stream *StreamIn) ReadListFloat32LE() ([]float32, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } @@ -748,7 +759,7 @@ func (stream *StreamIn) ReadListFloat32BE() ([]float32, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } @@ -773,7 +784,7 @@ func (stream *StreamIn) ReadListFloat64LE() ([]float64, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } @@ -798,7 +809,7 @@ func (stream *StreamIn) ReadListFloat64BE() ([]float64, error) { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if len(stream.Bytes()[stream.ByteOffset():]) < int(length*4) { + if stream.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") }