From f9f4d3bca43baf5494e3ea9c9e369aac1a576229 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 10 Nov 2023 01:28:25 -0500 Subject: [PATCH 001/178] total rewrite --- client.go | 278 ------ client_interface.go | 11 + compression.go | 27 - counter.go | 29 +- go.mod | 2 +- hpp_packet.go | 119 --- hpp_server.go | 9 + init.go | 4 - kerberos.go | 237 ++--- library_version.go | 78 ++ md5.go | 11 - mutex_map.go | 11 +- nex_version.go | 78 -- packet.go | 171 ---- packet_interface.go | 35 +- packet_manager.go | 43 - packet_resend_manager.go | 113 --- packet_v0.go | 326 ------- packet_v1.go | 389 -------- prudp_client.go | 180 ++++ prudp_packet.go | 164 ++++ packet_flags.go => prudp_packet_flags.go | 6 +- prudp_packet_interface.go | 40 + packet_types.go => prudp_packet_types.go | 2 +- prudp_packet_v0.go | 341 +++++++ prudp_packet_v1.go | 351 +++++++ prudp_server.go | 697 ++++++++++++++ reliable_packet_substream_manager.go | 94 ++ resend_scheduler.go | 124 +++ rmc.go | 307 +++--- sequence_id_manager.go | 29 - server.go | 1075 ---------------------- server_interface.go | 17 + stream_in.go | 10 +- stream_out.go | 10 +- sum.go | 13 +- test/auth.go | 154 ++++ test/generate_ticket.go | 34 + test/main.go | 14 + test/secure.go | 254 +++++ types.go | 23 +- 41 files changed, 2838 insertions(+), 3072 deletions(-) delete mode 100644 client.go create mode 100644 client_interface.go delete mode 100644 compression.go delete mode 100644 hpp_packet.go create mode 100644 hpp_server.go create mode 100644 library_version.go delete mode 100644 md5.go delete mode 100644 nex_version.go delete mode 100644 packet.go delete mode 100644 packet_manager.go delete mode 100644 packet_resend_manager.go delete mode 100644 packet_v0.go delete mode 100644 packet_v1.go create mode 100644 prudp_client.go create mode 100644 prudp_packet.go rename packet_flags.go => prudp_packet_flags.go (87%) create mode 100644 prudp_packet_interface.go rename packet_types.go => prudp_packet_types.go (93%) create mode 100644 prudp_packet_v0.go create mode 100644 prudp_packet_v1.go create mode 100644 prudp_server.go create mode 100644 reliable_packet_substream_manager.go create mode 100644 resend_scheduler.go delete mode 100644 sequence_id_manager.go delete mode 100644 server.go create mode 100644 server_interface.go create mode 100644 test/auth.go create mode 100644 test/generate_ticket.go create mode 100644 test/main.go create mode 100644 test/secure.go diff --git a/client.go b/client.go deleted file mode 100644 index 9818db6a..00000000 --- a/client.go +++ /dev/null @@ -1,278 +0,0 @@ -package nex - -import ( - "crypto/rc4" - "fmt" - "net" - "time" -) - -// Client represents a connected or non-connected PRUDP client -type Client struct { - address *net.UDPAddr - server *Server - cipher *rc4.Cipher - decipher *rc4.Cipher - prudpProtocolMinorVersion int - supportedFunctions int - signatureKey []byte - signatureBase int - serverConnectionSignature []byte - clientConnectionSignature []byte - sessionKey []byte - sequenceIDIn *Counter - sequenceIDOutManager *SequenceIDManager - pid uint32 - stationURLs []*StationURL - connectionID uint32 - pingCheckTimer *time.Timer - pingKickTimer *time.Timer - connected bool - incomingPacketManager *PacketManager - outgoingResendManager *PacketResendManager -} - -// Reset resets the Client to default values -func (client *Client) Reset() error { - server := client.Server() - - client.sequenceIDIn = NewCounter(0) - client.sequenceIDOutManager = NewSequenceIDManager() // TODO - Pass the server into here to get data for multiple substreams and the unreliable starting ID - client.incomingPacketManager = NewPacketManager() - - if client.outgoingResendManager != nil { - // * PacketResendManager makes use of time.Ticker structs. - // * These create new channels and goroutines which won't - // * close even if the objects are deleted. To free up - // * resources, time.Ticker MUST be stopped before reassigning - client.outgoingResendManager.Clear() - } - - client.outgoingResendManager = NewPacketResendManager(server.resendTimeout, server.resendTimeoutIncrement, server.resendMaxIterations) - - client.UpdateAccessKey(server.AccessKey()) - err := client.UpdateRC4Key([]byte("CD&ML")) - if err != nil { - return fmt.Errorf("Failed to update client RC4 key. %s", err.Error()) - } - - if server.PRUDPVersion() == 0 { - client.SetServerConnectionSignature(make([]byte, 4)) - client.SetClientConnectionSignature(make([]byte, 4)) - } else { - client.SetServerConnectionSignature([]byte{}) - client.SetClientConnectionSignature([]byte{}) - } - - client.SetConnected(false) - - return nil -} - -// Address returns the clients UDP address -func (client *Client) Address() *net.UDPAddr { - return client.address -} - -// Server returns the server the client is currently connected to -func (client *Client) Server() *Server { - return client.server -} - -// PRUDPProtocolMinorVersion returns the client PRUDP minor version -func (client *Client) PRUDPProtocolMinorVersion() int { - return client.prudpProtocolMinorVersion -} - -// SetPRUDPProtocolMinorVersion sets the client PRUDP minor -func (client *Client) SetPRUDPProtocolMinorVersion(prudpProtocolMinorVersion int) { - client.prudpProtocolMinorVersion = prudpProtocolMinorVersion -} - -// SupportedFunctions returns the supported PRUDP functions by the client -func (client *Client) SupportedFunctions() int { - return client.supportedFunctions -} - -// SetSupportedFunctions sets the supported PRUDP functions by the client -func (client *Client) SetSupportedFunctions(supportedFunctions int) { - client.supportedFunctions = supportedFunctions -} - -// UpdateRC4Key sets the client RC4 stream key -func (client *Client) UpdateRC4Key(key []byte) error { - cipher, err := rc4.NewCipher(key) - if err != nil { - return fmt.Errorf("Failed to create RC4 cipher. %s", err.Error()) - } - - client.cipher = cipher - - decipher, err := rc4.NewCipher(key) - if err != nil { - return fmt.Errorf("Failed to create RC4 decipher. %s", err.Error()) - } - - client.decipher = decipher - - return nil -} - -// Cipher returns the RC4 cipher stream for out-bound packets -func (client *Client) Cipher() *rc4.Cipher { - return client.cipher -} - -// Decipher returns the RC4 cipher stream for in-bound packets -func (client *Client) Decipher() *rc4.Cipher { - return client.decipher -} - -// UpdateAccessKey sets the client signature base and signature key -func (client *Client) UpdateAccessKey(accessKey string) { - client.signatureBase = sum([]byte(accessKey)) - client.signatureKey = MD5Hash([]byte(accessKey)) -} - -// SignatureBase returns the v0 checksum signature base -func (client *Client) SignatureBase() int { - return client.signatureBase -} - -// SignatureKey returns signature key -func (client *Client) SignatureKey() []byte { - return client.signatureKey -} - -// SetServerConnectionSignature sets the clients server-side connection signature -func (client *Client) SetServerConnectionSignature(serverConnectionSignature []byte) { - client.serverConnectionSignature = serverConnectionSignature -} - -// ServerConnectionSignature returns the clients server-side connection signature -func (client *Client) ServerConnectionSignature() []byte { - return client.serverConnectionSignature -} - -// SetClientConnectionSignature sets the clients client-side connection signature -func (client *Client) SetClientConnectionSignature(clientConnectionSignature []byte) { - client.clientConnectionSignature = clientConnectionSignature -} - -// ClientConnectionSignature returns the clients client-side connection signature -func (client *Client) ClientConnectionSignature() []byte { - return client.clientConnectionSignature -} - -// SequenceIDOutManager returns the clients packet SequenceID manager for out-going packets -func (client *Client) SequenceIDOutManager() *SequenceIDManager { - return client.sequenceIDOutManager -} - -// SequenceIDCounterIn returns the clients packet SequenceID counter for incoming packets -func (client *Client) SequenceIDCounterIn() *Counter { - return client.sequenceIDIn -} - -// SetSessionKey sets the clients session key -func (client *Client) SetSessionKey(sessionKey []byte) { - client.sessionKey = sessionKey -} - -// SessionKey returns the clients session key -func (client *Client) SessionKey() []byte { - return client.sessionKey -} - -// SetPID sets the clients NEX PID -func (client *Client) SetPID(pid uint32) { - client.pid = pid -} - -// PID returns the clients NEX PID -func (client *Client) PID() uint32 { - return client.pid -} - -// SetStationURLs sets the clients Station URLs -func (client *Client) SetStationURLs(stationURLs []*StationURL) { - client.stationURLs = stationURLs -} - -// AddStationURL adds the StationURL to the clients StationURLs -func (client *Client) AddStationURL(stationURL *StationURL) { - client.stationURLs = append(client.stationURLs, stationURL) -} - -// StationURLs returns the clients Station URLs -func (client *Client) StationURLs() []*StationURL { - return client.stationURLs -} - -// SetConnectionID sets the clients Connection ID -func (client *Client) SetConnectionID(connectionID uint32) { - client.connectionID = connectionID -} - -// ConnectionID returns the clients Connection ID -func (client *Client) ConnectionID() uint32 { - return client.connectionID -} - -// SetConnected sets the clients connection status -func (client *Client) SetConnected(connected bool) { - client.connected = connected -} - -// IncreasePingTimeoutTime adds a number of seconds to the check timer -func (client *Client) IncreasePingTimeoutTime(seconds int) { - //Stop the kick timer if we get something back - if client.pingKickTimer != nil { - client.pingKickTimer.Stop() - } - //and reset the check timer - if client.pingCheckTimer != nil { - client.pingCheckTimer.Reset(time.Second * time.Duration(seconds)) - } -} - -// StartTimeoutTimer begins the packet timeout timer -func (client *Client) StartTimeoutTimer() { - //if we haven't gotten a ping *from* the client, send them one to check all is well - client.pingCheckTimer = time.AfterFunc(time.Second*time.Duration(client.server.PingTimeout()), func() { - client.server.SendPing(client) - //if we *still* get nothing, they're gone - client.pingKickTimer = time.AfterFunc(time.Second*time.Duration(client.server.PingTimeout()), func() { - client.server.TimeoutKick(client) - }) - }) -} - -// StopTimeoutTimer stops the packet timeout timer -func (client *Client) StopTimeoutTimer() { - //Stop the kick timer - if client.pingKickTimer != nil { - client.pingKickTimer.Stop() - } - //and the check timer - if client.pingCheckTimer != nil { - client.pingCheckTimer.Stop() - } -} - -// NewClient returns a new PRUDP client -func NewClient(address *net.UDPAddr, server *Server) *Client { - client := &Client{ - address: address, - server: server, - } - - err := client.Reset() - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return nil - } - - return client -} diff --git a/client_interface.go b/client_interface.go new file mode 100644 index 00000000..2f496f7b --- /dev/null +++ b/client_interface.go @@ -0,0 +1,11 @@ +package nex + +import "net" + +// ClientInterface defines all the methods a client should have regardless of server type +type ClientInterface interface { + Server() ServerInterface + Address() net.Addr + PID() uint32 + SetPID(pid uint32) +} diff --git a/compression.go b/compression.go deleted file mode 100644 index aebd2b38..00000000 --- a/compression.go +++ /dev/null @@ -1,27 +0,0 @@ -package nex - -// DummyCompression represents no compression -type DummyCompression struct{} - -// Compress returns the data as-is -func (compression *DummyCompression) Compress(data []byte) []byte { - return data -} - -// Decompress returns the data as-is -func (compression *DummyCompression) Decompress(data []byte) []byte { - return data -} - -// ZLibCompression represents ZLib compression -type ZLibCompression struct{} - -// Compress returns the data as-is (needs to be updated to return ZLib compressed data) -func (compression *ZLibCompression) Compress(data []byte) []byte { - return data -} - -// Decompress returns the data as-is (needs to be updated to return ZLib decompressed data) -func (compression *ZLibCompression) Decompress(data []byte) []byte { - return data -} diff --git a/counter.go b/counter.go index c17093e1..4f55a7a6 100644 --- a/counter.go +++ b/counter.go @@ -1,24 +1,25 @@ package nex -// Counter represents an incremental counter -type Counter struct { - value uint32 +import ( + "golang.org/x/exp/constraints" +) + +type numeric interface { + constraints.Integer | constraints.Float | constraints.Complex } -// Value returns the counters current value -func (counter Counter) Value() uint32 { - return counter.value +// Counter represents an incremental counter of a specific numeric type +type Counter[T numeric] struct { + Value T } -// Increment increments the counter by 1 and returns the value -func (counter *Counter) Increment() uint32 { - counter.value++ - return counter.Value() +// Next increments the counter by 1 and returns the new value +func (c *Counter[T]) Next() T { + c.Value++ + return c.Value } // NewCounter returns a new Counter, with a starting number -func NewCounter(start uint32) *Counter { - counter := &Counter{value: start} - - return counter +func NewCounter[T numeric](start T) *Counter[T] { + return &Counter[T]{Value: start} } diff --git a/go.mod b/go.mod index 911601da..173c8657 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/PretendoNetwork/nex-go -go 1.18 +go 1.21 require ( github.com/PretendoNetwork/plogger-go v1.0.4 diff --git a/hpp_packet.go b/hpp_packet.go deleted file mode 100644 index 7c42ef3b..00000000 --- a/hpp_packet.go +++ /dev/null @@ -1,119 +0,0 @@ -package nex - -import ( - "bytes" - "crypto/hmac" - "crypto/md5" - "encoding/hex" - "errors" -) - -// HPPPacket represents an HPP packet -type HPPPacket struct { - Packet - accessKeySignature []byte - passwordSignature []byte -} - -// SetAccessKeySignature sets the packet access key signature -func (packet *HPPPacket) SetAccessKeySignature(accessKeySignature string) { - accessKeySignatureBytes, err := hex.DecodeString(accessKeySignature) - if err != nil { - logger.Error("[HPP] Failed to convert AccessKeySignature to bytes") - } - - packet.accessKeySignature = accessKeySignatureBytes -} - -// AccessKeySignature returns the packet access key signature -func (packet *HPPPacket) AccessKeySignature() []byte { - return packet.accessKeySignature -} - -// SetPasswordSignature sets the packet password signature -func (packet *HPPPacket) SetPasswordSignature(passwordSignature string) { - passwordSignatureBytes, err := hex.DecodeString(passwordSignature) - if err != nil { - logger.Error("[HPP] Failed to convert PasswordSignature to bytes") - } - - packet.passwordSignature = passwordSignatureBytes -} - -// PasswordSignature returns the packet password signature -func (packet *HPPPacket) PasswordSignature() []byte { - return packet.passwordSignature -} - -// ValidateAccessKey checks if the access key signature is valid -func (packet *HPPPacket) ValidateAccessKey() error { - accessKey := packet.Sender().Server().AccessKey() - buffer := packet.rmcRequest.Bytes() - - accessKeyBytes, err := hex.DecodeString(accessKey) - if err != nil { - return err - } - - calculatedAccessKeySignature := packet.calculateSignature(buffer, accessKeyBytes) - if !bytes.Equal(calculatedAccessKeySignature, packet.accessKeySignature) { - return errors.New("[HPP] Access key signature is not valid") - } - - return nil -} - -// ValidatePassword checks if the password signature is valid -func (packet *HPPPacket) ValidatePassword() error { - if packet.sender.server.passwordFromPIDHandler == nil { - return errors.New("[HPP] Missing passwordFromPIDHandler!") - } - - pid := packet.Sender().PID() - buffer := packet.rmcRequest.Bytes() - - password, _ := packet.sender.server.passwordFromPIDHandler(pid) - if password == "" { - return errors.New("[HPP] PID does not exist") - } - - passwordBytes := []byte(password) - - passwordSignatureKey := DeriveKerberosKey(pid, passwordBytes) - - calculatedPasswordSignature := packet.calculateSignature(buffer, passwordSignatureKey) - if !bytes.Equal(calculatedPasswordSignature, packet.passwordSignature) { - return errors.New("[HPP] Password signature is invalid") - } - - return nil -} - -func (packet *HPPPacket) calculateSignature(buffer []byte, key []byte) []byte { - mac := hmac.New(md5.New, key) - mac.Write(buffer) - hmac := mac.Sum(nil) - - return hmac -} - -// NewHPPPacket returns a new HPP packet -func NewHPPPacket(client *Client, data []byte) (*HPPPacket, error) { - packet := NewPacket(client, data) - - hppPacket := HPPPacket{Packet: packet} - - if data != nil { - hppPacket.payload = data - - rmcRequest := NewRMCRequest() - err := rmcRequest.FromBytes(data) - if err != nil { - return &HPPPacket{}, errors.New("[HPP] Error parsing RMC request: " + err.Error()) - } - - hppPacket.rmcRequest = rmcRequest - } - - return &hppPacket, nil -} diff --git a/hpp_server.go b/hpp_server.go new file mode 100644 index 00000000..c9d11ef9 --- /dev/null +++ b/hpp_server.go @@ -0,0 +1,9 @@ +package nex + +// HPPServer represents a bare-bones HPP server +type HPPServer struct{} + +// NewHPPServer returns a new HPP server +func NewHPPServer() *HPPServer { + return &HPPServer{} +} diff --git a/init.go b/init.go index c69cf42c..178e8e3f 100644 --- a/init.go +++ b/init.go @@ -1,9 +1,5 @@ package nex -import "github.com/PretendoNetwork/plogger-go" - -var logger = plogger.NewLogger() - func init() { initErrorsData() } diff --git a/kerberos.go b/kerberos.go index 175c550f..f76fdc77 100644 --- a/kerberos.go +++ b/kerberos.go @@ -1,194 +1,122 @@ package nex import ( - "bytes" "crypto/hmac" "crypto/md5" "crypto/rand" "crypto/rc4" + "errors" "fmt" ) -// KerberosEncryption is used to encrypt/decrypt using Kerberos +// KerberosEncryption is a struct representing a Kerberos encryption utility type KerberosEncryption struct { - key []byte - cipher *rc4.Cipher + key []byte } -// Encrypt will encrypt the given data using Kerberos -func (encryption *KerberosEncryption) Encrypt(buffer []byte) []byte { - encrypted := make([]byte, len(buffer)) - encryption.cipher.XORKeyStream(encrypted, buffer) +// Validate checks the integrity of the given buffer by verifying the HMAC checksum +func (ke *KerberosEncryption) Validate(buffer []byte) bool { + data := buffer[:len(buffer)-0x10] + checksum := buffer[len(buffer)-0x10:] + mac := hmac.New(md5.New, ke.key) - mac := hmac.New(md5.New, []byte(encryption.key)) - mac.Write(encrypted) - hmac := mac.Sum(nil) + mac.Write(data) - return append(encrypted, hmac...) + return hmac.Equal(checksum, mac.Sum(nil)) } -// Decrypt will decrypt the given data using Kerberos -func (encryption *KerberosEncryption) Decrypt(buffer []byte) []byte { - if !encryption.Validate(buffer) { - logger.Error("Keberos hmac validation failed") +// Decrypt decrypts the provided buffer if it passes the integrity check +func (ke *KerberosEncryption) Decrypt(buffer []byte) ([]byte, error) { + if !ke.Validate(buffer) { + return nil, errors.New("Invalid Kerberos checksum (incorrect password)") } - offset := len(buffer) - offset = offset + -0x10 - - encrypted := buffer[:offset] - - decrypted := make([]byte, len(encrypted)) - encryption.cipher.XORKeyStream(decrypted, encrypted) - - return decrypted -} - -// Validate will check the HMAC of the encrypted data -func (encryption *KerberosEncryption) Validate(buffer []byte) bool { - offset := len(buffer) - offset = offset + -0x10 + cipher, err := rc4.NewCipher(ke.key) + if err != nil { + return nil, err + } - data := buffer[:offset] - checksum := buffer[offset:] + decrypted := make([]byte, len(buffer)-0x10) - cipher := hmac.New(md5.New, []byte(encryption.key)) - cipher.Write(data) - mac := cipher.Sum(nil) + cipher.XORKeyStream(decrypted, buffer[:len(buffer)-0x10]) - return bytes.Equal(mac, checksum) + return decrypted, nil } -// NewKerberosEncryption returns a new KerberosEncryption instance -func NewKerberosEncryption(key []byte) (*KerberosEncryption, error) { - cipher, err := rc4.NewCipher(key) - if err != nil { - return nil, fmt.Errorf("Failed to create Kerberos RC4 cipher. %s", err.Error()) - } - - return &KerberosEncryption{key: key, cipher: cipher}, nil -} +// Encrypt encrypts the given buffer and appends an HMAC checksum for integrity +func (ke *KerberosEncryption) Encrypt(buffer []byte) []byte { + cipher, _ := rc4.NewCipher(ke.key) + encrypted := make([]byte, len(buffer)) -// Ticket represents a Kerberos authentication ticket -type Ticket struct { - sessionKey []byte - targetPID uint32 - internalData []byte -} + cipher.XORKeyStream(encrypted, buffer) -// SessionKey returns the Tickets session key -func (ticket *Ticket) SessionKey() []byte { - return ticket.sessionKey -} + mac := hmac.New(md5.New, ke.key) -// SetSessionKey sets the Tickets session key -func (ticket *Ticket) SetSessionKey(sessionKey []byte) { - ticket.sessionKey = sessionKey -} + mac.Write(encrypted) -// TargetPID returns the Tickets target PID -func (ticket *Ticket) TargetPID() uint32 { - return ticket.targetPID -} + checksum := mac.Sum(nil) -// SetTargetPID sets the Tickets target PID -func (ticket *Ticket) SetTargetPID(targetPID uint32) { - ticket.targetPID = targetPID + return append(encrypted, checksum...) } -// InternalData returns the Tickets internal data buffer -func (ticket *Ticket) InternalData() []byte { - return ticket.internalData +// NewKerberosEncryption creates a new KerberosEncryption instance with the given key. +func NewKerberosEncryption(key []byte) *KerberosEncryption { + return &KerberosEncryption{key: key} } -// SetInternalData sets the Tickets internal data buffer -func (ticket *Ticket) SetInternalData(internalData []byte) { - ticket.internalData = internalData +// KerberosTicket represents a ticket granting a user access to a secure server +type KerberosTicket struct { + SessionKey []byte + TargetPID uint32 + InternalData []byte } // Encrypt writes the ticket data to the provided stream and returns the encrypted byte slice -func (ticket *Ticket) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { - encryption, err := NewKerberosEncryption(key) - if err != nil { - return nil, fmt.Errorf("Failed to create Kerberos ticket encryption instance. %s", err.Error()) - } +func (kt *KerberosTicket) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { + encryption := NewKerberosEncryption(key) - // Session key is not a NEX buffer type - stream.Grow(int64(len(ticket.sessionKey))) - stream.WriteBytesNext(ticket.sessionKey) + stream.Grow(int64(len(kt.SessionKey))) + stream.WriteBytesNext(kt.SessionKey) - stream.WriteUInt32LE(ticket.targetPID) - stream.WriteBuffer(ticket.internalData) + stream.WriteUInt32LE(kt.TargetPID) + stream.WriteBuffer(kt.InternalData) return encryption.Encrypt(stream.Bytes()), nil } // NewKerberosTicket returns a new Ticket instance -func NewKerberosTicket() *Ticket { - return &Ticket{} +func NewKerberosTicket() *KerberosTicket { + return &KerberosTicket{} } -// TicketInternalData contains information sent to the secure server -type TicketInternalData struct { - timestamp *DateTime - userPID uint32 - sessionKey []byte -} - -// Timestamp returns the TicketInternalDatas timestamp -func (ticketInternalData *TicketInternalData) Timestamp() *DateTime { - return ticketInternalData.timestamp -} - -// SetTimestamp sets the TicketInternalDatas timestamp -func (ticketInternalData *TicketInternalData) SetTimestamp(timestamp *DateTime) { - ticketInternalData.timestamp = timestamp -} - -// UserPID returns the TicketInternalDatas user PID -func (ticketInternalData *TicketInternalData) UserPID() uint32 { - return ticketInternalData.userPID -} - -// SetUserPID sets the TicketInternalDatas user PID -func (ticketInternalData *TicketInternalData) SetUserPID(userPID uint32) { - ticketInternalData.userPID = userPID -} - -// SessionKey returns the TicketInternalDatas session key -func (ticketInternalData *TicketInternalData) SessionKey() []byte { - return ticketInternalData.sessionKey -} - -// SetSessionKey sets the TicketInternalDatas session key -func (ticketInternalData *TicketInternalData) SetSessionKey(sessionKey []byte) { - ticketInternalData.sessionKey = sessionKey +// KerberosTicketInternalData holds the internal data for a kerberos ticket to be processed by the server +type KerberosTicketInternalData struct { + Issued *DateTime + SourcePID uint32 + SessionKey []byte } // Encrypt writes the ticket data to the provided stream and returns the encrypted byte slice -func (ticketInternalData *TicketInternalData) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { - stream.WriteDateTime(ticketInternalData.timestamp) - stream.WriteUInt32LE(ticketInternalData.userPID) +func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { + stream.WriteDateTime(ti.Issued) + stream.WriteUInt32LE(ti.SourcePID) - // Session key is not a NEX buffer type - stream.Grow(int64(len(ticketInternalData.sessionKey))) - stream.WriteBytesNext(ticketInternalData.sessionKey) + stream.Grow(int64(len(ti.SessionKey))) + stream.WriteBytesNext(ti.SessionKey) data := stream.Bytes() - if stream.Server.KerberosTicketVersion() == 1 { + if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { ticketKey := make([]byte, 16) _, err := rand.Read(ticketKey) if err != nil { return nil, fmt.Errorf("Failed to generate ticket key. %s", err.Error()) } - finalKey := MD5Hash(append(key, ticketKey...)) + hash := md5.Sum(append(key, ticketKey...)) + finalKey := hash[:] - encryption, err := NewKerberosEncryption(finalKey) - if err != nil { - return nil, fmt.Errorf("Failed to create Kerberos ticket internal data encryption instance. %s", err.Error()) - } + encryption := NewKerberosEncryption(finalKey) encrypted := encryption.Encrypt(data) @@ -198,19 +126,16 @@ func (ticketInternalData *TicketInternalData) Encrypt(key []byte, stream *Stream finalStream.WriteBuffer(encrypted) return finalStream.Bytes(), nil - } else { - encryption, err := NewKerberosEncryption([]byte(key)) - if err != nil { - return nil, fmt.Errorf("Failed to create Kerberos ticket internal data encryption instance. %s", err.Error()) - } - - return encryption.Encrypt(data), nil } + + encryption := NewKerberosEncryption([]byte(key)) + + return encryption.Encrypt(data), nil } // Decrypt decrypts the given data and populates the struct -func (ticketInternalData *TicketInternalData) Decrypt(stream *StreamIn, key []byte) error { - if stream.Server.KerberosTicketVersion() == 1 { +func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) error { + if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { ticketKey, err := stream.ReadBuffer() if err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data key. %s", err.Error()) @@ -221,18 +146,19 @@ func (ticketInternalData *TicketInternalData) Decrypt(stream *StreamIn, key []by return fmt.Errorf("Failed to read Kerberos ticket internal data. %s", err.Error()) } - key = MD5Hash(append(key, ticketKey...)) + hash := md5.Sum(append(key, ticketKey...)) + key = hash[:] stream = NewStreamIn(data, stream.Server) } - encryption, err := NewKerberosEncryption(key) + encryption := NewKerberosEncryption(key) + + decrypted, err := encryption.Decrypt(stream.Bytes()) if err != nil { - return fmt.Errorf("Failed to create Kerberos ticket internal data encryption instance. %s", err.Error()) + return fmt.Errorf("Failed to decrypt Kerberos ticket internal data. %s", err.Error()) } - decrypted := encryption.Decrypt(stream.Bytes()) - stream = NewStreamIn(decrypted, stream.Server) timestamp, err := stream.ReadDateTime() @@ -245,23 +171,26 @@ func (ticketInternalData *TicketInternalData) Decrypt(stream *StreamIn, key []by return fmt.Errorf("Failed to read Kerberos ticket internal data user PID %s", err.Error()) } - ticketInternalData.SetTimestamp(timestamp) - ticketInternalData.SetUserPID(userPID) - ticketInternalData.SetSessionKey(stream.ReadBytesNext(int64(stream.Server.KerberosKeySize()))) + ti.Issued = timestamp + ti.SourcePID = userPID + ti.SessionKey = stream.ReadBytesNext(int64(stream.Server.(*PRUDPServer).kerberosKeySize)) return nil } -// NewKerberosTicketInternalData returns a new TicketInternalData instance -func NewKerberosTicketInternalData() *TicketInternalData { - return &TicketInternalData{} +// NewKerberosTicketInternalData returns a new KerberosTicketInternalData instance +func NewKerberosTicketInternalData() *KerberosTicketInternalData { + return &KerberosTicketInternalData{} } // DeriveKerberosKey derives a users kerberos encryption key based on their PID and password func DeriveKerberosKey(pid uint32, password []byte) []byte { + key := password + for i := 0; i < 65000+int(pid)%1024; i++ { - password = MD5Hash(password) + hash := md5.Sum(key) + key = hash[:] } - return password + return key } diff --git a/library_version.go b/library_version.go new file mode 100644 index 00000000..adb0dcf7 --- /dev/null +++ b/library_version.go @@ -0,0 +1,78 @@ +package nex + +import ( + "fmt" + "strings" + + "golang.org/x/mod/semver" +) + +// LibraryVersion represents a NEX library version +type LibraryVersion struct { + Major int + Minor int + Patch int + GameSpecificPatch string + semver string +} + +// Copy returns a new copied instance of LibraryVersion +func (lv *LibraryVersion) Copy() *LibraryVersion { + return &LibraryVersion{ + Major: lv.Major, + Minor: lv.Minor, + Patch: lv.Patch, + GameSpecificPatch: lv.GameSpecificPatch, + semver: fmt.Sprintf("v%d.%d.%d", lv.Major, lv.Minor, lv.Patch), + } +} + +func (lv *LibraryVersion) semverCompare(compare string) int { + if !strings.HasPrefix(compare, "v") { + // * Faster than doing "v" + string(compare) + var b strings.Builder + + b.WriteString("v") + b.WriteString(compare) + + compare = b.String() + } + + if !semver.IsValid(compare) { + // * The semver package returns 0 (equal) for invalid semvers in semver.Compare + return 0 + } + + return semver.Compare(lv.semver, compare) +} + +// GreaterOrEqual compares if the given semver is greater than or equal to the current version +func (lv *LibraryVersion) GreaterOrEqual(compare string) bool { + return lv.semverCompare(compare) != -1 +} + +// LessOrEqual compares if the given semver is lesser than or equal to the current version +func (lv *LibraryVersion) LessOrEqual(compare string) bool { + return lv.semverCompare(compare) != 1 +} + +// NewPatchedLibraryVersion returns a new LibraryVersion with a game specific patch +func NewPatchedLibraryVersion(major, minor, patch int, gameSpecificPatch string) *LibraryVersion { + return &LibraryVersion{ + Major: major, + Minor: minor, + Patch: patch, + GameSpecificPatch: gameSpecificPatch, + semver: fmt.Sprintf("v%d.%d.%d", major, minor, patch), + } +} + +// NewLibraryVersion returns a new LibraryVersion +func NewLibraryVersion(major, minor, patch int) *LibraryVersion { + return &LibraryVersion{ + Major: major, + Minor: minor, + Patch: patch, + semver: fmt.Sprintf("v%d.%d.%d", major, minor, patch), + } +} diff --git a/md5.go b/md5.go deleted file mode 100644 index 206618d4..00000000 --- a/md5.go +++ /dev/null @@ -1,11 +0,0 @@ -package nex - -import "crypto/md5" - - -// MD5Hash returns the MD5 hash of the input -func MD5Hash(text []byte) []byte { - hasher := md5.New() - hasher.Write(text) - return hasher.Sum(nil) -} \ No newline at end of file diff --git a/mutex_map.go b/mutex_map.go index 26284746..0324d379 100644 --- a/mutex_map.go +++ b/mutex_map.go @@ -26,6 +26,15 @@ func (m *MutexMap[K, V]) Get(key K) (V, bool) { return value, ok } +// Has checks if a key exists in the map +func (m *MutexMap[K, V]) Has(key K) bool { + m.RLock() + defer m.RUnlock() + + _, ok := m.real[key] + return ok +} + // Delete removes a key from the internal map func (m *MutexMap[K, V]) Delete(key K) { m.Lock() @@ -38,7 +47,7 @@ func (m *MutexMap[K, V]) Delete(key K) { func (m *MutexMap[K, V]) RunAndDelete(key K, callback func(key K, value V)) { m.Lock() defer m.Unlock() - + if value, ok := m.real[key]; ok { callback(key, value) delete(m.real, key) diff --git a/nex_version.go b/nex_version.go deleted file mode 100644 index dd1e9630..00000000 --- a/nex_version.go +++ /dev/null @@ -1,78 +0,0 @@ -package nex - -import ( - "fmt" - "strings" - - "golang.org/x/mod/semver" -) - -// NEXVersion represents a NEX library version -type NEXVersion struct { - Major int - Minor int - Patch int - GameSpecificPatch string - semver string -} - -// Copy returns a new copied instance of NEXVersion -func (nexVersion *NEXVersion) Copy() *NEXVersion { - return &NEXVersion{ - Major: nexVersion.Major, - Minor: nexVersion.Minor, - Patch: nexVersion.Patch, - GameSpecificPatch: nexVersion.GameSpecificPatch, - semver: fmt.Sprintf("v%d.%d.%d", nexVersion.Major, nexVersion.Minor, nexVersion.Patch), - } -} - -func (nexVersion *NEXVersion) semverCompare(compare string) int { - if !strings.HasPrefix(compare, "v") { - // * Faster than doing "v" + string(compare) - var b strings.Builder - - b.WriteString("v") - b.WriteString(compare) - - compare = b.String() - } - - if !semver.IsValid(compare) { - // * The semver package returns 0 (equal) for invalid semvers in semver.Compare - return 0 - } - - return semver.Compare(nexVersion.semver, compare) -} - -// GreaterOrEqual compares if the given semver is greater than or equal to the current version -func (nexVersion *NEXVersion) GreaterOrEqual(compare string) bool { - return nexVersion.semverCompare(compare) != -1 -} - -// LessOrEqual compares if the given semver is lesser than or equal to the current version -func (nexVersion *NEXVersion) LessOrEqual(compare string) bool { - return nexVersion.semverCompare(compare) != 1 -} - -// NewPatchedNEXVersion returns a new NEXVersion with a game specific patch -func NewPatchedNEXVersion(major, minor, patch int, gameSpecificPatch string) *NEXVersion { - return &NEXVersion{ - Major: major, - Minor: minor, - Patch: patch, - GameSpecificPatch: gameSpecificPatch, - semver: fmt.Sprintf("v%d.%d.%d", major, minor, patch), - } -} - -// NewNEXVersion returns a new NEXVersion -func NewNEXVersion(major, minor, patch int) *NEXVersion { - return &NEXVersion{ - Major: major, - Minor: minor, - Patch: patch, - semver: fmt.Sprintf("v%d.%d.%d", major, minor, patch), - } -} diff --git a/packet.go b/packet.go deleted file mode 100644 index 53dcd787..00000000 --- a/packet.go +++ /dev/null @@ -1,171 +0,0 @@ -package nex - -// Packet represents a generic PRUDP packet -type Packet struct { - sender *Client - data []byte - version uint8 - source uint8 - destination uint8 - packetType uint16 - flags uint16 - sessionID uint8 - signature []byte - sequenceID uint16 - connectionSignature []byte - fragmentID uint8 - payload []byte - rmcRequest RMCRequest - PacketInterface -} - -// Data returns bytes used to create the packet (this is not the same as Bytes()) -func (packet *Packet) Data() []byte { - return packet.data -} - -// Sender returns the packet sender -func (packet *Packet) Sender() *Client { - return packet.sender -} - -// SetVersion sets the packet PRUDP version -func (packet *Packet) SetVersion(version uint8) { - packet.version = version -} - -// Version gets the packet PRUDP version -func (packet *Packet) Version() uint8 { - return packet.version -} - -// SetSource sets the packet source -func (packet *Packet) SetSource(source uint8) { - packet.source = source -} - -// Source returns the packet source -func (packet *Packet) Source() uint8 { - return packet.source -} - -// SetDestination sets the packet destination -func (packet *Packet) SetDestination(destination uint8) { - packet.destination = destination -} - -// Destination returns the packet destination -func (packet *Packet) Destination() uint8 { - return packet.destination -} - -// SetType sets the packet type -func (packet *Packet) SetType(packetType uint16) { - packet.packetType = packetType -} - -// Type returns the packet type -func (packet *Packet) Type() uint16 { - return packet.packetType -} - -// SetFlags sets the packet flag bitmask -func (packet *Packet) SetFlags(bitmask uint16) { - packet.flags = bitmask -} - -// Flags returns the packet flag bitmask -func (packet *Packet) Flags() uint16 { - return packet.flags -} - -// HasFlag checks if the packet has the given flag -func (packet *Packet) HasFlag(flag uint16) bool { - return packet.flags&flag != 0 -} - -// AddFlag adds the given flag to the packet flag bitmask -func (packet *Packet) AddFlag(flag uint16) { - packet.flags |= flag -} - -// ClearFlag removes the given flag from the packet bitmask -func (packet *Packet) ClearFlag(flag uint16) { - packet.flags &^= flag -} - -// SetSessionID sets the packet sessionID -func (packet *Packet) SetSessionID(sessionID uint8) { - packet.sessionID = sessionID -} - -// SessionID returns the packet sessionID -func (packet *Packet) SessionID() uint8 { - return packet.sessionID -} - -// SetSignature sets the packet signature -func (packet *Packet) SetSignature(signature []byte) { - packet.signature = signature -} - -// Signature returns the packet signature -func (packet *Packet) Signature() []byte { - return packet.signature -} - -// SetSequenceID sets the packet sequenceID -func (packet *Packet) SetSequenceID(sequenceID uint16) { - packet.sequenceID = sequenceID -} - -// SequenceID returns the packet sequenceID -func (packet *Packet) SequenceID() uint16 { - return packet.sequenceID -} - -// SetConnectionSignature sets the packet connection signature -func (packet *Packet) SetConnectionSignature(connectionSignature []byte) { - packet.connectionSignature = connectionSignature -} - -// ConnectionSignature returns the packet connection signature -func (packet *Packet) ConnectionSignature() []byte { - return packet.connectionSignature -} - -// SetFragmentID sets the packet fragmentID -func (packet *Packet) SetFragmentID(fragmentID uint8) { - packet.fragmentID = fragmentID -} - -// FragmentID returns the packet fragmentID -func (packet *Packet) FragmentID() uint8 { - return packet.fragmentID -} - -// SetPayload sets the packet payload -func (packet *Packet) SetPayload(payload []byte) { - packet.payload = payload -} - -// Payload returns the packet payload -func (packet *Packet) Payload() []byte { - return packet.payload -} - -// RMCRequest returns the packet RMC request -func (packet *Packet) RMCRequest() RMCRequest { - return packet.rmcRequest -} - -// NewPacket returns a new PRUDP packet generic -func NewPacket(client *Client, data []byte) Packet { - packet := Packet{ - sender: client, - data: data, - payload: []byte{}, - } - - return packet -} diff --git a/packet_interface.go b/packet_interface.go index 71aaf5b9..a2bcc397 100644 --- a/packet_interface.go +++ b/packet_interface.go @@ -1,35 +1,10 @@ package nex -// PacketInterface implements all Packet methods +// PacketInterface defines all the methods a packet for both PRUDP and HPP should have type PacketInterface interface { - Data() []byte - Sender() *Client - SetVersion(version uint8) - Version() uint8 - SetSource(source uint8) - Source() uint8 - SetDestination(destination uint8) - Destination() uint8 - SetType(packetType uint16) - Type() uint16 - SetFlags(bitmask uint16) - Flags() uint16 - HasFlag(flag uint16) bool - AddFlag(flag uint16) - ClearFlag(flag uint16) - SetSessionID(sessionID uint8) - SessionID() uint8 - SetSignature(signature []byte) - Signature() []byte - SetSequenceID(sequenceID uint16) - SequenceID() uint16 - SetConnectionSignature(connectionSignature []byte) - ConnectionSignature() []byte - SetFragmentID(fragmentID uint8) - FragmentID() uint8 - SetPayload(payload []byte) + Sender() ClientInterface Payload() []byte - DecryptPayload() error - RMCRequest() RMCRequest - Bytes() []byte + SetPayload(payload []byte) + RMCMessage() *RMCMessage + SetRMCMessage(message *RMCMessage) } diff --git a/packet_manager.go b/packet_manager.go deleted file mode 100644 index ab49806c..00000000 --- a/packet_manager.go +++ /dev/null @@ -1,43 +0,0 @@ -package nex - -// PacketManager implements an API for pushing/popping packets in the correct order -type PacketManager struct { - currentSequenceID *Counter - packets []PacketInterface -} - -// Next gets the next packet in the sequence. Returns nil if the next packet has not been sent yet -func (p *PacketManager) Next() PacketInterface { - var packet PacketInterface - - for i := 0; i < len(p.packets); i++ { - if p.currentSequenceID.Value() == uint32(p.packets[i].SequenceID()) { - packet = p.packets[i] - p.RemoveByIndex(i) - p.currentSequenceID.Increment() - break - } - } - - return packet -} - -// Push adds a packet to the pool to choose from in Next -func (p *PacketManager) Push(packet PacketInterface) { - p.packets = append(p.packets, packet) -} - -// RemoveByIndex removes a packet from the pool using it's index in the slice -func (p *PacketManager) RemoveByIndex(i int) { - // * https://stackoverflow.com/a/37335777 - p.packets[i] = p.packets[len(p.packets)-1] - p.packets = p.packets[:len(p.packets)-1] -} - -// NewPacketManager returns a new PacketManager -func NewPacketManager() *PacketManager { - return &PacketManager{ - currentSequenceID: NewCounter(0), - packets: make([]PacketInterface, 0), - } -} diff --git a/packet_resend_manager.go b/packet_resend_manager.go deleted file mode 100644 index efa89d0a..00000000 --- a/packet_resend_manager.go +++ /dev/null @@ -1,113 +0,0 @@ -package nex - -import ( - "time" -) - -// PendingPacket represents a packet which the server has sent but not received an ACK for -// it handles it's own retransmission on a per-packet timer -type PendingPacket struct { - ticking bool - ticker *time.Ticker - quit chan struct{} - packet PacketInterface - iterations *Counter - timeout time.Duration - timeoutInc time.Duration - maxIterations int -} - -// BeginTimeoutTimer starts the pending packets timeout timer until it is either stopped or maxIterations is hit -func (p *PendingPacket) BeginTimeoutTimer() { - go func() { - for { - select { - case <-p.quit: - return - case <-p.ticker.C: - client := p.packet.Sender() - server := client.Server() - - if int(p.iterations.Increment()) > p.maxIterations { - // * Max iterations hit. Assume client is dead - server.TimeoutKick(client) - p.StopTimeoutTimer() - return - } else { - if p.timeoutInc != 0 { - p.timeout += p.timeoutInc - p.ticker.Reset(p.timeout) - } - - // * Resend the packet - server.SendRaw(client.Address(), p.packet.Bytes()) - } - } - } - }() -} - -// StopTimeoutTimer stops the packet retransmission timer -func (p *PendingPacket) StopTimeoutTimer() { - if p.ticking { - close(p.quit) - p.ticker.Stop() - p.ticking = false - } -} - -// NewPendingPacket returns a new PendingPacket -func NewPendingPacket(packet PacketInterface, timeoutTime time.Duration, timeoutIncrement time.Duration, maxIterations int) *PendingPacket { - p := &PendingPacket{ - ticking: true, - ticker: time.NewTicker(timeoutTime), - quit: make(chan struct{}), - packet: packet, - iterations: NewCounter(0), - timeout: timeoutTime, - timeoutInc: timeoutIncrement, - maxIterations: maxIterations, - } - - return p -} - -// PacketResendManager manages all the pending packets sent the client waiting to be ACKed -type PacketResendManager struct { - pending *MutexMap[uint16, *PendingPacket] - timeoutTime time.Duration - timeoutInc time.Duration - maxIterations int -} - -// Add creates a PendingPacket, adds it to the pool, and begins it's timeout timer -func (p *PacketResendManager) Add(packet PacketInterface) { - cached := NewPendingPacket(packet, p.timeoutTime, p.timeoutInc, p.maxIterations) - p.pending.Set(packet.SequenceID(), cached) - - cached.BeginTimeoutTimer() -} - -// Remove removes a packet from pool and stops it's timer -func (p *PacketResendManager) Remove(sequenceID uint16) { - p.pending.RunAndDelete(sequenceID, func(key uint16, value *PendingPacket){ - value.StopTimeoutTimer() - }) -} - -// Clear removes all packets from pool and stops their timers -func (p *PacketResendManager) Clear() { - p.pending.Clear(func(key uint16, value *PendingPacket) { - value.StopTimeoutTimer() - }) -} - -// NewPacketResendManager returns a new PacketResendManager -func NewPacketResendManager(timeoutTime time.Duration, timeoutIncrement time.Duration, maxIterations int) *PacketResendManager { - return &PacketResendManager{ - pending: NewMutexMap[uint16, *PendingPacket](), - timeoutTime: timeoutTime, - timeoutInc: timeoutIncrement, - maxIterations: maxIterations, - } -} diff --git a/packet_v0.go b/packet_v0.go deleted file mode 100644 index 3161d459..00000000 --- a/packet_v0.go +++ /dev/null @@ -1,326 +0,0 @@ -package nex - -import ( - "crypto/hmac" - "crypto/md5" - "encoding/binary" - "errors" - "fmt" -) - -// PacketV0 reresents a PRUDPv0 packet -type PacketV0 struct { - Packet - checksum uint8 -} - -// SetChecksum sets the packet checksum -func (packet *PacketV0) SetChecksum(checksum uint8) { - packet.checksum = checksum -} - -// Checksum returns the packet checksum -func (packet *PacketV0) Checksum() uint8 { - return packet.checksum -} - -// Decode decodes the packet -func (packet *PacketV0) Decode() error { - checksumSize := 1 - var payloadSize uint16 - var typeFlags uint16 - - stream := NewStreamIn(packet.Data(), packet.Sender().Server()) - - source, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 source. %s", err.Error()) - } - - packet.SetSource(source) - - destination, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 destination. %s", err.Error()) - } - - packet.SetDestination(destination) - - typeFlags, err = stream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 type-flags. %s", err.Error()) - } - - packet.SetType(typeFlags & 0xF) - - if _, ok := validTypes[packet.Type()]; !ok { - return errors.New("Invalid PRUDP packet type") - } - - packet.SetFlags(typeFlags >> 4) - - sessionID, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 session ID. %s", err.Error()) - } - - packet.SetSessionID(sessionID) - - if len(stream.Bytes()[stream.ByteOffset():]) < 4 { - return errors.New("Failed to read PRUDPv0 packet signature. Not have enough data") - } - - packet.SetSignature(stream.ReadBytesNext(4)) - - sequenceID, err := stream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 sequence ID. %s", err.Error()) - } - - packet.SetSequenceID(sequenceID) - - if packet.Type() == SynPacket || packet.Type() == ConnectPacket { - if len(packet.Data()[stream.ByteOffset():]) < 4 { - return errors.New("Failed to read PRUDPv0 connection signature. Not enough data") - } - - packet.SetConnectionSignature(stream.ReadBytesNext(4)) - } - - if packet.Type() == DataPacket { - fragmentID, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 connection signature. %s", err.Error()) - } - - packet.SetFragmentID(fragmentID) - } - - if packet.HasFlag(FlagHasSize) { - if len(packet.Data()[stream.ByteOffset():]) < 2 { - return errors.New("[PRUDPv0] Packet specific data not large enough for payload size") - } - - payloadSize, err = stream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 payload size. %s", err.Error()) - } - } else { - payloadSize = uint16(len(packet.data) - int(stream.ByteOffset()) - checksumSize) - } - - if payloadSize > 0 { - if len(packet.Data()[stream.ByteOffset():]) < int(payloadSize) { - return errors.New("[PRUDPv0] Packet data length less than payload length") - } - - payloadCrypted := stream.ReadBytesNext(int64(payloadSize)) - - packet.SetPayload(payloadCrypted) - } - - if len(packet.Data()[stream.ByteOffset():]) < int(checksumSize) { - return errors.New("[PRUDPv0] Packet data length less than checksum length") - } - - checksum, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 packet checksum. %s", err.Error()) - } - - packet.SetChecksum(checksum) - - packetBody := stream.Bytes() - - calculatedChecksum := packet.calculateChecksum(packetBody[:len(packetBody)-checksumSize]) - - if calculatedChecksum != packet.Checksum() { - logger.Error("PRUDPv0 packet calculated checksum did not match") - } - - return nil -} - -// DecryptPayload decrypts the packets payload and sets the RMC request data -func (packet *PacketV0) DecryptPayload() error { - if packet.Type() == DataPacket && !packet.HasFlag(FlagAck) { - ciphered := make([]byte, len(packet.Payload())) - - packet.Sender().Decipher().XORKeyStream(ciphered, packet.Payload()) - - request := NewRMCRequest() - err := request.FromBytes(ciphered) - if err != nil { - return fmt.Errorf("Failed to read PRUDPv0 RMC request. %s", err.Error()) - } - - packet.rmcRequest = request - } - - return nil -} - -// Bytes encodes the packet and returns a byte array -func (packet *PacketV0) Bytes() []byte { - var typeFlags uint16 = packet.Type() | packet.Flags()<<4 - - stream := NewStreamOut(packet.Sender().Server()) - packetSignature := packet.calculateSignature() - - stream.WriteUInt8(packet.Source()) - stream.WriteUInt8(packet.Destination()) - stream.WriteUInt16LE(typeFlags) - stream.WriteUInt8(packet.SessionID()) - stream.Grow(int64(len(packetSignature))) - stream.WriteBytesNext(packetSignature) - stream.WriteUInt16LE(packet.SequenceID()) - - options := packet.encodeOptions() - optionsLength := len(options) - - if optionsLength > 0 { - stream.Grow(int64(optionsLength)) - stream.WriteBytesNext(options) - } - - payload := packet.Payload() - - if len(payload) > 0 { - stream.Grow(int64(len(payload))) - stream.WriteBytesNext(payload) - } - - checksum := packet.calculateChecksum(stream.Bytes()) - - stream.WriteUInt8(checksum) - - return stream.Bytes() -} - -func (packet *PacketV0) calculateSignature() []byte { - // Friends server handles signatures differently, so check for the Friends server access key - if packet.Sender().Server().AccessKey() == "ridfebb9" { - if packet.Type() == DataPacket { - payload := packet.Payload() - - if payload == nil || len(payload) <= 0 { - signature := NewStreamOut(packet.Sender().Server()) - signature.WriteUInt32LE(0x12345678) - - return signature.Bytes() - } - - key := packet.Sender().SignatureKey() - cipher := hmac.New(md5.New, key) - cipher.Write(payload) - - return cipher.Sum(nil)[:4] - } - - clientConnectionSignature := packet.Sender().ClientConnectionSignature() - - if clientConnectionSignature != nil { - return clientConnectionSignature - } - - return []byte{0x0, 0x0, 0x0, 0x0} - } else { // Normal signature handling - if packet.Type() == DataPacket || packet.Type() == DisconnectPacket { - payload := NewStreamOut(packet.Sender().Server()) - sessionKey := packet.Sender().SessionKey() - if sessionKey != nil { - payload.Grow(int64(len(sessionKey))) - payload.WriteBytesNext(sessionKey) - } - payload.WriteUInt16LE(packet.sequenceID) - payload.Grow(1) - payload.WriteByteNext(packet.fragmentID) - pktpay := packet.Payload() - if len(pktpay) > 0 { - payload.Grow(int64(len(pktpay))) - payload.WriteBytesNext(pktpay) - } - - key := packet.Sender().SignatureKey() - cipher := hmac.New(md5.New, key) - cipher.Write(payload.Bytes()) - - return cipher.Sum(nil)[:4] - } else { - clientConnectionSignature := packet.Sender().ClientConnectionSignature() - - if clientConnectionSignature != nil { - return clientConnectionSignature - } - } - } - - return []byte{} -} - -func (packet *PacketV0) encodeOptions() []byte { - stream := NewStreamOut(packet.Sender().Server()) - - if packet.Type() == SynPacket { - stream.Grow(4) - stream.WriteBytesNext(packet.Sender().ServerConnectionSignature()) - } - - if packet.Type() == ConnectPacket { - stream.Grow(4) - stream.WriteBytesNext([]byte{0x00, 0x00, 0x00, 0x00}) - } - - if packet.Type() == DataPacket { - stream.WriteUInt8(packet.FragmentID()) - } - - if packet.HasFlag(FlagHasSize) { - payload := packet.Payload() - - if payload != nil { - stream.WriteUInt16LE(uint16(len(payload))) - } else { - stream.WriteUInt16LE(0) - } - } - - return stream.Bytes() -} - -func (packet *PacketV0) calculateChecksum(data []byte) uint8 { - signatureBase := packet.Sender().SignatureBase() - steps := len(data) / 4 - var temp uint32 - - for i := 0; i < steps; i++ { - offset := i * 4 - temp += binary.LittleEndian.Uint32(data[offset : offset+4]) - } - - temp &= 0xFFFFFFFF - - buff := make([]byte, 4) - binary.LittleEndian.PutUint32(buff, temp) - - checksum := signatureBase - checksum += sum(data[len(data) & ^3:]) - checksum += sum(buff) - - return uint8(checksum & 0xFF) -} - -// NewPacketV0 returns a new PRUDPv0 packet -func NewPacketV0(client *Client, data []byte) (*PacketV0, error) { - packet := NewPacket(client, data) - packetv0 := PacketV0{Packet: packet} - - if data != nil { - err := packetv0.Decode() - if err != nil { - return &PacketV0{}, errors.New("[PRUDPv0] Error decoding packet data: " + err.Error()) - } - } - - return &packetv0, nil -} diff --git a/packet_v1.go b/packet_v1.go deleted file mode 100644 index 4fd505b6..00000000 --- a/packet_v1.go +++ /dev/null @@ -1,389 +0,0 @@ -package nex - -import ( - "bytes" - "crypto/hmac" - "crypto/md5" - "encoding/binary" - "errors" - "fmt" -) - -// Magic is the expected PRUDPv1 magic number -var Magic = []byte{0xEA, 0xD0} - -// OptionAllFunctions is used with OptionSupportedFunctions to support all methods -var OptionAllFunctions = 0xFFFFFFFF - -// OptionSupportedFunctions is the ID for the Supported Functions option in PRUDP v1 packets -var OptionSupportedFunctions uint8 = 0 - -// OptionConnectionSignature is the ID for the Connection Signature option in PRUDP v1 packets -var OptionConnectionSignature uint8 = 1 - -// OptionFragmentID is the ID for the Fragment ID option in PRUDP v1 packets -var OptionFragmentID uint8 = 2 - -// OptionInitialSequenceID is the ID for the initial sequence ID option in PRUDP v1 packets -var OptionInitialSequenceID uint8 = 3 - -// OptionMaxSubstreamID is the ID for the max substream ID option in PRUDP v1 packets -var OptionMaxSubstreamID uint8 = 4 - -// PacketV1 reresents a PRUDPv1 packet -type PacketV1 struct { - Packet - magic []byte - substreamID uint8 - prudpProtocolMinorVersion int - supportedFunctions int - initialSequenceID uint16 - maximumSubstreamID uint8 -} - -// SetSubstreamID sets the packet substream ID -func (packet *PacketV1) SetSubstreamID(substreamID uint8) { - packet.substreamID = substreamID -} - -// SubstreamID returns the packet substream ID -func (packet *PacketV1) SubstreamID() uint8 { - return packet.substreamID -} - -// PRUDPProtocolMinorVersion returns the packet PRUDP minor version -func (packet *PacketV1) PRUDPProtocolMinorVersion() int { - return packet.prudpProtocolMinorVersion -} - -// SetPRUDPProtocolMinorVersion sets the packet PRUDP minor version -func (packet *PacketV1) SetPRUDPProtocolMinorVersion(prudpProtocolMinorVersion int) { - packet.prudpProtocolMinorVersion = prudpProtocolMinorVersion -} - -// SetSupportedFunctions sets the packet supported functions flags -func (packet *PacketV1) SetSupportedFunctions(supportedFunctions int) { - packet.supportedFunctions = supportedFunctions -} - -// SupportedFunctions returns the packet supported functions flags -func (packet *PacketV1) SupportedFunctions() int { - return packet.supportedFunctions -} - -// SetInitialSequenceID sets the packet initial sequence ID for unreliable packets -func (packet *PacketV1) SetInitialSequenceID(initialSequenceID uint16) { - packet.initialSequenceID = initialSequenceID -} - -// InitialSequenceID returns the packet initial sequence ID for unreliable packets -func (packet *PacketV1) InitialSequenceID() uint16 { - return packet.initialSequenceID -} - -// SetMaximumSubstreamID sets the packet maximum substream ID -func (packet *PacketV1) SetMaximumSubstreamID(maximumSubstreamID uint8) { - packet.maximumSubstreamID = maximumSubstreamID -} - -// MaximumSubstreamID returns the packet maximum substream ID -func (packet *PacketV1) MaximumSubstreamID() uint8 { - return packet.maximumSubstreamID -} - -// Decode decodes the packet -func (packet *PacketV1) Decode() error { - stream := NewStreamIn(packet.Data(), packet.Sender().Server()) - - if len(stream.Bytes()[stream.ByteOffset():]) < 2 { - return errors.New("Failed to read PRUDPv1 magic. Not have enough data") - } - - packet.magic = stream.ReadBytesNext(2) - - if !bytes.Equal(packet.magic, Magic) { - return fmt.Errorf("Invalid PRUDPv1 magic. Expected %x, got %x", Magic, packet.magic) - } - - version, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 version. %s", err.Error()) - } - - packet.SetVersion(version) - - if packet.Version() != 1 { - return fmt.Errorf("Invalid PRUDPv1 version. Expected 1, got %d", packet.Version()) - } - - optionsLength, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 options length. %s", err.Error()) - } - - payloadSize, err := stream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 payload size. %s", err.Error()) - } - - source, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 source. %s", err.Error()) - } - - packet.SetSource(source) - - destination, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 destination. %s", err.Error()) - } - - packet.SetDestination(destination) - - typeFlags, err := stream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 type-flags. %s", err.Error()) - } - - packet.SetType(typeFlags & 0xF) - - if _, ok := validTypes[packet.Type()]; !ok { - return errors.New("Invalid PRUDP packet type") - } - - packet.SetFlags(typeFlags >> 4) - - sessionID, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 session ID. %s", err.Error()) - } - - packet.SetSessionID(sessionID) - - substreamID, err := stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 substream ID. %s", err.Error()) - } - - packet.SetSubstreamID(substreamID) - - sequenceID, err := stream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 sequence ID. %s", err.Error()) - } - - packet.SetSequenceID(sequenceID) - - if len(stream.Bytes()[stream.ByteOffset():]) < 16 { - return errors.New("Failed to read PRUDPv1 packet signature. Not have enough data") - } - - packet.SetSignature(stream.ReadBytesNext(16)) - - if len(packet.Data()[stream.ByteOffset():]) < int(optionsLength) { - return errors.New("[PRUDPv1] Packet specific data size does not match") - } - - options := stream.ReadBytesNext(int64(optionsLength)) - - err = packet.decodeOptions(options) - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 options. %s", err.Error()) - } - - if payloadSize > 0 { - if len(packet.Data()[stream.ByteOffset():]) < int(payloadSize) { - return errors.New("Failed to read PRUDPv1 packet payload. Not enough data") - } - - payloadCrypted := stream.ReadBytesNext(int64(payloadSize)) - - packet.SetPayload(payloadCrypted) - } - - calculatedSignature := packet.calculateSignature(packet.Data()[2:14], packet.Sender().ServerConnectionSignature(), options, packet.Payload()) - - if !bytes.Equal(calculatedSignature, packet.Signature()) { - logger.Error("PRUDPv1 calculated signature did not match") - } - - return nil -} - -// DecryptPayload decrypts the packets payload and sets the RMC request data -func (packet *PacketV1) DecryptPayload() error { - if packet.Type() == DataPacket && !packet.HasFlag(FlagMultiAck) { - ciphered := make([]byte, len(packet.Payload())) - - packet.Sender().Decipher().XORKeyStream(ciphered, packet.Payload()) - - request := NewRMCRequest() - err := request.FromBytes(ciphered) - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 RMC request. %s", err.Error()) - } - - packet.rmcRequest = request - } - - return nil -} - -// Bytes encodes the packet and returns a byte array -func (packet *PacketV1) Bytes() []byte { - var typeFlags uint16 = packet.Type() | packet.Flags()<<4 - - stream := NewStreamOut(packet.Sender().Server()) - - stream.WriteUInt16LE(0xD0EA) // v1 magic - stream.WriteUInt8(1) - - options := packet.encodeOptions() - optionsLength := len(options) - - stream.WriteUInt8(uint8(optionsLength)) - stream.WriteUInt16LE(uint16(len(packet.Payload()))) - stream.WriteUInt8(packet.Source()) - stream.WriteUInt8(packet.Destination()) - stream.WriteUInt16LE(typeFlags) - stream.WriteUInt8(packet.SessionID()) - stream.WriteUInt8(packet.SubstreamID()) - stream.WriteUInt16LE(packet.SequenceID()) - - signature := packet.calculateSignature(stream.Bytes()[2:14], packet.Sender().ClientConnectionSignature(), options, packet.Payload()) - - stream.Grow(int64(len(signature))) - stream.WriteBytesNext(signature) - - if optionsLength > 0 { - stream.Grow(int64(optionsLength)) - stream.WriteBytesNext(options) - } - - payload := packet.Payload() - payloadLength := len(payload) - - if payload != nil && payloadLength > 0 { - stream.Grow(int64(payloadLength)) - stream.WriteBytesNext(payload) - } - - return stream.Bytes() -} - -func (packet *PacketV1) decodeOptions(options []byte) error { - optionsStream := NewStreamIn(options, packet.Sender().Server()) - - for optionsStream.ByteOffset() != optionsStream.ByteCapacity() { - optionID, err := optionsStream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 option ID. %s", err.Error()) - } - - optionSize, err := optionsStream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 option size for option ID %d. %s", optionID, err.Error()) - } - - switch optionID { - case OptionSupportedFunctions: - supportedFunctions, err := optionsStream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 option supported functions. %s", err.Error()) - } - - packet.sender.SetPRUDPProtocolMinorVersion(int(supportedFunctions & 0xFF)) - packet.sender.SetSupportedFunctions(int(supportedFunctions >> 8)) - case OptionConnectionSignature: - packet.SetConnectionSignature(optionsStream.ReadBytesNext(int64(optionSize))) - case OptionFragmentID: - fragmentID, err := optionsStream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 option fragment ID. %s", err.Error()) - } - - packet.SetFragmentID(fragmentID) - case OptionInitialSequenceID: - sequenceID, err := optionsStream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 option sequence ID. %s", err.Error()) - } - - packet.SetInitialSequenceID(sequenceID) - case OptionMaxSubstreamID: - maximumSubstreamID, err := optionsStream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read PRUDPv1 option maximum substream ID. %s", err.Error()) - } - - packet.SetMaximumSubstreamID(maximumSubstreamID) - } - } - - return nil -} - -func (packet *PacketV1) encodeOptions() []byte { - stream := NewStreamOut(packet.Sender().Server()) - - if packet.Type() == SynPacket || packet.Type() == ConnectPacket { - stream.WriteUInt8(OptionSupportedFunctions) - stream.WriteUInt8(4) - stream.WriteUInt32LE(uint32(packet.prudpProtocolMinorVersion) | uint32(packet.supportedFunctions<<8)) - - stream.WriteUInt8(OptionConnectionSignature) - stream.WriteUInt8(16) - stream.Grow(16) - stream.WriteBytesNext(packet.ConnectionSignature()) - - if packet.Type() == ConnectPacket { - stream.WriteUInt8(OptionInitialSequenceID) - stream.WriteUInt8(2) - stream.WriteUInt16LE(packet.initialSequenceID) - } - - stream.WriteUInt8(OptionMaxSubstreamID) - stream.WriteUInt8(1) - stream.WriteUInt8(packet.maximumSubstreamID) - } else if packet.Type() == DataPacket { - stream.WriteUInt8(OptionFragmentID) - stream.WriteUInt8(1) - stream.WriteUInt8(packet.FragmentID()) - } - - return stream.Bytes() -} - -func (packet *PacketV1) calculateSignature(header []byte, connectionSignature []byte, options []byte, payload []byte) []byte { - key := packet.Sender().SignatureKey() - sessionKey := packet.Sender().SessionKey() - - signatureBase := make([]byte, 4) - binary.LittleEndian.PutUint32(signatureBase, uint32(packet.Sender().SignatureBase())) - - mac := hmac.New(md5.New, key) - - mac.Write(header[4:]) - mac.Write(sessionKey) - mac.Write(signatureBase) - mac.Write(connectionSignature) - mac.Write(options) - mac.Write(payload) - - return mac.Sum(nil) -} - -// NewPacketV1 returns a new PRUDPv1 packet -func NewPacketV1(client *Client, data []byte) (*PacketV1, error) { - packet := NewPacket(client, data) - packetv1 := PacketV1{Packet: packet} - - if data != nil { - err := packetv1.Decode() - if err != nil { - return &PacketV1{}, fmt.Errorf("Failed to decode PRUDPv1 packet. %s", err.Error()) - } - } - - return &packetv1, nil -} diff --git a/prudp_client.go b/prudp_client.go new file mode 100644 index 00000000..83a21b74 --- /dev/null +++ b/prudp_client.go @@ -0,0 +1,180 @@ +package nex + +import ( + "net" + "time" +) + +// PRUDPClient represents a single PRUDP client +type PRUDPClient struct { + address *net.UDPAddr + server *PRUDPServer + pid uint32 + clientConnectionSignature []byte + serverConnectionSignature []byte + sessionKey []byte + reliableSubstreams []*ReliablePacketSubstreamManager + outgoingUnreliableSequenceIDCounter *Counter[uint16] + outgoingPingSequenceIDCounter *Counter[uint16] + heartbeatTimer *time.Timer + pingKickTimer *time.Timer + sourceStreamType uint8 + sourcePort uint8 + destinationStreamType uint8 + destinationPort uint8 + minorVersion uint32 // * Not currently used for anything, but maybe useful later? + supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? +} + +// Reset sets the client back to it's default state +func (c *PRUDPClient) reset() { + for _, substream := range c.reliableSubstreams { + substream.ResendScheduler.Stop() + } + + c.clientConnectionSignature = make([]byte, 0) + c.serverConnectionSignature = make([]byte, 0) + c.sessionKey = make([]byte, 0) + c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) + c.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](0) + c.outgoingPingSequenceIDCounter = NewCounter[uint16](0) + c.sourceStreamType = 0 + c.sourcePort = 0 + c.destinationStreamType = 0 + c.destinationPort = 0 +} + +// Cleanup cleans up any resources the client may be using +// +// This is similar to Client.Reset(), with the key difference +// being that Cleanup does not care about the state the client +// is currently in, or will be in, after execution. It only +// frees resources that are not easily garbage collected +func (c *PRUDPClient) cleanup() { + for _, substream := range c.reliableSubstreams { + substream.ResendScheduler.Stop() + } + + c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) + c.stopHeartbeatTimers() +} + +// Server returns the server the client is connecting to +func (c *PRUDPClient) Server() ServerInterface { + return c.server +} + +// Address returns the clients address as a net.Addr +func (c *PRUDPClient) Address() net.Addr { + return c.address +} + +// PID returns the clients NEX PID +func (c *PRUDPClient) PID() uint32 { + return c.pid +} + +// SetPID sets the clients NEX PID +func (c *PRUDPClient) SetPID(pid uint32) { + c.pid = pid +} + +// SetSessionKey sets the clients session key used for reliable RC4 ciphers +func (c *PRUDPClient) setSessionKey(sessionKey []byte) { + c.sessionKey = sessionKey + + c.reliableSubstreams[0].SetCipherKey(sessionKey) + + // * Only the first substream uses the session key directly. + // * All other substreams modify the key before it so that + // * all substreams have a unique cipher key + for _, substream := range c.reliableSubstreams[1:] { + modifier := len(sessionKey)/2 + 1 + + // * Create a new slice to avoid modifying past keys + sessionKey = append(make([]byte, 0), sessionKey...) + + // * Only the first half of the key is modified + for i := 0; i < len(sessionKey)/2; i++ { + sessionKey[i] = (sessionKey[i] + byte(modifier-i)) & 0xFF + } + + substream.SetCipherKey(sessionKey) + } +} + +// ReliableSubstream returns the clients reliable substream ID +func (c *PRUDPClient) reliableSubstream(substreamID uint8) *ReliablePacketSubstreamManager { + return c.reliableSubstreams[substreamID] +} + +// CreateReliableSubstreams creates the list of substreams used for reliable PRUDP packets +func (c *PRUDPClient) createReliableSubstreams(maxSubstreamID uint8) { + substreams := maxSubstreamID + 1 + + c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, substreams) + + for i := 0; i < len(c.reliableSubstreams); i++ { + // * First DATA packet from the client has sequence ID 2 + // * First DATA packet from the server has sequence ID 1 (starts counter at 0 and is incremeneted) + c.reliableSubstreams[i] = NewReliablePacketSubstreamManager(2, 0) + } +} + +func (c *PRUDPClient) nextOutgoingUnreliableSequenceID() uint16 { + return c.outgoingUnreliableSequenceIDCounter.Next() +} + +func (c *PRUDPClient) nextOutgoingPingSequenceID() uint16 { + return c.outgoingPingSequenceIDCounter.Next() +} + +func (c *PRUDPClient) resetHeartbeat() { + if c.pingKickTimer != nil { + c.pingKickTimer.Stop() + } + + if c.heartbeatTimer != nil { + c.heartbeatTimer.Reset(c.server.pingTimeout) + } +} + +func (c *PRUDPClient) startHeartbeat() { + server := c.server + + // * Every time a packet is sent, client.resetHeartbeat() + // * is called which resets this timer. If this function + // * ever executes, it means we haven't seen the client + // * in the expected time frame. If this happens, send + // * the client a PING packet to try and kick start the + // * heartbeat again + c.heartbeatTimer = time.AfterFunc(server.pingTimeout, func() { + server.sendPing(c) + + // * If the heartbeat still did not restart, assume the + // * client is dead and clean up + c.pingKickTimer = time.AfterFunc(server.pingTimeout, func() { + c.cleanup() + c.server.clients.Delete(c.address.String()) + }) + }) +} + +func (c *PRUDPClient) stopHeartbeatTimers() { + if c.pingKickTimer != nil { + c.pingKickTimer.Stop() + } + + if c.heartbeatTimer != nil { + c.heartbeatTimer.Stop() + } +} + +// NewPRUDPClient creates and returns a new Client using the provided UDP address and server +func NewPRUDPClient(address *net.UDPAddr, server *PRUDPServer) *PRUDPClient { + return &PRUDPClient{ + address: address, + server: server, + outgoingPingSequenceIDCounter: NewCounter[uint16](0), + } +} diff --git a/prudp_packet.go b/prudp_packet.go new file mode 100644 index 00000000..513f0b57 --- /dev/null +++ b/prudp_packet.go @@ -0,0 +1,164 @@ +package nex + +// PRUDPPacket holds all the fields each packet should have in all PRUDP versions +type PRUDPPacket struct { + sender *PRUDPClient + readStream *StreamIn + sourceStreamType uint8 + sourcePort uint8 + destinationStreamType uint8 + destinationPort uint8 + packetType uint16 + flags uint16 + sessionID uint8 + substreamID uint8 + signature []byte + sequenceID uint16 + connectionSignature []byte + fragmentID uint8 + payload []byte + message *RMCMessage +} + +// Sender returns the Client who sent the packet +func (p *PRUDPPacket) Sender() ClientInterface { + return p.sender +} + +// HasFlag checks if the packet has the given flag +func (p *PRUDPPacket) HasFlag(flag uint16) bool { + return p.flags&flag != 0 +} + +// AddFlag adds the given flag to the packet flag bitmask +func (p *PRUDPPacket) AddFlag(flag uint16) { + p.flags |= flag +} + +// SetType sets the packets type +func (p *PRUDPPacket) SetType(packetType uint16) { + p.packetType = packetType +} + +// Type returns the packets type +func (p *PRUDPPacket) Type() uint16 { + return p.packetType +} + +// SetSourceStreamType sets the packet virtual source stream type +func (p *PRUDPPacket) SetSourceStreamType(sourceStreamType uint8) { + p.sourceStreamType = sourceStreamType +} + +// SourceStreamType returns the packet virtual source stream type +func (p *PRUDPPacket) SourceStreamType() uint8 { + return p.sourceStreamType +} + +// SetSourcePort sets the packet virtual source stream type +func (p *PRUDPPacket) SetSourcePort(sourcePort uint8) { + p.sourcePort = sourcePort +} + +// SourcePort returns the packet virtual source stream type +func (p *PRUDPPacket) SourcePort() uint8 { + return p.sourcePort +} + +// SetDestinationStreamType sets the packet virtual destination stream type +func (p *PRUDPPacket) SetDestinationStreamType(destinationStreamType uint8) { + p.destinationStreamType = destinationStreamType +} + +// DestinationStreamType returns the packet virtual destination stream type +func (p *PRUDPPacket) DestinationStreamType() uint8 { + return p.destinationStreamType +} + +// SetDestinationPort sets the packet virtual destination port +func (p *PRUDPPacket) SetDestinationPort(destinationPort uint8) { + p.destinationPort = destinationPort +} + +// DestinationPort returns the packet virtual destination port +func (p *PRUDPPacket) DestinationPort() uint8 { + return p.destinationPort +} + +// SetSessionID sets the packets session ID +func (p *PRUDPPacket) SetSessionID(sessionID uint8) { + p.sessionID = sessionID +} + +// SubstreamID returns the packets substream ID +func (p *PRUDPPacket) SubstreamID() uint8 { + return p.substreamID +} + +// SetSubstreamID sets the packets substream ID +func (p *PRUDPPacket) SetSubstreamID(substreamID uint8) { + p.substreamID = substreamID +} + +func (p *PRUDPPacket) setSignature(signature []byte) { + p.signature = signature +} + +// SequenceID returns the packets sequenc ID +func (p *PRUDPPacket) SequenceID() uint16 { + return p.sequenceID +} + +// SetSequenceID sets the packets sequenc ID +func (p *PRUDPPacket) SetSequenceID(sequenceID uint16) { + p.sequenceID = sequenceID +} + +// Payload returns the packets payload +func (p *PRUDPPacket) Payload() []byte { + return p.payload +} + +// SetPayload sets the packets payload +func (p *PRUDPPacket) SetPayload(payload []byte) { + p.payload = payload +} + +func (p *PRUDPPacket) decryptPayload() []byte { + payload := p.payload + + // TODO - This assumes a reliable DATA packet. Handle unreliable here? Or do that in a different method? + if p.packetType == DataPacket { + substream := p.sender.reliableSubstream(p.SubstreamID()) + + payload = substream.Decrypt(payload) + } + + return payload +} + +func (p *PRUDPPacket) getConnectionSignature() []byte { + return p.connectionSignature +} + +func (p *PRUDPPacket) setConnectionSignature(connectionSignature []byte) { + p.connectionSignature = connectionSignature +} + +func (p *PRUDPPacket) getFragmentID() uint8 { + return p.fragmentID +} + +func (p *PRUDPPacket) setFragmentID(fragmentID uint8) { + p.fragmentID = fragmentID +} + +// RMCMessage returns the packets RMC Message +func (p *PRUDPPacket) RMCMessage() *RMCMessage { + return p.message +} + +// SetRMCMessage sets the packets RMC Message +func (p *PRUDPPacket) SetRMCMessage(message *RMCMessage) { + p.message = message +} diff --git a/packet_flags.go b/prudp_packet_flags.go similarity index 87% rename from packet_flags.go rename to prudp_packet_flags.go index 1832cc3c..147dd7b1 100644 --- a/packet_flags.go +++ b/prudp_packet_flags.go @@ -2,7 +2,7 @@ package nex const ( // FlagAck is the ID for the PRUDP Ack Flag - FlagAck uint16 = 0x1 + FlagAck uint16 = 0x1 // FlagReliable is the ID for the PRUDP Reliable Flag FlagReliable uint16 = 0x2 @@ -11,8 +11,8 @@ const ( FlagNeedsAck uint16 = 0x4 // FlagHasSize is the ID for the PRUDP HasSize Flag - FlagHasSize uint16 = 0x8 + FlagHasSize uint16 = 0x8 // FlagMultiAck is the ID for the PRUDP MultiAck Flag FlagMultiAck uint16 = 0x200 -) \ No newline at end of file +) diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go new file mode 100644 index 00000000..5703cd8c --- /dev/null +++ b/prudp_packet_interface.go @@ -0,0 +1,40 @@ +package nex + +import "net" + +// PRUDPPacketInterface defines all the methods a PRUDP packet should have +type PRUDPPacketInterface interface { + Version() int + Bytes() []byte + Sender() ClientInterface + HasFlag(flag uint16) bool + AddFlag(flag uint16) + SetType(packetType uint16) + Type() uint16 + SetSourceStreamType(sourceStreamType uint8) + SourceStreamType() uint8 + SetSourcePort(sourcePort uint8) + SourcePort() uint8 + SetDestinationStreamType(destinationStreamType uint8) + DestinationStreamType() uint8 + SetDestinationPort(destinationPort uint8) + DestinationPort() uint8 + SetSessionID(sessionID uint8) + SubstreamID() uint8 + SetSubstreamID(substreamID uint8) + SequenceID() uint16 + SetSequenceID(sequenceID uint16) + Payload() []byte + SetPayload(payload []byte) + RMCMessage() *RMCMessage + SetRMCMessage(message *RMCMessage) + decode() error + setSignature(signature []byte) + calculateConnectionSignature(addr net.Addr) ([]byte, error) + calculateSignature(sessionKey, connectionSignature []byte) []byte + decryptPayload() []byte + getConnectionSignature() []byte + setConnectionSignature(connectionSignature []byte) + getFragmentID() uint8 + setFragmentID(fragmentID uint8) +} diff --git a/packet_types.go b/prudp_packet_types.go similarity index 93% rename from packet_types.go rename to prudp_packet_types.go index d33a21ad..fc690866 100644 --- a/packet_types.go +++ b/prudp_packet_types.go @@ -17,7 +17,7 @@ const ( PingPacket uint16 = 0x4 ) -var validTypes = map[uint16]bool{ +var validPacketTypes = map[uint16]bool{ SynPacket: true, ConnectPacket: true, DataPacket: true, diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go new file mode 100644 index 00000000..3af9496a --- /dev/null +++ b/prudp_packet_v0.go @@ -0,0 +1,341 @@ +package nex + +import ( + "crypto/hmac" + "crypto/md5" + "encoding/binary" + "errors" + "fmt" + "net" + "slices" +) + +// PRUDPPacketV0 represents a PRUDPv0 packet +type PRUDPPacketV0 struct { + PRUDPPacket +} + +// Version returns the packets PRUDP version +func (p *PRUDPPacketV0) Version() int { + return 0 +} + +func (p *PRUDPPacketV0) decode() error { + // * Header is technically 11 bytes but checking for 12 includes the checksum + if p.readStream.Remaining() < 12 { + return errors.New("Failed to read PRUDPv0 header. Not have enough data") + } + + server := p.sender.server + start := p.readStream.ByteOffset() + + source, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 source. %s", err.Error()) + } + + destination, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 destination. %s", err.Error()) + } + + p.sourceStreamType = source >> 4 + p.sourcePort = source & 0xF + p.destinationStreamType = destination >> 4 + p.destinationPort = destination & 0xF + + if server.IsQuazalMode { + typeAndFlags, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 type and flags. %s", err.Error()) + } + + p.flags = uint16(typeAndFlags >> 3) + p.packetType = uint16(typeAndFlags & 7) + } else { + typeAndFlags, err := p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 type and flags. %s", err.Error()) + } + + p.flags = typeAndFlags >> 4 + p.packetType = typeAndFlags & 0xF + } + + if _, ok := validPacketTypes[p.packetType]; !ok { + return errors.New("Invalid PRUDPv0 packet type") + } + + p.sessionID, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 session ID. %s", err.Error()) + } + + p.signature = p.readStream.ReadBytesNext(4) + + p.sequenceID, err = p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 sequence ID. %s", err.Error()) + } + + if p.packetType == SynPacket || p.packetType == ConnectPacket { + if p.readStream.Remaining() < 4 { + return errors.New("Failed to read PRUDPv0 connection signature. Not have enough data") + } + + p.connectionSignature = p.readStream.ReadBytesNext(4) + } + + if p.packetType == DataPacket { + if p.readStream.Remaining() < 1 { + return errors.New("Failed to read PRUDPv0 fragment ID. Not have enough data") + } + + p.fragmentID, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 fragment ID. %s", err.Error()) + } + } + + var payloadSize uint16 + + if p.HasFlag(FlagHasSize) { + if p.readStream.Remaining() < 2 { + return errors.New("Failed to read PRUDPv0 payload size. Not have enough data") + } + + payloadSize, err = p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 payload size. %s", err.Error()) + } + } else { + // * Quazal used a 4 byte checksum. NEX uses 1 byte + if server.IsQuazalMode { + payloadSize = uint16(p.readStream.Remaining() - 4) + } else { + payloadSize = uint16(p.readStream.Remaining() - 1) + } + } + + if p.readStream.Remaining() < int(payloadSize) { + return errors.New("Failed to read PRUDPv0 payload. Not have enough data") + } + + p.payload = p.readStream.ReadBytesNext(int64(payloadSize)) + + if server.IsQuazalMode && p.readStream.Remaining() < 4 { + return errors.New("Failed to read PRUDPv0 checksum. Not have enough data") + } else if p.readStream.Remaining() < 1 { + return errors.New("Failed to read PRUDPv0 checksum. Not have enough data") + } + + checksumData := p.readStream.Bytes()[start:p.readStream.ByteOffset()] + + var checksum uint32 + var checksumU8 uint8 + + if server.IsQuazalMode { + checksum, err = p.readStream.ReadUInt32LE() + } else { + checksumU8, err = p.readStream.ReadUInt8() + checksum = uint32(checksumU8) + } + + if err != nil { + return fmt.Errorf("Failed to read PRUDPv0 checksum. %s", err.Error()) + } + + calculatedChecksum := p.calculateChecksum(checksumData) + + if checksum != calculatedChecksum { + return errors.New("Invalid PRUDPv0 checksum") + } + + return nil +} + +// Bytes encodes a PRUDPv0 packet into a byte slice +func (p *PRUDPPacketV0) Bytes() []byte { + server := p.sender.server + stream := NewStreamOut(server) + + stream.WriteUInt8(p.sourcePort | (p.sourceStreamType << 4)) + stream.WriteUInt8(p.destinationPort | (p.destinationStreamType << 4)) + + if server.IsQuazalMode { + stream.WriteUInt8(uint8(p.packetType | (p.flags << 3))) + } else { + stream.WriteUInt16LE(p.packetType | (p.flags << 4)) + } + + stream.WriteUInt8(p.sessionID) + stream.Grow(int64(len(p.signature))) + stream.WriteBytesNext(p.signature) + stream.WriteUInt16LE(p.sequenceID) + + if p.packetType == SynPacket || p.packetType == ConnectPacket { + stream.Grow(int64(len(p.connectionSignature))) + stream.WriteBytesNext(p.connectionSignature) + } + + if p.packetType == DataPacket { + stream.WriteUInt8(p.fragmentID) + } + + if p.HasFlag(FlagHasSize) { + stream.WriteUInt16LE(uint16(len(p.payload))) + } + + if len(p.payload) > 0 { + stream.Grow(int64(len(p.payload))) + stream.WriteBytesNext(p.payload) + } + + checksum := p.calculateChecksum(stream.Bytes()) + + if server.IsQuazalMode { + stream.WriteUInt32LE(checksum) + } else { + stream.WriteUInt8(uint8(checksum)) + } + + return stream.Bytes() +} + +func (p *PRUDPPacketV0) calculateConnectionSignature(addr net.Addr) ([]byte, error) { + var ip net.IP + var port int + + switch v := addr.(type) { + case *net.UDPAddr: + ip = v.IP.To4() + port = v.Port + default: + return nil, fmt.Errorf("Unsupported network type: %T", addr) + } + + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(port)) + + data := append(ip, portBytes...) + hash := md5.Sum(data) + signatureBytes := hash[:4] + + slices.Reverse(signatureBytes) + + return signatureBytes, nil +} + +func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byte) []byte { + if p.packetType == DataPacket { + return p.calculateDataSignature(sessionKey) + } + + if p.packetType == DisconnectPacket && p.sender.server.accessKey != "ridfebb9" { + return p.calculateDataSignature(sessionKey) + } + + if len(connectionSignature) != 0 { + return connectionSignature + } + + return make([]byte, 4) +} + +func (p *PRUDPPacketV0) calculateDataSignature(sessionKey []byte) []byte { + server := p.sender.server + data := p.payload + + if server.AccessKey() != "ridfebb9" { + header := []byte{0, 0, p.fragmentID} + binary.LittleEndian.PutUint16(header[:2], p.sequenceID) + + data = append(sessionKey, header...) + data = append(data, p.payload...) + } + + if len(data) > 0 { + key := md5.Sum([]byte(server.AccessKey())) + mac := hmac.New(md5.New, key[:]) + + mac.Write(data) + + digest := mac.Sum(nil) + + return digest[:4] + } + + return []byte{0x78, 0x56, 0x34, 0x12} +} + +func (p *PRUDPPacketV0) calculateChecksum(data []byte) uint32 { + server := p.sender.server + checksum := sum[byte, uint32]([]byte(server.AccessKey())) + + if server.IsQuazalMode { + padSize := (len(data) + 3) &^ 3 + data = append(data, make([]byte, padSize-len(data))...) + words := make([]uint32, len(data)/4) + + for i := 0; i < len(data)/4; i++ { + words[i] = binary.LittleEndian.Uint32(data[i*4 : (i+1)*4]) + } + + result := (checksum & 0xFF) + sum[uint32, uint32](words) + + return result & 0xFFFFFFFF + } else { + words := make([]uint32, len(data)/4) + + for i := 0; i < len(data)/4; i++ { + words[i] = binary.LittleEndian.Uint32(data[i*4 : (i+1)*4]) + } + + temp := sum[uint32, uint32](words) & 0xFFFFFFFF + + checksum += sum[byte, uint32](data[len(data)&^3:]) + + tempBytes := make([]byte, 4) + + binary.LittleEndian.PutUint32(tempBytes, temp) + + checksum += sum[byte, uint32](tempBytes) + + return checksum & 0xFF + } +} + +// NewPRUDPPacketV0 creates and returns a new PacketV0 using the provided Client and stream +func NewPRUDPPacketV0(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV0, error) { + packet := &PRUDPPacketV0{ + PRUDPPacket: PRUDPPacket{ + sender: client, + readStream: readStream, + }, + } + + if readStream != nil { + err := packet.decode() + if err != nil { + return nil, fmt.Errorf("Failed to decode PRUDPv0 packet. %s", err.Error()) + } + } + + return packet, nil +} + +// NewPRUDPPacketsV0 reads all possible PRUDPv0 packets from the stream +func NewPRUDPPacketsV0(client *PRUDPClient, readStream *StreamIn) ([]PRUDPPacketInterface, error) { + packets := make([]PRUDPPacketInterface, 0) + + for readStream.Remaining() > 0 { + packet, err := NewPRUDPPacketV0(client, readStream) + if err != nil { + return packets, err + } + + packets = append(packets, packet) + } + + return packets, nil +} diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go new file mode 100644 index 00000000..c606cf3e --- /dev/null +++ b/prudp_packet_v1.go @@ -0,0 +1,351 @@ +package nex + +import ( + "bytes" + "crypto/hmac" + "crypto/md5" + "encoding/binary" + "errors" + "fmt" + "net" +) + +// PRUDPPacketV1 represents a PRUDPv1 packet +type PRUDPPacketV1 struct { + PRUDPPacket + optionsLength uint8 + payloadLength uint16 + minorVersion uint32 + supportedFunctions uint32 + maximumSubstreamID uint8 + initialUnreliableSequenceID uint16 +} + +// Version returns the packets PRUDP version +func (p *PRUDPPacketV1) Version() int { + return 1 +} + +// Decode parses the packets data +func (p *PRUDPPacketV1) decode() error { + if p.readStream.Remaining() < 2 { + return errors.New("Failed to read PRUDPv1 magic. Not have enough data") + } + + magic := p.readStream.ReadBytesNext(2) + + if !bytes.Equal(magic, []byte{0xEA, 0xD0}) { + return fmt.Errorf("Invalid PRUDPv1 magic. Expected 0xEAD0, got 0x%x", magic) + } + + err := p.decodeHeader() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPv1 header. %s", err.Error()) + } + + p.signature = p.readStream.ReadBytesNext(16) + + err = p.decodeOptions() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPv1 options. %s", err.Error()) + } + + p.payload = p.readStream.ReadBytesNext(int64(p.payloadLength)) + + return nil +} + +// Bytes encodes a PRUDPv1 packet into a byte slice +func (p *PRUDPPacketV1) Bytes() []byte { + options := p.encodeOptions() + + p.optionsLength = uint8(len(options)) + + header := p.encodeHeader() + + stream := NewStreamOut(nil) + + stream.Grow(2) + stream.WriteBytesNext([]byte{0xEA, 0xD0}) + + stream.Grow(12) + stream.WriteBytesNext(header) + + stream.Grow(16) + stream.WriteBytesNext(p.signature) + + stream.Grow(int64(p.optionsLength)) + stream.WriteBytesNext(options) + + stream.Grow(int64(len(p.payload))) + stream.WriteBytesNext(p.payload) + + return stream.Bytes() +} + +func (p *PRUDPPacketV1) decodeHeader() error { + if p.readStream.Remaining() < 12 { + return errors.New("Failed to read PRUDPv1 magic. Not have enough data") + } + + version, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPv1 version. %s", err.Error()) + } + + if version != 1 { + return fmt.Errorf("Invalid PRUDPv1 version. Expected 1, got %d", version) + } + + p.optionsLength, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPv1 options length. %s", err.Error()) + } + + p.payloadLength, err = p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPv1 payload length. %s", err.Error()) + } + + source, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv1 source. %s", err.Error()) + } + + destination, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv1 destination. %s", err.Error()) + } + + p.sourceStreamType = source >> 4 + p.sourcePort = source & 0xF + p.destinationStreamType = destination >> 4 + p.destinationPort = destination & 0xF + + // TODO - Does QRV also encode it this way in PRUDPv1? + typeAndFlags, err := p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv1 type and flags. %s", err.Error()) + } + + p.flags = typeAndFlags >> 4 + p.packetType = typeAndFlags & 0xF + + if _, ok := validPacketTypes[p.packetType]; !ok { + return errors.New("Invalid PRUDPv1 packet type") + } + + p.sessionID, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv1 session ID. %s", err.Error()) + } + + p.substreamID, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv1 substream ID. %s", err.Error()) + } + + p.sequenceID, err = p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read PRUDPv1 sequence ID. %s", err.Error()) + } + + return nil +} + +func (p *PRUDPPacketV1) encodeHeader() []byte { + stream := NewStreamOut(nil) + + stream.WriteUInt8(1) // * Version + stream.WriteUInt8(p.optionsLength) + stream.WriteUInt16LE(uint16(len(p.payload))) + stream.WriteUInt8(p.sourcePort | (p.sourceStreamType << 4)) + stream.WriteUInt8(p.destinationPort | (p.destinationStreamType << 4)) + stream.WriteUInt16LE(p.packetType | (p.flags << 4)) // TODO - Does QRV also encode it this way in PRUDPv1? + stream.WriteUInt8(p.sessionID) + stream.WriteUInt8(p.substreamID) + stream.WriteUInt16LE(p.sequenceID) + + return stream.Bytes() +} + +func (p *PRUDPPacketV1) decodeOptions() error { + data := p.readStream.ReadBytesNext(int64(p.optionsLength)) + optionsStream := NewStreamIn(data, nil) + + for optionsStream.Remaining() > 0 { + optionID, err := optionsStream.ReadUInt8() + if err != nil { + return err + } + + _, err = optionsStream.ReadUInt8() // * Options size. We already know the size based on the ID, though + if err != nil { + return err + } + + if p.packetType == SynPacket || p.packetType == ConnectPacket { + if optionID == 0 { + p.supportedFunctions, err = optionsStream.ReadUInt32LE() + + p.minorVersion = p.supportedFunctions & 0xFF + p.supportedFunctions = p.supportedFunctions >> 8 + } + + if optionID == 1 { + p.connectionSignature = optionsStream.ReadBytesNext(16) + } + + if optionID == 4 { + p.maximumSubstreamID, err = optionsStream.ReadUInt8() + } + } + + if p.packetType == ConnectPacket { + if optionID == 3 { + p.initialUnreliableSequenceID, err = optionsStream.ReadUInt16LE() + } + } + + if p.packetType == DataPacket { + if optionID == 2 { + p.fragmentID, err = optionsStream.ReadUInt8() + } + } + + // * Only one option is processed at a time, so we can + // * just check for errors here rather than after EVERY + // * read + if err != nil { + return err + } + } + + return nil +} + +func (p *PRUDPPacketV1) encodeOptions() []byte { + optionsStream := NewStreamOut(nil) + + if p.packetType == SynPacket || p.packetType == ConnectPacket { + optionsStream.WriteUInt8(0) + optionsStream.WriteUInt8(4) + optionsStream.WriteUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) + + optionsStream.WriteUInt8(1) + optionsStream.WriteUInt8(16) + optionsStream.Grow(16) + optionsStream.WriteBytesNext(p.connectionSignature) + + // * Encoded here for NintendoClients compatibility. + // * The order of these options should not matter, + // * however when NintendoClients calculates the + // * signature it does NOT use the original options + // * section, and instead re-encodes the data in a + // * specific order. Due to how this section is + // * parsed, though, order REALLY doesn't matter. + // * NintendoClients expects option 3 before 4, though + if p.packetType == ConnectPacket { + optionsStream.WriteUInt8(3) + optionsStream.WriteUInt8(2) + optionsStream.WriteUInt16LE(p.initialUnreliableSequenceID) + } + + optionsStream.WriteUInt8(4) + optionsStream.WriteUInt8(1) + optionsStream.WriteUInt8(p.maximumSubstreamID) + } + + if p.packetType == DataPacket { + optionsStream.WriteUInt8(2) + optionsStream.WriteUInt8(1) + optionsStream.WriteUInt8(p.fragmentID) + } + + return optionsStream.Bytes() +} + +func (p *PRUDPPacketV1) calculateConnectionSignature(addr net.Addr) ([]byte, error) { + var ip net.IP + var port int + + switch v := addr.(type) { + case *net.UDPAddr: + ip = v.IP.To4() + port = v.Port + default: + return nil, fmt.Errorf("Unsupported network type: %T", addr) + } + + // * The real client seems to not care about this. The original + // * server just used rand.Read here. This is done to implement + // * compatibility with NintendoClients, as this is how it + // * calculates PRUDPv1 connection signatures + key := []byte{0x26, 0xc3, 0x1f, 0x38, 0x1e, 0x46, 0xd6, 0xeb, 0x38, 0xe1, 0xaf, 0x6a, 0xb7, 0x0d, 0x11} + + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(port)) + + data := append(ip, portBytes...) + hash := hmac.New(md5.New, key) + hash.Write(data) + + return hash.Sum(nil), nil +} + +func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byte) []byte { + accessKeyBytes := []byte(p.sender.server.accessKey) + options := p.encodeOptions() + header := p.encodeHeader() + + accessKeySum := sum[byte, uint32](accessKeyBytes) + accessKeySumBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(accessKeySumBytes, accessKeySum) + + key := md5.Sum(accessKeyBytes) + mac := hmac.New(md5.New, key[:]) + + mac.Write(header[4:]) + mac.Write(sessionKey) + mac.Write(accessKeySumBytes) + mac.Write(connectionSignature) + mac.Write(options) + mac.Write(p.payload) + + return mac.Sum(nil) +} + +// NewPRUDPPacketV1 creates and returns a new PacketV1 using the provided Client and stream +func NewPRUDPPacketV1(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV1, error) { + packet := &PRUDPPacketV1{ + PRUDPPacket: PRUDPPacket{ + sender: client, + readStream: readStream, + }, + } + + if readStream != nil { + err := packet.decode() + if err != nil { + return nil, fmt.Errorf("Failed to decode PRUDPv1 packet. %s", err.Error()) + } + } + + return packet, nil +} + +// NewPRUDPPacketsV1 reads all possible PRUDPv1 packets from the stream +func NewPRUDPPacketsV1(client *PRUDPClient, readStream *StreamIn) ([]PRUDPPacketInterface, error) { + packets := make([]PRUDPPacketInterface, 0) + + for readStream.Remaining() > 0 { + packet, err := NewPRUDPPacketV1(client, readStream) + if err != nil { + return packets, err + } + + packets = append(packets, packet) + } + + return packets, nil +} diff --git a/prudp_server.go b/prudp_server.go new file mode 100644 index 00000000..3bfac681 --- /dev/null +++ b/prudp_server.go @@ -0,0 +1,697 @@ +package nex + +import ( + "bytes" + "errors" + "fmt" + "net" + "runtime" + "slices" + "time" +) + +// PRUDPServer represents a bare-bones PRUDP server +type PRUDPServer struct { + udpSocket *net.UDPConn + clients *MutexMap[string, *PRUDPClient] + PRUDPVersion int + IsQuazalMode bool + IsSecureServer bool + accessKey string + kerberosPassword []byte + kerberosTicketVersion int + kerberosKeySize int + FragmentSize int + version *LibraryVersion + datastoreProtocolVersion *LibraryVersion + matchMakingProtocolVersion *LibraryVersion + rankingProtocolVersion *LibraryVersion + ranking2ProtocolVersion *LibraryVersion + messagingProtocolVersion *LibraryVersion + utilityProtocolVersion *LibraryVersion + natTraversalProtocolVersion *LibraryVersion + eventHandlers map[string][]func(PacketInterface) + connectionIDCounter *Counter[uint32] + pingTimeout time.Duration +} + +// OnReliableData adds an event handler which is fired when a new reliable DATA packet is received +func (s *PRUDPServer) OnReliableData(handler func(PacketInterface)) { + s.on("reliable-data", handler) +} + +func (s *PRUDPServer) on(name string, handler func(PacketInterface)) { + if _, ok := s.eventHandlers[name]; !ok { + s.eventHandlers[name] = make([]func(PacketInterface), 0) + } + + s.eventHandlers[name] = append(s.eventHandlers[name], handler) +} + +func (s *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { + if handlers, ok := s.eventHandlers[name]; ok { + for _, handler := range handlers { + go handler(packet) + } + } +} + +// Listen starts a PRUDP server on a given port +func (s *PRUDPServer) Listen(port int) { + udpAddress, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) + if err != nil { + panic(err) + } + + socket, err := net.ListenUDP("udp", udpAddress) + if err != nil { + panic(err) + } + + s.udpSocket = socket + + quit := make(chan struct{}) + + for i := 0; i < runtime.NumCPU(); i++ { + go s.listenDatagram(quit) + } + + <-quit +} + +func (s *PRUDPServer) listenDatagram(quit chan struct{}) { + err := error(nil) + + for err == nil { + err = s.handleSocketMessage() + } + + quit <- struct{}{} + + panic(err) +} + +func (s *PRUDPServer) handleSocketMessage() error { + buffer := make([]byte, 64000) + + read, addr, err := s.udpSocket.ReadFromUDP(buffer) + if err != nil { + return err + } + + discriminator := addr.String() + + client, ok := s.clients.Get(discriminator) + + if !ok { + client = NewPRUDPClient(addr, s) + client.startHeartbeat() + + s.clients.Set(discriminator, client) + } + + packetData := buffer[:read] + readStream := NewStreamIn(packetData, s) + + var packets []PRUDPPacketInterface + + // * Support any packet type the client sends and respond + // * with that same type. Also keep reading from the stream + // * until no more data is left, to account for multiple + // * packets being sent at once + if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) { + packets, _ = NewPRUDPPacketsV1(client, readStream) + } else { + packets, _ = NewPRUDPPacketsV0(client, readStream) + } + + for _, packet := range packets { + s.processPacket(packet) + } + + return nil +} + +func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface) { + packet.Sender().(*PRUDPClient).resetHeartbeat() + + if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { + s.handleAcknowledgment(packet) + return + } + + switch packet.Type() { + case SynPacket: + s.handleSyn(packet) + case ConnectPacket: + s.handleConnect(packet) + case DataPacket: + s.handleData(packet) + case DisconnectPacket: + s.handleDisconnect(packet) + case PingPacket: + s.handlePing(packet) + } +} + +func (s *PRUDPServer) handleAcknowledgment(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagMultiAck) { + s.handleMultiAcknowledgment(packet) + return + } + + client := packet.Sender().(*PRUDPClient) + + substream := client.reliableSubstream(packet.SubstreamID()) + substream.ResendScheduler.AcknowledgePacket(packet.SequenceID()) +} + +func (s *PRUDPServer) handleMultiAcknowledgment(packet PRUDPPacketInterface) { + client := packet.Sender().(*PRUDPClient) + stream := NewStreamIn(packet.Payload(), s) + sequenceIDs := make([]uint16, 0) + var baseSequenceID uint16 + var substream *ReliablePacketSubstreamManager + + if packet.SubstreamID() == 1 { + // * New aggregate acknowledgment packets set this to 1 + // * and encode the real substream ID in in the payload + substreamID, _ := stream.ReadUInt8() + additionalIDsCount, _ := stream.ReadUInt8() + baseSequenceID, _ = stream.ReadUInt16LE() + substream = client.reliableSubstream(substreamID) + + for i := 0; i < int(additionalIDsCount); i++ { + additionalID, _ := stream.ReadUInt16LE() + sequenceIDs = append(sequenceIDs, additionalID) + } + } else { + // TODO - This is how Kinnay's client handles this, but it doesn't make sense for QRV? Since it can have multiple reliable substreams? + // * Old aggregate acknowledgment packets always use + // * substream 0 + substream = client.reliableSubstream(0) + baseSequenceID = packet.SequenceID() + + for stream.Remaining() > 0 { + additionalID, _ := stream.ReadUInt16LE() + sequenceIDs = append(sequenceIDs, additionalID) + } + } + + // * MutexMap.Each locks the mutex, can't remove while reading. + // * Have to just loop again + substream.ResendScheduler.packets.Each(func(sequenceID uint16, pending *PendingPacket) { + if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) { + sequenceIDs = append(sequenceIDs, sequenceID) + } + }) + + // * Actually remove the packets from the pool + for _, sequenceID := range sequenceIDs { + substream.ResendScheduler.AcknowledgePacket(sequenceID) + } +} + +func (s *PRUDPServer) handleSyn(packet PRUDPPacketInterface) { + client := packet.Sender().(*PRUDPClient) + + var ack PRUDPPacketInterface + + if packet.Version() == 0 { + ack, _ = NewPRUDPPacketV0(client, nil) + } else { + ack, _ = NewPRUDPPacketV1(client, nil) + } + + connectionSignature, _ := packet.calculateConnectionSignature(client.address) + + client.reset() + client.clientConnectionSignature = connectionSignature + client.sourceStreamType = packet.SourceStreamType() + client.sourcePort = packet.SourcePort() + client.destinationStreamType = packet.DestinationStreamType() + client.destinationPort = packet.DestinationPort() + + ack.SetType(SynPacket) + ack.AddFlag(FlagAck) + ack.AddFlag(FlagHasSize) + ack.SetSourceStreamType(packet.DestinationStreamType()) + ack.SetSourcePort(packet.DestinationPort()) + ack.SetDestinationStreamType(packet.SourceStreamType()) + ack.SetDestinationPort(packet.SourcePort()) + ack.setConnectionSignature(connectionSignature) + ack.setSignature(ack.calculateSignature([]byte{}, []byte{})) + + s.emit("syn", ack) + + s.sendRaw(client.address, ack.Bytes()) +} + +func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { + client := packet.Sender().(*PRUDPClient) + + var ack PRUDPPacketInterface + + if packet.Version() == 0 { + ack, _ = NewPRUDPPacketV0(client, nil) + } else { + ack, _ = NewPRUDPPacketV1(client, nil) + } + + client.serverConnectionSignature = packet.getConnectionSignature() + + connectionSignature, _ := packet.calculateConnectionSignature(client.address) + + ack.SetType(ConnectPacket) + ack.AddFlag(FlagAck) + ack.AddFlag(FlagHasSize) + ack.SetSourceStreamType(packet.DestinationStreamType()) + ack.SetSourcePort(packet.DestinationPort()) + ack.SetDestinationStreamType(packet.SourceStreamType()) + ack.SetDestinationPort(packet.SourcePort()) + ack.setConnectionSignature(make([]byte, len(connectionSignature))) + ack.SetSessionID(0) + ack.SetSequenceID(1) + + if ack, ok := ack.(*PRUDPPacketV1); ok { + // * Just tell the client we support exactly what it wants + ack.maximumSubstreamID = packet.(*PRUDPPacketV1).maximumSubstreamID + ack.minorVersion = packet.(*PRUDPPacketV1).minorVersion + ack.supportedFunctions = packet.(*PRUDPPacketV1).supportedFunctions + + client.minorVersion = ack.minorVersion + client.supportedFunctions = ack.supportedFunctions + client.createReliableSubstreams(ack.maximumSubstreamID) + } else { + client.createReliableSubstreams(0) + } + + var payload []byte + + if s.IsSecureServer { + sessionKey, pid, checkValue, err := s.readKerberosTicket(packet.Payload()) + if err != nil { + fmt.Println(err) + } + + client.SetPID(pid) + client.setSessionKey(sessionKey) + + stream := NewStreamOut(s) + + // * The response value is a Buffer whose data contains + // * checkValue+1. This is just a lazy way of encoding + // * a Buffer type + stream.WriteUInt32LE(4) // * Buffer length + stream.WriteUInt32LE(checkValue + 1) // * Buffer data + + payload = stream.Bytes() + } else { + payload = make([]byte, 0) + } + + ack.SetPayload(payload) + ack.setSignature(ack.calculateSignature([]byte{}, packet.getConnectionSignature())) + + s.emit("connect", ack) + + s.sendRaw(client.address, ack.Bytes()) +} + +func (s *PRUDPServer) handleData(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagReliable) { + s.handleReliable(packet) + } else { + s.handleUnreliable(packet) + } +} + +func (s *PRUDPServer) handleDisconnect(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + s.acknowledgePacket(packet) + } + + client := packet.Sender().(*PRUDPClient) + + client.cleanup() + s.clients.Delete(client.address.String()) + + s.emit("disconnect", packet) +} + +func (s *PRUDPServer) handlePing(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + s.acknowledgePacket(packet) + } +} + +func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, uint32, uint32, error) { + stream := NewStreamIn(payload, s) + + ticketData, err := stream.ReadBuffer() + if err != nil { + return nil, 0, 0, err + } + + requestData, err := stream.ReadBuffer() + if err != nil { + return nil, 0, 0, err + } + + serverKey := DeriveKerberosKey(2, s.kerberosPassword) + + ticket := NewKerberosTicketInternalData() + err = ticket.Decrypt(NewStreamIn(ticketData, s), serverKey) + if err != nil { + return nil, 0, 0, err + } + + ticketTime := ticket.Issued.Standard() + serverTime := time.Now().UTC() + + timeLimit := ticketTime.Add(time.Minute * 2) + if serverTime.After(timeLimit) { + return nil, 0, 0, errors.New("Kerberos ticket expired") + } + + sessionKey := ticket.SessionKey + kerberos := NewKerberosEncryption(sessionKey) + + decryptedRequestData, err := kerberos.Decrypt(requestData) + if err != nil { + return nil, 0, 0, err + } + + checkDataStream := NewStreamIn(decryptedRequestData, s) + + userPID, err := checkDataStream.ReadUInt32LE() + if err != nil { + return nil, 0, 0, err + } + + _, err = checkDataStream.ReadUInt32LE() // * CID of secure server station url + if err != nil { + return nil, 0, 0, err + } + + responseCheck, err := checkDataStream.ReadUInt32LE() + if err != nil { + return nil, 0, 0, err + } + + return sessionKey, userPID, responseCheck, nil +} + +func (s *PRUDPServer) acknowledgePacket(packet PRUDPPacketInterface) { + var ack PRUDPPacketInterface + + if packet.Version() == 0 { + ack, _ = NewPRUDPPacketV0(packet.Sender().(*PRUDPClient), nil) + } else { + ack, _ = NewPRUDPPacketV1(packet.Sender().(*PRUDPClient), nil) + } + + ack.SetType(packet.Type()) + ack.AddFlag(FlagAck) + ack.SetSourceStreamType(packet.DestinationStreamType()) + ack.SetSourcePort(packet.DestinationPort()) + ack.SetDestinationStreamType(packet.SourceStreamType()) + ack.SetDestinationPort(packet.SourcePort()) + ack.SetSequenceID(packet.SequenceID()) + ack.setFragmentID(packet.getFragmentID()) + ack.SetSubstreamID(packet.SubstreamID()) + + s.sendPacket(ack) + + // * Servers send the DISCONNECT ACK 3 times + if packet.Type() == DisconnectPacket { + s.sendPacket(ack) + s.sendPacket(ack) + } +} + +func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + s.acknowledgePacket(packet) + } + + substream := packet.Sender().(*PRUDPClient).reliableSubstream(packet.SubstreamID()) + + for _, pendingPacket := range substream.Update(packet) { + if packet.Type() == DataPacket { + payload := substream.AddFragment(pendingPacket.decryptPayload()) + + if packet.getFragmentID() == 0 { + message := NewRMCMessage() + message.FromBytes(payload) + + substream.ResetFragmentedPayload() + + packet.SetRMCMessage(message) + + s.emit("reliable-data", packet) + } + } + } +} + +func (s *PRUDPServer) handleUnreliable(packet PRUDPPacketInterface) {} + +func (s *PRUDPServer) sendPing(client *PRUDPClient) { + var ping PRUDPPacketInterface + + if s.PRUDPVersion == 0 { + ping, _ = NewPRUDPPacketV0(client, nil) + } else { + ping, _ = NewPRUDPPacketV1(client, nil) + } + + ping.SetType(PingPacket) + ping.AddFlag(FlagNeedsAck) + ping.SetSourceStreamType(client.destinationStreamType) + ping.SetSourcePort(client.destinationPort) + ping.SetDestinationStreamType(client.sourceStreamType) + ping.SetDestinationPort(client.sourcePort) + ping.SetSubstreamID(0) + + s.sendPacket(ping) +} + +// Send sends the packet to the packets sender +func (s *PRUDPServer) Send(packet PacketInterface) { + if packet, ok := packet.(PRUDPPacketInterface); ok { + data := packet.Payload() + fragments := int(len(data) / s.FragmentSize) + + var fragmentID uint8 = 1 + for i := 0; i <= fragments; i++ { + if len(data) < s.FragmentSize { + packet.SetPayload(data) + packet.setFragmentID(0) + } else { + packet.SetPayload(data[:s.FragmentSize]) + packet.setFragmentID(fragmentID) + + data = data[s.FragmentSize:] + fragmentID++ + } + + s.sendPacket(packet) + } + } +} + +func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { + client := packet.Sender().(*PRUDPClient) + + if !packet.HasFlag(FlagAck) && !packet.HasFlag(FlagMultiAck) { + if packet.HasFlag(FlagReliable) { + substream := client.reliableSubstream(packet.SubstreamID()) + packet.SetSequenceID(substream.NextOutgoingSequenceID()) + } else if packet.Type() == DataPacket { + packet.SetSequenceID(client.nextOutgoingUnreliableSequenceID()) + } else if packet.Type() == PingPacket { + packet.SetSequenceID(client.nextOutgoingPingSequenceID()) + } else { + packet.SetSequenceID(0) + } + } + + if packet.Type() == DataPacket && !packet.HasFlag(FlagAck) && !packet.HasFlag(FlagMultiAck) { + if packet.HasFlag(FlagReliable) { + substream := client.reliableSubstream(packet.SubstreamID()) + packet.SetPayload(substream.Encrypt(packet.Payload())) + } + // TODO - Unreliable crypto + } + + packet.setSignature(packet.calculateSignature(client.sessionKey, client.serverConnectionSignature)) + + if packet.HasFlag(FlagReliable) && packet.HasFlag(FlagNeedsAck) { + substream := client.reliableSubstream(packet.SubstreamID()) + substream.ResendScheduler.AddPacket(packet) + } + + s.sendRaw(packet.Sender().Address(), packet.Bytes()) +} + +// sendRaw will send the given address the provided packet +func (s *PRUDPServer) sendRaw(conn net.Addr, data []byte) { + s.udpSocket.WriteToUDP(data, conn.(*net.UDPAddr)) +} + +// AccessKey returns the servers sandbox access key +func (s *PRUDPServer) AccessKey() string { + return s.accessKey +} + +// SetAccessKey sets the servers sandbox access key +func (s *PRUDPServer) SetAccessKey(accessKey string) { + s.accessKey = accessKey +} + +// KerberosPassword returns the server kerberos password +func (s *PRUDPServer) KerberosPassword() []byte { + return s.kerberosPassword +} + +// SetKerberosPassword sets the server kerberos password +func (s *PRUDPServer) SetKerberosPassword(kerberosPassword []byte) { + s.kerberosPassword = kerberosPassword +} + +// SetFragmentSize sets the max size for a packets payload +func (s *PRUDPServer) SetFragmentSize(fragmentSize int) { + // TODO - Derive this value from the MTU + // * From the wiki: + // * + // * The fragment size depends on the implementation. + // * It is generally set to the MTU minus the packet overhead. + // * + // * In old NEX versions, which only support PRUDP v0, the MTU is + // * hardcoded to 1000 and the maximum payload size seems to be 962 bytes. + // * + // * Later, the MTU was increased to 1364, and the maximum payload + // * size is seems to be 1300 bytes, unless PRUDP v0 is used, in which case it’s 1264 bytes. + s.FragmentSize = fragmentSize +} + +// SetKerberosTicketVersion sets the version used when handling kerberos tickets +func (s *PRUDPServer) SetKerberosTicketVersion(kerberosTicketVersion int) { + s.kerberosTicketVersion = kerberosTicketVersion +} + +// KerberosKeySize gets the size for the kerberos session key +func (s *PRUDPServer) KerberosKeySize() int { + return s.kerberosKeySize +} + +// SetKerberosKeySize sets the size for the kerberos session key +func (s *PRUDPServer) SetKerberosKeySize(kerberosKeySize int) { + s.kerberosKeySize = kerberosKeySize +} + +// LibraryVersion returns the server NEX version +func (s *PRUDPServer) LibraryVersion() *LibraryVersion { + return s.version +} + +// SetDefaultLibraryVersion sets the default NEX protocol versions +func (s *PRUDPServer) SetDefaultLibraryVersion(version *LibraryVersion) { + s.version = version + s.datastoreProtocolVersion = version.Copy() + s.matchMakingProtocolVersion = version.Copy() + s.rankingProtocolVersion = version.Copy() + s.ranking2ProtocolVersion = version.Copy() + s.messagingProtocolVersion = version.Copy() + s.utilityProtocolVersion = version.Copy() + s.natTraversalProtocolVersion = version.Copy() +} + +// DataStoreProtocolVersion returns the servers DataStore protocol version +func (s *PRUDPServer) DataStoreProtocolVersion() *LibraryVersion { + return s.datastoreProtocolVersion +} + +// SetDataStoreProtocolVersion sets the servers DataStore protocol version +func (s *PRUDPServer) SetDataStoreProtocolVersion(version *LibraryVersion) { + s.datastoreProtocolVersion = version +} + +// MatchMakingProtocolVersion returns the servers MatchMaking protocol version +func (s *PRUDPServer) MatchMakingProtocolVersion() *LibraryVersion { + return s.matchMakingProtocolVersion +} + +// SetMatchMakingProtocolVersion sets the servers MatchMaking protocol version +func (s *PRUDPServer) SetMatchMakingProtocolVersion(version *LibraryVersion) { + s.matchMakingProtocolVersion = version +} + +// RankingProtocolVersion returns the servers Ranking protocol version +func (s *PRUDPServer) RankingProtocolVersion() *LibraryVersion { + return s.rankingProtocolVersion +} + +// SetRankingProtocolVersion sets the servers Ranking protocol version +func (s *PRUDPServer) SetRankingProtocolVersion(version *LibraryVersion) { + s.rankingProtocolVersion = version +} + +// Ranking2ProtocolVersion returns the servers Ranking2 protocol version +func (s *PRUDPServer) Ranking2ProtocolVersion() *LibraryVersion { + return s.ranking2ProtocolVersion +} + +// SetRanking2ProtocolVersion sets the servers Ranking2 protocol version +func (s *PRUDPServer) SetRanking2ProtocolVersion(version *LibraryVersion) { + s.ranking2ProtocolVersion = version +} + +// MessagingProtocolVersion returns the servers Messaging protocol version +func (s *PRUDPServer) MessagingProtocolVersion() *LibraryVersion { + return s.messagingProtocolVersion +} + +// SetMessagingProtocolVersion sets the servers Messaging protocol version +func (s *PRUDPServer) SetMessagingProtocolVersion(version *LibraryVersion) { + s.messagingProtocolVersion = version +} + +// UtilityProtocolVersion returns the servers Utility protocol version +func (s *PRUDPServer) UtilityProtocolVersion() *LibraryVersion { + return s.utilityProtocolVersion +} + +// SetUtilityProtocolVersion sets the servers Utility protocol version +func (s *PRUDPServer) SetUtilityProtocolVersion(version *LibraryVersion) { + s.utilityProtocolVersion = version +} + +// SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version +func (s *PRUDPServer) SetNATTraversalProtocolVersion(version *LibraryVersion) { + s.natTraversalProtocolVersion = version +} + +// NATTraversalProtocolVersion returns the servers NAT Traversal protocol version +func (s *PRUDPServer) NATTraversalProtocolVersion() *LibraryVersion { + return s.natTraversalProtocolVersion +} + +// ConnectionIDCounter returns the servers CID counter +func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { + return s.connectionIDCounter +} + +// NewPRUDPServer will return a new PRUDP server +func NewPRUDPServer() *PRUDPServer { + return &PRUDPServer{ + clients: NewMutexMap[string, *PRUDPClient](), + IsQuazalMode: false, + kerberosKeySize: 32, + FragmentSize: 1300, + eventHandlers: make(map[string][]func(PacketInterface)), + connectionIDCounter: NewCounter[uint32](10), + pingTimeout: time.Second * 15, + } +} diff --git a/reliable_packet_substream_manager.go b/reliable_packet_substream_manager.go new file mode 100644 index 00000000..972df322 --- /dev/null +++ b/reliable_packet_substream_manager.go @@ -0,0 +1,94 @@ +package nex + +import ( + "crypto/rc4" + "time" +) + +// ReliablePacketSubstreamManager represents a substream manager for reliable PRUDP packets +type ReliablePacketSubstreamManager struct { + packetMap *MutexMap[uint16, PRUDPPacketInterface] + incomingSequenceIDCounter *Counter[uint16] + outgoingSequenceIDCounter *Counter[uint16] + cipher *rc4.Cipher + decipher *rc4.Cipher + fragmentedPayload []byte + ResendScheduler *ResendScheduler +} + +// Update adds an incoming packet to the list of known packets and returns a list of packets to be processed in order +func (psm *ReliablePacketSubstreamManager) Update(packet PRUDPPacketInterface) []PRUDPPacketInterface { + packets := make([]PRUDPPacketInterface, 0) + + if packet.SequenceID() >= psm.incomingSequenceIDCounter.Value && !psm.packetMap.Has(packet.SequenceID()) { + psm.packetMap.Set(packet.SequenceID(), packet) + + for psm.packetMap.Has(psm.incomingSequenceIDCounter.Value) { + storedPacket, _ := psm.packetMap.Get(psm.incomingSequenceIDCounter.Value) + packets = append(packets, storedPacket) + psm.packetMap.Delete(psm.incomingSequenceIDCounter.Value) + psm.incomingSequenceIDCounter.Next() + } + } + + return packets +} + +// SetCipherKey sets the reliable substreams RC4 cipher keys +func (psm *ReliablePacketSubstreamManager) SetCipherKey(key []byte) { + cipher, _ := rc4.NewCipher(key) + decipher, _ := rc4.NewCipher(key) + + psm.cipher = cipher + psm.decipher = decipher +} + +// NextOutgoingSequenceID sets the reliable substreams RC4 cipher keys +func (psm *ReliablePacketSubstreamManager) NextOutgoingSequenceID() uint16 { + return psm.outgoingSequenceIDCounter.Next() +} + +// Decrypt decrypts the provided data with the substreams decipher +func (psm *ReliablePacketSubstreamManager) Decrypt(data []byte) []byte { + ciphered := make([]byte, len(data)) + + psm.decipher.XORKeyStream(ciphered, data) + + return ciphered +} + +// Encrypt encrypts the provided data with the substreams cipher +func (psm *ReliablePacketSubstreamManager) Encrypt(data []byte) []byte { + ciphered := make([]byte, len(data)) + + psm.cipher.XORKeyStream(ciphered, data) + + return ciphered +} + +// AddFragment adds the given fragment to the substreams fragmented payload +// Returns the current fragmented payload +func (psm *ReliablePacketSubstreamManager) AddFragment(fragment []byte) []byte { + psm.fragmentedPayload = append(psm.fragmentedPayload, fragment...) + + return psm.fragmentedPayload +} + +// ResetFragmentedPayload resets the substreams fragmented payload +func (psm *ReliablePacketSubstreamManager) ResetFragmentedPayload() { + psm.fragmentedPayload = make([]byte, 0) +} + +// NewReliablePacketSubstreamManager initializes a new ReliablePacketSubstreamManager with a starting counter value. +func NewReliablePacketSubstreamManager(startingIncomingSequenceID, startingOutgoingSequenceID uint16) *ReliablePacketSubstreamManager { + psm := &ReliablePacketSubstreamManager{ + packetMap: NewMutexMap[uint16, PRUDPPacketInterface](), + incomingSequenceIDCounter: NewCounter[uint16](startingIncomingSequenceID), + outgoingSequenceIDCounter: NewCounter[uint16](startingOutgoingSequenceID), + ResendScheduler: NewResendScheduler(5, time.Second, 0), + } + + psm.SetCipherKey([]byte("CD&ML")) + + return psm +} diff --git a/resend_scheduler.go b/resend_scheduler.go new file mode 100644 index 00000000..c4842759 --- /dev/null +++ b/resend_scheduler.go @@ -0,0 +1,124 @@ +package nex + +import ( + "time" +) + +// PendingPacket represends a packet scheduled to be resent +type PendingPacket struct { + packet PRUDPPacketInterface + lastSendTime time.Time + resendCount int + isAcknowledged bool + interval time.Duration + ticker *time.Ticker + rs *ResendScheduler +} + +func (pi *PendingPacket) startResendTimer() { + pi.lastSendTime = time.Now() + pi.ticker = time.NewTicker(pi.interval) + + for range pi.ticker.C { + if pi.isAcknowledged { + pi.ticker.Stop() + pi.rs.packets.Delete(pi.packet.SequenceID()) + } else { + pi.rs.resendPacket(pi) + } + } +} + +// ResendScheduler manages the resending of reliable PRUDP packets +type ResendScheduler struct { + packets *MutexMap[uint16, *PendingPacket] + MaxResendCount int + Interval time.Duration + Increase time.Duration +} + +// Stop kills the resend scheduler and stops all pending packets +func (rs *ResendScheduler) Stop() { + stillPending := make([]uint16, rs.packets.Size()) + + rs.packets.Each(func(sequenceID uint16, packet *PendingPacket) { + if !packet.isAcknowledged { + stillPending = append(stillPending, sequenceID) + } + }) + + for _, sequenceID := range stillPending { + if pendingPacket, ok := rs.packets.Get(sequenceID); ok { + pendingPacket.isAcknowledged = true // * Prevent an edge case where the ticker is already being processed + pendingPacket.ticker.Stop() + rs.packets.Delete(sequenceID) + } + } +} + +// AddPacket adds a packet to the scheduler and begins it's timer +func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { + pendingPacket := &PendingPacket{ + packet: packet, + rs: rs, + interval: rs.Interval, + } + + rs.packets.Set(packet.SequenceID(), pendingPacket) + + go pendingPacket.startResendTimer() +} + +// AcknowledgePacket marks a pending packet as acknowledged. It will be ignored at the next resend attempt +func (rs *ResendScheduler) AcknowledgePacket(sequenceID uint16) { + if pendingPacket, ok := rs.packets.Get(sequenceID); ok { + pendingPacket.isAcknowledged = true + } +} + +func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { + if pendingPacket.isAcknowledged { + // * Prevent a race condition where resendPacket may be called + // * at the same time a packet is acknowledged. Packet will be + // * handled properly at the next tick + return + } + + packet := pendingPacket.packet + client := packet.Sender().(*PRUDPClient) + + if pendingPacket.resendCount >= rs.MaxResendCount { + // * The maximum resend count has been reached, consider the client dead. + pendingPacket.ticker.Stop() + rs.packets.Delete(packet.SequenceID()) + client.cleanup() + client.server.clients.Delete(client.address.String()) + return + } + + if time.Since(pendingPacket.lastSendTime) >= rs.Interval { + // * Resend the packet to the client + server := client.server + data := packet.Bytes() + server.sendRaw(client.Address(), data) + + pendingPacket.interval += rs.Increase + pendingPacket.ticker.Reset(pendingPacket.interval) + pendingPacket.resendCount++ + pendingPacket.lastSendTime = time.Now() + } +} + +// NewResendScheduler creates a new ResendScheduler with the provided max resend count and interval and increase durations +// +// If increase is non-zero then every resend will have it's duration increased by that amount. For example an interval of +// 1 second and an increase of 5 seconds. The 1st resend happens after 1 second, the 2nd will take place 6 seconds +// after the 1st, and the 3rd will take place 11 seconds after the 2nd +func NewResendScheduler(maxResendCount int, interval, increase time.Duration) *ResendScheduler { + return &ResendScheduler{ + packets: NewMutexMap[uint16, *PendingPacket](), + MaxResendCount: maxResendCount, + Interval: interval, + Increase: increase, + } +} diff --git a/rmc.go b/rmc.go index fb752517..93075b19 100644 --- a/rmc.go +++ b/rmc.go @@ -5,223 +5,182 @@ import ( "fmt" ) -// TODO - We should probably combine RMCRequest and RMCResponse in a single RMCMessage for simpler packet payload setting/reading that supports both request and response payloads - -// RMCRequest represets a RMC request -type RMCRequest struct { - protocolID uint8 - customID uint16 - callID uint32 - methodID uint32 - parameters []byte -} - -// ProtocolID sets the RMC request protocolID -func (request *RMCRequest) ProtocolID() uint8 { - return request.protocolID -} - -// CustomID returns the RMC request custom ID -func (request *RMCRequest) CustomID() uint16 { - return request.customID -} - -// CallID sets the RMC request callID -func (request *RMCRequest) CallID() uint32 { - return request.callID -} - -// MethodID sets the RMC request methodID -func (request *RMCRequest) MethodID() uint32 { - return request.methodID -} - -// Parameters sets the RMC request parameters -func (request *RMCRequest) Parameters() []byte { - return request.parameters -} - -// SetCustomID sets the RMC request custom ID -func (request *RMCRequest) SetCustomID(customID uint16) { - request.customID = customID -} - -// SetProtocolID sets the RMC request protocol ID -func (request *RMCRequest) SetProtocolID(protocolID uint8) { - request.protocolID = protocolID -} - -// SetCallID sets the RMC request call ID -func (request *RMCRequest) SetCallID(callID uint32) { - request.callID = callID -} - -// SetMethodID sets the RMC request method ID -func (request *RMCRequest) SetMethodID(methodID uint32) { - request.methodID = methodID -} - -// SetParameters sets the RMC request parameters -func (request *RMCRequest) SetParameters(parameters []byte) { - request.parameters = parameters -} - -// NewRMCRequest returns a new blank RMCRequest -func NewRMCRequest() RMCRequest { - return RMCRequest{} -} - -// FromBytes converts a byte slice into a RMCRequest -func (request *RMCRequest) FromBytes(data []byte) error { - if len(data) < 13 { - return errors.New("[RMC] Data size less than minimum") - } - +// RMCMessage represents a message in the RMC (Remote Method Call) protocol +type RMCMessage struct { + IsRequest bool // * Indicates if the message is a request message (true) or response message (false) + IsSuccess bool // * Indicates if the message is a success message (true) for a response message + ProtocolID uint16 // * Protocol ID of the message + CallID uint32 // * Call ID associated with the message + MethodID uint32 // * Method ID in the requested protocol + ErrorCode uint32 // * Error code for a response message + Parameters []byte // * Input for the method +} + +// FromBytes decodes an RMCMessage from the given byte slice. +func (rmc *RMCMessage) FromBytes(data []byte) error { stream := NewStreamIn(data, nil) - size, err := stream.ReadUInt32LE() + length, err := stream.ReadUInt32LE() if err != nil { - return fmt.Errorf("Failed to read RMC Request size. %s", err.Error()) + return fmt.Errorf("Failed to read RMC Message size. %s", err.Error()) } - if int(size) != (len(data) - 4) { - return errors.New("RMC Request size does not match length of buffer") + if stream.Remaining() != int(length) { + return errors.New("RMC Message has unexpected size") } protocolID, err := stream.ReadUInt8() if err != nil { - return fmt.Errorf("Failed to read RMC Request protocol ID. %s", err.Error()) + return fmt.Errorf("Failed to read RMC Message protocol ID. %s", err.Error()) } - request.SetProtocolID(protocolID ^ 0x80) + rmc.ProtocolID = uint16(protocolID & ^byte(0x80)) - if request.ProtocolID() == 0x7f { - customID, err := stream.ReadUInt16LE() + if rmc.ProtocolID == 0x7F { + rmc.ProtocolID, err = stream.ReadUInt16LE() if err != nil { - return fmt.Errorf("Failed to read RMC Request custom protocol ID. %s", err.Error()) + return fmt.Errorf("Failed to read RMC Message extended protocol ID. %s", err.Error()) } - - request.SetCustomID(customID) } - callID, err := stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read RMC Request call ID. %s", err.Error()) - } + if protocolID&0x80 != 0 { + rmc.IsRequest = true + rmc.CallID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (request) call ID. %s", err.Error()) + } - request.SetCallID(callID) + rmc.MethodID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (request) method ID. %s", err.Error()) + } - methodID, err := stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read RMC Request method ID. %s", err.Error()) - } + rmc.Parameters = stream.ReadRemaining() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (request) parameters. %s", err.Error()) + } + } else { + rmc.IsRequest = false + rmc.IsSuccess, err = stream.ReadBool() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) error check. %s", err.Error()) + } - request.SetMethodID(methodID) - request.SetParameters(data[stream.ByteOffset():]) + if rmc.IsSuccess { + rmc.CallID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) + } + + rmc.MethodID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) method ID. %s", err.Error()) + } + + rmc.MethodID = rmc.MethodID & ^uint32(0x8000) + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) method ID. %s", err.Error()) + } + + rmc.Parameters = stream.ReadRemaining() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) parameters. %s", err.Error()) + } + + } else { + rmc.ErrorCode, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) error code. %s", err.Error()) + } + + rmc.CallID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) + } + + } + } return nil } -// Bytes converts a RMCRequest struct into a usable byte array -func (request *RMCRequest) Bytes() []byte { - body := NewStreamOut(nil) +// Bytes serializes the RMCMessage to a byte slice. +func (rmc *RMCMessage) Bytes() []byte { + stream := NewStreamOut(nil) - body.WriteUInt8(request.protocolID | 0x80) - if request.protocolID == 0x7f { - body.WriteUInt16LE(request.customID) + // * RMC requests have their protocol IDs ORed with 0x80 + var protocolIDFlag uint16 = 0x80 + if !rmc.IsRequest { + protocolIDFlag = 0 } - body.WriteUInt32LE(request.callID) - body.WriteUInt32LE(request.methodID) + if rmc.ProtocolID < 0x80 { + stream.WriteUInt8(uint8(rmc.ProtocolID | protocolIDFlag)) + } else { + stream.WriteUInt8(uint8(0x7F | protocolIDFlag)) + stream.WriteUInt16LE(rmc.ProtocolID) + } - if request.parameters != nil && len(request.parameters) > 0 { - body.Grow(int64(len(request.parameters))) - body.WriteBytesNext(request.parameters) + if rmc.IsRequest { + stream.WriteUInt32LE(rmc.CallID) + stream.WriteUInt32LE(rmc.MethodID) + stream.Grow(int64(len(rmc.Parameters))) + stream.WriteBytesNext(rmc.Parameters) + } else { + if rmc.IsSuccess { + stream.WriteBool(true) + stream.WriteUInt32LE(rmc.CallID) + stream.WriteUInt32LE(rmc.MethodID | 0x8000) + stream.Grow(int64(len(rmc.Parameters))) + stream.WriteBytesNext(rmc.Parameters) + } else { + stream.WriteBool(false) + stream.WriteUInt32LE(uint32(rmc.ErrorCode)) + stream.WriteUInt32LE(rmc.CallID) + } } - data := NewStreamOut(nil) + serialized := stream.Bytes() - data.WriteBuffer(body.Bytes()) + message := NewStreamOut(nil) - return data.Bytes() -} + message.WriteUInt32LE(uint32(len(serialized))) + message.Grow(int64(len(serialized))) + message.WriteBytesNext(serialized) -// RMCResponse represents a RMC response -type RMCResponse struct { - protocolID uint8 - customID uint16 - success uint8 - callID uint32 - methodID uint32 - data []byte - errorCode uint32 + return message.Bytes() } -// CustomID returns the RMC response customID -func (response *RMCResponse) CustomID() uint16 { - return response.customID +// NewRMCMessage returns a new generic RMC Message +func NewRMCMessage() *RMCMessage { + return &RMCMessage{} } -// SetCustomID sets the RMC response customID -func (response *RMCResponse) SetCustomID(customID uint16) { - response.customID = customID +// NewRMCRequest returns a new blank RMCRequest +func NewRMCRequest() RMCMessage { + return RMCMessage{IsRequest: true} } -// SetSuccess sets the RMCResponse payload to an instance of RMCSuccess -func (response *RMCResponse) SetSuccess(methodID uint32, data []byte) { - response.success = 1 - response.methodID = methodID - response.data = data +// NewRMCSuccess returns a new RMC Message configured as a success response +func NewRMCSuccess(parameters []byte) *RMCMessage { + message := NewRMCMessage() + message.IsRequest = false + message.IsSuccess = true + message.Parameters = parameters + + return message } -// SetError sets the RMCResponse payload to an instance of RMCError -func (response *RMCResponse) SetError(errorCode uint32) { +// NewRMCError returns a new RMC Message configured as a error response +func NewRMCError(errorCode uint32) *RMCMessage { if int(errorCode)&errorMask == 0 { errorCode = uint32(int(errorCode) | errorMask) } - response.success = 0 - response.errorCode = errorCode -} - -// Bytes converts a RMCResponse struct into a usable byte array -func (response *RMCResponse) Bytes() []byte { - body := NewStreamOut(nil) - - if response.protocolID > 0 { - body.WriteUInt8(response.protocolID) - if response.protocolID == 0x7f { - body.WriteUInt16LE(response.customID) - } - } - body.WriteUInt8(response.success) - - if response.success == 1 { - body.WriteUInt32LE(response.callID) - body.WriteUInt32LE(response.methodID | 0x8000) - - if response.data != nil && len(response.data) > 0 { - body.Grow(int64(len(response.data))) - body.WriteBytesNext(response.data) - } - } else { - body.WriteUInt32LE(response.errorCode) - body.WriteUInt32LE(response.callID) - } - - data := NewStreamOut(nil) - - data.WriteBuffer(body.Bytes()) - - return data.Bytes() -} - -// NewRMCResponse returns a new RMCResponse -func NewRMCResponse(protocolID uint8, callID uint32) RMCResponse { - response := RMCResponse{ - protocolID: protocolID, - callID: callID, - } + message := NewRMCMessage() + message.IsRequest = false + message.IsSuccess = false + message.ErrorCode = errorCode - return response + return message } diff --git a/sequence_id_manager.go b/sequence_id_manager.go deleted file mode 100644 index ca80bde6..00000000 --- a/sequence_id_manager.go +++ /dev/null @@ -1,29 +0,0 @@ -package nex - -// SequenceIDManager implements an API for managing the sequence IDs of different packet streams on a client -type SequenceIDManager struct { - reliableCounter *Counter // TODO - NEX only uses one reliable stream, but Rendezvous supports many. This needs to be a slice! - pingCounter *Counter - // TODO - Unreliable packets for Rendezvous -} - -// Next gets the next sequence ID for the packet. Returns 0 for an unsupported packet -func (s *SequenceIDManager) Next(packet PacketInterface) uint32 { - if packet.HasFlag(FlagReliable) { - return s.reliableCounter.Increment() - } - - if packet.Type() == PingPacket { - return s.pingCounter.Increment() - } - - return 0 -} - -// NewSequenceIDManager returns a new SequenceIDManager -func NewSequenceIDManager() *SequenceIDManager { - return &SequenceIDManager{ - reliableCounter: NewCounter(0), - pingCounter: NewCounter(0), - } -} diff --git a/server.go b/server.go deleted file mode 100644 index 3d3e138e..00000000 --- a/server.go +++ /dev/null @@ -1,1075 +0,0 @@ -// Package nex implements an API for creating bare-bones -// NEX servers and clients and provides the underlying -// PRUDP implementation -// -// No NEX protocols are implemented in this package. For -// NEX protocols see https://github.com/PretendoNetwork/nex-protocols-go -// -// No PIA code is implemented in this package -package nex - -import ( - "crypto/rand" - "fmt" - mrand "math/rand" - "net" - "net/http" - "runtime" - "strconv" - "time" - - "golang.org/x/exp/slices" -) - -// Server represents a PRUDP server -type Server struct { - socket *net.UDPConn - clients *MutexMap[string, *Client] - genericEventHandles map[string][]func(PacketInterface) - prudpV0EventHandles map[string][]func(*PacketV0) - prudpV1EventHandles map[string][]func(*PacketV1) - hppEventHandles map[string][]func(*HPPPacket) - hppClientResponses map[*Client](chan []byte) - passwordFromPIDHandler func(pid uint32) (string, uint32) - accessKey string - prudpVersion int - prudpProtocolMinorVersion int - supportedFunctions int - fragmentSize int16 - resendTimeout time.Duration - resendTimeoutIncrement time.Duration - resendMaxIterations int - pingTimeout int - kerberosPassword string - kerberosKeySize int - kerberosKeyDerivation int - kerberosTicketVersion int - connectionIDCounter *Counter - nexVersion *NEXVersion - datastoreProtocolVersion *NEXVersion - matchMakingProtocolVersion *NEXVersion - rankingProtocolVersion *NEXVersion - ranking2ProtocolVersion *NEXVersion - messagingProtocolVersion *NEXVersion - utilityProtocolVersion *NEXVersion - natTraversalProtocolVersion *NEXVersion - emuSendPacketDropPercent int - emuRecvPacketDropPercent int -} - -// Listen starts a NEX server on a given address -func (server *Server) Listen(address string) { - protocol := "udp" - - udpAddress, err := net.ResolveUDPAddr(protocol, address) - if err != nil { - panic(err) - } - - socket, err := net.ListenUDP(protocol, udpAddress) - if err != nil { - panic(err) - } - - server.SetSocket(socket) - - quit := make(chan struct{}) - - for i := 0; i < runtime.NumCPU(); i++ { - go server.listenDatagram(quit) - } - - logger.Success(fmt.Sprintf("PRUDP server listening on address - %s", udpAddress.String())) - - server.Emit("Listening", nil) - - <-quit -} - -func (server *Server) listenDatagram(quit chan struct{}) { - err := error(nil) - - for err == nil { - err = server.handleSocketMessage() - } - - quit <- struct{}{} - - panic(err) -} - -func (server *Server) handleSocketMessage() error { - var buffer [64000]byte - - socket := server.Socket() - - length, addr, err := socket.ReadFromUDP(buffer[0:]) - if err != nil { - return err - } - - if server.shouldDropPacket(true) { - // Emulate packet drop for debugging - return nil - } - - discriminator := addr.String() - - client, ok := server.clients.Get(discriminator) - - if !ok { - client = NewClient(addr, server) - - server.clients.Set(discriminator, client) - } - - data := buffer[0:length] - - var packet PacketInterface - - if server.PRUDPVersion() == 0 { - packet, err = NewPacketV0(client, data) - } else { - packet, err = NewPacketV1(client, data) - } - - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return nil - } - - client.IncreasePingTimeoutTime(server.PingTimeout()) - - if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { - // TODO - Should this return an error? - server.handleAcknowledgement(packet) - return nil - } - - if packet.HasFlag(FlagNeedsAck) { - if packet.Type() != ConnectPacket || (packet.Type() == ConnectPacket && len(packet.Payload()) <= 0) { - go server.AcknowledgePacket(packet, nil) - } - - if packet.Type() == DisconnectPacket { - go server.AcknowledgePacket(packet, nil) - go server.AcknowledgePacket(packet, nil) - } - } - - switch packet.Type() { - case PingPacket: - err := server.processPacket(packet) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return nil - } - default: - // TODO - Make a better API in client to access incomingPacketManager? - client.incomingPacketManager.Push(packet) - - // TODO - Make this API smarter. Only track missing packets and not all packets? - // * Keep processing packets so long as the next one is in the pool, - // * this way if several packets came in out of order they all get - // * processed at once the moment the correct next packet comes in - for next := client.incomingPacketManager.Next(); next != nil; { - err := server.processPacket(next) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return nil - } - - next = client.incomingPacketManager.Next() - } - } - - return nil -} - -func (server *Server) processPacket(packet PacketInterface) error { - err := packet.DecryptPayload() - if err != nil { - return err - } - - client := packet.Sender() - - if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { - return nil - } - - switch packet.Type() { - case SynPacket: - // * PID should always be 0 when a fresh connection is made - if client.PID() != 0 { - // * Was connected before on the same device, using a different account - server.Emit("Disconnect", packet) // * Disconnect the old connection - } - err := client.Reset() - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return nil - } - - client.SetConnected(true) - client.StartTimeoutTimer() - // TODO - Don't make this part suck ass? - // * Manually incrementing because the original manager gets destroyed in the reset - // * but we need to still track the SYN packet was sent - client.incomingPacketManager.currentSequenceID.Increment() - server.Emit("Syn", packet) - case ConnectPacket: - packet.Sender().SetClientConnectionSignature(packet.ConnectionSignature()) - - server.Emit("Connect", packet) - case DataPacket: - server.Emit("Data", packet) - case DisconnectPacket: - server.Emit("Disconnect", packet) - server.GracefulKick(client) - case PingPacket: - //server.SendPing(client) - server.Emit("Ping", packet) - } - - server.Emit("Packet", packet) - - return nil -} - -func (server *Server) handleAcknowledgement(packet PacketInterface) { - if packet.Version() == 0 || (packet.HasFlag(FlagAck) && !packet.HasFlag(FlagMultiAck)) { - packet.Sender().outgoingResendManager.Remove(packet.SequenceID()) - } else { - // TODO - Validate the aggregate packet is valid and can be processed - sequenceIDs := make([]uint16, 0) - stream := NewStreamIn(packet.Payload(), server) - var baseSequenceID uint16 - - // TODO - We should probably handle these errors lol - if server.PRUDPProtocolMinorVersion() >= 2 { - _, _ = stream.ReadUInt8() // * Substream ID. NEX always uses 0 - additionalIDsCount, _ := stream.ReadUInt8() - baseSequenceID, _ = stream.ReadUInt16LE() - - for i := 0; i < int(additionalIDsCount); i++ { - additionalID, _ := stream.ReadUInt16LE() - sequenceIDs = append(sequenceIDs, additionalID) - } - } else { - baseSequenceID = packet.SequenceID() - - for remaining := stream.Remaining(); remaining != 0; { - additionalID, _ := stream.ReadUInt16LE() - sequenceIDs = append(sequenceIDs, additionalID) - remaining = stream.Remaining() - } - } - - // * MutexMap.Each locks the mutex, can't remove while reading - // * Have to just loop again - packet.Sender().outgoingResendManager.pending.Each(func(sequenceID uint16, pending *PendingPacket) { - if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) { - sequenceIDs = append(sequenceIDs, sequenceID) - } - }) - - // * Actually remove the packets from the pool - for _, sequenceID := range sequenceIDs { - packet.Sender().outgoingResendManager.Remove(sequenceID) - } - } -} - -// HPPListen starts a NEX HPP server on a given address -func (server *Server) HPPListen(address string) { - hppHandler := func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - w.WriteHeader(http.StatusBadRequest) - return - } - - pidValue := req.Header.Get("pid") - if pidValue == "" { - logger.Error("[HPP] PID is empty") - w.WriteHeader(http.StatusBadRequest) - return - } - - token := req.Header.Get("token") - if token == "" { - logger.Error("[HPP] Token is empty") - w.WriteHeader(http.StatusBadRequest) - return - } - - accessKeySignature := req.Header.Get("signature1") - if accessKeySignature == "" { - logger.Error("[HPP] Access key signature is empty") - w.WriteHeader(http.StatusBadRequest) - return - } - - passwordSignature := req.Header.Get("signature2") - if passwordSignature == "" { - logger.Error("[HPP] Password signature is empty") - w.WriteHeader(http.StatusBadRequest) - return - } - - pid, err := strconv.Atoi(pidValue) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - w.WriteHeader(http.StatusBadRequest) - return - } - - rmcRequestString := req.FormValue("file") - - rmcRequestBytes := []byte(rmcRequestString) - - client := NewClient(nil, server) - client.SetPID(uint32(pid)) - - hppPacket, err := NewHPPPacket(client, rmcRequestBytes) - if err != nil { - logger.Error(fmt.Sprintf("Failed to create new HPPPacket instance. %s", err.Error())) - w.WriteHeader(http.StatusBadRequest) - return - } - - hppPacket.SetAccessKeySignature(accessKeySignature) - hppPacket.SetPasswordSignature(passwordSignature) - - err = hppPacket.ValidateAccessKey() - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - w.WriteHeader(http.StatusBadRequest) - return - } - - err = hppPacket.ValidatePassword() - if err != nil { - logger.Error(err.Error()) - rmcRequest := hppPacket.RMCRequest() - callID := rmcRequest.CallID() - - errorResponse := NewRMCResponse(0, callID) - // * HPP returns PythonCore::ValidationError if password is missing or invalid - errorResponse.SetError(Errors.PythonCore.ValidationError) - - _, err = w.Write(errorResponse.Bytes()) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - } - - return - } - - server.hppClientResponses[client] = make(chan []byte) - - server.Emit("Data", hppPacket) - - rmcResponseBytes := <-server.hppClientResponses[client] - - if len(rmcResponseBytes) > 0 { - _, err = w.Write(rmcResponseBytes) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - } - } - - delete(server.hppClientResponses, client) - } - - http.HandleFunc("/hpp/", hppHandler) - - quit := make(chan struct{}) - - go server.handleHTTP(address, quit) - - logger.Success(fmt.Sprintf("HPP server listening on address - %s", address)) - - <-quit -} - -func (server *Server) handleHTTP(address string, quit chan struct{}) { - err := http.ListenAndServe(address, nil) - - quit <- struct{}{} - - panic(err) -} - -// On sets the data event handler -func (server *Server) On(event string, handler interface{}) { - // Check if the handler type matches one of the allowed types, and store the handler in it's allowed property - // Need to cast the handler to the correct function type before storing - switch handler := handler.(type) { - case func(PacketInterface): - server.genericEventHandles[event] = append(server.genericEventHandles[event], handler) - case func(*PacketV0): - server.prudpV0EventHandles[event] = append(server.prudpV0EventHandles[event], handler) - case func(*PacketV1): - server.prudpV1EventHandles[event] = append(server.prudpV1EventHandles[event], handler) - case func(*HPPPacket): - server.hppEventHandles[event] = append(server.hppEventHandles[event], handler) - } -} - -// Emit runs the given event handle -func (server *Server) Emit(event string, packet interface{}) { - - eventName := server.genericEventHandles[event] - for i := 0; i < len(eventName); i++ { - handler := eventName[i] - packet := packet.(PacketInterface) - go handler(packet) - } - - // Check if the packet type matches one of the allowed types and run the given handler - - switch packet := packet.(type) { - case *PacketV0: - eventName := server.prudpV0EventHandles[event] - for i := 0; i < len(eventName); i++ { - handler := eventName[i] - go handler(packet) - } - case *PacketV1: - eventName := server.prudpV1EventHandles[event] - for i := 0; i < len(eventName); i++ { - handler := eventName[i] - go handler(packet) - } - case *HPPPacket: - eventName := server.hppEventHandles[event] - for i := 0; i < len(eventName); i++ { - handler := eventName[i] - go handler(packet) - } - } -} - -// ClientConnected checks if a given client is stored on the server -func (server *Server) ClientConnected(client *Client) bool { - discriminator := client.Address().String() - - _, connected := server.clients.Get(discriminator) - - return connected -} - -// TimeoutKick removes a client from the server for inactivity -func (server *Server) TimeoutKick(client *Client) { - var packet PacketInterface - var err error - - if server.PRUDPVersion() == 0 { - packet, err = NewPacketV0(client, nil) - packet.SetVersion(0) - } else { - packet, err = NewPacketV1(client, nil) - packet.SetVersion(1) - } - - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return - } - - packet.SetSource(0xA1) - packet.SetDestination(0xAF) - packet.SetType(DisconnectPacket) - - server.Send(packet) - - server.Emit("Kick", packet) - client.SetConnected(false) - discriminator := client.Address().String() - - client.outgoingResendManager.Clear() - server.clients.Delete(discriminator) -} - -// GracefulKick removes an active client from the server -func (server *Server) GracefulKick(client *Client) { - var packet PacketInterface - var err error - - if server.PRUDPVersion() == 0 { - packet, err = NewPacketV0(client, nil) - packet.SetVersion(0) - } else { - packet, err = NewPacketV1(client, nil) - packet.SetVersion(1) - } - - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return - } - - packet.SetSource(0xA1) - packet.SetDestination(0xAF) - packet.SetType(DisconnectPacket) - - packet.AddFlag(FlagReliable) - - server.Send(packet) - - server.Emit("Kick", packet) - client.SetConnected(false) - client.StopTimeoutTimer() - discriminator := client.Address().String() - - client.outgoingResendManager.Clear() - server.clients.Delete(discriminator) -} - -// GracefulKickAll removes all clients from the server -func (server *Server) GracefulKickAll() { - // * https://stackoverflow.com/a/40456170 - server.clients.RLock() - defer server.clients.RUnlock() - // TODO - MAKE A BETTER API FOR RANGING OVER THIS DATA INSIDE MutexMap! - for _, client := range server.clients.real { - server.clients.RUnlock() - - var packet PacketInterface - var err error - if server.PRUDPVersion() == 0 { - packet, err = NewPacketV0(client, nil) - packet.SetVersion(0) - } else { - packet, err = NewPacketV1(client, nil) - packet.SetVersion(1) - } - - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - server.clients.RLock() - continue - } - - packet.SetSource(0xA1) - packet.SetDestination(0xAF) - packet.SetType(DisconnectPacket) - - packet.AddFlag(FlagReliable) - - server.Send(packet) - - server.Emit("Kick", packet) - client.SetConnected(false) - discriminator := client.Address().String() - - client.outgoingResendManager.Clear() - server.clients.Delete(discriminator) - - server.clients.RLock() - } -} - -// SendPing sends a ping packet to the given client -func (server *Server) SendPing(client *Client) { - var pingPacket PacketInterface - var err error - - if server.PRUDPVersion() == 0 { - pingPacket, err = NewPacketV0(client, nil) - } else { - pingPacket, err = NewPacketV1(client, nil) - } - - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return - } - - pingPacket.SetSource(0xA1) - pingPacket.SetDestination(0xAF) - pingPacket.SetType(PingPacket) - pingPacket.AddFlag(FlagNeedsAck) - pingPacket.AddFlag(FlagReliable) - - server.Send(pingPacket) -} - -// AcknowledgePacket acknowledges that the given packet was recieved -func (server *Server) AcknowledgePacket(packet PacketInterface, payload []byte) { - sender := packet.Sender() - - var ackPacket PacketInterface - var err error - - if server.PRUDPVersion() == 0 { - ackPacket, err = NewPacketV0(sender, nil) - } else { - ackPacket, err = NewPacketV1(sender, nil) - } - - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return - } - - ackPacket.SetSource(packet.Destination()) - ackPacket.SetDestination(packet.Source()) - ackPacket.SetType(packet.Type()) - ackPacket.SetSequenceID(packet.SequenceID()) - ackPacket.SetFragmentID(packet.FragmentID()) - ackPacket.AddFlag(FlagAck) - ackPacket.AddFlag(FlagHasSize) - - if payload != nil { - ackPacket.SetPayload(payload) - } - - if server.PRUDPVersion() == 1 { - packet := packet.(*PacketV1) - ackPacket := ackPacket.(*PacketV1) - - ackPacket.SetVersion(1) - ackPacket.SetSubstreamID(0) - ackPacket.AddFlag(FlagHasSize) - - if packet.Type() == SynPacket || packet.Type() == ConnectPacket { - ackPacket.SetPRUDPProtocolMinorVersion(packet.sender.PRUDPProtocolMinorVersion()) - //Going to leave this note here in case this causes issues later on, but for now, the below line breaks Splatoon and Minecraft Wii U (and probs other later games). - //ackPacket.SetSupportedFunctions(packet.sender.SupportedFunctions()) - ackPacket.SetMaximumSubstreamID(0) - } - - if packet.Type() == SynPacket { - serverConnectionSignature := make([]byte, 16) - _, err := rand.Read(serverConnectionSignature) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - return - } - - ackPacket.Sender().SetServerConnectionSignature(serverConnectionSignature) - ackPacket.SetConnectionSignature(serverConnectionSignature) - } - - if packet.Type() == ConnectPacket { - ackPacket.SetConnectionSignature(make([]byte, 16)) - ackPacket.SetInitialSequenceID(10000) - } - - if packet.Type() == DataPacket { - // Aggregate acknowledgement - ackPacket.ClearFlag(FlagAck) - ackPacket.AddFlag(FlagMultiAck) - - payloadStream := NewStreamOut(server) - - // New version - if server.PRUDPProtocolMinorVersion() >= 2 { - ackPacket.SetSequenceID(0) - ackPacket.SetSubstreamID(1) - - // I'm lazy so just ack one packet - payloadStream.WriteUInt8(0) // substream ID - payloadStream.WriteUInt8(0) // length of additional sequence ids - payloadStream.WriteUInt16LE(packet.SequenceID()) // Sequence id - } - - ackPacket.SetPayload(payloadStream.Bytes()) - } - } - - data := ackPacket.Bytes() - - server.SendRaw(sender.Address(), data) -} - -// Socket returns the underlying server UDP socket -func (server *Server) Socket() *net.UDPConn { - return server.socket -} - -// SetSocket sets the underlying UDP socket -func (server *Server) SetSocket(socket *net.UDPConn) { - server.socket = socket -} - -// PRUDPVersion returns the server PRUDP version -func (server *Server) PRUDPVersion() int { - return server.prudpVersion -} - -// SetPRUDPVersion sets the server PRUDP version -func (server *Server) SetPRUDPVersion(prudpVersion int) { - server.prudpVersion = prudpVersion -} - -// PRUDPProtocolMinorVersion returns the server PRUDP minor version -func (server *Server) PRUDPProtocolMinorVersion() int { - return server.prudpProtocolMinorVersion -} - -// SetPRUDPProtocolMinorVersion sets the server PRUDP minor -func (server *Server) SetPRUDPProtocolMinorVersion(prudpProtocolMinorVersion int) { - server.prudpProtocolMinorVersion = prudpProtocolMinorVersion -} - -// NEXVersion returns the server NEX version -func (server *Server) NEXVersion() *NEXVersion { - return server.nexVersion -} - -// SetDefaultNEXVersion sets the default NEX protocol versions -func (server *Server) SetDefaultNEXVersion(nexVersion *NEXVersion) { - server.nexVersion = nexVersion - server.datastoreProtocolVersion = nexVersion.Copy() - server.matchMakingProtocolVersion = nexVersion.Copy() - server.rankingProtocolVersion = nexVersion.Copy() - server.ranking2ProtocolVersion = nexVersion.Copy() - server.messagingProtocolVersion = nexVersion.Copy() - server.utilityProtocolVersion = nexVersion.Copy() - server.natTraversalProtocolVersion = nexVersion.Copy() -} - -// DataStoreProtocolVersion returns the servers DataStore protocol version -func (server *Server) DataStoreProtocolVersion() *NEXVersion { - return server.datastoreProtocolVersion -} - -// SetDataStoreProtocolVersion sets the servers DataStore protocol version -func (server *Server) SetDataStoreProtocolVersion(nexVersion *NEXVersion) { - server.datastoreProtocolVersion = nexVersion -} - -// MatchMakingProtocolVersion returns the servers MatchMaking protocol version -func (server *Server) MatchMakingProtocolVersion() *NEXVersion { - return server.matchMakingProtocolVersion -} - -// SetMatchMakingProtocolVersion sets the servers MatchMaking protocol version -func (server *Server) SetMatchMakingProtocolVersion(nexVersion *NEXVersion) { - server.matchMakingProtocolVersion = nexVersion -} - -// RankingProtocolVersion returns the servers Ranking protocol version -func (server *Server) RankingProtocolVersion() *NEXVersion { - return server.rankingProtocolVersion -} - -// SetRankingProtocolVersion sets the servers Ranking protocol version -func (server *Server) SetRankingProtocolVersion(nexVersion *NEXVersion) { - server.rankingProtocolVersion = nexVersion -} - -// Ranking2ProtocolVersion returns the servers Ranking2 protocol version -func (server *Server) Ranking2ProtocolVersion() *NEXVersion { - return server.ranking2ProtocolVersion -} - -// SetRanking2ProtocolVersion sets the servers Ranking2 protocol version -func (server *Server) SetRanking2ProtocolVersion(nexVersion *NEXVersion) { - server.ranking2ProtocolVersion = nexVersion -} - -// MessagingProtocolVersion returns the servers Messaging protocol version -func (server *Server) MessagingProtocolVersion() *NEXVersion { - return server.messagingProtocolVersion -} - -// SetMessagingProtocolVersion sets the servers Messaging protocol version -func (server *Server) SetMessagingProtocolVersion(nexVersion *NEXVersion) { - server.messagingProtocolVersion = nexVersion -} - -// UtilityProtocolVersion returns the servers Utility protocol version -func (server *Server) UtilityProtocolVersion() *NEXVersion { - return server.utilityProtocolVersion -} - -// SetUtilityProtocolVersion sets the servers Utility protocol version -func (server *Server) SetUtilityProtocolVersion(nexVersion *NEXVersion) { - server.utilityProtocolVersion = nexVersion -} - -// SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version -func (server *Server) SetNATTraversalProtocolVersion(nexVersion *NEXVersion) { - server.natTraversalProtocolVersion = nexVersion -} - -// NATTraversalProtocolVersion returns the servers NAT Traversal protocol version -func (server *Server) NATTraversalProtocolVersion() *NEXVersion { - return server.natTraversalProtocolVersion -} - -// SupportedFunctions returns the supported PRUDP functions by the server -func (server *Server) SupportedFunctions() int { - return server.supportedFunctions -} - -// SetSupportedFunctions sets the supported PRUDP functions by the server -func (server *Server) SetSupportedFunctions(supportedFunctions int) { - server.supportedFunctions = supportedFunctions -} - -// AccessKey returns the server access key -func (server *Server) AccessKey() string { - return server.accessKey -} - -// SetAccessKey sets the server access key -func (server *Server) SetAccessKey(accessKey string) { - server.accessKey = accessKey -} - -// KerberosPassword returns the server kerberos password -func (server *Server) KerberosPassword() string { - return server.kerberosPassword -} - -// SetKerberosPassword sets the server kerberos password -func (server *Server) SetKerberosPassword(kerberosPassword string) { - server.kerberosPassword = kerberosPassword -} - -// KerberosKeySize returns the server kerberos key size -func (server *Server) KerberosKeySize() int { - return server.kerberosKeySize -} - -// SetKerberosKeySize sets the server kerberos key size -func (server *Server) SetKerberosKeySize(kerberosKeySize int) { - server.kerberosKeySize = kerberosKeySize -} - -// KerberosTicketVersion returns the server kerberos ticket contents version -func (server *Server) KerberosTicketVersion() int { - return server.kerberosTicketVersion -} - -// SetKerberosTicketVersion sets the server kerberos ticket contents version -func (server *Server) SetKerberosTicketVersion(ticketVersion int) { - server.kerberosTicketVersion = ticketVersion -} - -// PingTimeout returns the server ping timeout time in seconds -func (server *Server) PingTimeout() int { - return server.pingTimeout -} - -// SetPingTimeout sets the server ping timeout time in seconds -func (server *Server) SetPingTimeout(pingTimeout int) { - server.pingTimeout = pingTimeout -} - -// SetFragmentSize sets the packet fragment size -func (server *Server) SetFragmentSize(fragmentSize int16) { - server.fragmentSize = fragmentSize -} - -// SetResendTimeout sets the time that a packet should wait before resending to the client -func (server *Server) SetResendTimeout(resendTimeout time.Duration) { - server.resendTimeout = resendTimeout -} - -// SetResendTimeoutIncrement sets how much to increment the resendTimeout every time a packet is resent to the client -func (server *Server) SetResendTimeoutIncrement(resendTimeoutIncrement time.Duration) { - server.resendTimeoutIncrement = resendTimeoutIncrement -} - -// SetResendMaxIterations sets the max number of times a packet can try to resend before assuming the client is dead -func (server *Server) SetResendMaxIterations(resendMaxIterations int) { - server.resendMaxIterations = resendMaxIterations -} - -// ConnectionIDCounter gets the server connection ID counter -func (server *Server) ConnectionIDCounter() *Counter { - return server.connectionIDCounter -} - -// FindClientFromPID finds a client by their PID -func (server *Server) FindClientFromPID(pid uint32) *Client { - // * https://stackoverflow.com/a/40456170 - // TODO - MAKE A BETTER API FOR RANGING OVER THIS DATA INSIDE MutexMap! - server.clients.RLock() - for _, client := range server.clients.real { - server.clients.RUnlock() - if client.pid == pid { - return client - } - server.clients.RLock() - } - - server.clients.RUnlock() - - return nil -} - -// FindClientFromConnectionID finds a client by their Connection ID -func (server *Server) FindClientFromConnectionID(rvcid uint32) *Client { - // * https://stackoverflow.com/a/40456170 - // TODO - MAKE A BETTER API FOR RANGING OVER THIS DATA INSIDE MutexMap! - server.clients.RLock() - for _, client := range server.clients.real { - server.clients.RUnlock() - if client.connectionID == rvcid { - return client - } - server.clients.RLock() - } - - server.clients.RUnlock() - - return nil -} - -// SetPasswordFromPIDFunction sets the function for HPP or the auth server to get a NEX password using the PID -func (server *Server) SetPasswordFromPIDFunction(handler func(pid uint32) (string, uint32)) { - server.passwordFromPIDHandler = handler -} - -// PasswordFromPIDFunction returns the function for HPP or the auth server to get a NEX password using the PID -func (server *Server) PasswordFromPIDFunction() func(pid uint32) (string, uint32) { - return server.passwordFromPIDHandler -} - -// Send writes data to client -func (server *Server) Send(packet PacketInterface) { - switch packet := packet.(type) { - case *HPPPacket: - client := packet.Sender() - payload := packet.Payload() - server.hppClientResponses[client] <- payload - default: - data := packet.Payload() - fragments := int(int16(len(data)) / server.fragmentSize) - - var fragmentID uint8 = 1 - for i := 0; i <= fragments; i++ { - time.Sleep(time.Second / 2) - if int16(len(data)) < server.fragmentSize { - packet.SetPayload(data) - server.SendFragment(packet, 0) - } else { - packet.SetPayload(data[:server.fragmentSize]) - server.SendFragment(packet, fragmentID) - - data = data[server.fragmentSize:] - fragmentID++ - } - } - } - -} - -// SendFragment sends a packet fragment to the client -func (server *Server) SendFragment(packet PacketInterface, fragmentID uint8) { - client := packet.Sender() - payload := packet.Payload() - - if packet.Type() == DataPacket { - if packet.Version() == 0 && packet.HasFlag(FlagAck) { - // * v0 ACK payloads empty, ensure this - payload = []byte{} - } else if !packet.HasFlag(FlagMultiAck) { - if payload != nil || len(payload) > 0 { - payloadSize := len(payload) - - encrypted := make([]byte, payloadSize) - packet.Sender().Cipher().XORKeyStream(encrypted, payload) - - payload = encrypted - } - } - - // * Only add the HAS_SIZE flag if the payload exists - if !packet.HasFlag(FlagHasSize) && len(payload) > 0 { - packet.AddFlag(FlagHasSize) - } - } - - packet.SetFragmentID(fragmentID) - - packet.SetPayload(payload) - packet.SetSequenceID(uint16(client.SequenceIDOutManager().Next(packet))) - - encodedPacket := packet.Bytes() - - server.SendRaw(client.Address(), encodedPacket) - - if (packet.HasFlag(FlagReliable) || packet.Type() == SynPacket) && packet.HasFlag(FlagNeedsAck) { - packet.Sender().outgoingResendManager.Add(packet) - } -} - -// SendRaw writes raw packet data to the client socket -func (server *Server) SendRaw(conn *net.UDPAddr, data []byte) { - if server.shouldDropPacket(false) { - // Emulate packet drop for debugging - return - } - - _, err := server.Socket().WriteToUDP(data, conn) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - } -} - -func (server *Server) shouldDropPacket(isRecv bool) bool { - if isRecv { - return server.emuRecvPacketDropPercent != 0 && mrand.Intn(100) < server.emuRecvPacketDropPercent - } else { - return server.emuSendPacketDropPercent != 0 && mrand.Intn(100) < server.emuSendPacketDropPercent - } -} - -// SetEmulatedPacketDropPercent sets the percentage of emulated sent and received dropped packets -func (server *Server) SetEmulatedPacketDropPercent(forRecv bool, percent int) { - if forRecv { - server.emuRecvPacketDropPercent = percent - } else { - server.emuSendPacketDropPercent = percent - } -} - -// NewServer returns a new NEX server -func NewServer() *Server { - server := &Server{ - genericEventHandles: make(map[string][]func(PacketInterface)), - prudpV0EventHandles: make(map[string][]func(*PacketV0)), - prudpV1EventHandles: make(map[string][]func(*PacketV1)), - hppEventHandles: make(map[string][]func(*HPPPacket)), - hppClientResponses: make(map[*Client](chan []byte)), - clients: NewMutexMap[string, *Client](), - prudpVersion: 1, - fragmentSize: 1300, - resendTimeout: time.Second, - resendTimeoutIncrement: 0, - resendMaxIterations: 5, - pingTimeout: 5, - kerberosKeySize: 32, - kerberosKeyDerivation: 0, - connectionIDCounter: NewCounter(10), - emuSendPacketDropPercent: 0, - emuRecvPacketDropPercent: 0, - } - - server.SetDefaultNEXVersion(NewNEXVersion(0, 0, 0)) - - return server -} diff --git a/server_interface.go b/server_interface.go new file mode 100644 index 00000000..b992d18e --- /dev/null +++ b/server_interface.go @@ -0,0 +1,17 @@ +package nex + +// ServerInterface defines all the methods a server should have regardless of type +type ServerInterface interface { + AccessKey() string + SetAccessKey(accessKey string) + LibraryVersion() *LibraryVersion + DataStoreProtocolVersion() *LibraryVersion + MatchMakingProtocolVersion() *LibraryVersion + RankingProtocolVersion() *LibraryVersion + Ranking2ProtocolVersion() *LibraryVersion + MessagingProtocolVersion() *LibraryVersion + UtilityProtocolVersion() *LibraryVersion + NATTraversalProtocolVersion() *LibraryVersion + SetDefaultLibraryVersion(version *LibraryVersion) + Send(packet PacketInterface) +} diff --git a/stream_in.go b/stream_in.go index 795d0e87..535c66e3 100644 --- a/stream_in.go +++ b/stream_in.go @@ -9,10 +9,10 @@ import ( crunch "github.com/superwhiskers/crunch/v3" ) -// StreamIn is an input stream abstraction of github.com/superwhiskers/crunch with nex type support +// StreamIn is an input stream abstraction of github.com/superwhiskers/crunch/v3 with nex type support type StreamIn struct { *crunch.Buffer - Server *Server + Server ServerInterface } // Remaining returns the amount of data left to be read in the buffer @@ -255,9 +255,7 @@ func (stream *StreamIn) ReadStructure(structure StructureInterface) (StructureIn } } - nexVersion := stream.Server.NEXVersion() - - if nexVersion.GreaterOrEqual("3.5.0") { + if stream.Server.LibraryVersion().GreaterOrEqual("3.5.0") { version, err := stream.ReadUInt8() if err != nil { return nil, fmt.Errorf("Failed to read NEX Structure version. %s", err.Error()) @@ -957,7 +955,7 @@ func (stream *StreamIn) ReadListDataHolder() ([]*DataHolder, error) { } // NewStreamIn returns a new NEX input stream -func NewStreamIn(data []byte, server *Server) *StreamIn { +func NewStreamIn(data []byte, server ServerInterface) *StreamIn { return &StreamIn{ Buffer: crunch.NewBuffer(data), Server: server, diff --git a/stream_out.go b/stream_out.go index cf1ee05d..50e66979 100644 --- a/stream_out.go +++ b/stream_out.go @@ -9,7 +9,7 @@ import ( // StreamOut is an abstraction of github.com/superwhiskers/crunch with nex type support type StreamOut struct { *crunch.Buffer - Server *Server + Server ServerInterface } // WriteBool writes a bool @@ -166,7 +166,7 @@ func (stream *StreamOut) WriteQBuffer(data []byte) { // WriteResult writes a NEX Result type func (stream *StreamOut) WriteResult(result *Result) { - stream.WriteUInt32LE(result.code) + stream.WriteUInt32LE(result.Code) } // WriteStructure writes a nex Structure type @@ -177,9 +177,7 @@ func (stream *StreamOut) WriteStructure(structure StructureInterface) { content := structure.Bytes(NewStreamOut(stream.Server)) - nexVersion := stream.Server.NEXVersion() - - if nexVersion.GreaterOrEqual("3.5.0") { + if stream.Server.LibraryVersion().GreaterOrEqual("3.5.0") { stream.WriteUInt8(structure.StructureVersion()) stream.WriteUInt32LE(uint32(len(content))) } @@ -480,7 +478,7 @@ func (stream *StreamOut) WriteMap(mapType interface{}) { } // NewStreamOut returns a new nex output stream -func NewStreamOut(server *Server) *StreamOut { +func NewStreamOut(server ServerInterface) *StreamOut { return &StreamOut{ Buffer: crunch.NewBuffer(), Server: server, diff --git a/sum.go b/sum.go index 116cd294..3255ff13 100644 --- a/sum.go +++ b/sum.go @@ -1,10 +1,11 @@ package nex -func sum(slice []byte) int { - total := 0 - for _, value := range slice { - total += int(value) - } +import "golang.org/x/exp/constraints" - return total +func sum[T, O constraints.Integer](data []T) O { + var result O + for _, b := range data { + result += O(b) + } + return result } diff --git a/test/auth.go b/test/auth.go new file mode 100644 index 00000000..b0ec6ebf --- /dev/null +++ b/test/auth.go @@ -0,0 +1,154 @@ +package main + +import ( + "fmt" + "strconv" + + "github.com/PretendoNetwork/nex-go" +) + +var authServer *nex.PRUDPServer + +func startAuthenticationServer() { + fmt.Println("Starting auth") + + authServer = nex.NewPRUDPServer() + + authServer.OnReliableData(func(packet nex.PacketInterface) { + if packet, ok := packet.(nex.PRUDPPacketInterface); ok { + request := packet.RMCMessage() + + fmt.Println("[AUTH]", request.ProtocolID, request.MethodID) + + if request.ProtocolID == 0xA { // * Ticket Granting + if request.MethodID == 0x1 { + login(packet) + } + + if request.MethodID == 0x3 { + requestTicket(packet) + } + } + } + }) + + authServer.SetFragmentSize(962) + //authServer.PRUDPVersion = 1 + authServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) + authServer.SetKerberosPassword([]byte("password")) + authServer.SetKerberosKeySize(16) + authServer.SetAccessKey("ridfebb9") + authServer.Listen(60000) +} + +func login(packet nex.PRUDPPacketInterface) { + request := packet.RMCMessage() + response := nex.NewRMCMessage() + + parameters := request.Parameters + + parametersStream := nex.NewStreamIn(parameters, authServer) + + strUserName, err := parametersStream.ReadString() + if err != nil { + panic(err) + } + + converted, err := strconv.Atoi(strUserName) + if err != nil { + panic(err) + } + + retval := nex.NewResultSuccess(0x00010001) + pidPrincipal := uint32(converted) + pbufResponse := generateTicket(pidPrincipal, 2) + pConnectionData := nex.NewRVConnectionData() + strReturnMsg := "Test Build" + + pConnectionData.SetStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") + pConnectionData.SetSpecialProtocols([]byte{}) + pConnectionData.SetStationURLSpecialProtocols("") + serverTime := nex.NewDateTime(0) + pConnectionData.SetTime(nex.NewDateTime(serverTime.UTC())) + + responseStream := nex.NewStreamOut(authServer) + + responseStream.WriteResult(retval) + responseStream.WriteUInt32LE(pidPrincipal) + responseStream.WriteBuffer(pbufResponse) + responseStream.WriteStructure(pConnectionData) + responseStream.WriteString(strReturnMsg) + + response.IsSuccess = true + response.IsRequest = false + response.ErrorCode = 0x00010001 + response.ProtocolID = request.ProtocolID + response.CallID = request.CallID + response.MethodID = request.MethodID + response.Parameters = responseStream.Bytes() + + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + + responsePacket.SetType(packet.Type()) + responsePacket.AddFlag(nex.FlagHasSize) + responsePacket.AddFlag(nex.FlagReliable) + responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.SetSourceStreamType(packet.DestinationStreamType()) + responsePacket.SetSourcePort(packet.DestinationPort()) + responsePacket.SetDestinationStreamType(packet.SourceStreamType()) + responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSubstreamID(packet.SubstreamID()) + responsePacket.SetPayload(response.Bytes()) + + authServer.Send(responsePacket) +} + +func requestTicket(packet nex.PRUDPPacketInterface) { + request := packet.RMCMessage() + response := nex.NewRMCMessage() + + parameters := request.Parameters + + parametersStream := nex.NewStreamIn(parameters, authServer) + + idSource, err := parametersStream.ReadUInt32LE() + if err != nil { + panic(err) + } + + idTarget, err := parametersStream.ReadUInt32LE() + if err != nil { + panic(err) + } + + retval := nex.NewResultSuccess(0x00010001) + pbufResponse := generateTicket(idSource, idTarget) + + responseStream := nex.NewStreamOut(authServer) + + responseStream.WriteResult(retval) + responseStream.WriteBuffer(pbufResponse) + + response.IsSuccess = true + response.IsRequest = false + response.ErrorCode = 0x00010001 + response.ProtocolID = request.ProtocolID + response.CallID = request.CallID + response.MethodID = request.MethodID + response.Parameters = responseStream.Bytes() + + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + + responsePacket.SetType(packet.Type()) + responsePacket.AddFlag(nex.FlagHasSize) + responsePacket.AddFlag(nex.FlagReliable) + responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.SetSourceStreamType(packet.DestinationStreamType()) + responsePacket.SetSourcePort(packet.DestinationPort()) + responsePacket.SetDestinationStreamType(packet.SourceStreamType()) + responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSubstreamID(packet.SubstreamID()) + responsePacket.SetPayload(response.Bytes()) + + authServer.Send(responsePacket) +} diff --git a/test/generate_ticket.go b/test/generate_ticket.go new file mode 100644 index 00000000..95951e7f --- /dev/null +++ b/test/generate_ticket.go @@ -0,0 +1,34 @@ +package main + +import ( + "crypto/rand" + + "github.com/PretendoNetwork/nex-go" +) + +func generateTicket(userPID uint32, targetPID uint32) []byte { + userKey := nex.DeriveKerberosKey(userPID, []byte("abcdefghijklmnop")) + targetKey := nex.DeriveKerberosKey(targetPID, []byte("password")) + sessionKey := make([]byte, authServer.KerberosKeySize()) + + rand.Read(sessionKey) + + ticketInternalData := nex.NewKerberosTicketInternalData() + serverTime := nex.NewDateTime(0) + serverTime.UTC() + + ticketInternalData.Issued = serverTime + ticketInternalData.SourcePID = userPID + ticketInternalData.SessionKey = sessionKey + + encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewStreamOut(authServer)) + + ticket := nex.NewKerberosTicket() + ticket.SessionKey = sessionKey + ticket.TargetPID = targetPID + ticket.InternalData = encryptedTicketInternalData + + encryptedTicket, _ := ticket.Encrypt(userKey, nex.NewStreamOut(authServer)) + + return encryptedTicket +} diff --git a/test/main.go b/test/main.go new file mode 100644 index 00000000..d559d002 --- /dev/null +++ b/test/main.go @@ -0,0 +1,14 @@ +package main + +import "sync" + +var wg sync.WaitGroup + +func main() { + wg.Add(2) + + go startAuthenticationServer() + go startSecureServer() + + wg.Wait() +} diff --git a/test/secure.go b/test/secure.go new file mode 100644 index 00000000..b587bf76 --- /dev/null +++ b/test/secure.go @@ -0,0 +1,254 @@ +package main + +import ( + "fmt" + "net" + + "github.com/PretendoNetwork/nex-go" +) + +var secureServer *nex.PRUDPServer + +// * Took these structs out of the protocols lib for convenience + +type PrincipalPreference struct { + nex.Structure + *nex.Data + ShowOnlinePresence bool + ShowCurrentTitle bool + BlockFriendRequests bool +} + +func (pp *PrincipalPreference) Bytes(stream *nex.StreamOut) []byte { + stream.WriteBool(pp.ShowOnlinePresence) + stream.WriteBool(pp.ShowCurrentTitle) + stream.WriteBool(pp.BlockFriendRequests) + + return stream.Bytes() +} + +type Comment struct { + nex.Structure + *nex.Data + Unknown uint8 + Contents string + LastChanged *nex.DateTime +} + +func (c *Comment) Bytes(stream *nex.StreamOut) []byte { + stream.WriteUInt8(c.Unknown) + stream.WriteString(c.Contents) + stream.WriteDateTime(c.LastChanged) + + return stream.Bytes() +} + +func startSecureServer() { + fmt.Println("Starting secure") + + secureServer = nex.NewPRUDPServer() + + secureServer.OnReliableData(func(packet nex.PacketInterface) { + if packet, ok := packet.(nex.PRUDPPacketInterface); ok { + request := packet.RMCMessage() + + fmt.Println("[SECR]", request.ProtocolID, request.MethodID) + + if request.ProtocolID == 0xB { // * Secure Connection + if request.MethodID == 0x4 { + registerEx(packet) + } + } + + if request.ProtocolID == 0x66 { // * Friends (WiiU) + if request.MethodID == 1 { + updateAndGetAllInformation(packet) + } else if request.MethodID == 19 { + checkSettingStatus(packet) + } else if request.MethodID == 13 { + updatePresence(packet) + } else { + panic(fmt.Sprintf("Unknown method %d", request.MethodID)) + } + } + } + }) + + secureServer.IsSecureServer = true + //secureServer.PRUDPVersion = 1 + secureServer.SetFragmentSize(962) + secureServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) + secureServer.SetKerberosPassword([]byte("password")) + secureServer.SetKerberosKeySize(16) + secureServer.SetAccessKey("ridfebb9") + secureServer.Listen(60001) +} + +func registerEx(packet nex.PRUDPPacketInterface) { + request := packet.RMCMessage() + response := nex.NewRMCMessage() + + parameters := request.Parameters + + parametersStream := nex.NewStreamIn(parameters, authServer) + + vecMyURLs, err := parametersStream.ReadListStationURL() + if err != nil { + panic(err) + } + + _, err = parametersStream.ReadDataHolder() + if err != nil { + fmt.Println(err) + } + + localStation := vecMyURLs[0] + + address := packet.Sender().Address().(*net.UDPAddr).IP.String() + + localStation.SetAddress(address) + localStation.SetPort(uint32(packet.Sender().Address().(*net.UDPAddr).Port)) + + retval := nex.NewResultSuccess(0x00010001) + localStationURL := localStation.EncodeToString() + + responseStream := nex.NewStreamOut(authServer) + + responseStream.WriteResult(retval) + responseStream.WriteUInt32LE(secureServer.ConnectionIDCounter().Next()) + responseStream.WriteString(localStationURL) + + response.IsSuccess = true + response.IsRequest = false + response.ErrorCode = 0x00010001 + response.ProtocolID = request.ProtocolID + response.CallID = request.CallID + response.MethodID = request.MethodID + response.Parameters = responseStream.Bytes() + + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + + responsePacket.SetType(packet.Type()) + responsePacket.AddFlag(nex.FlagHasSize) + responsePacket.AddFlag(nex.FlagReliable) + responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.SetSourceStreamType(packet.DestinationStreamType()) + responsePacket.SetSourcePort(packet.DestinationPort()) + responsePacket.SetDestinationStreamType(packet.SourceStreamType()) + responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSubstreamID(packet.SubstreamID()) + responsePacket.SetPayload(response.Bytes()) + + secureServer.Send(responsePacket) +} + +func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { + request := packet.RMCMessage() + response := nex.NewRMCMessage() + + principalPreference := &PrincipalPreference{ + ShowOnlinePresence: true, + ShowCurrentTitle: true, + BlockFriendRequests: false, + } + + comment := &Comment{ + Unknown: 0, + Contents: "Rewrite Test", + LastChanged: nex.NewDateTime(0), + } + + responseStream := nex.NewStreamOut(authServer) + + responseStream.WriteStructure(principalPreference) + responseStream.WriteStructure(comment) + responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendList) + responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsOut) + responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsIn) + responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(blockList) + responseStream.WriteBool(false) // * Unknown + responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(notifications) + responseStream.WriteBool(false) // * Unknown + + response.IsSuccess = true + response.IsRequest = false + response.ErrorCode = 0x00010001 + response.ProtocolID = request.ProtocolID + response.CallID = request.CallID + response.MethodID = request.MethodID + response.Parameters = responseStream.Bytes() + + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + + responsePacket.SetType(packet.Type()) + responsePacket.AddFlag(nex.FlagHasSize) + responsePacket.AddFlag(nex.FlagReliable) + responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.SetSourceStreamType(packet.DestinationStreamType()) + responsePacket.SetSourcePort(packet.DestinationPort()) + responsePacket.SetDestinationStreamType(packet.SourceStreamType()) + responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSubstreamID(packet.SubstreamID()) + responsePacket.SetPayload(response.Bytes()) + + secureServer.Send(responsePacket) +} + +func checkSettingStatus(packet nex.PRUDPPacketInterface) { + request := packet.RMCMessage() + response := nex.NewRMCMessage() + + responseStream := nex.NewStreamOut(authServer) + + responseStream.WriteUInt8(0) // * Unknown + + response.IsSuccess = true + response.IsRequest = false + response.ErrorCode = 0x00010001 + response.ProtocolID = request.ProtocolID + response.CallID = request.CallID + response.MethodID = request.MethodID + response.Parameters = responseStream.Bytes() + + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + + responsePacket.SetType(packet.Type()) + responsePacket.AddFlag(nex.FlagHasSize) + responsePacket.AddFlag(nex.FlagReliable) + responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.SetSourceStreamType(packet.DestinationStreamType()) + responsePacket.SetSourcePort(packet.DestinationPort()) + responsePacket.SetDestinationStreamType(packet.SourceStreamType()) + responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSubstreamID(packet.SubstreamID()) + responsePacket.SetPayload(response.Bytes()) + + secureServer.Send(responsePacket) +} + +func updatePresence(packet nex.PRUDPPacketInterface) { + request := packet.RMCMessage() + response := nex.NewRMCMessage() + + response.IsSuccess = true + response.IsRequest = false + response.ErrorCode = 0x00010001 + response.ProtocolID = request.ProtocolID + response.CallID = request.CallID + response.MethodID = request.MethodID + + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + + responsePacket.SetType(packet.Type()) + responsePacket.AddFlag(nex.FlagHasSize) + responsePacket.AddFlag(nex.FlagReliable) + responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.SetSourceStreamType(packet.DestinationStreamType()) + responsePacket.SetSourcePort(packet.DestinationPort()) + responsePacket.SetDestinationStreamType(packet.SourceStreamType()) + responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSubstreamID(packet.SubstreamID()) + responsePacket.SetPayload(response.Bytes()) + + secureServer.Send(responsePacket) +} diff --git a/types.go b/types.go index a1bb8f5d..d08e03f1 100644 --- a/types.go +++ b/types.go @@ -163,7 +163,6 @@ func (dataHolder *DataHolder) ExtractFromStream(stream *StreamIn) error { if dataType == nil { // TODO - Should we really log this here, or just pass the error to the caller? message := fmt.Sprintf("UNKNOWN DATAHOLDER TYPE: %s", dataHolder.typeName) - logger.Critical(message) return errors.New(message) } @@ -299,13 +298,11 @@ func (rvConnectionData *RVConnectionData) SetTime(time *DateTime) { // Bytes encodes the RVConnectionData and returns a byte array func (rvConnectionData *RVConnectionData) Bytes(stream *StreamOut) []byte { - nexVersion := stream.Server.NEXVersion() - stream.WriteString(rvConnectionData.stationURL) stream.WriteListUInt8(rvConnectionData.specialProtocols) stream.WriteString(rvConnectionData.stationURLSpecialProtocols) - if nexVersion.GreaterOrEqual("3.5.0") { + if stream.Server.LibraryVersion().GreaterOrEqual("3.5.0") { rvConnectionData.SetStructureVersion(1) stream.WriteDateTime(rvConnectionData.time) } @@ -958,17 +955,17 @@ func NewStationURL(str string) *StationURL { // Result is sent in methods which query large objects type Result struct { - code uint32 + Code uint32 } // IsSuccess returns true if the Result is a success func (result *Result) IsSuccess() bool { - return int(result.code)&errorMask == 0 + return int(result.Code)&errorMask == 0 } // IsError returns true if the Result is a error func (result *Result) IsError() bool { - return int(result.code)&errorMask != 0 + return int(result.Code)&errorMask != 0 } // ExtractFromStream extracts a Result structure from a stream @@ -978,26 +975,26 @@ func (result *Result) ExtractFromStream(stream *StreamIn) error { return fmt.Errorf("Failed to read Result code. %s", err.Error()) } - result.code = code + result.Code = code return nil } // Bytes encodes the Result and returns a byte array func (result *Result) Bytes(stream *StreamOut) []byte { - stream.WriteUInt32LE(result.code) + stream.WriteUInt32LE(result.Code) return stream.Bytes() } // Copy returns a new copied instance of Result func (result *Result) Copy() *Result { - return NewResult(result.code) + return NewResult(result.Code) } // Equals checks if the passed Structure contains the same data as the current instance func (result *Result) Equals(other *Result) bool { - return result.code == other.code + return result.Code == other.Code } // String returns a string representation of the struct @@ -1015,9 +1012,9 @@ func (result *Result) FormatToString(indentationLevel int) string { b.WriteString("Result{\n") if result.IsSuccess() { - b.WriteString(fmt.Sprintf("%scode: %d (success)\n", indentationValues, result.code)) + b.WriteString(fmt.Sprintf("%scode: %d (success)\n", indentationValues, result.Code)) } else { - b.WriteString(fmt.Sprintf("%scode: %d (error)\n", indentationValues, result.code)) + b.WriteString(fmt.Sprintf("%scode: %d (error)\n", indentationValues, result.Code)) } b.WriteString(fmt.Sprintf("%s}", indentationEnd)) From b4b55e217b911c84e6d4bc163458b6acd52dfce5 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 10 Nov 2023 21:18:15 -0500 Subject: [PATCH 002/178] added OnReliableData to server interface --- server_interface.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server_interface.go b/server_interface.go index b992d18e..304b28ab 100644 --- a/server_interface.go +++ b/server_interface.go @@ -14,4 +14,5 @@ type ServerInterface interface { NATTraversalProtocolVersion() *LibraryVersion SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) + OnReliableData(handler func(PacketInterface)) } From 10febb70ccc4f98459ed71f9dbc40dd42d93b069 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 10 Nov 2023 21:20:04 -0500 Subject: [PATCH 003/178] Added PasswordFromPID to PRUDP server --- prudp_server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/prudp_server.go b/prudp_server.go index 3bfac681..792a058e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -33,6 +33,7 @@ type PRUDPServer struct { eventHandlers map[string][]func(PacketInterface) connectionIDCounter *Counter[uint32] pingTimeout time.Duration + PasswordFromPID func(pid uint32) (string, uint32) } // OnReliableData adds an event handler which is fired when a new reliable DATA packet is received From 6985ea422a3a81d9ff0bdfb2c04208b5bc49627d Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 10 Nov 2023 22:48:23 -0500 Subject: [PATCH 004/178] export source and destination data on PRUDP client --- prudp_client.go | 16 ++++++++-------- prudp_server.go | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/prudp_client.go b/prudp_client.go index 83a21b74..bf67c436 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -18,10 +18,10 @@ type PRUDPClient struct { outgoingPingSequenceIDCounter *Counter[uint16] heartbeatTimer *time.Timer pingKickTimer *time.Timer - sourceStreamType uint8 - sourcePort uint8 - destinationStreamType uint8 - destinationPort uint8 + SourceStreamType uint8 + SourcePort uint8 + DestinationStreamType uint8 + DestinationPort uint8 minorVersion uint32 // * Not currently used for anything, but maybe useful later? supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? } @@ -38,10 +38,10 @@ func (c *PRUDPClient) reset() { c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) c.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](0) c.outgoingPingSequenceIDCounter = NewCounter[uint16](0) - c.sourceStreamType = 0 - c.sourcePort = 0 - c.destinationStreamType = 0 - c.destinationPort = 0 + c.SourceStreamType = 0 + c.SourcePort = 0 + c.DestinationStreamType = 0 + c.DestinationPort = 0 } // Cleanup cleans up any resources the client may be using diff --git a/prudp_server.go b/prudp_server.go index 792a058e..783e1bdc 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -228,10 +228,10 @@ func (s *PRUDPServer) handleSyn(packet PRUDPPacketInterface) { client.reset() client.clientConnectionSignature = connectionSignature - client.sourceStreamType = packet.SourceStreamType() - client.sourcePort = packet.SourcePort() - client.destinationStreamType = packet.DestinationStreamType() - client.destinationPort = packet.DestinationPort() + client.SourceStreamType = packet.SourceStreamType() + client.SourcePort = packet.SourcePort() + client.DestinationStreamType = packet.DestinationStreamType() + client.DestinationPort = packet.DestinationPort() ack.SetType(SynPacket) ack.AddFlag(FlagAck) @@ -469,10 +469,10 @@ func (s *PRUDPServer) sendPing(client *PRUDPClient) { ping.SetType(PingPacket) ping.AddFlag(FlagNeedsAck) - ping.SetSourceStreamType(client.destinationStreamType) - ping.SetSourcePort(client.destinationPort) - ping.SetDestinationStreamType(client.sourceStreamType) - ping.SetDestinationPort(client.sourcePort) + ping.SetSourceStreamType(client.DestinationStreamType) + ping.SetSourcePort(client.DestinationPort) + ping.SetDestinationStreamType(client.SourceStreamType) + ping.SetDestinationPort(client.SourcePort) ping.SetSubstreamID(0) s.sendPacket(ping) From f6782dab70465be52e01e771054f6f3ecd97222d Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 10 Nov 2023 22:50:23 -0500 Subject: [PATCH 005/178] only write RMC parameters if they exist --- rmc.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/rmc.go b/rmc.go index 93075b19..80474e97 100644 --- a/rmc.go +++ b/rmc.go @@ -124,15 +124,19 @@ func (rmc *RMCMessage) Bytes() []byte { if rmc.IsRequest { stream.WriteUInt32LE(rmc.CallID) stream.WriteUInt32LE(rmc.MethodID) - stream.Grow(int64(len(rmc.Parameters))) - stream.WriteBytesNext(rmc.Parameters) + if rmc.Parameters != nil && len(rmc.Parameters) > 0 { + stream.Grow(int64(len(rmc.Parameters))) + stream.WriteBytesNext(rmc.Parameters) + } } else { if rmc.IsSuccess { stream.WriteBool(true) stream.WriteUInt32LE(rmc.CallID) stream.WriteUInt32LE(rmc.MethodID | 0x8000) - stream.Grow(int64(len(rmc.Parameters))) - stream.WriteBytesNext(rmc.Parameters) + if rmc.Parameters != nil && len(rmc.Parameters) > 0 { + stream.Grow(int64(len(rmc.Parameters))) + stream.WriteBytesNext(rmc.Parameters) + } } else { stream.WriteBool(false) stream.WriteUInt32LE(uint32(rmc.ErrorCode)) From abee5158148f9999fe7adbafea2c993c86c3ab16 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 11 Nov 2023 12:53:26 -0500 Subject: [PATCH 006/178] fixed casing in godoc comments --- prudp_client.go | 14 +++++++------- prudp_packet_v1.go | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/prudp_client.go b/prudp_client.go index bf67c436..9dac45e7 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -26,7 +26,7 @@ type PRUDPClient struct { supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? } -// Reset sets the client back to it's default state +// reset sets the client back to it's default state func (c *PRUDPClient) reset() { for _, substream := range c.reliableSubstreams { substream.ResendScheduler.Stop() @@ -44,10 +44,10 @@ func (c *PRUDPClient) reset() { c.DestinationPort = 0 } -// Cleanup cleans up any resources the client may be using +// cleanup cleans up any resources the client may be using // -// This is similar to Client.Reset(), with the key difference -// being that Cleanup does not care about the state the client +// This is similar to Client.reset(), with the key difference +// being that cleanup does not care about the state the client // is currently in, or will be in, after execution. It only // frees resources that are not easily garbage collected func (c *PRUDPClient) cleanup() { @@ -79,7 +79,7 @@ func (c *PRUDPClient) SetPID(pid uint32) { c.pid = pid } -// SetSessionKey sets the clients session key used for reliable RC4 ciphers +// setSessionKey sets the clients session key used for reliable RC4 ciphers func (c *PRUDPClient) setSessionKey(sessionKey []byte) { c.sessionKey = sessionKey @@ -103,12 +103,12 @@ func (c *PRUDPClient) setSessionKey(sessionKey []byte) { } } -// ReliableSubstream returns the clients reliable substream ID +// reliableSubstream returns the clients reliable substream ID func (c *PRUDPClient) reliableSubstream(substreamID uint8) *ReliablePacketSubstreamManager { return c.reliableSubstreams[substreamID] } -// CreateReliableSubstreams creates the list of substreams used for reliable PRUDP packets +// createReliableSubstreams creates the list of substreams used for reliable PRUDP packets func (c *PRUDPClient) createReliableSubstreams(maxSubstreamID uint8) { substreams := maxSubstreamID + 1 diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index c606cf3e..9f6b1398 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -26,7 +26,7 @@ func (p *PRUDPPacketV1) Version() int { return 1 } -// Decode parses the packets data +// decode parses the packets data func (p *PRUDPPacketV1) decode() error { if p.readStream.Remaining() < 2 { return errors.New("Failed to read PRUDPv1 magic. Not have enough data") From 0bf62394cffb23dfb4129df1d086d1b8a56ab18f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 11 Nov 2023 13:26:17 -0500 Subject: [PATCH 007/178] fix PRUDPv1 configuration negotiation --- prudp_server.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/prudp_server.go b/prudp_server.go index 783e1bdc..5bffa41f 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -17,6 +17,7 @@ type PRUDPServer struct { PRUDPVersion int IsQuazalMode bool IsSecureServer bool + SupportedFunctions uint32 accessKey string kerberosPassword []byte kerberosTicketVersion int @@ -243,6 +244,13 @@ func (s *PRUDPServer) handleSyn(packet PRUDPPacketInterface) { ack.setConnectionSignature(connectionSignature) ack.setSignature(ack.calculateSignature([]byte{}, []byte{})) + if ack, ok := ack.(*PRUDPPacketV1); ok { + // * Negotiate with the client what we support + ack.maximumSubstreamID = packet.(*PRUDPPacketV1).maximumSubstreamID // * No change needed, we can just support what the client wants + ack.minorVersion = packet.(*PRUDPPacketV1).minorVersion // * No change needed, we can just support what the client wants + ack.supportedFunctions = s.SupportedFunctions & packet.(*PRUDPPacketV1).supportedFunctions + } + s.emit("syn", ack) s.sendRaw(client.address, ack.Bytes()) @@ -275,7 +283,10 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { ack.SetSequenceID(1) if ack, ok := ack.(*PRUDPPacketV1); ok { - // * Just tell the client we support exactly what it wants + // * At this stage the client and server have already + // * negotiated what they each can support, so configure + // * the client now and just send the client back the + // * negotiated configuration ack.maximumSubstreamID = packet.(*PRUDPPacketV1).maximumSubstreamID ack.minorVersion = packet.(*PRUDPPacketV1).minorVersion ack.supportedFunctions = packet.(*PRUDPPacketV1).supportedFunctions From 3a41b79b63ad4658660d6173b59dfca0b722cfc9 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 11 Nov 2023 14:34:13 -0500 Subject: [PATCH 008/178] use protocol minor version in streams --- prudp_server.go | 11 +++++++++++ server_interface.go | 2 ++ stream_in.go | 2 +- stream_out.go | 2 +- 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 5bffa41f..ce810ded 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -15,6 +15,7 @@ type PRUDPServer struct { udpSocket *net.UDPConn clients *MutexMap[string, *PRUDPClient] PRUDPVersion int + protocolMinorVersion uint32 IsQuazalMode bool IsSecureServer bool SupportedFunctions uint32 @@ -695,6 +696,16 @@ func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { return s.connectionIDCounter } +// SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version +func (s *PRUDPServer) SetProtocolMinorVersion(protocolMinorVersion uint32) { + s.protocolMinorVersion = protocolMinorVersion +} + +// ProtocolMinorVersion returns the servers PRUDP protocol minor version +func (s *PRUDPServer) ProtocolMinorVersion() uint32 { + return s.protocolMinorVersion +} + // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ diff --git a/server_interface.go b/server_interface.go index 304b28ab..1a5a7de4 100644 --- a/server_interface.go +++ b/server_interface.go @@ -4,6 +4,8 @@ package nex type ServerInterface interface { AccessKey() string SetAccessKey(accessKey string) + SetProtocolMinorVersion(protocolMinorVersion uint32) + ProtocolMinorVersion() uint32 LibraryVersion() *LibraryVersion DataStoreProtocolVersion() *LibraryVersion MatchMakingProtocolVersion() *LibraryVersion diff --git a/stream_in.go b/stream_in.go index 535c66e3..0c86bc1c 100644 --- a/stream_in.go +++ b/stream_in.go @@ -255,7 +255,7 @@ func (stream *StreamIn) ReadStructure(structure StructureInterface) (StructureIn } } - if stream.Server.LibraryVersion().GreaterOrEqual("3.5.0") { + if stream.Server.ProtocolMinorVersion() >= 3 { version, err := stream.ReadUInt8() if err != nil { return nil, fmt.Errorf("Failed to read NEX Structure version. %s", err.Error()) diff --git a/stream_out.go b/stream_out.go index 50e66979..f1ce9de8 100644 --- a/stream_out.go +++ b/stream_out.go @@ -177,7 +177,7 @@ func (stream *StreamOut) WriteStructure(structure StructureInterface) { content := structure.Bytes(NewStreamOut(stream.Server)) - if stream.Server.LibraryVersion().GreaterOrEqual("3.5.0") { + if stream.Server.ProtocolMinorVersion() >= 3 { stream.WriteUInt8(structure.StructureVersion()) stream.WriteUInt32LE(uint32(len(content))) } From deb29eebfd1ade42ad2717cd6c479d4033327114 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 11 Nov 2023 14:47:09 -0500 Subject: [PATCH 009/178] removed hard coded connection signature key --- prudp_packet_v1.go | 8 +----- prudp_server.go | 61 +++++++++++++++++++++++++++------------------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 9f6b1398..bcf60536 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -277,17 +277,11 @@ func (p *PRUDPPacketV1) calculateConnectionSignature(addr net.Addr) ([]byte, err return nil, fmt.Errorf("Unsupported network type: %T", addr) } - // * The real client seems to not care about this. The original - // * server just used rand.Read here. This is done to implement - // * compatibility with NintendoClients, as this is how it - // * calculates PRUDPv1 connection signatures - key := []byte{0x26, 0xc3, 0x1f, 0x38, 0x1e, 0x46, 0xd6, 0xeb, 0x38, 0xe1, 0xaf, 0x6a, 0xb7, 0x0d, 0x11} - portBytes := make([]byte, 2) binary.BigEndian.PutUint16(portBytes, uint16(port)) data := append(ip, portBytes...) - hash := hmac.New(md5.New, key) + hash := hmac.New(md5.New, p.sender.server.PRUDPv1ConnectionSignatureKey) hash.Write(data) return hash.Sum(nil), nil diff --git a/prudp_server.go b/prudp_server.go index ce810ded..eb15bc2d 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -2,6 +2,7 @@ package nex import ( "bytes" + "crypto/rand" "errors" "fmt" "net" @@ -12,30 +13,31 @@ import ( // PRUDPServer represents a bare-bones PRUDP server type PRUDPServer struct { - udpSocket *net.UDPConn - clients *MutexMap[string, *PRUDPClient] - PRUDPVersion int - protocolMinorVersion uint32 - IsQuazalMode bool - IsSecureServer bool - SupportedFunctions uint32 - accessKey string - kerberosPassword []byte - kerberosTicketVersion int - kerberosKeySize int - FragmentSize int - version *LibraryVersion - datastoreProtocolVersion *LibraryVersion - matchMakingProtocolVersion *LibraryVersion - rankingProtocolVersion *LibraryVersion - ranking2ProtocolVersion *LibraryVersion - messagingProtocolVersion *LibraryVersion - utilityProtocolVersion *LibraryVersion - natTraversalProtocolVersion *LibraryVersion - eventHandlers map[string][]func(PacketInterface) - connectionIDCounter *Counter[uint32] - pingTimeout time.Duration - PasswordFromPID func(pid uint32) (string, uint32) + udpSocket *net.UDPConn + clients *MutexMap[string, *PRUDPClient] + PRUDPVersion int + protocolMinorVersion uint32 + IsQuazalMode bool + IsSecureServer bool + SupportedFunctions uint32 + accessKey string + kerberosPassword []byte + kerberosTicketVersion int + kerberosKeySize int + FragmentSize int + version *LibraryVersion + datastoreProtocolVersion *LibraryVersion + matchMakingProtocolVersion *LibraryVersion + rankingProtocolVersion *LibraryVersion + ranking2ProtocolVersion *LibraryVersion + messagingProtocolVersion *LibraryVersion + utilityProtocolVersion *LibraryVersion + natTraversalProtocolVersion *LibraryVersion + eventHandlers map[string][]func(PacketInterface) + connectionIDCounter *Counter[uint32] + pingTimeout time.Duration + PasswordFromPID func(pid uint32) (string, uint32) + PRUDPv1ConnectionSignatureKey []byte } // OnReliableData adds an event handler which is fired when a new reliable DATA packet is received @@ -61,6 +63,15 @@ func (s *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { // Listen starts a PRUDP server on a given port func (s *PRUDPServer) Listen(port int) { + // * Ensure the server has a key for PRUDPv1 connection signatures + if len(s.PRUDPv1ConnectionSignatureKey) != 16 { + s.PRUDPv1ConnectionSignatureKey = make([]byte, 16) + _, err := rand.Read(s.PRUDPv1ConnectionSignatureKey) + if err != nil { + panic(err) + } + } + udpAddress, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err != nil { panic(err) @@ -696,7 +707,7 @@ func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { return s.connectionIDCounter } -// SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version +// SetNATTraversalProtocolVersion sets the servers PRUDP protocol minor version func (s *PRUDPServer) SetProtocolMinorVersion(protocolMinorVersion uint32) { s.protocolMinorVersion = protocolMinorVersion } From 8d404a5066e9bdbb611cc474043c135541d68d69 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 11 Nov 2023 20:15:34 -0500 Subject: [PATCH 010/178] renamed OnReliableData to just OnData --- prudp_server.go | 8 ++++---- server_interface.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index eb15bc2d..1097760e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -40,9 +40,9 @@ type PRUDPServer struct { PRUDPv1ConnectionSignatureKey []byte } -// OnReliableData adds an event handler which is fired when a new reliable DATA packet is received -func (s *PRUDPServer) OnReliableData(handler func(PacketInterface)) { - s.on("reliable-data", handler) +// OnData adds an event handler which is fired when a new DATA packet is received +func (s *PRUDPServer) OnData(handler func(PacketInterface)) { + s.on("data", handler) } func (s *PRUDPServer) on(name string, handler func(PacketInterface)) { @@ -473,7 +473,7 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { packet.SetRMCMessage(message) - s.emit("reliable-data", packet) + s.emit("data", packet) } } } diff --git a/server_interface.go b/server_interface.go index 1a5a7de4..baa2ab54 100644 --- a/server_interface.go +++ b/server_interface.go @@ -16,5 +16,5 @@ type ServerInterface interface { NATTraversalProtocolVersion() *LibraryVersion SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) - OnReliableData(handler func(PacketInterface)) + OnData(handler func(packet PacketInterface)) } From 571a25d155cc2938acb996928aaaf8cda52d9573 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 11 Nov 2023 21:11:32 -0500 Subject: [PATCH 011/178] added ConnectionID field on PRUDPClient --- prudp_client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/prudp_client.go b/prudp_client.go index 9dac45e7..181587a0 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -24,6 +24,7 @@ type PRUDPClient struct { DestinationPort uint8 minorVersion uint32 // * Not currently used for anything, but maybe useful later? supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? + ConnectionID uint32 } // reset sets the client back to it's default state From 9f875c6e5ed3f999597aa5e25f852b5c7c9d5e4b Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 11 Nov 2023 21:28:34 -0500 Subject: [PATCH 012/178] added StationURLs field on PRUDPClient --- prudp_client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/prudp_client.go b/prudp_client.go index 181587a0..fb3745ac 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -25,6 +25,7 @@ type PRUDPClient struct { minorVersion uint32 // * Not currently used for anything, but maybe useful later? supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? ConnectionID uint32 + StationURLs []*StationURL } // reset sets the client back to it's default state From 0e4ff2bec9f7cade0304d9f6500cc9e74c869e80 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 12 Nov 2023 01:28:44 -0500 Subject: [PATCH 013/178] added FindClientByConnectionID field on PRUDPServer --- mutex_map.go | 9 +++++++-- prudp_server.go | 20 +++++++++++++++++++- resend_scheduler.go | 4 +++- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/mutex_map.go b/mutex_map.go index 0324d379..b2d076fc 100644 --- a/mutex_map.go +++ b/mutex_map.go @@ -64,13 +64,18 @@ func (m *MutexMap[K, V]) Size() int { // Each runs a callback function for every item in the map // The map should not be modified inside the callback function -func (m *MutexMap[K, V]) Each(callback func(key K, value V)) { +// Returns true if the loop was terminated early +func (m *MutexMap[K, V]) Each(callback func(key K, value V) bool) bool { m.RLock() defer m.RUnlock() for key, value := range m.real { - callback(key, value) + if callback(key, value) { + return true + } } + + return false } // Clear removes all items from the `real` map diff --git a/prudp_server.go b/prudp_server.go index 1097760e..5f31cee0 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -214,10 +214,12 @@ func (s *PRUDPServer) handleMultiAcknowledgment(packet PRUDPPacketInterface) { // * MutexMap.Each locks the mutex, can't remove while reading. // * Have to just loop again - substream.ResendScheduler.packets.Each(func(sequenceID uint16, pending *PendingPacket) { + substream.ResendScheduler.packets.Each(func(sequenceID uint16, pending *PendingPacket) bool { if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) { sequenceIDs = append(sequenceIDs, sequenceID) } + + return false }) // * Actually remove the packets from the pool @@ -717,6 +719,22 @@ func (s *PRUDPServer) ProtocolMinorVersion() uint32 { return s.protocolMinorVersion } +// FindClientByConnectionID returns the PRUDP client connected with the given connection ID +func (s *PRUDPServer) FindClientByConnectionID(connectedID uint32) *PRUDPClient { + var client *PRUDPClient + + s.clients.Each(func(discriminator string, c *PRUDPClient) bool { + if c.ConnectionID == connectedID { + client = c + return true + } + + return false + }) + + return client +} + // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ diff --git a/resend_scheduler.go b/resend_scheduler.go index c4842759..17e32334 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -41,10 +41,12 @@ type ResendScheduler struct { func (rs *ResendScheduler) Stop() { stillPending := make([]uint16, rs.packets.Size()) - rs.packets.Each(func(sequenceID uint16, packet *PendingPacket) { + rs.packets.Each(func(sequenceID uint16, packet *PendingPacket) bool { if !packet.isAcknowledged { stillPending = append(stillPending, sequenceID) } + + return false }) for _, sequenceID := range stillPending { From 463ce86d0eefef083f3b95a4a28d8952352bd527 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 12 Nov 2023 01:29:37 -0500 Subject: [PATCH 014/178] fixed SetProtocolMinorVersion godoc comment --- prudp_server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prudp_server.go b/prudp_server.go index 5f31cee0..ede9c5f8 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -709,7 +709,7 @@ func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { return s.connectionIDCounter } -// SetNATTraversalProtocolVersion sets the servers PRUDP protocol minor version +// SetProtocolMinorVersion sets the servers PRUDP protocol minor version func (s *PRUDPServer) SetProtocolMinorVersion(protocolMinorVersion uint32) { s.protocolMinorVersion = protocolMinorVersion } From de80be2bc2123910001b96c2ed8cc426a3181a3f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 12 Nov 2023 02:08:49 -0500 Subject: [PATCH 015/178] added FindClientByPID field on PRUDPServer --- prudp_server.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/prudp_server.go b/prudp_server.go index ede9c5f8..1aa531cf 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -735,6 +735,22 @@ func (s *PRUDPServer) FindClientByConnectionID(connectedID uint32) *PRUDPClient return client } +// FindClientByPID returns the PRUDP client connected with the given PID +func (s *PRUDPServer) FindClientByPID(pid uint32) *PRUDPClient { + var client *PRUDPClient + + s.clients.Each(func(discriminator string, c *PRUDPClient) bool { + if c.pid == pid { + client = c + return true + } + + return false + }) + + return client +} + // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ From dfbbc192ba6d301dad05e75f873fc98221050ebf Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 12 Nov 2023 03:01:05 -0500 Subject: [PATCH 016/178] added OnClientRemoved and OnDisconnect on PRUDPServer --- prudp_client.go | 4 +++- prudp_server.go | 40 +++++++++++++++++++++++++++++++--------- resend_scheduler.go | 2 +- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/prudp_client.go b/prudp_client.go index fb3745ac..c6c1ffab 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -59,6 +59,8 @@ func (c *PRUDPClient) cleanup() { c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) c.stopHeartbeatTimers() + + c.server.emitRemoved(c) } // Server returns the server the client is connecting to @@ -156,7 +158,7 @@ func (c *PRUDPClient) startHeartbeat() { // * If the heartbeat still did not restart, assume the // * client is dead and clean up c.pingKickTimer = time.AfterFunc(server.pingTimeout, func() { - c.cleanup() + c.cleanup() // * "removed" event is dispatched here c.server.clients.Delete(c.address.String()) }) }) diff --git a/prudp_server.go b/prudp_server.go index 1aa531cf..3e4a172a 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -33,7 +33,8 @@ type PRUDPServer struct { messagingProtocolVersion *LibraryVersion utilityProtocolVersion *LibraryVersion natTraversalProtocolVersion *LibraryVersion - eventHandlers map[string][]func(PacketInterface) + prudpEventHandlers map[string][]func(packet PacketInterface) + clientRemovedEventHandlers []func(client *PRUDPClient) connectionIDCounter *Counter[uint32] pingTimeout time.Duration PasswordFromPID func(pid uint32) (string, uint32) @@ -41,26 +42,47 @@ type PRUDPServer struct { } // OnData adds an event handler which is fired when a new DATA packet is received -func (s *PRUDPServer) OnData(handler func(PacketInterface)) { +func (s *PRUDPServer) OnData(handler func(packet PacketInterface)) { s.on("data", handler) } -func (s *PRUDPServer) on(name string, handler func(PacketInterface)) { - if _, ok := s.eventHandlers[name]; !ok { - s.eventHandlers[name] = make([]func(PacketInterface), 0) +// OnDisconnect adds an event handler which is fired when a new DISCONNECT packet is received +// +// To handle a client being removed from the server, see OnClientRemoved which fires on more cases +func (s *PRUDPServer) OnDisconnect(handler func(packet PacketInterface)) { + s.on("disconnect", handler) +} + +// OnClientRemoved adds an event handler which is fired when a client is removed from the server +// +// Fires both on a natural disconnect and from a timeout +func (s *PRUDPServer) OnClientRemoved(handler func(client *PRUDPClient)) { + // * "removed" events are a special case, so handle them separately + s.clientRemovedEventHandlers = append(s.clientRemovedEventHandlers, handler) +} + +func (s *PRUDPServer) on(name string, handler func(packet PacketInterface)) { + if _, ok := s.prudpEventHandlers[name]; !ok { + s.prudpEventHandlers[name] = make([]func(packet PacketInterface), 0) } - s.eventHandlers[name] = append(s.eventHandlers[name], handler) + s.prudpEventHandlers[name] = append(s.prudpEventHandlers[name], handler) } func (s *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { - if handlers, ok := s.eventHandlers[name]; ok { + if handlers, ok := s.prudpEventHandlers[name]; ok { for _, handler := range handlers { go handler(packet) } } } +func (s *PRUDPServer) emitRemoved(client *PRUDPClient) { + for _, handler := range s.clientRemovedEventHandlers { + go handler(client) + } +} + // Listen starts a PRUDP server on a given port func (s *PRUDPServer) Listen(port int) { // * Ensure the server has a key for PRUDPv1 connection signatures @@ -359,7 +381,7 @@ func (s *PRUDPServer) handleDisconnect(packet PRUDPPacketInterface) { client := packet.Sender().(*PRUDPClient) - client.cleanup() + client.cleanup() // * "removed" event is dispatched here s.clients.Delete(client.address.String()) s.emit("disconnect", packet) @@ -758,7 +780,7 @@ func NewPRUDPServer() *PRUDPServer { IsQuazalMode: false, kerberosKeySize: 32, FragmentSize: 1300, - eventHandlers: make(map[string][]func(PacketInterface)), + prudpEventHandlers: make(map[string][]func(PacketInterface)), connectionIDCounter: NewCounter[uint32](10), pingTimeout: time.Second * 15, } diff --git a/resend_scheduler.go b/resend_scheduler.go index 17e32334..3a0a3893 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -93,7 +93,7 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { // * The maximum resend count has been reached, consider the client dead. pendingPacket.ticker.Stop() rs.packets.Delete(packet.SequenceID()) - client.cleanup() + client.cleanup() // * "removed" event is dispatched here client.server.clients.Delete(client.address.String()) return } From 18c6d73fc6e601e72cc1ef5cdaa14a08dd0b224b Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 12 Nov 2023 06:49:29 -0500 Subject: [PATCH 017/178] added debug logs. delete later --- prudp_server.go | 32 +++++++++++++++++++++++++++++++- resend_scheduler.go | 29 +++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 3e4a172a..77529645 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -151,6 +151,12 @@ func (s *PRUDPServer) handleSocketMessage() error { var packets []PRUDPPacketInterface + if s.IsSecureServer { + fmt.Printf("[SECR] Got packet data %x\n", packetData) + } else { + fmt.Printf("[AUTH] Got packet data %x\n", packetData) + } + // * Support any packet type the client sends and respond // * with that same type. Also keep reading from the stream // * until no more data is left, to account for multiple @@ -162,7 +168,7 @@ func (s *PRUDPServer) handleSocketMessage() error { } for _, packet := range packets { - s.processPacket(packet) + go s.processPacket(packet) } return nil @@ -196,6 +202,12 @@ func (s *PRUDPServer) handleAcknowledgment(packet PRUDPPacketInterface) { return } + if s.IsSecureServer { + fmt.Println("[SECR] Got ACK for SequenceID", packet.SequenceID()) + } else { + fmt.Println("[AUTH] Got ACK for SequenceID", packet.SequenceID()) + } + client := packet.Sender().(*PRUDPClient) substream := client.reliableSubstream(packet.SubstreamID()) @@ -580,6 +592,24 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { substream.ResendScheduler.AddPacket(packet) } + if packet.Type() == DataPacket && packet.RMCMessage() != nil { + if s.IsSecureServer { + fmt.Println("[SECR] ======= SENDING =======") + fmt.Println("[SECR] ProtocolID:", packet.RMCMessage().ProtocolID) + fmt.Println("[SECR] MethodID:", packet.RMCMessage().MethodID) + fmt.Println("[SECR] FragmentID:", packet.getFragmentID()) + fmt.Println("[SECR] SequenceID:", packet.SequenceID()) + fmt.Println("[SECR] =======================") + } else { + fmt.Println("[AUTH] ======= SENDING =======") + fmt.Println("[AUTH] ProtocolID:", packet.RMCMessage().ProtocolID) + fmt.Println("[AUTH] MethodID:", packet.RMCMessage().MethodID) + fmt.Println("[AUTH] FragmentID:", packet.getFragmentID()) + fmt.Println("[AUTH] SequenceID:", packet.SequenceID()) + fmt.Println("[AUTH] =======================") + } + } + s.sendRaw(packet.Sender().Address(), packet.Bytes()) } diff --git a/resend_scheduler.go b/resend_scheduler.go index 3a0a3893..d5fbe536 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -1,6 +1,7 @@ package nex import ( + "fmt" "time" ) @@ -66,6 +67,12 @@ func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { interval: rs.Interval, } + if packet.Sender().Server().(*PRUDPServer).IsSecureServer { + fmt.Println("[SECR] Adding packet", packet.SequenceID(), "to resend queue") + } else { + fmt.Println("[AUTH] Adding packet", packet.SequenceID(), "to resend queue") + } + rs.packets.Set(packet.SequenceID(), pendingPacket) go pendingPacket.startResendTimer() @@ -74,6 +81,12 @@ func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { // AcknowledgePacket marks a pending packet as acknowledged. It will be ignored at the next resend attempt func (rs *ResendScheduler) AcknowledgePacket(sequenceID uint16) { if pendingPacket, ok := rs.packets.Get(sequenceID); ok { + if pendingPacket.packet.Sender().Server().(*PRUDPServer).IsSecureServer { + fmt.Println("[SECR] Acknowledged", sequenceID) + } else { + fmt.Println("[AUTH] Acknowledged", sequenceID) + } + pendingPacket.isAcknowledged = true } } @@ -81,8 +94,7 @@ func (rs *ResendScheduler) AcknowledgePacket(sequenceID uint16) { func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { if pendingPacket.isAcknowledged { // * Prevent a race condition where resendPacket may be called - // * at the same time a packet is acknowledged. Packet will be - // * handled properly at the next tick + // * at the same time a packet is acknowledged return } @@ -91,6 +103,12 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { if pendingPacket.resendCount >= rs.MaxResendCount { // * The maximum resend count has been reached, consider the client dead. + if pendingPacket.packet.Sender().Server().(*PRUDPServer).IsSecureServer { + fmt.Println("[SECR] Max resends hit for", pendingPacket.packet.SequenceID()) + } else { + fmt.Println("[AUTH] Max resends hit for", pendingPacket.packet.SequenceID()) + } + pendingPacket.ticker.Stop() rs.packets.Delete(packet.SequenceID()) client.cleanup() // * "removed" event is dispatched here @@ -99,6 +117,13 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { } if time.Since(pendingPacket.lastSendTime) >= rs.Interval { + if pendingPacket.packet.Sender().Server().(*PRUDPServer).IsSecureServer { + fmt.Println("[SECR] Resending packet", pendingPacket.packet.SequenceID()) + fmt.Println("[SECR]", rs.packets.real) + } else { + fmt.Println("[AUTH] Resending packet", pendingPacket.packet.SequenceID()) + } + // * Resend the packet to the client server := client.server data := packet.Bytes() From 94b4f8343a1f8b164dc192e5787313acb758a6ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 12 Nov 2023 20:50:30 +0000 Subject: [PATCH 018/178] test: Rename OnReliableData to OnData --- test/auth.go | 2 +- test/secure.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/auth.go b/test/auth.go index b0ec6ebf..5fffda40 100644 --- a/test/auth.go +++ b/test/auth.go @@ -14,7 +14,7 @@ func startAuthenticationServer() { authServer = nex.NewPRUDPServer() - authServer.OnReliableData(func(packet nex.PacketInterface) { + authServer.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { request := packet.RMCMessage() diff --git a/test/secure.go b/test/secure.go index b587bf76..d7fef980 100644 --- a/test/secure.go +++ b/test/secure.go @@ -48,7 +48,7 @@ func startSecureServer() { secureServer = nex.NewPRUDPServer() - secureServer.OnReliableData(func(packet nex.PacketInterface) { + secureServer.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { request := packet.RMCMessage() From b39018bb1905d72e8ca46e1aabded0ae324b69db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 12 Nov 2023 20:52:48 +0000 Subject: [PATCH 019/178] prudp: Create copies of packets for fragments This makes the resend scheduler to not have the same pointer to a packet when sending fragments. --- prudp_packet.go | 10 +++++ prudp_packet_interface.go | 2 + prudp_server.go | 77 +++++++++++++++++++++++++-------------- resend_scheduler.go | 1 - 4 files changed, 62 insertions(+), 28 deletions(-) diff --git a/prudp_packet.go b/prudp_packet.go index 513f0b57..81f24359 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -25,6 +25,11 @@ func (p *PRUDPPacket) Sender() ClientInterface { return p.sender } +// Flags returns the packet flags +func (p *PRUDPPacket) Flags() uint16 { + return p.flags +} + // HasFlag checks if the packet has the given flag func (p *PRUDPPacket) HasFlag(flag uint16) bool { return p.flags&flag != 0 @@ -85,6 +90,11 @@ func (p *PRUDPPacket) DestinationPort() uint8 { return p.destinationPort } +// SessionID returns the packets session ID +func (p *PRUDPPacket) SessionID() uint8 { + return p.sessionID +} + // SetSessionID sets the packets session ID func (p *PRUDPPacket) SetSessionID(sessionID uint8) { p.sessionID = sessionID diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index 5703cd8c..1be2a3fe 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -7,6 +7,7 @@ type PRUDPPacketInterface interface { Version() int Bytes() []byte Sender() ClientInterface + Flags() uint16 HasFlag(flag uint16) bool AddFlag(flag uint16) SetType(packetType uint16) @@ -19,6 +20,7 @@ type PRUDPPacketInterface interface { DestinationStreamType() uint8 SetDestinationPort(destinationPort uint8) DestinationPort() uint8 + SessionID() uint8 SetSessionID(sessionID uint8) SubstreamID() uint8 SetSubstreamID(substreamID uint8) diff --git a/prudp_server.go b/prudp_server.go index 77529645..b624431c 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -564,53 +564,76 @@ func (s *PRUDPServer) Send(packet PacketInterface) { func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { client := packet.Sender().(*PRUDPClient) - if !packet.HasFlag(FlagAck) && !packet.HasFlag(FlagMultiAck) { - if packet.HasFlag(FlagReliable) { - substream := client.reliableSubstream(packet.SubstreamID()) - packet.SetSequenceID(substream.NextOutgoingSequenceID()) - } else if packet.Type() == DataPacket { - packet.SetSequenceID(client.nextOutgoingUnreliableSequenceID()) - } else if packet.Type() == PingPacket { - packet.SetSequenceID(client.nextOutgoingPingSequenceID()) + // TODO - Add packet.Copy() + var packetCopy PRUDPPacketInterface + if packet.Version() == 1 { + packetCopy, _ = NewPRUDPPacketV1(client, nil) + } else { + packetCopy, _ = NewPRUDPPacketV0(client, nil) + } + + packetCopy.SetSourceStreamType(packet.SourceStreamType()) + packetCopy.SetSourcePort(packet.SourcePort()) + + packetCopy.SetDestinationStreamType(packet.DestinationStreamType()) + packetCopy.SetDestinationPort(packet.DestinationPort()) + + packetCopy.SetType(packet.Type()) + packetCopy.AddFlag(packet.Flags()) + packetCopy.SetSessionID(packet.SessionID()) + packetCopy.SetSubstreamID(packet.SubstreamID()) + packetCopy.SetSequenceID(packet.SequenceID()) + packetCopy.SetPayload(packet.Payload()) + packetCopy.setConnectionSignature(packet.getConnectionSignature()) + packetCopy.setFragmentID(packet.getFragmentID()) + + if !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { + if packetCopy.HasFlag(FlagReliable) { + substream := client.reliableSubstream(packetCopy.SubstreamID()) + packetCopy.SetSequenceID(substream.NextOutgoingSequenceID()) + } else if packetCopy.Type() == DataPacket { + packetCopy.SetSequenceID(client.nextOutgoingUnreliableSequenceID()) + } else if packetCopy.Type() == PingPacket { + packetCopy.SetSequenceID(client.nextOutgoingPingSequenceID()) } else { - packet.SetSequenceID(0) + packetCopy.SetSequenceID(0) } } - if packet.Type() == DataPacket && !packet.HasFlag(FlagAck) && !packet.HasFlag(FlagMultiAck) { - if packet.HasFlag(FlagReliable) { - substream := client.reliableSubstream(packet.SubstreamID()) - packet.SetPayload(substream.Encrypt(packet.Payload())) + if packetCopy.Type() == DataPacket && !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { + if packetCopy.HasFlag(FlagReliable) { + substream := client.reliableSubstream(packetCopy.SubstreamID()) + packetCopy.SetPayload(substream.Encrypt(packetCopy.Payload())) } // TODO - Unreliable crypto } - packet.setSignature(packet.calculateSignature(client.sessionKey, client.serverConnectionSignature)) + packetCopy.setSignature(packetCopy.calculateSignature(client.sessionKey, client.serverConnectionSignature)) - if packet.HasFlag(FlagReliable) && packet.HasFlag(FlagNeedsAck) { - substream := client.reliableSubstream(packet.SubstreamID()) - substream.ResendScheduler.AddPacket(packet) + if packetCopy.HasFlag(FlagReliable) && packetCopy.HasFlag(FlagNeedsAck) { + substream := client.reliableSubstream(packetCopy.SubstreamID()) + substream.ResendScheduler.AddPacket(packetCopy) } - if packet.Type() == DataPacket && packet.RMCMessage() != nil { + if packetCopy.Type() == DataPacket && packetCopy.RMCMessage() != nil { if s.IsSecureServer { fmt.Println("[SECR] ======= SENDING =======") - fmt.Println("[SECR] ProtocolID:", packet.RMCMessage().ProtocolID) - fmt.Println("[SECR] MethodID:", packet.RMCMessage().MethodID) - fmt.Println("[SECR] FragmentID:", packet.getFragmentID()) - fmt.Println("[SECR] SequenceID:", packet.SequenceID()) + fmt.Println("[SECR] ProtocolID:", packetCopy.RMCMessage().ProtocolID) + fmt.Println("[SECR] MethodID:", packetCopy.RMCMessage().MethodID) + fmt.Println("[SECR] FragmentID:", packetCopy.getFragmentID()) + fmt.Println("[SECR] SequenceID:", packetCopy.SequenceID()) fmt.Println("[SECR] =======================") } else { fmt.Println("[AUTH] ======= SENDING =======") - fmt.Println("[AUTH] ProtocolID:", packet.RMCMessage().ProtocolID) - fmt.Println("[AUTH] MethodID:", packet.RMCMessage().MethodID) - fmt.Println("[AUTH] FragmentID:", packet.getFragmentID()) - fmt.Println("[AUTH] SequenceID:", packet.SequenceID()) + fmt.Println("[AUTH] ProtocolID:", packetCopy.RMCMessage().ProtocolID) + fmt.Println("[AUTH] MethodID:", packetCopy.RMCMessage().MethodID) + fmt.Println("[AUTH] FragmentID:", packetCopy.getFragmentID()) + fmt.Println("[AUTH] SequenceID:", packetCopy.SequenceID()) fmt.Println("[AUTH] =======================") } } - s.sendRaw(packet.Sender().Address(), packet.Bytes()) + s.sendRaw(packetCopy.Sender().Address(), packetCopy.Bytes()) } // sendRaw will send the given address the provided packet diff --git a/resend_scheduler.go b/resend_scheduler.go index d5fbe536..c2010a41 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -119,7 +119,6 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { if time.Since(pendingPacket.lastSendTime) >= rs.Interval { if pendingPacket.packet.Sender().Server().(*PRUDPServer).IsSecureServer { fmt.Println("[SECR] Resending packet", pendingPacket.packet.SequenceID()) - fmt.Println("[SECR]", rs.packets.real) } else { fmt.Println("[AUTH] Resending packet", pendingPacket.packet.SequenceID()) } From ffc71a97b7fb97148493263cfcf593aabd08eb25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 12 Nov 2023 20:54:14 +0000 Subject: [PATCH 020/178] stream: NEX version checks when not using PRUDP --- server_interface.go | 2 -- stream_in.go | 10 +++++++++- stream_out.go | 10 +++++++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/server_interface.go b/server_interface.go index baa2ab54..bccb2bc7 100644 --- a/server_interface.go +++ b/server_interface.go @@ -4,8 +4,6 @@ package nex type ServerInterface interface { AccessKey() string SetAccessKey(accessKey string) - SetProtocolMinorVersion(protocolMinorVersion uint32) - ProtocolMinorVersion() uint32 LibraryVersion() *LibraryVersion DataStoreProtocolVersion() *LibraryVersion MatchMakingProtocolVersion() *LibraryVersion diff --git a/stream_in.go b/stream_in.go index 0c86bc1c..ac97ed34 100644 --- a/stream_in.go +++ b/stream_in.go @@ -255,7 +255,15 @@ func (stream *StreamIn) ReadStructure(structure StructureInterface) (StructureIn } } - if stream.Server.ProtocolMinorVersion() >= 3 { + var useStructures bool + switch server := stream.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructures = server.ProtocolMinorVersion() >= 3 + default: + useStructures = server.LibraryVersion().GreaterOrEqual("3.5.0") + } + + if useStructures { version, err := stream.ReadUInt8() if err != nil { return nil, fmt.Errorf("Failed to read NEX Structure version. %s", err.Error()) diff --git a/stream_out.go b/stream_out.go index f1ce9de8..171ce751 100644 --- a/stream_out.go +++ b/stream_out.go @@ -177,7 +177,15 @@ func (stream *StreamOut) WriteStructure(structure StructureInterface) { content := structure.Bytes(NewStreamOut(stream.Server)) - if stream.Server.ProtocolMinorVersion() >= 3 { + var useStructures bool + switch server := stream.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructures = server.ProtocolMinorVersion() >= 3 + default: + useStructures = server.LibraryVersion().GreaterOrEqual("3.5.0") + } + + if useStructures { stream.WriteUInt8(structure.StructureVersion()) stream.WriteUInt32LE(uint32(len(content))) } From 1194ad3960030d74a2467726f29ff95ac425214f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 12 Nov 2023 16:56:07 -0500 Subject: [PATCH 021/178] prudp: added Copy method to packets --- prudp_packet_interface.go | 1 + prudp_packet_v0.go | 38 +++++++++++++++++++++++++++++++++ prudp_packet_v1.go | 45 +++++++++++++++++++++++++++++++++++++++ prudp_server.go | 30 ++++++-------------------- rmc.go => rmc_message.go | 17 +++++++++++++++ 5 files changed, 107 insertions(+), 24 deletions(-) rename rmc.go => rmc_message.go (92%) diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index 1be2a3fe..1af61274 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -4,6 +4,7 @@ import "net" // PRUDPPacketInterface defines all the methods a PRUDP packet should have type PRUDPPacketInterface interface { + Copy() PRUDPPacketInterface Version() int Bytes() []byte Sender() ClientInterface diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 3af9496a..9b93080a 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -15,6 +15,44 @@ type PRUDPPacketV0 struct { PRUDPPacket } +// Copy copies the packet into a new PRUDPPacketV0 +// +// Retains the same PRUDPClient pointer +func (p *PRUDPPacketV0) Copy() PRUDPPacketInterface { + copied, _ := NewPRUDPPacketV0(p.sender, nil) + + copied.sourceStreamType = p.sourceStreamType + copied.sourcePort = p.sourcePort + copied.destinationStreamType = p.destinationStreamType + copied.destinationPort = p.destinationPort + copied.packetType = p.packetType + copied.flags = p.flags + copied.sessionID = p.sessionID + copied.substreamID = p.substreamID + + if p.signature != nil { + copied.signature = append([]byte(nil), p.signature...) + } + + copied.sequenceID = p.sequenceID + + if p.connectionSignature != nil { + copied.connectionSignature = append([]byte(nil), p.connectionSignature...) + } + + copied.fragmentID = p.fragmentID + + if p.payload != nil { + copied.payload = append([]byte(nil), p.payload...) + } + + if p.message != nil { + copied.message = p.message.Copy() + } + + return copied +} + // Version returns the packets PRUDP version func (p *PRUDPPacketV0) Version() int { return 0 diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index bcf60536..e46f5c56 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -21,6 +21,51 @@ type PRUDPPacketV1 struct { initialUnreliableSequenceID uint16 } +// Copy copies the packet into a new PRUDPPacketV1 +// +// Retains the same PRUDPClient pointer +func (p *PRUDPPacketV1) Copy() PRUDPPacketInterface { + copied, _ := NewPRUDPPacketV1(p.sender, nil) + + copied.sourceStreamType = p.sourceStreamType + copied.sourcePort = p.sourcePort + copied.destinationStreamType = p.destinationStreamType + copied.destinationPort = p.destinationPort + copied.packetType = p.packetType + copied.flags = p.flags + copied.sessionID = p.sessionID + copied.substreamID = p.substreamID + + if p.signature != nil { + copied.signature = append([]byte(nil), p.signature...) + } + + copied.sequenceID = p.sequenceID + + if p.connectionSignature != nil { + copied.connectionSignature = append([]byte(nil), p.connectionSignature...) + } + + copied.fragmentID = p.fragmentID + + if p.payload != nil { + copied.payload = append([]byte(nil), p.payload...) + } + + if p.message != nil { + copied.message = p.message.Copy() + } + + copied.optionsLength = p.optionsLength + copied.payloadLength = p.payloadLength + copied.minorVersion = p.minorVersion + copied.supportedFunctions = p.supportedFunctions + copied.maximumSubstreamID = p.maximumSubstreamID + copied.initialUnreliableSequenceID = p.initialUnreliableSequenceID + + return copied +} + // Version returns the packets PRUDP version func (p *PRUDPPacketV1) Version() int { return 1 diff --git a/prudp_server.go b/prudp_server.go index b624431c..0563447b 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -562,30 +562,12 @@ func (s *PRUDPServer) Send(packet PacketInterface) { } func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { - client := packet.Sender().(*PRUDPClient) - - // TODO - Add packet.Copy() - var packetCopy PRUDPPacketInterface - if packet.Version() == 1 { - packetCopy, _ = NewPRUDPPacketV1(client, nil) - } else { - packetCopy, _ = NewPRUDPPacketV0(client, nil) - } - - packetCopy.SetSourceStreamType(packet.SourceStreamType()) - packetCopy.SetSourcePort(packet.SourcePort()) - - packetCopy.SetDestinationStreamType(packet.DestinationStreamType()) - packetCopy.SetDestinationPort(packet.DestinationPort()) - - packetCopy.SetType(packet.Type()) - packetCopy.AddFlag(packet.Flags()) - packetCopy.SetSessionID(packet.SessionID()) - packetCopy.SetSubstreamID(packet.SubstreamID()) - packetCopy.SetSequenceID(packet.SequenceID()) - packetCopy.SetPayload(packet.Payload()) - packetCopy.setConnectionSignature(packet.getConnectionSignature()) - packetCopy.setFragmentID(packet.getFragmentID()) + // * PRUDPServer.Send will send fragments as the same packet, + // * just with different fields. In order to prevent modifying + // * multiple packets at once, due to the same pointer being + // * reused, we must make a copy of the packet being sent + packetCopy := packet.Copy() + client := packetCopy.Sender().(*PRUDPClient) if !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { if packetCopy.HasFlag(FlagReliable) { diff --git a/rmc.go b/rmc_message.go similarity index 92% rename from rmc.go rename to rmc_message.go index 80474e97..30b0eb31 100644 --- a/rmc.go +++ b/rmc_message.go @@ -16,6 +16,23 @@ type RMCMessage struct { Parameters []byte // * Input for the method } +func (rmc *RMCMessage) Copy() *RMCMessage { + copied := NewRMCMessage() + + copied.IsRequest = copied.IsRequest + copied.IsSuccess = copied.IsSuccess + copied.ProtocolID = copied.ProtocolID + copied.CallID = copied.CallID + copied.MethodID = copied.MethodID + copied.ErrorCode = copied.ErrorCode + + if rmc.Parameters != nil { + copied.Parameters = append([]byte(nil), rmc.Parameters...) + } + + return copied +} + // FromBytes decodes an RMCMessage from the given byte slice. func (rmc *RMCMessage) FromBytes(data []byte) error { stream := NewStreamIn(data, nil) From bdffdc360497f06ee3ecdb9b409deae28693e389 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 13 Nov 2023 01:21:45 -0500 Subject: [PATCH 022/178] added PID type --- client_interface.go | 4 ++-- prudp_client.go | 6 ++--- prudp_server.go | 24 +++++++++---------- stream_in.go | 34 ++++++++++++++++++++------- stream_out.go | 28 ++++++++++++++-------- types.go | 57 +++++++++++++++++++++++++++++++++++---------- 6 files changed, 105 insertions(+), 48 deletions(-) diff --git a/client_interface.go b/client_interface.go index 2f496f7b..20cdece3 100644 --- a/client_interface.go +++ b/client_interface.go @@ -6,6 +6,6 @@ import "net" type ClientInterface interface { Server() ServerInterface Address() net.Addr - PID() uint32 - SetPID(pid uint32) + PID() *PID + SetPID(pid *PID) } diff --git a/prudp_client.go b/prudp_client.go index c6c1ffab..160b1de2 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -9,7 +9,7 @@ import ( type PRUDPClient struct { address *net.UDPAddr server *PRUDPServer - pid uint32 + pid *PID clientConnectionSignature []byte serverConnectionSignature []byte sessionKey []byte @@ -74,12 +74,12 @@ func (c *PRUDPClient) Address() net.Addr { } // PID returns the clients NEX PID -func (c *PRUDPClient) PID() uint32 { +func (c *PRUDPClient) PID() *PID { return c.pid } // SetPID sets the clients NEX PID -func (c *PRUDPClient) SetPID(pid uint32) { +func (c *PRUDPClient) SetPID(pid *PID) { c.pid = pid } diff --git a/prudp_server.go b/prudp_server.go index 0563447b..da044d2e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -405,17 +405,17 @@ func (s *PRUDPServer) handlePing(packet PRUDPPacketInterface) { } } -func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, uint32, uint32, error) { +func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *PID, uint32, error) { stream := NewStreamIn(payload, s) ticketData, err := stream.ReadBuffer() if err != nil { - return nil, 0, 0, err + return nil, nil, 0, err } requestData, err := stream.ReadBuffer() if err != nil { - return nil, 0, 0, err + return nil, nil, 0, err } serverKey := DeriveKerberosKey(2, s.kerberosPassword) @@ -423,7 +423,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, uint32, uint32 ticket := NewKerberosTicketInternalData() err = ticket.Decrypt(NewStreamIn(ticketData, s), serverKey) if err != nil { - return nil, 0, 0, err + return nil, nil, 0, err } ticketTime := ticket.Issued.Standard() @@ -431,7 +431,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, uint32, uint32 timeLimit := ticketTime.Add(time.Minute * 2) if serverTime.After(timeLimit) { - return nil, 0, 0, errors.New("Kerberos ticket expired") + return nil, nil, 0, errors.New("Kerberos ticket expired") } sessionKey := ticket.SessionKey @@ -439,24 +439,24 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, uint32, uint32 decryptedRequestData, err := kerberos.Decrypt(requestData) if err != nil { - return nil, 0, 0, err + return nil, nil, 0, err } checkDataStream := NewStreamIn(decryptedRequestData, s) - userPID, err := checkDataStream.ReadUInt32LE() + userPID, err := checkDataStream.ReadPID() if err != nil { - return nil, 0, 0, err + return nil, nil, 0, err } _, err = checkDataStream.ReadUInt32LE() // * CID of secure server station url if err != nil { - return nil, 0, 0, err + return nil, nil, 0, err } responseCheck, err := checkDataStream.ReadUInt32LE() if err != nil { - return nil, 0, 0, err + return nil, nil, 0, err } return sessionKey, userPID, responseCheck, nil @@ -793,11 +793,11 @@ func (s *PRUDPServer) FindClientByConnectionID(connectedID uint32) *PRUDPClient } // FindClientByPID returns the PRUDP client connected with the given PID -func (s *PRUDPServer) FindClientByPID(pid uint32) *PRUDPClient { +func (s *PRUDPServer) FindClientByPID(pid uint64) *PRUDPClient { var client *PRUDPClient s.clients.Each(func(discriminator string, c *PRUDPClient) bool { - if c.pid == pid { + if c.pid.pid == pid { client = c return true } diff --git a/stream_in.go b/stream_in.go index ac97ed34..a95498d5 100644 --- a/stream_in.go +++ b/stream_in.go @@ -26,15 +26,6 @@ func (stream *StreamIn) ReadRemaining() []byte { return stream.ReadBytesNext(int64(stream.Remaining())) } -// ReadBool reads a bool -func (stream *StreamIn) ReadBool() (bool, error) { - if stream.Remaining() < 1 { - return false, errors.New("Not enough data to read bool") - } - - return stream.ReadByteNext() == 1, nil -} - // ReadUInt8 reads a uint8 func (stream *StreamIn) ReadUInt8() (uint8, error) { if stream.Remaining() < 1 { @@ -197,6 +188,31 @@ func (stream *StreamIn) ReadFloat64BE() (float64, error) { return stream.ReadF64BENext(1)[0], nil } +// ReadBool reads a bool +func (stream *StreamIn) ReadBool() (bool, error) { + if stream.Remaining() < 1 { + return false, errors.New("Not enough data to read bool") + } + + return stream.ReadByteNext() == 1, nil +} + +// ReadPID reads a PID. The size depends on the server type +func (stream *StreamIn) ReadPID() (*PID, error) { + if _, ok := stream.Server.(*PRUDPServer); ok { + // * Assume all UDP servers use the legacy size + if stream.Remaining() < 4 { + return nil, errors.New("Not enough data to read legacy PID") + } + + pid, _ := stream.ReadUInt32LE() + + return NewPID(pid), nil + } + + return nil, errors.New("Unknown PID size. Server type could not be determined") +} + // ReadString reads and returns a nex string type func (stream *StreamIn) ReadString() (string, error) { length, err := stream.ReadUInt16LE() diff --git a/stream_out.go b/stream_out.go index 171ce751..0391d211 100644 --- a/stream_out.go +++ b/stream_out.go @@ -12,16 +12,6 @@ type StreamOut struct { Server ServerInterface } -// WriteBool writes a bool -func (stream *StreamOut) WriteBool(b bool) { - var bVar uint8 - if b { - bVar = 1 - } - stream.Grow(1) - stream.WriteByteNext(byte(bVar)) -} - // WriteUInt8 writes a uint8 func (stream *StreamOut) WriteUInt8(u8 uint8) { stream.Grow(1) @@ -130,6 +120,24 @@ func (stream *StreamOut) WriteFloat64BE(f64 float64) { stream.WriteF64BENext([]float64{f64}) } +// WriteBool writes a bool +func (stream *StreamOut) WriteBool(b bool) { + var bVar uint8 + if b { + bVar = 1 + } + stream.Grow(1) + stream.WriteByteNext(byte(bVar)) +} + +// WritePID writes a NEX PID. The size depends on the server type +func (stream *StreamOut) WritePID(pid *PID) { + if _, ok := stream.Server.(*PRUDPServer); ok { + // * Assume all UDP servers use the legacy size + stream.WriteUInt32LE(uint32(pid.pid)) + } +} + // WriteString writes a NEX string type func (stream *StreamOut) WriteString(str string) { str = str + "\x00" diff --git a/types.go b/types.go index d08e03f1..7f86b663 100644 --- a/types.go +++ b/types.go @@ -9,6 +9,43 @@ import ( "time" ) +// PID represents a unique number to identify a user +// +// The true size of this value depends on the client version. +// Legacy clients (WiiU/3DS) use a uint32, whereas new clients (Nintendo Switch) use a uint64. +// Value is always stored as the higher uint64, the consuming API should assert accordingly +type PID struct { + pid uint64 +} + +// Value returns the numeric value of the PID as a uint64 regardless of client version +func (p *PID) Value() uint64 { + return p.pid +} + +// LegacyValue returns the numeric value of the PID as a uint32, for legacy clients +func (p *PID) LegacyValue() uint32 { + return uint32(p.pid) +} + +// NewPID returns a PID instance. The size of PID depends on the client version +func NewPID[T uint32 | uint64](pid T) *PID { + switch v := any(pid).(type) { + case uint32: + return &PID{pid: uint64(v)} + case uint64: + return &PID{pid: v} + } + + // * This will never happen because Go will + // * not compile any code where "pid" is not + // * a uint32/uint64, so it will ALWAYS get + // * caught by the above switch-case. This + // * return is only here because Go won't + // * compile without a default return + return nil +} + // StructureInterface implements all Structure methods type StructureInterface interface { SetParentType(StructureInterface) @@ -521,7 +558,7 @@ func NewDateTime(value uint64) *DateTime { } // StationURL contains the data for a NEX station URL. -// Uses uint32 pointers to check for nil, 0 is valid +// Uses pointers to check for nil, 0 is valid type StationURL struct { local bool // * Not part of the data structure. Used for easier lookups elsewhere public bool // * Not part of the data structure. Used for easier lookups elsewhere @@ -532,7 +569,7 @@ type StationURL struct { stream *uint32 sid *uint32 cid *uint32 - pid *uint32 + pid *PID transportType *uint32 rvcid *uint32 natm *uint32 @@ -601,8 +638,8 @@ func (stationURL *StationURL) SetCID(cid uint32) { } // SetPID sets the StationURL PID -func (stationURL *StationURL) SetPID(pid uint32) { - stationURL.pid = &pid +func (stationURL *StationURL) SetPID(pid *PID) { + stationURL.pid = pid } // SetType sets the StationURL transportType @@ -701,12 +738,8 @@ func (stationURL *StationURL) CID() uint32 { } // PID returns the StationURL PID value -func (stationURL *StationURL) PID() uint32 { - if stationURL.pid == nil { - return 0 - } else { - return *stationURL.pid - } +func (stationURL *StationURL) PID() *PID { + return stationURL.pid } // Type returns the StationURL type @@ -817,7 +850,7 @@ func (stationURL *StationURL) FromString(str string) { stationURL.SetCID(uint32(ui64)) case "PID": ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetPID(uint32(ui64)) + stationURL.SetPID(NewPID(ui64)) case "type": ui64, _ := strconv.ParseUint(value, 10, 32) stationURL.SetType(uint32(ui64)) @@ -875,7 +908,7 @@ func (stationURL *StationURL) EncodeToString() string { } if stationURL.pid != nil { - fields = append(fields, "PID="+strconv.FormatUint(uint64(stationURL.PID()), 10)) + fields = append(fields, "PID="+strconv.FormatUint(uint64(stationURL.PID().pid), 10)) } if stationURL.transportType != nil { From 59febaaf432b684e6466ec887082bc4f0d9bc8c4 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 13 Nov 2023 15:14:33 -0500 Subject: [PATCH 023/178] DateTime type methods no longer return underlying value --- test/auth.go | 3 +-- types.go | 76 ++++++++++++++++++++++++++-------------------------- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/test/auth.go b/test/auth.go index 5fffda40..af302f36 100644 --- a/test/auth.go +++ b/test/auth.go @@ -68,8 +68,7 @@ func login(packet nex.PRUDPPacketInterface) { pConnectionData.SetStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") pConnectionData.SetSpecialProtocols([]byte{}) pConnectionData.SetStationURLSpecialProtocols("") - serverTime := nex.NewDateTime(0) - pConnectionData.SetTime(nex.NewDateTime(serverTime.UTC())) + pConnectionData.SetTime(nex.NewDateTime(0).UTC()) responseStream := nex.NewStreamOut(authServer) diff --git a/types.go b/types.go index 7f86b663..544a5e7e 100644 --- a/types.go +++ b/types.go @@ -446,14 +446,14 @@ type DateTime struct { } // Make initilizes a DateTime with the input data -func (datetime *DateTime) Make(year, month, day, hour, minute, second int) uint64 { - datetime.value = uint64(second | (minute << 6) | (hour << 12) | (day << 17) | (month << 22) | (year << 26)) +func (dt *DateTime) Make(year, month, day, hour, minute, second int) *DateTime { + dt.value = uint64(second | (minute << 6) | (hour << 12) | (day << 17) | (month << 22) | (year << 26)) - return datetime.value + return dt } // FromTimestamp converts a Time timestamp into a NEX DateTime -func (datetime *DateTime) FromTimestamp(timestamp time.Time) uint64 { +func (dt *DateTime) FromTimestamp(timestamp time.Time) *DateTime { year := timestamp.Year() month := int(timestamp.Month()) day := timestamp.Day() @@ -461,92 +461,92 @@ func (datetime *DateTime) FromTimestamp(timestamp time.Time) uint64 { minute := timestamp.Minute() second := timestamp.Second() - return datetime.Make(year, month, day, hour, minute, second) + return dt.Make(year, month, day, hour, minute, second) } // Now converts the current Time timestamp to a NEX DateTime -func (datetime *DateTime) Now() uint64 { - return datetime.FromTimestamp(time.Now()) +func (dt *DateTime) Now() *DateTime { + return dt.FromTimestamp(time.Now()) } // UTC returns a NEX DateTime value of the current UTC time -func (datetime *DateTime) UTC() uint64 { - return datetime.FromTimestamp(time.Now().UTC()) +func (dt *DateTime) UTC() *DateTime { + return dt.FromTimestamp(time.Now().UTC()) } // Value returns the stored DateTime time -func (datetime *DateTime) Value() uint64 { - return datetime.value +func (dt *DateTime) Value() uint64 { + return dt.value } // Second returns the seconds value stored in the DateTime -func (datetime *DateTime) Second() int { - return int(datetime.value & 63) +func (dt *DateTime) Second() int { + return int(dt.value & 63) } // Minute returns the minutes value stored in the DateTime -func (datetime *DateTime) Minute() int { - return int((datetime.value >> 6) & 63) +func (dt *DateTime) Minute() int { + return int((dt.value >> 6) & 63) } // Hour returns the hours value stored in the DateTime -func (datetime *DateTime) Hour() int { - return int((datetime.value >> 12) & 31) +func (dt *DateTime) Hour() int { + return int((dt.value >> 12) & 31) } // Day returns the day value stored in the DateTime -func (datetime *DateTime) Day() int { - return int((datetime.value >> 17) & 31) +func (dt *DateTime) Day() int { + return int((dt.value >> 17) & 31) } // Month returns the month value stored in the DateTime -func (datetime *DateTime) Month() time.Month { - return time.Month((datetime.value >> 22) & 15) +func (dt *DateTime) Month() time.Month { + return time.Month((dt.value >> 22) & 15) } // Year returns the year value stored in the DateTime -func (datetime *DateTime) Year() int { - return int(datetime.value >> 26) +func (dt *DateTime) Year() int { + return int(dt.value >> 26) } // Standard returns the DateTime as a standard time.Time -func (datetime *DateTime) Standard() time.Time { +func (dt *DateTime) Standard() time.Time { return time.Date( - datetime.Year(), - datetime.Month(), - datetime.Day(), - datetime.Hour(), - datetime.Minute(), - datetime.Second(), + dt.Year(), + dt.Month(), + dt.Day(), + dt.Hour(), + dt.Minute(), + dt.Second(), 0, time.UTC, ) } // Copy returns a new copied instance of DateTime -func (datetime *DateTime) Copy() *DateTime { - return NewDateTime(datetime.value) +func (dt *DateTime) Copy() *DateTime { + return NewDateTime(dt.value) } // Equals checks if the passed Structure contains the same data as the current instance -func (datetime *DateTime) Equals(other *DateTime) bool { - return datetime.value == other.value +func (dt *DateTime) Equals(other *DateTime) bool { + return dt.value == other.value } // String returns a string representation of the struct -func (datetime *DateTime) String() string { - return datetime.FormatToString(0) +func (dt *DateTime) String() string { + return dt.FormatToString(0) } // FormatToString pretty-prints the struct data using the provided indentation level -func (datetime *DateTime) FormatToString(indentationLevel int) string { +func (dt *DateTime) FormatToString(indentationLevel int) string { indentationValues := strings.Repeat("\t", indentationLevel+1) indentationEnd := strings.Repeat("\t", indentationLevel) var b strings.Builder b.WriteString("DateTime{\n") - b.WriteString(fmt.Sprintf("%svalue: %d (%s)\n", indentationValues, datetime.value, datetime.Standard().Format("2006-01-02 15:04:05"))) + b.WriteString(fmt.Sprintf("%svalue: %d (%s)\n", indentationValues, dt.value, dt.Standard().Format("2006-01-02 15:04:05"))) b.WriteString(fmt.Sprintf("%s}", indentationEnd)) return b.String() From 4405b7e485d08ff435b4ea70406cc0924fb50c96 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 13 Nov 2023 15:16:33 -0500 Subject: [PATCH 024/178] removed DateTime UTC() method. Now() now always returns UTC --- test/auth.go | 2 +- test/generate_ticket.go | 3 +-- types.go | 7 +------ 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/test/auth.go b/test/auth.go index af302f36..468040bc 100644 --- a/test/auth.go +++ b/test/auth.go @@ -68,7 +68,7 @@ func login(packet nex.PRUDPPacketInterface) { pConnectionData.SetStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") pConnectionData.SetSpecialProtocols([]byte{}) pConnectionData.SetStationURLSpecialProtocols("") - pConnectionData.SetTime(nex.NewDateTime(0).UTC()) + pConnectionData.SetTime(nex.NewDateTime(0).Now()) responseStream := nex.NewStreamOut(authServer) diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 95951e7f..19ed91ff 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -14,8 +14,7 @@ func generateTicket(userPID uint32, targetPID uint32) []byte { rand.Read(sessionKey) ticketInternalData := nex.NewKerberosTicketInternalData() - serverTime := nex.NewDateTime(0) - serverTime.UTC() + serverTime := nex.NewDateTime(0).Now() ticketInternalData.Issued = serverTime ticketInternalData.SourcePID = userPID diff --git a/types.go b/types.go index 544a5e7e..1a241654 100644 --- a/types.go +++ b/types.go @@ -464,13 +464,8 @@ func (dt *DateTime) FromTimestamp(timestamp time.Time) *DateTime { return dt.Make(year, month, day, hour, minute, second) } -// Now converts the current Time timestamp to a NEX DateTime +// Now returns a NEX DateTime value of the current UTC time func (dt *DateTime) Now() *DateTime { - return dt.FromTimestamp(time.Now()) -} - -// UTC returns a NEX DateTime value of the current UTC time -func (dt *DateTime) UTC() *DateTime { return dt.FromTimestamp(time.Now().UTC()) } From 825f13b3531f98087791fbfe40eb15c75d8bce3d Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 13 Nov 2023 19:52:31 -0500 Subject: [PATCH 025/178] added PID list methods to streams --- stream_in.go | 21 +++++++++++++++++++++ stream_out.go | 11 +++++++++++ 2 files changed, 32 insertions(+) diff --git a/stream_in.go b/stream_in.go index a95498d5..3c576f05 100644 --- a/stream_in.go +++ b/stream_in.go @@ -849,6 +849,27 @@ func (stream *StreamIn) ReadListFloat64BE() ([]float64, error) { return list, nil } +// ReadListPID reads a list of NEX PIDs +func (stream *StreamIn) ReadListPID() ([]*PID, error) { + length, err := stream.ReadUInt32LE() + if err != nil { + return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) + } + + list := make([]*PID, 0, length) + + for i := 0; i < int(length); i++ { + value, err := stream.ReadPID() + if err != nil { + return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) + } + + list = append(list, value) + } + + return list, nil +} + // ReadListString reads a list of NEX String types func (stream *StreamIn) ReadListString() ([]string, error) { length, err := stream.ReadUInt32LE() diff --git a/stream_out.go b/stream_out.go index 0391d211..e2a2c4e5 100644 --- a/stream_out.go +++ b/stream_out.go @@ -380,6 +380,17 @@ func (stream *StreamOut) WriteListStructure(structures interface{}) { } } +// WriteListPID writes a list of NEX PIDs +func (stream *StreamOut) WriteListPID(pids []*PID) { + length := len(pids) + + stream.WriteUInt32LE(uint32(length)) + + for i := 0; i < length; i++ { + stream.WritePID(pids[i]) + } +} + // WriteListString writes a list of NEX String types func (stream *StreamOut) WriteListString(strings []string) { length := len(strings) From a647b16c25fa1630bb29bb781fa1650630d647ef Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 13 Nov 2023 19:55:27 -0500 Subject: [PATCH 026/178] streams now check server version for PID size --- stream_in.go | 15 ++++++++++----- stream_out.go | 7 ++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/stream_in.go b/stream_in.go index 3c576f05..46c8f0db 100644 --- a/stream_in.go +++ b/stream_in.go @@ -197,10 +197,17 @@ func (stream *StreamIn) ReadBool() (bool, error) { return stream.ReadByteNext() == 1, nil } -// ReadPID reads a PID. The size depends on the server type +// ReadPID reads a PID. The size depends on the server version func (stream *StreamIn) ReadPID() (*PID, error) { - if _, ok := stream.Server.(*PRUDPServer); ok { - // * Assume all UDP servers use the legacy size + if stream.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + if stream.Remaining() < 8 { + return nil, errors.New("Not enough data to read PID") + } + + pid, _ := stream.ReadUInt64LE() + + return NewPID(pid), nil + } else { if stream.Remaining() < 4 { return nil, errors.New("Not enough data to read legacy PID") } @@ -209,8 +216,6 @@ func (stream *StreamIn) ReadPID() (*PID, error) { return NewPID(pid), nil } - - return nil, errors.New("Unknown PID size. Server type could not be determined") } // ReadString reads and returns a nex string type diff --git a/stream_out.go b/stream_out.go index e2a2c4e5..0dd364f0 100644 --- a/stream_out.go +++ b/stream_out.go @@ -130,10 +130,11 @@ func (stream *StreamOut) WriteBool(b bool) { stream.WriteByteNext(byte(bVar)) } -// WritePID writes a NEX PID. The size depends on the server type +// WritePID writes a NEX PID. The size depends on the server version func (stream *StreamOut) WritePID(pid *PID) { - if _, ok := stream.Server.(*PRUDPServer); ok { - // * Assume all UDP servers use the legacy size + if stream.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + stream.WriteUInt64LE(pid.pid) + } else { stream.WriteUInt32LE(uint32(pid.pid)) } } From 234b0a9f8c8d04c6b3cd947ec5d436dd87823fa0 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 13 Nov 2023 20:08:43 -0500 Subject: [PATCH 027/178] added Copy and string methods to PID type --- types.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/types.go b/types.go index 1a241654..e18141b7 100644 --- a/types.go +++ b/types.go @@ -28,6 +28,37 @@ func (p *PID) LegacyValue() uint32 { return uint32(p.pid) } +// Copy returns a copy of the current PID +func (p *PID) Copy() *PID { + return NewPID(p.pid) +} + +// String returns a string representation of the struct +func (p *PID) String() string { + return p.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (p *PID) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("PID{\n") + + switch v := any(p.pid).(type) { + case uint32: + b.WriteString(fmt.Sprintf("%spid: %d (legacy)\n", indentationValues, v)) + case uint64: + b.WriteString(fmt.Sprintf("%spid: %d (modern)\n", indentationValues, v)) + } + + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + // NewPID returns a PID instance. The size of PID depends on the client version func NewPID[T uint32 | uint64](pid T) *PID { switch v := any(pid).(type) { From e77c51991ca98191f3b400f2936b3732cf2d29bc Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 14 Nov 2023 01:45:50 -0500 Subject: [PATCH 028/178] added Equals method to PID type --- types.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/types.go b/types.go index e18141b7..4a3fd10f 100644 --- a/types.go +++ b/types.go @@ -28,6 +28,11 @@ func (p *PID) LegacyValue() uint32 { return uint32(p.pid) } +// Equals checks if the two structs are equal +func (p *PID) Equals(other *PID) bool { + return p.pid == other.pid +} + // Copy returns a copy of the current PID func (p *PID) Copy() *PID { return NewPID(p.pid) From 8c35f5e1a12444db090e53bef8edef96a168a9c3 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 14 Nov 2023 22:45:36 -0500 Subject: [PATCH 029/178] prudp: use PID type in kerberos and PasswordFromPID --- kerberos.go | 14 +++++++------- prudp_server.go | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/kerberos.go b/kerberos.go index f76fdc77..4671d685 100644 --- a/kerberos.go +++ b/kerberos.go @@ -67,7 +67,7 @@ func NewKerberosEncryption(key []byte) *KerberosEncryption { // KerberosTicket represents a ticket granting a user access to a secure server type KerberosTicket struct { SessionKey []byte - TargetPID uint32 + TargetPID *PID InternalData []byte } @@ -78,7 +78,7 @@ func (kt *KerberosTicket) Encrypt(key []byte, stream *StreamOut) ([]byte, error) stream.Grow(int64(len(kt.SessionKey))) stream.WriteBytesNext(kt.SessionKey) - stream.WriteUInt32LE(kt.TargetPID) + stream.WritePID(kt.TargetPID) stream.WriteBuffer(kt.InternalData) return encryption.Encrypt(stream.Bytes()), nil @@ -92,14 +92,14 @@ func NewKerberosTicket() *KerberosTicket { // KerberosTicketInternalData holds the internal data for a kerberos ticket to be processed by the server type KerberosTicketInternalData struct { Issued *DateTime - SourcePID uint32 + SourcePID *PID SessionKey []byte } // Encrypt writes the ticket data to the provided stream and returns the encrypted byte slice func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { stream.WriteDateTime(ti.Issued) - stream.WriteUInt32LE(ti.SourcePID) + stream.WritePID(ti.SourcePID) stream.Grow(int64(len(ti.SessionKey))) stream.WriteBytesNext(ti.SessionKey) @@ -166,7 +166,7 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) erro return fmt.Errorf("Failed to read Kerberos ticket internal data timestamp %s", err.Error()) } - userPID, err := stream.ReadUInt32LE() + userPID, err := stream.ReadPID() if err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data user PID %s", err.Error()) } @@ -184,10 +184,10 @@ func NewKerberosTicketInternalData() *KerberosTicketInternalData { } // DeriveKerberosKey derives a users kerberos encryption key based on their PID and password -func DeriveKerberosKey(pid uint32, password []byte) []byte { +func DeriveKerberosKey(pid *PID, password []byte) []byte { key := password - for i := 0; i < 65000+int(pid)%1024; i++ { + for i := 0; i < 65000+int(pid.Value())%1024; i++ { hash := md5.Sum(key) key = hash[:] } diff --git a/prudp_server.go b/prudp_server.go index da044d2e..12a2751e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -37,7 +37,7 @@ type PRUDPServer struct { clientRemovedEventHandlers []func(client *PRUDPClient) connectionIDCounter *Counter[uint32] pingTimeout time.Duration - PasswordFromPID func(pid uint32) (string, uint32) + PasswordFromPID func(pid *PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte } @@ -418,7 +418,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *PID, uint32, return nil, nil, 0, err } - serverKey := DeriveKerberosKey(2, s.kerberosPassword) + serverKey := DeriveKerberosKey(NewPID[uint64](2), s.kerberosPassword) ticket := NewKerberosTicketInternalData() err = ticket.Decrypt(NewStreamIn(ticketData, s), serverKey) From d901566ae5312028a1338f508ea022ac694ca2c6 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 17 Nov 2023 08:31:06 -0500 Subject: [PATCH 030/178] prudp: PRUDPClient now initializes with PID 0 --- prudp_client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/prudp_client.go b/prudp_client.go index 160b1de2..452a5b9d 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -180,5 +180,6 @@ func NewPRUDPClient(address *net.UDPAddr, server *PRUDPServer) *PRUDPClient { address: address, server: server, outgoingPingSequenceIDCounter: NewCounter[uint16](0), + pid: NewPID[uint32](0), } } From 029a8db2f8b09e92d40f3929f7ab47cf17123f0d Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 18 Nov 2023 13:15:12 -0500 Subject: [PATCH 031/178] removed Structure methods from StreamIn. Moved to generics --- stream_in.go | 150 +++++++++++++++++++++++++-------------------------- types.go | 2 +- 2 files changed, 76 insertions(+), 76 deletions(-) diff --git a/stream_in.go b/stream_in.go index 46c8f0db..d5f37954 100644 --- a/stream_in.go +++ b/stream_in.go @@ -3,7 +3,6 @@ package nex import ( "errors" "fmt" - "reflect" "strings" crunch "github.com/superwhiskers/crunch/v3" @@ -267,49 +266,6 @@ func (stream *StreamIn) ReadQBuffer() ([]byte, error) { return data, nil } -// ReadStructure reads a nex Structure type -func (stream *StreamIn) ReadStructure(structure StructureInterface) (StructureInterface, error) { - if structure.ParentType() != nil { - _, err := stream.ReadStructure(structure.ParentType()) - if err != nil { - return nil, fmt.Errorf("Failed to read structure parent. %s", err.Error()) - } - } - - var useStructures bool - switch server := stream.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructures = server.ProtocolMinorVersion() >= 3 - default: - useStructures = server.LibraryVersion().GreaterOrEqual("3.5.0") - } - - if useStructures { - version, err := stream.ReadUInt8() - if err != nil { - return nil, fmt.Errorf("Failed to read NEX Structure version. %s", err.Error()) - } - - structureLength, err := stream.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read NEX Structure content length. %s", err.Error()) - } - - if stream.Remaining() < int(structureLength) { - return nil, errors.New("NEX Structure content length longer than data size") - } - - structure.SetStructureVersion(version) - } - - err := structure.ExtractFromStream(stream) - if err != nil { - return nil, fmt.Errorf("Failed to read structure from stream. %s", err.Error()) - } - - return structure, nil -} - // ReadVariant reads a Variant type. This type can hold 7 different types func (stream *StreamIn) ReadVariant() (*Variant, error) { variant := NewVariant() @@ -324,13 +280,10 @@ func (stream *StreamIn) ReadVariant() (*Variant, error) { // ReadMap reads a Map type with the given key and value types func (stream *StreamIn) ReadMap(keyFunction interface{}, valueFunction interface{}) (map[interface{}]interface{}, error) { - /* - TODO: Make this not suck - - Map types can have any type as the key and any type as the value - Due to strict typing we cannot just pass stream functions as these values and call them - At the moment this just reads what type you want from the interface{} function type - */ + // TODO - Make this not suck + // * Map types can have any type as the key and any type as the value + // * Due to strict typing we cannot just pass stream functions as these values and call them + // * At the moment this just reads what type you want from the interface{} function type length, err := stream.ReadUInt32LE() if err != nil { @@ -959,30 +912,6 @@ func (stream *StreamIn) ReadListStationURL() ([]*StationURL, error) { return list, nil } -// ReadListStructure reads and returns a list structure types -func (stream *StreamIn) ReadListStructure(structure StructureInterface) (interface{}, error) { - length, err := stream.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - structureType := reflect.TypeOf(structure) - structureSlice := reflect.MakeSlice(reflect.SliceOf(structureType), 0, int(length)) - - for i := 0; i < int(length); i++ { - newStructure := structure.Copy() - - extractedStructure, err := stream.ReadStructure(newStructure) - if err != nil { - return nil, err - } - - structureSlice = reflect.Append(structureSlice, reflect.ValueOf(extractedStructure)) - } - - return structureSlice.Interface(), nil -} - // ReadListDataHolder reads a list of NEX DataHolder types func (stream *StreamIn) ReadListDataHolder() ([]*DataHolder, error) { length, err := stream.ReadUInt32LE() @@ -1011,3 +940,74 @@ func NewStreamIn(data []byte, server ServerInterface) *StreamIn { Server: server, } } + +// StreamReadStructure reads a Structure type from a StreamIn +// +// Implemented as a separate function to utilize generics +func StreamReadStructure[T StructureInterface](stream *StreamIn, structure T) (T, error) { + if structure.ParentType() != nil { + //_, err := stream.ReadStructure(structure.ParentType()) + _, err := StreamReadStructure(stream, structure.ParentType()) + if err != nil { + return structure, fmt.Errorf("Failed to read structure parent. %s", err.Error()) + } + } + + var useStructureHeader bool + switch server := stream.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructureHeader = server.ProtocolMinorVersion() >= 3 + default: + useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") + } + + if useStructureHeader { + version, err := stream.ReadUInt8() + if err != nil { + return structure, fmt.Errorf("Failed to read NEX Structure version. %s", err.Error()) + } + + structureLength, err := stream.ReadUInt32LE() + if err != nil { + return structure, fmt.Errorf("Failed to read NEX Structure content length. %s", err.Error()) + } + + if stream.Remaining() < int(structureLength) { + return structure, errors.New("NEX Structure content length longer than data size") + } + + structure.SetStructureVersion(version) + } + + err := structure.ExtractFromStream(stream) + if err != nil { + return structure, fmt.Errorf("Failed to read structure from stream. %s", err.Error()) + } + + return structure, nil +} + +// StreamReadListStructure reads and returns a list structure types from a StreamIn +// +// Implemented as a separate function to utilize generics +func StreamReadListStructure[T StructureInterface](stream *StreamIn, structure T) ([]T, error) { + length, err := stream.ReadUInt32LE() + if err != nil { + return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) + } + + structures := make([]T, 0, int(length)) + + for i := 0; i < int(length); i++ { + newStructure := structure.Copy() + + extracted, err := StreamReadStructure[T](stream, newStructure.(T)) + if err != nil { + return nil, err + } + + structures = append(structures, extracted) + } + + return structures, nil +} diff --git a/types.go b/types.go index 4a3fd10f..0fc41833 100644 --- a/types.go +++ b/types.go @@ -241,7 +241,7 @@ func (dataHolder *DataHolder) ExtractFromStream(stream *StreamIn) error { newObjectInstance := dataType.Copy() - dataHolder.objectData, err = stream.ReadStructure(newObjectInstance) + dataHolder.objectData, err = StreamReadStructure(stream, newObjectInstance) if err != nil { return fmt.Errorf("Failed to read DataHolder object data. %s", err.Error()) } From 8207ccdf4cabd9ea1011f37959d14d5619bd6b06 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 18 Nov 2023 14:21:53 -0500 Subject: [PATCH 032/178] removed ReadMap from StreamIn. Moved to generics --- stream_in.go | 76 +++++++++++++++++++--------------------------------- 1 file changed, 28 insertions(+), 48 deletions(-) diff --git a/stream_in.go b/stream_in.go index d5f37954..de130d0b 100644 --- a/stream_in.go +++ b/stream_in.go @@ -278,54 +278,6 @@ func (stream *StreamIn) ReadVariant() (*Variant, error) { return variant, nil } -// ReadMap reads a Map type with the given key and value types -func (stream *StreamIn) ReadMap(keyFunction interface{}, valueFunction interface{}) (map[interface{}]interface{}, error) { - // TODO - Make this not suck - // * Map types can have any type as the key and any type as the value - // * Due to strict typing we cannot just pass stream functions as these values and call them - // * At the moment this just reads what type you want from the interface{} function type - - length, err := stream.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read Map length. %s", err.Error()) - } - - typeReader := func(function interface{}) (interface{}, error) { - var value interface{} - var err error - - switch function.(type) { - case func() (string, error): - value, err = stream.ReadString() - case func() (*Variant, error): - value, err = stream.ReadVariant() - default: - value = nil - err = errors.New("Unsupported type in ReadMap") - } - - return value, err - } - - newMap := make(map[interface{}]interface{}) - - for i := 0; i < int(length); i++ { - key, err := typeReader(keyFunction) - if err != nil { - return nil, fmt.Errorf("Failed to read Map key. %s", err.Error()) - } - - value, err := typeReader(valueFunction) - if err != nil { - return nil, fmt.Errorf("Failed to read Map value. %s", err.Error()) - } - - newMap[key] = value - } - - return newMap, nil -} - // ReadDateTime reads a DateTime type func (stream *StreamIn) ReadDateTime() (*DateTime, error) { value, err := stream.ReadUInt64LE() @@ -1011,3 +963,31 @@ func StreamReadListStructure[T StructureInterface](stream *StreamIn, structure T return structures, nil } + +// StreamReadMap reads a Map type with the given key and value types from a StreamIn +// +// Implemented as a separate function to utilize generics +func StreamReadMap[K comparable, V any](stream *StreamIn, keyReader func() (K, error), valueReader func() (V, error)) (map[K]V, error) { + length, err := stream.ReadUInt32LE() + if err != nil { + return nil, fmt.Errorf("Failed to read Map length. %s", err.Error()) + } + + m := make(map[K]V) + + for i := 0; i < int(length); i++ { + key, err := keyReader() + if err != nil { + return nil, err + } + + value, err := valueReader() + if err != nil { + return nil, err + } + + m[key] = value + } + + return m, nil +} From 226d1048552925c6a42d3160f8821fa01e59f6ba Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 18 Nov 2023 15:05:43 -0500 Subject: [PATCH 033/178] removed reflect from StreamOut as well --- stream_in.go | 2 +- stream_out.go | 65 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/stream_in.go b/stream_in.go index de130d0b..7de138a4 100644 --- a/stream_in.go +++ b/stream_in.go @@ -939,7 +939,7 @@ func StreamReadStructure[T StructureInterface](stream *StreamIn, structure T) (T return structure, nil } -// StreamReadListStructure reads and returns a list structure types from a StreamIn +// StreamReadListStructure reads and returns a list of structure types from a StreamIn // // Implemented as a separate function to utilize generics func StreamReadListStructure[T StructureInterface](stream *StreamIn, structure T) ([]T, error) { diff --git a/stream_out.go b/stream_out.go index 0dd364f0..cef0c2f6 100644 --- a/stream_out.go +++ b/stream_out.go @@ -1,7 +1,7 @@ package nex import ( - "reflect" + "fmt" crunch "github.com/superwhiskers/crunch/v3" ) @@ -365,22 +365,6 @@ func (stream *StreamOut) WriteListFloat64BE(list []float64) { } } -// WriteListStructure writes a list of NEX Structure types -func (stream *StreamOut) WriteListStructure(structures interface{}) { - // TODO: - // Find a better solution that doesn't use reflect - - slice := reflect.ValueOf(structures) - count := slice.Len() - - stream.WriteUInt32LE(uint32(count)) - - for i := 0; i < count; i++ { - structure := slice.Index(i).Interface().(StructureInterface) - stream.WriteStructure(structure) - } -} - // WriteListPID writes a list of NEX PIDs func (stream *StreamOut) WriteListPID(pids []*PID) { length := len(pids) @@ -477,6 +461,7 @@ func (stream *StreamOut) WriteVariant(variant *Variant) { stream.WriteBytesNext(content) } +/* // WriteMap writes a Map type with the given key and value types func (stream *StreamOut) WriteMap(mapType interface{}) { // TODO: @@ -504,6 +489,7 @@ func (stream *StreamOut) WriteMap(mapType interface{}) { } } } +*/ // NewStreamOut returns a new nex output stream func NewStreamOut(server ServerInterface) *StreamOut { @@ -512,3 +498,48 @@ func NewStreamOut(server ServerInterface) *StreamOut { Server: server, } } + +// StreamWriteListStructure writes a list of structure types to a StreamOut +// +// Implemented as a separate function to utilize generics +func StreamWriteListStructure[T StructureInterface](stream *StreamOut, structures []T) { + count := len(structures) + + stream.WriteUInt32LE(uint32(count)) + + for i := 0; i < count; i++ { + stream.WriteStructure(structures[i]) + } +} + +func mapTypeWriter[T any](stream *StreamOut, t T) { + // * Map types in NEX can have any type for the + // * key and value. So we need to just check the + // * type each time and call the right function + switch v := any(t).(type) { + case string: + stream.WriteString(v) + case *Variant: + stream.WriteVariant(v) + default: + // * Writer functions don't return errors so just log here. + // * The client will disconnect but the server won't die, + // * that way other clients stay connected, but we still + // * have a log of what the error was + fmt.Printf("Unsupported Map type trying to be written: %T\n", v) + } +} + +// StreamWriteMap writes a Map type to a StreamOut +// +// Implemented as a separate function to utilize generics +func StreamWriteMap[K comparable, V any](stream *StreamOut, m map[K]V) { + count := len(m) + + stream.WriteUInt32LE(uint32(count)) + + for key, value := range m { + mapTypeWriter(stream, key) + mapTypeWriter(stream, value) + } +} From abb84681093dfadfaf38be946855f782a0638829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 18 Nov 2023 21:43:08 +0000 Subject: [PATCH 034/178] test: Update new functions --- test/auth.go | 10 +++++----- test/generate_ticket.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/auth.go b/test/auth.go index 468040bc..1100c0cf 100644 --- a/test/auth.go +++ b/test/auth.go @@ -60,8 +60,8 @@ func login(packet nex.PRUDPPacketInterface) { } retval := nex.NewResultSuccess(0x00010001) - pidPrincipal := uint32(converted) - pbufResponse := generateTicket(pidPrincipal, 2) + pidPrincipal := nex.NewPID(uint32(converted)) + pbufResponse := generateTicket(pidPrincipal, nex.NewPID[uint32](2)) pConnectionData := nex.NewRVConnectionData() strReturnMsg := "Test Build" @@ -73,7 +73,7 @@ func login(packet nex.PRUDPPacketInterface) { responseStream := nex.NewStreamOut(authServer) responseStream.WriteResult(retval) - responseStream.WriteUInt32LE(pidPrincipal) + responseStream.WritePID(pidPrincipal) responseStream.WriteBuffer(pbufResponse) responseStream.WriteStructure(pConnectionData) responseStream.WriteString(strReturnMsg) @@ -110,12 +110,12 @@ func requestTicket(packet nex.PRUDPPacketInterface) { parametersStream := nex.NewStreamIn(parameters, authServer) - idSource, err := parametersStream.ReadUInt32LE() + idSource, err := parametersStream.ReadPID() if err != nil { panic(err) } - idTarget, err := parametersStream.ReadUInt32LE() + idTarget, err := parametersStream.ReadPID() if err != nil { panic(err) } diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 19ed91ff..6ceb2839 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -6,7 +6,7 @@ import ( "github.com/PretendoNetwork/nex-go" ) -func generateTicket(userPID uint32, targetPID uint32) []byte { +func generateTicket(userPID *nex.PID, targetPID *nex.PID) []byte { userKey := nex.DeriveKerberosKey(userPID, []byte("abcdefghijklmnop")) targetKey := nex.DeriveKerberosKey(targetPID, []byte("password")) sessionKey := make([]byte, authServer.KerberosKeySize()) From bb7c025cd89fe1c858785b5914c1cc6b9635c31f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 18 Nov 2023 22:26:52 +0000 Subject: [PATCH 035/178] rmc: Minor fixes --- rmc_message.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/rmc_message.go b/rmc_message.go index 30b0eb31..75d303f5 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -19,12 +19,12 @@ type RMCMessage struct { func (rmc *RMCMessage) Copy() *RMCMessage { copied := NewRMCMessage() - copied.IsRequest = copied.IsRequest - copied.IsSuccess = copied.IsSuccess - copied.ProtocolID = copied.ProtocolID - copied.CallID = copied.CallID - copied.MethodID = copied.MethodID - copied.ErrorCode = copied.ErrorCode + copied.IsRequest = rmc.IsRequest + copied.IsSuccess = rmc.IsSuccess + copied.ProtocolID = rmc.ProtocolID + copied.CallID = rmc.CallID + copied.MethodID = rmc.MethodID + copied.ErrorCode = rmc.ErrorCode if rmc.Parameters != nil { copied.Parameters = append([]byte(nil), rmc.Parameters...) @@ -178,8 +178,8 @@ func NewRMCMessage() *RMCMessage { } // NewRMCRequest returns a new blank RMCRequest -func NewRMCRequest() RMCMessage { - return RMCMessage{IsRequest: true} +func NewRMCRequest() *RMCMessage { + return &RMCMessage{IsRequest: true} } // NewRMCSuccess returns a new RMC Message configured as a success response From cc34633ae6f1cc10b5fc405aab47338db4b6297d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 18 Nov 2023 22:28:54 +0000 Subject: [PATCH 036/178] hpp: Initial support We aren't logging errors at the moment. --- hpp_client.go | 38 ++++++++ hpp_packet.go | 156 ++++++++++++++++++++++++++++++++ hpp_server.go | 240 ++++++++++++++++++++++++++++++++++++++++++++++++- rmc_message.go | 14 ++- test/hpp.go | 113 +++++++++++++++++++++++ test/main.go | 1 + 6 files changed, 556 insertions(+), 6 deletions(-) create mode 100644 hpp_client.go create mode 100644 hpp_packet.go create mode 100644 test/hpp.go diff --git a/hpp_client.go b/hpp_client.go new file mode 100644 index 00000000..9f47432f --- /dev/null +++ b/hpp_client.go @@ -0,0 +1,38 @@ +package nex + +import "net" + +// HPPClient represents a single HPP client +type HPPClient struct { + address *net.TCPAddr + server *HPPServer + pid *PID +} + +// Server returns the server the client is connecting to +func (c *HPPClient) Server() ServerInterface { + return c.server +} + +// Address returns the clients address as a net.Addr +func (c *HPPClient) Address() net.Addr { + return c.address +} + +// PID returns the clients NEX PID +func (c *HPPClient) PID() *PID { + return c.pid +} + +// SetPID sets the clients NEX PID +func (c *HPPClient) SetPID(pid *PID) { + c.pid = pid +} + +// NewHPPClient creates and returns a new Client using the provided IP address and server +func NewHPPClient(address *net.TCPAddr, server *HPPServer) *HPPClient { + return &HPPClient{ + address: address, + server: server, + } +} diff --git a/hpp_packet.go b/hpp_packet.go new file mode 100644 index 00000000..4c9e95fa --- /dev/null +++ b/hpp_packet.go @@ -0,0 +1,156 @@ +package nex + +import ( + "bytes" + "crypto/hmac" + "crypto/md5" + "encoding/hex" + "errors" + "fmt" +) + +// HPPPacket holds all the data about an HPP request +type HPPPacket struct { + sender *HPPClient + accessKeySignature []byte + passwordSignature []byte + payload []byte + message *RMCMessage + processed chan bool +} + +// Sender returns the Client who sent the packet +func (p *HPPPacket) Sender() ClientInterface { + return p.sender +} + +// Payload returns the packets payload +func (p *HPPPacket) Payload() []byte { + return p.payload +} + +// SetPayload sets the packets payload +func (p *HPPPacket) SetPayload(payload []byte) { + p.payload = payload +} + +func (p *HPPPacket) validateAccessKeySignature(signature string) error { + signatureBytes, err := hex.DecodeString(signature) + if err != nil { + return fmt.Errorf("Failed to decode access key signature. %s", err) + } + + p.accessKeySignature = signatureBytes + + calculatedSignature, err := p.calculateAccessKeySignature() + if err != nil { + return fmt.Errorf("Failed to calculate access key signature. %s", err) + } + + if !bytes.Equal(calculatedSignature, p.accessKeySignature) { + return errors.New("Access key signature does not match") + } + + return nil +} + +func (p *HPPPacket) calculateAccessKeySignature() ([]byte, error) { + accessKey := p.Sender().Server().AccessKey() + + accessKeyBytes, err := hex.DecodeString(accessKey) + if err != nil { + return nil, err + } + + signature, err := p.calculateSignature(p.payload, accessKeyBytes) + if err != nil { + return nil, err + } + + return signature, nil +} + +func (p *HPPPacket) validatePasswordSignature(signature string) error { + signatureBytes, err := hex.DecodeString(signature) + if err != nil { + return fmt.Errorf("Failed to decode password signature. %s", err) + } + + p.passwordSignature = signatureBytes + + calculatedSignature, err := p.calculatePasswordSignature() + if err != nil { + return fmt.Errorf("Failed to calculate password signature. %s", err) + } + + if !bytes.Equal(calculatedSignature, p.passwordSignature) { + return errors.New("Password signature does not match") + } + + return nil +} + +func (p *HPPPacket) calculatePasswordSignature() ([]byte, error) { + passwordFromPID := p.Sender().Server().(*HPPServer).PasswordFromPID + if passwordFromPID == nil { + return nil, errors.New("Missing PasswordFromPID") + } + + pid := p.Sender().PID() + password, _ := passwordFromPID(pid) + if password == "" { + return nil, errors.New("PID does not exist") + } + + key := DeriveKerberosKey(pid, []byte(password)) + + signature, err := p.calculateSignature(p.payload, key) + if err != nil { + return nil, err + } + + return signature, nil +} + +func (p *HPPPacket) calculateSignature(buffer []byte, key []byte) ([]byte, error) { + mac := hmac.New(md5.New, key) + + _, err := mac.Write(buffer) + if err != nil { + return nil, err + } + + hmac := mac.Sum(nil) + + return hmac, nil +} + +// RMCMessage returns the packets RMC Message +func (p *HPPPacket) RMCMessage() *RMCMessage { + return p.message +} + +// SetRMCMessage sets the packets RMC Message +func (p *HPPPacket) SetRMCMessage(message *RMCMessage) { + p.message = message +} + +func NewHPPPacket(client *HPPClient, payload []byte) (*HPPPacket, error) { + hppPacket := &HPPPacket{ + sender: client, + payload: payload, + processed: make(chan bool), + } + + if payload != nil { + rmcMessage := NewRMCRequest() + err := rmcMessage.FromBytes(payload) + if err != nil { + return nil, fmt.Errorf("Failed to decode HPP request. %s", err) + } + + hppPacket.SetRMCMessage(rmcMessage) + } + + return hppPacket, nil +} diff --git a/hpp_server.go b/hpp_server.go index c9d11ef9..0a8809a6 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -1,9 +1,245 @@ package nex +import ( + "fmt" + "net" + "net/http" + "strconv" +) + // HPPServer represents a bare-bones HPP server -type HPPServer struct{} +type HPPServer struct { + accessKey string + version *LibraryVersion + datastoreProtocolVersion *LibraryVersion + matchMakingProtocolVersion *LibraryVersion + rankingProtocolVersion *LibraryVersion + ranking2ProtocolVersion *LibraryVersion + messagingProtocolVersion *LibraryVersion + utilityProtocolVersion *LibraryVersion + natTraversalProtocolVersion *LibraryVersion + dataHandlers []func(packet PacketInterface) + PasswordFromPID func(pid *PID) (string, uint32) +} + +// OnData adds an event handler which is fired when a new HPP request is received +func (s *HPPServer) OnData(handler func(packet PacketInterface)) { + s.dataHandlers = append(s.dataHandlers, handler) +} + +func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + w.WriteHeader(http.StatusBadRequest) + return + } + + pidValue := req.Header.Get("pid") + if pidValue == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + // * The server checks that the header exists, but doesn't verify the value + token := req.Header.Get("token") + if token == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + accessKeySignature := req.Header.Get("signature1") + if accessKeySignature == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + passwordSignature := req.Header.Get("signature2") + if passwordSignature == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + pid, err := strconv.Atoi(pidValue) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + rmcRequestString := req.FormValue("file") + + rmcRequestBytes := []byte(rmcRequestString) + + tcpAddr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr) + if err != nil { + // * Should never happen? + w.WriteHeader(http.StatusBadRequest) + return + } + + client := NewHPPClient(tcpAddr, s) + client.SetPID(NewPID(uint32(pid))) + + hppPacket, err := NewHPPPacket(client, rmcRequestBytes) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + err = hppPacket.validateAccessKeySignature(accessKeySignature) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + err = hppPacket.validatePasswordSignature(passwordSignature) + if err != nil { + rmcMessage := hppPacket.RMCMessage() + + // HPP returns PythonCore::ValidationError if password is missing or invalid + errorResponse := NewRMCError(Errors.PythonCore.ValidationError) + errorResponse.CallID = rmcMessage.CallID + + _, _ = w.Write(errorResponse.Bytes()) + // if err != nil { + // logger.Error(err.Error()) + // } + + return + } + + for _, dataHandler := range s.dataHandlers { + go dataHandler(hppPacket) + } + + <- hppPacket.processed + + if len(hppPacket.payload) > 0 { + _, _ = w.Write(hppPacket.payload) + // if err != nil { + // logger.Error(err.Error()) + // } + } +} + +// Listen starts a HPP server on a given port +func (s *HPPServer) Listen(port int) { + http.HandleFunc("/hpp/", s.handleRequest) + + err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil) + if err != nil { + panic(err) + } +} + +// Send sends the packet to the packets sender +func (s *HPPServer) Send(packet PacketInterface) { + if packet, ok := packet.(*HPPPacket); ok { + packet.message.IsHPP = true + packet.payload = packet.message.Bytes() + + packet.processed <- true + } +} + +// AccessKey returns the servers sandbox access key +func (s *HPPServer) AccessKey() string { + return s.accessKey +} + +// SetAccessKey sets the servers sandbox access key +func (s *HPPServer) SetAccessKey(accessKey string) { + s.accessKey = accessKey +} + +// LibraryVersion returns the server NEX version +func (s *HPPServer) LibraryVersion() *LibraryVersion { + return s.version +} + +// SetDefaultLibraryVersion sets the default NEX protocol versions +func (s *HPPServer) SetDefaultLibraryVersion(version *LibraryVersion) { + s.version = version + s.datastoreProtocolVersion = version.Copy() + s.matchMakingProtocolVersion = version.Copy() + s.rankingProtocolVersion = version.Copy() + s.ranking2ProtocolVersion = version.Copy() + s.messagingProtocolVersion = version.Copy() + s.utilityProtocolVersion = version.Copy() + s.natTraversalProtocolVersion = version.Copy() +} + +// DataStoreProtocolVersion returns the servers DataStore protocol version +func (s *HPPServer) DataStoreProtocolVersion() *LibraryVersion { + return s.datastoreProtocolVersion +} + +// SetDataStoreProtocolVersion sets the servers DataStore protocol version +func (s *HPPServer) SetDataStoreProtocolVersion(version *LibraryVersion) { + s.datastoreProtocolVersion = version +} + +// MatchMakingProtocolVersion returns the servers MatchMaking protocol version +func (s *HPPServer) MatchMakingProtocolVersion() *LibraryVersion { + return s.matchMakingProtocolVersion +} + +// SetMatchMakingProtocolVersion sets the servers MatchMaking protocol version +func (s *HPPServer) SetMatchMakingProtocolVersion(version *LibraryVersion) { + s.matchMakingProtocolVersion = version +} + +// RankingProtocolVersion returns the servers Ranking protocol version +func (s *HPPServer) RankingProtocolVersion() *LibraryVersion { + return s.rankingProtocolVersion +} + +// SetRankingProtocolVersion sets the servers Ranking protocol version +func (s *HPPServer) SetRankingProtocolVersion(version *LibraryVersion) { + s.rankingProtocolVersion = version +} + +// Ranking2ProtocolVersion returns the servers Ranking2 protocol version +func (s *HPPServer) Ranking2ProtocolVersion() *LibraryVersion { + return s.ranking2ProtocolVersion +} + +// SetRanking2ProtocolVersion sets the servers Ranking2 protocol version +func (s *HPPServer) SetRanking2ProtocolVersion(version *LibraryVersion) { + s.ranking2ProtocolVersion = version +} + +// MessagingProtocolVersion returns the servers Messaging protocol version +func (s *HPPServer) MessagingProtocolVersion() *LibraryVersion { + return s.messagingProtocolVersion +} + +// SetMessagingProtocolVersion sets the servers Messaging protocol version +func (s *HPPServer) SetMessagingProtocolVersion(version *LibraryVersion) { + s.messagingProtocolVersion = version +} + +// UtilityProtocolVersion returns the servers Utility protocol version +func (s *HPPServer) UtilityProtocolVersion() *LibraryVersion { + return s.utilityProtocolVersion +} + +// SetUtilityProtocolVersion sets the servers Utility protocol version +func (s *HPPServer) SetUtilityProtocolVersion(version *LibraryVersion) { + s.utilityProtocolVersion = version +} + +// SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version +func (s *HPPServer) SetNATTraversalProtocolVersion(version *LibraryVersion) { + s.natTraversalProtocolVersion = version +} + +// NATTraversalProtocolVersion returns the servers NAT Traversal protocol version +func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { + return s.natTraversalProtocolVersion +} // NewHPPServer returns a new HPP server func NewHPPServer() *HPPServer { - return &HPPServer{} + return &HPPServer{ + dataHandlers: make([]func(packet PacketInterface), 0), + } } diff --git a/rmc_message.go b/rmc_message.go index 75d303f5..4cc24f3e 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -9,6 +9,7 @@ import ( type RMCMessage struct { IsRequest bool // * Indicates if the message is a request message (true) or response message (false) IsSuccess bool // * Indicates if the message is a success message (true) for a response message + IsHPP bool // * Indicates if the message is an HPP message ProtocolID uint16 // * Protocol ID of the message CallID uint32 // * Call ID associated with the message MethodID uint32 // * Method ID in the requested protocol @@ -131,11 +132,16 @@ func (rmc *RMCMessage) Bytes() []byte { protocolIDFlag = 0 } - if rmc.ProtocolID < 0x80 { + // * HPP does not include the protocol ID on the response. We technically + // * don't have to support converting HPP requests to bytes but we'll + // * do it for accuracy. + if !rmc.IsHPP || (rmc.IsHPP && rmc.IsRequest) { + if rmc.ProtocolID < 0x80 { stream.WriteUInt8(uint8(rmc.ProtocolID | protocolIDFlag)) - } else { - stream.WriteUInt8(uint8(0x7F | protocolIDFlag)) - stream.WriteUInt16LE(rmc.ProtocolID) + } else { + stream.WriteUInt8(uint8(0x7F | protocolIDFlag)) + stream.WriteUInt16LE(rmc.ProtocolID) + } } if rmc.IsRequest { diff --git a/test/hpp.go b/test/hpp.go new file mode 100644 index 00000000..1f00a0d9 --- /dev/null +++ b/test/hpp.go @@ -0,0 +1,113 @@ +package main + +import ( + "fmt" + + "github.com/PretendoNetwork/nex-go" +) + +var hppServer *nex.HPPServer + +// * Took these structs out of the protocols lib for convenience + +type DataStoreGetNotificationURLParam struct { + nex.Structure + PreviousURL string +} + +func (dataStoreGetNotificationURLParam *DataStoreGetNotificationURLParam) ExtractFromStream(stream *nex.StreamIn) error { + var err error + + dataStoreGetNotificationURLParam.PreviousURL, err = stream.ReadString() + if err != nil { + return fmt.Errorf("Failed to extract DataStoreGetNotificationURLParam.PreviousURL. %s", err.Error()) + } + + return nil +} + +type DataStoreReqGetNotificationURLInfo struct { + nex.Structure + URL string + Key string + Query string + RootCACert []byte +} + +func (dataStoreReqGetNotificationURLInfo *DataStoreReqGetNotificationURLInfo) Bytes(stream *nex.StreamOut) []byte { + stream.WriteString(dataStoreReqGetNotificationURLInfo.URL) + stream.WriteString(dataStoreReqGetNotificationURLInfo.Key) + stream.WriteString(dataStoreReqGetNotificationURLInfo.Query) + stream.WriteBuffer(dataStoreReqGetNotificationURLInfo.RootCACert) + + return stream.Bytes() +} + +func passwordFromPID(pid *nex.PID) (string, uint32) { + return "notmypassword", 0 +} + +func startHPPServer() { + fmt.Println("Starting HPP") + + hppServer = nex.NewHPPServer() + + hppServer.OnData(func(packet nex.PacketInterface) { + if packet, ok := packet.(*nex.HPPPacket); ok { + request := packet.RMCMessage() + + fmt.Println("[HPP]", request.ProtocolID, request.MethodID) + + if request.ProtocolID == 0x73 { // * DataStore + if request.MethodID == 0xD { + getNotificationURL(packet) + } + } + } + }) + + hppServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(2, 4, 1)) + hppServer.SetAccessKey("76f26496") + hppServer.PasswordFromPID = passwordFromPID + + hppServer.Listen(8085) +} + +func getNotificationURL(packet *nex.HPPPacket) { + request := packet.RMCMessage() + response := nex.NewRMCMessage() + + parameters := request.Parameters + + parametersStream := nex.NewStreamIn(parameters, hppServer) + + param, err := nex.StreamReadStructure(parametersStream, &DataStoreGetNotificationURLParam{}) + if err != nil { + fmt.Println("[HPP]", err) + return + } + + fmt.Println("[HPP]", param.PreviousURL) + + responseStream := nex.NewStreamOut(hppServer) + + info := &DataStoreReqGetNotificationURLInfo{} + info.URL = "https://example.com" + info.Key = "whatever/key" + info.Query = "?pretendo=1" + + responseStream.WriteStructure(info) + + response.IsSuccess = true + response.IsRequest = false + response.ErrorCode = 0x00010001 + response.ProtocolID = request.ProtocolID + response.CallID = request.CallID + response.MethodID = request.MethodID + response.Parameters = responseStream.Bytes() + + // * We replace the RMC message so that it can be delivered back + packet.SetRMCMessage(response) + + hppServer.Send(packet) +} diff --git a/test/main.go b/test/main.go index d559d002..12dd8a03 100644 --- a/test/main.go +++ b/test/main.go @@ -9,6 +9,7 @@ func main() { go startAuthenticationServer() go startSecureServer() + go startHPPServer() wg.Wait() } From d2d40aad485aaeb05f1addf39e5b68f9908e63a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 18 Nov 2023 22:32:05 +0000 Subject: [PATCH 037/178] stream_out: Remove commented code --- stream_out.go | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/stream_out.go b/stream_out.go index cef0c2f6..8bb6af18 100644 --- a/stream_out.go +++ b/stream_out.go @@ -461,36 +461,6 @@ func (stream *StreamOut) WriteVariant(variant *Variant) { stream.WriteBytesNext(content) } -/* -// WriteMap writes a Map type with the given key and value types -func (stream *StreamOut) WriteMap(mapType interface{}) { - // TODO: - // Find a better solution that doesn't use reflect - - mapValue := reflect.ValueOf(mapType) - count := mapValue.Len() - - stream.WriteUInt32LE(uint32(count)) - - mapIter := mapValue.MapRange() - - for mapIter.Next() { - key := mapIter.Key().Interface() - value := mapIter.Value().Interface() - - switch key := key.(type) { - case string: - stream.WriteString(key) - } - - switch value := value.(type) { - case *Variant: - stream.WriteVariant(value) - } - } -} -*/ - // NewStreamOut returns a new nex output stream func NewStreamOut(server ServerInterface) *StreamOut { return &StreamOut{ From 8a6fb8356b4c5608946f2213a93d3f9f70a8588b Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 18 Nov 2023 22:27:08 -0500 Subject: [PATCH 038/178] added back in plogger --- hpp_server.go | 23 ++++++++++++++--------- init.go | 4 ++++ prudp_server.go | 12 +++++++++--- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/hpp_server.go b/hpp_server.go index 0a8809a6..b2727537 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -71,6 +71,7 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { tcpAddr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr) if err != nil { // * Should never happen? + logger.Error(err.Error()) w.WriteHeader(http.StatusBadRequest) return } @@ -80,28 +81,32 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { hppPacket, err := NewHPPPacket(client, rmcRequestBytes) if err != nil { + logger.Error(err.Error()) w.WriteHeader(http.StatusBadRequest) return } err = hppPacket.validateAccessKeySignature(accessKeySignature) if err != nil { + logger.Error(err.Error()) w.WriteHeader(http.StatusBadRequest) return } err = hppPacket.validatePasswordSignature(passwordSignature) if err != nil { + logger.Error(err.Error()) + rmcMessage := hppPacket.RMCMessage() // HPP returns PythonCore::ValidationError if password is missing or invalid errorResponse := NewRMCError(Errors.PythonCore.ValidationError) errorResponse.CallID = rmcMessage.CallID - _, _ = w.Write(errorResponse.Bytes()) - // if err != nil { - // logger.Error(err.Error()) - // } + _, err = w.Write(errorResponse.Bytes()) + if err != nil { + logger.Error(err.Error()) + } return } @@ -110,13 +115,13 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { go dataHandler(hppPacket) } - <- hppPacket.processed + <-hppPacket.processed if len(hppPacket.payload) > 0 { - _, _ = w.Write(hppPacket.payload) - // if err != nil { - // logger.Error(err.Error()) - // } + _, err = w.Write(hppPacket.payload) + if err != nil { + logger.Error(err.Error()) + } } } diff --git a/init.go b/init.go index 178e8e3f..c69cf42c 100644 --- a/init.go +++ b/init.go @@ -1,5 +1,9 @@ package nex +import "github.com/PretendoNetwork/plogger-go" + +var logger = plogger.NewLogger() + func init() { initErrorsData() } diff --git a/prudp_server.go b/prudp_server.go index 12a2751e..e805ba8e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -273,7 +273,10 @@ func (s *PRUDPServer) handleSyn(packet PRUDPPacketInterface) { ack, _ = NewPRUDPPacketV1(client, nil) } - connectionSignature, _ := packet.calculateConnectionSignature(client.address) + connectionSignature, err := packet.calculateConnectionSignature(client.address) + if err != nil { + logger.Error(err.Error()) + } client.reset() client.clientConnectionSignature = connectionSignature @@ -317,7 +320,10 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { client.serverConnectionSignature = packet.getConnectionSignature() - connectionSignature, _ := packet.calculateConnectionSignature(client.address) + connectionSignature, err := packet.calculateConnectionSignature(client.address) + if err != nil { + logger.Error(err.Error()) + } ack.SetType(ConnectPacket) ack.AddFlag(FlagAck) @@ -351,7 +357,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { if s.IsSecureServer { sessionKey, pid, checkValue, err := s.readKerberosTicket(packet.Payload()) if err != nil { - fmt.Println(err) + logger.Error(err.Error()) } client.SetPID(pid) From d885afa5f7ff8154d92d5dfe14e775d227a5bf54 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 18 Nov 2023 22:29:23 -0500 Subject: [PATCH 039/178] added logger to StreamOut --- stream_out.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stream_out.go b/stream_out.go index 8bb6af18..26a66712 100644 --- a/stream_out.go +++ b/stream_out.go @@ -1,8 +1,6 @@ package nex import ( - "fmt" - crunch "github.com/superwhiskers/crunch/v3" ) @@ -496,7 +494,7 @@ func mapTypeWriter[T any](stream *StreamOut, t T) { // * The client will disconnect but the server won't die, // * that way other clients stay connected, but we still // * have a log of what the error was - fmt.Printf("Unsupported Map type trying to be written: %T\n", v) + logger.Warningf("Unsupported Map type trying to be written: %T\n", v) } } From a880ec15663d3fe78b09295d983e19dc164e368a Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 18 Nov 2023 22:30:42 -0500 Subject: [PATCH 040/178] removed debug fmt logs --- prudp_server.go | 30 ------------------------------ resend_scheduler.go | 25 ------------------------- 2 files changed, 55 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index e805ba8e..6e859998 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -151,12 +151,6 @@ func (s *PRUDPServer) handleSocketMessage() error { var packets []PRUDPPacketInterface - if s.IsSecureServer { - fmt.Printf("[SECR] Got packet data %x\n", packetData) - } else { - fmt.Printf("[AUTH] Got packet data %x\n", packetData) - } - // * Support any packet type the client sends and respond // * with that same type. Also keep reading from the stream // * until no more data is left, to account for multiple @@ -202,12 +196,6 @@ func (s *PRUDPServer) handleAcknowledgment(packet PRUDPPacketInterface) { return } - if s.IsSecureServer { - fmt.Println("[SECR] Got ACK for SequenceID", packet.SequenceID()) - } else { - fmt.Println("[AUTH] Got ACK for SequenceID", packet.SequenceID()) - } - client := packet.Sender().(*PRUDPClient) substream := client.reliableSubstream(packet.SubstreamID()) @@ -603,24 +591,6 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { substream.ResendScheduler.AddPacket(packetCopy) } - if packetCopy.Type() == DataPacket && packetCopy.RMCMessage() != nil { - if s.IsSecureServer { - fmt.Println("[SECR] ======= SENDING =======") - fmt.Println("[SECR] ProtocolID:", packetCopy.RMCMessage().ProtocolID) - fmt.Println("[SECR] MethodID:", packetCopy.RMCMessage().MethodID) - fmt.Println("[SECR] FragmentID:", packetCopy.getFragmentID()) - fmt.Println("[SECR] SequenceID:", packetCopy.SequenceID()) - fmt.Println("[SECR] =======================") - } else { - fmt.Println("[AUTH] ======= SENDING =======") - fmt.Println("[AUTH] ProtocolID:", packetCopy.RMCMessage().ProtocolID) - fmt.Println("[AUTH] MethodID:", packetCopy.RMCMessage().MethodID) - fmt.Println("[AUTH] FragmentID:", packetCopy.getFragmentID()) - fmt.Println("[AUTH] SequenceID:", packetCopy.SequenceID()) - fmt.Println("[AUTH] =======================") - } - } - s.sendRaw(packetCopy.Sender().Address(), packetCopy.Bytes()) } diff --git a/resend_scheduler.go b/resend_scheduler.go index c2010a41..f351bacf 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -1,7 +1,6 @@ package nex import ( - "fmt" "time" ) @@ -67,12 +66,6 @@ func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { interval: rs.Interval, } - if packet.Sender().Server().(*PRUDPServer).IsSecureServer { - fmt.Println("[SECR] Adding packet", packet.SequenceID(), "to resend queue") - } else { - fmt.Println("[AUTH] Adding packet", packet.SequenceID(), "to resend queue") - } - rs.packets.Set(packet.SequenceID(), pendingPacket) go pendingPacket.startResendTimer() @@ -81,12 +74,6 @@ func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { // AcknowledgePacket marks a pending packet as acknowledged. It will be ignored at the next resend attempt func (rs *ResendScheduler) AcknowledgePacket(sequenceID uint16) { if pendingPacket, ok := rs.packets.Get(sequenceID); ok { - if pendingPacket.packet.Sender().Server().(*PRUDPServer).IsSecureServer { - fmt.Println("[SECR] Acknowledged", sequenceID) - } else { - fmt.Println("[AUTH] Acknowledged", sequenceID) - } - pendingPacket.isAcknowledged = true } } @@ -103,12 +90,6 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { if pendingPacket.resendCount >= rs.MaxResendCount { // * The maximum resend count has been reached, consider the client dead. - if pendingPacket.packet.Sender().Server().(*PRUDPServer).IsSecureServer { - fmt.Println("[SECR] Max resends hit for", pendingPacket.packet.SequenceID()) - } else { - fmt.Println("[AUTH] Max resends hit for", pendingPacket.packet.SequenceID()) - } - pendingPacket.ticker.Stop() rs.packets.Delete(packet.SequenceID()) client.cleanup() // * "removed" event is dispatched here @@ -117,12 +98,6 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { } if time.Since(pendingPacket.lastSendTime) >= rs.Interval { - if pendingPacket.packet.Sender().Server().(*PRUDPServer).IsSecureServer { - fmt.Println("[SECR] Resending packet", pendingPacket.packet.SequenceID()) - } else { - fmt.Println("[AUTH] Resending packet", pendingPacket.packet.SequenceID()) - } - // * Resend the packet to the client server := client.server data := packet.Bytes() From b897a9216865ceb11ca1f1e55173e971f9040c55 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 19 Nov 2023 15:23:30 -0500 Subject: [PATCH 041/178] prudp: add client substream fail-safes --- prudp_client.go | 15 ++++++++++++++- prudp_server.go | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/prudp_client.go b/prudp_client.go index 452a5b9d..3d955eb9 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -109,11 +109,24 @@ func (c *PRUDPClient) setSessionKey(sessionKey []byte) { // reliableSubstream returns the clients reliable substream ID func (c *PRUDPClient) reliableSubstream(substreamID uint8) *ReliablePacketSubstreamManager { - return c.reliableSubstreams[substreamID] + // * Fail-safe. The client may not always have + // * the correct number of substreams. See the + // * comment in handleSocketMessage of PRUDPServer + // * for more details + if int(substreamID) >= len(c.reliableSubstreams) { + return c.reliableSubstreams[0] + } else { + return c.reliableSubstreams[substreamID] + } } // createReliableSubstreams creates the list of substreams used for reliable PRUDP packets func (c *PRUDPClient) createReliableSubstreams(maxSubstreamID uint8) { + // * Kill any existing substreams + for _, substream := range c.reliableSubstreams { + substream.ResendScheduler.Stop() + } + substreams := maxSubstreamID + 1 c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, substreams) diff --git a/prudp_server.go b/prudp_server.go index 6e859998..299ec49a 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -143,6 +143,27 @@ func (s *PRUDPServer) handleSocketMessage() error { client = NewPRUDPClient(addr, s) client.startHeartbeat() + // * Fail-safe. If the server reboots, then + // * s.clients has no record of old clients. + // * An existing client which has not killed + // * the connection on it's end MAY still send + // * DATA packets once the server is back + // * online, assuming it reboots fast enough. + // * Since the client did NOT redo the SYN + // * and CONNECT packets, it's reliable + // * substreams never got remade. This is put + // * in place to ensure there is always AT + // * LEAST one substream in place, so the client + // * can naturally error out due to the RC4 + // * errors. + // * + // * NOTE: THE CLIENT MAY NOT HAVE THE REAL + // * CORRECT NUMBER OF SUBSTREAMS HERE. THIS + // * IS ONLY DONE TO PREVENT A SERVER CRASH, + // * NOT TO SAVE THE CLIENT. THE CLIENT IS + // * EXPECTED TO NATURALLY DIE HERE + client.createReliableSubstreams(0) + s.clients.Set(discriminator, client) } From eb23cb7776aa8a9584cdea43b242bd1266a4bfba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Mon, 20 Nov 2023 22:48:45 +0000 Subject: [PATCH 042/178] types: Get 64 bits for PID on StationURL --- types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/types.go b/types.go index 0fc41833..f245193b 100644 --- a/types.go +++ b/types.go @@ -880,7 +880,7 @@ func (stationURL *StationURL) FromString(str string) { ui64, _ := strconv.ParseUint(value, 10, 32) stationURL.SetCID(uint32(ui64)) case "PID": - ui64, _ := strconv.ParseUint(value, 10, 32) + ui64, _ := strconv.ParseUint(value, 10, 64) stationURL.SetPID(NewPID(ui64)) case "type": ui64, _ := strconv.ParseUint(value, 10, 32) From 47fd1e1be673db54d2691df5fcd0f24921bb93cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Mon, 20 Nov 2023 22:50:18 +0000 Subject: [PATCH 043/178] Fix some linting issues --- README.md | 22 ++++++++++++---------- hpp_packet.go | 1 + prudp_server.go | 18 ++++++------------ rmc_message.go | 2 ++ stream_in.go | 2 +- stream_out.go | 2 +- 6 files changed, 23 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 775ec6ba..32c4c42d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ ### Usage note -This module provides a barebones PRUDP server for use with titles using the Nintendo NEX library. It does not provide any support for titles using the original Rendez-Vous library developed by Quazal. This library only provides the low level packet data, as such it is recommended to use [NEX Protocols Go](https://github.com/PretendoNetwork/nex-protocols-go) to develop servers. +This module provides a barebones PRUDP server for use with titles using the Nintendo NEX library. It provides some support for titles using the original Rendez-Vous library developed by Quazal. This library only provides the low level packet data, as such it is recommended to use [NEX Protocols Go](https://github.com/PretendoNetwork/nex-protocols-go) to develop servers. ### Usage @@ -28,21 +28,23 @@ import ( ) func main() { - nexServer := nex.NewServer() - nexServer.SetPrudpVersion(0) - nexServer.SetSignatureVersion(1) + nexServer := nex.NewPRUDPServer() + nexServer.PRUDPVersion = 0 + nexServer.SetFragmentSize(962) + nexServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) + nexServer.SetKerberosPassword([]byte("password")) nexServer.SetKerberosKeySize(16) nexServer.SetAccessKey("ridfebb9") - nexServer.On("Data", func(packet *nex.PacketV0) { - request := packet.RMCRequest() + nexServer.OnData(func(packet nex.PacketInterface) { + request := packet.RMCMessage() fmt.Println("==Friends - Auth==") - fmt.Printf("Protocol ID: %#v\n", request.ProtocolID()) - fmt.Printf("Method ID: %#v\n", request.MethodID()) + fmt.Printf("Protocol ID: %#v\n", request.ProtocolID) + fmt.Printf("Method ID: %#v\n", request.MethodID) fmt.Println("==================") }) - nexServer.Listen(":60000") + nexServer.Listen(60000) } -``` \ No newline at end of file +``` diff --git a/hpp_packet.go b/hpp_packet.go index 4c9e95fa..94caf5ab 100644 --- a/hpp_packet.go +++ b/hpp_packet.go @@ -135,6 +135,7 @@ func (p *HPPPacket) SetRMCMessage(message *RMCMessage) { p.message = message } +// NewHPPPacket creates and returns a new HPPPacket using the provided Client and payload func NewHPPPacket(client *HPPClient, payload []byte) (*HPPPacket, error) { hppPacket := &HPPPacket{ sender: client, diff --git a/prudp_server.go b/prudp_server.go index 299ec49a..f3f6cc32 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -16,7 +16,7 @@ type PRUDPServer struct { udpSocket *net.UDPConn clients *MutexMap[string, *PRUDPClient] PRUDPVersion int - protocolMinorVersion uint32 + PRUDPMinorVersion uint32 IsQuazalMode bool IsSecureServer bool SupportedFunctions uint32 @@ -617,7 +617,11 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { // sendRaw will send the given address the provided packet func (s *PRUDPServer) sendRaw(conn net.Addr, data []byte) { - s.udpSocket.WriteToUDP(data, conn.(*net.UDPAddr)) + _, err := s.udpSocket.WriteToUDP(data, conn.(*net.UDPAddr)) + if err != nil { + // TODO - Should this return the error too? + logger.Error(err.Error()) + } } // AccessKey returns the servers sandbox access key @@ -763,16 +767,6 @@ func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { return s.connectionIDCounter } -// SetProtocolMinorVersion sets the servers PRUDP protocol minor version -func (s *PRUDPServer) SetProtocolMinorVersion(protocolMinorVersion uint32) { - s.protocolMinorVersion = protocolMinorVersion -} - -// ProtocolMinorVersion returns the servers PRUDP protocol minor version -func (s *PRUDPServer) ProtocolMinorVersion() uint32 { - return s.protocolMinorVersion -} - // FindClientByConnectionID returns the PRUDP client connected with the given connection ID func (s *PRUDPServer) FindClientByConnectionID(connectedID uint32) *PRUDPClient { var client *PRUDPClient diff --git a/rmc_message.go b/rmc_message.go index 4cc24f3e..f9ac68ba 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -17,11 +17,13 @@ type RMCMessage struct { Parameters []byte // * Input for the method } +// Copy copies the message into a new RMCMessage func (rmc *RMCMessage) Copy() *RMCMessage { copied := NewRMCMessage() copied.IsRequest = rmc.IsRequest copied.IsSuccess = rmc.IsSuccess + copied.IsHPP = rmc.IsHPP copied.ProtocolID = rmc.ProtocolID copied.CallID = rmc.CallID copied.MethodID = rmc.MethodID diff --git a/stream_in.go b/stream_in.go index 7de138a4..d0a46d93 100644 --- a/stream_in.go +++ b/stream_in.go @@ -908,7 +908,7 @@ func StreamReadStructure[T StructureInterface](stream *StreamIn, structure T) (T var useStructureHeader bool switch server := stream.Server.(type) { case *PRUDPServer: // * Support QRV versions - useStructureHeader = server.ProtocolMinorVersion() >= 3 + useStructureHeader = server.PRUDPMinorVersion >= 3 default: useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") } diff --git a/stream_out.go b/stream_out.go index 26a66712..b3a2e6d0 100644 --- a/stream_out.go +++ b/stream_out.go @@ -187,7 +187,7 @@ func (stream *StreamOut) WriteStructure(structure StructureInterface) { var useStructures bool switch server := stream.Server.(type) { case *PRUDPServer: // * Support QRV versions - useStructures = server.ProtocolMinorVersion() >= 3 + useStructures = server.PRUDPMinorVersion >= 3 default: useStructures = server.LibraryVersion().GreaterOrEqual("3.5.0") } From ca1be76ffc15d10ca7c45ebdbb2b9aefc6819823 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 20 Nov 2023 18:25:39 -0500 Subject: [PATCH 044/178] removed remaining linter issues --- client_interface.go | 1 + prudp_server.go | 6 +++++- test/auth.go | 1 + test/generate_ticket.go | 5 ++++- test/hpp.go | 22 +++++++++++----------- test/secure.go | 24 ++++++++++-------------- 6 files changed, 32 insertions(+), 27 deletions(-) diff --git a/client_interface.go b/client_interface.go index 20cdece3..041f7555 100644 --- a/client_interface.go +++ b/client_interface.go @@ -1,3 +1,4 @@ +// Package nex provides a collection of utility structs, functions, and data types for making NEX/QRV servers package nex import "net" diff --git a/prudp_server.go b/prudp_server.go index f3f6cc32..9937fe6c 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -518,7 +518,11 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { if packet.getFragmentID() == 0 { message := NewRMCMessage() - message.FromBytes(payload) + err := message.FromBytes(payload) + if err != nil { + // TODO - Should this return the error too? + logger.Error(err.Error()) + } substream.ResetFragmentedPayload() diff --git a/test/auth.go b/test/auth.go index 1100c0cf..35fe872e 100644 --- a/test/auth.go +++ b/test/auth.go @@ -1,3 +1,4 @@ +// Package main implements a test server package main import ( diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 6ceb2839..1e49ee46 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -11,7 +11,10 @@ func generateTicket(userPID *nex.PID, targetPID *nex.PID) []byte { targetKey := nex.DeriveKerberosKey(targetPID, []byte("password")) sessionKey := make([]byte, authServer.KerberosKeySize()) - rand.Read(sessionKey) + _, err := rand.Read(sessionKey) + if err != nil { + panic(err) + } ticketInternalData := nex.NewKerberosTicketInternalData() serverTime := nex.NewDateTime(0).Now() diff --git a/test/hpp.go b/test/hpp.go index 1f00a0d9..4bcd4146 100644 --- a/test/hpp.go +++ b/test/hpp.go @@ -10,15 +10,15 @@ var hppServer *nex.HPPServer // * Took these structs out of the protocols lib for convenience -type DataStoreGetNotificationURLParam struct { +type dataStoreGetNotificationURLParam struct { nex.Structure PreviousURL string } -func (dataStoreGetNotificationURLParam *DataStoreGetNotificationURLParam) ExtractFromStream(stream *nex.StreamIn) error { +func (d *dataStoreGetNotificationURLParam) ExtractFromStream(stream *nex.StreamIn) error { var err error - dataStoreGetNotificationURLParam.PreviousURL, err = stream.ReadString() + d.PreviousURL, err = stream.ReadString() if err != nil { return fmt.Errorf("Failed to extract DataStoreGetNotificationURLParam.PreviousURL. %s", err.Error()) } @@ -26,7 +26,7 @@ func (dataStoreGetNotificationURLParam *DataStoreGetNotificationURLParam) Extrac return nil } -type DataStoreReqGetNotificationURLInfo struct { +type dataStoreReqGetNotificationURLInfo struct { nex.Structure URL string Key string @@ -34,11 +34,11 @@ type DataStoreReqGetNotificationURLInfo struct { RootCACert []byte } -func (dataStoreReqGetNotificationURLInfo *DataStoreReqGetNotificationURLInfo) Bytes(stream *nex.StreamOut) []byte { - stream.WriteString(dataStoreReqGetNotificationURLInfo.URL) - stream.WriteString(dataStoreReqGetNotificationURLInfo.Key) - stream.WriteString(dataStoreReqGetNotificationURLInfo.Query) - stream.WriteBuffer(dataStoreReqGetNotificationURLInfo.RootCACert) +func (d *dataStoreReqGetNotificationURLInfo) Bytes(stream *nex.StreamOut) []byte { + stream.WriteString(d.URL) + stream.WriteString(d.Key) + stream.WriteString(d.Query) + stream.WriteBuffer(d.RootCACert) return stream.Bytes() } @@ -81,7 +81,7 @@ func getNotificationURL(packet *nex.HPPPacket) { parametersStream := nex.NewStreamIn(parameters, hppServer) - param, err := nex.StreamReadStructure(parametersStream, &DataStoreGetNotificationURLParam{}) + param, err := nex.StreamReadStructure(parametersStream, &dataStoreGetNotificationURLParam{}) if err != nil { fmt.Println("[HPP]", err) return @@ -91,7 +91,7 @@ func getNotificationURL(packet *nex.HPPPacket) { responseStream := nex.NewStreamOut(hppServer) - info := &DataStoreReqGetNotificationURLInfo{} + info := &dataStoreReqGetNotificationURLInfo{} info.URL = "https://example.com" info.Key = "whatever/key" info.Query = "?pretendo=1" diff --git a/test/secure.go b/test/secure.go index d7fef980..cf92c4d6 100644 --- a/test/secure.go +++ b/test/secure.go @@ -11,7 +11,7 @@ var secureServer *nex.PRUDPServer // * Took these structs out of the protocols lib for convenience -type PrincipalPreference struct { +type principalPreference struct { nex.Structure *nex.Data ShowOnlinePresence bool @@ -19,7 +19,7 @@ type PrincipalPreference struct { BlockFriendRequests bool } -func (pp *PrincipalPreference) Bytes(stream *nex.StreamOut) []byte { +func (pp *principalPreference) Bytes(stream *nex.StreamOut) []byte { stream.WriteBool(pp.ShowOnlinePresence) stream.WriteBool(pp.ShowCurrentTitle) stream.WriteBool(pp.BlockFriendRequests) @@ -27,7 +27,7 @@ func (pp *PrincipalPreference) Bytes(stream *nex.StreamOut) []byte { return stream.Bytes() } -type Comment struct { +type comment struct { nex.Structure *nex.Data Unknown uint8 @@ -35,7 +35,7 @@ type Comment struct { LastChanged *nex.DateTime } -func (c *Comment) Bytes(stream *nex.StreamOut) []byte { +func (c *comment) Bytes(stream *nex.StreamOut) []byte { stream.WriteUInt8(c.Unknown) stream.WriteString(c.Contents) stream.WriteDateTime(c.LastChanged) @@ -146,22 +146,18 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() response := nex.NewRMCMessage() - principalPreference := &PrincipalPreference{ + responseStream := nex.NewStreamOut(authServer) + + responseStream.WriteStructure(&principalPreference{ ShowOnlinePresence: true, ShowCurrentTitle: true, BlockFriendRequests: false, - } - - comment := &Comment{ + }) + responseStream.WriteStructure(&comment{ Unknown: 0, Contents: "Rewrite Test", LastChanged: nex.NewDateTime(0), - } - - responseStream := nex.NewStreamOut(authServer) - - responseStream.WriteStructure(principalPreference) - responseStream.WriteStructure(comment) + }) responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendList) responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsOut) responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsIn) From 3f3c11597ee2eeedee94798ddf579f9a69e50feb Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 21 Nov 2023 12:08:40 -0500 Subject: [PATCH 045/178] qrv: packet payload compression (needs testing) --- prudp_server.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 9937fe6c..b7a2d228 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -2,6 +2,7 @@ package nex import ( "bytes" + "compress/zlib" "crypto/rand" "errors" "fmt" @@ -39,6 +40,7 @@ type PRUDPServer struct { pingTimeout time.Duration PasswordFromPID func(pid *PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte + CompressionEnabled bool } // OnData adds an event handler which is fired when a new DATA packet is received @@ -514,7 +516,13 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { for _, pendingPacket := range substream.Update(packet) { if packet.Type() == DataPacket { - payload := substream.AddFragment(pendingPacket.decryptPayload()) + decryptedPayload := pendingPacket.decryptPayload() + decompressedPayload, err := s.decompressPayload(decryptedPayload) + if err != nil { + logger.Error(err.Error()) + } + + payload := substream.AddFragment(decompressedPayload) if packet.getFragmentID() == 0 { message := NewRMCMessage() @@ -603,8 +611,15 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { if packetCopy.Type() == DataPacket && !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { if packetCopy.HasFlag(FlagReliable) { + payload := packetCopy.Payload() + compressedPayload, err := s.compressPayload(payload) + if err != nil { + logger.Error(err.Error()) + } + substream := client.reliableSubstream(packetCopy.SubstreamID()) - packetCopy.SetPayload(substream.Encrypt(packetCopy.Payload())) + + packetCopy.SetPayload(substream.Encrypt(compressedPayload)) } // TODO - Unreliable crypto } @@ -628,6 +643,80 @@ func (s *PRUDPServer) sendRaw(conn net.Addr, data []byte) { } } +func (s *PRUDPServer) decompressPayload(payload []byte) ([]byte, error) { + if !s.CompressionEnabled { + return payload, nil + } + + compressionRatio := payload[0] + compressed := payload[1:] + + if compressionRatio == 0 { + // * Compression ratio of 0 means no compression + return compressed, nil + } + + reader := bytes.NewReader(compressed) + decompressed := bytes.Buffer{} + + // * Create a zlib reader + zlibReader, err := zlib.NewReader(reader) + if err != nil { + return []byte{}, err + } + defer zlibReader.Close() + + // * Copy the decompressed payload into a buffer + _, err = decompressed.ReadFrom(zlibReader) + if err != nil { + return []byte{}, err + } + + decompressedBytes := decompressed.Bytes() + + ratioCheck := len(decompressedBytes)/len(compressed) + 1 + + if ratioCheck != int(compressionRatio) { + return []byte{}, fmt.Errorf("Failed to decompress payload. Got bad ratio. Expected %d, got %d", compressionRatio, ratioCheck) + } + + return decompressedBytes, nil +} + +func (s *PRUDPServer) compressPayload(payload []byte) ([]byte, error) { + if !s.CompressionEnabled { + return payload, nil + } + + compressed := bytes.Buffer{} + + // * Create a zlib writer with default compression level + zlibWriter := zlib.NewWriter(&compressed) + + _, err := zlibWriter.Write(payload) + if err != nil { + return []byte{}, err + } + + // * Close the zlib writer to flush any remaining data + err = zlibWriter.Close() + if err != nil { + return []byte{}, err + } + + compressedBytes := compressed.Bytes() + + compressionRatio := len(payload)/len(compressedBytes) + 1 + + stream := NewStreamOut(s) + + stream.WriteUInt8(uint8(compressionRatio)) + stream.Grow(int64(len(compressedBytes))) + stream.WriteBytesNext(compressedBytes) + + return stream.Bytes(), nil +} + // AccessKey returns the servers sandbox access key func (s *PRUDPServer) AccessKey() string { return s.accessKey From 4967d1c7c87ed11a5f2edeb27af3711140f01ef2 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 21 Nov 2023 12:19:01 -0500 Subject: [PATCH 046/178] qrv: check err from zlibReader --- prudp_server.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/prudp_server.go b/prudp_server.go index b7a2d228..7e0c73e1 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -664,7 +664,6 @@ func (s *PRUDPServer) decompressPayload(payload []byte) ([]byte, error) { if err != nil { return []byte{}, err } - defer zlibReader.Close() // * Copy the decompressed payload into a buffer _, err = decompressed.ReadFrom(zlibReader) @@ -672,6 +671,12 @@ func (s *PRUDPServer) decompressPayload(payload []byte) ([]byte, error) { return []byte{}, err } + // * Close the zlib reader to flush any remaining data + err = zlibReader.Close() + if err != nil { + return []byte{}, err + } + decompressedBytes := decompressed.Bytes() ratioCheck := len(decompressedBytes)/len(compressed) + 1 From 73921dbc3684b3f5cbb2814cad2982c41af5e70f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 21 Nov 2023 16:34:41 -0500 Subject: [PATCH 047/178] qrv: handle unreliable DATA packets --- prudp_client.go | 16 +++++++++- prudp_packet.go | 18 +++++++++++ prudp_packet_interface.go | 1 + prudp_server.go | 65 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 97 insertions(+), 3 deletions(-) diff --git a/prudp_client.go b/prudp_client.go index 3d955eb9..419f57fe 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -1,6 +1,7 @@ package nex import ( + "crypto/md5" "net" "time" ) @@ -26,6 +27,7 @@ type PRUDPClient struct { supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? ConnectionID uint32 StationURLs []*StationURL + unreliableBaseKey []byte } // reset sets the client back to it's default state @@ -38,7 +40,7 @@ func (c *PRUDPClient) reset() { c.serverConnectionSignature = make([]byte, 0) c.sessionKey = make([]byte, 0) c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) - c.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](0) + c.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](1) c.outgoingPingSequenceIDCounter = NewCounter[uint16](0) c.SourceStreamType = 0 c.SourcePort = 0 @@ -105,6 +107,17 @@ func (c *PRUDPClient) setSessionKey(sessionKey []byte) { substream.SetCipherKey(sessionKey) } + + // * Init the base key used for unreliable DATA packets. + // * + // * Since unreliable DATA packets can come in out of + // * order, each packet uses a dedicated RC4 stream. The + // * key of each RC4 stream is made up by using this base + // * key, modified using the packets sequence/session IDs + unreliableBaseKeyPart1 := md5.Sum(append(sessionKey, []byte{0x18, 0xD8, 0x23, 0x34, 0x37, 0xE4, 0xE3, 0xFE}...)) + unreliableBaseKeyPart2 := md5.Sum(append(sessionKey, []byte{0x23, 0x3E, 0x60, 0x01, 0x23, 0xCD, 0xAB, 0x80}...)) + + c.unreliableBaseKey = append(unreliableBaseKeyPart1[:], unreliableBaseKeyPart2[:]...) } // reliableSubstream returns the clients reliable substream ID @@ -194,5 +207,6 @@ func NewPRUDPClient(address *net.UDPAddr, server *PRUDPServer) *PRUDPClient { server: server, outgoingPingSequenceIDCounter: NewCounter[uint16](0), pid: NewPID[uint32](0), + unreliableBaseKey: make([]byte, 0x20), } } diff --git a/prudp_packet.go b/prudp_packet.go index 81f24359..7c4a9f77 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -1,5 +1,7 @@ package nex +import "crypto/rc4" + // PRUDPPacket holds all the fields each packet should have in all PRUDP versions type PRUDPPacket struct { sender *PRUDPClient @@ -172,3 +174,19 @@ func (p *PRUDPPacket) RMCMessage() *RMCMessage { func (p *PRUDPPacket) SetRMCMessage(message *RMCMessage) { p.message = message } + +func (p *PRUDPPacket) processUnreliableCrypto() []byte { + // * Since unreliable DATA packets can come in out of + // * order, each packet uses a dedicated RC4 stream + uniqueKey := p.sender.unreliableBaseKey[:] + uniqueKey[0] = byte((uint16(uniqueKey[0]) + p.sequenceID) & 0xFF) + uniqueKey[1] = byte((uint16(uniqueKey[1]) + (p.sequenceID >> 8)) & 0xFF) + uniqueKey[31] = byte((uniqueKey[31] + p.sessionID) & 0xFF) + + cipher, _ := rc4.NewCipher(uniqueKey) + ciphered := make([]byte, len(p.payload)) + + cipher.XORKeyStream(ciphered, p.payload) + + return ciphered +} diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index 1af61274..00046876 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -40,4 +40,5 @@ type PRUDPPacketInterface interface { setConnectionSignature(connectionSignature []byte) getFragmentID() uint8 setFragmentID(fragmentID uint8) + processUnreliableCrypto() []byte } diff --git a/prudp_server.go b/prudp_server.go index 7e0c73e1..4bee9ff4 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -359,6 +359,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { client.minorVersion = ack.minorVersion client.supportedFunctions = ack.supportedFunctions client.createReliableSubstreams(ack.maximumSubstreamID) + client.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](packet.(*PRUDPPacketV1).initialUnreliableSequenceID) } else { client.createReliableSubstreams(0) } @@ -542,7 +543,66 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { } } -func (s *PRUDPServer) handleUnreliable(packet PRUDPPacketInterface) {} +func (s *PRUDPServer) handleUnreliable(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + s.acknowledgePacket(packet) + } + + // * Since unreliable DATA packets can in theory reach the + // * server in any order, and they lack a substream, it's + // * not actually possible to know what order they should + // * be processed in for each request. So assume all packets + // * MUST be fragment 0 (unreliable packets do not have frags) + // * + // * Example - + // * + // * Say there is 2 requests to the same protocol, methods 1 + // * and 2. The starting unreliable sequence ID is 10. If both + // * method 1 and 2 are called at the same time, but method 1 + // * has a fragmented payload, the packets could, in theory, reach + // * the server like so: + // * + // * - Method1 - Sequence 10, Fragment 1 + // * - Method1 - Sequence 13, Fragment 3 + // * - Method2 - Sequence 12, Fragment 0 + // * - Method1 - Sequence 11, Fragment 2 + // * - Method1 - Sequence 14, Fragment 0 + // * + // * If we reorder these to the proper order, like so: + // * + // * - Method1 - Sequence 10, Fragment 1 + // * - Method1 - Sequence 11, Fragment 2 + // * - Method2 - Sequence 12, Fragment 0 + // * - Method1 - Sequence 13, Fragment 3 + // * - Method1 - Sequence 14, Fragment 0 + // * + // * We still have a gap where Method2 was called. It's not + // * possible to know if the packet with sequence ID 12 belongs + // * to the Method1 calls or not. We don't even know which methods + // * the packets are for at this stage yet, since the RMC data + // * can't be checked until all the fragments are collected and + // * the payload decrypted. In this case, we would see fragment 0 + // * and assume that's the end of fragments, losing the real last + // * fragments and resulting in a bad decryption + // TODO - Is this actually true? I'm just assuming, based on common sense, tbh. Kinnay also does not implement fragmented unreliable packets? + if packet.getFragmentID() != 0 { + logger.Warningf("Unexpected unreliable fragment ID. Expected 0, got %d", packet.getFragmentID()) + return + } + + payload := packet.processUnreliableCrypto() + + message := NewRMCMessage() + err := message.FromBytes(payload) + if err != nil { + // TODO - Should this return the error too? + logger.Error(err.Error()) + } + + packet.SetRMCMessage(message) + + s.emit("data", packet) +} func (s *PRUDPServer) sendPing(client *PRUDPClient) { var ping PRUDPPacketInterface @@ -620,8 +680,9 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { substream := client.reliableSubstream(packetCopy.SubstreamID()) packetCopy.SetPayload(substream.Encrypt(compressedPayload)) + } else { + packetCopy.SetPayload(packetCopy.processUnreliableCrypto()) } - // TODO - Unreliable crypto } packetCopy.setSignature(packetCopy.calculateSignature(client.sessionKey, client.serverConnectionSignature)) From 5c3039ec7efa60f8d8466058b7902076b587d984 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Thu, 23 Nov 2023 11:19:24 -0500 Subject: [PATCH 048/178] qrv: support insecure PRUDP station crypto --- prudp_packet.go | 9 +++++++++ prudp_server.go | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/prudp_packet.go b/prudp_packet.go index 7c4a9f77..c9ec7105 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -143,6 +143,15 @@ func (p *PRUDPPacket) decryptPayload() []byte { if p.packetType == DataPacket { substream := p.sender.reliableSubstream(p.SubstreamID()) + if !p.sender.server.UseSecurePRUDP { + // * Servers which use the "prudp" scheme instead of + // * the secure "prudps" scheme in their station URL + // * don't use a session key for packet encryption. + // * Instead they use a per-packet RC4 stream using + // * the default key + substream.SetCipherKey([]byte("CD&ML")) + } + payload = substream.Decrypt(payload) } diff --git a/prudp_server.go b/prudp_server.go index 4bee9ff4..e17c4f32 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -41,6 +41,7 @@ type PRUDPServer struct { PasswordFromPID func(pid *PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte CompressionEnabled bool + UseSecurePRUDP bool } // OnData adds an event handler which is fired when a new DATA packet is received @@ -679,6 +680,15 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { substream := client.reliableSubstream(packetCopy.SubstreamID()) + if !s.UseSecurePRUDP { + // * Servers which use the "prudp" scheme instead of + // * the secure "prudps" scheme in their station URL + // * don't use a session key for packet encryption. + // * Instead they use a per-packet RC4 stream using + // * the default key + substream.SetCipherKey([]byte("CD&ML")) + } + packetCopy.SetPayload(substream.Encrypt(compressedPayload)) } else { packetCopy.SetPayload(packetCopy.processUnreliableCrypto()) @@ -968,5 +978,6 @@ func NewPRUDPServer() *PRUDPServer { prudpEventHandlers: make(map[string][]func(PacketInterface)), connectionIDCounter: NewCounter[uint32](10), pingTimeout: time.Second * 15, + UseSecurePRUDP: true, } } From 030945bd7511749a9cc00df0e48eb0843775f4d7 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 24 Nov 2023 17:40:37 -0500 Subject: [PATCH 049/178] StationURL type now uses a MutexMap for fields --- types.go | 410 ++++++------------------------------------------------- 1 file changed, 41 insertions(+), 369 deletions(-) diff --git a/types.go b/types.go index f245193b..35df9f41 100644 --- a/types.go +++ b/types.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "strconv" "strings" "time" ) @@ -591,416 +590,87 @@ func NewDateTime(value uint64) *DateTime { // StationURL contains the data for a NEX station URL. // Uses pointers to check for nil, 0 is valid type StationURL struct { - local bool // * Not part of the data structure. Used for easier lookups elsewhere - public bool // * Not part of the data structure. Used for easier lookups elsewhere - scheme string - address string - port *uint32 - pl *uint32 // * Seen in Minecraft - stream *uint32 - sid *uint32 - cid *uint32 - pid *PID - transportType *uint32 - rvcid *uint32 - natm *uint32 - natf *uint32 - upnp *uint32 - pmp *uint32 - probeinit *uint32 - prid *uint32 + local bool // * Not part of the data structure. Used for easier lookups elsewhere + public bool // * Not part of the data structure. Used for easier lookups elsewhere + Scheme string + Fields *MutexMap[string, string] } // SetLocal marks the StationURL as an local URL -func (stationURL *StationURL) SetLocal() { - stationURL.local = true - stationURL.public = false +func (s *StationURL) SetLocal() { + s.local = true + s.public = false } // SetPublic marks the StationURL as an public URL -func (stationURL *StationURL) SetPublic() { - stationURL.local = false - stationURL.public = true +func (s *StationURL) SetPublic() { + s.local = false + s.public = true } // IsLocal checks if the StationURL is a local URL -func (stationURL *StationURL) IsLocal() bool { - return stationURL.local +func (s *StationURL) IsLocal() bool { + return s.local } // IsPublic checks if the StationURL is a public URL -func (stationURL *StationURL) IsPublic() bool { - return stationURL.public +func (s *StationURL) IsPublic() bool { + return s.public } -// SetScheme sets the StationURL scheme -func (stationURL *StationURL) SetScheme(scheme string) { - stationURL.scheme = scheme -} - -// SetAddress sets the StationURL address -func (stationURL *StationURL) SetAddress(address string) { - stationURL.address = address -} - -// SetPort sets the StationURL port -func (stationURL *StationURL) SetPort(port uint32) { - stationURL.port = &port -} - -// SetPL sets the StationURL Pl -func (stationURL *StationURL) SetPL(pl uint32) { - stationURL.pl = &pl -} - -// SetStream sets the StationURL stream -func (stationURL *StationURL) SetStream(stream uint32) { - stationURL.stream = &stream -} - -// SetSID sets the StationURL SID -func (stationURL *StationURL) SetSID(sid uint32) { - stationURL.sid = &sid -} - -// SetCID sets the StationURL CID -func (stationURL *StationURL) SetCID(cid uint32) { - stationURL.cid = &cid -} - -// SetPID sets the StationURL PID -func (stationURL *StationURL) SetPID(pid *PID) { - stationURL.pid = pid -} - -// SetType sets the StationURL transportType -func (stationURL *StationURL) SetType(transportType uint32) { - stationURL.transportType = &transportType -} - -// SetRVCID sets the StationURL RVCID -func (stationURL *StationURL) SetRVCID(rvcid uint32) { - stationURL.rvcid = &rvcid -} - -// SetNatm sets the StationURL Natm -func (stationURL *StationURL) SetNatm(natm uint32) { - stationURL.natm = &natm -} - -// SetNatf sets the StationURL Natf -func (stationURL *StationURL) SetNatf(natf uint32) { - stationURL.natf = &natf -} - -// SetUpnp sets the StationURL Upnp -func (stationURL *StationURL) SetUpnp(upnp uint32) { - stationURL.upnp = &upnp -} - -// SetPmp sets the StationURL Pmp -func (stationURL *StationURL) SetPmp(pmp uint32) { - stationURL.pmp = &pmp -} - -// SetProbeInit sets the StationURL ProbeInit -func (stationURL *StationURL) SetProbeInit(probeinit uint32) { - stationURL.probeinit = &probeinit -} - -// SetPRID sets the StationURL PRID -func (stationURL *StationURL) SetPRID(prid uint32) { - stationURL.prid = &prid -} - -// Scheme returns the StationURL scheme type -func (stationURL *StationURL) Scheme() string { - return stationURL.address -} - -// Address returns the StationURL address -func (stationURL *StationURL) Address() string { - return stationURL.address -} - -// Port returns the StationURL port -func (stationURL *StationURL) Port() uint32 { - if stationURL.port == nil { - return 0 - } else { - return *stationURL.port - } -} - -// PL returns the StationURL Pl -func (stationURL *StationURL) PL() uint32 { - if stationURL.pl == nil { - return 0 - } else { - return *stationURL.pl - } -} - -// Stream returns the StationURL stream value -func (stationURL *StationURL) Stream() uint32 { - if stationURL.stream == nil { - return 0 - } else { - return *stationURL.stream - } -} - -// SID returns the StationURL SID value -func (stationURL *StationURL) SID() uint32 { - if stationURL.sid == nil { - return 0 - } else { - return *stationURL.sid - } -} - -// CID returns the StationURL CID value -func (stationURL *StationURL) CID() uint32 { - if stationURL.cid == nil { - return 0 - } else { - return *stationURL.cid - } -} - -// PID returns the StationURL PID value -func (stationURL *StationURL) PID() *PID { - return stationURL.pid -} - -// Type returns the StationURL type -func (stationURL *StationURL) Type() uint32 { - if stationURL.transportType == nil { - return 0 - } else { - return *stationURL.transportType - } -} - -// RVCID returns the StationURL RVCID -func (stationURL *StationURL) RVCID() uint32 { - if stationURL.rvcid == nil { - return 0 - } else { - return *stationURL.rvcid - } -} - -// Natm returns the StationURL Natm value -func (stationURL *StationURL) Natm() uint32 { - if stationURL.natm == nil { - return 0 - } else { - return *stationURL.natm - } -} - -// Natf returns the StationURL Natf value -func (stationURL *StationURL) Natf() uint32 { - if stationURL.natf == nil { - return 0 - } else { - return *stationURL.natf - } -} - -// Upnp returns the StationURL Upnp value -func (stationURL *StationURL) Upnp() uint32 { - if stationURL.upnp == nil { - return 0 - } else { - return *stationURL.upnp - } -} - -// Pmp returns the StationURL Pmp value -func (stationURL *StationURL) Pmp() uint32 { - if stationURL.pmp == nil { - return 0 - } else { - return *stationURL.pmp - } -} +// FromString parses the StationURL data from a string +func (s *StationURL) FromString(str string) { + split := strings.Split(str, ":/") -// ProbeInit returns the StationURL ProbeInit value -func (stationURL *StationURL) ProbeInit() uint32 { - if stationURL.probeinit == nil { - return 0 - } else { - return *stationURL.probeinit - } -} + s.Scheme = split[0] + fields := strings.Split(split[1], ";") -// PRID returns the StationURL PRID value -func (stationURL *StationURL) PRID() uint32 { - if stationURL.prid == nil { - return 0 - } else { - return *stationURL.prid - } -} + for i := 0; i < len(fields); i++ { + field := strings.Split(fields[i], "=") -// FromString parses the StationURL data from a string -func (stationURL *StationURL) FromString(str string) { - split := strings.Split(str, ":/") + key := field[0] + value := field[1] - stationURL.scheme = split[0] - fields := split[1] - - params := strings.Split(fields, ";") - - for i := 0; i < len(params); i++ { - param := params[i] - split = strings.Split(param, "=") - - name := split[0] - value := split[1] - - switch name { - case "address": - stationURL.address = value - case "port": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetPort(uint32(ui64)) - case "Pl": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetPL(uint32(ui64)) - case "stream": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetStream(uint32(ui64)) - case "sid": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetSID(uint32(ui64)) - case "CID": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetCID(uint32(ui64)) - case "PID": - ui64, _ := strconv.ParseUint(value, 10, 64) - stationURL.SetPID(NewPID(ui64)) - case "type": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetType(uint32(ui64)) - case "RVCID": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetRVCID(uint32(ui64)) - case "natm": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetNatm(uint32(ui64)) - case "natf": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetNatf(uint32(ui64)) - case "upnp": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetUpnp(uint32(ui64)) - case "pmp": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetPmp(uint32(ui64)) - case "probeinit": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetProbeInit(uint32(ui64)) - case "PRID": - ui64, _ := strconv.ParseUint(value, 10, 32) - stationURL.SetPRID(uint32(ui64)) - } + s.Fields.Set(key, value) } } // EncodeToString encodes the StationURL into a string -func (stationURL *StationURL) EncodeToString() string { +func (s *StationURL) EncodeToString() string { fields := []string{} - if stationURL.address != "" { - fields = append(fields, "address="+stationURL.address) - } - - if stationURL.port != nil { - fields = append(fields, "port="+strconv.FormatUint(uint64(stationURL.Port()), 10)) - } - - if stationURL.pl != nil { - fields = append(fields, "Pl="+strconv.FormatUint(uint64(stationURL.PL()), 10)) - } - - if stationURL.stream != nil { - fields = append(fields, "stream="+strconv.FormatUint(uint64(stationURL.Stream()), 10)) - } - - if stationURL.sid != nil { - fields = append(fields, "sid="+strconv.FormatUint(uint64(stationURL.SID()), 10)) - } - - if stationURL.cid != nil { - fields = append(fields, "CID="+strconv.FormatUint(uint64(stationURL.CID()), 10)) - } - - if stationURL.pid != nil { - fields = append(fields, "PID="+strconv.FormatUint(uint64(stationURL.PID().pid), 10)) - } - - if stationURL.transportType != nil { - fields = append(fields, "type="+strconv.FormatUint(uint64(stationURL.Type()), 10)) - } - - if stationURL.rvcid != nil { - fields = append(fields, "RVCID="+strconv.FormatUint(uint64(stationURL.RVCID()), 10)) - } - - if stationURL.natm != nil { - fields = append(fields, "natm="+strconv.FormatUint(uint64(stationURL.Natm()), 10)) - } - - if stationURL.natf != nil { - fields = append(fields, "natf="+strconv.FormatUint(uint64(stationURL.Natf()), 10)) - } - - if stationURL.upnp != nil { - fields = append(fields, "upnp="+strconv.FormatUint(uint64(stationURL.Upnp()), 10)) - } - - if stationURL.pmp != nil { - fields = append(fields, "pmp="+strconv.FormatUint(uint64(stationURL.Pmp()), 10)) - } - - if stationURL.probeinit != nil { - fields = append(fields, "probeinit="+strconv.FormatUint(uint64(stationURL.ProbeInit()), 10)) - } - - if stationURL.prid != nil { - fields = append(fields, "PRID="+strconv.FormatUint(uint64(stationURL.PRID()), 10)) - } + s.Fields.Each(func(key, value string) bool { + fields = append(fields, fmt.Sprintf("%s=%s", key, value)) + return false + }) - return stationURL.scheme + ":/" + strings.Join(fields, ";") + return s.Scheme + ":/" + strings.Join(fields, ";") } // Copy returns a new copied instance of StationURL -func (stationURL *StationURL) Copy() *StationURL { - return NewStationURL(stationURL.EncodeToString()) +func (s *StationURL) Copy() *StationURL { + return NewStationURL(s.EncodeToString()) } // Equals checks if the passed Structure contains the same data as the current instance -func (stationURL *StationURL) Equals(other *StationURL) bool { - return stationURL.EncodeToString() == other.EncodeToString() +func (s *StationURL) Equals(other *StationURL) bool { + return s.EncodeToString() == other.EncodeToString() } // String returns a string representation of the struct -func (stationURL *StationURL) String() string { - return stationURL.FormatToString(0) +func (s *StationURL) String() string { + return s.FormatToString(0) } // FormatToString pretty-prints the struct data using the provided indentation level -func (stationURL *StationURL) FormatToString(indentationLevel int) string { +func (s *StationURL) FormatToString(indentationLevel int) string { indentationValues := strings.Repeat("\t", indentationLevel+1) indentationEnd := strings.Repeat("\t", indentationLevel) var b strings.Builder b.WriteString("StationURL{\n") - b.WriteString(fmt.Sprintf("%surl: %q\n", indentationValues, stationURL.EncodeToString())) + b.WriteString(fmt.Sprintf("%surl: %q\n", indentationValues, s.EncodeToString())) b.WriteString(fmt.Sprintf("%s}", indentationEnd)) return b.String() @@ -1008,7 +678,9 @@ func (stationURL *StationURL) FormatToString(indentationLevel int) string { // NewStationURL returns a new StationURL func NewStationURL(str string) *StationURL { - stationURL := &StationURL{} + stationURL := &StationURL{ + Fields: NewMutexMap[string, string](), + } if str != "" { stationURL.FromString(str) From a12c70e2001d113e26a9409e44a065b649eb7324 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 1 Dec 2023 01:03:45 -0500 Subject: [PATCH 050/178] added all basic types to mapTypeWriter --- stream_out.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/stream_out.go b/stream_out.go index b3a2e6d0..8a66a83f 100644 --- a/stream_out.go +++ b/stream_out.go @@ -485,8 +485,38 @@ func mapTypeWriter[T any](stream *StreamOut, t T) { // * key and value. So we need to just check the // * type each time and call the right function switch v := any(t).(type) { + case uint8: + stream.WriteUInt8(v) + case int8: + stream.WriteInt8(v) + case uint16: + stream.WriteUInt16LE(v) + case int16: + stream.WriteInt16LE(v) + case uint32: + stream.WriteUInt32LE(v) + case int32: + stream.WriteInt32LE(v) + case uint64: + stream.WriteUInt64LE(v) + case int64: + stream.WriteInt64LE(v) + case float32: + stream.WriteFloat32LE(v) + case float64: + stream.WriteFloat64LE(v) case string: stream.WriteString(v) + case bool: + stream.WriteBool(v) + case []byte: + // * This actually isn't a good situation, since a byte slice can be either + // * a Buffer or qBuffer. The only known official case is a qBuffer, inside + // * UserAccountManagement::LookupSceNpIds, which is why it's implemented + // * as a qBuffer + stream.WriteQBuffer(v) // TODO - Maybe we should make Buffer and qBuffer real types? + case StructureInterface: + stream.WriteStructure(v) case *Variant: stream.WriteVariant(v) default: From 9d16a4121772807b8dc9a90345f2ebece5ccc2b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Wed, 6 Dec 2023 16:45:01 +0000 Subject: [PATCH 051/178] qrv: Set packet signature as connection signature This behavior has been confirmed on other Quazal server implementations. --- prudp_packet_v0.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 9b93080a..8f0f62d2 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -265,12 +265,14 @@ func (p *PRUDPPacketV0) calculateConnectionSignature(addr net.Addr) ([]byte, err } func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byte) []byte { - if p.packetType == DataPacket { - return p.calculateDataSignature(sessionKey) - } + if !p.sender.server.IsQuazalMode { + if p.packetType == DataPacket { + return p.calculateDataSignature(sessionKey) + } - if p.packetType == DisconnectPacket && p.sender.server.accessKey != "ridfebb9" { - return p.calculateDataSignature(sessionKey) + if p.packetType == DisconnectPacket && p.sender.server.accessKey != "ridfebb9" { + return p.calculateDataSignature(sessionKey) + } } if len(connectionSignature) != 0 { From 1b91bb1c220adc58318d6af8e56c312c78c291b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Wed, 6 Dec 2023 16:48:40 +0000 Subject: [PATCH 052/178] qrv: Extend insecure crypto to all servers All Quazal servers seem to use insecure encryption, including "secure" servers. --- prudp_packet.go | 11 +++++------ prudp_server.go | 13 +++++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/prudp_packet.go b/prudp_packet.go index c9ec7105..0081d46a 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -143,12 +143,11 @@ func (p *PRUDPPacket) decryptPayload() []byte { if p.packetType == DataPacket { substream := p.sender.reliableSubstream(p.SubstreamID()) - if !p.sender.server.UseSecurePRUDP { - // * Servers which use the "prudp" scheme instead of - // * the secure "prudps" scheme in their station URL - // * don't use a session key for packet encryption. - // * Instead they use a per-packet RC4 stream using - // * the default key + // * According to other Quazal server implementations, + // * the RC4 stream is always reset to the default key + // * regardless if the client is connecting to a secure + // * server (prudps) or not + if p.sender.server.IsQuazalMode { substream.SetCipherKey([]byte("CD&ML")) } diff --git a/prudp_server.go b/prudp_server.go index e17c4f32..4332f258 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -41,7 +41,6 @@ type PRUDPServer struct { PasswordFromPID func(pid *PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte CompressionEnabled bool - UseSecurePRUDP bool } // OnData adds an event handler which is fired when a new DATA packet is received @@ -680,12 +679,11 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { substream := client.reliableSubstream(packetCopy.SubstreamID()) - if !s.UseSecurePRUDP { - // * Servers which use the "prudp" scheme instead of - // * the secure "prudps" scheme in their station URL - // * don't use a session key for packet encryption. - // * Instead they use a per-packet RC4 stream using - // * the default key + // * According to other Quazal server implementations, + // * the RC4 stream is always reset to the default key + // * regardless if the client is connecting to a secure + // * server (prudps) or not + if s.IsQuazalMode { substream.SetCipherKey([]byte("CD&ML")) } @@ -978,6 +976,5 @@ func NewPRUDPServer() *PRUDPServer { prudpEventHandlers: make(map[string][]func(PacketInterface)), connectionIDCounter: NewCounter[uint32](10), pingTimeout: time.Second * 15, - UseSecurePRUDP: true, } } From e8049a2bdb489fae7766dc6485632eeb03f9f067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Wed, 6 Dec 2023 16:50:05 +0000 Subject: [PATCH 053/178] test: Make tests compile again --- test/secure.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/secure.go b/test/secure.go index cf92c4d6..45796500 100644 --- a/test/secure.go +++ b/test/secure.go @@ -3,6 +3,7 @@ package main import ( "fmt" "net" + "strconv" "github.com/PretendoNetwork/nex-go" ) @@ -106,8 +107,8 @@ func registerEx(packet nex.PRUDPPacketInterface) { address := packet.Sender().Address().(*net.UDPAddr).IP.String() - localStation.SetAddress(address) - localStation.SetPort(uint32(packet.Sender().Address().(*net.UDPAddr).Port)) + localStation.Fields.Set("address", address) + localStation.Fields.Set("port", strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port)) retval := nex.NewResultSuccess(0x00010001) localStationURL := localStation.EncodeToString() From d7c8c154595aa6eb48bfd4d6e113fc4bc1aa073d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Wed, 6 Dec 2023 23:04:57 +0000 Subject: [PATCH 054/178] hpp: Add IsHPP to RMC password signature error --- hpp_server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/hpp_server.go b/hpp_server.go index b2727537..3591fd7e 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -102,6 +102,7 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { // HPP returns PythonCore::ValidationError if password is missing or invalid errorResponse := NewRMCError(Errors.PythonCore.ValidationError) errorResponse.CallID = rmcMessage.CallID + errorResponse.IsHPP = true _, err = w.Write(errorResponse.Bytes()) if err != nil { From 7b7b9c59d7b498c7acc27915c672626efeaefcef Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 8 Dec 2023 01:24:57 -0500 Subject: [PATCH 055/178] prudp: add support for virtual connections --- prudp_client.go | 3 +- prudp_packet.go | 6 ++ prudp_packet_interface.go | 1 + prudp_packet_v0.go | 18 ++-- prudp_packet_v1.go | 10 ++- prudp_server.go | 94 +++++++++++---------- prudp_virtual_connection_manager.go | 126 ++++++++++++++++++++++++++++ resend_scheduler.go | 5 +- 8 files changed, 211 insertions(+), 52 deletions(-) create mode 100644 prudp_virtual_connection_manager.go diff --git a/prudp_client.go b/prudp_client.go index 419f57fe..3f986218 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -185,7 +185,8 @@ func (c *PRUDPClient) startHeartbeat() { // * client is dead and clean up c.pingKickTimer = time.AfterFunc(server.pingTimeout, func() { c.cleanup() // * "removed" event is dispatched here - c.server.clients.Delete(c.address.String()) + virtualStream := c.server.virtualConnectionManager.Get(c.DestinationPort, c.DestinationStreamType) + virtualStream.clients.Delete(c.address.String()) }) }) } diff --git a/prudp_packet.go b/prudp_packet.go index 0081d46a..2a4463e4 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -4,6 +4,7 @@ import "crypto/rc4" // PRUDPPacket holds all the fields each packet should have in all PRUDP versions type PRUDPPacket struct { + server *PRUDPServer sender *PRUDPClient readStream *StreamIn sourceStreamType uint8 @@ -22,6 +23,11 @@ type PRUDPPacket struct { message *RMCMessage } +// SetSender sets the Client who sent the packet +func (p *PRUDPPacket) SetSender(sender ClientInterface) { + p.sender = sender.(*PRUDPClient) +} + // Sender returns the Client who sent the packet func (p *PRUDPPacket) Sender() ClientInterface { return p.sender diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index 00046876..94f441aa 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -7,6 +7,7 @@ type PRUDPPacketInterface interface { Copy() PRUDPPacketInterface Version() int Bytes() []byte + SetSender(sender ClientInterface) Sender() ClientInterface Flags() uint16 HasFlag(flag uint16) bool diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 8f0f62d2..df09c4a8 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -21,6 +21,7 @@ type PRUDPPacketV0 struct { func (p *PRUDPPacketV0) Copy() PRUDPPacketInterface { copied, _ := NewPRUDPPacketV0(p.sender, nil) + copied.server = p.server copied.sourceStreamType = p.sourceStreamType copied.sourcePort = p.sourcePort copied.destinationStreamType = p.destinationStreamType @@ -64,7 +65,7 @@ func (p *PRUDPPacketV0) decode() error { return errors.New("Failed to read PRUDPv0 header. Not have enough data") } - server := p.sender.server + server := p.server start := p.readStream.ByteOffset() source, err := p.readStream.ReadUInt8() @@ -194,7 +195,7 @@ func (p *PRUDPPacketV0) decode() error { // Bytes encodes a PRUDPv0 packet into a byte slice func (p *PRUDPPacketV0) Bytes() []byte { - server := p.sender.server + server := p.server stream := NewStreamOut(server) stream.WriteUInt8(p.sourcePort | (p.sourceStreamType << 4)) @@ -265,12 +266,12 @@ func (p *PRUDPPacketV0) calculateConnectionSignature(addr net.Addr) ([]byte, err } func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byte) []byte { - if !p.sender.server.IsQuazalMode { + if !p.server.IsQuazalMode { if p.packetType == DataPacket { return p.calculateDataSignature(sessionKey) } - if p.packetType == DisconnectPacket && p.sender.server.accessKey != "ridfebb9" { + if p.packetType == DisconnectPacket && p.server.accessKey != "ridfebb9" { return p.calculateDataSignature(sessionKey) } } @@ -283,7 +284,7 @@ func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byt } func (p *PRUDPPacketV0) calculateDataSignature(sessionKey []byte) []byte { - server := p.sender.server + server := p.server data := p.payload if server.AccessKey() != "ridfebb9" { @@ -309,7 +310,7 @@ func (p *PRUDPPacketV0) calculateDataSignature(sessionKey []byte) []byte { } func (p *PRUDPPacketV0) calculateChecksum(data []byte) uint32 { - server := p.sender.server + server := p.server checksum := sum[byte, uint32]([]byte(server.AccessKey())) if server.IsQuazalMode { @@ -355,12 +356,17 @@ func NewPRUDPPacketV0(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV0 } if readStream != nil { + packet.server = readStream.Server.(*PRUDPServer) err := packet.decode() if err != nil { return nil, fmt.Errorf("Failed to decode PRUDPv0 packet. %s", err.Error()) } } + if client != nil { + packet.server = client.server + } + return packet, nil } diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index e46f5c56..7deacd87 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -27,6 +27,7 @@ type PRUDPPacketV1 struct { func (p *PRUDPPacketV1) Copy() PRUDPPacketInterface { copied, _ := NewPRUDPPacketV1(p.sender, nil) + copied.server = p.server copied.sourceStreamType = p.sourceStreamType copied.sourcePort = p.sourcePort copied.destinationStreamType = p.destinationStreamType @@ -326,14 +327,14 @@ func (p *PRUDPPacketV1) calculateConnectionSignature(addr net.Addr) ([]byte, err binary.BigEndian.PutUint16(portBytes, uint16(port)) data := append(ip, portBytes...) - hash := hmac.New(md5.New, p.sender.server.PRUDPv1ConnectionSignatureKey) + hash := hmac.New(md5.New, p.server.PRUDPv1ConnectionSignatureKey) hash.Write(data) return hash.Sum(nil), nil } func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byte) []byte { - accessKeyBytes := []byte(p.sender.server.accessKey) + accessKeyBytes := []byte(p.server.accessKey) options := p.encodeOptions() header := p.encodeHeader() @@ -364,12 +365,17 @@ func NewPRUDPPacketV1(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV1 } if readStream != nil { + packet.server = readStream.Server.(*PRUDPServer) err := packet.decode() if err != nil { return nil, fmt.Errorf("Failed to decode PRUDPv1 packet. %s", err.Error()) } } + if client != nil { + packet.server = client.server + } + return packet, nil } diff --git a/prudp_server.go b/prudp_server.go index 4332f258..9fa7325a 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -15,9 +15,10 @@ import ( // PRUDPServer represents a bare-bones PRUDP server type PRUDPServer struct { udpSocket *net.UDPConn - clients *MutexMap[string, *PRUDPClient] PRUDPVersion int PRUDPMinorVersion uint32 + MaxPRUDPVirtualPorts uint8 + virtualConnectionManager *PRUDPVirtualConnectionManager IsQuazalMode bool IsSecureServer bool SupportedFunctions uint32 @@ -108,6 +109,9 @@ func (s *PRUDPServer) Listen(port int) { s.udpSocket = socket + s.virtualConnectionManager = NewPRUDPVirtualConnectionManager(s.MaxPRUDPVirtualPorts) + logger.Success("Virtual ports created") + quit := make(chan struct{}) for i := 0; i < runtime.NumCPU(); i++ { @@ -137,16 +141,41 @@ func (s *PRUDPServer) handleSocketMessage() error { return err } - discriminator := addr.String() + packetData := buffer[:read] + readStream := NewStreamIn(packetData, s) + + var packets []PRUDPPacketInterface + + // * Support any packet type the client sends and respond + // * with that same type. Also keep reading from the stream + // * until no more data is left, to account for multiple + // * packets being sent at once + if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) { + packets, _ = NewPRUDPPacketsV1(nil, readStream) + } else { + packets, _ = NewPRUDPPacketsV0(nil, readStream) + } + + for _, packet := range packets { + go s.processPacket(packet, addr) + } + + return nil +} + +func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address *net.UDPAddr) { + virtualStream := s.virtualConnectionManager.Get(packet.DestinationPort(), packet.DestinationStreamType()) - client, ok := s.clients.Get(discriminator) + clientSocketDiscriminator := address.String() + + client, ok := virtualStream.clients.Get(clientSocketDiscriminator) if !ok { - client = NewPRUDPClient(addr, s) + client = NewPRUDPClient(address, s) client.startHeartbeat() // * Fail-safe. If the server reboots, then - // * s.clients has no record of old clients. + // * clients has no record of old clients. // * An existing client which has not killed // * the connection on it's end MAY still send // * DATA packets once the server is back @@ -166,33 +195,11 @@ func (s *PRUDPServer) handleSocketMessage() error { // * EXPECTED TO NATURALLY DIE HERE client.createReliableSubstreams(0) - s.clients.Set(discriminator, client) - } - - packetData := buffer[:read] - readStream := NewStreamIn(packetData, s) - - var packets []PRUDPPacketInterface - - // * Support any packet type the client sends and respond - // * with that same type. Also keep reading from the stream - // * until no more data is left, to account for multiple - // * packets being sent at once - if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) { - packets, _ = NewPRUDPPacketsV1(client, readStream) - } else { - packets, _ = NewPRUDPPacketsV0(client, readStream) - } - - for _, packet := range packets { - go s.processPacket(packet) + virtualStream.clients.Set(clientSocketDiscriminator, client) } - return nil -} - -func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface) { - packet.Sender().(*PRUDPClient).resetHeartbeat() + packet.SetSender(client) + client.resetHeartbeat() if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { s.handleAcknowledgment(packet) @@ -409,10 +416,11 @@ func (s *PRUDPServer) handleDisconnect(packet PRUDPPacketInterface) { s.acknowledgePacket(packet) } + virtualStream := s.virtualConnectionManager.Get(packet.DestinationPort(), packet.DestinationStreamType()) client := packet.Sender().(*PRUDPClient) client.cleanup() // * "removed" event is dispatched here - s.clients.Delete(client.address.String()) + virtualStream.clients.Delete(client.address.String()) s.emit("disconnect", packet) } @@ -935,10 +943,11 @@ func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { } // FindClientByConnectionID returns the PRUDP client connected with the given connection ID -func (s *PRUDPServer) FindClientByConnectionID(connectedID uint32) *PRUDPClient { +func (s *PRUDPServer) FindClientByConnectionID(port, streamType uint8, connectedID uint32) *PRUDPClient { var client *PRUDPClient - s.clients.Each(func(discriminator string, c *PRUDPClient) bool { + virtualStream := s.virtualConnectionManager.Get(port, streamType) + virtualStream.clients.Each(func(discriminator string, c *PRUDPClient) bool { if c.ConnectionID == connectedID { client = c return true @@ -951,10 +960,11 @@ func (s *PRUDPServer) FindClientByConnectionID(connectedID uint32) *PRUDPClient } // FindClientByPID returns the PRUDP client connected with the given PID -func (s *PRUDPServer) FindClientByPID(pid uint64) *PRUDPClient { +func (s *PRUDPServer) FindClientByPID(port, streamType uint8, pid uint64) *PRUDPClient { var client *PRUDPClient - s.clients.Each(func(discriminator string, c *PRUDPClient) bool { + virtualStream := s.virtualConnectionManager.Get(port, streamType) + virtualStream.clients.Each(func(discriminator string, c *PRUDPClient) bool { if c.pid.pid == pid { client = c return true @@ -969,12 +979,12 @@ func (s *PRUDPServer) FindClientByPID(pid uint64) *PRUDPClient { // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ - clients: NewMutexMap[string, *PRUDPClient](), - IsQuazalMode: false, - kerberosKeySize: 32, - FragmentSize: 1300, - prudpEventHandlers: make(map[string][]func(PacketInterface)), - connectionIDCounter: NewCounter[uint32](10), - pingTimeout: time.Second * 15, + MaxPRUDPVirtualPorts: 16, // * UDP PRUDP servers use 16 virtual ports + IsQuazalMode: false, + kerberosKeySize: 32, + FragmentSize: 1300, + prudpEventHandlers: make(map[string][]func(PacketInterface)), + connectionIDCounter: NewCounter[uint32](10), + pingTimeout: time.Second * 15, } } diff --git a/prudp_virtual_connection_manager.go b/prudp_virtual_connection_manager.go new file mode 100644 index 00000000..108585aa --- /dev/null +++ b/prudp_virtual_connection_manager.go @@ -0,0 +1,126 @@ +package nex + +const ( + // VirtualStreamTypeDO represents the DO PRUDP virtual connection stream type + VirtualStreamTypeDO uint8 = 1 + + // VirtualStreamTypeRV represents the RV PRUDP virtual connection stream type + VirtualStreamTypeRV uint8 = 2 + + // VirtualStreamTypeOldRVSec represents the OldRVSec PRUDP virtual connection stream type + VirtualStreamTypeOldRVSec uint8 = 3 + + // VirtualStreamTypeSBMGMT represents the SBMGMT PRUDP virtual connection stream type + VirtualStreamTypeSBMGMT uint8 = 4 + + // VirtualStreamTypeNAT represents the NAT PRUDP virtual connection stream type + VirtualStreamTypeNAT uint8 = 5 + + // VirtualStreamTypeSessionDiscovery represents the SessionDiscovery PRUDP virtual connection stream type + VirtualStreamTypeSessionDiscovery uint8 = 6 + + // VirtualStreamTypeNATEcho represents the NATEcho PRUDP virtual connection stream type + VirtualStreamTypeNATEcho uint8 = 7 + + // VirtualStreamTypeRouting represents the Routing PRUDP virtual connection stream type + VirtualStreamTypeRouting uint8 = 8 + + // VirtualStreamTypeGame represents the Game PRUDP virtual connection stream type + VirtualStreamTypeGame uint8 = 9 + + // VirtualStreamTypeRVSecure represents the RVSecure PRUDP virtual connection stream type + VirtualStreamTypeRVSecure uint8 = 10 + + // VirtualStreamTypeRelay represents the Relay PRUDP virtual connection stream type + VirtualStreamTypeRelay uint8 = 11 +) + +// PRUDPVirtualStream represents a PRUDP virtual stream +type PRUDPVirtualStream struct { + clients *MutexMap[string, *PRUDPClient] +} + +// PRUDPVirtualPort represents a PRUDP virtual connections virtual port +type PRUDPVirtualPort struct { + streams *MutexMap[uint8, *PRUDPVirtualStream] +} + +func (vp *PRUDPVirtualPort) init() { + vp.initStream(VirtualStreamTypeDO) + vp.initStream(VirtualStreamTypeRV) + vp.initStream(VirtualStreamTypeOldRVSec) + vp.initStream(VirtualStreamTypeSBMGMT) + vp.initStream(VirtualStreamTypeNAT) + vp.initStream(VirtualStreamTypeSessionDiscovery) + vp.initStream(VirtualStreamTypeNATEcho) + vp.initStream(VirtualStreamTypeRouting) + vp.initStream(VirtualStreamTypeGame) + vp.initStream(VirtualStreamTypeRVSecure) + vp.initStream(VirtualStreamTypeRelay) +} + +func (vp *PRUDPVirtualPort) initStream(streamType uint8) *PRUDPVirtualStream { + virtualStream := &PRUDPVirtualStream{ + clients: NewMutexMap[string, *PRUDPClient](), + } + + vp.streams.Set(streamType, virtualStream) + + return virtualStream +} + +// PRUDPVirtualConnectionManager manages virtual ports used by PRUDP connections +// +// PRUDP uses a single UDP connection to establish multiple "connections" through virtual ports +type PRUDPVirtualConnectionManager struct { + ports *MutexMap[uint8, *PRUDPVirtualPort] +} + +func (vcm *PRUDPVirtualConnectionManager) init(numberOfPorts uint8) { + for i := 0; i < int(numberOfPorts); i++ { + vcm.createVirtualPort(uint8(i)) + } +} + +func (vcm *PRUDPVirtualConnectionManager) createVirtualPort(port uint8) *PRUDPVirtualPort { + virtualPort := &PRUDPVirtualPort{ + streams: NewMutexMap[uint8, *PRUDPVirtualStream](), + } + + virtualPort.init() + + vcm.ports.Set(port, virtualPort) + + return virtualPort +} + +// Get returns PRUDPVirtualStream for the given port and stream type. +// If either the virtual port or stream type do not exist, new ones are created +func (vcm *PRUDPVirtualConnectionManager) Get(port, streamType uint8) *PRUDPVirtualStream { + virtualPort, ok := vcm.ports.Get(port) + if !ok { + // * Just force the port to exist + virtualPort = vcm.createVirtualPort(port) + logger.Warningf("Invalid virtual port %d trying to be accessed. Creating new one to prevent crash", port) + } + + virtualStream, ok := virtualPort.streams.Get(streamType) + if !ok { + // * Just force the stream to exist + virtualStream = virtualPort.initStream(streamType) + logger.Warningf("Invalid virtual stream type %d trying to be accessed. Creating new one to prevent crash", streamType) + } + + return virtualStream +} + +// NewPRUDPVirtualConnectionManager creates a new PRUDPVirtualConnectionManager with the given number of virtual ports +func NewPRUDPVirtualConnectionManager(numberOfPorts uint8) *PRUDPVirtualConnectionManager { + virtualConnectionManager := &PRUDPVirtualConnectionManager{ + ports: NewMutexMap[uint8, *PRUDPVirtualPort](), + } + + virtualConnectionManager.init(numberOfPorts) + + return virtualConnectionManager +} diff --git a/resend_scheduler.go b/resend_scheduler.go index f351bacf..54c46fde 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -93,7 +93,10 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { pendingPacket.ticker.Stop() rs.packets.Delete(packet.SequenceID()) client.cleanup() // * "removed" event is dispatched here - client.server.clients.Delete(client.address.String()) + + virtualStream := client.server.virtualConnectionManager.Get(client.DestinationPort, client.DestinationStreamType) + virtualStream.clients.Delete(client.address.String()) + return } From 109fd47a19568963109a188b87873eb399032f5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Fri, 8 Dec 2023 11:06:21 +0000 Subject: [PATCH 056/178] prudp: Virtual connections now use source port --- prudp_server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 9fa7325a..ff652bcb 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -164,7 +164,7 @@ func (s *PRUDPServer) handleSocketMessage() error { } func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address *net.UDPAddr) { - virtualStream := s.virtualConnectionManager.Get(packet.DestinationPort(), packet.DestinationStreamType()) + virtualStream := s.virtualConnectionManager.Get(packet.SourcePort(), packet.SourceStreamType()) clientSocketDiscriminator := address.String() @@ -416,7 +416,7 @@ func (s *PRUDPServer) handleDisconnect(packet PRUDPPacketInterface) { s.acknowledgePacket(packet) } - virtualStream := s.virtualConnectionManager.Get(packet.DestinationPort(), packet.DestinationStreamType()) + virtualStream := s.virtualConnectionManager.Get(packet.SourcePort(), packet.SourceStreamType()) client := packet.Sender().(*PRUDPClient) client.cleanup() // * "removed" event is dispatched here From 1970b2a73e4ccf8ef251559fcbee0b19af1a2bca Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 8 Dec 2023 16:08:27 -0500 Subject: [PATCH 057/178] prudp: support server vports and move client vport to discriminator --- prudp_client.go | 10 +- prudp_server.go | 153 +++++++++++++++++++--------- prudp_virtual_connection_manager.go | 126 ----------------------- prudp_virtual_stream_types.go | 36 +++++++ resend_scheduler.go | 9 +- 5 files changed, 155 insertions(+), 179 deletions(-) delete mode 100644 prudp_virtual_connection_manager.go create mode 100644 prudp_virtual_stream_types.go diff --git a/prudp_client.go b/prudp_client.go index 3f986218..c9036f46 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -2,6 +2,7 @@ package nex import ( "crypto/md5" + "fmt" "net" "time" ) @@ -185,8 +186,13 @@ func (c *PRUDPClient) startHeartbeat() { // * client is dead and clean up c.pingKickTimer = time.AfterFunc(server.pingTimeout, func() { c.cleanup() // * "removed" event is dispatched here - virtualStream := c.server.virtualConnectionManager.Get(c.DestinationPort, c.DestinationStreamType) - virtualStream.clients.Delete(c.address.String()) + + virtualServer, _ := c.server.virtualServers.Get(c.DestinationPort) + virtualServerStream, _ := virtualServer.Get(c.DestinationStreamType) + + discriminator := fmt.Sprintf("%s-%d-%d", c.address.String(), c.SourcePort, c.SourceStreamType) + + virtualServerStream.Delete(discriminator) }) }) } diff --git a/prudp_server.go b/prudp_server.go index ff652bcb..69e056ec 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -14,34 +14,36 @@ import ( // PRUDPServer represents a bare-bones PRUDP server type PRUDPServer struct { - udpSocket *net.UDPConn - PRUDPVersion int - PRUDPMinorVersion uint32 - MaxPRUDPVirtualPorts uint8 - virtualConnectionManager *PRUDPVirtualConnectionManager - IsQuazalMode bool - IsSecureServer bool - SupportedFunctions uint32 - accessKey string - kerberosPassword []byte - kerberosTicketVersion int - kerberosKeySize int - FragmentSize int - version *LibraryVersion - datastoreProtocolVersion *LibraryVersion - matchMakingProtocolVersion *LibraryVersion - rankingProtocolVersion *LibraryVersion - ranking2ProtocolVersion *LibraryVersion - messagingProtocolVersion *LibraryVersion - utilityProtocolVersion *LibraryVersion - natTraversalProtocolVersion *LibraryVersion - prudpEventHandlers map[string][]func(packet PacketInterface) - clientRemovedEventHandlers []func(client *PRUDPClient) - connectionIDCounter *Counter[uint32] - pingTimeout time.Duration - PasswordFromPID func(pid *PID) (string, uint32) - PRUDPv1ConnectionSignatureKey []byte - CompressionEnabled bool + udpSocket *net.UDPConn + PRUDPVersion int + PRUDPMinorVersion uint32 + MaxPRUDPVirtualPorts uint8 + virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] + IsQuazalMode bool + IsSecureServer bool + AuthenticationVirtualServerPort uint8 + SecureVirtualServerPort uint8 + SupportedFunctions uint32 + accessKey string + kerberosPassword []byte + kerberosTicketVersion int + kerberosKeySize int + FragmentSize int + version *LibraryVersion + datastoreProtocolVersion *LibraryVersion + matchMakingProtocolVersion *LibraryVersion + rankingProtocolVersion *LibraryVersion + ranking2ProtocolVersion *LibraryVersion + messagingProtocolVersion *LibraryVersion + utilityProtocolVersion *LibraryVersion + natTraversalProtocolVersion *LibraryVersion + prudpEventHandlers map[string][]func(packet PacketInterface) + clientRemovedEventHandlers []func(client *PRUDPClient) + connectionIDCounter *Counter[uint32] + pingTimeout time.Duration + PasswordFromPID func(pid *PID) (string, uint32) + PRUDPv1ConnectionSignatureKey []byte + CompressionEnabled bool } // OnData adds an event handler which is fired when a new DATA packet is received @@ -109,7 +111,23 @@ func (s *PRUDPServer) Listen(port int) { s.udpSocket = socket - s.virtualConnectionManager = NewPRUDPVirtualConnectionManager(s.MaxPRUDPVirtualPorts) + for i := 0; i < int(s.MaxPRUDPVirtualPorts); i++ { + virtualServer := NewMutexMap[uint8, *MutexMap[string, *PRUDPClient]]() + virtualServer.Set(VirtualStreamTypeDO, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeRV, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeOldRVSec, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeSBMGMT, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeNAT, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeSessionDiscovery, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeNATEcho, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeRouting, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeGame, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeRVSecure, NewMutexMap[string, *PRUDPClient]()) + virtualServer.Set(VirtualStreamTypeRelay, NewMutexMap[string, *PRUDPClient]()) + + s.virtualServers.Set(uint8(i), virtualServer) + } + logger.Success("Virtual ports created") quit := make(chan struct{}) @@ -164,11 +182,12 @@ func (s *PRUDPServer) handleSocketMessage() error { } func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address *net.UDPAddr) { - virtualStream := s.virtualConnectionManager.Get(packet.SourcePort(), packet.SourceStreamType()) + virtualServer, _ := s.virtualServers.Get(packet.DestinationPort()) + virtualServerStream, _ := virtualServer.Get(packet.DestinationStreamType()) - clientSocketDiscriminator := address.String() + discriminator := fmt.Sprintf("%s-%d-%d", address.String(), packet.SourcePort(), packet.SourceStreamType()) - client, ok := virtualStream.clients.Get(clientSocketDiscriminator) + client, ok := virtualServerStream.Get(discriminator) if !ok { client = NewPRUDPClient(address, s) @@ -195,7 +214,7 @@ func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address *net.UD // * EXPECTED TO NATURALLY DIE HERE client.createReliableSubstreams(0) - virtualStream.clients.Set(clientSocketDiscriminator, client) + virtualServerStream.Set(discriminator, client) } packet.SetSender(client) @@ -373,7 +392,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { var payload []byte - if s.IsSecureServer { + if s.isSecurePort(packet.DestinationPort()) { sessionKey, pid, checkValue, err := s.readKerberosTicket(packet.Payload()) if err != nil { logger.Error(err.Error()) @@ -416,11 +435,14 @@ func (s *PRUDPServer) handleDisconnect(packet PRUDPPacketInterface) { s.acknowledgePacket(packet) } - virtualStream := s.virtualConnectionManager.Get(packet.SourcePort(), packet.SourceStreamType()) + virtualServer, _ := s.virtualServers.Get(packet.DestinationPort()) + virtualServerStream, _ := virtualServer.Get(packet.DestinationStreamType()) + client := packet.Sender().(*PRUDPClient) + discriminator := fmt.Sprintf("%s-%d-%d", client.address.String(), packet.SourcePort(), packet.SourceStreamType()) client.cleanup() // * "removed" event is dispatched here - virtualStream.clients.Delete(client.address.String()) + virtualServerStream.Delete(discriminator) s.emit("disconnect", packet) } @@ -799,6 +821,32 @@ func (s *PRUDPServer) compressPayload(payload []byte) ([]byte, error) { return stream.Bytes(), nil } +func (s *PRUDPServer) isSecurePort(port uint8) bool { + // * We have to support cases where 2 physical servers exist + // * and cases where one physical server exists with multiple + // * virtual ports + + // * If marked as true, we can assume that 2 physical servers exist + // * and that this is always the secure server + if s.IsSecureServer { + return true + } + + // * If not marked true, we have to check if multiple virtual + // * ports are set. Any number of virtual ports can be defined, + // * all with different ports, so long as the "secure" port is + // * not the same as the "authentication" port + authPort := s.AuthenticationVirtualServerPort + securePort := s.SecureVirtualServerPort + if authPort != securePort && securePort == port { + return true + } + + // TODO - Are there cases where both RVSecure and OldRVSec are used on the same server, with different ports? + + return false // * Assume not the secure port +} + // AccessKey returns the servers sandbox access key func (s *PRUDPServer) AccessKey() string { return s.accessKey @@ -943,11 +991,13 @@ func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { } // FindClientByConnectionID returns the PRUDP client connected with the given connection ID -func (s *PRUDPServer) FindClientByConnectionID(port, streamType uint8, connectedID uint32) *PRUDPClient { +func (s *PRUDPServer) FindClientByConnectionID(serverPort, serverStreamType uint8, connectedID uint32) *PRUDPClient { var client *PRUDPClient - virtualStream := s.virtualConnectionManager.Get(port, streamType) - virtualStream.clients.Each(func(discriminator string, c *PRUDPClient) bool { + virtualServer, _ := s.virtualServers.Get(serverPort) + virtualServerStream, _ := virtualServer.Get(serverStreamType) + + virtualServerStream.Each(func(discriminator string, c *PRUDPClient) bool { if c.ConnectionID == connectedID { client = c return true @@ -960,11 +1010,13 @@ func (s *PRUDPServer) FindClientByConnectionID(port, streamType uint8, connected } // FindClientByPID returns the PRUDP client connected with the given PID -func (s *PRUDPServer) FindClientByPID(port, streamType uint8, pid uint64) *PRUDPClient { +func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid uint64) *PRUDPClient { var client *PRUDPClient - virtualStream := s.virtualConnectionManager.Get(port, streamType) - virtualStream.clients.Each(func(discriminator string, c *PRUDPClient) bool { + virtualServer, _ := s.virtualServers.Get(serverPort) + virtualServerStream, _ := virtualServer.Get(serverStreamType) + + virtualServerStream.Each(func(discriminator string, c *PRUDPClient) bool { if c.pid.pid == pid { client = c return true @@ -979,12 +1031,15 @@ func (s *PRUDPServer) FindClientByPID(port, streamType uint8, pid uint64) *PRUDP // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ - MaxPRUDPVirtualPorts: 16, // * UDP PRUDP servers use 16 virtual ports - IsQuazalMode: false, - kerberosKeySize: 32, - FragmentSize: 1300, - prudpEventHandlers: make(map[string][]func(PacketInterface)), - connectionIDCounter: NewCounter[uint32](10), - pingTimeout: time.Second * 15, + MaxPRUDPVirtualPorts: 16, // * UDP PRUDP servers use 16 virtual ports + AuthenticationVirtualServerPort: 1, // * Server ports default to 1 + SecureVirtualServerPort: 1, // * Server ports default to 1 + virtualServers: NewMutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]](), + IsQuazalMode: false, + kerberosKeySize: 32, + FragmentSize: 1300, + prudpEventHandlers: make(map[string][]func(PacketInterface)), + connectionIDCounter: NewCounter[uint32](10), + pingTimeout: time.Second * 15, } } diff --git a/prudp_virtual_connection_manager.go b/prudp_virtual_connection_manager.go deleted file mode 100644 index 108585aa..00000000 --- a/prudp_virtual_connection_manager.go +++ /dev/null @@ -1,126 +0,0 @@ -package nex - -const ( - // VirtualStreamTypeDO represents the DO PRUDP virtual connection stream type - VirtualStreamTypeDO uint8 = 1 - - // VirtualStreamTypeRV represents the RV PRUDP virtual connection stream type - VirtualStreamTypeRV uint8 = 2 - - // VirtualStreamTypeOldRVSec represents the OldRVSec PRUDP virtual connection stream type - VirtualStreamTypeOldRVSec uint8 = 3 - - // VirtualStreamTypeSBMGMT represents the SBMGMT PRUDP virtual connection stream type - VirtualStreamTypeSBMGMT uint8 = 4 - - // VirtualStreamTypeNAT represents the NAT PRUDP virtual connection stream type - VirtualStreamTypeNAT uint8 = 5 - - // VirtualStreamTypeSessionDiscovery represents the SessionDiscovery PRUDP virtual connection stream type - VirtualStreamTypeSessionDiscovery uint8 = 6 - - // VirtualStreamTypeNATEcho represents the NATEcho PRUDP virtual connection stream type - VirtualStreamTypeNATEcho uint8 = 7 - - // VirtualStreamTypeRouting represents the Routing PRUDP virtual connection stream type - VirtualStreamTypeRouting uint8 = 8 - - // VirtualStreamTypeGame represents the Game PRUDP virtual connection stream type - VirtualStreamTypeGame uint8 = 9 - - // VirtualStreamTypeRVSecure represents the RVSecure PRUDP virtual connection stream type - VirtualStreamTypeRVSecure uint8 = 10 - - // VirtualStreamTypeRelay represents the Relay PRUDP virtual connection stream type - VirtualStreamTypeRelay uint8 = 11 -) - -// PRUDPVirtualStream represents a PRUDP virtual stream -type PRUDPVirtualStream struct { - clients *MutexMap[string, *PRUDPClient] -} - -// PRUDPVirtualPort represents a PRUDP virtual connections virtual port -type PRUDPVirtualPort struct { - streams *MutexMap[uint8, *PRUDPVirtualStream] -} - -func (vp *PRUDPVirtualPort) init() { - vp.initStream(VirtualStreamTypeDO) - vp.initStream(VirtualStreamTypeRV) - vp.initStream(VirtualStreamTypeOldRVSec) - vp.initStream(VirtualStreamTypeSBMGMT) - vp.initStream(VirtualStreamTypeNAT) - vp.initStream(VirtualStreamTypeSessionDiscovery) - vp.initStream(VirtualStreamTypeNATEcho) - vp.initStream(VirtualStreamTypeRouting) - vp.initStream(VirtualStreamTypeGame) - vp.initStream(VirtualStreamTypeRVSecure) - vp.initStream(VirtualStreamTypeRelay) -} - -func (vp *PRUDPVirtualPort) initStream(streamType uint8) *PRUDPVirtualStream { - virtualStream := &PRUDPVirtualStream{ - clients: NewMutexMap[string, *PRUDPClient](), - } - - vp.streams.Set(streamType, virtualStream) - - return virtualStream -} - -// PRUDPVirtualConnectionManager manages virtual ports used by PRUDP connections -// -// PRUDP uses a single UDP connection to establish multiple "connections" through virtual ports -type PRUDPVirtualConnectionManager struct { - ports *MutexMap[uint8, *PRUDPVirtualPort] -} - -func (vcm *PRUDPVirtualConnectionManager) init(numberOfPorts uint8) { - for i := 0; i < int(numberOfPorts); i++ { - vcm.createVirtualPort(uint8(i)) - } -} - -func (vcm *PRUDPVirtualConnectionManager) createVirtualPort(port uint8) *PRUDPVirtualPort { - virtualPort := &PRUDPVirtualPort{ - streams: NewMutexMap[uint8, *PRUDPVirtualStream](), - } - - virtualPort.init() - - vcm.ports.Set(port, virtualPort) - - return virtualPort -} - -// Get returns PRUDPVirtualStream for the given port and stream type. -// If either the virtual port or stream type do not exist, new ones are created -func (vcm *PRUDPVirtualConnectionManager) Get(port, streamType uint8) *PRUDPVirtualStream { - virtualPort, ok := vcm.ports.Get(port) - if !ok { - // * Just force the port to exist - virtualPort = vcm.createVirtualPort(port) - logger.Warningf("Invalid virtual port %d trying to be accessed. Creating new one to prevent crash", port) - } - - virtualStream, ok := virtualPort.streams.Get(streamType) - if !ok { - // * Just force the stream to exist - virtualStream = virtualPort.initStream(streamType) - logger.Warningf("Invalid virtual stream type %d trying to be accessed. Creating new one to prevent crash", streamType) - } - - return virtualStream -} - -// NewPRUDPVirtualConnectionManager creates a new PRUDPVirtualConnectionManager with the given number of virtual ports -func NewPRUDPVirtualConnectionManager(numberOfPorts uint8) *PRUDPVirtualConnectionManager { - virtualConnectionManager := &PRUDPVirtualConnectionManager{ - ports: NewMutexMap[uint8, *PRUDPVirtualPort](), - } - - virtualConnectionManager.init(numberOfPorts) - - return virtualConnectionManager -} diff --git a/prudp_virtual_stream_types.go b/prudp_virtual_stream_types.go new file mode 100644 index 00000000..b639f0c0 --- /dev/null +++ b/prudp_virtual_stream_types.go @@ -0,0 +1,36 @@ +package nex + +const ( + // VirtualStreamTypeDO represents the DO PRUDP virtual connection stream type + VirtualStreamTypeDO uint8 = 1 + + // VirtualStreamTypeRV represents the RV PRUDP virtual connection stream type + VirtualStreamTypeRV uint8 = 2 + + // VirtualStreamTypeOldRVSec represents the OldRVSec PRUDP virtual connection stream type + VirtualStreamTypeOldRVSec uint8 = 3 + + // VirtualStreamTypeSBMGMT represents the SBMGMT PRUDP virtual connection stream type + VirtualStreamTypeSBMGMT uint8 = 4 + + // VirtualStreamTypeNAT represents the NAT PRUDP virtual connection stream type + VirtualStreamTypeNAT uint8 = 5 + + // VirtualStreamTypeSessionDiscovery represents the SessionDiscovery PRUDP virtual connection stream type + VirtualStreamTypeSessionDiscovery uint8 = 6 + + // VirtualStreamTypeNATEcho represents the NATEcho PRUDP virtual connection stream type + VirtualStreamTypeNATEcho uint8 = 7 + + // VirtualStreamTypeRouting represents the Routing PRUDP virtual connection stream type + VirtualStreamTypeRouting uint8 = 8 + + // VirtualStreamTypeGame represents the Game PRUDP virtual connection stream type + VirtualStreamTypeGame uint8 = 9 + + // VirtualStreamTypeRVSecure represents the RVSecure PRUDP virtual connection stream type + VirtualStreamTypeRVSecure uint8 = 10 + + // VirtualStreamTypeRelay represents the Relay PRUDP virtual connection stream type + VirtualStreamTypeRelay uint8 = 11 +) diff --git a/resend_scheduler.go b/resend_scheduler.go index 54c46fde..cad8e1a6 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -1,6 +1,7 @@ package nex import ( + "fmt" "time" ) @@ -94,8 +95,12 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { rs.packets.Delete(packet.SequenceID()) client.cleanup() // * "removed" event is dispatched here - virtualStream := client.server.virtualConnectionManager.Get(client.DestinationPort, client.DestinationStreamType) - virtualStream.clients.Delete(client.address.String()) + virtualServer, _ := client.server.virtualServers.Get(client.DestinationPort) + virtualServerStream, _ := virtualServer.Get(client.DestinationStreamType) + + discriminator := fmt.Sprintf("%s-%d-%d", client.address.String(), client.SourcePort, client.SourceStreamType) + + virtualServerStream.Delete(discriminator) return } From 0ac073d2c8e00985aead2a8e7bc0f0163d7e6220 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 8 Dec 2023 16:25:54 -0500 Subject: [PATCH 058/178] prudp: update default number of server vports --- prudp_server.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 69e056ec..9635ba06 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -17,7 +17,7 @@ type PRUDPServer struct { udpSocket *net.UDPConn PRUDPVersion int PRUDPMinorVersion uint32 - MaxPRUDPVirtualPorts uint8 + MaxVirtualServerPorts uint8 virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] IsQuazalMode bool IsSecureServer bool @@ -111,7 +111,7 @@ func (s *PRUDPServer) Listen(port int) { s.udpSocket = socket - for i := 0; i < int(s.MaxPRUDPVirtualPorts); i++ { + for i := 0; i < int(s.MaxVirtualServerPorts); i++ { virtualServer := NewMutexMap[uint8, *MutexMap[string, *PRUDPClient]]() virtualServer.Set(VirtualStreamTypeDO, NewMutexMap[string, *PRUDPClient]()) virtualServer.Set(VirtualStreamTypeRV, NewMutexMap[string, *PRUDPClient]()) @@ -125,7 +125,7 @@ func (s *PRUDPServer) Listen(port int) { virtualServer.Set(VirtualStreamTypeRVSecure, NewMutexMap[string, *PRUDPClient]()) virtualServer.Set(VirtualStreamTypeRelay, NewMutexMap[string, *PRUDPClient]()) - s.virtualServers.Set(uint8(i), virtualServer) + s.virtualServers.Set(uint8(i+1), virtualServer) // * Don't allow 0 as a vport } logger.Success("Virtual ports created") @@ -1031,9 +1031,9 @@ func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid ui // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ - MaxPRUDPVirtualPorts: 16, // * UDP PRUDP servers use 16 virtual ports - AuthenticationVirtualServerPort: 1, // * Server ports default to 1 - SecureVirtualServerPort: 1, // * Server ports default to 1 + MaxVirtualServerPorts: 1, // * Assume only 1 port per server by default (NEX 3 and below style) + AuthenticationVirtualServerPort: 1, // * Server ports default to 1 + SecureVirtualServerPort: 1, // * Server ports default to 1 virtualServers: NewMutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]](), IsQuazalMode: false, kerberosKeySize: 32, From a7564f55555112061fe4a27f909d67c379dd67e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 9 Dec 2023 21:03:31 +0000 Subject: [PATCH 059/178] Add PasswordFromPID to ServerInterface This function handler is common to all servers. --- hpp_packet.go | 2 +- hpp_server.go | 12 +++++++++++- prudp_server.go | 12 +++++++++++- server_interface.go | 2 ++ test/hpp.go | 2 +- 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/hpp_packet.go b/hpp_packet.go index 94caf5ab..defc9af6 100644 --- a/hpp_packet.go +++ b/hpp_packet.go @@ -91,7 +91,7 @@ func (p *HPPPacket) validatePasswordSignature(signature string) error { } func (p *HPPPacket) calculatePasswordSignature() ([]byte, error) { - passwordFromPID := p.Sender().Server().(*HPPServer).PasswordFromPID + passwordFromPID := p.Sender().Server().PasswordFromPIDFunction() if passwordFromPID == nil { return nil, errors.New("Missing PasswordFromPID") } diff --git a/hpp_server.go b/hpp_server.go index 3591fd7e..fdf5bac8 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -19,7 +19,7 @@ type HPPServer struct { utilityProtocolVersion *LibraryVersion natTraversalProtocolVersion *LibraryVersion dataHandlers []func(packet PacketInterface) - PasswordFromPID func(pid *PID) (string, uint32) + passwordFromPIDHandler func(pid *PID) (string, uint32) } // OnData adds an event handler which is fired when a new HPP request is received @@ -243,6 +243,16 @@ func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { return s.natTraversalProtocolVersion } +// PasswordFromPIDFunction returns the function for HPP to get a NEX password using the PID +func (s *HPPServer) PasswordFromPIDFunction() func(pid *PID) (string, uint32) { + return s.passwordFromPIDHandler +} + +// SetPasswordFromPIDFunction sets the function for HPP to get a NEX password using the PID +func (s *HPPServer) SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) { + s.passwordFromPIDHandler = handler +} + // NewHPPServer returns a new HPP server func NewHPPServer() *HPPServer { return &HPPServer{ diff --git a/prudp_server.go b/prudp_server.go index 9635ba06..cbb2f090 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -41,7 +41,7 @@ type PRUDPServer struct { clientRemovedEventHandlers []func(client *PRUDPClient) connectionIDCounter *Counter[uint32] pingTimeout time.Duration - PasswordFromPID func(pid *PID) (string, uint32) + passwordFromPIDHandler func(pid *PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte CompressionEnabled bool } @@ -1028,6 +1028,16 @@ func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid ui return client } +// PasswordFromPIDFunction returns the function for the auth server to get a NEX password using the PID +func (s *PRUDPServer) PasswordFromPIDFunction() func(pid *PID) (string, uint32) { + return s.passwordFromPIDHandler +} + +// SetPasswordFromPIDFunction sets the function for the auth server to get a NEX password using the PID +func (s *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) { + s.passwordFromPIDHandler = handler +} + // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ diff --git a/server_interface.go b/server_interface.go index bccb2bc7..58d095d7 100644 --- a/server_interface.go +++ b/server_interface.go @@ -15,4 +15,6 @@ type ServerInterface interface { SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) OnData(handler func(packet PacketInterface)) + PasswordFromPIDFunction() func(pid *PID) (string, uint32) + SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) } diff --git a/test/hpp.go b/test/hpp.go index 4bcd4146..36e3cee4 100644 --- a/test/hpp.go +++ b/test/hpp.go @@ -68,7 +68,7 @@ func startHPPServer() { hppServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(2, 4, 1)) hppServer.SetAccessKey("76f26496") - hppServer.PasswordFromPID = passwordFromPID + hppServer.SetPasswordFromPIDFunction(passwordFromPID) hppServer.Listen(8085) } From 3ee001f8786d96d1b1991347526b414731e66481 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 9 Dec 2023 19:35:46 -0500 Subject: [PATCH 060/178] prudp: rename PasswordFromPIDFunction to PasswordFromPID --- hpp_packet.go | 7 +------ hpp_server.go | 6 +++--- prudp_server.go | 11 ++++++++--- server_interface.go | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/hpp_packet.go b/hpp_packet.go index defc9af6..bf84a6c2 100644 --- a/hpp_packet.go +++ b/hpp_packet.go @@ -91,13 +91,8 @@ func (p *HPPPacket) validatePasswordSignature(signature string) error { } func (p *HPPPacket) calculatePasswordSignature() ([]byte, error) { - passwordFromPID := p.Sender().Server().PasswordFromPIDFunction() - if passwordFromPID == nil { - return nil, errors.New("Missing PasswordFromPID") - } - pid := p.Sender().PID() - password, _ := passwordFromPID(pid) + password, _ := p.Sender().Server().PasswordFromPID(pid) if password == "" { return nil, errors.New("PID does not exist") } diff --git a/hpp_server.go b/hpp_server.go index fdf5bac8..590facfb 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -243,9 +243,9 @@ func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { return s.natTraversalProtocolVersion } -// PasswordFromPIDFunction returns the function for HPP to get a NEX password using the PID -func (s *HPPServer) PasswordFromPIDFunction() func(pid *PID) (string, uint32) { - return s.passwordFromPIDHandler +// PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result +func (s *HPPServer) PasswordFromPID(pid *PID) (string, uint32) { + return s.passwordFromPIDHandler(pid) } // SetPasswordFromPIDFunction sets the function for HPP to get a NEX password using the PID diff --git a/prudp_server.go b/prudp_server.go index cbb2f090..f8ae681a 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -1028,9 +1028,14 @@ func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid ui return client } -// PasswordFromPIDFunction returns the function for the auth server to get a NEX password using the PID -func (s *PRUDPServer) PasswordFromPIDFunction() func(pid *PID) (string, uint32) { - return s.passwordFromPIDHandler +// PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result +func (s *PRUDPServer) PasswordFromPID(pid *PID) (string, uint32) { + if s.passwordFromPIDHandler == nil { + logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") + return "", Errors.Core.InvalidHandle + } + + return s.passwordFromPIDHandler(pid) } // SetPasswordFromPIDFunction sets the function for the auth server to get a NEX password using the PID diff --git a/server_interface.go b/server_interface.go index 58d095d7..03a86a7e 100644 --- a/server_interface.go +++ b/server_interface.go @@ -15,6 +15,6 @@ type ServerInterface interface { SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) OnData(handler func(packet PacketInterface)) - PasswordFromPIDFunction() func(pid *PID) (string, uint32) + PasswordFromPID(pid *PID) (string, uint32) SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) } From 35c2afce1a20bcbf4554b9252806345fe781cfdb Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 9 Dec 2023 19:46:05 -0500 Subject: [PATCH 061/178] prudp: handle edge cases for virtual ports --- prudp_server.go | 109 ++++++++++++++++++------------------------------ test/secure.go | 2 +- 2 files changed, 41 insertions(+), 70 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index f8ae681a..f8772dee 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -14,36 +14,34 @@ import ( // PRUDPServer represents a bare-bones PRUDP server type PRUDPServer struct { - udpSocket *net.UDPConn - PRUDPVersion int - PRUDPMinorVersion uint32 - MaxVirtualServerPorts uint8 - virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] - IsQuazalMode bool - IsSecureServer bool - AuthenticationVirtualServerPort uint8 - SecureVirtualServerPort uint8 - SupportedFunctions uint32 - accessKey string - kerberosPassword []byte - kerberosTicketVersion int - kerberosKeySize int - FragmentSize int - version *LibraryVersion - datastoreProtocolVersion *LibraryVersion - matchMakingProtocolVersion *LibraryVersion - rankingProtocolVersion *LibraryVersion - ranking2ProtocolVersion *LibraryVersion - messagingProtocolVersion *LibraryVersion - utilityProtocolVersion *LibraryVersion - natTraversalProtocolVersion *LibraryVersion - prudpEventHandlers map[string][]func(packet PacketInterface) - clientRemovedEventHandlers []func(client *PRUDPClient) - connectionIDCounter *Counter[uint32] - pingTimeout time.Duration - passwordFromPIDHandler func(pid *PID) (string, uint32) - PRUDPv1ConnectionSignatureKey []byte - CompressionEnabled bool + udpSocket *net.UDPConn + PRUDPVersion int + PRUDPMinorVersion uint32 + virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] + IsQuazalMode bool + VirtualServerPorts []uint8 + SecureVirtualServerPorts []uint8 + SupportedFunctions uint32 + accessKey string + kerberosPassword []byte + kerberosTicketVersion int + kerberosKeySize int + FragmentSize int + version *LibraryVersion + datastoreProtocolVersion *LibraryVersion + matchMakingProtocolVersion *LibraryVersion + rankingProtocolVersion *LibraryVersion + ranking2ProtocolVersion *LibraryVersion + messagingProtocolVersion *LibraryVersion + utilityProtocolVersion *LibraryVersion + natTraversalProtocolVersion *LibraryVersion + prudpEventHandlers map[string][]func(packet PacketInterface) + clientRemovedEventHandlers []func(client *PRUDPClient) + connectionIDCounter *Counter[uint32] + pingTimeout time.Duration + passwordFromPIDHandler func(pid *PID) (string, uint32) + PRUDPv1ConnectionSignatureKey []byte + CompressionEnabled bool } // OnData adds an event handler which is fired when a new DATA packet is received @@ -111,7 +109,7 @@ func (s *PRUDPServer) Listen(port int) { s.udpSocket = socket - for i := 0; i < int(s.MaxVirtualServerPorts); i++ { + for _, port := range s.VirtualServerPorts { virtualServer := NewMutexMap[uint8, *MutexMap[string, *PRUDPClient]]() virtualServer.Set(VirtualStreamTypeDO, NewMutexMap[string, *PRUDPClient]()) virtualServer.Set(VirtualStreamTypeRV, NewMutexMap[string, *PRUDPClient]()) @@ -125,7 +123,7 @@ func (s *PRUDPServer) Listen(port int) { virtualServer.Set(VirtualStreamTypeRVSecure, NewMutexMap[string, *PRUDPClient]()) virtualServer.Set(VirtualStreamTypeRelay, NewMutexMap[string, *PRUDPClient]()) - s.virtualServers.Set(uint8(i+1), virtualServer) // * Don't allow 0 as a vport + s.virtualServers.Set(port, virtualServer) } logger.Success("Virtual ports created") @@ -392,7 +390,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { var payload []byte - if s.isSecurePort(packet.DestinationPort()) { + if slices.Contains(s.SecureVirtualServerPorts, packet.DestinationPort()) { sessionKey, pid, checkValue, err := s.readKerberosTicket(packet.Payload()) if err != nil { logger.Error(err.Error()) @@ -821,32 +819,6 @@ func (s *PRUDPServer) compressPayload(payload []byte) ([]byte, error) { return stream.Bytes(), nil } -func (s *PRUDPServer) isSecurePort(port uint8) bool { - // * We have to support cases where 2 physical servers exist - // * and cases where one physical server exists with multiple - // * virtual ports - - // * If marked as true, we can assume that 2 physical servers exist - // * and that this is always the secure server - if s.IsSecureServer { - return true - } - - // * If not marked true, we have to check if multiple virtual - // * ports are set. Any number of virtual ports can be defined, - // * all with different ports, so long as the "secure" port is - // * not the same as the "authentication" port - authPort := s.AuthenticationVirtualServerPort - securePort := s.SecureVirtualServerPort - if authPort != securePort && securePort == port { - return true - } - - // TODO - Are there cases where both RVSecure and OldRVSec are used on the same server, with different ports? - - return false // * Assume not the secure port -} - // AccessKey returns the servers sandbox access key func (s *PRUDPServer) AccessKey() string { return s.accessKey @@ -1046,15 +1018,14 @@ func (s *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *PID) (string, // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ - MaxVirtualServerPorts: 1, // * Assume only 1 port per server by default (NEX 3 and below style) - AuthenticationVirtualServerPort: 1, // * Server ports default to 1 - SecureVirtualServerPort: 1, // * Server ports default to 1 - virtualServers: NewMutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]](), - IsQuazalMode: false, - kerberosKeySize: 32, - FragmentSize: 1300, - prudpEventHandlers: make(map[string][]func(PacketInterface)), - connectionIDCounter: NewCounter[uint32](10), - pingTimeout: time.Second * 15, + VirtualServerPorts: []uint8{1}, + SecureVirtualServerPorts: make([]uint8, 0), + virtualServers: NewMutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]](), + IsQuazalMode: false, + kerberosKeySize: 32, + FragmentSize: 1300, + prudpEventHandlers: make(map[string][]func(PacketInterface)), + connectionIDCounter: NewCounter[uint32](10), + pingTimeout: time.Second * 15, } } diff --git a/test/secure.go b/test/secure.go index 45796500..c1907841 100644 --- a/test/secure.go +++ b/test/secure.go @@ -75,7 +75,7 @@ func startSecureServer() { } }) - secureServer.IsSecureServer = true + secureServer.SecureVirtualServerPorts = []uint8{1} //secureServer.PRUDPVersion = 1 secureServer.SetFragmentSize(962) secureServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) From 240c3161e5cc0ff92fad6a1e2090681606818e6b Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 9 Dec 2023 19:48:25 -0500 Subject: [PATCH 062/178] hpp: add check for passwordFromPIDHandler --- hpp_server.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hpp_server.go b/hpp_server.go index 590facfb..56ceff41 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -245,6 +245,11 @@ func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { // PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result func (s *HPPServer) PasswordFromPID(pid *PID) (string, uint32) { + if s.passwordFromPIDHandler == nil { + logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") + return "", Errors.Core.InvalidHandle + } + return s.passwordFromPIDHandler(pid) } From 3bb9c3fca24338902851929d3df5747af6c86a90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 10 Dec 2023 11:32:26 +0000 Subject: [PATCH 063/178] Update PasswordFromPID error to NotImplemented --- hpp_server.go | 2 +- prudp_server.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hpp_server.go b/hpp_server.go index 56ceff41..eed0ecc0 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -247,7 +247,7 @@ func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { func (s *HPPServer) PasswordFromPID(pid *PID) (string, uint32) { if s.passwordFromPIDHandler == nil { logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") - return "", Errors.Core.InvalidHandle + return "", Errors.Core.NotImplemented } return s.passwordFromPIDHandler(pid) diff --git a/prudp_server.go b/prudp_server.go index f8772dee..88dd9c2f 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -1004,7 +1004,7 @@ func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid ui func (s *PRUDPServer) PasswordFromPID(pid *PID) (string, uint32) { if s.passwordFromPIDHandler == nil { logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") - return "", Errors.Core.InvalidHandle + return "", Errors.Core.NotImplemented } return s.passwordFromPIDHandler(pid) From 5c4004413c4d1daf24114f81fce3fa1666f95ff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 10 Dec 2023 14:09:40 +0000 Subject: [PATCH 064/178] qrv: Add EnhancedChecksum option Quazal Rendez-Vous supports both 1 byte and 4 byte checksums. Add a toggle to enable the 4 byte checksum. --- prudp_packet_v0.go | 12 ++++++------ prudp_server.go | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index df09c4a8..5acb9651 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -148,8 +148,8 @@ func (p *PRUDPPacketV0) decode() error { return fmt.Errorf("Failed to read PRUDPv0 payload size. %s", err.Error()) } } else { - // * Quazal used a 4 byte checksum. NEX uses 1 byte - if server.IsQuazalMode { + // * Some Quazal games use a 4 byte checksum. NEX uses 1 byte + if server.EnhancedChecksum { payloadSize = uint16(p.readStream.Remaining() - 4) } else { payloadSize = uint16(p.readStream.Remaining() - 1) @@ -162,7 +162,7 @@ func (p *PRUDPPacketV0) decode() error { p.payload = p.readStream.ReadBytesNext(int64(payloadSize)) - if server.IsQuazalMode && p.readStream.Remaining() < 4 { + if server.EnhancedChecksum && p.readStream.Remaining() < 4 { return errors.New("Failed to read PRUDPv0 checksum. Not have enough data") } else if p.readStream.Remaining() < 1 { return errors.New("Failed to read PRUDPv0 checksum. Not have enough data") @@ -173,7 +173,7 @@ func (p *PRUDPPacketV0) decode() error { var checksum uint32 var checksumU8 uint8 - if server.IsQuazalMode { + if server.EnhancedChecksum { checksum, err = p.readStream.ReadUInt32LE() } else { checksumU8, err = p.readStream.ReadUInt8() @@ -232,7 +232,7 @@ func (p *PRUDPPacketV0) Bytes() []byte { checksum := p.calculateChecksum(stream.Bytes()) - if server.IsQuazalMode { + if server.EnhancedChecksum { stream.WriteUInt32LE(checksum) } else { stream.WriteUInt8(uint8(checksum)) @@ -313,7 +313,7 @@ func (p *PRUDPPacketV0) calculateChecksum(data []byte) uint32 { server := p.server checksum := sum[byte, uint32]([]byte(server.AccessKey())) - if server.IsQuazalMode { + if server.EnhancedChecksum { padSize := (len(data) + 3) &^ 3 data = append(data, make([]byte, padSize-len(data))...) words := make([]uint32, len(data)/4) diff --git a/prudp_server.go b/prudp_server.go index 88dd9c2f..25e468c0 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -41,6 +41,7 @@ type PRUDPServer struct { pingTimeout time.Duration passwordFromPIDHandler func(pid *PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte + EnhancedChecksum bool CompressionEnabled bool } From 58e6343339dd7037b97e26a2c201edacdf72c58f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 10 Dec 2023 14:20:54 +0000 Subject: [PATCH 065/178] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 32c4c42d..4c021b34 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ ### Usage note -This module provides a barebones PRUDP server for use with titles using the Nintendo NEX library. It provides some support for titles using the original Rendez-Vous library developed by Quazal. This library only provides the low level packet data, as such it is recommended to use [NEX Protocols Go](https://github.com/PretendoNetwork/nex-protocols-go) to develop servers. +This module provides a barebones PRUDP server for use with titles using the Nintendo NEX library. It also provides support for titles using the original Rendez-Vous library developed by Quazal. This library only provides the low level packet data, as such it is recommended to use [NEX Protocols Go](https://github.com/PretendoNetwork/nex-protocols-go) to develop servers. ### Usage From a210ae62c3ec3a1b41aa085ba4d14cdc0020d772 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 10 Dec 2023 19:31:15 +0000 Subject: [PATCH 066/178] prudp: Use client session ID as server session ID --- prudp_client.go | 2 ++ prudp_server.go | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/prudp_client.go b/prudp_client.go index c9036f46..e348a1ff 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -14,6 +14,8 @@ type PRUDPClient struct { pid *PID clientConnectionSignature []byte serverConnectionSignature []byte + clientSessionID uint8 + serverSessionID uint8 sessionKey []byte reliableSubstreams []*ReliablePacketSubstreamManager outgoingUnreliableSequenceIDCounter *Counter[uint16] diff --git a/prudp_server.go b/prudp_server.go index 25e468c0..a2235233 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -355,12 +355,15 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { } client.serverConnectionSignature = packet.getConnectionSignature() + client.clientSessionID = packet.SessionID() connectionSignature, err := packet.calculateConnectionSignature(client.address) if err != nil { logger.Error(err.Error()) } + client.serverSessionID = packet.SessionID() + ack.SetType(ConnectPacket) ack.AddFlag(FlagAck) ack.AddFlag(FlagHasSize) @@ -369,7 +372,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { ack.SetDestinationStreamType(packet.SourceStreamType()) ack.SetDestinationPort(packet.SourcePort()) ack.setConnectionSignature(make([]byte, len(connectionSignature))) - ack.SetSessionID(0) + ack.SetSessionID(client.serverSessionID) ack.SetSequenceID(1) if ack, ok := ack.(*PRUDPPacketV1); ok { @@ -698,6 +701,8 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { } } + packetCopy.SetSessionID(client.serverSessionID) + if packetCopy.Type() == DataPacket && !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { if packetCopy.HasFlag(FlagReliable) { payload := packetCopy.Payload() From 1ccfd3ae423f7de662a3584da45edcd52186b4e8 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 10 Dec 2023 19:31:55 -0500 Subject: [PATCH 067/178] update README --- README.md | 46 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 4c021b34..825dba34 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,52 @@ # NEX Go -## Barebones PRUDP/NEX server library written in Go [![GoDoc](https://godoc.org/github.com/PretendoNetwork/nex-go?status.svg)](https://godoc.org/github.com/PretendoNetwork/nex-go) -### Other NEX libraries -[nex-protocols-go](https://github.com/PretendoNetwork/nex-protocols-go) - NEX protocol definitions +### Overview +NEX is the networking library used by all 1st party, and many 3rd party, games on the Nintendo Wii U, 3DS, and Switch which have online features. The NEX library has many different parts, ranging from low level packet transport to higher level service implementations -[nex-protocols-common-go](https://github.com/PretendoNetwork/nex-protocols-common-go) - NEX protocols used by many games with premade handlers and a high level API +This library implements the lowest level parts of NEX, the transport protocols. For other parts of the NEX stack, see the below libraries. For detailed information on NEX as a whole, see our wiki docs https://nintendo-wiki.pretendo.network/docs/nex ### Install -`go get github.com/PretendoNetwork/nex-go` +``` +go get github.com/PretendoNetwork/nex-go +``` + +### Other NEX libraries +- [nex-protocols-go](https://github.com/PretendoNetwork/nex-protocols-go) - NEX protocol definitions +- [nex-protocols-common-go](https://github.com/PretendoNetwork/nex-protocols-common-go) - Implementations of common NEX protocols which can be reused on many servers + +### Quazal Rendez-Vous +Nintendo did not make NEX from scratch. NEX is largely based on an existing library called Rendez-Vous (QRV), made by Canadian software company Quazal. Quazal licensed Rendez-Vous out to many other companies, and was eventually bought out by Ubisoft. Because of this, QRV is seen in many many other games on all major platforms, especially Ubisoft + +Nintendo modified Rendez-Vous somewhat heavily, simplifying the library/transport protocol quite a bit, and adding several custom services -### Usage note +While the main goal of this library is to support games which use the NEX variant of Rendez-Vous made by Nintendo, we also aim to be compatible with games using the original Rendez-Vous library. Due to the extensible nature of Rendez-Vous, many games may feature customizations much like NEX and have non-standard features/behavior. We do our best to support these cases, but there may be times where supporting all variations becomes untenable. In those cases, a fork of these libraries should be made instead if they require heavy modifications -This module provides a barebones PRUDP server for use with titles using the Nintendo NEX library. It also provides support for titles using the original Rendez-Vous library developed by Quazal. This library only provides the low level packet data, as such it is recommended to use [NEX Protocols Go](https://github.com/PretendoNetwork/nex-protocols-go) to develop servers. +### Supported features +- [x] [HPP servers](https://nintendo-wiki.pretendo.network/docs/hpp) (NEX over HTTP) +- [ ] [PRUDP servers](https://nintendo-wiki.pretendo.network/docs/prudp) + - [x] UDP transport + - [ ] WebSocket transport + - [x] PRUDPv0 packets + - [x] PRUDPv1 packets + - [ ] PRUDPLite packets +- [x] Fragmented packet payloads +- [x] Packet retransmission +- [x] Reliable packets +- [x] Unreliable packets +- [x] [Virtual ports](https://nintendo-wiki.pretendo.network/docs/prudp#virtual-ports) +- [x] Packet compression +- [x] [RMC](https://nintendo-wiki.pretendo.network/docs/rmc) + - [x] Request messages + - [x] Response messages + - [x] "Packed" encoded messages + - [x] "Packed" (extended) encoded messages + - [x] "Verbose" encoded messages +- [x] [Kerberos authentication](https://nintendo-wiki.pretendo.network/docs/nex/kerberos) -### Usage +### Example ```go package main From 99c37d3df2bfb64064763690fd8c7e2ccdfb5a90 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 11 Dec 2023 00:22:19 -0500 Subject: [PATCH 068/178] prudp: experimental WSS and PRUDPLite support --- go.mod | 6 +- go.sum | 13 +- prudp_client.go | 16 ++- prudp_packet_lite.go | 316 +++++++++++++++++++++++++++++++++++++++++++ prudp_server.go | 185 ++++++++++++++++++------- websocket_server.go | 58 ++++++++ 6 files changed, 532 insertions(+), 62 deletions(-) create mode 100644 prudp_packet_lite.go create mode 100644 websocket_server.go diff --git a/go.mod b/go.mod index 173c8657..e0a48777 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( github.com/PretendoNetwork/plogger-go v1.0.4 + github.com/gorilla/websocket v1.5.1 github.com/superwhiskers/crunch/v3 v3.5.7 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/mod v0.12.0 @@ -14,6 +15,7 @@ require ( github.com/jwalton/go-supportscolor v1.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - golang.org/x/sys v0.12.0 // indirect - golang.org/x/term v0.11.0 // indirect + golang.org/x/net v0.19.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/term v0.15.0 // indirect ) diff --git a/go.sum b/go.sum index ef5e34c6..79cde014 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,9 @@ github.com/PretendoNetwork/plogger-go v1.0.4/go.mod h1:7kD6M4vPq1JL4LTuPg6kuB1Ov github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jwalton/go-supportscolor v1.2.0 h1:g6Ha4u7Vm3LIsQ5wmeBpS4gazu0UP1DRDE8y6bre4H8= github.com/jwalton/go-supportscolor v1.2.0/go.mod h1:hFVUAZV2cWg+WFFC4v8pT2X/S2qUUBYMioBD9AINXGs= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -16,12 +19,14 @@ golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjs golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= -golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= diff --git a/prudp_client.go b/prudp_client.go index e348a1ff..4f142f75 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -5,12 +5,15 @@ import ( "fmt" "net" "time" + + "github.com/gorilla/websocket" ) // PRUDPClient represents a single PRUDP client type PRUDPClient struct { - address *net.UDPAddr server *PRUDPServer + address net.Addr + webSocketConnection *websocket.Conn pid *PID clientConnectionSignature []byte serverConnectionSignature []byte @@ -65,6 +68,10 @@ func (c *PRUDPClient) cleanup() { c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) c.stopHeartbeatTimers() + if c.webSocketConnection != nil { + c.webSocketConnection.Close() + } + c.server.emitRemoved(c) } @@ -209,11 +216,12 @@ func (c *PRUDPClient) stopHeartbeatTimers() { } } -// NewPRUDPClient creates and returns a new Client using the provided UDP address and server -func NewPRUDPClient(address *net.UDPAddr, server *PRUDPServer) *PRUDPClient { +// NewPRUDPClient creates and returns a new PRUDPClient +func NewPRUDPClient(server *PRUDPServer, address net.Addr, webSocketConnection *websocket.Conn) *PRUDPClient { return &PRUDPClient{ - address: address, server: server, + address: address, + webSocketConnection: webSocketConnection, outgoingPingSequenceIDCounter: NewCounter[uint16](0), pid: NewPID[uint32](0), unreliableBaseKey: make([]byte, 0x20), diff --git a/prudp_packet_lite.go b/prudp_packet_lite.go new file mode 100644 index 00000000..ed1d3f10 --- /dev/null +++ b/prudp_packet_lite.go @@ -0,0 +1,316 @@ +package nex + +import ( + "crypto/hmac" + "crypto/md5" + "encoding/binary" + "fmt" + "net" +) + +// PRUDPPacketLite represents a PRUDPLite packet +type PRUDPPacketLite struct { + PRUDPPacket + optionsLength uint8 + minorVersion uint32 + supportedFunctions uint32 + maximumSubstreamID uint8 + initialUnreliableSequenceID uint16 + liteSignature []byte +} + +// Copy copies the packet into a new PRUDPPacketLite +// +// Retains the same PRUDPClient pointer +func (p *PRUDPPacketLite) Copy() PRUDPPacketInterface { + copied, _ := NewPRUDPPacketLite(p.sender, nil) + + copied.server = p.server + copied.sourceStreamType = p.sourceStreamType + copied.sourcePort = p.sourcePort + copied.destinationStreamType = p.destinationStreamType + copied.destinationPort = p.destinationPort + copied.packetType = p.packetType + copied.flags = p.flags + copied.sessionID = p.sessionID + copied.substreamID = p.substreamID + + if p.signature != nil { + copied.signature = append([]byte(nil), p.signature...) + } + + copied.sequenceID = p.sequenceID + + if p.connectionSignature != nil { + copied.connectionSignature = append([]byte(nil), p.connectionSignature...) + } + + copied.fragmentID = p.fragmentID + + if p.payload != nil { + copied.payload = append([]byte(nil), p.payload...) + } + + if p.message != nil { + copied.message = p.message.Copy() + } + + copied.optionsLength = p.optionsLength + copied.minorVersion = p.minorVersion + copied.supportedFunctions = p.supportedFunctions + copied.maximumSubstreamID = p.maximumSubstreamID + copied.initialUnreliableSequenceID = p.initialUnreliableSequenceID + + return copied +} + +// Version returns the packets PRUDP version +func (p *PRUDPPacketLite) Version() int { + return 2 +} + +// decode parses the packets data +func (p *PRUDPPacketLite) decode() error { + magic, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to read PRUDPLite magic. %s", err.Error()) + } + + if magic != 0x80 { + return fmt.Errorf("Invalid PRUDPLite magic. Expected 0x80, got 0x%x", magic) + } + + p.optionsLength, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite options length. %s", err.Error()) + } + + payloadLength, err := p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite payload length. %s", err.Error()) + } + + streamTypes, err := p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite virtual ports stream types. %s", err.Error()) + } + + p.sourceStreamType = streamTypes >> 4 + p.destinationStreamType = streamTypes & 0xF + + p.sourcePort, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite virtual source port. %s", err.Error()) + } + + p.destinationPort, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite virtual destination port. %s", err.Error()) + } + + p.fragmentID, err = p.readStream.ReadUInt8() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite fragment ID. %s", err.Error()) + } + + typeAndFlags, err := p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read PRUDPLite type and flags. %s", err.Error()) + } + + p.flags = typeAndFlags >> 4 + p.packetType = typeAndFlags & 0xF + + p.sequenceID, err = p.readStream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite sequence ID. %s", err.Error()) + } + + err = p.decodeOptions() + if err != nil { + return fmt.Errorf("Failed to decode PRUDPLite options. %s", err.Error()) + } + + p.payload = p.readStream.ReadBytesNext(int64(payloadLength)) + + return nil +} + +// Bytes encodes a PRUDPLite packet into a byte slice +func (p *PRUDPPacketLite) Bytes() []byte { + options := p.encodeOptions() + + stream := NewStreamOut(nil) + + stream.WriteUInt8(0x80) + stream.WriteUInt8(uint8(len(options))) + stream.WriteUInt16LE(uint16(len(p.payload))) + stream.WriteUInt8((p.sourceStreamType << 4) | p.destinationStreamType) + stream.WriteUInt8(p.sourcePort) + stream.WriteUInt8(p.destinationPort) + stream.WriteUInt8(p.fragmentID) + stream.WriteUInt16LE(p.packetType | (p.flags << 4)) + stream.WriteUInt16LE(p.sequenceID) + + stream.Grow(int64(len(options))) + stream.WriteBytesNext(options) + + stream.Grow(int64(len(p.payload))) + stream.WriteBytesNext(p.payload) + + return stream.Bytes() +} + +func (p *PRUDPPacketLite) decodeOptions() error { + data := p.readStream.ReadBytesNext(int64(p.optionsLength)) + optionsStream := NewStreamIn(data, nil) + + for optionsStream.Remaining() > 0 { + optionID, err := optionsStream.ReadUInt8() + if err != nil { + return err + } + + optionSize, err := optionsStream.ReadUInt8() // * Options size. We already know the size based on the ID, though + if err != nil { + return err + } + + if p.packetType == SynPacket || p.packetType == ConnectPacket { + if optionID == 0 { + p.supportedFunctions, err = optionsStream.ReadUInt32LE() + + p.minorVersion = p.supportedFunctions & 0xFF + p.supportedFunctions = p.supportedFunctions >> 8 + } + + if optionID == 1 { + p.connectionSignature = optionsStream.ReadBytesNext(int64(optionSize)) + } + + if optionID == 4 { + p.maximumSubstreamID, err = optionsStream.ReadUInt8() + } + } + + if p.packetType == ConnectPacket { + if optionID == 3 { + p.initialUnreliableSequenceID, err = optionsStream.ReadUInt16LE() + } + } + + if p.packetType == DataPacket { + if optionID == 2 { + p.fragmentID, err = optionsStream.ReadUInt8() + } + } + + if p.packetType == ConnectPacket && !p.HasFlag(FlagAck) { + if optionID == 0x80 { + p.liteSignature = optionsStream.ReadBytesNext(int64(optionSize)) + } + } + + // * Only one option is processed at a time, so we can + // * just check for errors here rather than after EVERY + // * read + if err != nil { + return err + } + } + + return nil +} + +func (p *PRUDPPacketLite) encodeOptions() []byte { + optionsStream := NewStreamOut(nil) + + if p.packetType == SynPacket || p.packetType == ConnectPacket { + optionsStream.WriteUInt8(0) + optionsStream.WriteUInt8(4) + optionsStream.WriteUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) + + if p.packetType == SynPacket && p.HasFlag(FlagAck) { + optionsStream.WriteUInt8(1) + optionsStream.WriteUInt8(16) + optionsStream.Grow(16) + optionsStream.WriteBytesNext(p.connectionSignature) + } + + if p.packetType == ConnectPacket && !p.HasFlag(FlagAck) { + optionsStream.WriteUInt8(1) + optionsStream.WriteUInt8(16) + optionsStream.Grow(16) + optionsStream.WriteBytesNext(p.liteSignature) + } + } + + return optionsStream.Bytes() +} + +func (p *PRUDPPacketLite) calculateConnectionSignature(addr net.Addr) ([]byte, error) { + var ip net.IP + var port int + + switch v := addr.(type) { + case *net.TCPAddr: + ip = v.IP.To4() + port = v.Port + default: + return nil, fmt.Errorf("Unsupported network type: %T", addr) + } + + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(port)) + + data := append(ip, portBytes...) + hash := hmac.New(md5.New, p.server.PRUDPv1ConnectionSignatureKey) + hash.Write(data) + + return hash.Sum(nil), nil +} + +func (p *PRUDPPacketLite) calculateSignature(sessionKey, connectionSignature []byte) []byte { + // * PRUDPLite has no signature + return make([]byte, 0) +} + +// NewPRUDPPacketLite creates and returns a new PacketLite using the provided Client and stream +func NewPRUDPPacketLite(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketLite, error) { + packet := &PRUDPPacketLite{ + PRUDPPacket: PRUDPPacket{ + sender: client, + readStream: readStream, + }, + } + + if readStream != nil { + packet.server = readStream.Server.(*PRUDPServer) + err := packet.decode() + if err != nil { + return nil, fmt.Errorf("Failed to decode PRUDPLite packet. %s", err.Error()) + } + } + + if client != nil { + packet.server = client.server + } + + return packet, nil +} + +// NewPRUDPPacketsLite reads all possible PRUDPLite packets from the stream +func NewPRUDPPacketsLite(client *PRUDPClient, readStream *StreamIn) ([]PRUDPPacketInterface, error) { + packets := make([]PRUDPPacketInterface, 0) + + for readStream.Remaining() > 0 { + packet, err := NewPRUDPPacketLite(client, readStream) + if err != nil { + return packets, err + } + + packets = append(packets, packet) + } + + return packets, nil +} diff --git a/prudp_server.go b/prudp_server.go index a2235233..0f0bbcc7 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -4,17 +4,21 @@ import ( "bytes" "compress/zlib" "crypto/rand" + "encoding/hex" "errors" "fmt" "net" "runtime" "slices" "time" + + "github.com/gorilla/websocket" ) // PRUDPServer represents a bare-bones PRUDP server type PRUDPServer struct { udpSocket *net.UDPConn + websocketServer *WebSocketServer PRUDPVersion int PRUDPMinorVersion uint32 virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] @@ -87,16 +91,15 @@ func (s *PRUDPServer) emitRemoved(client *PRUDPClient) { } } -// Listen starts a PRUDP server on a given port +// Listen is an alias of ListenUDP. Implemented to conform to the ServerInterface func (s *PRUDPServer) Listen(port int) { - // * Ensure the server has a key for PRUDPv1 connection signatures - if len(s.PRUDPv1ConnectionSignatureKey) != 16 { - s.PRUDPv1ConnectionSignatureKey = make([]byte, 16) - _, err := rand.Read(s.PRUDPv1ConnectionSignatureKey) - if err != nil { - panic(err) - } - } + s.ListenUDP(port) +} + +// ListenUDP starts a PRUDP server on a given port using a UDP server +func (s *PRUDPServer) ListenUDP(port int) { + s.initPRUDPv1ConnectionSignatureKey() + s.initVirtualPorts() udpAddress, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err != nil { @@ -110,6 +113,60 @@ func (s *PRUDPServer) Listen(port int) { s.udpSocket = socket + quit := make(chan struct{}) + + for i := 0; i < runtime.NumCPU(); i++ { + go s.listenDatagram(quit) + } + + <-quit +} + +// ListenWebSocket starts a PRUDP server on a given port using a WebSocket server +func (s *PRUDPServer) ListenWebSocket(port int) { + + s.initPRUDPv1ConnectionSignatureKey() + s.initVirtualPorts() + + s.websocketServer = &WebSocketServer{ + upgrader: websocket.Upgrader{ + ReadBufferSize: 64000, + WriteBufferSize: 64000, + }, + handleSocketMessage: s.handleSocketMessage, + } + + s.websocketServer.listen(port) +} + +// ListenWebSocketSecure starts a PRUDP server on a given port using a secure (TLS) WebSocket server +func (s *PRUDPServer) ListenWebSocketSecure(port int, certFile, keyFile string) { + s.initPRUDPv1ConnectionSignatureKey() + s.initVirtualPorts() + + s.websocketServer = &WebSocketServer{ + upgrader: websocket.Upgrader{ + ReadBufferSize: 64000, + WriteBufferSize: 64000, + }, + handleSocketMessage: s.handleSocketMessage, + } + + s.websocketServer.listenSecure(port, certFile, keyFile) +} + +func (s *PRUDPServer) initPRUDPv1ConnectionSignatureKey() { + // * Ensure the server has a key for PRUDPv1 connection signatures + if len(s.PRUDPv1ConnectionSignatureKey) != 16 { + s.PRUDPv1ConnectionSignatureKey = make([]byte, 16) + _, err := rand.Read(s.PRUDPv1ConnectionSignatureKey) + if err != nil { + panic(err) + } + } +} + +func (s *PRUDPServer) initVirtualPorts() { for _, port := range s.VirtualServerPorts { virtualServer := NewMutexMap[uint8, *MutexMap[string, *PRUDPClient]]() virtualServer.Set(VirtualStreamTypeDO, NewMutexMap[string, *PRUDPClient]()) @@ -128,21 +185,20 @@ func (s *PRUDPServer) Listen(port int) { } logger.Success("Virtual ports created") - - quit := make(chan struct{}) - - for i := 0; i < runtime.NumCPU(); i++ { - go s.listenDatagram(quit) - } - - <-quit } func (s *PRUDPServer) listenDatagram(quit chan struct{}) { - err := error(nil) + var err error for err == nil { - err = s.handleSocketMessage() + buffer := make([]byte, 64000) + var read int + var addr *net.UDPAddr + + read, addr, err = s.udpSocket.ReadFromUDP(buffer) + packetData := buffer[:read] + + err = s.handleSocketMessage(packetData, addr, nil) } quit <- struct{}{} @@ -150,15 +206,7 @@ func (s *PRUDPServer) listenDatagram(quit chan struct{}) { panic(err) } -func (s *PRUDPServer) handleSocketMessage() error { - buffer := make([]byte, 64000) - - read, addr, err := s.udpSocket.ReadFromUDP(buffer) - if err != nil { - return err - } - - packetData := buffer[:read] +func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *websocket.Conn) error { readStream := NewStreamIn(packetData, s) var packets []PRUDPPacketInterface @@ -167,20 +215,22 @@ func (s *PRUDPServer) handleSocketMessage() error { // * with that same type. Also keep reading from the stream // * until no more data is left, to account for multiple // * packets being sent at once - if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) { + if s.websocketServer != nil && packetData[0] == 0x80 { + packets, _ = NewPRUDPPacketsLite(nil, readStream) + } else if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) { packets, _ = NewPRUDPPacketsV1(nil, readStream) } else { packets, _ = NewPRUDPPacketsV0(nil, readStream) } for _, packet := range packets { - go s.processPacket(packet, addr) + go s.processPacket(packet, address, webSocketConnection) } return nil } -func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address *net.UDPAddr) { +func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *websocket.Conn) { virtualServer, _ := s.virtualServers.Get(packet.DestinationPort()) virtualServerStream, _ := virtualServer.Get(packet.DestinationStreamType()) @@ -189,7 +239,7 @@ func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address *net.UD client, ok := virtualServerStream.Get(discriminator) if !ok { - client = NewPRUDPClient(address, s) + client = NewPRUDPClient(s, address, webSocketConnection) client.startHeartbeat() // * Fail-safe. If the server reboots, then @@ -303,10 +353,12 @@ func (s *PRUDPServer) handleSyn(packet PRUDPPacketInterface) { var ack PRUDPPacketInterface - if packet.Version() == 0 { - ack, _ = NewPRUDPPacketV0(client, nil) - } else { + if packet.Version() == 2 { + ack, _ = NewPRUDPPacketLite(client, nil) + } else if packet.Version() == 1 { ack, _ = NewPRUDPPacketV1(client, nil) + } else { + ack, _ = NewPRUDPPacketV0(client, nil) } connectionSignature, err := packet.calculateConnectionSignature(client.address) @@ -340,7 +392,7 @@ func (s *PRUDPServer) handleSyn(packet PRUDPPacketInterface) { s.emit("syn", ack) - s.sendRaw(client.address, ack.Bytes()) + s.sendRaw(client, ack.Bytes()) } func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { @@ -348,10 +400,12 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { var ack PRUDPPacketInterface - if packet.Version() == 0 { - ack, _ = NewPRUDPPacketV0(client, nil) - } else { + if packet.Version() == 2 { + ack, _ = NewPRUDPPacketLite(client, nil) + } else if packet.Version() == 1 { ack, _ = NewPRUDPPacketV1(client, nil) + } else { + ack, _ = NewPRUDPPacketV0(client, nil) } client.serverConnectionSignature = packet.getConnectionSignature() @@ -421,7 +475,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { s.emit("connect", ack) - s.sendRaw(client.address, ack.Bytes()) + s.sendRaw(client, ack.Bytes()) } func (s *PRUDPServer) handleData(packet PRUDPPacketInterface) { @@ -515,10 +569,12 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *PID, uint32, func (s *PRUDPServer) acknowledgePacket(packet PRUDPPacketInterface) { var ack PRUDPPacketInterface - if packet.Version() == 0 { - ack, _ = NewPRUDPPacketV0(packet.Sender().(*PRUDPClient), nil) - } else { + if packet.Version() == 2 { + ack, _ = NewPRUDPPacketLite(packet.Sender().(*PRUDPClient), nil) + } else if packet.Version() == 1 { ack, _ = NewPRUDPPacketV1(packet.Sender().(*PRUDPClient), nil) + } else { + ack, _ = NewPRUDPPacketV0(packet.Sender().(*PRUDPClient), nil) } ack.SetType(packet.Type()) @@ -549,7 +605,15 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { for _, pendingPacket := range substream.Update(packet) { if packet.Type() == DataPacket { - decryptedPayload := pendingPacket.decryptPayload() + var decryptedPayload []byte + + if packet.Version() != 2 { + decryptedPayload = pendingPacket.decryptPayload() + } else { + // * PRUDPLite does not encrypt payloads + decryptedPayload = pendingPacket.Payload() + } + decompressedPayload, err := s.decompressPayload(decryptedPayload) if err != nil { logger.Error(err.Error()) @@ -560,6 +624,7 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { if packet.getFragmentID() == 0 { message := NewRMCMessage() err := message.FromBytes(payload) + fmt.Println(hex.EncodeToString(payload)) if err != nil { // TODO - Should this return the error too? logger.Error(err.Error()) @@ -639,7 +704,9 @@ func (s *PRUDPServer) handleUnreliable(packet PRUDPPacketInterface) { func (s *PRUDPServer) sendPing(client *PRUDPClient) { var ping PRUDPPacketInterface - if s.PRUDPVersion == 0 { + if s.websocketServer != nil { + ping, _ = NewPRUDPPacketLite(client, nil) + } else if s.PRUDPVersion == 0 { ping, _ = NewPRUDPPacketV0(client, nil) } else { ping, _ = NewPRUDPPacketV1(client, nil) @@ -721,9 +788,15 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { substream.SetCipherKey([]byte("CD&ML")) } - packetCopy.SetPayload(substream.Encrypt(compressedPayload)) + // * PRUDPLite packet. No RC4 + if packetCopy.Version() != 2 { + packetCopy.SetPayload(substream.Encrypt(compressedPayload)) + } } else { - packetCopy.SetPayload(packetCopy.processUnreliableCrypto()) + // * PRUDPLite packet. No RC4 + if packetCopy.Version() != 2 { + packetCopy.SetPayload(packetCopy.processUnreliableCrypto()) + } } } @@ -734,14 +807,22 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { substream.ResendScheduler.AddPacket(packetCopy) } - s.sendRaw(packetCopy.Sender().Address(), packetCopy.Bytes()) + s.sendRaw(packetCopy.Sender().(*PRUDPClient), packetCopy.Bytes()) } -// sendRaw will send the given address the provided packet -func (s *PRUDPServer) sendRaw(conn net.Addr, data []byte) { - _, err := s.udpSocket.WriteToUDP(data, conn.(*net.UDPAddr)) +// sendRaw will send the given client the provided packet +func (s *PRUDPServer) sendRaw(client *PRUDPClient, data []byte) { + // TODO - Should this return the error too? + + var err error + + if s.udpSocket != nil { + _, err = s.udpSocket.WriteToUDP(data, client.address.(*net.UDPAddr)) + } else if client.webSocketConnection != nil { + err = client.webSocketConnection.WriteMessage(websocket.BinaryMessage, data) + } + if err != nil { - // TODO - Should this return the error too? logger.Error(err.Error()) } } diff --git a/websocket_server.go b/websocket_server.go new file mode 100644 index 00000000..75cd3c5e --- /dev/null +++ b/websocket_server.go @@ -0,0 +1,58 @@ +package nex + +import ( + "fmt" + "net" + "net/http" + + "github.com/gorilla/websocket" +) + +// WebSocketServer wraps a WebSocket server to create an easier API to consume +type WebSocketServer struct { + mux *http.ServeMux + upgrader websocket.Upgrader + handleSocketMessage func(packetData []byte, address net.Addr, webSocketConnection *websocket.Conn) error +} + +func (ws *WebSocketServer) handleConnection(conn *websocket.Conn) { + defer conn.Close() + + conn.RemoteAddr() + + for { + _, data, err := conn.ReadMessage() + if err != nil { + logger.Error(err.Error()) + return + } + + ws.handleSocketMessage(data, conn.RemoteAddr(), conn) + } +} + +func (ws *WebSocketServer) initMux() { + ws.mux = http.NewServeMux() + ws.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + conn, err := ws.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Error(err.Error()) + return + } + defer conn.Close() + + ws.handleConnection(conn) + }) +} + +func (ws *WebSocketServer) listen(port int) { + ws.initMux() + + http.ListenAndServe(fmt.Sprintf(":%d", port), ws.mux) +} + +func (ws *WebSocketServer) listenSecure(port int, certFile, keyFile string) { + ws.initMux() + + http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, ws.mux) +} From 0272e257233bd6dc85f91c0983a9f078b2253e58 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 11 Dec 2023 14:01:08 -0500 Subject: [PATCH 069/178] prudp: forgot to commit resend changes last night --- resend_scheduler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resend_scheduler.go b/resend_scheduler.go index cad8e1a6..df174edc 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -109,7 +109,7 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { // * Resend the packet to the client server := client.server data := packet.Bytes() - server.sendRaw(client.Address(), data) + server.sendRaw(client, data) pendingPacket.interval += rs.Increase pendingPacket.ticker.Reset(pendingPacket.interval) From a64a555520a27df6ac92fb5ef0b8e860abad47af Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 12 Dec 2023 04:32:34 -0500 Subject: [PATCH 070/178] prudp: swap WebSocket libraries --- go.mod | 4 +-- go.sum | 16 ++++++--- prudp_client.go | 8 ++--- prudp_server.go | 22 ++++-------- rmc_message.go | 2 +- websocket_server.go | 88 +++++++++++++++++++++++++++++++-------------- 6 files changed, 87 insertions(+), 53 deletions(-) diff --git a/go.mod b/go.mod index e0a48777..414359b2 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/PretendoNetwork/plogger-go v1.0.4 - github.com/gorilla/websocket v1.5.1 + github.com/lxzan/gws v1.7.0 github.com/superwhiskers/crunch/v3 v3.5.7 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/mod v0.12.0 @@ -13,9 +13,9 @@ require ( require ( github.com/fatih/color v1.15.0 // indirect github.com/jwalton/go-supportscolor v1.2.0 // indirect + github.com/klauspost/compress v1.16.5 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - golang.org/x/net v0.19.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect ) diff --git a/go.sum b/go.sum index 79cde014..6b888962 100644 --- a/go.sum +++ b/go.sum @@ -1,26 +1,32 @@ github.com/PretendoNetwork/plogger-go v1.0.4 h1:PF7xHw9eDRHH+RsAP9tmAE7fG0N0p6H4iPwHKnsoXwc= github.com/PretendoNetwork/plogger-go v1.0.4/go.mod h1:7kD6M4vPq1JL4LTuPg6kuB1OvUBOwQOtAvTaUwMbwvU= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jwalton/go-supportscolor v1.2.0 h1:g6Ha4u7Vm3LIsQ5wmeBpS4gazu0UP1DRDE8y6bre4H8= github.com/jwalton/go-supportscolor v1.2.0/go.mod h1:hFVUAZV2cWg+WFFC4v8pT2X/S2qUUBYMioBD9AINXGs= +github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= +github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/lxzan/gws v1.7.0 h1:/yy5/+3eccMy61/scXM57fTDvucN/t7/0t5wLTwL+qY= +github.com/lxzan/gws v1.7.0/go.mod h1:dsC6S7kJNh+iWqqu2HiO8tnNCji04HwyJCYfTOS+6iY= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/superwhiskers/crunch/v3 v3.5.7 h1:N9RLxaR65C36i26BUIpzPXGy2f6pQ7wisu2bawbKNqg= github.com/superwhiskers/crunch/v3 v3.5.7/go.mod h1:4ub2EKgF1MAhTjoOCTU4b9uLMsAweHEa89aRrfAypXA= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -30,3 +36,5 @@ golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/prudp_client.go b/prudp_client.go index 4f142f75..53c30370 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -6,14 +6,14 @@ import ( "net" "time" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" ) // PRUDPClient represents a single PRUDP client type PRUDPClient struct { server *PRUDPServer address net.Addr - webSocketConnection *websocket.Conn + webSocketConnection *gws.Conn pid *PID clientConnectionSignature []byte serverConnectionSignature []byte @@ -69,7 +69,7 @@ func (c *PRUDPClient) cleanup() { c.stopHeartbeatTimers() if c.webSocketConnection != nil { - c.webSocketConnection.Close() + c.webSocketConnection.NetConn().Close() // TODO - Swap this out for WriteClose() to send a close frame? } c.server.emitRemoved(c) @@ -217,7 +217,7 @@ func (c *PRUDPClient) stopHeartbeatTimers() { } // NewPRUDPClient creates and returns a new PRUDPClient -func NewPRUDPClient(server *PRUDPServer, address net.Addr, webSocketConnection *websocket.Conn) *PRUDPClient { +func NewPRUDPClient(server *PRUDPServer, address net.Addr, webSocketConnection *gws.Conn) *PRUDPClient { return &PRUDPClient{ server: server, address: address, diff --git a/prudp_server.go b/prudp_server.go index 0f0bbcc7..fc7e1fe3 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -4,7 +4,6 @@ import ( "bytes" "compress/zlib" "crypto/rand" - "encoding/hex" "errors" "fmt" "net" @@ -12,7 +11,7 @@ import ( "slices" "time" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" ) // PRUDPServer represents a bare-bones PRUDP server @@ -129,11 +128,7 @@ func (s *PRUDPServer) ListenWebSocket(port int) { s.initVirtualPorts() s.websocketServer = &WebSocketServer{ - upgrader: websocket.Upgrader{ - ReadBufferSize: 64000, - WriteBufferSize: 64000, - }, - handleSocketMessage: s.handleSocketMessage, + prudpServer: s, } s.websocketServer.listen(port) @@ -145,11 +140,7 @@ func (s *PRUDPServer) ListenWebSocketSecure(port int, certFile, keyFile string) s.initVirtualPorts() s.websocketServer = &WebSocketServer{ - upgrader: websocket.Upgrader{ - ReadBufferSize: 64000, - WriteBufferSize: 64000, - }, - handleSocketMessage: s.handleSocketMessage, + prudpServer: s, } s.websocketServer.listenSecure(port, certFile, keyFile) @@ -206,7 +197,7 @@ func (s *PRUDPServer) listenDatagram(quit chan struct{}) { panic(err) } -func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *websocket.Conn) error { +func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *gws.Conn) error { readStream := NewStreamIn(packetData, s) var packets []PRUDPPacketInterface @@ -230,7 +221,7 @@ func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, w return nil } -func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *websocket.Conn) { +func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *gws.Conn) { virtualServer, _ := s.virtualServers.Get(packet.DestinationPort()) virtualServerStream, _ := virtualServer.Get(packet.DestinationStreamType()) @@ -624,7 +615,6 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { if packet.getFragmentID() == 0 { message := NewRMCMessage() err := message.FromBytes(payload) - fmt.Println(hex.EncodeToString(payload)) if err != nil { // TODO - Should this return the error too? logger.Error(err.Error()) @@ -819,7 +809,7 @@ func (s *PRUDPServer) sendRaw(client *PRUDPClient, data []byte) { if s.udpSocket != nil { _, err = s.udpSocket.WriteToUDP(data, client.address.(*net.UDPAddr)) } else if client.webSocketConnection != nil { - err = client.webSocketConnection.WriteMessage(websocket.BinaryMessage, data) + err = client.webSocketConnection.WriteMessage(gws.OpcodeBinary, data) } if err != nil { diff --git a/rmc_message.go b/rmc_message.go index f9ac68ba..0d5933fe 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -139,7 +139,7 @@ func (rmc *RMCMessage) Bytes() []byte { // * do it for accuracy. if !rmc.IsHPP || (rmc.IsHPP && rmc.IsRequest) { if rmc.ProtocolID < 0x80 { - stream.WriteUInt8(uint8(rmc.ProtocolID | protocolIDFlag)) + stream.WriteUInt8(uint8(rmc.ProtocolID | protocolIDFlag)) } else { stream.WriteUInt8(uint8(0x7F | protocolIDFlag)) stream.WriteUInt16LE(rmc.ProtocolID) diff --git a/websocket_server.go b/websocket_server.go index 75cd3c5e..525282bf 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -2,57 +2,93 @@ package nex import ( "fmt" - "net" "net/http" + "time" - "github.com/gorilla/websocket" + "github.com/lxzan/gws" ) -// WebSocketServer wraps a WebSocket server to create an easier API to consume -type WebSocketServer struct { - mux *http.ServeMux - upgrader websocket.Upgrader - handleSocketMessage func(packetData []byte, address net.Addr, webSocketConnection *websocket.Conn) error +const ( + pingInterval = 5 * time.Second + pingWait = 10 * time.Second +) + +type wsEventHandler struct { + prudpServer *PRUDPServer } -func (ws *WebSocketServer) handleConnection(conn *websocket.Conn) { - defer conn.Close() +func (wseh *wsEventHandler) OnOpen(socket *gws.Conn) { + _ = socket.SetDeadline(time.Now().Add(pingInterval + pingWait)) +} - conn.RemoteAddr() +func (wseh *wsEventHandler) OnClose(socket *gws.Conn, err error) { + // TODO - Client clean up +} - for { - _, data, err := conn.ReadMessage() - if err != nil { - logger.Error(err.Error()) - return - } +func (wseh *wsEventHandler) OnPing(socket *gws.Conn, payload []byte) { + _ = socket.SetDeadline(time.Now().Add(pingInterval + pingWait)) + _ = socket.WritePong(nil) +} - ws.handleSocketMessage(data, conn.RemoteAddr(), conn) +func (wseh *wsEventHandler) OnPong(socket *gws.Conn, payload []byte) {} + +func (wseh *wsEventHandler) OnMessage(socket *gws.Conn, message *gws.Message) { + defer message.Close() + + // * Create a COPY of the underlying *bytes.Buffer bytes. + // * If this is not done, then the byte slice sometimes + // * gets modified in unexpected places + packetData := append([]byte(nil), message.Bytes()...) + err := wseh.prudpServer.handleSocketMessage(packetData, socket.RemoteAddr(), socket) + if err != nil { + logger.Error(err.Error()) } } -func (ws *WebSocketServer) initMux() { +// WebSocketServer wraps a WebSocket server to create an easier API to consume +type WebSocketServer struct { + mux *http.ServeMux + upgrader *gws.Upgrader + prudpServer *PRUDPServer +} + +func (ws *WebSocketServer) init() { + ws.upgrader = gws.NewUpgrader(&wsEventHandler{ + prudpServer: ws.prudpServer, + }, &gws.ServerOption{ + ReadAsyncEnabled: true, // * Parallel message processing + Recovery: gws.Recovery, // * Exception recovery + ReadBufferSize: 64000, + WriteBufferSize: 64000, + }) + ws.mux = http.NewServeMux() ws.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - conn, err := ws.upgrader.Upgrade(w, r, nil) + socket, err := ws.upgrader.Upgrade(w, r) if err != nil { - logger.Error(err.Error()) return } - defer conn.Close() - ws.handleConnection(conn) + go func() { + socket.ReadLoop() // * Blocking prevents the context from being GC + }() }) } func (ws *WebSocketServer) listen(port int) { - ws.initMux() + ws.init() - http.ListenAndServe(fmt.Sprintf(":%d", port), ws.mux) + err := http.ListenAndServe(fmt.Sprintf(":%d", port), ws.mux) + if err != nil { + logger.Error(err.Error()) + } } func (ws *WebSocketServer) listenSecure(port int, certFile, keyFile string) { - ws.initMux() + ws.init() - http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, ws.mux) + err := http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, ws.mux) + if err != nil { + logger.Error(err.Error()) + } } From d12b98e09fdeeeb2594133b2449d684adf090572 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 12 Dec 2023 04:34:51 -0500 Subject: [PATCH 071/178] update supported features in README --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 825dba34..0c5effbf 100644 --- a/README.md +++ b/README.md @@ -25,13 +25,14 @@ Nintendo modified Rendez-Vous somewhat heavily, simplifying the library/transpor While the main goal of this library is to support games which use the NEX variant of Rendez-Vous made by Nintendo, we also aim to be compatible with games using the original Rendez-Vous library. Due to the extensible nature of Rendez-Vous, many games may feature customizations much like NEX and have non-standard features/behavior. We do our best to support these cases, but there may be times where supporting all variations becomes untenable. In those cases, a fork of these libraries should be made instead if they require heavy modifications ### Supported features +- [x] Quazal compatibility mode/settings - [x] [HPP servers](https://nintendo-wiki.pretendo.network/docs/hpp) (NEX over HTTP) -- [ ] [PRUDP servers](https://nintendo-wiki.pretendo.network/docs/prudp) +- [x] [PRUDP servers](https://nintendo-wiki.pretendo.network/docs/prudp) - [x] UDP transport - - [ ] WebSocket transport + - [x] WebSocket transport (Experimental, largely untested) - [x] PRUDPv0 packets - [x] PRUDPv1 packets - - [ ] PRUDPLite packets + - [x] PRUDPLite packets - [x] Fragmented packet payloads - [x] Packet retransmission - [x] Reliable packets From 89af63e3d41efc4847256307ea0bfeaa318753c7 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 13 Dec 2023 19:51:05 -0500 Subject: [PATCH 072/178] prudp: validate virtual port being connected to --- prudp_server.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/prudp_server.go b/prudp_server.go index fc7e1fe3..ec4936fa 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -123,7 +123,6 @@ func (s *PRUDPServer) ListenUDP(port int) { // ListenWebSocket starts a PRUDP server on a given port using a WebSocket server func (s *PRUDPServer) ListenWebSocket(port int) { - s.initPRUDPv1ConnectionSignatureKey() s.initVirtualPorts() @@ -222,6 +221,16 @@ func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, w } func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *gws.Conn) { + if !slices.Contains(s.VirtualServerPorts, packet.DestinationPort()) { + logger.Warningf("Client %s trying to connect to unbound server vport %d", address.String(), packet.DestinationPort()) + return + } + + if packet.DestinationStreamType() > VirtualStreamTypeRelay { + logger.Warningf("Client %s trying to use invalid to server stream type %d", address.String(), packet.DestinationStreamType()) + return + } + virtualServer, _ := s.virtualServers.Get(packet.DestinationPort()) virtualServerStream, _ := virtualServer.Get(packet.DestinationStreamType()) From 5648479445d52c2f0234501a1db4919945fd306f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 13 Dec 2023 19:51:33 -0500 Subject: [PATCH 073/178] prudp: clean up websocket client disconnects --- websocket_server.go | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/websocket_server.go b/websocket_server.go index 525282bf..75146fde 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -3,6 +3,7 @@ package nex import ( "fmt" "net/http" + "strings" "time" "github.com/lxzan/gws" @@ -22,7 +23,42 @@ func (wseh *wsEventHandler) OnOpen(socket *gws.Conn) { } func (wseh *wsEventHandler) OnClose(socket *gws.Conn, err error) { - // TODO - Client clean up + clientsToCleanup := make([]*PRUDPClient, 0) + + // * Loop over all bound ports, and each ports stream types + // * to look for clients connecting from this WebSocket + // TODO - This kinda sucks tbh. Unsure how much this effects performance. Test more and refactor? + wseh.prudpServer.virtualServers.Each(func(port uint8, stream *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]) bool { + stream.Each(func(streamType uint8, clients *MutexMap[string, *PRUDPClient]) bool { + clients.Each(func(discriminator string, client *PRUDPClient) bool { + if strings.HasPrefix(discriminator, socket.RemoteAddr().String()) { + clientsToCleanup = append(clientsToCleanup, client) + return true // * Assume only one client connected per server port per stream type + } + + return false + }) + + return false + }) + + return false + }) + + // * We cannot modify a MutexMap while looping over it + // * since the mutex is locked. We first need to grab + // * the entries we want to delete, and then loop over + // * them here to actually clean them up + for _, client := range clientsToCleanup { + client.cleanup() // * "removed" event is dispatched here + + virtualServer, _ := wseh.prudpServer.virtualServers.Get(client.DestinationPort) + virtualServerStream, _ := virtualServer.Get(client.DestinationStreamType) + + discriminator := fmt.Sprintf("%s-%d-%d", client.address.String(), client.SourcePort, client.SourceStreamType) + + virtualServerStream.Delete(discriminator) + } } func (wseh *wsEventHandler) OnPing(socket *gws.Conn, payload []byte) { From 71f8d8629634483dc7fb83429131413b85be3e77 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 13 Dec 2023 19:55:50 -0500 Subject: [PATCH 074/178] prudp: added PRUDPv0CustomChecksumCalculator to PRUDPServer --- prudp_packet_v0.go | 8 +++++- prudp_server.go | 61 +++++++++++++++++++++++----------------------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 5acb9651..adfe43c1 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -230,7 +230,13 @@ func (p *PRUDPPacketV0) Bytes() []byte { stream.WriteBytesNext(p.payload) } - checksum := p.calculateChecksum(stream.Bytes()) + var checksum uint32 + + if p.server.PRUDPv0CustomChecksumCalculator != nil { + checksum = p.server.PRUDPv0CustomChecksumCalculator(p, stream.Bytes()) + } else { + checksum = p.calculateChecksum(stream.Bytes()) + } if server.EnhancedChecksum { stream.WriteUInt32LE(checksum) diff --git a/prudp_server.go b/prudp_server.go index ec4936fa..8d350276 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -16,36 +16,37 @@ import ( // PRUDPServer represents a bare-bones PRUDP server type PRUDPServer struct { - udpSocket *net.UDPConn - websocketServer *WebSocketServer - PRUDPVersion int - PRUDPMinorVersion uint32 - virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] - IsQuazalMode bool - VirtualServerPorts []uint8 - SecureVirtualServerPorts []uint8 - SupportedFunctions uint32 - accessKey string - kerberosPassword []byte - kerberosTicketVersion int - kerberosKeySize int - FragmentSize int - version *LibraryVersion - datastoreProtocolVersion *LibraryVersion - matchMakingProtocolVersion *LibraryVersion - rankingProtocolVersion *LibraryVersion - ranking2ProtocolVersion *LibraryVersion - messagingProtocolVersion *LibraryVersion - utilityProtocolVersion *LibraryVersion - natTraversalProtocolVersion *LibraryVersion - prudpEventHandlers map[string][]func(packet PacketInterface) - clientRemovedEventHandlers []func(client *PRUDPClient) - connectionIDCounter *Counter[uint32] - pingTimeout time.Duration - passwordFromPIDHandler func(pid *PID) (string, uint32) - PRUDPv1ConnectionSignatureKey []byte - EnhancedChecksum bool - CompressionEnabled bool + udpSocket *net.UDPConn + websocketServer *WebSocketServer + PRUDPVersion int + PRUDPMinorVersion uint32 + virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] + IsQuazalMode bool + VirtualServerPorts []uint8 + SecureVirtualServerPorts []uint8 + SupportedFunctions uint32 + accessKey string + kerberosPassword []byte + kerberosTicketVersion int + kerberosKeySize int + FragmentSize int + version *LibraryVersion + datastoreProtocolVersion *LibraryVersion + matchMakingProtocolVersion *LibraryVersion + rankingProtocolVersion *LibraryVersion + ranking2ProtocolVersion *LibraryVersion + messagingProtocolVersion *LibraryVersion + utilityProtocolVersion *LibraryVersion + natTraversalProtocolVersion *LibraryVersion + prudpEventHandlers map[string][]func(packet PacketInterface) + clientRemovedEventHandlers []func(client *PRUDPClient) + connectionIDCounter *Counter[uint32] + pingTimeout time.Duration + passwordFromPIDHandler func(pid *PID) (string, uint32) + PRUDPv1ConnectionSignatureKey []byte + EnhancedChecksum bool + CompressionEnabled bool + PRUDPv0CustomChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 } // OnData adds an event handler which is fired when a new DATA packet is received From 30d9545c6184b9d4f3f65b5defea7fe11996e19a Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 13 Dec 2023 20:05:11 -0500 Subject: [PATCH 075/178] configure the size of the Quazal::String length field --- hpp_server.go | 14 +++++++++++++- prudp_server.go | 12 ++++++++++++ server_interface.go | 2 ++ stream_in.go | 17 +++++++++++++++-- stream_out.go | 7 ++++++- 5 files changed, 48 insertions(+), 4 deletions(-) diff --git a/hpp_server.go b/hpp_server.go index eed0ecc0..b172dbc0 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -20,6 +20,7 @@ type HPPServer struct { natTraversalProtocolVersion *LibraryVersion dataHandlers []func(packet PacketInterface) passwordFromPIDHandler func(pid *PID) (string, uint32) + stringLengthSize int } // OnData adds an event handler which is fired when a new HPP request is received @@ -258,9 +259,20 @@ func (s *HPPServer) SetPasswordFromPIDFunction(handler func(pid *PID) (string, u s.passwordFromPIDHandler = handler } +// StringLengthSize returns the size of the length field used for Quazal::String types +func (s *HPPServer) StringLengthSize() int { + return s.stringLengthSize +} + +// SetStringLengthSize sets the size of the length field used for Quazal::String types +func (s *HPPServer) SetStringLengthSize(size int) { + s.stringLengthSize = size +} + // NewHPPServer returns a new HPP server func NewHPPServer() *HPPServer { return &HPPServer{ - dataHandlers: make([]func(packet PacketInterface), 0), + dataHandlers: make([]func(packet PacketInterface), 0), + stringLengthSize: 2, } } diff --git a/prudp_server.go b/prudp_server.go index 8d350276..1500f6ef 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -47,6 +47,7 @@ type PRUDPServer struct { EnhancedChecksum bool CompressionEnabled bool PRUDPv0CustomChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 + stringLengthSize int } // OnData adds an event handler which is fired when a new DATA packet is received @@ -1102,6 +1103,16 @@ func (s *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *PID) (string, s.passwordFromPIDHandler = handler } +// StringLengthSize returns the size of the length field used for Quazal::String types +func (s *PRUDPServer) StringLengthSize() int { + return s.stringLengthSize +} + +// SetStringLengthSize sets the size of the length field used for Quazal::String types +func (s *PRUDPServer) SetStringLengthSize(size int) { + s.stringLengthSize = size +} + // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ @@ -1114,5 +1125,6 @@ func NewPRUDPServer() *PRUDPServer { prudpEventHandlers: make(map[string][]func(PacketInterface)), connectionIDCounter: NewCounter[uint32](10), pingTimeout: time.Second * 15, + stringLengthSize: 2, } } diff --git a/server_interface.go b/server_interface.go index 03a86a7e..f61fca74 100644 --- a/server_interface.go +++ b/server_interface.go @@ -17,4 +17,6 @@ type ServerInterface interface { OnData(handler func(packet PacketInterface)) PasswordFromPID(pid *PID) (string, uint32) SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) + StringLengthSize() int + SetStringLengthSize(size int) } diff --git a/stream_in.go b/stream_in.go index d0a46d93..cbe0794a 100644 --- a/stream_in.go +++ b/stream_in.go @@ -219,7 +219,20 @@ func (stream *StreamIn) ReadPID() (*PID, error) { // ReadString reads and returns a nex string type func (stream *StreamIn) ReadString() (string, error) { - length, err := stream.ReadUInt16LE() + var length int64 + var err error + + // TODO - These variable names kinda suck? + if stream.Server.StringLengthSize() == 4 { + l, e := stream.ReadUInt32LE() + length = int64(l) + err = e + } else { + l, e := stream.ReadUInt16LE() + length = int64(l) + err = e + } + if err != nil { return "", fmt.Errorf("Failed to read NEX string length. %s", err.Error()) } @@ -228,7 +241,7 @@ func (stream *StreamIn) ReadString() (string, error) { return "", errors.New("NEX string length longer than data size") } - stringData := stream.ReadBytesNext(int64(length)) + stringData := stream.ReadBytesNext(length) str := string(stringData) return strings.TrimRight(str, "\x00"), nil diff --git a/stream_out.go b/stream_out.go index 8a66a83f..9dd4cb28 100644 --- a/stream_out.go +++ b/stream_out.go @@ -142,8 +142,13 @@ func (stream *StreamOut) WriteString(str string) { str = str + "\x00" strLength := len(str) + if stream.Server.StringLengthSize() == 4 { + stream.WriteUInt32LE(uint32(strLength)) + } else { + stream.WriteUInt16LE(uint16(strLength)) + } + stream.Grow(int64(strLength)) - stream.WriteUInt16LE(uint16(strLength)) stream.WriteBytesNext([]byte(str)) } From add4019b444d1eb2a6907b58ef5d9702ce392057 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 13 Dec 2023 20:09:00 -0500 Subject: [PATCH 076/178] add WriteStationURL to StreamOut --- stream_out.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/stream_out.go b/stream_out.go index 9dd4cb28..c0dedb1a 100644 --- a/stream_out.go +++ b/stream_out.go @@ -206,6 +206,11 @@ func (stream *StreamOut) WriteStructure(structure StructureInterface) { stream.WriteBytesNext(content) } +// WriteStationURL writes a StationURL type +func (stream *StreamOut) WriteStationURL(stationURL *StationURL) { + stream.WriteString(stationURL.EncodeToString()) +} + // WriteListUInt8 writes a list of uint8 types func (stream *StreamOut) WriteListUInt8(list []uint8) { stream.WriteUInt32LE(uint32(len(list))) From b056f3492680b3ed3906a6203729489c1e4992b9 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 13 Dec 2023 20:15:24 -0500 Subject: [PATCH 077/178] update RVConnectionData to use new StationURL stream method --- types.go | 74 +++++++++++++++++++++----------------------------------- 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/types.go b/types.go index 35df9f41..f99fe962 100644 --- a/types.go +++ b/types.go @@ -342,41 +342,21 @@ func NewDataHolder() *DataHolder { // RVConnectionData represents a nex RVConnectionData type type RVConnectionData struct { Structure - stationURL string - specialProtocols []byte - stationURLSpecialProtocols string - time *DateTime -} - -// SetStationURL sets the RVConnectionData station URL -func (rvConnectionData *RVConnectionData) SetStationURL(stationURL string) { - rvConnectionData.stationURL = stationURL -} - -// SetSpecialProtocols sets the RVConnectionData special protocol list (unused by Nintendo) -func (rvConnectionData *RVConnectionData) SetSpecialProtocols(specialProtocols []byte) { - rvConnectionData.specialProtocols = specialProtocols -} - -// SetStationURLSpecialProtocols sets the RVConnectionData special station URL (unused by Nintendo) -func (rvConnectionData *RVConnectionData) SetStationURLSpecialProtocols(stationURLSpecialProtocols string) { - rvConnectionData.stationURLSpecialProtocols = stationURLSpecialProtocols -} - -// SetTime sets the RVConnectionData time -func (rvConnectionData *RVConnectionData) SetTime(time *DateTime) { - rvConnectionData.time = time + StationURL *StationURL + SpecialProtocols []byte + StationURLSpecialProtocols *StationURL + Time *DateTime } // Bytes encodes the RVConnectionData and returns a byte array func (rvConnectionData *RVConnectionData) Bytes(stream *StreamOut) []byte { - stream.WriteString(rvConnectionData.stationURL) - stream.WriteListUInt8(rvConnectionData.specialProtocols) - stream.WriteString(rvConnectionData.stationURLSpecialProtocols) + stream.WriteStationURL(rvConnectionData.StationURL) + stream.WriteListUInt8(rvConnectionData.SpecialProtocols) + stream.WriteStationURL(rvConnectionData.StationURLSpecialProtocols) if stream.Server.LibraryVersion().GreaterOrEqual("3.5.0") { rvConnectionData.SetStructureVersion(1) - stream.WriteDateTime(rvConnectionData.time) + stream.WriteDateTime(rvConnectionData.Time) } return stream.Bytes() @@ -388,15 +368,15 @@ func (rvConnectionData *RVConnectionData) Copy() StructureInterface { copied.SetStructureVersion(rvConnectionData.StructureVersion()) copied.parentType = rvConnectionData.parentType - copied.stationURL = rvConnectionData.stationURL - copied.specialProtocols = make([]byte, len(rvConnectionData.specialProtocols)) + copied.StationURL = rvConnectionData.StationURL.Copy() + copied.SpecialProtocols = make([]byte, len(rvConnectionData.SpecialProtocols)) - copy(copied.specialProtocols, rvConnectionData.specialProtocols) + copy(copied.SpecialProtocols, rvConnectionData.SpecialProtocols) - copied.stationURLSpecialProtocols = rvConnectionData.stationURLSpecialProtocols + copied.StationURLSpecialProtocols = rvConnectionData.StationURLSpecialProtocols.Copy() - if rvConnectionData.time != nil { - copied.time = rvConnectionData.time.Copy() + if rvConnectionData.Time != nil { + copied.Time = rvConnectionData.Time.Copy() } return copied @@ -410,28 +390,28 @@ func (rvConnectionData *RVConnectionData) Equals(structure StructureInterface) b return false } - if rvConnectionData.stationURL != other.stationURL { + if !rvConnectionData.StationURL.Equals(other.StationURL) { return false } - if !bytes.Equal(rvConnectionData.specialProtocols, other.specialProtocols) { + if !bytes.Equal(rvConnectionData.SpecialProtocols, other.SpecialProtocols) { return false } - if rvConnectionData.stationURLSpecialProtocols != other.stationURLSpecialProtocols { + if !rvConnectionData.StationURLSpecialProtocols.Equals(other.StationURLSpecialProtocols) { return false } - if rvConnectionData.time != nil && other.time == nil { + if rvConnectionData.Time != nil && other.Time == nil { return false } - if rvConnectionData.time == nil && other.time != nil { + if rvConnectionData.Time == nil && other.Time != nil { return false } - if rvConnectionData.time != nil && other.time != nil { - if !rvConnectionData.time.Equals(other.time) { + if rvConnectionData.Time != nil && other.Time != nil { + if !rvConnectionData.Time.Equals(other.Time) { return false } } @@ -453,14 +433,14 @@ func (rvConnectionData *RVConnectionData) FormatToString(indentationLevel int) s b.WriteString("RVConnectionData{\n") b.WriteString(fmt.Sprintf("%sstructureVersion: %d,\n", indentationValues, rvConnectionData.structureVersion)) - b.WriteString(fmt.Sprintf("%sstationURL: %q,\n", indentationValues, rvConnectionData.stationURL)) - b.WriteString(fmt.Sprintf("%sspecialProtocols: %v,\n", indentationValues, rvConnectionData.specialProtocols)) - b.WriteString(fmt.Sprintf("%sstationURLSpecialProtocols: %q,\n", indentationValues, rvConnectionData.stationURLSpecialProtocols)) + b.WriteString(fmt.Sprintf("%sStationURL: %q,\n", indentationValues, rvConnectionData.StationURL.FormatToString(indentationLevel+1))) + b.WriteString(fmt.Sprintf("%sSpecialProtocols: %v,\n", indentationValues, rvConnectionData.SpecialProtocols)) + b.WriteString(fmt.Sprintf("%sStationURLSpecialProtocols: %q,\n", indentationValues, rvConnectionData.StationURLSpecialProtocols.FormatToString(indentationLevel+1))) - if rvConnectionData.time != nil { - b.WriteString(fmt.Sprintf("%stime: %s\n", indentationValues, rvConnectionData.time.FormatToString(indentationLevel+1))) + if rvConnectionData.Time != nil { + b.WriteString(fmt.Sprintf("%sTime: %s\n", indentationValues, rvConnectionData.Time.FormatToString(indentationLevel+1))) } else { - b.WriteString(fmt.Sprintf("%stime: nil\n", indentationValues)) + b.WriteString(fmt.Sprintf("%sTime: nil\n", indentationValues)) } b.WriteString(fmt.Sprintf("%s}", indentationEnd)) From e2b443aa191b6ab9345ce7d2551bd9fa17608f6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 14 Dec 2023 19:36:39 +0000 Subject: [PATCH 078/178] StationURL: return empty string when no scheme is set --- types.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/types.go b/types.go index f99fe962..4ac959f4 100644 --- a/types.go +++ b/types.go @@ -617,6 +617,11 @@ func (s *StationURL) FromString(str string) { // EncodeToString encodes the StationURL into a string func (s *StationURL) EncodeToString() string { + // * Don't return anything if no scheme is set + if s.Scheme == "" { + return "" + } + fields := []string{} s.Fields.Each(func(key, value string) bool { From 40a4dfa1ab79b912d037ade42c50c3f9b17bb840 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 14 Dec 2023 19:37:31 +0000 Subject: [PATCH 079/178] test: Update with RVConnectionData changes --- test/auth.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/auth.go b/test/auth.go index 35fe872e..08f8904e 100644 --- a/test/auth.go +++ b/test/auth.go @@ -66,10 +66,10 @@ func login(packet nex.PRUDPPacketInterface) { pConnectionData := nex.NewRVConnectionData() strReturnMsg := "Test Build" - pConnectionData.SetStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") - pConnectionData.SetSpecialProtocols([]byte{}) - pConnectionData.SetStationURLSpecialProtocols("") - pConnectionData.SetTime(nex.NewDateTime(0).Now()) + pConnectionData.StationURL = nex.NewStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") + pConnectionData.SpecialProtocols = []byte{} + pConnectionData.StationURLSpecialProtocols = nex.NewStationURL("") + pConnectionData.Time = nex.NewDateTime(0).Now() responseStream := nex.NewStreamOut(authServer) From 5d741dc85f2ca5d29518a652e5d77a5c8a4df265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 14 Dec 2023 19:58:51 +0000 Subject: [PATCH 080/178] hpp: Add ListenSecure function This one uses TLS by itself, instead of depending on the implementation to make the HTTPS server. --- hpp_server.go | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/hpp_server.go b/hpp_server.go index b172dbc0..4a285f9f 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -1,6 +1,7 @@ package nex import ( + "crypto/tls" "fmt" "net" "net/http" @@ -9,6 +10,7 @@ import ( // HPPServer represents a bare-bones HPP server type HPPServer struct { + server *http.Server accessKey string version *LibraryVersion datastoreProtocolVersion *LibraryVersion @@ -129,9 +131,19 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { // Listen starts a HPP server on a given port func (s *HPPServer) Listen(port int) { - http.HandleFunc("/hpp/", s.handleRequest) + s.server.Addr = fmt.Sprintf(":%d", port) - err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil) + err := s.server.ListenAndServe() + if err != nil { + panic(err) + } +} + +// ListenSecure starts a HPP server on a given port using a secure (TLS) server +func (s *HPPServer) ListenSecure(port int, certFile, keyFile string) { + s.server.Addr = fmt.Sprintf(":%d", port) + + err := s.server.ListenAndServeTLS(certFile, keyFile) if err != nil { panic(err) } @@ -271,8 +283,22 @@ func (s *HPPServer) SetStringLengthSize(size int) { // NewHPPServer returns a new HPP server func NewHPPServer() *HPPServer { - return &HPPServer{ + s := &HPPServer{ dataHandlers: make([]func(packet PacketInterface), 0), stringLengthSize: 2, } + + mux := http.NewServeMux() + mux.HandleFunc("/hpp/", s.handleRequest) + + httpServer := &http.Server{ + Handler: mux, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS11, // * The 3DS and Wii U only support up to TLS 1.1 natively + }, + } + + s.server = httpServer + + return s } From 9e5f8bf781b9722a8200792bfd5af4fe8c49fd72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 14 Dec 2023 20:20:00 +0000 Subject: [PATCH 081/178] websocket: Panic if ListenAndServe fails --- websocket_server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/websocket_server.go b/websocket_server.go index 75146fde..8a3f1587 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -116,7 +116,7 @@ func (ws *WebSocketServer) listen(port int) { err := http.ListenAndServe(fmt.Sprintf(":%d", port), ws.mux) if err != nil { - logger.Error(err.Error()) + panic(err) } } @@ -125,6 +125,6 @@ func (ws *WebSocketServer) listenSecure(port int, certFile, keyFile string) { err := http.ListenAndServeTLS(fmt.Sprintf(":%d", port), certFile, keyFile, ws.mux) if err != nil { - logger.Error(err.Error()) + panic(err) } } From f318d31ed8bc7e12dd7c0b79c4f0ec698f67b92e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 14 Dec 2023 20:20:46 +0000 Subject: [PATCH 082/178] README: Uncheck support for "verbose" RMC The verbose RMC is not supported at the moment. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0c5effbf..0c1600b6 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ While the main goal of this library is to support games which use the NEX varian - [x] Response messages - [x] "Packed" encoded messages - [x] "Packed" (extended) encoded messages - - [x] "Verbose" encoded messages + - [ ] "Verbose" encoded messages - [x] [Kerberos authentication](https://nintendo-wiki.pretendo.network/docs/nex/kerberos) ### Example From 816f8e2edddac246335ff4fbac51512ff0de2cba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 14 Dec 2023 21:21:57 +0000 Subject: [PATCH 083/178] StationURL: More checks on FromString --- types.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/types.go b/types.go index 4ac959f4..571e6203 100644 --- a/types.go +++ b/types.go @@ -600,9 +600,19 @@ func (s *StationURL) IsPublic() bool { // FromString parses the StationURL data from a string func (s *StationURL) FromString(str string) { + if str == "" { + return + } + split := strings.Split(str, ":/") s.Scheme = split[0] + + // * Return if there are no fields + if split[1] == "" { + return + } + fields := strings.Split(split[1], ";") for i := 0; i < len(fields); i++ { @@ -667,9 +677,7 @@ func NewStationURL(str string) *StationURL { Fields: NewMutexMap[string, string](), } - if str != "" { - stationURL.FromString(str) - } + stationURL.FromString(str) return stationURL } From ffaf7278eb84b7b44f74e03281e37cfee92d2d1c Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Thu, 14 Dec 2023 17:17:44 -0500 Subject: [PATCH 084/178] prudp: export payload compression methods to support custom compression --- prudp_server.go | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 1500f6ef..61059568 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -616,12 +616,16 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { decryptedPayload = pendingPacket.Payload() } - decompressedPayload, err := s.decompressPayload(decryptedPayload) - if err != nil { - logger.Error(err.Error()) + if s.CompressionEnabled { + decompressedPayload, err := s.DecompressPayload(decryptedPayload) + if err != nil { + logger.Error(err.Error()) + } + + decryptedPayload = decompressedPayload } - payload := substream.AddFragment(decompressedPayload) + payload := substream.AddFragment(decryptedPayload) if packet.getFragmentID() == 0 { message := NewRMCMessage() @@ -774,9 +778,14 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { if packetCopy.Type() == DataPacket && !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { if packetCopy.HasFlag(FlagReliable) { payload := packetCopy.Payload() - compressedPayload, err := s.compressPayload(payload) - if err != nil { - logger.Error(err.Error()) + + if s.CompressionEnabled { + compressedPayload, err := s.CompressPayload(payload) + if err != nil { + logger.Error(err.Error()) + } + + payload = compressedPayload } substream := client.reliableSubstream(packetCopy.SubstreamID()) @@ -791,7 +800,7 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { // * PRUDPLite packet. No RC4 if packetCopy.Version() != 2 { - packetCopy.SetPayload(substream.Encrypt(compressedPayload)) + packetCopy.SetPayload(substream.Encrypt(payload)) } } else { // * PRUDPLite packet. No RC4 @@ -828,11 +837,8 @@ func (s *PRUDPServer) sendRaw(client *PRUDPClient, data []byte) { } } -func (s *PRUDPServer) decompressPayload(payload []byte) ([]byte, error) { - if !s.CompressionEnabled { - return payload, nil - } - +// DecompressPayload handles the decompression of DATA packet payloads. By default this uses zlib compression +func (s *PRUDPServer) DecompressPayload(payload []byte) ([]byte, error) { compressionRatio := payload[0] compressed := payload[1:] @@ -873,11 +879,8 @@ func (s *PRUDPServer) decompressPayload(payload []byte) ([]byte, error) { return decompressedBytes, nil } -func (s *PRUDPServer) compressPayload(payload []byte) ([]byte, error) { - if !s.CompressionEnabled { - return payload, nil - } - +// CompressPayload handles the compression of DATA packet payloads. By default this uses zlib compression +func (s *PRUDPServer) CompressPayload(payload []byte) ([]byte, error) { compressed := bytes.Buffer{} // * Create a zlib writer with default compression level From 95e8f79957982011a585fce2f2bd191793344048 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Thu, 14 Dec 2023 17:39:52 -0500 Subject: [PATCH 085/178] prudp: fix payload compression exports --- prudp_server.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 61059568..a735a14e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -48,6 +48,8 @@ type PRUDPServer struct { CompressionEnabled bool PRUDPv0CustomChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 stringLengthSize int + CustomPayloadCompressor func(payload []byte) ([]byte, error) + CustomPayloadDecompressor func(payload []byte) ([]byte, error) } // OnData adds an event handler which is fired when a new DATA packet is received @@ -617,7 +619,7 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { } if s.CompressionEnabled { - decompressedPayload, err := s.DecompressPayload(decryptedPayload) + decompressedPayload, err := s.decompressPayload(decryptedPayload) if err != nil { logger.Error(err.Error()) } @@ -780,7 +782,7 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { payload := packetCopy.Payload() if s.CompressionEnabled { - compressedPayload, err := s.CompressPayload(payload) + compressedPayload, err := s.compressPayload(payload) if err != nil { logger.Error(err.Error()) } @@ -837,8 +839,11 @@ func (s *PRUDPServer) sendRaw(client *PRUDPClient, data []byte) { } } -// DecompressPayload handles the decompression of DATA packet payloads. By default this uses zlib compression -func (s *PRUDPServer) DecompressPayload(payload []byte) ([]byte, error) { +func (s *PRUDPServer) decompressPayload(payload []byte) ([]byte, error) { + if s.CustomPayloadDecompressor != nil { + return s.CustomPayloadDecompressor(payload) + } + compressionRatio := payload[0] compressed := payload[1:] @@ -879,8 +884,11 @@ func (s *PRUDPServer) DecompressPayload(payload []byte) ([]byte, error) { return decompressedBytes, nil } -// CompressPayload handles the compression of DATA packet payloads. By default this uses zlib compression -func (s *PRUDPServer) CompressPayload(payload []byte) ([]byte, error) { +func (s *PRUDPServer) compressPayload(payload []byte) ([]byte, error) { + if s.CustomPayloadCompressor != nil { + return s.CustomPayloadCompressor(payload) + } + compressed := bytes.Buffer{} // * Create a zlib writer with default compression level From 49761ddbf9159f62f245c2a8e5412b76eb391e36 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Dec 2023 01:55:35 -0500 Subject: [PATCH 086/178] rmc: support verbose encoded messages --- README.md | 2 +- rmc_message.go | 174 ++++++++++++++++++++++++++++++++++++++++++++++--- stream_in.go | 21 ++++-- stream_out.go | 19 ++++-- types.go | 104 +++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 0c1600b6..0c5effbf 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ While the main goal of this library is to support games which use the NEX varian - [x] Response messages - [x] "Packed" encoded messages - [x] "Packed" (extended) encoded messages - - [ ] "Verbose" encoded messages + - [x] "Verbose" encoded messages - [x] [Kerberos authentication](https://nintendo-wiki.pretendo.network/docs/nex/kerberos) ### Example diff --git a/rmc_message.go b/rmc_message.go index 0d5933fe..20a66c28 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -7,14 +7,19 @@ import ( // RMCMessage represents a message in the RMC (Remote Method Call) protocol type RMCMessage struct { - IsRequest bool // * Indicates if the message is a request message (true) or response message (false) - IsSuccess bool // * Indicates if the message is a success message (true) for a response message - IsHPP bool // * Indicates if the message is an HPP message - ProtocolID uint16 // * Protocol ID of the message - CallID uint32 // * Call ID associated with the message - MethodID uint32 // * Method ID in the requested protocol - ErrorCode uint32 // * Error code for a response message - Parameters []byte // * Input for the method + VerboseMode bool // * Determines whether or not to encode the message using the "verbose" encoding method + IsRequest bool // * Indicates if the message is a request message (true) or response message (false) + IsSuccess bool // * Indicates if the message is a success message (true) for a response message + IsHPP bool // * Indicates if the message is an HPP message + ProtocolID uint16 // * Protocol ID of the message. Only present in "packed" variations + ProtocolName string // * Protocol name of the message. Only present in "verbose" variations + CallID uint32 // * Call ID associated with the message + MethodID uint32 // * Method ID in the requested protocol. Only present in "packed" variations + MethodName string // * Method name in the requested protocol. Only present in "verbose" variations + ErrorCode uint32 // * Error code for a response message + ClassVersionContainer *ClassVersionContainer // * Contains version info for Structures in the request. Only present in "verbose" variations + Parameters []byte // * Input for the method + // TODO - Verbose messages suffix response method names with "*". Should we have a "HasResponsePointer" sort of field? } // Copy copies the message into a new RMCMessage @@ -25,8 +30,10 @@ func (rmc *RMCMessage) Copy() *RMCMessage { copied.IsSuccess = rmc.IsSuccess copied.IsHPP = rmc.IsHPP copied.ProtocolID = rmc.ProtocolID + copied.ProtocolName = rmc.ProtocolName copied.CallID = rmc.CallID copied.MethodID = rmc.MethodID + copied.MethodName = rmc.MethodName copied.ErrorCode = rmc.ErrorCode if rmc.Parameters != nil { @@ -38,6 +45,14 @@ func (rmc *RMCMessage) Copy() *RMCMessage { // FromBytes decodes an RMCMessage from the given byte slice. func (rmc *RMCMessage) FromBytes(data []byte) error { + if rmc.VerboseMode { + return rmc.decodeVerbose(data) + } else { + return rmc.decodePacked(data) + } +} + +func (rmc *RMCMessage) decodePacked(data []byte) error { stream := NewStreamIn(data, nil) length, err := stream.ReadUInt32LE() @@ -124,8 +139,96 @@ func (rmc *RMCMessage) FromBytes(data []byte) error { return nil } +func (rmc *RMCMessage) decodeVerbose(data []byte) error { + stream := NewStreamIn(data, nil) + + length, err := stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message size. %s", err.Error()) + } + + if stream.Remaining() != int(length) { + return errors.New("RMC Message has unexpected size") + } + + rmc.ProtocolName, err = stream.ReadString() + if err != nil { + return fmt.Errorf("Failed to read RMC Message protocol name. %s", err.Error()) + } + + rmc.IsRequest, err = stream.ReadBool() + if err != nil { + return fmt.Errorf("Failed to read RMC Message \"is request\" bool. %s", err.Error()) + } + + if rmc.IsRequest { + rmc.CallID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (request) call ID. %s", err.Error()) + } + + rmc.MethodName, err = stream.ReadString() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (request) method name. %s", err.Error()) + } + + rmc.ClassVersionContainer, err = StreamReadStructure(stream, NewClassVersionContainer()) + if err != nil { + return fmt.Errorf("Failed to read RMC Message ClassVersionContainer. %s", err.Error()) + } + + rmc.Parameters = stream.ReadRemaining() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (request) parameters. %s", err.Error()) + } + } else { + rmc.IsSuccess, err = stream.ReadBool() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) error check. %s", err.Error()) + } + + if rmc.IsSuccess { + rmc.CallID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) + } + + rmc.MethodName, err = stream.ReadString() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) method name. %s", err.Error()) + } + + rmc.Parameters = stream.ReadRemaining() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) parameters. %s", err.Error()) + } + + } else { + rmc.ErrorCode, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) error code. %s", err.Error()) + } + + rmc.CallID, err = stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) + } + } + } + + return nil +} + // Bytes serializes the RMCMessage to a byte slice. func (rmc *RMCMessage) Bytes() []byte { + if rmc.VerboseMode { + return rmc.encodeVerbose() + } else { + return rmc.encodePacked() + } +} + +func (rmc *RMCMessage) encodePacked() []byte { stream := NewStreamOut(nil) // * RMC requests have their protocol IDs ORed with 0x80 @@ -149,21 +252,72 @@ func (rmc *RMCMessage) Bytes() []byte { if rmc.IsRequest { stream.WriteUInt32LE(rmc.CallID) stream.WriteUInt32LE(rmc.MethodID) + if rmc.Parameters != nil && len(rmc.Parameters) > 0 { stream.Grow(int64(len(rmc.Parameters))) stream.WriteBytesNext(rmc.Parameters) } } else { + stream.WriteBool(rmc.IsSuccess) + if rmc.IsSuccess { - stream.WriteBool(true) stream.WriteUInt32LE(rmc.CallID) stream.WriteUInt32LE(rmc.MethodID | 0x8000) + + if rmc.Parameters != nil && len(rmc.Parameters) > 0 { + stream.Grow(int64(len(rmc.Parameters))) + stream.WriteBytesNext(rmc.Parameters) + } + } else { + stream.WriteUInt32LE(uint32(rmc.ErrorCode)) + stream.WriteUInt32LE(rmc.CallID) + } + } + + serialized := stream.Bytes() + + message := NewStreamOut(nil) + + message.WriteUInt32LE(uint32(len(serialized))) + message.Grow(int64(len(serialized))) + message.WriteBytesNext(serialized) + + return message.Bytes() +} + +func (rmc *RMCMessage) encodeVerbose() []byte { + stream := NewStreamOut(nil) + + stream.WriteString(rmc.ProtocolName) + stream.WriteBool(rmc.IsRequest) + + if rmc.IsRequest { + stream.WriteUInt32LE(rmc.CallID) + stream.WriteString(rmc.MethodName) + + if rmc.ClassVersionContainer != nil { + stream.WriteStructure(rmc.ClassVersionContainer) + } else { + // * Fail safe. This is always present even if no structures are used + stream.WriteUInt32LE(0) + } + + if rmc.Parameters != nil && len(rmc.Parameters) > 0 { + stream.Grow(int64(len(rmc.Parameters))) + stream.WriteBytesNext(rmc.Parameters) + } + } else { + stream.WriteBool(rmc.IsSuccess) + + if rmc.IsSuccess { + stream.WriteUInt32LE(rmc.CallID) + stream.WriteString(rmc.MethodName) + if rmc.Parameters != nil && len(rmc.Parameters) > 0 { stream.Grow(int64(len(rmc.Parameters))) stream.WriteBytesNext(rmc.Parameters) } } else { - stream.WriteBool(false) stream.WriteUInt32LE(uint32(rmc.ErrorCode)) stream.WriteUInt32LE(rmc.CallID) } diff --git a/stream_in.go b/stream_in.go index cbe0794a..c698ddf1 100644 --- a/stream_in.go +++ b/stream_in.go @@ -223,7 +223,11 @@ func (stream *StreamIn) ReadString() (string, error) { var err error // TODO - These variable names kinda suck? - if stream.Server.StringLengthSize() == 4 { + if stream.Server == nil { + l, e := stream.ReadUInt16LE() + length = int64(l) + err = e + } else if stream.Server.StringLengthSize() == 4 { l, e := stream.ReadUInt32LE() length = int64(l) err = e @@ -918,12 +922,15 @@ func StreamReadStructure[T StructureInterface](stream *StreamIn, structure T) (T } } - var useStructureHeader bool - switch server := stream.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructureHeader = server.PRUDPMinorVersion >= 3 - default: - useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") + useStructureHeader := false + + if stream.Server != nil { + switch server := stream.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructureHeader = server.PRUDPMinorVersion >= 3 + default: + useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") + } } if useStructureHeader { diff --git a/stream_out.go b/stream_out.go index c0dedb1a..151e6938 100644 --- a/stream_out.go +++ b/stream_out.go @@ -142,7 +142,9 @@ func (stream *StreamOut) WriteString(str string) { str = str + "\x00" strLength := len(str) - if stream.Server.StringLengthSize() == 4 { + if stream.Server == nil { + stream.WriteUInt16LE(uint16(strLength)) + } else if stream.Server.StringLengthSize() == 4 { stream.WriteUInt32LE(uint32(strLength)) } else { stream.WriteUInt16LE(uint16(strLength)) @@ -189,12 +191,15 @@ func (stream *StreamOut) WriteStructure(structure StructureInterface) { content := structure.Bytes(NewStreamOut(stream.Server)) - var useStructures bool - switch server := stream.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructures = server.PRUDPMinorVersion >= 3 - default: - useStructures = server.LibraryVersion().GreaterOrEqual("3.5.0") + useStructures := false + + if stream.Server != nil { + switch server := stream.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructures = server.PRUDPMinorVersion >= 3 + default: + useStructures = server.LibraryVersion().GreaterOrEqual("3.5.0") + } } if useStructures { diff --git a/types.go b/types.go index 571e6203..b987e4b2 100644 --- a/types.go +++ b/types.go @@ -1008,3 +1008,107 @@ func (variant *Variant) FormatToString(indentationLevel int) string { func NewVariant() *Variant { return &Variant{} } + +// ClassVersionContainer contains version info for structurs used in verbose RMC messages +type ClassVersionContainer struct { + Structure + ClassVersions map[string]uint16 +} + +// ExtractFromStream extracts a ClassVersionContainer structure from a stream +func (cvc *ClassVersionContainer) ExtractFromStream(stream *StreamIn) error { + length, err := stream.ReadUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read ClassVersionContainer length. %s", err.Error()) + } + + for i := 0; i < int(length); i++ { + name, err := stream.ReadString() + if err != nil { + return fmt.Errorf("Failed to read ClassVersionContainer Structure name. %s", err.Error()) + } + + version, err := stream.ReadUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read ClassVersionContainer %s version. %s", name, err.Error()) + } + + cvc.ClassVersions[name] = version + } + + return nil +} + +// Bytes encodes the ClassVersionContainer and returns a byte array +func (cvc *ClassVersionContainer) Bytes(stream *StreamOut) []byte { + stream.WriteUInt32LE(uint32(len(cvc.ClassVersions))) + + for name, version := range cvc.ClassVersions { + stream.WriteString(name) + stream.WriteUInt16LE(version) + } + + return stream.Bytes() +} + +// Copy returns a new copied instance of ClassVersionContainer +func (cvc *ClassVersionContainer) Copy() StructureInterface { + copied := NewClassVersionContainer() + + for name, version := range cvc.ClassVersions { + copied.ClassVersions[name] = version + } + + return copied +} + +// Equals checks if the passed Structure contains the same data as the current instance +func (cvc *ClassVersionContainer) Equals(structure StructureInterface) bool { + other := structure.(*ClassVersionContainer) + + if len(cvc.ClassVersions) != len(other.ClassVersions) { + return false + } + + for name, version1 := range cvc.ClassVersions { + version2, ok := other.ClassVersions[name] + if !ok || version1 != version2 { + return false + } + } + + return true +} + +// String returns a string representation of the struct +func (cvc *ClassVersionContainer) String() string { + return cvc.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (cvc *ClassVersionContainer) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationListValues := strings.Repeat("\t", indentationLevel+2) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("ClassVersionContainer{\n") + b.WriteString(fmt.Sprintf("%sClassVersions: {\n", indentationValues)) + + for name, version := range cvc.ClassVersions { + b.WriteString(fmt.Sprintf("%s%s: %d\n", indentationListValues, name, version)) + } + + b.WriteString(fmt.Sprintf("%s}\n", indentationValues)) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// NewClassVersionContainer returns a new ClassVersionContainer +func NewClassVersionContainer() *ClassVersionContainer { + return &ClassVersionContainer{ + ClassVersions: make(map[string]uint16), + } +} From ce2e746e96dfaf7df3a7fd63f94b62517caed5d3 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Dec 2023 16:54:52 -0500 Subject: [PATCH 087/178] prudp: support LZO compressed payloads --- compression/algorithm_interface.go | 9 +++ compression/dummy.go | 19 +++++ compression/lzo.go | 81 +++++++++++++++++++++ compression/zlib.go | 83 ++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + prudp_server.go | 110 +++-------------------------- 7 files changed, 206 insertions(+), 99 deletions(-) create mode 100644 compression/algorithm_interface.go create mode 100644 compression/dummy.go create mode 100644 compression/lzo.go create mode 100644 compression/zlib.go diff --git a/compression/algorithm_interface.go b/compression/algorithm_interface.go new file mode 100644 index 00000000..6a09124a --- /dev/null +++ b/compression/algorithm_interface.go @@ -0,0 +1,9 @@ +// Package compression provides a set of compression algorithms found +// in several versions of Rendez-Vous for compressing large payloads +package compression + +// Algorithm defines all the methods a compression algorithm should have +type Algorithm interface { + Compress(payload []byte) ([]byte, error) + Decompress(payload []byte) ([]byte, error) +} diff --git a/compression/dummy.go b/compression/dummy.go new file mode 100644 index 00000000..33a1080b --- /dev/null +++ b/compression/dummy.go @@ -0,0 +1,19 @@ +package compression + +// Dummy does no compression. Payloads are returned as-is +type Dummy struct{} + +// Compress does nothing +func (d *Dummy) Compress(payload []byte) ([]byte, error) { + return payload, nil +} + +// Decompress does nothing +func (d *Dummy) Decompress(payload []byte) ([]byte, error) { + return payload, nil +} + +// NewDummyCompression returns a new instance of the Dummy compression +func NewDummyCompression() *Dummy { + return &Dummy{} +} diff --git a/compression/lzo.go b/compression/lzo.go new file mode 100644 index 00000000..55c4a284 --- /dev/null +++ b/compression/lzo.go @@ -0,0 +1,81 @@ +package compression + +import ( + "bytes" + "fmt" + + "github.com/cyberdelia/lzo" +) + +// TODO - Untested. I think this works. Maybe. Verify and remove this comment + +// LZO implements packet payload compression using LZO +type LZO struct{} + +// Compress compresses the payload using LZO +func (l *LZO) Compress(payload []byte) ([]byte, error) { + var compressed bytes.Buffer + + lzoWriter := lzo.NewWriter(&compressed) + + _, err := lzoWriter.Write(payload) + if err != nil { + return []byte{}, err + } + + err = lzoWriter.Close() + if err != nil { + return []byte{}, err + } + + compressedBytes := compressed.Bytes() + + compressionRatio := len(payload)/len(compressedBytes) + 1 + + result := make([]byte, len(compressedBytes)+1) + + result[0] = byte(compressionRatio) + + copy(result[1:], compressedBytes) + + return result, nil +} + +// Decompress decompresses the payload using LZO +func (l *LZO) Decompress(payload []byte) ([]byte, error) { + compressionRatio := payload[0] + compressed := payload[1:] + + if compressionRatio == 0 { + // * Compression ratio of 0 means no compression + return compressed, nil + } + + reader := bytes.NewReader(compressed) + decompressed := bytes.Buffer{} + + lzoReader, err := lzo.NewReader(reader) + if err != nil { + return []byte{}, err + } + + _, err = decompressed.ReadFrom(lzoReader) + if err != nil { + return []byte{}, err + } + + err = lzoReader.Close() + if err != nil { + return []byte{}, err + } + + decompressedBytes := decompressed.Bytes() + + ratioCheck := len(decompressedBytes)/len(compressed) + 1 + + if ratioCheck != int(compressionRatio) { + return []byte{}, fmt.Errorf("Failed to decompress payload. Got bad ratio. Expected %d, got %d", compressionRatio, ratioCheck) + } + + return decompressedBytes, nil +} diff --git a/compression/zlib.go b/compression/zlib.go new file mode 100644 index 00000000..82dd096a --- /dev/null +++ b/compression/zlib.go @@ -0,0 +1,83 @@ +package compression + +import ( + "bytes" + "compress/zlib" + "fmt" +) + +// Zlib implements packet payload compression using zlib +type Zlib struct{} + +// Compress compresses the payload using zlib +func (z *Zlib) Compress(payload []byte) ([]byte, error) { + compressed := bytes.Buffer{} + + zlibWriter := zlib.NewWriter(&compressed) + + _, err := zlibWriter.Write(payload) + if err != nil { + return []byte{}, err + } + + err = zlibWriter.Close() + if err != nil { + return []byte{}, err + } + + compressedBytes := compressed.Bytes() + + compressionRatio := len(payload)/len(compressedBytes) + 1 + + result := make([]byte, len(compressedBytes)+1) + + result[0] = byte(compressionRatio) + + copy(result[1:], compressedBytes) + + return result, nil +} + +// Decompress decompresses the payload using zlib +func (z *Zlib) Decompress(payload []byte) ([]byte, error) { + compressionRatio := payload[0] + compressed := payload[1:] + + if compressionRatio == 0 { + // * Compression ratio of 0 means no compression + return compressed, nil + } + + reader := bytes.NewReader(compressed) + decompressed := bytes.Buffer{} + + zlibReader, err := zlib.NewReader(reader) + if err != nil { + return []byte{}, err + } + + _, err = decompressed.ReadFrom(zlibReader) + if err != nil { + return []byte{}, err + } + + err = zlibReader.Close() + if err != nil { + return []byte{}, err + } + + decompressedBytes := decompressed.Bytes() + + ratioCheck := len(decompressedBytes)/len(compressed) + 1 + + if ratioCheck != int(compressionRatio) { + return []byte{}, fmt.Errorf("Failed to decompress payload. Got bad ratio. Expected %d, got %d", compressionRatio, ratioCheck) + } + + return decompressedBytes, nil +} + +// NewZlibCompression returns a new instance of the Zlib compression +func NewZlibCompression() *Zlib { + return &Zlib{} +} diff --git a/go.mod b/go.mod index 414359b2..cca3b394 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( github.com/PretendoNetwork/plogger-go v1.0.4 + github.com/cyberdelia/lzo v1.0.0 github.com/lxzan/gws v1.7.0 github.com/superwhiskers/crunch/v3 v3.5.7 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 diff --git a/go.sum b/go.sum index 6b888962..6d57e1bc 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/PretendoNetwork/plogger-go v1.0.4 h1:PF7xHw9eDRHH+RsAP9tmAE7fG0N0p6H4iPwHKnsoXwc= github.com/PretendoNetwork/plogger-go v1.0.4/go.mod h1:7kD6M4vPq1JL4LTuPg6kuB1OvUBOwQOtAvTaUwMbwvU= +github.com/cyberdelia/lzo v1.0.0 h1:smmvcahczwI/VWSzZ7iikt50lubari5py3qL4hAEHII= +github.com/cyberdelia/lzo v1.0.0/go.mod h1:UVNk6eM6Sozt1wx17TECJKuqmIY58TJOVeJxjlGGAGs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= diff --git a/prudp_server.go b/prudp_server.go index a735a14e..dbd362f8 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -2,7 +2,6 @@ package nex import ( "bytes" - "compress/zlib" "crypto/rand" "errors" "fmt" @@ -11,6 +10,7 @@ import ( "slices" "time" + "github.com/PretendoNetwork/nex-go/compression" "github.com/lxzan/gws" ) @@ -45,11 +45,9 @@ type PRUDPServer struct { passwordFromPIDHandler func(pid *PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte EnhancedChecksum bool - CompressionEnabled bool PRUDPv0CustomChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 stringLengthSize int - CustomPayloadCompressor func(payload []byte) ([]byte, error) - CustomPayloadDecompressor func(payload []byte) ([]byte, error) + CompressionAlgorithm compression.Algorithm } // OnData adds an event handler which is fired when a new DATA packet is received @@ -618,16 +616,12 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { decryptedPayload = pendingPacket.Payload() } - if s.CompressionEnabled { - decompressedPayload, err := s.decompressPayload(decryptedPayload) - if err != nil { - logger.Error(err.Error()) - } - - decryptedPayload = decompressedPayload + decompressedPayload, err := s.CompressionAlgorithm.Decompress(decryptedPayload) + if err != nil { + logger.Error(err.Error()) } - payload := substream.AddFragment(decryptedPayload) + payload := substream.AddFragment(decompressedPayload) if packet.getFragmentID() == 0 { message := NewRMCMessage() @@ -781,13 +775,9 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { if packetCopy.HasFlag(FlagReliable) { payload := packetCopy.Payload() - if s.CompressionEnabled { - compressedPayload, err := s.compressPayload(payload) - if err != nil { - logger.Error(err.Error()) - } - - payload = compressedPayload + compressedPayload, err := s.CompressionAlgorithm.Compress(payload) + if err != nil { + logger.Error(err.Error()) } substream := client.reliableSubstream(packetCopy.SubstreamID()) @@ -802,7 +792,7 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { // * PRUDPLite packet. No RC4 if packetCopy.Version() != 2 { - packetCopy.SetPayload(substream.Encrypt(payload)) + packetCopy.SetPayload(substream.Encrypt(compressedPayload)) } } else { // * PRUDPLite packet. No RC4 @@ -839,85 +829,6 @@ func (s *PRUDPServer) sendRaw(client *PRUDPClient, data []byte) { } } -func (s *PRUDPServer) decompressPayload(payload []byte) ([]byte, error) { - if s.CustomPayloadDecompressor != nil { - return s.CustomPayloadDecompressor(payload) - } - - compressionRatio := payload[0] - compressed := payload[1:] - - if compressionRatio == 0 { - // * Compression ratio of 0 means no compression - return compressed, nil - } - - reader := bytes.NewReader(compressed) - decompressed := bytes.Buffer{} - - // * Create a zlib reader - zlibReader, err := zlib.NewReader(reader) - if err != nil { - return []byte{}, err - } - - // * Copy the decompressed payload into a buffer - _, err = decompressed.ReadFrom(zlibReader) - if err != nil { - return []byte{}, err - } - - // * Close the zlib reader to flush any remaining data - err = zlibReader.Close() - if err != nil { - return []byte{}, err - } - - decompressedBytes := decompressed.Bytes() - - ratioCheck := len(decompressedBytes)/len(compressed) + 1 - - if ratioCheck != int(compressionRatio) { - return []byte{}, fmt.Errorf("Failed to decompress payload. Got bad ratio. Expected %d, got %d", compressionRatio, ratioCheck) - } - - return decompressedBytes, nil -} - -func (s *PRUDPServer) compressPayload(payload []byte) ([]byte, error) { - if s.CustomPayloadCompressor != nil { - return s.CustomPayloadCompressor(payload) - } - - compressed := bytes.Buffer{} - - // * Create a zlib writer with default compression level - zlibWriter := zlib.NewWriter(&compressed) - - _, err := zlibWriter.Write(payload) - if err != nil { - return []byte{}, err - } - - // * Close the zlib writer to flush any remaining data - err = zlibWriter.Close() - if err != nil { - return []byte{}, err - } - - compressedBytes := compressed.Bytes() - - compressionRatio := len(payload)/len(compressedBytes) + 1 - - stream := NewStreamOut(s) - - stream.WriteUInt8(uint8(compressionRatio)) - stream.Grow(int64(len(compressedBytes))) - stream.WriteBytesNext(compressedBytes) - - return stream.Bytes(), nil -} - // AccessKey returns the servers sandbox access key func (s *PRUDPServer) AccessKey() string { return s.accessKey @@ -1137,5 +1048,6 @@ func NewPRUDPServer() *PRUDPServer { connectionIDCounter: NewCounter[uint32](10), pingTimeout: time.Second * 15, stringLengthSize: 2, + CompressionAlgorithm: compression.NewDummyCompression(), } } From f2562bd26cafbbaaea330659ab895ce70f2305c9 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Dec 2023 17:02:19 -0500 Subject: [PATCH 088/178] rmc: pass server to RMC messages --- hpp_packet.go | 2 +- hpp_server.go | 2 +- prudp_server.go | 4 ++-- rmc_message.go | 36 +++++++++++++++++++++--------------- test/auth.go | 4 ++-- test/hpp.go | 2 +- test/secure.go | 16 ++++++++-------- 7 files changed, 36 insertions(+), 30 deletions(-) diff --git a/hpp_packet.go b/hpp_packet.go index bf84a6c2..fbfb9c3e 100644 --- a/hpp_packet.go +++ b/hpp_packet.go @@ -139,7 +139,7 @@ func NewHPPPacket(client *HPPClient, payload []byte) (*HPPPacket, error) { } if payload != nil { - rmcMessage := NewRMCRequest() + rmcMessage := NewRMCRequest(client.Server()) err := rmcMessage.FromBytes(payload) if err != nil { return nil, fmt.Errorf("Failed to decode HPP request. %s", err) diff --git a/hpp_server.go b/hpp_server.go index 4a285f9f..59e56d7c 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -103,7 +103,7 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { rmcMessage := hppPacket.RMCMessage() // HPP returns PythonCore::ValidationError if password is missing or invalid - errorResponse := NewRMCError(Errors.PythonCore.ValidationError) + errorResponse := NewRMCError(s, Errors.PythonCore.ValidationError) errorResponse.CallID = rmcMessage.CallID errorResponse.IsHPP = true diff --git a/prudp_server.go b/prudp_server.go index dbd362f8..f17db16a 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -624,7 +624,7 @@ func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { payload := substream.AddFragment(decompressedPayload) if packet.getFragmentID() == 0 { - message := NewRMCMessage() + message := NewRMCMessage(s) err := message.FromBytes(payload) if err != nil { // TODO - Should this return the error too? @@ -690,7 +690,7 @@ func (s *PRUDPServer) handleUnreliable(packet PRUDPPacketInterface) { payload := packet.processUnreliableCrypto() - message := NewRMCMessage() + message := NewRMCMessage(s) err := message.FromBytes(payload) if err != nil { // TODO - Should this return the error too? diff --git a/rmc_message.go b/rmc_message.go index 20a66c28..66984bb6 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -7,6 +7,7 @@ import ( // RMCMessage represents a message in the RMC (Remote Method Call) protocol type RMCMessage struct { + Server ServerInterface VerboseMode bool // * Determines whether or not to encode the message using the "verbose" encoding method IsRequest bool // * Indicates if the message is a request message (true) or response message (false) IsSuccess bool // * Indicates if the message is a success message (true) for a response message @@ -24,7 +25,7 @@ type RMCMessage struct { // Copy copies the message into a new RMCMessage func (rmc *RMCMessage) Copy() *RMCMessage { - copied := NewRMCMessage() + copied := NewRMCMessage(rmc.Server) copied.IsRequest = rmc.IsRequest copied.IsSuccess = rmc.IsSuccess @@ -53,7 +54,7 @@ func (rmc *RMCMessage) FromBytes(data []byte) error { } func (rmc *RMCMessage) decodePacked(data []byte) error { - stream := NewStreamIn(data, nil) + stream := NewStreamIn(data, rmc.Server) length, err := stream.ReadUInt32LE() if err != nil { @@ -140,7 +141,7 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } func (rmc *RMCMessage) decodeVerbose(data []byte) error { - stream := NewStreamIn(data, nil) + stream := NewStreamIn(data, rmc.Server) length, err := stream.ReadUInt32LE() if err != nil { @@ -229,7 +230,7 @@ func (rmc *RMCMessage) Bytes() []byte { } func (rmc *RMCMessage) encodePacked() []byte { - stream := NewStreamOut(nil) + stream := NewStreamOut(rmc.Server) // * RMC requests have their protocol IDs ORed with 0x80 var protocolIDFlag uint16 = 0x80 @@ -276,7 +277,7 @@ func (rmc *RMCMessage) encodePacked() []byte { serialized := stream.Bytes() - message := NewStreamOut(nil) + message := NewStreamOut(rmc.Server) message.WriteUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) @@ -286,7 +287,7 @@ func (rmc *RMCMessage) encodePacked() []byte { } func (rmc *RMCMessage) encodeVerbose() []byte { - stream := NewStreamOut(nil) + stream := NewStreamOut(rmc.Server) stream.WriteString(rmc.ProtocolName) stream.WriteBool(rmc.IsRequest) @@ -325,7 +326,7 @@ func (rmc *RMCMessage) encodeVerbose() []byte { serialized := stream.Bytes() - message := NewStreamOut(nil) + message := NewStreamOut(rmc.Server) message.WriteUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) @@ -335,18 +336,23 @@ func (rmc *RMCMessage) encodeVerbose() []byte { } // NewRMCMessage returns a new generic RMC Message -func NewRMCMessage() *RMCMessage { - return &RMCMessage{} +func NewRMCMessage(server ServerInterface) *RMCMessage { + return &RMCMessage{ + Server: server, + } } // NewRMCRequest returns a new blank RMCRequest -func NewRMCRequest() *RMCMessage { - return &RMCMessage{IsRequest: true} +func NewRMCRequest(server ServerInterface) *RMCMessage { + return &RMCMessage{ + Server: server, + IsRequest: true, + } } // NewRMCSuccess returns a new RMC Message configured as a success response -func NewRMCSuccess(parameters []byte) *RMCMessage { - message := NewRMCMessage() +func NewRMCSuccess(server ServerInterface, parameters []byte) *RMCMessage { + message := NewRMCMessage(server) message.IsRequest = false message.IsSuccess = true message.Parameters = parameters @@ -355,12 +361,12 @@ func NewRMCSuccess(parameters []byte) *RMCMessage { } // NewRMCError returns a new RMC Message configured as a error response -func NewRMCError(errorCode uint32) *RMCMessage { +func NewRMCError(server ServerInterface, errorCode uint32) *RMCMessage { if int(errorCode)&errorMask == 0 { errorCode = uint32(int(errorCode) | errorMask) } - message := NewRMCMessage() + message := NewRMCMessage(server) message.IsRequest = false message.IsSuccess = false message.ErrorCode = errorCode diff --git a/test/auth.go b/test/auth.go index 08f8904e..dd253465 100644 --- a/test/auth.go +++ b/test/auth.go @@ -44,7 +44,7 @@ func startAuthenticationServer() { func login(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage() + response := nex.NewRMCMessage(authServer) parameters := request.Parameters @@ -105,7 +105,7 @@ func login(packet nex.PRUDPPacketInterface) { func requestTicket(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage() + response := nex.NewRMCMessage(authServer) parameters := request.Parameters diff --git a/test/hpp.go b/test/hpp.go index 36e3cee4..adf445ed 100644 --- a/test/hpp.go +++ b/test/hpp.go @@ -75,7 +75,7 @@ func startHPPServer() { func getNotificationURL(packet *nex.HPPPacket) { request := packet.RMCMessage() - response := nex.NewRMCMessage() + response := nex.NewRMCMessage(hppServer) parameters := request.Parameters diff --git a/test/secure.go b/test/secure.go index c1907841..4ad384c3 100644 --- a/test/secure.go +++ b/test/secure.go @@ -87,11 +87,11 @@ func startSecureServer() { func registerEx(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage() + response := nex.NewRMCMessage(secureServer) parameters := request.Parameters - parametersStream := nex.NewStreamIn(parameters, authServer) + parametersStream := nex.NewStreamIn(parameters, secureServer) vecMyURLs, err := parametersStream.ReadListStationURL() if err != nil { @@ -113,7 +113,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { retval := nex.NewResultSuccess(0x00010001) localStationURL := localStation.EncodeToString() - responseStream := nex.NewStreamOut(authServer) + responseStream := nex.NewStreamOut(secureServer) responseStream.WriteResult(retval) responseStream.WriteUInt32LE(secureServer.ConnectionIDCounter().Next()) @@ -145,9 +145,9 @@ func registerEx(packet nex.PRUDPPacketInterface) { func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage() + response := nex.NewRMCMessage(secureServer) - responseStream := nex.NewStreamOut(authServer) + responseStream := nex.NewStreamOut(secureServer) responseStream.WriteStructure(&principalPreference{ ShowOnlinePresence: true, @@ -193,9 +193,9 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { func checkSettingStatus(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage() + response := nex.NewRMCMessage(secureServer) - responseStream := nex.NewStreamOut(authServer) + responseStream := nex.NewStreamOut(secureServer) responseStream.WriteUInt8(0) // * Unknown @@ -225,7 +225,7 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { func updatePresence(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage() + response := nex.NewRMCMessage(secureServer) response.IsSuccess = true response.IsRequest = false From 4885f237400082a8528743dc9499218c7d19c50c Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 22 Dec 2023 00:02:09 -0500 Subject: [PATCH 089/178] types: added qUUID type --- stream_in.go | 403 +++++++++++++++++++++++++++----------------------- stream_out.go | 376 ++++++++++++++++++++++++---------------------- types.go | 279 ++++++++++++++++++++++++++++------ 3 files changed, 646 insertions(+), 412 deletions(-) diff --git a/stream_in.go b/stream_in.go index c698ddf1..b48df3d9 100644 --- a/stream_in.go +++ b/stream_in.go @@ -15,224 +15,224 @@ type StreamIn struct { } // Remaining returns the amount of data left to be read in the buffer -func (stream *StreamIn) Remaining() int { - return len(stream.Bytes()[stream.ByteOffset():]) +func (s *StreamIn) Remaining() int { + return len(s.Bytes()[s.ByteOffset():]) } // ReadRemaining reads all the data left to be read in the buffer -func (stream *StreamIn) ReadRemaining() []byte { +func (s *StreamIn) ReadRemaining() []byte { // TODO - Should we do a bounds check here? Or just allow empty slices? - return stream.ReadBytesNext(int64(stream.Remaining())) + return s.ReadBytesNext(int64(s.Remaining())) } // ReadUInt8 reads a uint8 -func (stream *StreamIn) ReadUInt8() (uint8, error) { - if stream.Remaining() < 1 { +func (s *StreamIn) ReadUInt8() (uint8, error) { + if s.Remaining() < 1 { return 0, errors.New("Not enough data to read uint8") } - return uint8(stream.ReadByteNext()), nil + return uint8(s.ReadByteNext()), nil } // ReadInt8 reads a uint8 -func (stream *StreamIn) ReadInt8() (int8, error) { - if stream.Remaining() < 1 { +func (s *StreamIn) ReadInt8() (int8, error) { + if s.Remaining() < 1 { return 0, errors.New("Not enough data to read int8") } - return int8(stream.ReadByteNext()), nil + return int8(s.ReadByteNext()), nil } // ReadUInt16LE reads a Little-Endian encoded uint16 -func (stream *StreamIn) ReadUInt16LE() (uint16, error) { - if stream.Remaining() < 2 { +func (s *StreamIn) ReadUInt16LE() (uint16, error) { + if s.Remaining() < 2 { return 0, errors.New("Not enough data to read uint16") } - return stream.ReadU16LENext(1)[0], nil + return s.ReadU16LENext(1)[0], nil } // ReadUInt16BE reads a Big-Endian encoded uint16 -func (stream *StreamIn) ReadUInt16BE() (uint16, error) { - if stream.Remaining() < 2 { +func (s *StreamIn) ReadUInt16BE() (uint16, error) { + if s.Remaining() < 2 { return 0, errors.New("Not enough data to read uint16") } - return stream.ReadU16BENext(1)[0], nil + return s.ReadU16BENext(1)[0], nil } // ReadInt16LE reads a Little-Endian encoded int16 -func (stream *StreamIn) ReadInt16LE() (int16, error) { - if stream.Remaining() < 2 { +func (s *StreamIn) ReadInt16LE() (int16, error) { + if s.Remaining() < 2 { return 0, errors.New("Not enough data to read int16") } - return int16(stream.ReadU16LENext(1)[0]), nil + return int16(s.ReadU16LENext(1)[0]), nil } // ReadInt16BE reads a Big-Endian encoded int16 -func (stream *StreamIn) ReadInt16BE() (int16, error) { - if stream.Remaining() < 2 { +func (s *StreamIn) ReadInt16BE() (int16, error) { + if s.Remaining() < 2 { return 0, errors.New("Not enough data to read int16") } - return int16(stream.ReadU16BENext(1)[0]), nil + return int16(s.ReadU16BENext(1)[0]), nil } // ReadUInt32LE reads a Little-Endian encoded uint32 -func (stream *StreamIn) ReadUInt32LE() (uint32, error) { - if stream.Remaining() < 4 { +func (s *StreamIn) ReadUInt32LE() (uint32, error) { + if s.Remaining() < 4 { return 0, errors.New("Not enough data to read uint32") } - return stream.ReadU32LENext(1)[0], nil + return s.ReadU32LENext(1)[0], nil } // ReadUInt32BE reads a Big-Endian encoded uint32 -func (stream *StreamIn) ReadUInt32BE() (uint32, error) { - if stream.Remaining() < 4 { +func (s *StreamIn) ReadUInt32BE() (uint32, error) { + if s.Remaining() < 4 { return 0, errors.New("Not enough data to read uint32") } - return stream.ReadU32BENext(1)[0], nil + return s.ReadU32BENext(1)[0], nil } // ReadInt32LE reads a Little-Endian encoded int32 -func (stream *StreamIn) ReadInt32LE() (int32, error) { - if stream.Remaining() < 4 { +func (s *StreamIn) ReadInt32LE() (int32, error) { + if s.Remaining() < 4 { return 0, errors.New("Not enough data to read int32") } - return int32(stream.ReadU32LENext(1)[0]), nil + return int32(s.ReadU32LENext(1)[0]), nil } // ReadInt32BE reads a Big-Endian encoded int32 -func (stream *StreamIn) ReadInt32BE() (int32, error) { - if stream.Remaining() < 4 { +func (s *StreamIn) ReadInt32BE() (int32, error) { + if s.Remaining() < 4 { return 0, errors.New("Not enough data to read int32") } - return int32(stream.ReadU32BENext(1)[0]), nil + return int32(s.ReadU32BENext(1)[0]), nil } // ReadUInt64LE reads a Little-Endian encoded uint64 -func (stream *StreamIn) ReadUInt64LE() (uint64, error) { - if stream.Remaining() < 8 { +func (s *StreamIn) ReadUInt64LE() (uint64, error) { + if s.Remaining() < 8 { return 0, errors.New("Not enough data to read uint64") } - return stream.ReadU64LENext(1)[0], nil + return s.ReadU64LENext(1)[0], nil } // ReadUInt64BE reads a Big-Endian encoded uint64 -func (stream *StreamIn) ReadUInt64BE() (uint64, error) { - if stream.Remaining() < 8 { +func (s *StreamIn) ReadUInt64BE() (uint64, error) { + if s.Remaining() < 8 { return 0, errors.New("Not enough data to read uint64") } - return stream.ReadU64BENext(1)[0], nil + return s.ReadU64BENext(1)[0], nil } // ReadInt64LE reads a Little-Endian encoded int64 -func (stream *StreamIn) ReadInt64LE() (int64, error) { - if stream.Remaining() < 8 { +func (s *StreamIn) ReadInt64LE() (int64, error) { + if s.Remaining() < 8 { return 0, errors.New("Not enough data to read int64") } - return int64(stream.ReadU64LENext(1)[0]), nil + return int64(s.ReadU64LENext(1)[0]), nil } // ReadInt64BE reads a Big-Endian encoded int64 -func (stream *StreamIn) ReadInt64BE() (int64, error) { - if stream.Remaining() < 8 { +func (s *StreamIn) ReadInt64BE() (int64, error) { + if s.Remaining() < 8 { return 0, errors.New("Not enough data to read int64") } - return int64(stream.ReadU64BENext(1)[0]), nil + return int64(s.ReadU64BENext(1)[0]), nil } // ReadFloat32LE reads a Little-Endian encoded float32 -func (stream *StreamIn) ReadFloat32LE() (float32, error) { - if stream.Remaining() < 4 { +func (s *StreamIn) ReadFloat32LE() (float32, error) { + if s.Remaining() < 4 { return 0, errors.New("Not enough data to read float32") } - return stream.ReadF32LENext(1)[0], nil + return s.ReadF32LENext(1)[0], nil } // ReadFloat32BE reads a Big-Endian encoded float32 -func (stream *StreamIn) ReadFloat32BE() (float32, error) { - if stream.Remaining() < 4 { +func (s *StreamIn) ReadFloat32BE() (float32, error) { + if s.Remaining() < 4 { return 0, errors.New("Not enough data to read float32") } - return stream.ReadF32BENext(1)[0], nil + return s.ReadF32BENext(1)[0], nil } // ReadFloat64LE reads a Little-Endian encoded float64 -func (stream *StreamIn) ReadFloat64LE() (float64, error) { - if stream.Remaining() < 8 { +func (s *StreamIn) ReadFloat64LE() (float64, error) { + if s.Remaining() < 8 { return 0, errors.New("Not enough data to read float64") } - return stream.ReadF64LENext(1)[0], nil + return s.ReadF64LENext(1)[0], nil } // ReadFloat64BE reads a Big-Endian encoded float64 -func (stream *StreamIn) ReadFloat64BE() (float64, error) { - if stream.Remaining() < 8 { +func (s *StreamIn) ReadFloat64BE() (float64, error) { + if s.Remaining() < 8 { return 0, errors.New("Not enough data to read float64") } - return stream.ReadF64BENext(1)[0], nil + return s.ReadF64BENext(1)[0], nil } // ReadBool reads a bool -func (stream *StreamIn) ReadBool() (bool, error) { - if stream.Remaining() < 1 { +func (s *StreamIn) ReadBool() (bool, error) { + if s.Remaining() < 1 { return false, errors.New("Not enough data to read bool") } - return stream.ReadByteNext() == 1, nil + return s.ReadByteNext() == 1, nil } // ReadPID reads a PID. The size depends on the server version -func (stream *StreamIn) ReadPID() (*PID, error) { - if stream.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - if stream.Remaining() < 8 { +func (s *StreamIn) ReadPID() (*PID, error) { + if s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + if s.Remaining() < 8 { return nil, errors.New("Not enough data to read PID") } - pid, _ := stream.ReadUInt64LE() + pid, _ := s.ReadUInt64LE() return NewPID(pid), nil } else { - if stream.Remaining() < 4 { + if s.Remaining() < 4 { return nil, errors.New("Not enough data to read legacy PID") } - pid, _ := stream.ReadUInt32LE() + pid, _ := s.ReadUInt32LE() return NewPID(pid), nil } } // ReadString reads and returns a nex string type -func (stream *StreamIn) ReadString() (string, error) { +func (s *StreamIn) ReadString() (string, error) { var length int64 var err error // TODO - These variable names kinda suck? - if stream.Server == nil { - l, e := stream.ReadUInt16LE() + if s.Server == nil { + l, e := s.ReadUInt16LE() length = int64(l) err = e - } else if stream.Server.StringLengthSize() == 4 { - l, e := stream.ReadUInt32LE() + } else if s.Server.StringLengthSize() == 4 { + l, e := s.ReadUInt32LE() length = int64(l) err = e } else { - l, e := stream.ReadUInt16LE() + l, e := s.ReadUInt16LE() length = int64(l) err = e } @@ -241,53 +241,53 @@ func (stream *StreamIn) ReadString() (string, error) { return "", fmt.Errorf("Failed to read NEX string length. %s", err.Error()) } - if stream.Remaining() < int(length) { + if s.Remaining() < int(length) { return "", errors.New("NEX string length longer than data size") } - stringData := stream.ReadBytesNext(length) + stringData := s.ReadBytesNext(length) str := string(stringData) return strings.TrimRight(str, "\x00"), nil } // ReadBuffer reads a nex Buffer type -func (stream *StreamIn) ReadBuffer() ([]byte, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadBuffer() ([]byte, error) { + length, err := s.ReadUInt32LE() if err != nil { return []byte{}, fmt.Errorf("Failed to read NEX buffer length. %s", err.Error()) } - if stream.Remaining() < int(length) { + if s.Remaining() < int(length) { return []byte{}, errors.New("NEX buffer length longer than data size") } - data := stream.ReadBytesNext(int64(length)) + data := s.ReadBytesNext(int64(length)) return data, nil } // ReadQBuffer reads a nex qBuffer type -func (stream *StreamIn) ReadQBuffer() ([]byte, error) { - length, err := stream.ReadUInt16LE() +func (s *StreamIn) ReadQBuffer() ([]byte, error) { + length, err := s.ReadUInt16LE() if err != nil { return []byte{}, fmt.Errorf("Failed to read NEX qBuffer length. %s", err.Error()) } - if stream.Remaining() < int(length) { + if s.Remaining() < int(length) { return []byte{}, errors.New("NEX qBuffer length longer than data size") } - data := stream.ReadBytesNext(int64(length)) + data := s.ReadBytesNext(int64(length)) return data, nil } // ReadVariant reads a Variant type. This type can hold 7 different types -func (stream *StreamIn) ReadVariant() (*Variant, error) { +func (s *StreamIn) ReadVariant() (*Variant, error) { variant := NewVariant() - err := variant.ExtractFromStream(stream) + err := variant.ExtractFromStream(s) if err != nil { return nil, fmt.Errorf("Failed to read Variant. %s", err.Error()) } @@ -296,8 +296,8 @@ func (stream *StreamIn) ReadVariant() (*Variant, error) { } // ReadDateTime reads a DateTime type -func (stream *StreamIn) ReadDateTime() (*DateTime, error) { - value, err := stream.ReadUInt64LE() +func (s *StreamIn) ReadDateTime() (*DateTime, error) { + value, err := s.ReadUInt64LE() if err != nil { return nil, fmt.Errorf("Failed to read DateTime value. %s", err.Error()) } @@ -306,9 +306,9 @@ func (stream *StreamIn) ReadDateTime() (*DateTime, error) { } // ReadDataHolder reads a DataHolder type -func (stream *StreamIn) ReadDataHolder() (*DataHolder, error) { +func (s *StreamIn) ReadDataHolder() (*DataHolder, error) { dataHolder := NewDataHolder() - err := dataHolder.ExtractFromStream(stream) + err := dataHolder.ExtractFromStream(s) if err != nil { return nil, fmt.Errorf("Failed to read DateHolder. %s", err.Error()) } @@ -317,8 +317,8 @@ func (stream *StreamIn) ReadDataHolder() (*DataHolder, error) { } // ReadStationURL reads a StationURL type -func (stream *StreamIn) ReadStationURL() (*StationURL, error) { - stationString, err := stream.ReadString() +func (s *StreamIn) ReadStationURL() (*StationURL, error) { + stationString, err := s.ReadString() if err != nil { return nil, fmt.Errorf("Failed to read StationURL. %s", err.Error()) } @@ -326,21 +326,33 @@ func (stream *StreamIn) ReadStationURL() (*StationURL, error) { return NewStationURL(stationString), nil } +// ReadQUUID reads a qUUID type +func (s *StreamIn) ReadQUUID() (*QUUID, error) { + qUUID := NewQUUID() + + err := qUUID.ExtractFromStream(s) + if err != nil { + return nil, fmt.Errorf("Failed to read qUUID. %s", err.Error()) + } + + return qUUID, nil +} + // ReadListUInt8 reads a list of uint8 types -func (stream *StreamIn) ReadListUInt8() ([]uint8, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListUInt8() ([]uint8, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length) { + if s.Remaining() < int(length) { return nil, errors.New("NEX List length longer than data size") } list := make([]uint8, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadUInt8() + value, err := s.ReadUInt8() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -352,20 +364,20 @@ func (stream *StreamIn) ReadListUInt8() ([]uint8, error) { } // ReadListInt8 reads a list of int8 types -func (stream *StreamIn) ReadListInt8() ([]int8, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListInt8() ([]int8, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length) { + if s.Remaining() < int(length) { return nil, errors.New("NEX List length longer than data size") } list := make([]int8, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadInt8() + value, err := s.ReadInt8() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -377,20 +389,20 @@ func (stream *StreamIn) ReadListInt8() ([]int8, error) { } // ReadListUInt16LE reads a list of Little-Endian encoded uint16 types -func (stream *StreamIn) ReadListUInt16LE() ([]uint16, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListUInt16LE() ([]uint16, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*2) { + if s.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } list := make([]uint16, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadUInt16LE() + value, err := s.ReadUInt16LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -402,20 +414,20 @@ func (stream *StreamIn) ReadListUInt16LE() ([]uint16, error) { } // ReadListUInt16BE reads a list of Big-Endian encoded uint16 types -func (stream *StreamIn) ReadListUInt16BE() ([]uint16, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListUInt16BE() ([]uint16, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*2) { + if s.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } list := make([]uint16, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadUInt16BE() + value, err := s.ReadUInt16BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -427,20 +439,20 @@ func (stream *StreamIn) ReadListUInt16BE() ([]uint16, error) { } // ReadListInt16LE reads a list of Little-Endian encoded int16 types -func (stream *StreamIn) ReadListInt16LE() ([]int16, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListInt16LE() ([]int16, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*2) { + if s.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } list := make([]int16, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadInt16LE() + value, err := s.ReadInt16LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -452,20 +464,20 @@ func (stream *StreamIn) ReadListInt16LE() ([]int16, error) { } // ReadListInt16BE reads a list of Big-Endian encoded uint16 types -func (stream *StreamIn) ReadListInt16BE() ([]int16, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListInt16BE() ([]int16, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*2) { + if s.Remaining() < int(length*2) { return nil, errors.New("NEX List length longer than data size") } list := make([]int16, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadInt16BE() + value, err := s.ReadInt16BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -477,20 +489,20 @@ func (stream *StreamIn) ReadListInt16BE() ([]int16, error) { } // ReadListUInt32LE reads a list of Little-Endian encoded uint32 types -func (stream *StreamIn) ReadListUInt32LE() ([]uint32, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListUInt32LE() ([]uint32, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]uint32, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadUInt32LE() + value, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -502,20 +514,20 @@ func (stream *StreamIn) ReadListUInt32LE() ([]uint32, error) { } // ReadListUInt32BE reads a list of Big-Endian encoded uint32 types -func (stream *StreamIn) ReadListUInt32BE() ([]uint32, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListUInt32BE() ([]uint32, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]uint32, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadUInt32BE() + value, err := s.ReadUInt32BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -527,20 +539,20 @@ func (stream *StreamIn) ReadListUInt32BE() ([]uint32, error) { } // ReadListInt32LE reads a list of Little-Endian encoded int32 types -func (stream *StreamIn) ReadListInt32LE() ([]int32, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListInt32LE() ([]int32, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]int32, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadInt32LE() + value, err := s.ReadInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -552,20 +564,20 @@ func (stream *StreamIn) ReadListInt32LE() ([]int32, error) { } // ReadListInt32BE reads a list of Big-Endian encoded int32 types -func (stream *StreamIn) ReadListInt32BE() ([]int32, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListInt32BE() ([]int32, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]int32, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadInt32BE() + value, err := s.ReadInt32BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -577,20 +589,20 @@ func (stream *StreamIn) ReadListInt32BE() ([]int32, error) { } // ReadListUInt64LE reads a list of Little-Endian encoded uint64 types -func (stream *StreamIn) ReadListUInt64LE() ([]uint64, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListUInt64LE() ([]uint64, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*8) { + if s.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } list := make([]uint64, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadUInt64LE() + value, err := s.ReadUInt64LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -602,20 +614,20 @@ func (stream *StreamIn) ReadListUInt64LE() ([]uint64, error) { } // ReadListUInt64BE reads a list of Big-Endian encoded uint64 types -func (stream *StreamIn) ReadListUInt64BE() ([]uint64, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListUInt64BE() ([]uint64, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*8) { + if s.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } list := make([]uint64, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadUInt64BE() + value, err := s.ReadUInt64BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -627,20 +639,20 @@ func (stream *StreamIn) ReadListUInt64BE() ([]uint64, error) { } // ReadListInt64LE reads a list of Little-Endian encoded int64 types -func (stream *StreamIn) ReadListInt64LE() ([]int64, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListInt64LE() ([]int64, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*8) { + if s.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } list := make([]int64, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadInt64LE() + value, err := s.ReadInt64LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -652,20 +664,20 @@ func (stream *StreamIn) ReadListInt64LE() ([]int64, error) { } // ReadListInt64BE reads a list of Big-Endian encoded int64 types -func (stream *StreamIn) ReadListInt64BE() ([]int64, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListInt64BE() ([]int64, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*8) { + if s.Remaining() < int(length*8) { return nil, errors.New("NEX List length longer than data size") } list := make([]int64, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadInt64BE() + value, err := s.ReadInt64BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -677,20 +689,20 @@ func (stream *StreamIn) ReadListInt64BE() ([]int64, error) { } // ReadListFloat32LE reads a list of Little-Endian encoded float32 types -func (stream *StreamIn) ReadListFloat32LE() ([]float32, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListFloat32LE() ([]float32, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]float32, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadFloat32LE() + value, err := s.ReadFloat32LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -702,20 +714,20 @@ func (stream *StreamIn) ReadListFloat32LE() ([]float32, error) { } // ReadListFloat32BE reads a list of Big-Endian encoded float32 types -func (stream *StreamIn) ReadListFloat32BE() ([]float32, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListFloat32BE() ([]float32, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]float32, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadFloat32BE() + value, err := s.ReadFloat32BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -727,20 +739,20 @@ func (stream *StreamIn) ReadListFloat32BE() ([]float32, error) { } // ReadListFloat64LE reads a list of Little-Endian encoded float64 types -func (stream *StreamIn) ReadListFloat64LE() ([]float64, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListFloat64LE() ([]float64, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]float64, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadFloat64LE() + value, err := s.ReadFloat64LE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -752,20 +764,20 @@ func (stream *StreamIn) ReadListFloat64LE() ([]float64, error) { } // ReadListFloat64BE reads a list of Big-Endian encoded float64 types -func (stream *StreamIn) ReadListFloat64BE() ([]float64, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListFloat64BE() ([]float64, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } - if stream.Remaining() < int(length*4) { + if s.Remaining() < int(length*4) { return nil, errors.New("NEX List length longer than data size") } list := make([]float64, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadFloat64BE() + value, err := s.ReadFloat64BE() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -777,8 +789,8 @@ func (stream *StreamIn) ReadListFloat64BE() ([]float64, error) { } // ReadListPID reads a list of NEX PIDs -func (stream *StreamIn) ReadListPID() ([]*PID, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListPID() ([]*PID, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } @@ -786,7 +798,7 @@ func (stream *StreamIn) ReadListPID() ([]*PID, error) { list := make([]*PID, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadPID() + value, err := s.ReadPID() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -798,8 +810,8 @@ func (stream *StreamIn) ReadListPID() ([]*PID, error) { } // ReadListString reads a list of NEX String types -func (stream *StreamIn) ReadListString() ([]string, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListString() ([]string, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } @@ -807,7 +819,7 @@ func (stream *StreamIn) ReadListString() ([]string, error) { list := make([]string, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadString() + value, err := s.ReadString() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -819,8 +831,8 @@ func (stream *StreamIn) ReadListString() ([]string, error) { } // ReadListBuffer reads a list of NEX Buffer types -func (stream *StreamIn) ReadListBuffer() ([][]byte, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListBuffer() ([][]byte, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } @@ -828,7 +840,7 @@ func (stream *StreamIn) ReadListBuffer() ([][]byte, error) { list := make([][]byte, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadBuffer() + value, err := s.ReadBuffer() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -840,8 +852,8 @@ func (stream *StreamIn) ReadListBuffer() ([][]byte, error) { } // ReadListQBuffer reads a list of NEX qBuffer types -func (stream *StreamIn) ReadListQBuffer() ([][]byte, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListQBuffer() ([][]byte, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } @@ -849,7 +861,7 @@ func (stream *StreamIn) ReadListQBuffer() ([][]byte, error) { list := make([][]byte, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadQBuffer() + value, err := s.ReadQBuffer() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -861,8 +873,8 @@ func (stream *StreamIn) ReadListQBuffer() ([][]byte, error) { } // ReadListStationURL reads a list of NEX Station URL types -func (stream *StreamIn) ReadListStationURL() ([]*StationURL, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListStationURL() ([]*StationURL, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } @@ -870,7 +882,7 @@ func (stream *StreamIn) ReadListStationURL() ([]*StationURL, error) { list := make([]*StationURL, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadStationURL() + value, err := s.ReadStationURL() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -882,8 +894,8 @@ func (stream *StreamIn) ReadListStationURL() ([]*StationURL, error) { } // ReadListDataHolder reads a list of NEX DataHolder types -func (stream *StreamIn) ReadListDataHolder() ([]*DataHolder, error) { - length, err := stream.ReadUInt32LE() +func (s *StreamIn) ReadListDataHolder() ([]*DataHolder, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) } @@ -891,7 +903,7 @@ func (stream *StreamIn) ReadListDataHolder() ([]*DataHolder, error) { list := make([]*DataHolder, 0, length) for i := 0; i < int(length); i++ { - value, err := stream.ReadDataHolder() + value, err := s.ReadDataHolder() if err != nil { return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) } @@ -902,6 +914,27 @@ func (stream *StreamIn) ReadListDataHolder() ([]*DataHolder, error) { return list, nil } +// ReadListQUUID reads a list of NEX qUUID types +func (s *StreamIn) ReadListQUUID() ([]*QUUID, error) { + length, err := s.ReadUInt32LE() + if err != nil { + return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) + } + + list := make([]*QUUID, 0, length) + + for i := 0; i < int(length); i++ { + value, err := s.ReadQUUID() + if err != nil { + return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) + } + + list = append(list, value) + } + + return list, nil +} + // NewStreamIn returns a new NEX input stream func NewStreamIn(data []byte, server ServerInterface) *StreamIn { return &StreamIn{ @@ -915,7 +948,7 @@ func NewStreamIn(data []byte, server ServerInterface) *StreamIn { // Implemented as a separate function to utilize generics func StreamReadStructure[T StructureInterface](stream *StreamIn, structure T) (T, error) { if structure.ParentType() != nil { - //_, err := stream.ReadStructure(structure.ParentType()) + //_, err := s.ReadStructure(structure.ParentType()) _, err := StreamReadStructure(stream, structure.ParentType()) if err != nil { return structure, fmt.Errorf("Failed to read structure parent. %s", err.Error()) @@ -953,7 +986,7 @@ func StreamReadStructure[T StructureInterface](stream *StreamIn, structure T) (T err := structure.ExtractFromStream(stream) if err != nil { - return structure, fmt.Errorf("Failed to read structure from stream. %s", err.Error()) + return structure, fmt.Errorf("Failed to read structure from s. %s", err.Error()) } return structure, nil @@ -987,8 +1020,8 @@ func StreamReadListStructure[T StructureInterface](stream *StreamIn, structure T // StreamReadMap reads a Map type with the given key and value types from a StreamIn // // Implemented as a separate function to utilize generics -func StreamReadMap[K comparable, V any](stream *StreamIn, keyReader func() (K, error), valueReader func() (V, error)) (map[K]V, error) { - length, err := stream.ReadUInt32LE() +func StreamReadMap[K comparable, V any](s *StreamIn, keyReader func() (K, error), valueReader func() (V, error)) (map[K]V, error) { + length, err := s.ReadUInt32LE() if err != nil { return nil, fmt.Errorf("Failed to read Map length. %s", err.Error()) } diff --git a/stream_out.go b/stream_out.go index 151e6938..0b8bef46 100644 --- a/stream_out.go +++ b/stream_out.go @@ -11,190 +11,190 @@ type StreamOut struct { } // WriteUInt8 writes a uint8 -func (stream *StreamOut) WriteUInt8(u8 uint8) { - stream.Grow(1) - stream.WriteByteNext(byte(u8)) +func (s *StreamOut) WriteUInt8(u8 uint8) { + s.Grow(1) + s.WriteByteNext(byte(u8)) } // WriteInt8 writes a int8 -func (stream *StreamOut) WriteInt8(s8 int8) { - stream.Grow(1) - stream.WriteByteNext(byte(s8)) +func (s *StreamOut) WriteInt8(s8 int8) { + s.Grow(1) + s.WriteByteNext(byte(s8)) } // WriteUInt16LE writes a uint16 as LE -func (stream *StreamOut) WriteUInt16LE(u16 uint16) { - stream.Grow(2) - stream.WriteU16LENext([]uint16{u16}) +func (s *StreamOut) WriteUInt16LE(u16 uint16) { + s.Grow(2) + s.WriteU16LENext([]uint16{u16}) } // WriteUInt16BE writes a uint16 as BE -func (stream *StreamOut) WriteUInt16BE(u16 uint16) { - stream.Grow(2) - stream.WriteU16BENext([]uint16{u16}) +func (s *StreamOut) WriteUInt16BE(u16 uint16) { + s.Grow(2) + s.WriteU16BENext([]uint16{u16}) } // WriteInt16LE writes a uint16 as LE -func (stream *StreamOut) WriteInt16LE(s16 int16) { - stream.Grow(2) - stream.WriteU16LENext([]uint16{uint16(s16)}) +func (s *StreamOut) WriteInt16LE(s16 int16) { + s.Grow(2) + s.WriteU16LENext([]uint16{uint16(s16)}) } // WriteInt16BE writes a uint16 as BE -func (stream *StreamOut) WriteInt16BE(s16 int16) { - stream.Grow(2) - stream.WriteU16BENext([]uint16{uint16(s16)}) +func (s *StreamOut) WriteInt16BE(s16 int16) { + s.Grow(2) + s.WriteU16BENext([]uint16{uint16(s16)}) } // WriteUInt32LE writes a uint32 as LE -func (stream *StreamOut) WriteUInt32LE(u32 uint32) { - stream.Grow(4) - stream.WriteU32LENext([]uint32{u32}) +func (s *StreamOut) WriteUInt32LE(u32 uint32) { + s.Grow(4) + s.WriteU32LENext([]uint32{u32}) } // WriteUInt32BE writes a uint32 as BE -func (stream *StreamOut) WriteUInt32BE(u32 uint32) { - stream.Grow(4) - stream.WriteU32BENext([]uint32{u32}) +func (s *StreamOut) WriteUInt32BE(u32 uint32) { + s.Grow(4) + s.WriteU32BENext([]uint32{u32}) } // WriteInt32LE writes a int32 as LE -func (stream *StreamOut) WriteInt32LE(s32 int32) { - stream.Grow(4) - stream.WriteU32LENext([]uint32{uint32(s32)}) +func (s *StreamOut) WriteInt32LE(s32 int32) { + s.Grow(4) + s.WriteU32LENext([]uint32{uint32(s32)}) } // WriteInt32BE writes a int32 as BE -func (stream *StreamOut) WriteInt32BE(s32 int32) { - stream.Grow(4) - stream.WriteU32BENext([]uint32{uint32(s32)}) +func (s *StreamOut) WriteInt32BE(s32 int32) { + s.Grow(4) + s.WriteU32BENext([]uint32{uint32(s32)}) } // WriteUInt64LE writes a uint64 as LE -func (stream *StreamOut) WriteUInt64LE(u64 uint64) { - stream.Grow(8) - stream.WriteU64LENext([]uint64{u64}) +func (s *StreamOut) WriteUInt64LE(u64 uint64) { + s.Grow(8) + s.WriteU64LENext([]uint64{u64}) } // WriteUInt64BE writes a uint64 as BE -func (stream *StreamOut) WriteUInt64BE(u64 uint64) { - stream.Grow(8) - stream.WriteU64BENext([]uint64{u64}) +func (s *StreamOut) WriteUInt64BE(u64 uint64) { + s.Grow(8) + s.WriteU64BENext([]uint64{u64}) } // WriteInt64LE writes a int64 as LE -func (stream *StreamOut) WriteInt64LE(s64 int64) { - stream.Grow(8) - stream.WriteU64LENext([]uint64{uint64(s64)}) +func (s *StreamOut) WriteInt64LE(s64 int64) { + s.Grow(8) + s.WriteU64LENext([]uint64{uint64(s64)}) } // WriteInt64BE writes a int64 as BE -func (stream *StreamOut) WriteInt64BE(s64 int64) { - stream.Grow(8) - stream.WriteU64BENext([]uint64{uint64(s64)}) +func (s *StreamOut) WriteInt64BE(s64 int64) { + s.Grow(8) + s.WriteU64BENext([]uint64{uint64(s64)}) } // WriteFloat32LE writes a float32 as LE -func (stream *StreamOut) WriteFloat32LE(f32 float32) { - stream.Grow(4) - stream.WriteF32LENext([]float32{f32}) +func (s *StreamOut) WriteFloat32LE(f32 float32) { + s.Grow(4) + s.WriteF32LENext([]float32{f32}) } // WriteFloat32BE writes a float32 as BE -func (stream *StreamOut) WriteFloat32BE(f32 float32) { - stream.Grow(4) - stream.WriteF32BENext([]float32{f32}) +func (s *StreamOut) WriteFloat32BE(f32 float32) { + s.Grow(4) + s.WriteF32BENext([]float32{f32}) } // WriteFloat64LE writes a float64 as LE -func (stream *StreamOut) WriteFloat64LE(f64 float64) { - stream.Grow(8) - stream.WriteF64LENext([]float64{f64}) +func (s *StreamOut) WriteFloat64LE(f64 float64) { + s.Grow(8) + s.WriteF64LENext([]float64{f64}) } // WriteFloat64BE writes a float64 as BE -func (stream *StreamOut) WriteFloat64BE(f64 float64) { - stream.Grow(8) - stream.WriteF64BENext([]float64{f64}) +func (s *StreamOut) WriteFloat64BE(f64 float64) { + s.Grow(8) + s.WriteF64BENext([]float64{f64}) } // WriteBool writes a bool -func (stream *StreamOut) WriteBool(b bool) { +func (s *StreamOut) WriteBool(b bool) { var bVar uint8 if b { bVar = 1 } - stream.Grow(1) - stream.WriteByteNext(byte(bVar)) + s.Grow(1) + s.WriteByteNext(byte(bVar)) } // WritePID writes a NEX PID. The size depends on the server version -func (stream *StreamOut) WritePID(pid *PID) { - if stream.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - stream.WriteUInt64LE(pid.pid) +func (s *StreamOut) WritePID(pid *PID) { + if s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + s.WriteUInt64LE(pid.pid) } else { - stream.WriteUInt32LE(uint32(pid.pid)) + s.WriteUInt32LE(uint32(pid.pid)) } } // WriteString writes a NEX string type -func (stream *StreamOut) WriteString(str string) { +func (s *StreamOut) WriteString(str string) { str = str + "\x00" strLength := len(str) - if stream.Server == nil { - stream.WriteUInt16LE(uint16(strLength)) - } else if stream.Server.StringLengthSize() == 4 { - stream.WriteUInt32LE(uint32(strLength)) + if s.Server == nil { + s.WriteUInt16LE(uint16(strLength)) + } else if s.Server.StringLengthSize() == 4 { + s.WriteUInt32LE(uint32(strLength)) } else { - stream.WriteUInt16LE(uint16(strLength)) + s.WriteUInt16LE(uint16(strLength)) } - stream.Grow(int64(strLength)) - stream.WriteBytesNext([]byte(str)) + s.Grow(int64(strLength)) + s.WriteBytesNext([]byte(str)) } // WriteBuffer writes a NEX Buffer type -func (stream *StreamOut) WriteBuffer(data []byte) { +func (s *StreamOut) WriteBuffer(data []byte) { dataLength := len(data) - stream.WriteUInt32LE(uint32(dataLength)) + s.WriteUInt32LE(uint32(dataLength)) if dataLength > 0 { - stream.Grow(int64(dataLength)) - stream.WriteBytesNext(data) + s.Grow(int64(dataLength)) + s.WriteBytesNext(data) } } // WriteQBuffer writes a NEX qBuffer type -func (stream *StreamOut) WriteQBuffer(data []byte) { +func (s *StreamOut) WriteQBuffer(data []byte) { dataLength := len(data) - stream.WriteUInt16LE(uint16(dataLength)) + s.WriteUInt16LE(uint16(dataLength)) if dataLength > 0 { - stream.Grow(int64(dataLength)) - stream.WriteBytesNext(data) + s.Grow(int64(dataLength)) + s.WriteBytesNext(data) } } // WriteResult writes a NEX Result type -func (stream *StreamOut) WriteResult(result *Result) { - stream.WriteUInt32LE(result.Code) +func (s *StreamOut) WriteResult(result *Result) { + s.WriteUInt32LE(result.Code) } // WriteStructure writes a nex Structure type -func (stream *StreamOut) WriteStructure(structure StructureInterface) { +func (s *StreamOut) WriteStructure(structure StructureInterface) { if structure.ParentType() != nil { - stream.WriteStructure(structure.ParentType()) + s.WriteStructure(structure.ParentType()) } - content := structure.Bytes(NewStreamOut(stream.Server)) + content := structure.Bytes(NewStreamOut(s.Server)) useStructures := false - if stream.Server != nil { - switch server := stream.Server.(type) { + if s.Server != nil { + switch server := s.Server.(type) { case *PRUDPServer: // * Support QRV versions useStructures = server.PRUDPMinorVersion >= 3 default: @@ -203,275 +203,291 @@ func (stream *StreamOut) WriteStructure(structure StructureInterface) { } if useStructures { - stream.WriteUInt8(structure.StructureVersion()) - stream.WriteUInt32LE(uint32(len(content))) + s.WriteUInt8(structure.StructureVersion()) + s.WriteUInt32LE(uint32(len(content))) } - stream.Grow(int64(len(content))) - stream.WriteBytesNext(content) + s.Grow(int64(len(content))) + s.WriteBytesNext(content) } // WriteStationURL writes a StationURL type -func (stream *StreamOut) WriteStationURL(stationURL *StationURL) { - stream.WriteString(stationURL.EncodeToString()) +func (s *StreamOut) WriteStationURL(stationURL *StationURL) { + s.WriteString(stationURL.EncodeToString()) +} + +// WriteDataHolder writes a NEX DataHolder type +func (s *StreamOut) WriteDataHolder(dataholder *DataHolder) { + content := dataholder.Bytes(NewStreamOut(s.Server)) + s.Grow(int64(len(content))) + s.WriteBytesNext(content) +} + +// WriteDateTime writes a NEX DateTime type +func (s *StreamOut) WriteDateTime(datetime *DateTime) { + s.WriteUInt64LE(datetime.value) +} + +// WriteVariant writes a Variant type +func (s *StreamOut) WriteVariant(variant *Variant) { + content := variant.Bytes(NewStreamOut(s.Server)) + s.Grow(int64(len(content))) + s.WriteBytesNext(content) +} + +// WriteQUUID writes a qUUID type +func (s *StreamOut) WriteQUUID(qUUID *QUUID) { + qUUID.Bytes(s) } // WriteListUInt8 writes a list of uint8 types -func (stream *StreamOut) WriteListUInt8(list []uint8) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListUInt8(list []uint8) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteUInt8(list[i]) + s.WriteUInt8(list[i]) } } // WriteListInt8 writes a list of int8 types -func (stream *StreamOut) WriteListInt8(list []int8) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListInt8(list []int8) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteInt8(list[i]) + s.WriteInt8(list[i]) } } // WriteListUInt16LE writes a list of Little-Endian encoded uint16 types -func (stream *StreamOut) WriteListUInt16LE(list []uint16) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListUInt16LE(list []uint16) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteUInt16LE(list[i]) + s.WriteUInt16LE(list[i]) } } // WriteListUInt16BE writes a list of Big-Endian encoded uint16 types -func (stream *StreamOut) WriteListUInt16BE(list []uint16) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListUInt16BE(list []uint16) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteUInt16BE(list[i]) + s.WriteUInt16BE(list[i]) } } // WriteListInt16LE writes a list of Little-Endian encoded int16 types -func (stream *StreamOut) WriteListInt16LE(list []int16) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListInt16LE(list []int16) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteInt16LE(list[i]) + s.WriteInt16LE(list[i]) } } // WriteListInt16BE writes a list of Big-Endian encoded int16 types -func (stream *StreamOut) WriteListInt16BE(list []int16) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListInt16BE(list []int16) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteInt16BE(list[i]) + s.WriteInt16BE(list[i]) } } // WriteListUInt32LE writes a list of Little-Endian encoded uint32 types -func (stream *StreamOut) WriteListUInt32LE(list []uint32) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListUInt32LE(list []uint32) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteUInt32LE(list[i]) + s.WriteUInt32LE(list[i]) } } // WriteListUInt32BE writes a list of Big-Endian encoded uint32 types -func (stream *StreamOut) WriteListUInt32BE(list []uint32) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListUInt32BE(list []uint32) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteUInt32BE(list[i]) + s.WriteUInt32BE(list[i]) } } // WriteListInt32LE writes a list of Little-Endian encoded int32 types -func (stream *StreamOut) WriteListInt32LE(list []int32) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListInt32LE(list []int32) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteInt32LE(list[i]) + s.WriteInt32LE(list[i]) } } // WriteListInt32BE writes a list of Big-Endian encoded int32 types -func (stream *StreamOut) WriteListInt32BE(list []int32) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListInt32BE(list []int32) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteInt32BE(list[i]) + s.WriteInt32BE(list[i]) } } // WriteListUInt64LE writes a list of Little-Endian encoded uint64 types -func (stream *StreamOut) WriteListUInt64LE(list []uint64) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListUInt64LE(list []uint64) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteUInt64LE(list[i]) + s.WriteUInt64LE(list[i]) } } // WriteListUInt64BE writes a list of Big-Endian encoded uint64 types -func (stream *StreamOut) WriteListUInt64BE(list []uint64) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListUInt64BE(list []uint64) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteUInt64BE(list[i]) + s.WriteUInt64BE(list[i]) } } // WriteListInt64LE writes a list of Little-Endian encoded int64 types -func (stream *StreamOut) WriteListInt64LE(list []int64) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListInt64LE(list []int64) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteInt64LE(list[i]) + s.WriteInt64LE(list[i]) } } // WriteListInt64BE writes a list of Big-Endian encoded int64 types -func (stream *StreamOut) WriteListInt64BE(list []int64) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListInt64BE(list []int64) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteInt64BE(list[i]) + s.WriteInt64BE(list[i]) } } // WriteListFloat32LE writes a list of Little-Endian encoded float32 types -func (stream *StreamOut) WriteListFloat32LE(list []float32) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListFloat32LE(list []float32) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteFloat32LE(list[i]) + s.WriteFloat32LE(list[i]) } } // WriteListFloat32BE writes a list of Big-Endian encoded float32 types -func (stream *StreamOut) WriteListFloat32BE(list []float32) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListFloat32BE(list []float32) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteFloat32BE(list[i]) + s.WriteFloat32BE(list[i]) } } // WriteListFloat64LE writes a list of Little-Endian encoded float64 types -func (stream *StreamOut) WriteListFloat64LE(list []float64) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListFloat64LE(list []float64) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteFloat64LE(list[i]) + s.WriteFloat64LE(list[i]) } } // WriteListFloat64BE writes a list of Big-Endian encoded float64 types -func (stream *StreamOut) WriteListFloat64BE(list []float64) { - stream.WriteUInt32LE(uint32(len(list))) +func (s *StreamOut) WriteListFloat64BE(list []float64) { + s.WriteUInt32LE(uint32(len(list))) for i := 0; i < len(list); i++ { - stream.WriteFloat64BE(list[i]) + s.WriteFloat64BE(list[i]) } } // WriteListPID writes a list of NEX PIDs -func (stream *StreamOut) WriteListPID(pids []*PID) { +func (s *StreamOut) WriteListPID(pids []*PID) { length := len(pids) - stream.WriteUInt32LE(uint32(length)) + s.WriteUInt32LE(uint32(length)) for i := 0; i < length; i++ { - stream.WritePID(pids[i]) + s.WritePID(pids[i]) } } // WriteListString writes a list of NEX String types -func (stream *StreamOut) WriteListString(strings []string) { +func (s *StreamOut) WriteListString(strings []string) { length := len(strings) - stream.WriteUInt32LE(uint32(length)) + s.WriteUInt32LE(uint32(length)) for i := 0; i < length; i++ { - stream.WriteString(strings[i]) + s.WriteString(strings[i]) } } // WriteListBuffer writes a list of NEX Buffer types -func (stream *StreamOut) WriteListBuffer(buffers [][]byte) { +func (s *StreamOut) WriteListBuffer(buffers [][]byte) { length := len(buffers) - stream.WriteUInt32LE(uint32(length)) + s.WriteUInt32LE(uint32(length)) for i := 0; i < length; i++ { - stream.WriteBuffer(buffers[i]) + s.WriteBuffer(buffers[i]) } } // WriteListQBuffer writes a list of NEX qBuffer types -func (stream *StreamOut) WriteListQBuffer(buffers [][]byte) { +func (s *StreamOut) WriteListQBuffer(buffers [][]byte) { length := len(buffers) - stream.WriteUInt32LE(uint32(length)) + s.WriteUInt32LE(uint32(length)) for i := 0; i < length; i++ { - stream.WriteQBuffer(buffers[i]) + s.WriteQBuffer(buffers[i]) } } // WriteListResult writes a list of NEX Result types -func (stream *StreamOut) WriteListResult(results []*Result) { +func (s *StreamOut) WriteListResult(results []*Result) { length := len(results) - stream.WriteUInt32LE(uint32(length)) + s.WriteUInt32LE(uint32(length)) for i := 0; i < length; i++ { - stream.WriteResult(results[i]) + s.WriteResult(results[i]) } } // WriteListStationURL writes a list of NEX StationURL types -func (stream *StreamOut) WriteListStationURL(stationURLs []*StationURL) { +func (s *StreamOut) WriteListStationURL(stationURLs []*StationURL) { length := len(stationURLs) - stream.WriteUInt32LE(uint32(length)) + s.WriteUInt32LE(uint32(length)) for i := 0; i < length; i++ { - stream.WriteString(stationURLs[i].EncodeToString()) + s.WriteString(stationURLs[i].EncodeToString()) } } // WriteListDataHolder writes a NEX DataHolder type -func (stream *StreamOut) WriteListDataHolder(dataholders []*DataHolder) { +func (s *StreamOut) WriteListDataHolder(dataholders []*DataHolder) { length := len(dataholders) - stream.WriteUInt32LE(uint32(length)) + s.WriteUInt32LE(uint32(length)) for i := 0; i < length; i++ { - stream.WriteDataHolder(dataholders[i]) + s.WriteDataHolder(dataholders[i]) } } -// WriteDataHolder writes a NEX DataHolder type -func (stream *StreamOut) WriteDataHolder(dataholder *DataHolder) { - content := dataholder.Bytes(NewStreamOut(stream.Server)) - stream.Grow(int64(len(content))) - stream.WriteBytesNext(content) -} +// WriteListQUUID writes a NEX qUUID type +func (s *StreamOut) WriteListQUUID(qUUIDs []*QUUID) { + length := len(qUUIDs) -// WriteDateTime writes a NEX DateTime type -func (stream *StreamOut) WriteDateTime(datetime *DateTime) { - stream.WriteUInt64LE(datetime.value) -} + s.WriteUInt32LE(uint32(length)) -// WriteVariant writes a Variant type -func (stream *StreamOut) WriteVariant(variant *Variant) { - content := variant.Bytes(NewStreamOut(stream.Server)) - stream.Grow(int64(len(content))) - stream.WriteBytesNext(content) + for i := 0; i < length; i++ { + s.WriteQUUID(qUUIDs[i]) + } } // NewStreamOut returns a new nex output stream diff --git a/types.go b/types.go index b987e4b2..852281a4 100644 --- a/types.go +++ b/types.go @@ -2,8 +2,10 @@ package nex import ( "bytes" + "encoding/hex" "errors" "fmt" + "slices" "strings" "time" ) @@ -860,31 +862,34 @@ type Variant struct { Str string DateTime *DateTime UInt64 uint64 + QUUID *QUUID } // ExtractFromStream extracts a Variant structure from a stream -func (variant *Variant) ExtractFromStream(stream *StreamIn) error { +func (v *Variant) ExtractFromStream(stream *StreamIn) error { var err error - variant.TypeID, err = stream.ReadUInt8() + v.TypeID, err = stream.ReadUInt8() if err != nil { return fmt.Errorf("Failed to read Variant type ID. %s", err.Error()) } // * A type ID of 0 means no value - switch variant.TypeID { + switch v.TypeID { case 1: // * sint64 - variant.Int64, err = stream.ReadInt64LE() + v.Int64, err = stream.ReadInt64LE() case 2: // * double - variant.Float64, err = stream.ReadFloat64LE() + v.Float64, err = stream.ReadFloat64LE() case 3: // * bool - variant.Bool, err = stream.ReadBool() + v.Bool, err = stream.ReadBool() case 4: // * string - variant.Str, err = stream.ReadString() + v.Str, err = stream.ReadString() case 5: // * datetime - variant.DateTime, err = stream.ReadDateTime() + v.DateTime, err = stream.ReadDateTime() case 6: // * uint64 - variant.UInt64, err = stream.ReadUInt64LE() + v.UInt64, err = stream.ReadUInt64LE() + case 7: // * qUUID + v.QUUID, err = stream.ReadQUUID() } // * These errors contain details about each of the values type @@ -897,104 +902,114 @@ func (variant *Variant) ExtractFromStream(stream *StreamIn) error { } // Bytes encodes the Variant and returns a byte array -func (variant *Variant) Bytes(stream *StreamOut) []byte { - stream.WriteUInt8(variant.TypeID) +func (v *Variant) Bytes(stream *StreamOut) []byte { + stream.WriteUInt8(v.TypeID) // * A type ID of 0 means no value - switch variant.TypeID { + switch v.TypeID { case 1: // * sint64 - stream.WriteInt64LE(variant.Int64) + stream.WriteInt64LE(v.Int64) case 2: // * double - stream.WriteFloat64LE(variant.Float64) + stream.WriteFloat64LE(v.Float64) case 3: // * bool - stream.WriteBool(variant.Bool) + stream.WriteBool(v.Bool) case 4: // * string - stream.WriteString(variant.Str) + stream.WriteString(v.Str) case 5: // * datetime - stream.WriteDateTime(variant.DateTime) + stream.WriteDateTime(v.DateTime) case 6: // * uint64 - stream.WriteUInt64LE(variant.UInt64) + stream.WriteUInt64LE(v.UInt64) + case 7: // * qUUID + stream.WriteQUUID(v.QUUID) } return stream.Bytes() } // Copy returns a new copied instance of Variant -func (variant *Variant) Copy() *Variant { +func (v *Variant) Copy() *Variant { copied := NewVariant() - copied.TypeID = variant.TypeID - copied.Int64 = variant.Int64 - copied.Float64 = variant.Float64 - copied.Bool = variant.Bool - copied.Str = variant.Str + copied.TypeID = v.TypeID + copied.Int64 = v.Int64 + copied.Float64 = v.Float64 + copied.Bool = v.Bool + copied.Str = v.Str - if variant.DateTime != nil { - copied.DateTime = variant.DateTime.Copy() + if v.DateTime != nil { + copied.DateTime = v.DateTime.Copy() } - copied.UInt64 = variant.UInt64 + copied.UInt64 = v.UInt64 + + if v.QUUID != nil { + copied.QUUID = v.QUUID.Copy() + } return copied } // Equals checks if the passed Structure contains the same data as the current instance -func (variant *Variant) Equals(other *Variant) bool { - if variant.TypeID != other.TypeID { +func (v *Variant) Equals(other *Variant) bool { + if v.TypeID != other.TypeID { return false } // * A type ID of 0 means no value - switch variant.TypeID { + switch v.TypeID { case 0: // * no value, always equal return true case 1: // * sint64 - return variant.Int64 == other.Int64 + return v.Int64 == other.Int64 case 2: // * double - return variant.Float64 == other.Float64 + return v.Float64 == other.Float64 case 3: // * bool - return variant.Bool == other.Bool + return v.Bool == other.Bool case 4: // * string - return variant.Str == other.Str + return v.Str == other.Str case 5: // * datetime - return variant.DateTime.Equals(other.DateTime) + return v.DateTime.Equals(other.DateTime) case 6: // * uint64 - return variant.UInt64 == other.UInt64 + return v.UInt64 == other.UInt64 + case 7: // * qUUID + return v.QUUID.Equals(other.QUUID) default: // * Something went horribly wrong return false } } // String returns a string representation of the struct -func (variant *Variant) String() string { - return variant.FormatToString(0) +func (v *Variant) String() string { + return v.FormatToString(0) } // FormatToString pretty-prints the struct data using the provided indentation level -func (variant *Variant) FormatToString(indentationLevel int) string { +func (v *Variant) FormatToString(indentationLevel int) string { indentationValues := strings.Repeat("\t", indentationLevel+1) indentationEnd := strings.Repeat("\t", indentationLevel) var b strings.Builder b.WriteString("Variant{\n") - b.WriteString(fmt.Sprintf("%sTypeID: %d\n", indentationValues, variant.TypeID)) + b.WriteString(fmt.Sprintf("%sTypeID: %d\n", indentationValues, v.TypeID)) - switch variant.TypeID { + switch v.TypeID { case 0: // * no value b.WriteString(fmt.Sprintf("%svalue: nil\n", indentationValues)) case 1: // * sint64 - b.WriteString(fmt.Sprintf("%svalue: %d\n", indentationValues, variant.Int64)) + b.WriteString(fmt.Sprintf("%svalue: %d\n", indentationValues, v.Int64)) case 2: // * double - b.WriteString(fmt.Sprintf("%svalue: %g\n", indentationValues, variant.Float64)) + b.WriteString(fmt.Sprintf("%svalue: %g\n", indentationValues, v.Float64)) case 3: // * bool - b.WriteString(fmt.Sprintf("%svalue: %t\n", indentationValues, variant.Bool)) + b.WriteString(fmt.Sprintf("%svalue: %t\n", indentationValues, v.Bool)) case 4: // * string - b.WriteString(fmt.Sprintf("%svalue: %q\n", indentationValues, variant.Str)) + b.WriteString(fmt.Sprintf("%svalue: %q\n", indentationValues, v.Str)) case 5: // * datetime - b.WriteString(fmt.Sprintf("%svalue: %s\n", indentationValues, variant.DateTime.FormatToString(indentationLevel+1))) + b.WriteString(fmt.Sprintf("%svalue: %s\n", indentationValues, v.DateTime.FormatToString(indentationLevel+1))) case 6: // * uint64 - b.WriteString(fmt.Sprintf("%svalue: %d\n", indentationValues, variant.UInt64)) + b.WriteString(fmt.Sprintf("%svalue: %d\n", indentationValues, v.UInt64)) + case 7: // * qUUID + b.WriteString(fmt.Sprintf("%svalue: %s\n", indentationValues, v.QUUID.FormatToString(indentationLevel+1))) default: b.WriteString(fmt.Sprintf("%svalue: Unknown\n", indentationValues)) } @@ -1112,3 +1127,173 @@ func NewClassVersionContainer() *ClassVersionContainer { ClassVersions: make(map[string]uint16), } } + +// QUUID represents a QRV qUUID type. This type encodes a UUID in little-endian byte order +type QUUID struct { + Data []byte +} + +// ExtractFromStream extracts a qUUID structure from a stream +func (qu *QUUID) ExtractFromStream(stream *StreamIn) error { + if stream.Remaining() < int(16) { + return errors.New("Not enough data left to read qUUID") + } + + qu.Data = stream.ReadBytesNext(16) + + return nil +} + +// Bytes encodes the qUUID and returns a byte array +func (qu *QUUID) Bytes(stream *StreamOut) []byte { + stream.Grow(int64(len(qu.Data))) + stream.WriteBytesNext(qu.Data) + + return stream.Bytes() +} + +// Copy returns a new copied instance of qUUID +func (qu *QUUID) Copy() *QUUID { + copied := NewQUUID() + + copied.Data = make([]byte, len(qu.Data)) + + copy(copied.Data, qu.Data) + + return copied +} + +// Equals checks if the passed Structure contains the same data as the current instance +func (qu *QUUID) Equals(other *QUUID) bool { + return qu.GetStringValue() == other.GetStringValue() +} + +// String returns a string representation of the struct +func (qu *QUUID) String() string { + return qu.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (qu *QUUID) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("qUUID{\n") + b.WriteString(fmt.Sprintf("%sUUID: %s\n", indentationValues, qu.GetStringValue())) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// GetStringValue returns the UUID encoded in the qUUID +func (qu *QUUID) GetStringValue() string { + // * Create copy of the data since slices.Reverse modifies the slice in-line + data := make([]byte, len(qu.Data)) + copy(data, qu.Data) + + if len(data) != 16 { + // * Default dummy UUID as found in WATCH_DOGS + return "00000000-0000-0000-0000-000000000002" + } + + section1 := data[0:4] + section2 := data[4:6] + section3 := data[6:8] + section4 := data[8:10] + section5_1 := data[10:12] + section5_2 := data[12:14] + section5_3 := data[14:16] + + slices.Reverse(section1) + slices.Reverse(section2) + slices.Reverse(section3) + slices.Reverse(section4) + slices.Reverse(section5_1) + slices.Reverse(section5_2) + slices.Reverse(section5_3) + + var b strings.Builder + + b.WriteString(hex.EncodeToString(section1)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section2)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section3)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section4)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section5_1)) + b.WriteString(hex.EncodeToString(section5_2)) + b.WriteString(hex.EncodeToString(section5_3)) + + return b.String() +} + +// FromString converts a UUID string to a qUUID +func (qu *QUUID) FromString(uuid string) error { + + sections := strings.Split(uuid, "-") + if len(sections) != 5 { + return fmt.Errorf("Invalid UUID. Not enough sections. Expected 5, got %d", len(sections)) + } + + data := make([]byte, 0, 16) + + var appendSection = func(section string, expectedSize int) error { + sectionBytes, err := hex.DecodeString(section) + if err != nil { + return err + } + + if len(sectionBytes) != expectedSize { + return fmt.Errorf("Unexpected section size. Expected %d, got %d", expectedSize, len(sectionBytes)) + } + + data = append(data, sectionBytes...) + + return nil + } + + if err := appendSection(sections[0], 4); err != nil { + return fmt.Errorf("Failed to read UUID section 1. %s", err.Error()) + } + + if err := appendSection(sections[1], 2); err != nil { + return fmt.Errorf("Failed to read UUID section 2. %s", err.Error()) + } + + if err := appendSection(sections[2], 2); err != nil { + return fmt.Errorf("Failed to read UUID section 3. %s", err.Error()) + } + + if err := appendSection(sections[3], 2); err != nil { + return fmt.Errorf("Failed to read UUID section 4. %s", err.Error()) + } + + if err := appendSection(sections[4], 6); err != nil { + return fmt.Errorf("Failed to read UUID section 5. %s", err.Error()) + } + + slices.Reverse(data[0:4]) + slices.Reverse(data[4:6]) + slices.Reverse(data[6:8]) + slices.Reverse(data[8:10]) + slices.Reverse(data[10:12]) + slices.Reverse(data[12:14]) + slices.Reverse(data[14:16]) + + qu.Data = make([]byte, 0, 16) + + copy(qu.Data, data) + + return nil +} + +// NewQUUID returns a new qUUID +func NewQUUID() *QUUID { + return &QUUID{ + Data: make([]byte, 0, 16), + } +} From 715e3182cdc93ee810b0d2d06314cac967fd09bf Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 13:03:32 -0500 Subject: [PATCH 090/178] expanded all types into dedicated files and made the stream API generic --- client_interface.go | 10 +- .../{algorithm_interface.go => algorithm.go} | 0 hpp_client.go | 12 +- hpp_server.go | 10 +- init.go | 12 +- kerberos.go | 47 +- prudp_client.go | 11 +- prudp_packet_lite.go | 62 +- prudp_packet_v0.go | 42 +- prudp_packet_v1.go | 76 +- prudp_server.go | 48 +- rmc_message.go | 126 +- server_interface.go | 6 +- stream_in.go | 1026 +------------ stream_out.go | 580 +------- test/auth.go | 51 +- test/generate_ticket.go | 13 +- test/main.go | 2 +- test/secure.go | 85 +- types.go | 1299 ----------------- types/any_data_holder.go | 109 ++ types/buffer.go | 61 + types/class_version_container.go | 42 + types/datetime.go | 141 ++ types/list.go | 97 ++ types/map.go | 152 ++ types/null_data.go | 84 ++ types/pid.go | 116 ++ types/primitive_bool.go | 44 + types/primitive_float32.go | 44 + types/primitive_float64.go | 44 + types/primitive_s16.go | 44 + types/primitive_s32.go | 44 + types/primitive_s64.go | 44 + types/primitive_s8.go | 44 + types/primitive_u16.go | 44 + types/primitive_u32.go | 44 + types/primitive_u64.go | 44 + types/primitive_u8.go | 44 + types/qbuffer.go | 61 + types/quuid.go | 181 +++ types/readable.go | 22 + types/result.go | 97 ++ types/result_range.go | 102 ++ types/rv_connection_data.go | 143 ++ types/rv_type.go | 10 + types/station_url.go | 169 +++ types/string.go | 82 ++ types/structure.go | 39 + types/variant.go | 74 + types/writable.go | 22 + 51 files changed, 2764 insertions(+), 3042 deletions(-) rename compression/{algorithm_interface.go => algorithm.go} (100%) delete mode 100644 types.go create mode 100644 types/any_data_holder.go create mode 100644 types/buffer.go create mode 100644 types/class_version_container.go create mode 100644 types/datetime.go create mode 100644 types/list.go create mode 100644 types/map.go create mode 100644 types/null_data.go create mode 100644 types/pid.go create mode 100644 types/primitive_bool.go create mode 100644 types/primitive_float32.go create mode 100644 types/primitive_float64.go create mode 100644 types/primitive_s16.go create mode 100644 types/primitive_s32.go create mode 100644 types/primitive_s64.go create mode 100644 types/primitive_s8.go create mode 100644 types/primitive_u16.go create mode 100644 types/primitive_u32.go create mode 100644 types/primitive_u64.go create mode 100644 types/primitive_u8.go create mode 100644 types/qbuffer.go create mode 100644 types/quuid.go create mode 100644 types/readable.go create mode 100644 types/result.go create mode 100644 types/result_range.go create mode 100644 types/rv_connection_data.go create mode 100644 types/rv_type.go create mode 100644 types/station_url.go create mode 100644 types/string.go create mode 100644 types/structure.go create mode 100644 types/variant.go create mode 100644 types/writable.go diff --git a/client_interface.go b/client_interface.go index 041f7555..7788468a 100644 --- a/client_interface.go +++ b/client_interface.go @@ -1,12 +1,16 @@ // Package nex provides a collection of utility structs, functions, and data types for making NEX/QRV servers package nex -import "net" +import ( + "net" + + "github.com/PretendoNetwork/nex-go/types" +) // ClientInterface defines all the methods a client should have regardless of server type type ClientInterface interface { Server() ServerInterface Address() net.Addr - PID() *PID - SetPID(pid *PID) + PID() *types.PID + SetPID(pid *types.PID) } diff --git a/compression/algorithm_interface.go b/compression/algorithm.go similarity index 100% rename from compression/algorithm_interface.go rename to compression/algorithm.go diff --git a/hpp_client.go b/hpp_client.go index 9f47432f..480b453e 100644 --- a/hpp_client.go +++ b/hpp_client.go @@ -1,12 +1,16 @@ package nex -import "net" +import ( + "net" + + "github.com/PretendoNetwork/nex-go/types" +) // HPPClient represents a single HPP client type HPPClient struct { address *net.TCPAddr server *HPPServer - pid *PID + pid *types.PID } // Server returns the server the client is connecting to @@ -20,12 +24,12 @@ func (c *HPPClient) Address() net.Addr { } // PID returns the clients NEX PID -func (c *HPPClient) PID() *PID { +func (c *HPPClient) PID() *types.PID { return c.pid } // SetPID sets the clients NEX PID -func (c *HPPClient) SetPID(pid *PID) { +func (c *HPPClient) SetPID(pid *types.PID) { c.pid = pid } diff --git a/hpp_server.go b/hpp_server.go index 59e56d7c..0b0c9fbd 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -6,6 +6,8 @@ import ( "net" "net/http" "strconv" + + "github.com/PretendoNetwork/nex-go/types" ) // HPPServer represents a bare-bones HPP server @@ -21,7 +23,7 @@ type HPPServer struct { utilityProtocolVersion *LibraryVersion natTraversalProtocolVersion *LibraryVersion dataHandlers []func(packet PacketInterface) - passwordFromPIDHandler func(pid *PID) (string, uint32) + passwordFromPIDHandler func(pid *types.PID) (string, uint32) stringLengthSize int } @@ -80,7 +82,7 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { } client := NewHPPClient(tcpAddr, s) - client.SetPID(NewPID(uint32(pid))) + client.SetPID(types.NewPID(uint32(pid))) hppPacket, err := NewHPPPacket(client, rmcRequestBytes) if err != nil { @@ -257,7 +259,7 @@ func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { } // PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result -func (s *HPPServer) PasswordFromPID(pid *PID) (string, uint32) { +func (s *HPPServer) PasswordFromPID(pid *types.PID) (string, uint32) { if s.passwordFromPIDHandler == nil { logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") return "", Errors.Core.NotImplemented @@ -267,7 +269,7 @@ func (s *HPPServer) PasswordFromPID(pid *PID) (string, uint32) { } // SetPasswordFromPIDFunction sets the function for HPP to get a NEX password using the PID -func (s *HPPServer) SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) { +func (s *HPPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) { s.passwordFromPIDHandler = handler } diff --git a/init.go b/init.go index c69cf42c..10dd3496 100644 --- a/init.go +++ b/init.go @@ -1,9 +1,19 @@ package nex -import "github.com/PretendoNetwork/plogger-go" +import ( + "github.com/PretendoNetwork/nex-go/types" + "github.com/PretendoNetwork/plogger-go" +) var logger = plogger.NewLogger() func init() { initErrorsData() + + types.RegisterVariantType(1, types.NewPrimitiveS64()) + types.RegisterVariantType(2, types.NewPrimitiveF64()) + types.RegisterVariantType(3, types.NewPrimitiveBool()) + types.RegisterVariantType(4, types.NewString()) + types.RegisterVariantType(5, types.NewDateTime(0)) + types.RegisterVariantType(6, types.NewPrimitiveU64()) } diff --git a/kerberos.go b/kerberos.go index 4671d685..0bdedb78 100644 --- a/kerberos.go +++ b/kerberos.go @@ -7,6 +7,8 @@ import ( "crypto/rc4" "errors" "fmt" + + "github.com/PretendoNetwork/nex-go/types" ) // KerberosEncryption is a struct representing a Kerberos encryption utility @@ -67,8 +69,8 @@ func NewKerberosEncryption(key []byte) *KerberosEncryption { // KerberosTicket represents a ticket granting a user access to a secure server type KerberosTicket struct { SessionKey []byte - TargetPID *PID - InternalData []byte + TargetPID *types.PID + InternalData *types.Buffer } // Encrypt writes the ticket data to the provided stream and returns the encrypted byte slice @@ -78,8 +80,8 @@ func (kt *KerberosTicket) Encrypt(key []byte, stream *StreamOut) ([]byte, error) stream.Grow(int64(len(kt.SessionKey))) stream.WriteBytesNext(kt.SessionKey) - stream.WritePID(kt.TargetPID) - stream.WriteBuffer(kt.InternalData) + kt.TargetPID.WriteTo(stream) + kt.InternalData.WriteTo(stream) return encryption.Encrypt(stream.Bytes()), nil } @@ -91,15 +93,15 @@ func NewKerberosTicket() *KerberosTicket { // KerberosTicketInternalData holds the internal data for a kerberos ticket to be processed by the server type KerberosTicketInternalData struct { - Issued *DateTime - SourcePID *PID + Issued *types.DateTime + SourcePID *types.PID SessionKey []byte } // Encrypt writes the ticket data to the provided stream and returns the encrypted byte slice func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { - stream.WriteDateTime(ti.Issued) - stream.WritePID(ti.SourcePID) + ti.Issued.WriteTo(stream) + ti.SourcePID.WriteTo(stream) stream.Grow(int64(len(ti.SessionKey))) stream.WriteBytesNext(ti.SessionKey) @@ -122,8 +124,11 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([] finalStream := NewStreamOut(stream.Server) - finalStream.WriteBuffer(ticketKey) - finalStream.WriteBuffer(encrypted) + var ticketBuffer types.Buffer = ticketKey + var encryptedBuffer types.Buffer = encrypted + + ticketBuffer.WriteTo(finalStream) + encryptedBuffer.WriteTo(finalStream) return finalStream.Bytes(), nil } @@ -136,20 +141,20 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([] // Decrypt decrypts the given data and populates the struct func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) error { if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { - ticketKey, err := stream.ReadBuffer() - if err != nil { + ticketKey := types.NewBuffer() + if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data key. %s", err.Error()) } - data, err := stream.ReadBuffer() - if err != nil { + data := types.NewBuffer() + if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data. %s", err.Error()) } - hash := md5.Sum(append(key, ticketKey...)) + hash := md5.Sum(append(key, *ticketKey...)) key = hash[:] - stream = NewStreamIn(data, stream.Server) + stream = NewStreamIn(*data, stream.Server) } encryption := NewKerberosEncryption(key) @@ -161,13 +166,13 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) erro stream = NewStreamIn(decrypted, stream.Server) - timestamp, err := stream.ReadDateTime() - if err != nil { + timestamp := types.NewDateTime(0) + if err := timestamp.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data timestamp %s", err.Error()) } - userPID, err := stream.ReadPID() - if err != nil { + userPID := types.NewPID[uint64](0) + if err := userPID.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data user PID %s", err.Error()) } @@ -184,7 +189,7 @@ func NewKerberosTicketInternalData() *KerberosTicketInternalData { } // DeriveKerberosKey derives a users kerberos encryption key based on their PID and password -func DeriveKerberosKey(pid *PID, password []byte) []byte { +func DeriveKerberosKey(pid *types.PID, password []byte) []byte { key := password for i := 0; i < 65000+int(pid.Value())%1024; i++ { diff --git a/prudp_client.go b/prudp_client.go index 53c30370..abb538e4 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -6,6 +6,7 @@ import ( "net" "time" + "github.com/PretendoNetwork/nex-go/types" "github.com/lxzan/gws" ) @@ -14,7 +15,7 @@ type PRUDPClient struct { server *PRUDPServer address net.Addr webSocketConnection *gws.Conn - pid *PID + pid *types.PID clientConnectionSignature []byte serverConnectionSignature []byte clientSessionID uint8 @@ -32,7 +33,7 @@ type PRUDPClient struct { minorVersion uint32 // * Not currently used for anything, but maybe useful later? supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? ConnectionID uint32 - StationURLs []*StationURL + StationURLs []*types.StationURL unreliableBaseKey []byte } @@ -86,12 +87,12 @@ func (c *PRUDPClient) Address() net.Addr { } // PID returns the clients NEX PID -func (c *PRUDPClient) PID() *PID { +func (c *PRUDPClient) PID() *types.PID { return c.pid } // SetPID sets the clients NEX PID -func (c *PRUDPClient) SetPID(pid *PID) { +func (c *PRUDPClient) SetPID(pid *types.PID) { c.pid = pid } @@ -223,7 +224,7 @@ func NewPRUDPClient(server *PRUDPServer, address net.Addr, webSocketConnection * address: address, webSocketConnection: webSocketConnection, outgoingPingSequenceIDCounter: NewCounter[uint16](0), - pid: NewPID[uint32](0), + pid: types.NewPID[uint32](0), unreliableBaseKey: make([]byte, 0x20), } } diff --git a/prudp_packet_lite.go b/prudp_packet_lite.go index ed1d3f10..349f0872 100644 --- a/prudp_packet_lite.go +++ b/prudp_packet_lite.go @@ -71,7 +71,7 @@ func (p *PRUDPPacketLite) Version() int { // decode parses the packets data func (p *PRUDPPacketLite) decode() error { - magic, err := p.readStream.ReadUInt8() + magic, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPLite magic. %s", err.Error()) } @@ -80,17 +80,17 @@ func (p *PRUDPPacketLite) decode() error { return fmt.Errorf("Invalid PRUDPLite magic. Expected 0x80, got 0x%x", magic) } - p.optionsLength, err = p.readStream.ReadUInt8() + p.optionsLength, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite options length. %s", err.Error()) } - payloadLength, err := p.readStream.ReadUInt16LE() + payloadLength, err := p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite payload length. %s", err.Error()) } - streamTypes, err := p.readStream.ReadUInt8() + streamTypes, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite virtual ports stream types. %s", err.Error()) } @@ -98,22 +98,22 @@ func (p *PRUDPPacketLite) decode() error { p.sourceStreamType = streamTypes >> 4 p.destinationStreamType = streamTypes & 0xF - p.sourcePort, err = p.readStream.ReadUInt8() + p.sourcePort, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite virtual source port. %s", err.Error()) } - p.destinationPort, err = p.readStream.ReadUInt8() + p.destinationPort, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite virtual destination port. %s", err.Error()) } - p.fragmentID, err = p.readStream.ReadUInt8() + p.fragmentID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite fragment ID. %s", err.Error()) } - typeAndFlags, err := p.readStream.ReadUInt16LE() + typeAndFlags, err := p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to read PRUDPLite type and flags. %s", err.Error()) } @@ -121,7 +121,7 @@ func (p *PRUDPPacketLite) decode() error { p.flags = typeAndFlags >> 4 p.packetType = typeAndFlags & 0xF - p.sequenceID, err = p.readStream.ReadUInt16LE() + p.sequenceID, err = p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite sequence ID. %s", err.Error()) } @@ -142,15 +142,15 @@ func (p *PRUDPPacketLite) Bytes() []byte { stream := NewStreamOut(nil) - stream.WriteUInt8(0x80) - stream.WriteUInt8(uint8(len(options))) - stream.WriteUInt16LE(uint16(len(p.payload))) - stream.WriteUInt8((p.sourceStreamType << 4) | p.destinationStreamType) - stream.WriteUInt8(p.sourcePort) - stream.WriteUInt8(p.destinationPort) - stream.WriteUInt8(p.fragmentID) - stream.WriteUInt16LE(p.packetType | (p.flags << 4)) - stream.WriteUInt16LE(p.sequenceID) + stream.WritePrimitiveUInt8(0x80) + stream.WritePrimitiveUInt8(uint8(len(options))) + stream.WritePrimitiveUInt16LE(uint16(len(p.payload))) + stream.WritePrimitiveUInt8((p.sourceStreamType << 4) | p.destinationStreamType) + stream.WritePrimitiveUInt8(p.sourcePort) + stream.WritePrimitiveUInt8(p.destinationPort) + stream.WritePrimitiveUInt8(p.fragmentID) + stream.WritePrimitiveUInt16LE(p.packetType | (p.flags << 4)) + stream.WritePrimitiveUInt16LE(p.sequenceID) stream.Grow(int64(len(options))) stream.WriteBytesNext(options) @@ -166,19 +166,19 @@ func (p *PRUDPPacketLite) decodeOptions() error { optionsStream := NewStreamIn(data, nil) for optionsStream.Remaining() > 0 { - optionID, err := optionsStream.ReadUInt8() + optionID, err := optionsStream.ReadPrimitiveUInt8() if err != nil { return err } - optionSize, err := optionsStream.ReadUInt8() // * Options size. We already know the size based on the ID, though + optionSize, err := optionsStream.ReadPrimitiveUInt8() // * Options size. We already know the size based on the ID, though if err != nil { return err } if p.packetType == SynPacket || p.packetType == ConnectPacket { if optionID == 0 { - p.supportedFunctions, err = optionsStream.ReadUInt32LE() + p.supportedFunctions, err = optionsStream.ReadPrimitiveUInt32LE() p.minorVersion = p.supportedFunctions & 0xFF p.supportedFunctions = p.supportedFunctions >> 8 @@ -189,19 +189,19 @@ func (p *PRUDPPacketLite) decodeOptions() error { } if optionID == 4 { - p.maximumSubstreamID, err = optionsStream.ReadUInt8() + p.maximumSubstreamID, err = optionsStream.ReadPrimitiveUInt8() } } if p.packetType == ConnectPacket { if optionID == 3 { - p.initialUnreliableSequenceID, err = optionsStream.ReadUInt16LE() + p.initialUnreliableSequenceID, err = optionsStream.ReadPrimitiveUInt16LE() } } if p.packetType == DataPacket { if optionID == 2 { - p.fragmentID, err = optionsStream.ReadUInt8() + p.fragmentID, err = optionsStream.ReadPrimitiveUInt8() } } @@ -226,20 +226,20 @@ func (p *PRUDPPacketLite) encodeOptions() []byte { optionsStream := NewStreamOut(nil) if p.packetType == SynPacket || p.packetType == ConnectPacket { - optionsStream.WriteUInt8(0) - optionsStream.WriteUInt8(4) - optionsStream.WriteUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) + optionsStream.WritePrimitiveUInt8(0) + optionsStream.WritePrimitiveUInt8(4) + optionsStream.WritePrimitiveUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) if p.packetType == SynPacket && p.HasFlag(FlagAck) { - optionsStream.WriteUInt8(1) - optionsStream.WriteUInt8(16) + optionsStream.WritePrimitiveUInt8(1) + optionsStream.WritePrimitiveUInt8(16) optionsStream.Grow(16) optionsStream.WriteBytesNext(p.connectionSignature) } if p.packetType == ConnectPacket && !p.HasFlag(FlagAck) { - optionsStream.WriteUInt8(1) - optionsStream.WriteUInt8(16) + optionsStream.WritePrimitiveUInt8(1) + optionsStream.WritePrimitiveUInt8(16) optionsStream.Grow(16) optionsStream.WriteBytesNext(p.liteSignature) } diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index adfe43c1..0a738dbf 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -68,12 +68,12 @@ func (p *PRUDPPacketV0) decode() error { server := p.server start := p.readStream.ByteOffset() - source, err := p.readStream.ReadUInt8() + source, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 source. %s", err.Error()) } - destination, err := p.readStream.ReadUInt8() + destination, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 destination. %s", err.Error()) } @@ -84,7 +84,7 @@ func (p *PRUDPPacketV0) decode() error { p.destinationPort = destination & 0xF if server.IsQuazalMode { - typeAndFlags, err := p.readStream.ReadUInt8() + typeAndFlags, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 type and flags. %s", err.Error()) } @@ -92,7 +92,7 @@ func (p *PRUDPPacketV0) decode() error { p.flags = uint16(typeAndFlags >> 3) p.packetType = uint16(typeAndFlags & 7) } else { - typeAndFlags, err := p.readStream.ReadUInt16LE() + typeAndFlags, err := p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 type and flags. %s", err.Error()) } @@ -105,14 +105,14 @@ func (p *PRUDPPacketV0) decode() error { return errors.New("Invalid PRUDPv0 packet type") } - p.sessionID, err = p.readStream.ReadUInt8() + p.sessionID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 session ID. %s", err.Error()) } p.signature = p.readStream.ReadBytesNext(4) - p.sequenceID, err = p.readStream.ReadUInt16LE() + p.sequenceID, err = p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 sequence ID. %s", err.Error()) } @@ -130,7 +130,7 @@ func (p *PRUDPPacketV0) decode() error { return errors.New("Failed to read PRUDPv0 fragment ID. Not have enough data") } - p.fragmentID, err = p.readStream.ReadUInt8() + p.fragmentID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 fragment ID. %s", err.Error()) } @@ -143,7 +143,7 @@ func (p *PRUDPPacketV0) decode() error { return errors.New("Failed to read PRUDPv0 payload size. Not have enough data") } - payloadSize, err = p.readStream.ReadUInt16LE() + payloadSize, err = p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 payload size. %s", err.Error()) } @@ -156,7 +156,7 @@ func (p *PRUDPPacketV0) decode() error { } } - if p.readStream.Remaining() < int(payloadSize) { + if p.readStream.Remaining() < uint64(payloadSize) { return errors.New("Failed to read PRUDPv0 payload. Not have enough data") } @@ -174,9 +174,9 @@ func (p *PRUDPPacketV0) decode() error { var checksumU8 uint8 if server.EnhancedChecksum { - checksum, err = p.readStream.ReadUInt32LE() + checksum, err = p.readStream.ReadPrimitiveUInt32LE() } else { - checksumU8, err = p.readStream.ReadUInt8() + checksumU8, err = p.readStream.ReadPrimitiveUInt8() checksum = uint32(checksumU8) } @@ -198,19 +198,19 @@ func (p *PRUDPPacketV0) Bytes() []byte { server := p.server stream := NewStreamOut(server) - stream.WriteUInt8(p.sourcePort | (p.sourceStreamType << 4)) - stream.WriteUInt8(p.destinationPort | (p.destinationStreamType << 4)) + stream.WritePrimitiveUInt8(p.sourcePort | (p.sourceStreamType << 4)) + stream.WritePrimitiveUInt8(p.destinationPort | (p.destinationStreamType << 4)) if server.IsQuazalMode { - stream.WriteUInt8(uint8(p.packetType | (p.flags << 3))) + stream.WritePrimitiveUInt8(uint8(p.packetType | (p.flags << 3))) } else { - stream.WriteUInt16LE(p.packetType | (p.flags << 4)) + stream.WritePrimitiveUInt16LE(p.packetType | (p.flags << 4)) } - stream.WriteUInt8(p.sessionID) + stream.WritePrimitiveUInt8(p.sessionID) stream.Grow(int64(len(p.signature))) stream.WriteBytesNext(p.signature) - stream.WriteUInt16LE(p.sequenceID) + stream.WritePrimitiveUInt16LE(p.sequenceID) if p.packetType == SynPacket || p.packetType == ConnectPacket { stream.Grow(int64(len(p.connectionSignature))) @@ -218,11 +218,11 @@ func (p *PRUDPPacketV0) Bytes() []byte { } if p.packetType == DataPacket { - stream.WriteUInt8(p.fragmentID) + stream.WritePrimitiveUInt8(p.fragmentID) } if p.HasFlag(FlagHasSize) { - stream.WriteUInt16LE(uint16(len(p.payload))) + stream.WritePrimitiveUInt16LE(uint16(len(p.payload))) } if len(p.payload) > 0 { @@ -239,9 +239,9 @@ func (p *PRUDPPacketV0) Bytes() []byte { } if server.EnhancedChecksum { - stream.WriteUInt32LE(checksum) + stream.WritePrimitiveUInt32LE(checksum) } else { - stream.WriteUInt8(uint8(checksum)) + stream.WritePrimitiveUInt8(uint8(checksum)) } return stream.Bytes() diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 7deacd87..b708e30e 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -134,7 +134,7 @@ func (p *PRUDPPacketV1) decodeHeader() error { return errors.New("Failed to read PRUDPv1 magic. Not have enough data") } - version, err := p.readStream.ReadUInt8() + version, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPv1 version. %s", err.Error()) } @@ -143,22 +143,22 @@ func (p *PRUDPPacketV1) decodeHeader() error { return fmt.Errorf("Invalid PRUDPv1 version. Expected 1, got %d", version) } - p.optionsLength, err = p.readStream.ReadUInt8() + p.optionsLength, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPv1 options length. %s", err.Error()) } - p.payloadLength, err = p.readStream.ReadUInt16LE() + p.payloadLength, err = p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to decode PRUDPv1 payload length. %s", err.Error()) } - source, err := p.readStream.ReadUInt8() + source, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 source. %s", err.Error()) } - destination, err := p.readStream.ReadUInt8() + destination, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 destination. %s", err.Error()) } @@ -169,7 +169,7 @@ func (p *PRUDPPacketV1) decodeHeader() error { p.destinationPort = destination & 0xF // TODO - Does QRV also encode it this way in PRUDPv1? - typeAndFlags, err := p.readStream.ReadUInt16LE() + typeAndFlags, err := p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 type and flags. %s", err.Error()) } @@ -181,17 +181,17 @@ func (p *PRUDPPacketV1) decodeHeader() error { return errors.New("Invalid PRUDPv1 packet type") } - p.sessionID, err = p.readStream.ReadUInt8() + p.sessionID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 session ID. %s", err.Error()) } - p.substreamID, err = p.readStream.ReadUInt8() + p.substreamID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 substream ID. %s", err.Error()) } - p.sequenceID, err = p.readStream.ReadUInt16LE() + p.sequenceID, err = p.readStream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 sequence ID. %s", err.Error()) } @@ -202,15 +202,15 @@ func (p *PRUDPPacketV1) decodeHeader() error { func (p *PRUDPPacketV1) encodeHeader() []byte { stream := NewStreamOut(nil) - stream.WriteUInt8(1) // * Version - stream.WriteUInt8(p.optionsLength) - stream.WriteUInt16LE(uint16(len(p.payload))) - stream.WriteUInt8(p.sourcePort | (p.sourceStreamType << 4)) - stream.WriteUInt8(p.destinationPort | (p.destinationStreamType << 4)) - stream.WriteUInt16LE(p.packetType | (p.flags << 4)) // TODO - Does QRV also encode it this way in PRUDPv1? - stream.WriteUInt8(p.sessionID) - stream.WriteUInt8(p.substreamID) - stream.WriteUInt16LE(p.sequenceID) + stream.WritePrimitiveUInt8(1) // * Version + stream.WritePrimitiveUInt8(p.optionsLength) + stream.WritePrimitiveUInt16LE(uint16(len(p.payload))) + stream.WritePrimitiveUInt8(p.sourcePort | (p.sourceStreamType << 4)) + stream.WritePrimitiveUInt8(p.destinationPort | (p.destinationStreamType << 4)) + stream.WritePrimitiveUInt16LE(p.packetType | (p.flags << 4)) // TODO - Does QRV also encode it this way in PRUDPv1? + stream.WritePrimitiveUInt8(p.sessionID) + stream.WritePrimitiveUInt8(p.substreamID) + stream.WritePrimitiveUInt16LE(p.sequenceID) return stream.Bytes() } @@ -220,19 +220,19 @@ func (p *PRUDPPacketV1) decodeOptions() error { optionsStream := NewStreamIn(data, nil) for optionsStream.Remaining() > 0 { - optionID, err := optionsStream.ReadUInt8() + optionID, err := optionsStream.ReadPrimitiveUInt8() if err != nil { return err } - _, err = optionsStream.ReadUInt8() // * Options size. We already know the size based on the ID, though + _, err = optionsStream.ReadPrimitiveUInt8() // * Options size. We already know the size based on the ID, though if err != nil { return err } if p.packetType == SynPacket || p.packetType == ConnectPacket { if optionID == 0 { - p.supportedFunctions, err = optionsStream.ReadUInt32LE() + p.supportedFunctions, err = optionsStream.ReadPrimitiveUInt32LE() p.minorVersion = p.supportedFunctions & 0xFF p.supportedFunctions = p.supportedFunctions >> 8 @@ -243,19 +243,19 @@ func (p *PRUDPPacketV1) decodeOptions() error { } if optionID == 4 { - p.maximumSubstreamID, err = optionsStream.ReadUInt8() + p.maximumSubstreamID, err = optionsStream.ReadPrimitiveUInt8() } } if p.packetType == ConnectPacket { if optionID == 3 { - p.initialUnreliableSequenceID, err = optionsStream.ReadUInt16LE() + p.initialUnreliableSequenceID, err = optionsStream.ReadPrimitiveUInt16LE() } } if p.packetType == DataPacket { if optionID == 2 { - p.fragmentID, err = optionsStream.ReadUInt8() + p.fragmentID, err = optionsStream.ReadPrimitiveUInt8() } } @@ -274,12 +274,12 @@ func (p *PRUDPPacketV1) encodeOptions() []byte { optionsStream := NewStreamOut(nil) if p.packetType == SynPacket || p.packetType == ConnectPacket { - optionsStream.WriteUInt8(0) - optionsStream.WriteUInt8(4) - optionsStream.WriteUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) + optionsStream.WritePrimitiveUInt8(0) + optionsStream.WritePrimitiveUInt8(4) + optionsStream.WritePrimitiveUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) - optionsStream.WriteUInt8(1) - optionsStream.WriteUInt8(16) + optionsStream.WritePrimitiveUInt8(1) + optionsStream.WritePrimitiveUInt8(16) optionsStream.Grow(16) optionsStream.WriteBytesNext(p.connectionSignature) @@ -292,20 +292,20 @@ func (p *PRUDPPacketV1) encodeOptions() []byte { // * parsed, though, order REALLY doesn't matter. // * NintendoClients expects option 3 before 4, though if p.packetType == ConnectPacket { - optionsStream.WriteUInt8(3) - optionsStream.WriteUInt8(2) - optionsStream.WriteUInt16LE(p.initialUnreliableSequenceID) + optionsStream.WritePrimitiveUInt8(3) + optionsStream.WritePrimitiveUInt8(2) + optionsStream.WritePrimitiveUInt16LE(p.initialUnreliableSequenceID) } - optionsStream.WriteUInt8(4) - optionsStream.WriteUInt8(1) - optionsStream.WriteUInt8(p.maximumSubstreamID) + optionsStream.WritePrimitiveUInt8(4) + optionsStream.WritePrimitiveUInt8(1) + optionsStream.WritePrimitiveUInt8(p.maximumSubstreamID) } if p.packetType == DataPacket { - optionsStream.WriteUInt8(2) - optionsStream.WriteUInt8(1) - optionsStream.WriteUInt8(p.fragmentID) + optionsStream.WritePrimitiveUInt8(2) + optionsStream.WritePrimitiveUInt8(1) + optionsStream.WritePrimitiveUInt8(p.fragmentID) } return optionsStream.Bytes() diff --git a/prudp_server.go b/prudp_server.go index f17db16a..94e3d615 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -11,6 +11,7 @@ import ( "time" "github.com/PretendoNetwork/nex-go/compression" + "github.com/PretendoNetwork/nex-go/types" "github.com/lxzan/gws" ) @@ -42,7 +43,7 @@ type PRUDPServer struct { clientRemovedEventHandlers []func(client *PRUDPClient) connectionIDCounter *Counter[uint32] pingTimeout time.Duration - passwordFromPIDHandler func(pid *PID) (string, uint32) + passwordFromPIDHandler func(pid *types.PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte EnhancedChecksum bool PRUDPv0CustomChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 @@ -312,13 +313,13 @@ func (s *PRUDPServer) handleMultiAcknowledgment(packet PRUDPPacketInterface) { if packet.SubstreamID() == 1 { // * New aggregate acknowledgment packets set this to 1 // * and encode the real substream ID in in the payload - substreamID, _ := stream.ReadUInt8() - additionalIDsCount, _ := stream.ReadUInt8() - baseSequenceID, _ = stream.ReadUInt16LE() + substreamID, _ := stream.ReadPrimitiveUInt8() + additionalIDsCount, _ := stream.ReadPrimitiveUInt8() + baseSequenceID, _ = stream.ReadPrimitiveUInt16LE() substream = client.reliableSubstream(substreamID) for i := 0; i < int(additionalIDsCount); i++ { - additionalID, _ := stream.ReadUInt16LE() + additionalID, _ := stream.ReadPrimitiveUInt16LE() sequenceIDs = append(sequenceIDs, additionalID) } } else { @@ -329,7 +330,7 @@ func (s *PRUDPServer) handleMultiAcknowledgment(packet PRUDPPacketInterface) { baseSequenceID = packet.SequenceID() for stream.Remaining() > 0 { - additionalID, _ := stream.ReadUInt16LE() + additionalID, _ := stream.ReadPrimitiveUInt16LE() sequenceIDs = append(sequenceIDs, additionalID) } } @@ -464,8 +465,8 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { // * The response value is a Buffer whose data contains // * checkValue+1. This is just a lazy way of encoding // * a Buffer type - stream.WriteUInt32LE(4) // * Buffer length - stream.WriteUInt32LE(checkValue + 1) // * Buffer data + stream.WritePrimitiveUInt32LE(4) // * Buffer length + stream.WritePrimitiveUInt32LE(checkValue + 1) // * Buffer data payload = stream.Bytes() } else { @@ -511,24 +512,23 @@ func (s *PRUDPServer) handlePing(packet PRUDPPacketInterface) { } } -func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *PID, uint32, error) { +func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, uint32, error) { stream := NewStreamIn(payload, s) - ticketData, err := stream.ReadBuffer() - if err != nil { + ticketData := types.NewBuffer() + if err := ticketData.ExtractFrom(stream); err != nil { return nil, nil, 0, err } - requestData, err := stream.ReadBuffer() - if err != nil { + requestData := types.NewBuffer() + if err := requestData.ExtractFrom(stream); err != nil { return nil, nil, 0, err } - serverKey := DeriveKerberosKey(NewPID[uint64](2), s.kerberosPassword) + serverKey := DeriveKerberosKey(types.NewPID[uint64](2), s.kerberosPassword) ticket := NewKerberosTicketInternalData() - err = ticket.Decrypt(NewStreamIn(ticketData, s), serverKey) - if err != nil { + if err := ticket.Decrypt(NewStreamIn([]byte(*ticketData), s), serverKey); err != nil { return nil, nil, 0, err } @@ -543,24 +543,24 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *PID, uint32, sessionKey := ticket.SessionKey kerberos := NewKerberosEncryption(sessionKey) - decryptedRequestData, err := kerberos.Decrypt(requestData) + decryptedRequestData, err := kerberos.Decrypt(*requestData) if err != nil { return nil, nil, 0, err } checkDataStream := NewStreamIn(decryptedRequestData, s) - userPID, err := checkDataStream.ReadPID() - if err != nil { + userPID := types.NewPID[uint64](0) + if err := userPID.ExtractFrom(checkDataStream); err != nil { return nil, nil, 0, err } - _, err = checkDataStream.ReadUInt32LE() // * CID of secure server station url + _, err = checkDataStream.ReadPrimitiveUInt32LE() // * CID of secure server station url if err != nil { return nil, nil, 0, err } - responseCheck, err := checkDataStream.ReadUInt32LE() + responseCheck, err := checkDataStream.ReadPrimitiveUInt32LE() if err != nil { return nil, nil, 0, err } @@ -999,7 +999,7 @@ func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid ui virtualServerStream, _ := virtualServer.Get(serverStreamType) virtualServerStream.Each(func(discriminator string, c *PRUDPClient) bool { - if c.pid.pid == pid { + if c.pid.Value() == pid { client = c return true } @@ -1011,7 +1011,7 @@ func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid ui } // PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result -func (s *PRUDPServer) PasswordFromPID(pid *PID) (string, uint32) { +func (s *PRUDPServer) PasswordFromPID(pid *types.PID) (string, uint32) { if s.passwordFromPIDHandler == nil { logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") return "", Errors.Core.NotImplemented @@ -1021,7 +1021,7 @@ func (s *PRUDPServer) PasswordFromPID(pid *PID) (string, uint32) { } // SetPasswordFromPIDFunction sets the function for the auth server to get a NEX password using the PID -func (s *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) { +func (s *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) { s.passwordFromPIDHandler = handler } diff --git a/rmc_message.go b/rmc_message.go index 66984bb6..14c5ba5a 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -3,23 +3,25 @@ package nex import ( "errors" "fmt" + + "github.com/PretendoNetwork/nex-go/types" ) // RMCMessage represents a message in the RMC (Remote Method Call) protocol type RMCMessage struct { Server ServerInterface - VerboseMode bool // * Determines whether or not to encode the message using the "verbose" encoding method - IsRequest bool // * Indicates if the message is a request message (true) or response message (false) - IsSuccess bool // * Indicates if the message is a success message (true) for a response message - IsHPP bool // * Indicates if the message is an HPP message - ProtocolID uint16 // * Protocol ID of the message. Only present in "packed" variations - ProtocolName string // * Protocol name of the message. Only present in "verbose" variations - CallID uint32 // * Call ID associated with the message - MethodID uint32 // * Method ID in the requested protocol. Only present in "packed" variations - MethodName string // * Method name in the requested protocol. Only present in "verbose" variations - ErrorCode uint32 // * Error code for a response message - ClassVersionContainer *ClassVersionContainer // * Contains version info for Structures in the request. Only present in "verbose" variations - Parameters []byte // * Input for the method + VerboseMode bool // * Determines whether or not to encode the message using the "verbose" encoding method + IsRequest bool // * Indicates if the message is a request message (true) or response message (false) + IsSuccess bool // * Indicates if the message is a success message (true) for a response message + IsHPP bool // * Indicates if the message is an HPP message + ProtocolID uint16 // * Protocol ID of the message. Only present in "packed" variations + ProtocolName *types.String // * Protocol name of the message. Only present in "verbose" variations + CallID uint32 // * Call ID associated with the message + MethodID uint32 // * Method ID in the requested protocol. Only present in "packed" variations + MethodName *types.String // * Method name in the requested protocol. Only present in "verbose" variations + ErrorCode uint32 // * Error code for a response message + ClassVersionContainer *types.ClassVersionContainer // * Contains version info for Structures in the request. Only present in "verbose" variations + Parameters []byte // * Input for the method // TODO - Verbose messages suffix response method names with "*". Should we have a "HasResponsePointer" sort of field? } @@ -56,16 +58,16 @@ func (rmc *RMCMessage) FromBytes(data []byte) error { func (rmc *RMCMessage) decodePacked(data []byte) error { stream := NewStreamIn(data, rmc.Server) - length, err := stream.ReadUInt32LE() + length, err := stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message size. %s", err.Error()) } - if stream.Remaining() != int(length) { + if stream.Remaining() != uint64(length) { return errors.New("RMC Message has unexpected size") } - protocolID, err := stream.ReadUInt8() + protocolID, err := stream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read RMC Message protocol ID. %s", err.Error()) } @@ -73,7 +75,7 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { rmc.ProtocolID = uint16(protocolID & ^byte(0x80)) if rmc.ProtocolID == 0x7F { - rmc.ProtocolID, err = stream.ReadUInt16LE() + rmc.ProtocolID, err = stream.ReadPrimitiveUInt16LE() if err != nil { return fmt.Errorf("Failed to read RMC Message extended protocol ID. %s", err.Error()) } @@ -81,12 +83,12 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { if protocolID&0x80 != 0 { rmc.IsRequest = true - rmc.CallID, err = stream.ReadUInt32LE() + rmc.CallID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (request) call ID. %s", err.Error()) } - rmc.MethodID, err = stream.ReadUInt32LE() + rmc.MethodID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (request) method ID. %s", err.Error()) } @@ -97,18 +99,18 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } } else { rmc.IsRequest = false - rmc.IsSuccess, err = stream.ReadBool() + rmc.IsSuccess, err = stream.ReadPrimitiveBool() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) error check. %s", err.Error()) } if rmc.IsSuccess { - rmc.CallID, err = stream.ReadUInt32LE() + rmc.CallID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) } - rmc.MethodID, err = stream.ReadUInt32LE() + rmc.MethodID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) method ID. %s", err.Error()) } @@ -124,12 +126,12 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } } else { - rmc.ErrorCode, err = stream.ReadUInt32LE() + rmc.ErrorCode, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) error code. %s", err.Error()) } - rmc.CallID, err = stream.ReadUInt32LE() + rmc.CallID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) } @@ -143,38 +145,38 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { func (rmc *RMCMessage) decodeVerbose(data []byte) error { stream := NewStreamIn(data, rmc.Server) - length, err := stream.ReadUInt32LE() + length, err := stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message size. %s", err.Error()) } - if stream.Remaining() != int(length) { + if stream.Remaining() != uint64(length) { return errors.New("RMC Message has unexpected size") } - rmc.ProtocolName, err = stream.ReadString() - if err != nil { + rmc.ProtocolName = types.NewString() + if err := rmc.ProtocolName.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read RMC Message protocol name. %s", err.Error()) } - rmc.IsRequest, err = stream.ReadBool() + rmc.IsRequest, err = stream.ReadPrimitiveBool() if err != nil { return fmt.Errorf("Failed to read RMC Message \"is request\" bool. %s", err.Error()) } if rmc.IsRequest { - rmc.CallID, err = stream.ReadUInt32LE() + rmc.CallID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (request) call ID. %s", err.Error()) } - rmc.MethodName, err = stream.ReadString() - if err != nil { + rmc.MethodName = types.NewString() + if err := rmc.MethodName.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read RMC Message (request) method name. %s", err.Error()) } - rmc.ClassVersionContainer, err = StreamReadStructure(stream, NewClassVersionContainer()) - if err != nil { + rmc.ClassVersionContainer = types.NewClassVersionContainer() + if err := rmc.ClassVersionContainer.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read RMC Message ClassVersionContainer. %s", err.Error()) } @@ -183,19 +185,19 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { return fmt.Errorf("Failed to read RMC Message (request) parameters. %s", err.Error()) } } else { - rmc.IsSuccess, err = stream.ReadBool() + rmc.IsSuccess, err = stream.ReadPrimitiveBool() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) error check. %s", err.Error()) } if rmc.IsSuccess { - rmc.CallID, err = stream.ReadUInt32LE() + rmc.CallID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) } - rmc.MethodName, err = stream.ReadString() - if err != nil { + rmc.MethodName = types.NewString() + if err := rmc.MethodName.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read RMC Message (response) method name. %s", err.Error()) } @@ -205,12 +207,12 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { } } else { - rmc.ErrorCode, err = stream.ReadUInt32LE() + rmc.ErrorCode, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) error code. %s", err.Error()) } - rmc.CallID, err = stream.ReadUInt32LE() + rmc.CallID, err = stream.ReadPrimitiveUInt32LE() if err != nil { return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) } @@ -243,35 +245,35 @@ func (rmc *RMCMessage) encodePacked() []byte { // * do it for accuracy. if !rmc.IsHPP || (rmc.IsHPP && rmc.IsRequest) { if rmc.ProtocolID < 0x80 { - stream.WriteUInt8(uint8(rmc.ProtocolID | protocolIDFlag)) + stream.WritePrimitiveUInt8(uint8(rmc.ProtocolID | protocolIDFlag)) } else { - stream.WriteUInt8(uint8(0x7F | protocolIDFlag)) - stream.WriteUInt16LE(rmc.ProtocolID) + stream.WritePrimitiveUInt8(uint8(0x7F | protocolIDFlag)) + stream.WritePrimitiveUInt16LE(rmc.ProtocolID) } } if rmc.IsRequest { - stream.WriteUInt32LE(rmc.CallID) - stream.WriteUInt32LE(rmc.MethodID) + stream.WritePrimitiveUInt32LE(rmc.CallID) + stream.WritePrimitiveUInt32LE(rmc.MethodID) if rmc.Parameters != nil && len(rmc.Parameters) > 0 { stream.Grow(int64(len(rmc.Parameters))) stream.WriteBytesNext(rmc.Parameters) } } else { - stream.WriteBool(rmc.IsSuccess) + stream.WritePrimitiveBool(rmc.IsSuccess) if rmc.IsSuccess { - stream.WriteUInt32LE(rmc.CallID) - stream.WriteUInt32LE(rmc.MethodID | 0x8000) + stream.WritePrimitiveUInt32LE(rmc.CallID) + stream.WritePrimitiveUInt32LE(rmc.MethodID | 0x8000) if rmc.Parameters != nil && len(rmc.Parameters) > 0 { stream.Grow(int64(len(rmc.Parameters))) stream.WriteBytesNext(rmc.Parameters) } } else { - stream.WriteUInt32LE(uint32(rmc.ErrorCode)) - stream.WriteUInt32LE(rmc.CallID) + stream.WritePrimitiveUInt32LE(uint32(rmc.ErrorCode)) + stream.WritePrimitiveUInt32LE(rmc.CallID) } } @@ -279,7 +281,7 @@ func (rmc *RMCMessage) encodePacked() []byte { message := NewStreamOut(rmc.Server) - message.WriteUInt32LE(uint32(len(serialized))) + message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) message.WriteBytesNext(serialized) @@ -289,18 +291,18 @@ func (rmc *RMCMessage) encodePacked() []byte { func (rmc *RMCMessage) encodeVerbose() []byte { stream := NewStreamOut(rmc.Server) - stream.WriteString(rmc.ProtocolName) - stream.WriteBool(rmc.IsRequest) + rmc.ProtocolName.WriteTo(stream) + stream.WritePrimitiveBool(rmc.IsRequest) if rmc.IsRequest { - stream.WriteUInt32LE(rmc.CallID) - stream.WriteString(rmc.MethodName) + stream.WritePrimitiveUInt32LE(rmc.CallID) + rmc.MethodName.WriteTo(stream) if rmc.ClassVersionContainer != nil { - stream.WriteStructure(rmc.ClassVersionContainer) + rmc.ClassVersionContainer.WriteTo(stream) } else { // * Fail safe. This is always present even if no structures are used - stream.WriteUInt32LE(0) + stream.WritePrimitiveUInt32LE(0) } if rmc.Parameters != nil && len(rmc.Parameters) > 0 { @@ -308,19 +310,19 @@ func (rmc *RMCMessage) encodeVerbose() []byte { stream.WriteBytesNext(rmc.Parameters) } } else { - stream.WriteBool(rmc.IsSuccess) + stream.WritePrimitiveBool(rmc.IsSuccess) if rmc.IsSuccess { - stream.WriteUInt32LE(rmc.CallID) - stream.WriteString(rmc.MethodName) + stream.WritePrimitiveUInt32LE(rmc.CallID) + rmc.MethodName.WriteTo(stream) if rmc.Parameters != nil && len(rmc.Parameters) > 0 { stream.Grow(int64(len(rmc.Parameters))) stream.WriteBytesNext(rmc.Parameters) } } else { - stream.WriteUInt32LE(uint32(rmc.ErrorCode)) - stream.WriteUInt32LE(rmc.CallID) + stream.WritePrimitiveUInt32LE(uint32(rmc.ErrorCode)) + stream.WritePrimitiveUInt32LE(rmc.CallID) } } @@ -328,7 +330,7 @@ func (rmc *RMCMessage) encodeVerbose() []byte { message := NewStreamOut(rmc.Server) - message.WriteUInt32LE(uint32(len(serialized))) + message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) message.WriteBytesNext(serialized) diff --git a/server_interface.go b/server_interface.go index f61fca74..8733c804 100644 --- a/server_interface.go +++ b/server_interface.go @@ -1,5 +1,7 @@ package nex +import "github.com/PretendoNetwork/nex-go/types" + // ServerInterface defines all the methods a server should have regardless of type type ServerInterface interface { AccessKey() string @@ -15,8 +17,8 @@ type ServerInterface interface { SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) OnData(handler func(packet PacketInterface)) - PasswordFromPID(pid *PID) (string, uint32) - SetPasswordFromPIDFunction(handler func(pid *PID) (string, uint32)) + PasswordFromPID(pid *types.PID) (string, uint32) + SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) StringLengthSize() int SetStringLengthSize(size int) } diff --git a/stream_in.go b/stream_in.go index b48df3d9..ab826517 100644 --- a/stream_in.go +++ b/stream_in.go @@ -2,8 +2,6 @@ package nex import ( "errors" - "fmt" - "strings" crunch "github.com/superwhiskers/crunch/v3" ) @@ -14,109 +12,95 @@ type StreamIn struct { Server ServerInterface } -// Remaining returns the amount of data left to be read in the buffer -func (s *StreamIn) Remaining() int { - return len(s.Bytes()[s.ByteOffset():]) -} +// StringLengthSize returns the expected size of String length fields +func (s *StreamIn) StringLengthSize() int { + size := 2 -// ReadRemaining reads all the data left to be read in the buffer -func (s *StreamIn) ReadRemaining() []byte { - // TODO - Should we do a bounds check here? Or just allow empty slices? - return s.ReadBytesNext(int64(s.Remaining())) -} - -// ReadUInt8 reads a uint8 -func (s *StreamIn) ReadUInt8() (uint8, error) { - if s.Remaining() < 1 { - return 0, errors.New("Not enough data to read uint8") + if s.Server != nil { + size = s.Server.StringLengthSize() } - return uint8(s.ReadByteNext()), nil + return size } -// ReadInt8 reads a uint8 -func (s *StreamIn) ReadInt8() (int8, error) { - if s.Remaining() < 1 { - return 0, errors.New("Not enough data to read int8") - } - - return int8(s.ReadByteNext()), nil -} +// PIDSize returns the size of PID types +func (s *StreamIn) PIDSize() int { + size := 4 -// ReadUInt16LE reads a Little-Endian encoded uint16 -func (s *StreamIn) ReadUInt16LE() (uint16, error) { - if s.Remaining() < 2 { - return 0, errors.New("Not enough data to read uint16") + if s.Server != nil && s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + size = 8 } - return s.ReadU16LENext(1)[0], nil + return size } -// ReadUInt16BE reads a Big-Endian encoded uint16 -func (s *StreamIn) ReadUInt16BE() (uint16, error) { - if s.Remaining() < 2 { - return 0, errors.New("Not enough data to read uint16") +// UseStructureHeader determines if Structure headers should be used +func (s *StreamIn) UseStructureHeader() bool { + useStructureHeader := false + + if s.Server != nil { + switch server := s.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructureHeader = server.PRUDPMinorVersion >= 3 + default: + useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") + } } - return s.ReadU16BENext(1)[0], nil + return useStructureHeader } -// ReadInt16LE reads a Little-Endian encoded int16 -func (s *StreamIn) ReadInt16LE() (int16, error) { - if s.Remaining() < 2 { - return 0, errors.New("Not enough data to read int16") - } - - return int16(s.ReadU16LENext(1)[0]), nil +// Remaining returns the amount of data left to be read in the buffer +func (s *StreamIn) Remaining() uint64 { + return uint64(len(s.Bytes()[s.ByteOffset():])) } -// ReadInt16BE reads a Big-Endian encoded int16 -func (s *StreamIn) ReadInt16BE() (int16, error) { - if s.Remaining() < 2 { - return 0, errors.New("Not enough data to read int16") - } +// ReadRemaining reads all the data left to be read in the buffer +func (s *StreamIn) ReadRemaining() []byte { + // * Can safely ignore this error, since s.Remaining() will never be less than itself + remaining, _ := s.Read(uint64(s.Remaining())) - return int16(s.ReadU16BENext(1)[0]), nil + return remaining } -// ReadUInt32LE reads a Little-Endian encoded uint32 -func (s *StreamIn) ReadUInt32LE() (uint32, error) { - if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read uint32") +// Read reads the specified number of bytes. Returns an error if OOB +func (s *StreamIn) Read(length uint64) ([]byte, error) { + if s.Remaining() < length { + return []byte{}, errors.New("Read is OOB") } - return s.ReadU32LENext(1)[0], nil + return s.ReadBytesNext(int64(length)), nil } -// ReadUInt32BE reads a Big-Endian encoded uint32 -func (s *StreamIn) ReadUInt32BE() (uint32, error) { - if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read uint32") +// ReadPrimitiveUInt8 reads a uint8 +func (s *StreamIn) ReadPrimitiveUInt8() (uint8, error) { + if s.Remaining() < 1 { + return 0, errors.New("Not enough data to read uint8") } - return s.ReadU32BENext(1)[0], nil + return uint8(s.ReadByteNext()), nil } -// ReadInt32LE reads a Little-Endian encoded int32 -func (s *StreamIn) ReadInt32LE() (int32, error) { - if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read int32") +// ReadPrimitiveUInt16LE reads a Little-Endian encoded uint16 +func (s *StreamIn) ReadPrimitiveUInt16LE() (uint16, error) { + if s.Remaining() < 2 { + return 0, errors.New("Not enough data to read uint16") } - return int32(s.ReadU32LENext(1)[0]), nil + return s.ReadU16LENext(1)[0], nil } -// ReadInt32BE reads a Big-Endian encoded int32 -func (s *StreamIn) ReadInt32BE() (int32, error) { +// ReadPrimitiveUInt32LE reads a Little-Endian encoded uint32 +func (s *StreamIn) ReadPrimitiveUInt32LE() (uint32, error) { if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read int32") + return 0, errors.New("Not enough data to read uint32") } - return int32(s.ReadU32BENext(1)[0]), nil + return s.ReadU32LENext(1)[0], nil } -// ReadUInt64LE reads a Little-Endian encoded uint64 -func (s *StreamIn) ReadUInt64LE() (uint64, error) { +// ReadPrimitiveUInt64LE reads a Little-Endian encoded uint64 +func (s *StreamIn) ReadPrimitiveUInt64LE() (uint64, error) { if s.Remaining() < 8 { return 0, errors.New("Not enough data to read uint64") } @@ -124,53 +108,53 @@ func (s *StreamIn) ReadUInt64LE() (uint64, error) { return s.ReadU64LENext(1)[0], nil } -// ReadUInt64BE reads a Big-Endian encoded uint64 -func (s *StreamIn) ReadUInt64BE() (uint64, error) { - if s.Remaining() < 8 { - return 0, errors.New("Not enough data to read uint64") +// ReadPrimitiveInt8 reads a uint8 +func (s *StreamIn) ReadPrimitiveInt8() (int8, error) { + if s.Remaining() < 1 { + return 0, errors.New("Not enough data to read int8") } - return s.ReadU64BENext(1)[0], nil + return int8(s.ReadByteNext()), nil } -// ReadInt64LE reads a Little-Endian encoded int64 -func (s *StreamIn) ReadInt64LE() (int64, error) { - if s.Remaining() < 8 { - return 0, errors.New("Not enough data to read int64") +// ReadPrimitiveInt16LE reads a Little-Endian encoded int16 +func (s *StreamIn) ReadPrimitiveInt16LE() (int16, error) { + if s.Remaining() < 2 { + return 0, errors.New("Not enough data to read int16") } - return int64(s.ReadU64LENext(1)[0]), nil + return int16(s.ReadU16LENext(1)[0]), nil } -// ReadInt64BE reads a Big-Endian encoded int64 -func (s *StreamIn) ReadInt64BE() (int64, error) { - if s.Remaining() < 8 { - return 0, errors.New("Not enough data to read int64") +// ReadPrimitiveInt32LE reads a Little-Endian encoded int32 +func (s *StreamIn) ReadPrimitiveInt32LE() (int32, error) { + if s.Remaining() < 4 { + return 0, errors.New("Not enough data to read int32") } - return int64(s.ReadU64BENext(1)[0]), nil + return int32(s.ReadU32LENext(1)[0]), nil } -// ReadFloat32LE reads a Little-Endian encoded float32 -func (s *StreamIn) ReadFloat32LE() (float32, error) { - if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read float32") +// ReadPrimitiveInt64LE reads a Little-Endian encoded int64 +func (s *StreamIn) ReadPrimitiveInt64LE() (int64, error) { + if s.Remaining() < 8 { + return 0, errors.New("Not enough data to read int64") } - return s.ReadF32LENext(1)[0], nil + return int64(s.ReadU64LENext(1)[0]), nil } -// ReadFloat32BE reads a Big-Endian encoded float32 -func (s *StreamIn) ReadFloat32BE() (float32, error) { +// ReadPrimitiveFloat32LE reads a Little-Endian encoded float32 +func (s *StreamIn) ReadPrimitiveFloat32LE() (float32, error) { if s.Remaining() < 4 { return 0, errors.New("Not enough data to read float32") } - return s.ReadF32BENext(1)[0], nil + return s.ReadF32LENext(1)[0], nil } -// ReadFloat64LE reads a Little-Endian encoded float64 -func (s *StreamIn) ReadFloat64LE() (float64, error) { +// ReadPrimitiveFloat64LE reads a Little-Endian encoded float64 +func (s *StreamIn) ReadPrimitiveFloat64LE() (float64, error) { if s.Remaining() < 8 { return 0, errors.New("Not enough data to read float64") } @@ -178,17 +162,8 @@ func (s *StreamIn) ReadFloat64LE() (float64, error) { return s.ReadF64LENext(1)[0], nil } -// ReadFloat64BE reads a Big-Endian encoded float64 -func (s *StreamIn) ReadFloat64BE() (float64, error) { - if s.Remaining() < 8 { - return 0, errors.New("Not enough data to read float64") - } - - return s.ReadF64BENext(1)[0], nil -} - -// ReadBool reads a bool -func (s *StreamIn) ReadBool() (bool, error) { +// ReadPrimitiveBool reads a bool +func (s *StreamIn) ReadPrimitiveBool() (bool, error) { if s.Remaining() < 1 { return false, errors.New("Not enough data to read bool") } @@ -196,745 +171,6 @@ func (s *StreamIn) ReadBool() (bool, error) { return s.ReadByteNext() == 1, nil } -// ReadPID reads a PID. The size depends on the server version -func (s *StreamIn) ReadPID() (*PID, error) { - if s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - if s.Remaining() < 8 { - return nil, errors.New("Not enough data to read PID") - } - - pid, _ := s.ReadUInt64LE() - - return NewPID(pid), nil - } else { - if s.Remaining() < 4 { - return nil, errors.New("Not enough data to read legacy PID") - } - - pid, _ := s.ReadUInt32LE() - - return NewPID(pid), nil - } -} - -// ReadString reads and returns a nex string type -func (s *StreamIn) ReadString() (string, error) { - var length int64 - var err error - - // TODO - These variable names kinda suck? - if s.Server == nil { - l, e := s.ReadUInt16LE() - length = int64(l) - err = e - } else if s.Server.StringLengthSize() == 4 { - l, e := s.ReadUInt32LE() - length = int64(l) - err = e - } else { - l, e := s.ReadUInt16LE() - length = int64(l) - err = e - } - - if err != nil { - return "", fmt.Errorf("Failed to read NEX string length. %s", err.Error()) - } - - if s.Remaining() < int(length) { - return "", errors.New("NEX string length longer than data size") - } - - stringData := s.ReadBytesNext(length) - str := string(stringData) - - return strings.TrimRight(str, "\x00"), nil -} - -// ReadBuffer reads a nex Buffer type -func (s *StreamIn) ReadBuffer() ([]byte, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return []byte{}, fmt.Errorf("Failed to read NEX buffer length. %s", err.Error()) - } - - if s.Remaining() < int(length) { - return []byte{}, errors.New("NEX buffer length longer than data size") - } - - data := s.ReadBytesNext(int64(length)) - - return data, nil -} - -// ReadQBuffer reads a nex qBuffer type -func (s *StreamIn) ReadQBuffer() ([]byte, error) { - length, err := s.ReadUInt16LE() - if err != nil { - return []byte{}, fmt.Errorf("Failed to read NEX qBuffer length. %s", err.Error()) - } - - if s.Remaining() < int(length) { - return []byte{}, errors.New("NEX qBuffer length longer than data size") - } - - data := s.ReadBytesNext(int64(length)) - - return data, nil -} - -// ReadVariant reads a Variant type. This type can hold 7 different types -func (s *StreamIn) ReadVariant() (*Variant, error) { - variant := NewVariant() - - err := variant.ExtractFromStream(s) - if err != nil { - return nil, fmt.Errorf("Failed to read Variant. %s", err.Error()) - } - - return variant, nil -} - -// ReadDateTime reads a DateTime type -func (s *StreamIn) ReadDateTime() (*DateTime, error) { - value, err := s.ReadUInt64LE() - if err != nil { - return nil, fmt.Errorf("Failed to read DateTime value. %s", err.Error()) - } - - return NewDateTime(value), nil -} - -// ReadDataHolder reads a DataHolder type -func (s *StreamIn) ReadDataHolder() (*DataHolder, error) { - dataHolder := NewDataHolder() - err := dataHolder.ExtractFromStream(s) - if err != nil { - return nil, fmt.Errorf("Failed to read DateHolder. %s", err.Error()) - } - - return dataHolder, nil -} - -// ReadStationURL reads a StationURL type -func (s *StreamIn) ReadStationURL() (*StationURL, error) { - stationString, err := s.ReadString() - if err != nil { - return nil, fmt.Errorf("Failed to read StationURL. %s", err.Error()) - } - - return NewStationURL(stationString), nil -} - -// ReadQUUID reads a qUUID type -func (s *StreamIn) ReadQUUID() (*QUUID, error) { - qUUID := NewQUUID() - - err := qUUID.ExtractFromStream(s) - if err != nil { - return nil, fmt.Errorf("Failed to read qUUID. %s", err.Error()) - } - - return qUUID, nil -} - -// ReadListUInt8 reads a list of uint8 types -func (s *StreamIn) ReadListUInt8() ([]uint8, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]uint8, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadUInt8() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListInt8 reads a list of int8 types -func (s *StreamIn) ReadListInt8() ([]int8, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]int8, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadInt8() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListUInt16LE reads a list of Little-Endian encoded uint16 types -func (s *StreamIn) ReadListUInt16LE() ([]uint16, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*2) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]uint16, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadUInt16LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListUInt16BE reads a list of Big-Endian encoded uint16 types -func (s *StreamIn) ReadListUInt16BE() ([]uint16, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*2) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]uint16, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadUInt16BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListInt16LE reads a list of Little-Endian encoded int16 types -func (s *StreamIn) ReadListInt16LE() ([]int16, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*2) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]int16, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadInt16LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListInt16BE reads a list of Big-Endian encoded uint16 types -func (s *StreamIn) ReadListInt16BE() ([]int16, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*2) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]int16, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadInt16BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListUInt32LE reads a list of Little-Endian encoded uint32 types -func (s *StreamIn) ReadListUInt32LE() ([]uint32, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]uint32, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListUInt32BE reads a list of Big-Endian encoded uint32 types -func (s *StreamIn) ReadListUInt32BE() ([]uint32, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]uint32, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadUInt32BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListInt32LE reads a list of Little-Endian encoded int32 types -func (s *StreamIn) ReadListInt32LE() ([]int32, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]int32, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListInt32BE reads a list of Big-Endian encoded int32 types -func (s *StreamIn) ReadListInt32BE() ([]int32, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]int32, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadInt32BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListUInt64LE reads a list of Little-Endian encoded uint64 types -func (s *StreamIn) ReadListUInt64LE() ([]uint64, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*8) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]uint64, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadUInt64LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListUInt64BE reads a list of Big-Endian encoded uint64 types -func (s *StreamIn) ReadListUInt64BE() ([]uint64, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*8) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]uint64, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadUInt64BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListInt64LE reads a list of Little-Endian encoded int64 types -func (s *StreamIn) ReadListInt64LE() ([]int64, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*8) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]int64, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadInt64LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListInt64BE reads a list of Big-Endian encoded int64 types -func (s *StreamIn) ReadListInt64BE() ([]int64, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*8) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]int64, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadInt64BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListFloat32LE reads a list of Little-Endian encoded float32 types -func (s *StreamIn) ReadListFloat32LE() ([]float32, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]float32, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadFloat32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListFloat32BE reads a list of Big-Endian encoded float32 types -func (s *StreamIn) ReadListFloat32BE() ([]float32, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]float32, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadFloat32BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListFloat64LE reads a list of Little-Endian encoded float64 types -func (s *StreamIn) ReadListFloat64LE() ([]float64, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]float64, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadFloat64LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListFloat64BE reads a list of Big-Endian encoded float64 types -func (s *StreamIn) ReadListFloat64BE() ([]float64, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - if s.Remaining() < int(length*4) { - return nil, errors.New("NEX List length longer than data size") - } - - list := make([]float64, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadFloat64BE() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListPID reads a list of NEX PIDs -func (s *StreamIn) ReadListPID() ([]*PID, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - list := make([]*PID, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadPID() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListString reads a list of NEX String types -func (s *StreamIn) ReadListString() ([]string, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - list := make([]string, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadString() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListBuffer reads a list of NEX Buffer types -func (s *StreamIn) ReadListBuffer() ([][]byte, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - list := make([][]byte, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadBuffer() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListQBuffer reads a list of NEX qBuffer types -func (s *StreamIn) ReadListQBuffer() ([][]byte, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - list := make([][]byte, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadQBuffer() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListStationURL reads a list of NEX Station URL types -func (s *StreamIn) ReadListStationURL() ([]*StationURL, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - list := make([]*StationURL, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadStationURL() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListDataHolder reads a list of NEX DataHolder types -func (s *StreamIn) ReadListDataHolder() ([]*DataHolder, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - list := make([]*DataHolder, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadDataHolder() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - -// ReadListQUUID reads a list of NEX qUUID types -func (s *StreamIn) ReadListQUUID() ([]*QUUID, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - list := make([]*QUUID, 0, length) - - for i := 0; i < int(length); i++ { - value, err := s.ReadQUUID() - if err != nil { - return nil, fmt.Errorf("Failed to read List value at index %d. %s", i, err.Error()) - } - - list = append(list, value) - } - - return list, nil -} - // NewStreamIn returns a new NEX input stream func NewStreamIn(data []byte, server ServerInterface) *StreamIn { return &StreamIn{ @@ -942,105 +178,3 @@ func NewStreamIn(data []byte, server ServerInterface) *StreamIn { Server: server, } } - -// StreamReadStructure reads a Structure type from a StreamIn -// -// Implemented as a separate function to utilize generics -func StreamReadStructure[T StructureInterface](stream *StreamIn, structure T) (T, error) { - if structure.ParentType() != nil { - //_, err := s.ReadStructure(structure.ParentType()) - _, err := StreamReadStructure(stream, structure.ParentType()) - if err != nil { - return structure, fmt.Errorf("Failed to read structure parent. %s", err.Error()) - } - } - - useStructureHeader := false - - if stream.Server != nil { - switch server := stream.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructureHeader = server.PRUDPMinorVersion >= 3 - default: - useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") - } - } - - if useStructureHeader { - version, err := stream.ReadUInt8() - if err != nil { - return structure, fmt.Errorf("Failed to read NEX Structure version. %s", err.Error()) - } - - structureLength, err := stream.ReadUInt32LE() - if err != nil { - return structure, fmt.Errorf("Failed to read NEX Structure content length. %s", err.Error()) - } - - if stream.Remaining() < int(structureLength) { - return structure, errors.New("NEX Structure content length longer than data size") - } - - structure.SetStructureVersion(version) - } - - err := structure.ExtractFromStream(stream) - if err != nil { - return structure, fmt.Errorf("Failed to read structure from s. %s", err.Error()) - } - - return structure, nil -} - -// StreamReadListStructure reads and returns a list of structure types from a StreamIn -// -// Implemented as a separate function to utilize generics -func StreamReadListStructure[T StructureInterface](stream *StreamIn, structure T) ([]T, error) { - length, err := stream.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read List length. %s", err.Error()) - } - - structures := make([]T, 0, int(length)) - - for i := 0; i < int(length); i++ { - newStructure := structure.Copy() - - extracted, err := StreamReadStructure[T](stream, newStructure.(T)) - if err != nil { - return nil, err - } - - structures = append(structures, extracted) - } - - return structures, nil -} - -// StreamReadMap reads a Map type with the given key and value types from a StreamIn -// -// Implemented as a separate function to utilize generics -func StreamReadMap[K comparable, V any](s *StreamIn, keyReader func() (K, error), valueReader func() (V, error)) (map[K]V, error) { - length, err := s.ReadUInt32LE() - if err != nil { - return nil, fmt.Errorf("Failed to read Map length. %s", err.Error()) - } - - m := make(map[K]V) - - for i := 0; i < int(length); i++ { - key, err := keyReader() - if err != nil { - return nil, err - } - - value, err := valueReader() - if err != nil { - return nil, err - } - - m[key] = value - } - - return m, nil -} diff --git a/stream_out.go b/stream_out.go index 0b8bef46..aec2b1fe 100644 --- a/stream_out.go +++ b/stream_out.go @@ -1,6 +1,7 @@ package nex import ( + "github.com/PretendoNetwork/nex-go/types" crunch "github.com/superwhiskers/crunch/v3" ) @@ -10,484 +11,124 @@ type StreamOut struct { Server ServerInterface } -// WriteUInt8 writes a uint8 -func (s *StreamOut) WriteUInt8(u8 uint8) { - s.Grow(1) - s.WriteByteNext(byte(u8)) -} - -// WriteInt8 writes a int8 -func (s *StreamOut) WriteInt8(s8 int8) { - s.Grow(1) - s.WriteByteNext(byte(s8)) -} - -// WriteUInt16LE writes a uint16 as LE -func (s *StreamOut) WriteUInt16LE(u16 uint16) { - s.Grow(2) - s.WriteU16LENext([]uint16{u16}) -} - -// WriteUInt16BE writes a uint16 as BE -func (s *StreamOut) WriteUInt16BE(u16 uint16) { - s.Grow(2) - s.WriteU16BENext([]uint16{u16}) -} - -// WriteInt16LE writes a uint16 as LE -func (s *StreamOut) WriteInt16LE(s16 int16) { - s.Grow(2) - s.WriteU16LENext([]uint16{uint16(s16)}) -} - -// WriteInt16BE writes a uint16 as BE -func (s *StreamOut) WriteInt16BE(s16 int16) { - s.Grow(2) - s.WriteU16BENext([]uint16{uint16(s16)}) -} - -// WriteUInt32LE writes a uint32 as LE -func (s *StreamOut) WriteUInt32LE(u32 uint32) { - s.Grow(4) - s.WriteU32LENext([]uint32{u32}) -} - -// WriteUInt32BE writes a uint32 as BE -func (s *StreamOut) WriteUInt32BE(u32 uint32) { - s.Grow(4) - s.WriteU32BENext([]uint32{u32}) -} - -// WriteInt32LE writes a int32 as LE -func (s *StreamOut) WriteInt32LE(s32 int32) { - s.Grow(4) - s.WriteU32LENext([]uint32{uint32(s32)}) -} - -// WriteInt32BE writes a int32 as BE -func (s *StreamOut) WriteInt32BE(s32 int32) { - s.Grow(4) - s.WriteU32BENext([]uint32{uint32(s32)}) -} - -// WriteUInt64LE writes a uint64 as LE -func (s *StreamOut) WriteUInt64LE(u64 uint64) { - s.Grow(8) - s.WriteU64LENext([]uint64{u64}) -} - -// WriteUInt64BE writes a uint64 as BE -func (s *StreamOut) WriteUInt64BE(u64 uint64) { - s.Grow(8) - s.WriteU64BENext([]uint64{u64}) -} - -// WriteInt64LE writes a int64 as LE -func (s *StreamOut) WriteInt64LE(s64 int64) { - s.Grow(8) - s.WriteU64LENext([]uint64{uint64(s64)}) -} - -// WriteInt64BE writes a int64 as BE -func (s *StreamOut) WriteInt64BE(s64 int64) { - s.Grow(8) - s.WriteU64BENext([]uint64{uint64(s64)}) -} - -// WriteFloat32LE writes a float32 as LE -func (s *StreamOut) WriteFloat32LE(f32 float32) { - s.Grow(4) - s.WriteF32LENext([]float32{f32}) -} - -// WriteFloat32BE writes a float32 as BE -func (s *StreamOut) WriteFloat32BE(f32 float32) { - s.Grow(4) - s.WriteF32BENext([]float32{f32}) -} - -// WriteFloat64LE writes a float64 as LE -func (s *StreamOut) WriteFloat64LE(f64 float64) { - s.Grow(8) - s.WriteF64LENext([]float64{f64}) -} - -// WriteFloat64BE writes a float64 as BE -func (s *StreamOut) WriteFloat64BE(f64 float64) { - s.Grow(8) - s.WriteF64BENext([]float64{f64}) -} - -// WriteBool writes a bool -func (s *StreamOut) WriteBool(b bool) { - var bVar uint8 - if b { - bVar = 1 - } - s.Grow(1) - s.WriteByteNext(byte(bVar)) -} +// StringLengthSize returns the expected size of String length fields +func (s *StreamOut) StringLengthSize() int { + size := 2 -// WritePID writes a NEX PID. The size depends on the server version -func (s *StreamOut) WritePID(pid *PID) { - if s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - s.WriteUInt64LE(pid.pid) - } else { - s.WriteUInt32LE(uint32(pid.pid)) - } -} - -// WriteString writes a NEX string type -func (s *StreamOut) WriteString(str string) { - str = str + "\x00" - strLength := len(str) - - if s.Server == nil { - s.WriteUInt16LE(uint16(strLength)) - } else if s.Server.StringLengthSize() == 4 { - s.WriteUInt32LE(uint32(strLength)) - } else { - s.WriteUInt16LE(uint16(strLength)) + if s.Server != nil { + size = s.Server.StringLengthSize() } - s.Grow(int64(strLength)) - s.WriteBytesNext([]byte(str)) + return size } -// WriteBuffer writes a NEX Buffer type -func (s *StreamOut) WriteBuffer(data []byte) { - dataLength := len(data) +// PIDSize returns the size of PID types +func (s *StreamOut) PIDSize() int { + size := 4 - s.WriteUInt32LE(uint32(dataLength)) - - if dataLength > 0 { - s.Grow(int64(dataLength)) - s.WriteBytesNext(data) + if s.Server != nil && s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + size = 8 } -} - -// WriteQBuffer writes a NEX qBuffer type -func (s *StreamOut) WriteQBuffer(data []byte) { - dataLength := len(data) - - s.WriteUInt16LE(uint16(dataLength)) - if dataLength > 0 { - s.Grow(int64(dataLength)) - s.WriteBytesNext(data) - } + return size } -// WriteResult writes a NEX Result type -func (s *StreamOut) WriteResult(result *Result) { - s.WriteUInt32LE(result.Code) -} - -// WriteStructure writes a nex Structure type -func (s *StreamOut) WriteStructure(structure StructureInterface) { - if structure.ParentType() != nil { - s.WriteStructure(structure.ParentType()) - } - - content := structure.Bytes(NewStreamOut(s.Server)) - - useStructures := false +// UseStructureHeader determines if Structure headers should be used +func (s *StreamOut) UseStructureHeader() bool { + useStructureHeader := false if s.Server != nil { switch server := s.Server.(type) { case *PRUDPServer: // * Support QRV versions - useStructures = server.PRUDPMinorVersion >= 3 + useStructureHeader = server.PRUDPMinorVersion >= 3 default: - useStructures = server.LibraryVersion().GreaterOrEqual("3.5.0") + useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") } } - if useStructures { - s.WriteUInt8(structure.StructureVersion()) - s.WriteUInt32LE(uint32(len(content))) - } - - s.Grow(int64(len(content))) - s.WriteBytesNext(content) -} - -// WriteStationURL writes a StationURL type -func (s *StreamOut) WriteStationURL(stationURL *StationURL) { - s.WriteString(stationURL.EncodeToString()) -} - -// WriteDataHolder writes a NEX DataHolder type -func (s *StreamOut) WriteDataHolder(dataholder *DataHolder) { - content := dataholder.Bytes(NewStreamOut(s.Server)) - s.Grow(int64(len(content))) - s.WriteBytesNext(content) -} - -// WriteDateTime writes a NEX DateTime type -func (s *StreamOut) WriteDateTime(datetime *DateTime) { - s.WriteUInt64LE(datetime.value) -} - -// WriteVariant writes a Variant type -func (s *StreamOut) WriteVariant(variant *Variant) { - content := variant.Bytes(NewStreamOut(s.Server)) - s.Grow(int64(len(content))) - s.WriteBytesNext(content) -} - -// WriteQUUID writes a qUUID type -func (s *StreamOut) WriteQUUID(qUUID *QUUID) { - qUUID.Bytes(s) -} - -// WriteListUInt8 writes a list of uint8 types -func (s *StreamOut) WriteListUInt8(list []uint8) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteUInt8(list[i]) - } -} - -// WriteListInt8 writes a list of int8 types -func (s *StreamOut) WriteListInt8(list []int8) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteInt8(list[i]) - } -} - -// WriteListUInt16LE writes a list of Little-Endian encoded uint16 types -func (s *StreamOut) WriteListUInt16LE(list []uint16) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteUInt16LE(list[i]) - } -} - -// WriteListUInt16BE writes a list of Big-Endian encoded uint16 types -func (s *StreamOut) WriteListUInt16BE(list []uint16) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteUInt16BE(list[i]) - } -} - -// WriteListInt16LE writes a list of Little-Endian encoded int16 types -func (s *StreamOut) WriteListInt16LE(list []int16) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteInt16LE(list[i]) - } -} - -// WriteListInt16BE writes a list of Big-Endian encoded int16 types -func (s *StreamOut) WriteListInt16BE(list []int16) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteInt16BE(list[i]) - } -} - -// WriteListUInt32LE writes a list of Little-Endian encoded uint32 types -func (s *StreamOut) WriteListUInt32LE(list []uint32) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteUInt32LE(list[i]) - } -} - -// WriteListUInt32BE writes a list of Big-Endian encoded uint32 types -func (s *StreamOut) WriteListUInt32BE(list []uint32) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteUInt32BE(list[i]) - } -} - -// WriteListInt32LE writes a list of Little-Endian encoded int32 types -func (s *StreamOut) WriteListInt32LE(list []int32) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteInt32LE(list[i]) - } + return useStructureHeader } -// WriteListInt32BE writes a list of Big-Endian encoded int32 types -func (s *StreamOut) WriteListInt32BE(list []int32) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteInt32BE(list[i]) - } +// CopyNew returns a copy of the StreamOut but with a blank internal buffer. Returns as types.Writable +func (s *StreamOut) CopyNew() types.Writable { + return NewStreamOut(s.Server) } -// WriteListUInt64LE writes a list of Little-Endian encoded uint64 types -func (s *StreamOut) WriteListUInt64LE(list []uint64) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteUInt64LE(list[i]) - } +// Writes the input data to the end of the StreamOut +func (s *StreamOut) Write(data []byte) { + s.Grow(int64(len(data))) + s.WriteBytesNext(data) } -// WriteListUInt64BE writes a list of Big-Endian encoded uint64 types -func (s *StreamOut) WriteListUInt64BE(list []uint64) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteUInt64BE(list[i]) - } -} - -// WriteListInt64LE writes a list of Little-Endian encoded int64 types -func (s *StreamOut) WriteListInt64LE(list []int64) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteInt64LE(list[i]) - } -} - -// WriteListInt64BE writes a list of Big-Endian encoded int64 types -func (s *StreamOut) WriteListInt64BE(list []int64) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteInt64BE(list[i]) - } -} - -// WriteListFloat32LE writes a list of Little-Endian encoded float32 types -func (s *StreamOut) WriteListFloat32LE(list []float32) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteFloat32LE(list[i]) - } +// WritePrimitiveUInt8 writes a uint8 +func (s *StreamOut) WritePrimitiveUInt8(u8 uint8) { + s.Grow(1) + s.WriteByteNext(byte(u8)) } -// WriteListFloat32BE writes a list of Big-Endian encoded float32 types -func (s *StreamOut) WriteListFloat32BE(list []float32) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteFloat32BE(list[i]) - } +// WritePrimitiveUInt16LE writes a uint16 as LE +func (s *StreamOut) WritePrimitiveUInt16LE(u16 uint16) { + s.Grow(2) + s.WriteU16LENext([]uint16{u16}) } -// WriteListFloat64LE writes a list of Little-Endian encoded float64 types -func (s *StreamOut) WriteListFloat64LE(list []float64) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteFloat64LE(list[i]) - } +// WritePrimitiveUInt32LE writes a uint32 as LE +func (s *StreamOut) WritePrimitiveUInt32LE(u32 uint32) { + s.Grow(4) + s.WriteU32LENext([]uint32{u32}) } -// WriteListFloat64BE writes a list of Big-Endian encoded float64 types -func (s *StreamOut) WriteListFloat64BE(list []float64) { - s.WriteUInt32LE(uint32(len(list))) - - for i := 0; i < len(list); i++ { - s.WriteFloat64BE(list[i]) - } +// WritePrimitiveUInt64LE writes a uint64 as LE +func (s *StreamOut) WritePrimitiveUInt64LE(u64 uint64) { + s.Grow(8) + s.WriteU64LENext([]uint64{u64}) } -// WriteListPID writes a list of NEX PIDs -func (s *StreamOut) WriteListPID(pids []*PID) { - length := len(pids) - - s.WriteUInt32LE(uint32(length)) - - for i := 0; i < length; i++ { - s.WritePID(pids[i]) - } +// WritePrimitiveInt8 writes a int8 +func (s *StreamOut) WritePrimitiveInt8(s8 int8) { + s.Grow(1) + s.WriteByteNext(byte(s8)) } -// WriteListString writes a list of NEX String types -func (s *StreamOut) WriteListString(strings []string) { - length := len(strings) - - s.WriteUInt32LE(uint32(length)) - - for i := 0; i < length; i++ { - s.WriteString(strings[i]) - } +// WritePrimitiveInt16LE writes a uint16 as LE +func (s *StreamOut) WritePrimitiveInt16LE(s16 int16) { + s.Grow(2) + s.WriteU16LENext([]uint16{uint16(s16)}) } -// WriteListBuffer writes a list of NEX Buffer types -func (s *StreamOut) WriteListBuffer(buffers [][]byte) { - length := len(buffers) - - s.WriteUInt32LE(uint32(length)) - - for i := 0; i < length; i++ { - s.WriteBuffer(buffers[i]) - } +// WritePrimitiveInt32LE writes a int32 as LE +func (s *StreamOut) WritePrimitiveInt32LE(s32 int32) { + s.Grow(4) + s.WriteU32LENext([]uint32{uint32(s32)}) } -// WriteListQBuffer writes a list of NEX qBuffer types -func (s *StreamOut) WriteListQBuffer(buffers [][]byte) { - length := len(buffers) - - s.WriteUInt32LE(uint32(length)) - - for i := 0; i < length; i++ { - s.WriteQBuffer(buffers[i]) - } +// WritePrimitiveInt64LE writes a int64 as LE +func (s *StreamOut) WritePrimitiveInt64LE(s64 int64) { + s.Grow(8) + s.WriteU64LENext([]uint64{uint64(s64)}) } -// WriteListResult writes a list of NEX Result types -func (s *StreamOut) WriteListResult(results []*Result) { - length := len(results) - - s.WriteUInt32LE(uint32(length)) - - for i := 0; i < length; i++ { - s.WriteResult(results[i]) - } +// WritePrimitiveFloat32LE writes a float32 as LE +func (s *StreamOut) WritePrimitiveFloat32LE(f32 float32) { + s.Grow(4) + s.WriteF32LENext([]float32{f32}) } -// WriteListStationURL writes a list of NEX StationURL types -func (s *StreamOut) WriteListStationURL(stationURLs []*StationURL) { - length := len(stationURLs) - - s.WriteUInt32LE(uint32(length)) - - for i := 0; i < length; i++ { - s.WriteString(stationURLs[i].EncodeToString()) - } +// WritePrimitiveFloat64LE writes a float64 as LE +func (s *StreamOut) WritePrimitiveFloat64LE(f64 float64) { + s.Grow(8) + s.WriteF64LENext([]float64{f64}) } -// WriteListDataHolder writes a NEX DataHolder type -func (s *StreamOut) WriteListDataHolder(dataholders []*DataHolder) { - length := len(dataholders) - - s.WriteUInt32LE(uint32(length)) - - for i := 0; i < length; i++ { - s.WriteDataHolder(dataholders[i]) +// WritePrimitiveBool writes a bool +func (s *StreamOut) WritePrimitiveBool(b bool) { + var bVar uint8 + if b { + bVar = 1 } -} - -// WriteListQUUID writes a NEX qUUID type -func (s *StreamOut) WriteListQUUID(qUUIDs []*QUUID) { - length := len(qUUIDs) - - s.WriteUInt32LE(uint32(length)) - for i := 0; i < length; i++ { - s.WriteQUUID(qUUIDs[i]) - } + s.Grow(1) + s.WriteByteNext(byte(bVar)) } // NewStreamOut returns a new nex output stream @@ -497,78 +138,3 @@ func NewStreamOut(server ServerInterface) *StreamOut { Server: server, } } - -// StreamWriteListStructure writes a list of structure types to a StreamOut -// -// Implemented as a separate function to utilize generics -func StreamWriteListStructure[T StructureInterface](stream *StreamOut, structures []T) { - count := len(structures) - - stream.WriteUInt32LE(uint32(count)) - - for i := 0; i < count; i++ { - stream.WriteStructure(structures[i]) - } -} - -func mapTypeWriter[T any](stream *StreamOut, t T) { - // * Map types in NEX can have any type for the - // * key and value. So we need to just check the - // * type each time and call the right function - switch v := any(t).(type) { - case uint8: - stream.WriteUInt8(v) - case int8: - stream.WriteInt8(v) - case uint16: - stream.WriteUInt16LE(v) - case int16: - stream.WriteInt16LE(v) - case uint32: - stream.WriteUInt32LE(v) - case int32: - stream.WriteInt32LE(v) - case uint64: - stream.WriteUInt64LE(v) - case int64: - stream.WriteInt64LE(v) - case float32: - stream.WriteFloat32LE(v) - case float64: - stream.WriteFloat64LE(v) - case string: - stream.WriteString(v) - case bool: - stream.WriteBool(v) - case []byte: - // * This actually isn't a good situation, since a byte slice can be either - // * a Buffer or qBuffer. The only known official case is a qBuffer, inside - // * UserAccountManagement::LookupSceNpIds, which is why it's implemented - // * as a qBuffer - stream.WriteQBuffer(v) // TODO - Maybe we should make Buffer and qBuffer real types? - case StructureInterface: - stream.WriteStructure(v) - case *Variant: - stream.WriteVariant(v) - default: - // * Writer functions don't return errors so just log here. - // * The client will disconnect but the server won't die, - // * that way other clients stay connected, but we still - // * have a log of what the error was - logger.Warningf("Unsupported Map type trying to be written: %T\n", v) - } -} - -// StreamWriteMap writes a Map type to a StreamOut -// -// Implemented as a separate function to utilize generics -func StreamWriteMap[K comparable, V any](stream *StreamOut, m map[K]V) { - count := len(m) - - stream.WriteUInt32LE(uint32(count)) - - for key, value := range m { - mapTypeWriter(stream, key) - mapTypeWriter(stream, value) - } -} diff --git a/test/auth.go b/test/auth.go index dd253465..9917a133 100644 --- a/test/auth.go +++ b/test/auth.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/types" ) var authServer *nex.PRUDPServer @@ -50,34 +51,34 @@ func login(packet nex.PRUDPPacketInterface) { parametersStream := nex.NewStreamIn(parameters, authServer) - strUserName, err := parametersStream.ReadString() - if err != nil { + strUserName := types.NewString() + if err := strUserName.ExtractFrom(parametersStream); err != nil { panic(err) } - converted, err := strconv.Atoi(strUserName) + converted, err := strconv.Atoi(string(*strUserName)) if err != nil { panic(err) } - retval := nex.NewResultSuccess(0x00010001) - pidPrincipal := nex.NewPID(uint32(converted)) - pbufResponse := generateTicket(pidPrincipal, nex.NewPID[uint32](2)) - pConnectionData := nex.NewRVConnectionData() - strReturnMsg := "Test Build" + retval := types.NewResultSuccess(0x00010001) + pidPrincipal := types.NewPID(uint32(converted)) + pbufResponse := types.Buffer(generateTicket(pidPrincipal, types.NewPID[uint32](2))) + pConnectionData := types.NewRVConnectionData() + strReturnMsg := types.String("Test Build") - pConnectionData.StationURL = nex.NewStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") - pConnectionData.SpecialProtocols = []byte{} - pConnectionData.StationURLSpecialProtocols = nex.NewStationURL("") - pConnectionData.Time = nex.NewDateTime(0).Now() + pConnectionData.StationURL = types.NewStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") + pConnectionData.SpecialProtocols = types.NewList(types.NewPrimitiveU8()) + pConnectionData.StationURLSpecialProtocols = types.NewStationURL("") + pConnectionData.Time = types.NewDateTime(0).Now() responseStream := nex.NewStreamOut(authServer) - responseStream.WriteResult(retval) - responseStream.WritePID(pidPrincipal) - responseStream.WriteBuffer(pbufResponse) - responseStream.WriteStructure(pConnectionData) - responseStream.WriteString(strReturnMsg) + retval.WriteTo(responseStream) + pidPrincipal.WriteTo(responseStream) + pbufResponse.WriteTo(responseStream) + pConnectionData.WriteTo(responseStream) + strReturnMsg.WriteTo(responseStream) response.IsSuccess = true response.IsRequest = false @@ -111,23 +112,23 @@ func requestTicket(packet nex.PRUDPPacketInterface) { parametersStream := nex.NewStreamIn(parameters, authServer) - idSource, err := parametersStream.ReadPID() - if err != nil { + idSource := types.NewPID[uint64](0) + if err := idSource.ExtractFrom(parametersStream); err != nil { panic(err) } - idTarget, err := parametersStream.ReadPID() - if err != nil { + idTarget := types.NewPID[uint64](0) + if err := idTarget.ExtractFrom(parametersStream); err != nil { panic(err) } - retval := nex.NewResultSuccess(0x00010001) - pbufResponse := generateTicket(idSource, idTarget) + retval := types.NewResultSuccess(0x00010001) + pbufResponse := types.Buffer(generateTicket(idSource, idTarget)) responseStream := nex.NewStreamOut(authServer) - responseStream.WriteResult(retval) - responseStream.WriteBuffer(pbufResponse) + retval.WriteTo(responseStream) + pbufResponse.WriteTo(responseStream) response.IsSuccess = true response.IsRequest = false diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 1e49ee46..78ae45fb 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -4,10 +4,11 @@ import ( "crypto/rand" "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/types" ) -func generateTicket(userPID *nex.PID, targetPID *nex.PID) []byte { - userKey := nex.DeriveKerberosKey(userPID, []byte("abcdefghijklmnop")) +func generateTicket(userPID *types.PID, targetPID *types.PID) []byte { + userKey := nex.DeriveKerberosKey(userPID, []byte("z5sykuHnX0q5SCJN")) targetKey := nex.DeriveKerberosKey(targetPID, []byte("password")) sessionKey := make([]byte, authServer.KerberosKeySize()) @@ -17,7 +18,7 @@ func generateTicket(userPID *nex.PID, targetPID *nex.PID) []byte { } ticketInternalData := nex.NewKerberosTicketInternalData() - serverTime := nex.NewDateTime(0).Now() + serverTime := types.NewDateTime(0).Now() ticketInternalData.Issued = serverTime ticketInternalData.SourcePID = userPID @@ -25,10 +26,14 @@ func generateTicket(userPID *nex.PID, targetPID *nex.PID) []byte { encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewStreamOut(authServer)) + encryptedTicketInternalDataBuffer := types.NewBuffer() + + *encryptedTicketInternalDataBuffer = encryptedTicketInternalData + ticket := nex.NewKerberosTicket() ticket.SessionKey = sessionKey ticket.TargetPID = targetPID - ticket.InternalData = encryptedTicketInternalData + ticket.InternalData = encryptedTicketInternalDataBuffer encryptedTicket, _ := ticket.Encrypt(userKey, nex.NewStreamOut(authServer)) diff --git a/test/main.go b/test/main.go index 12dd8a03..39ac9b22 100644 --- a/test/main.go +++ b/test/main.go @@ -5,7 +5,7 @@ import "sync" var wg sync.WaitGroup func main() { - wg.Add(2) + wg.Add(3) go startAuthenticationServer() go startSecureServer() diff --git a/test/secure.go b/test/secure.go index 4ad384c3..0fea3443 100644 --- a/test/secure.go +++ b/test/secure.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/types" ) var secureServer *nex.PRUDPServer @@ -13,35 +14,31 @@ var secureServer *nex.PRUDPServer // * Took these structs out of the protocols lib for convenience type principalPreference struct { - nex.Structure - *nex.Data + types.Structure + *types.NullData ShowOnlinePresence bool ShowCurrentTitle bool BlockFriendRequests bool } -func (pp *principalPreference) Bytes(stream *nex.StreamOut) []byte { - stream.WriteBool(pp.ShowOnlinePresence) - stream.WriteBool(pp.ShowCurrentTitle) - stream.WriteBool(pp.BlockFriendRequests) - - return stream.Bytes() +func (pp *principalPreference) WriteTo(stream *nex.StreamOut) { + stream.WritePrimitiveBool(pp.ShowOnlinePresence) + stream.WritePrimitiveBool(pp.ShowCurrentTitle) + stream.WritePrimitiveBool(pp.BlockFriendRequests) } type comment struct { - nex.Structure - *nex.Data + types.Structure + *types.NullData Unknown uint8 - Contents string - LastChanged *nex.DateTime + Contents *types.String + LastChanged *types.DateTime } -func (c *comment) Bytes(stream *nex.StreamOut) []byte { - stream.WriteUInt8(c.Unknown) - stream.WriteString(c.Contents) - stream.WriteDateTime(c.LastChanged) - - return stream.Bytes() +func (c *comment) WriteTo(stream *nex.StreamOut) { + stream.WritePrimitiveUInt8(c.Unknown) + c.Contents.WriteTo(stream) + c.LastChanged.WriteTo(stream) } func startSecureServer() { @@ -93,31 +90,31 @@ func registerEx(packet nex.PRUDPPacketInterface) { parametersStream := nex.NewStreamIn(parameters, secureServer) - vecMyURLs, err := parametersStream.ReadListStationURL() - if err != nil { + vecMyURLs := types.NewList(types.NewStationURL("")) + if err := vecMyURLs.ExtractFrom(parametersStream); err != nil { panic(err) } - _, err = parametersStream.ReadDataHolder() - if err != nil { + hCustomData := types.NewAnyDataHolder() + if err := hCustomData.ExtractFrom(parametersStream); err != nil { fmt.Println(err) } - localStation := vecMyURLs[0] + localStation, _ := vecMyURLs.Get(0) address := packet.Sender().Address().(*net.UDPAddr).IP.String() - localStation.Fields.Set("address", address) - localStation.Fields.Set("port", strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port)) + localStation.Fields["address"] = address + localStation.Fields["port"] = strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port) - retval := nex.NewResultSuccess(0x00010001) - localStationURL := localStation.EncodeToString() + retval := types.NewResultSuccess(0x00010001) + localStationURL := types.String(localStation.EncodeToString()) responseStream := nex.NewStreamOut(secureServer) - responseStream.WriteResult(retval) - responseStream.WriteUInt32LE(secureServer.ConnectionIDCounter().Next()) - responseStream.WriteString(localStationURL) + retval.WriteTo(responseStream) + responseStream.WritePrimitiveUInt32LE(secureServer.ConnectionIDCounter().Next()) + localStationURL.WriteTo(responseStream) response.IsSuccess = true response.IsRequest = false @@ -149,23 +146,23 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { responseStream := nex.NewStreamOut(secureServer) - responseStream.WriteStructure(&principalPreference{ + (&principalPreference{ ShowOnlinePresence: true, ShowCurrentTitle: true, BlockFriendRequests: false, - }) - responseStream.WriteStructure(&comment{ + }).WriteTo(responseStream) + (&comment{ Unknown: 0, - Contents: "Rewrite Test", - LastChanged: nex.NewDateTime(0), - }) - responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendList) - responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsOut) - responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsIn) - responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(blockList) - responseStream.WriteBool(false) // * Unknown - responseStream.WriteUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(notifications) - responseStream.WriteBool(false) // * Unknown + Contents: types.NewString(), + LastChanged: types.NewDateTime(0), + }).WriteTo(responseStream) + responseStream.WritePrimitiveUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendList) + responseStream.WritePrimitiveUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsOut) + responseStream.WritePrimitiveUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendRequestsIn) + responseStream.WritePrimitiveUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(blockList) + responseStream.WritePrimitiveBool(false) // * Unknown + responseStream.WritePrimitiveUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(notifications) + responseStream.WritePrimitiveBool(false) // * Unknown response.IsSuccess = true response.IsRequest = false @@ -197,7 +194,7 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { responseStream := nex.NewStreamOut(secureServer) - responseStream.WriteUInt8(0) // * Unknown + responseStream.WritePrimitiveUInt8(0) // * Unknown response.IsSuccess = true response.IsRequest = false diff --git a/types.go b/types.go deleted file mode 100644 index 852281a4..00000000 --- a/types.go +++ /dev/null @@ -1,1299 +0,0 @@ -package nex - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - "slices" - "strings" - "time" -) - -// PID represents a unique number to identify a user -// -// The true size of this value depends on the client version. -// Legacy clients (WiiU/3DS) use a uint32, whereas new clients (Nintendo Switch) use a uint64. -// Value is always stored as the higher uint64, the consuming API should assert accordingly -type PID struct { - pid uint64 -} - -// Value returns the numeric value of the PID as a uint64 regardless of client version -func (p *PID) Value() uint64 { - return p.pid -} - -// LegacyValue returns the numeric value of the PID as a uint32, for legacy clients -func (p *PID) LegacyValue() uint32 { - return uint32(p.pid) -} - -// Equals checks if the two structs are equal -func (p *PID) Equals(other *PID) bool { - return p.pid == other.pid -} - -// Copy returns a copy of the current PID -func (p *PID) Copy() *PID { - return NewPID(p.pid) -} - -// String returns a string representation of the struct -func (p *PID) String() string { - return p.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (p *PID) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("PID{\n") - - switch v := any(p.pid).(type) { - case uint32: - b.WriteString(fmt.Sprintf("%spid: %d (legacy)\n", indentationValues, v)) - case uint64: - b.WriteString(fmt.Sprintf("%spid: %d (modern)\n", indentationValues, v)) - } - - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewPID returns a PID instance. The size of PID depends on the client version -func NewPID[T uint32 | uint64](pid T) *PID { - switch v := any(pid).(type) { - case uint32: - return &PID{pid: uint64(v)} - case uint64: - return &PID{pid: v} - } - - // * This will never happen because Go will - // * not compile any code where "pid" is not - // * a uint32/uint64, so it will ALWAYS get - // * caught by the above switch-case. This - // * return is only here because Go won't - // * compile without a default return - return nil -} - -// StructureInterface implements all Structure methods -type StructureInterface interface { - SetParentType(StructureInterface) - ParentType() StructureInterface - SetStructureVersion(uint8) - StructureVersion() uint8 - ExtractFromStream(*StreamIn) error - Bytes(*StreamOut) []byte - Copy() StructureInterface - Equals(StructureInterface) bool - FormatToString(int) string -} - -// Structure represents a nex Structure type -type Structure struct { - parentType StructureInterface - structureVersion uint8 - StructureInterface -} - -// SetParentType sets the Structures parent type -func (structure *Structure) SetParentType(parentType StructureInterface) { - structure.parentType = parentType -} - -// ParentType returns the Structures parent type. nil if the type does not inherit another Structure -func (structure *Structure) ParentType() StructureInterface { - return structure.parentType -} - -// SetStructureVersion sets the structures version. Only used in NEX 3.5+ -func (structure *Structure) SetStructureVersion(version uint8) { - structure.structureVersion = version -} - -// StructureVersion returns the structures version. Only used in NEX 3.5+ -func (structure *Structure) StructureVersion() uint8 { - return structure.structureVersion -} - -// Data represents a structure with no data -type Data struct { - Structure -} - -// ExtractFromStream does nothing for Data -func (data *Data) ExtractFromStream(stream *StreamIn) error { - // Basically do nothing. Does a relative seek with 0 - stream.SeekByte(0, true) - - return nil -} - -// Bytes does nothing for Data -func (data *Data) Bytes(stream *StreamOut) []byte { - return stream.Bytes() -} - -// Copy returns a new copied instance of Data -func (data *Data) Copy() StructureInterface { - copied := NewData() - - copied.SetStructureVersion(data.StructureVersion()) - - return copied -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (data *Data) Equals(structure StructureInterface) bool { - return data.StructureVersion() == structure.StructureVersion() -} - -// String returns a string representation of the struct -func (data *Data) String() string { - return data.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (data *Data) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("Data{\n") - b.WriteString(fmt.Sprintf("%sstructureVersion: %d\n", indentationValues, data.structureVersion)) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewData returns a new Data Structure -func NewData() *Data { - return &Data{} -} - -var dataHolderKnownObjects = make(map[string]StructureInterface) - -// RegisterDataHolderType registers a structure to be a valid type in the DataHolder structure -func RegisterDataHolderType(name string, structure StructureInterface) { - dataHolderKnownObjects[name] = structure -} - -// DataHolder represents a structure which can hold any other structure -type DataHolder struct { - typeName string - length1 uint32 // length of data including length2 - length2 uint32 // length of the actual structure - objectData StructureInterface -} - -// TypeName returns the DataHolder type name -func (dataHolder *DataHolder) TypeName() string { - return dataHolder.typeName -} - -// SetTypeName sets the DataHolder type name -func (dataHolder *DataHolder) SetTypeName(typeName string) { - dataHolder.typeName = typeName -} - -// ObjectData returns the DataHolder internal object data -func (dataHolder *DataHolder) ObjectData() StructureInterface { - return dataHolder.objectData -} - -// SetObjectData sets the DataHolder internal object data -func (dataHolder *DataHolder) SetObjectData(objectData StructureInterface) { - dataHolder.objectData = objectData -} - -// ExtractFromStream extracts a DataHolder structure from a stream -func (dataHolder *DataHolder) ExtractFromStream(stream *StreamIn) error { - var err error - - dataHolder.typeName, err = stream.ReadString() - if err != nil { - return fmt.Errorf("Failed to read DataHolder type name. %s", err.Error()) - } - - dataHolder.length1, err = stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read DataHolder length 1. %s", err.Error()) - } - - dataHolder.length2, err = stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read DataHolder length 2. %s", err.Error()) - } - - dataType := dataHolderKnownObjects[dataHolder.typeName] - if dataType == nil { - // TODO - Should we really log this here, or just pass the error to the caller? - message := fmt.Sprintf("UNKNOWN DATAHOLDER TYPE: %s", dataHolder.typeName) - return errors.New(message) - } - - newObjectInstance := dataType.Copy() - - dataHolder.objectData, err = StreamReadStructure(stream, newObjectInstance) - if err != nil { - return fmt.Errorf("Failed to read DataHolder object data. %s", err.Error()) - } - - return nil -} - -// Bytes encodes the DataHolder and returns a byte array -func (dataHolder *DataHolder) Bytes(stream *StreamOut) []byte { - contentStream := NewStreamOut(stream.Server) - contentStream.WriteStructure(dataHolder.objectData) - content := contentStream.Bytes() - - /* - Technically this way of encoding a DataHolder is "wrong". - It implies the structure of DataHolder is: - - - Name (string) - - Length+4 (uint32) - - Content (Buffer) - - However the structure as defined by the official NEX library is: - - - Name (string) - - Length+4 (uint32) - - Length (uint32) - - Content (bytes) - - It is convenient to treat the last 2 fields as a Buffer type, but - it should be noted that this is not actually the case. - */ - stream.WriteString(dataHolder.typeName) - stream.WriteUInt32LE(uint32(len(content) + 4)) - stream.WriteBuffer(content) - - return stream.Bytes() -} - -// Copy returns a new copied instance of DataHolder -func (dataHolder *DataHolder) Copy() *DataHolder { - copied := NewDataHolder() - - copied.typeName = dataHolder.typeName - copied.length1 = dataHolder.length1 - copied.length2 = dataHolder.length2 - copied.objectData = dataHolder.objectData.Copy() - - return copied -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (dataHolder *DataHolder) Equals(other *DataHolder) bool { - if dataHolder.typeName != other.typeName { - return false - } - - if dataHolder.length1 != other.length1 { - return false - } - - if dataHolder.length2 != other.length2 { - return false - } - - if !dataHolder.objectData.Equals(other.objectData) { - return false - } - - return true -} - -// String returns a string representation of the struct -func (dataHolder *DataHolder) String() string { - return dataHolder.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (dataHolder *DataHolder) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("DataHolder{\n") - b.WriteString(fmt.Sprintf("%stypeName: %s,\n", indentationValues, dataHolder.typeName)) - b.WriteString(fmt.Sprintf("%slength1: %d,\n", indentationValues, dataHolder.length1)) - b.WriteString(fmt.Sprintf("%slength2: %d,\n", indentationValues, dataHolder.length2)) - b.WriteString(fmt.Sprintf("%sobjectData: %s\n", indentationValues, dataHolder.objectData.FormatToString(indentationLevel+1))) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewDataHolder returns a new DataHolder -func NewDataHolder() *DataHolder { - return &DataHolder{} -} - -// RVConnectionData represents a nex RVConnectionData type -type RVConnectionData struct { - Structure - StationURL *StationURL - SpecialProtocols []byte - StationURLSpecialProtocols *StationURL - Time *DateTime -} - -// Bytes encodes the RVConnectionData and returns a byte array -func (rvConnectionData *RVConnectionData) Bytes(stream *StreamOut) []byte { - stream.WriteStationURL(rvConnectionData.StationURL) - stream.WriteListUInt8(rvConnectionData.SpecialProtocols) - stream.WriteStationURL(rvConnectionData.StationURLSpecialProtocols) - - if stream.Server.LibraryVersion().GreaterOrEqual("3.5.0") { - rvConnectionData.SetStructureVersion(1) - stream.WriteDateTime(rvConnectionData.Time) - } - - return stream.Bytes() -} - -// Copy returns a new copied instance of RVConnectionData -func (rvConnectionData *RVConnectionData) Copy() StructureInterface { - copied := NewRVConnectionData() - - copied.SetStructureVersion(rvConnectionData.StructureVersion()) - copied.parentType = rvConnectionData.parentType - copied.StationURL = rvConnectionData.StationURL.Copy() - copied.SpecialProtocols = make([]byte, len(rvConnectionData.SpecialProtocols)) - - copy(copied.SpecialProtocols, rvConnectionData.SpecialProtocols) - - copied.StationURLSpecialProtocols = rvConnectionData.StationURLSpecialProtocols.Copy() - - if rvConnectionData.Time != nil { - copied.Time = rvConnectionData.Time.Copy() - } - - return copied -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (rvConnectionData *RVConnectionData) Equals(structure StructureInterface) bool { - other := structure.(*RVConnectionData) - - if rvConnectionData.StructureVersion() == other.StructureVersion() { - return false - } - - if !rvConnectionData.StationURL.Equals(other.StationURL) { - return false - } - - if !bytes.Equal(rvConnectionData.SpecialProtocols, other.SpecialProtocols) { - return false - } - - if !rvConnectionData.StationURLSpecialProtocols.Equals(other.StationURLSpecialProtocols) { - return false - } - - if rvConnectionData.Time != nil && other.Time == nil { - return false - } - - if rvConnectionData.Time == nil && other.Time != nil { - return false - } - - if rvConnectionData.Time != nil && other.Time != nil { - if !rvConnectionData.Time.Equals(other.Time) { - return false - } - } - - return true -} - -// String returns a string representation of the struct -func (rvConnectionData *RVConnectionData) String() string { - return rvConnectionData.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (rvConnectionData *RVConnectionData) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("RVConnectionData{\n") - b.WriteString(fmt.Sprintf("%sstructureVersion: %d,\n", indentationValues, rvConnectionData.structureVersion)) - b.WriteString(fmt.Sprintf("%sStationURL: %q,\n", indentationValues, rvConnectionData.StationURL.FormatToString(indentationLevel+1))) - b.WriteString(fmt.Sprintf("%sSpecialProtocols: %v,\n", indentationValues, rvConnectionData.SpecialProtocols)) - b.WriteString(fmt.Sprintf("%sStationURLSpecialProtocols: %q,\n", indentationValues, rvConnectionData.StationURLSpecialProtocols.FormatToString(indentationLevel+1))) - - if rvConnectionData.Time != nil { - b.WriteString(fmt.Sprintf("%sTime: %s\n", indentationValues, rvConnectionData.Time.FormatToString(indentationLevel+1))) - } else { - b.WriteString(fmt.Sprintf("%sTime: nil\n", indentationValues)) - } - - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewRVConnectionData returns a new RVConnectionData -func NewRVConnectionData() *RVConnectionData { - rvConnectionData := &RVConnectionData{} - - return rvConnectionData -} - -// DateTime represents a NEX DateTime type -type DateTime struct { - value uint64 -} - -// Make initilizes a DateTime with the input data -func (dt *DateTime) Make(year, month, day, hour, minute, second int) *DateTime { - dt.value = uint64(second | (minute << 6) | (hour << 12) | (day << 17) | (month << 22) | (year << 26)) - - return dt -} - -// FromTimestamp converts a Time timestamp into a NEX DateTime -func (dt *DateTime) FromTimestamp(timestamp time.Time) *DateTime { - year := timestamp.Year() - month := int(timestamp.Month()) - day := timestamp.Day() - hour := timestamp.Hour() - minute := timestamp.Minute() - second := timestamp.Second() - - return dt.Make(year, month, day, hour, minute, second) -} - -// Now returns a NEX DateTime value of the current UTC time -func (dt *DateTime) Now() *DateTime { - return dt.FromTimestamp(time.Now().UTC()) -} - -// Value returns the stored DateTime time -func (dt *DateTime) Value() uint64 { - return dt.value -} - -// Second returns the seconds value stored in the DateTime -func (dt *DateTime) Second() int { - return int(dt.value & 63) -} - -// Minute returns the minutes value stored in the DateTime -func (dt *DateTime) Minute() int { - return int((dt.value >> 6) & 63) -} - -// Hour returns the hours value stored in the DateTime -func (dt *DateTime) Hour() int { - return int((dt.value >> 12) & 31) -} - -// Day returns the day value stored in the DateTime -func (dt *DateTime) Day() int { - return int((dt.value >> 17) & 31) -} - -// Month returns the month value stored in the DateTime -func (dt *DateTime) Month() time.Month { - return time.Month((dt.value >> 22) & 15) -} - -// Year returns the year value stored in the DateTime -func (dt *DateTime) Year() int { - return int(dt.value >> 26) -} - -// Standard returns the DateTime as a standard time.Time -func (dt *DateTime) Standard() time.Time { - return time.Date( - dt.Year(), - dt.Month(), - dt.Day(), - dt.Hour(), - dt.Minute(), - dt.Second(), - 0, - time.UTC, - ) -} - -// Copy returns a new copied instance of DateTime -func (dt *DateTime) Copy() *DateTime { - return NewDateTime(dt.value) -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (dt *DateTime) Equals(other *DateTime) bool { - return dt.value == other.value -} - -// String returns a string representation of the struct -func (dt *DateTime) String() string { - return dt.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (dt *DateTime) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("DateTime{\n") - b.WriteString(fmt.Sprintf("%svalue: %d (%s)\n", indentationValues, dt.value, dt.Standard().Format("2006-01-02 15:04:05"))) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewDateTime returns a new DateTime instance -func NewDateTime(value uint64) *DateTime { - return &DateTime{value: value} -} - -// StationURL contains the data for a NEX station URL. -// Uses pointers to check for nil, 0 is valid -type StationURL struct { - local bool // * Not part of the data structure. Used for easier lookups elsewhere - public bool // * Not part of the data structure. Used for easier lookups elsewhere - Scheme string - Fields *MutexMap[string, string] -} - -// SetLocal marks the StationURL as an local URL -func (s *StationURL) SetLocal() { - s.local = true - s.public = false -} - -// SetPublic marks the StationURL as an public URL -func (s *StationURL) SetPublic() { - s.local = false - s.public = true -} - -// IsLocal checks if the StationURL is a local URL -func (s *StationURL) IsLocal() bool { - return s.local -} - -// IsPublic checks if the StationURL is a public URL -func (s *StationURL) IsPublic() bool { - return s.public -} - -// FromString parses the StationURL data from a string -func (s *StationURL) FromString(str string) { - if str == "" { - return - } - - split := strings.Split(str, ":/") - - s.Scheme = split[0] - - // * Return if there are no fields - if split[1] == "" { - return - } - - fields := strings.Split(split[1], ";") - - for i := 0; i < len(fields); i++ { - field := strings.Split(fields[i], "=") - - key := field[0] - value := field[1] - - s.Fields.Set(key, value) - } -} - -// EncodeToString encodes the StationURL into a string -func (s *StationURL) EncodeToString() string { - // * Don't return anything if no scheme is set - if s.Scheme == "" { - return "" - } - - fields := []string{} - - s.Fields.Each(func(key, value string) bool { - fields = append(fields, fmt.Sprintf("%s=%s", key, value)) - return false - }) - - return s.Scheme + ":/" + strings.Join(fields, ";") -} - -// Copy returns a new copied instance of StationURL -func (s *StationURL) Copy() *StationURL { - return NewStationURL(s.EncodeToString()) -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (s *StationURL) Equals(other *StationURL) bool { - return s.EncodeToString() == other.EncodeToString() -} - -// String returns a string representation of the struct -func (s *StationURL) String() string { - return s.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (s *StationURL) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("StationURL{\n") - b.WriteString(fmt.Sprintf("%surl: %q\n", indentationValues, s.EncodeToString())) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewStationURL returns a new StationURL -func NewStationURL(str string) *StationURL { - stationURL := &StationURL{ - Fields: NewMutexMap[string, string](), - } - - stationURL.FromString(str) - - return stationURL -} - -// Result is sent in methods which query large objects -type Result struct { - Code uint32 -} - -// IsSuccess returns true if the Result is a success -func (result *Result) IsSuccess() bool { - return int(result.Code)&errorMask == 0 -} - -// IsError returns true if the Result is a error -func (result *Result) IsError() bool { - return int(result.Code)&errorMask != 0 -} - -// ExtractFromStream extracts a Result structure from a stream -func (result *Result) ExtractFromStream(stream *StreamIn) error { - code, err := stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read Result code. %s", err.Error()) - } - - result.Code = code - - return nil -} - -// Bytes encodes the Result and returns a byte array -func (result *Result) Bytes(stream *StreamOut) []byte { - stream.WriteUInt32LE(result.Code) - - return stream.Bytes() -} - -// Copy returns a new copied instance of Result -func (result *Result) Copy() *Result { - return NewResult(result.Code) -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (result *Result) Equals(other *Result) bool { - return result.Code == other.Code -} - -// String returns a string representation of the struct -func (result *Result) String() string { - return result.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (result *Result) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("Result{\n") - - if result.IsSuccess() { - b.WriteString(fmt.Sprintf("%scode: %d (success)\n", indentationValues, result.Code)) - } else { - b.WriteString(fmt.Sprintf("%scode: %d (error)\n", indentationValues, result.Code)) - } - - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewResult returns a new Result -func NewResult(code uint32) *Result { - return &Result{code} -} - -// NewResultSuccess returns a new Result set as a success -func NewResultSuccess(code uint32) *Result { - return NewResult(uint32(int(code) & ^errorMask)) -} - -// NewResultError returns a new Result set as an error -func NewResultError(code uint32) *Result { - return NewResult(uint32(int(code) | errorMask)) -} - -// ResultRange is sent in methods which query large objects -type ResultRange struct { - Structure - Offset uint32 - Length uint32 -} - -// ExtractFromStream extracts a ResultRange structure from a stream -func (resultRange *ResultRange) ExtractFromStream(stream *StreamIn) error { - offset, err := stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read ResultRange offset. %s", err.Error()) - } - - length, err := stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read ResultRange length. %s", err.Error()) - } - - resultRange.Offset = offset - resultRange.Length = length - - return nil -} - -// Copy returns a new copied instance of ResultRange -func (resultRange *ResultRange) Copy() StructureInterface { - copied := NewResultRange() - - copied.SetStructureVersion(resultRange.StructureVersion()) - copied.Offset = resultRange.Offset - copied.Length = resultRange.Length - - return copied -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (resultRange *ResultRange) Equals(structure StructureInterface) bool { - other := structure.(*ResultRange) - - if resultRange.StructureVersion() == other.StructureVersion() { - return false - } - - if resultRange.Offset != other.Offset { - return false - } - - if resultRange.Length != other.Length { - return false - } - - return true -} - -// String returns a string representation of the struct -func (resultRange *ResultRange) String() string { - return resultRange.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (resultRange *ResultRange) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("ResultRange{\n") - b.WriteString(fmt.Sprintf("%sstructureVersion: %d,\n", indentationValues, resultRange.structureVersion)) - b.WriteString(fmt.Sprintf("%sOffset: %d,\n", indentationValues, resultRange.Offset)) - b.WriteString(fmt.Sprintf("%sLength: %d\n", indentationValues, resultRange.Length)) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewResultRange returns a new ResultRange -func NewResultRange() *ResultRange { - return &ResultRange{} -} - -// Variant can hold one of 7 types; nil, int64, float64, bool, string, DateTime, or uint64 -type Variant struct { - TypeID uint8 - // * In reality this type does not have this many fields - // * It only stores the type ID and then the value - // * However to get better typing, we opt to store each possible - // * type as it's own field and just check typeID to know which it has - Int64 int64 - Float64 float64 - Bool bool - Str string - DateTime *DateTime - UInt64 uint64 - QUUID *QUUID -} - -// ExtractFromStream extracts a Variant structure from a stream -func (v *Variant) ExtractFromStream(stream *StreamIn) error { - var err error - - v.TypeID, err = stream.ReadUInt8() - if err != nil { - return fmt.Errorf("Failed to read Variant type ID. %s", err.Error()) - } - - // * A type ID of 0 means no value - switch v.TypeID { - case 1: // * sint64 - v.Int64, err = stream.ReadInt64LE() - case 2: // * double - v.Float64, err = stream.ReadFloat64LE() - case 3: // * bool - v.Bool, err = stream.ReadBool() - case 4: // * string - v.Str, err = stream.ReadString() - case 5: // * datetime - v.DateTime, err = stream.ReadDateTime() - case 6: // * uint64 - v.UInt64, err = stream.ReadUInt64LE() - case 7: // * qUUID - v.QUUID, err = stream.ReadQUUID() - } - - // * These errors contain details about each of the values type - // * No need to return special errors for each value type - if err != nil { - return fmt.Errorf("Failed to read Variant value. %s", err.Error()) - } - - return nil -} - -// Bytes encodes the Variant and returns a byte array -func (v *Variant) Bytes(stream *StreamOut) []byte { - stream.WriteUInt8(v.TypeID) - - // * A type ID of 0 means no value - switch v.TypeID { - case 1: // * sint64 - stream.WriteInt64LE(v.Int64) - case 2: // * double - stream.WriteFloat64LE(v.Float64) - case 3: // * bool - stream.WriteBool(v.Bool) - case 4: // * string - stream.WriteString(v.Str) - case 5: // * datetime - stream.WriteDateTime(v.DateTime) - case 6: // * uint64 - stream.WriteUInt64LE(v.UInt64) - case 7: // * qUUID - stream.WriteQUUID(v.QUUID) - } - - return stream.Bytes() -} - -// Copy returns a new copied instance of Variant -func (v *Variant) Copy() *Variant { - copied := NewVariant() - - copied.TypeID = v.TypeID - copied.Int64 = v.Int64 - copied.Float64 = v.Float64 - copied.Bool = v.Bool - copied.Str = v.Str - - if v.DateTime != nil { - copied.DateTime = v.DateTime.Copy() - } - - copied.UInt64 = v.UInt64 - - if v.QUUID != nil { - copied.QUUID = v.QUUID.Copy() - } - - return copied -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (v *Variant) Equals(other *Variant) bool { - if v.TypeID != other.TypeID { - return false - } - - // * A type ID of 0 means no value - switch v.TypeID { - case 0: // * no value, always equal - return true - case 1: // * sint64 - return v.Int64 == other.Int64 - case 2: // * double - return v.Float64 == other.Float64 - case 3: // * bool - return v.Bool == other.Bool - case 4: // * string - return v.Str == other.Str - case 5: // * datetime - return v.DateTime.Equals(other.DateTime) - case 6: // * uint64 - return v.UInt64 == other.UInt64 - case 7: // * qUUID - return v.QUUID.Equals(other.QUUID) - default: // * Something went horribly wrong - return false - } -} - -// String returns a string representation of the struct -func (v *Variant) String() string { - return v.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (v *Variant) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("Variant{\n") - b.WriteString(fmt.Sprintf("%sTypeID: %d\n", indentationValues, v.TypeID)) - - switch v.TypeID { - case 0: // * no value - b.WriteString(fmt.Sprintf("%svalue: nil\n", indentationValues)) - case 1: // * sint64 - b.WriteString(fmt.Sprintf("%svalue: %d\n", indentationValues, v.Int64)) - case 2: // * double - b.WriteString(fmt.Sprintf("%svalue: %g\n", indentationValues, v.Float64)) - case 3: // * bool - b.WriteString(fmt.Sprintf("%svalue: %t\n", indentationValues, v.Bool)) - case 4: // * string - b.WriteString(fmt.Sprintf("%svalue: %q\n", indentationValues, v.Str)) - case 5: // * datetime - b.WriteString(fmt.Sprintf("%svalue: %s\n", indentationValues, v.DateTime.FormatToString(indentationLevel+1))) - case 6: // * uint64 - b.WriteString(fmt.Sprintf("%svalue: %d\n", indentationValues, v.UInt64)) - case 7: // * qUUID - b.WriteString(fmt.Sprintf("%svalue: %s\n", indentationValues, v.QUUID.FormatToString(indentationLevel+1))) - default: - b.WriteString(fmt.Sprintf("%svalue: Unknown\n", indentationValues)) - } - - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewVariant returns a new Variant -func NewVariant() *Variant { - return &Variant{} -} - -// ClassVersionContainer contains version info for structurs used in verbose RMC messages -type ClassVersionContainer struct { - Structure - ClassVersions map[string]uint16 -} - -// ExtractFromStream extracts a ClassVersionContainer structure from a stream -func (cvc *ClassVersionContainer) ExtractFromStream(stream *StreamIn) error { - length, err := stream.ReadUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read ClassVersionContainer length. %s", err.Error()) - } - - for i := 0; i < int(length); i++ { - name, err := stream.ReadString() - if err != nil { - return fmt.Errorf("Failed to read ClassVersionContainer Structure name. %s", err.Error()) - } - - version, err := stream.ReadUInt16LE() - if err != nil { - return fmt.Errorf("Failed to read ClassVersionContainer %s version. %s", name, err.Error()) - } - - cvc.ClassVersions[name] = version - } - - return nil -} - -// Bytes encodes the ClassVersionContainer and returns a byte array -func (cvc *ClassVersionContainer) Bytes(stream *StreamOut) []byte { - stream.WriteUInt32LE(uint32(len(cvc.ClassVersions))) - - for name, version := range cvc.ClassVersions { - stream.WriteString(name) - stream.WriteUInt16LE(version) - } - - return stream.Bytes() -} - -// Copy returns a new copied instance of ClassVersionContainer -func (cvc *ClassVersionContainer) Copy() StructureInterface { - copied := NewClassVersionContainer() - - for name, version := range cvc.ClassVersions { - copied.ClassVersions[name] = version - } - - return copied -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (cvc *ClassVersionContainer) Equals(structure StructureInterface) bool { - other := structure.(*ClassVersionContainer) - - if len(cvc.ClassVersions) != len(other.ClassVersions) { - return false - } - - for name, version1 := range cvc.ClassVersions { - version2, ok := other.ClassVersions[name] - if !ok || version1 != version2 { - return false - } - } - - return true -} - -// String returns a string representation of the struct -func (cvc *ClassVersionContainer) String() string { - return cvc.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (cvc *ClassVersionContainer) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationListValues := strings.Repeat("\t", indentationLevel+2) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("ClassVersionContainer{\n") - b.WriteString(fmt.Sprintf("%sClassVersions: {\n", indentationValues)) - - for name, version := range cvc.ClassVersions { - b.WriteString(fmt.Sprintf("%s%s: %d\n", indentationListValues, name, version)) - } - - b.WriteString(fmt.Sprintf("%s}\n", indentationValues)) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewClassVersionContainer returns a new ClassVersionContainer -func NewClassVersionContainer() *ClassVersionContainer { - return &ClassVersionContainer{ - ClassVersions: make(map[string]uint16), - } -} - -// QUUID represents a QRV qUUID type. This type encodes a UUID in little-endian byte order -type QUUID struct { - Data []byte -} - -// ExtractFromStream extracts a qUUID structure from a stream -func (qu *QUUID) ExtractFromStream(stream *StreamIn) error { - if stream.Remaining() < int(16) { - return errors.New("Not enough data left to read qUUID") - } - - qu.Data = stream.ReadBytesNext(16) - - return nil -} - -// Bytes encodes the qUUID and returns a byte array -func (qu *QUUID) Bytes(stream *StreamOut) []byte { - stream.Grow(int64(len(qu.Data))) - stream.WriteBytesNext(qu.Data) - - return stream.Bytes() -} - -// Copy returns a new copied instance of qUUID -func (qu *QUUID) Copy() *QUUID { - copied := NewQUUID() - - copied.Data = make([]byte, len(qu.Data)) - - copy(copied.Data, qu.Data) - - return copied -} - -// Equals checks if the passed Structure contains the same data as the current instance -func (qu *QUUID) Equals(other *QUUID) bool { - return qu.GetStringValue() == other.GetStringValue() -} - -// String returns a string representation of the struct -func (qu *QUUID) String() string { - return qu.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (qu *QUUID) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("qUUID{\n") - b.WriteString(fmt.Sprintf("%sUUID: %s\n", indentationValues, qu.GetStringValue())) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// GetStringValue returns the UUID encoded in the qUUID -func (qu *QUUID) GetStringValue() string { - // * Create copy of the data since slices.Reverse modifies the slice in-line - data := make([]byte, len(qu.Data)) - copy(data, qu.Data) - - if len(data) != 16 { - // * Default dummy UUID as found in WATCH_DOGS - return "00000000-0000-0000-0000-000000000002" - } - - section1 := data[0:4] - section2 := data[4:6] - section3 := data[6:8] - section4 := data[8:10] - section5_1 := data[10:12] - section5_2 := data[12:14] - section5_3 := data[14:16] - - slices.Reverse(section1) - slices.Reverse(section2) - slices.Reverse(section3) - slices.Reverse(section4) - slices.Reverse(section5_1) - slices.Reverse(section5_2) - slices.Reverse(section5_3) - - var b strings.Builder - - b.WriteString(hex.EncodeToString(section1)) - b.WriteString("-") - b.WriteString(hex.EncodeToString(section2)) - b.WriteString("-") - b.WriteString(hex.EncodeToString(section3)) - b.WriteString("-") - b.WriteString(hex.EncodeToString(section4)) - b.WriteString("-") - b.WriteString(hex.EncodeToString(section5_1)) - b.WriteString(hex.EncodeToString(section5_2)) - b.WriteString(hex.EncodeToString(section5_3)) - - return b.String() -} - -// FromString converts a UUID string to a qUUID -func (qu *QUUID) FromString(uuid string) error { - - sections := strings.Split(uuid, "-") - if len(sections) != 5 { - return fmt.Errorf("Invalid UUID. Not enough sections. Expected 5, got %d", len(sections)) - } - - data := make([]byte, 0, 16) - - var appendSection = func(section string, expectedSize int) error { - sectionBytes, err := hex.DecodeString(section) - if err != nil { - return err - } - - if len(sectionBytes) != expectedSize { - return fmt.Errorf("Unexpected section size. Expected %d, got %d", expectedSize, len(sectionBytes)) - } - - data = append(data, sectionBytes...) - - return nil - } - - if err := appendSection(sections[0], 4); err != nil { - return fmt.Errorf("Failed to read UUID section 1. %s", err.Error()) - } - - if err := appendSection(sections[1], 2); err != nil { - return fmt.Errorf("Failed to read UUID section 2. %s", err.Error()) - } - - if err := appendSection(sections[2], 2); err != nil { - return fmt.Errorf("Failed to read UUID section 3. %s", err.Error()) - } - - if err := appendSection(sections[3], 2); err != nil { - return fmt.Errorf("Failed to read UUID section 4. %s", err.Error()) - } - - if err := appendSection(sections[4], 6); err != nil { - return fmt.Errorf("Failed to read UUID section 5. %s", err.Error()) - } - - slices.Reverse(data[0:4]) - slices.Reverse(data[4:6]) - slices.Reverse(data[6:8]) - slices.Reverse(data[8:10]) - slices.Reverse(data[10:12]) - slices.Reverse(data[12:14]) - slices.Reverse(data[14:16]) - - qu.Data = make([]byte, 0, 16) - - copy(qu.Data, data) - - return nil -} - -// NewQUUID returns a new qUUID -func NewQUUID() *QUUID { - return &QUUID{ - Data: make([]byte, 0, 16), - } -} diff --git a/types/any_data_holder.go b/types/any_data_holder.go new file mode 100644 index 00000000..9964d523 --- /dev/null +++ b/types/any_data_holder.go @@ -0,0 +1,109 @@ +package types + +import ( + "fmt" +) + +// AnyDataHolderObjects holds a mapping of RVTypes that are accessible in a AnyDataHolder +var AnyDataHolderObjects = make(map[string]RVType) + +// RegisterDataHolderType registers a RVType to be accessible in a AnyDataHolder +func RegisterDataHolderType(name string, rvType RVType) { + AnyDataHolderObjects[name] = rvType +} + +// AnyDataHolder is a class which can contain any Structure +type AnyDataHolder struct { + TypeName string // TODO - Replace this with String? + Length1 uint32 // TODO - Replace this with PrimitiveU32? + Length2 uint32 // TODO - Replace this with PrimitiveU32? + ObjectData RVType +} + +// WriteTo writes the AnyDataholder to the given writable +func (adh *AnyDataHolder) WriteTo(writable Writable) { + contentWritable := writable.CopyNew() + + adh.ObjectData.WriteTo(contentWritable) + + objectData := contentWritable.Bytes() + typeName := String(adh.TypeName) + length1 := uint32(len(objectData) + 4) + length2 := uint32(len(objectData)) + + typeName.WriteTo(writable) + writable.WritePrimitiveUInt32LE(length1) + writable.WritePrimitiveUInt32LE(length2) + writable.Write(objectData) +} + +// ExtractFrom extracts the AnyDataholder to the given readable +func (adh *AnyDataHolder) ExtractFrom(readable Readable) error { + var typeName String + + err := typeName.ExtractFrom(readable) + if err != nil { + return fmt.Errorf("Failed to read DanyDataHolder type name. %s", err.Error()) + } + + length1, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read DanyDataHolder length 1. %s", err.Error()) + } + + length2, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read DanyDataHolder length 2. %s", err.Error()) + } + + if _, ok := AnyDataHolderObjects[string(typeName)]; !ok { + return fmt.Errorf("Unknown AnyDataHolder type: %s", string(typeName)) + } + + adh.ObjectData = AnyDataHolderObjects[string(typeName)].Copy() + + if err := adh.ObjectData.ExtractFrom(readable); err != nil { + return fmt.Errorf("Failed to read DanyDataHolder object data. %s", err.Error()) + } + + adh.TypeName = string(typeName) + adh.Length1 = length1 + adh.Length2 = length2 + + return nil +} + +// Copy returns a new copied instance of DataHolder +func (adh *AnyDataHolder) Copy() *AnyDataHolder { + copied := NewAnyDataHolder() + + copied.TypeName = adh.TypeName + copied.Length1 = adh.Length1 + copied.Length2 = adh.Length2 + copied.ObjectData = adh.ObjectData.Copy() + + return copied +} + +// Equals checks if the passed Structure contains the same data as the current instance +func (adh *AnyDataHolder) Equals(other *AnyDataHolder) bool { + if adh.TypeName != other.TypeName { + return false + } + + if adh.Length1 != other.Length1 { + return false + } + + if adh.Length2 != other.Length2 { + return false + } + + return adh.ObjectData.Equals(other.ObjectData) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewAnyDataHolder returns a new AnyDataHolder +func NewAnyDataHolder() *AnyDataHolder { + return &AnyDataHolder{} +} diff --git a/types/buffer.go b/types/buffer.go new file mode 100644 index 00000000..db0cb65b --- /dev/null +++ b/types/buffer.go @@ -0,0 +1,61 @@ +package types + +import ( + "bytes" + "fmt" +) + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// Buffer is a type alias of []byte with receiver methods to conform to RVType +type Buffer []byte // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the []byte to the given writable +func (b *Buffer) WriteTo(writable Writable) { + data := *b + length := len(data) + + writable.WritePrimitiveUInt32LE(uint32(length)) + + if length > 0 { + writable.Write([]byte(data)) + } +} + +// ExtractFrom extracts the Buffer to the given readable +func (b *Buffer) ExtractFrom(readable Readable) error { + length, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read NEX Buffer length. %s", err.Error()) + } + + data, err := readable.Read(uint64(length)) + if err != nil { + return fmt.Errorf("Failed to read NEX Buffer data. %s", err.Error()) + } + + *b = Buffer(data) + + return nil +} + +// Copy returns a pointer to a copy of the Buffer. Requires type assertion when used +func (b Buffer) Copy() RVType { + return &b +} + +// Equals checks if the input is equal in value to the current instance +func (b *Buffer) Equals(o RVType) bool { + if _, ok := o.(*Buffer); !ok { + return false + } + + return bytes.Equal([]byte(*b), []byte(*o.(*Buffer))) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewBuffer returns a new Buffer +func NewBuffer() *Buffer { + var b Buffer + return &b +} diff --git a/types/class_version_container.go b/types/class_version_container.go new file mode 100644 index 00000000..56799228 --- /dev/null +++ b/types/class_version_container.go @@ -0,0 +1,42 @@ +package types + +// ClassVersionContainer contains version info for Structures used in verbose RMC messages +type ClassVersionContainer struct { + Structure + ClassVersions *Map[*String, *PrimitiveU16] +} + +// WriteTo writes the ClassVersionContainer to the given writable +func (cvc *ClassVersionContainer) WriteTo(writable Writable) { + cvc.ClassVersions.WriteTo(writable) +} + +// ExtractFrom extracts the ClassVersionContainer to the given readable +func (cvc *ClassVersionContainer) ExtractFrom(readable Readable) error { + cvc.ClassVersions = NewMap(NewString(), NewPrimitiveU16()) + + return cvc.ClassVersions.ExtractFrom(readable) +} + +// Copy returns a pointer to a copy of the ClassVersionContainer. Requires type assertion when used +func (cvc ClassVersionContainer) Copy() RVType { + copied := NewClassVersionContainer() + copied.ClassVersions = cvc.ClassVersions.Copy().(*Map[*String, *PrimitiveU16]) + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (cvc *ClassVersionContainer) Equals(o RVType) bool { + if _, ok := o.(*ClassVersionContainer); !ok { + return false + } + + return cvc.ClassVersions.Equals(o) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewClassVersionContainer returns a new ClassVersionContainer +func NewClassVersionContainer() *ClassVersionContainer { + return &ClassVersionContainer{} +} diff --git a/types/datetime.go b/types/datetime.go new file mode 100644 index 00000000..68c988b0 --- /dev/null +++ b/types/datetime.go @@ -0,0 +1,141 @@ +package types + +import ( + "fmt" + "strings" + "time" +) + +// DateTime represents a NEX DateTime type +type DateTime struct { + value uint64 // TODO - Replace this with PrimitiveU64? +} + +// WriteTo writes the DateTime to the given writable +func (dt *DateTime) WriteTo(writable Writable) { + writable.WritePrimitiveUInt64LE(dt.value) +} + +// ExtractFrom extracts the DateTime to the given readable +func (dt *DateTime) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveUInt64LE() + if err != nil { + return fmt.Errorf("Failed to read DateTime value. %s", err.Error()) + } + + dt.value = value + + return nil +} + +// Copy returns a new copied instance of DateTime +func (dt DateTime) Copy() RVType { + return NewDateTime(dt.value) +} + +// Equals checks if the input is equal in value to the current instance +func (dt *DateTime) Equals(o RVType) bool { + if _, ok := o.(*DateTime); !ok { + return false + } + + return dt.value == o.(*DateTime).value +} + +// Make initilizes a DateTime with the input data +func (dt *DateTime) Make(year, month, day, hour, minute, second int) *DateTime { + dt.value = uint64(second | (minute << 6) | (hour << 12) | (day << 17) | (month << 22) | (year << 26)) + + return dt +} + +// FromTimestamp converts a Time timestamp into a NEX DateTime +func (dt *DateTime) FromTimestamp(timestamp time.Time) *DateTime { + year := timestamp.Year() + month := int(timestamp.Month()) + day := timestamp.Day() + hour := timestamp.Hour() + minute := timestamp.Minute() + second := timestamp.Second() + + return dt.Make(year, month, day, hour, minute, second) +} + +// Now returns a NEX DateTime value of the current UTC time +func (dt *DateTime) Now() *DateTime { + return dt.FromTimestamp(time.Now().UTC()) +} + +// Value returns the stored DateTime time +func (dt *DateTime) Value() uint64 { + return dt.value +} + +// Second returns the seconds value stored in the DateTime +func (dt *DateTime) Second() int { + return int(dt.value & 63) +} + +// Minute returns the minutes value stored in the DateTime +func (dt *DateTime) Minute() int { + return int((dt.value >> 6) & 63) +} + +// Hour returns the hours value stored in the DateTime +func (dt *DateTime) Hour() int { + return int((dt.value >> 12) & 31) +} + +// Day returns the day value stored in the DateTime +func (dt *DateTime) Day() int { + return int((dt.value >> 17) & 31) +} + +// Month returns the month value stored in the DateTime +func (dt *DateTime) Month() time.Month { + return time.Month((dt.value >> 22) & 15) +} + +// Year returns the year value stored in the DateTime +func (dt *DateTime) Year() int { + return int(dt.value >> 26) +} + +// Standard returns the DateTime as a standard time.Time +func (dt *DateTime) Standard() time.Time { + return time.Date( + dt.Year(), + dt.Month(), + dt.Day(), + dt.Hour(), + dt.Minute(), + dt.Second(), + 0, + time.UTC, + ) +} + +// String returns a string representation of the struct +func (dt *DateTime) String() string { + return dt.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (dt *DateTime) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("DateTime{\n") + b.WriteString(fmt.Sprintf("%svalue: %d (%s)\n", indentationValues, dt.value, dt.Standard().Format("2006-01-02 15:04:05"))) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewDateTime returns a new DateTime instance +func NewDateTime(value uint64) *DateTime { + return &DateTime{value: value} +} diff --git a/types/list.go b/types/list.go new file mode 100644 index 00000000..77fc9e03 --- /dev/null +++ b/types/list.go @@ -0,0 +1,97 @@ +package types + +import "errors" + +// List represents a Quazal Rendez-Vous/NEX List type +type List[T RVType] struct { + real []T + rvType T +} + +// WriteTo writes the bool to the given writable +func (l *List[T]) WriteTo(writable Writable) { + writable.WritePrimitiveUInt32LE(uint32(len(l.real))) + + for _, v := range l.real { + v.WriteTo(writable) + } +} + +// ExtractFrom extracts the bool to the given readable +func (l *List[T]) ExtractFrom(readable Readable) error { + length, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return err + } + + slice := make([]T, 0, length) + + for i := 0; i < int(length); i++ { + value := l.rvType.Copy() + if err := value.ExtractFrom(readable); err != nil { + return err + } + + slice = append(slice, value.(T)) + } + + l.real = slice + + return nil +} + +// Copy returns a pointer to a copy of the List[T]. Requires type assertion when used +func (l List[T]) Copy() RVType { + copied := NewList(l.rvType) + copied.real = make([]T, len(l.real)) + + for i, v := range l.real { + copied.real[i] = v.Copy().(T) + } + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (l *List[T]) Equals(o RVType) bool { + if _, ok := o.(*List[T]); !ok { + return false + } + + other := o.(*List[T]) + + if len(l.real) != len(other.real) { + return false + } + + for i := 0; i < len(l.real); i++ { + if !l.real[i].Equals(other.real[i]) { + return false + } + } + + return true +} + +// Append appends an element to the List internal slice +func (l *List[T]) Append(value T) { + l.real = append(l.real, value) +} + +// Get returns an element at the given index. Returns an error if the index is OOB +func (l *List[T]) Get(index int) (T, error) { + if index < 0 || index >= len(l.real) { + return l.rvType.Copy().(T), errors.New("Index out of bounds") + } + + return l.real[index], nil +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetType"-kind of method? +// NewList returns a new List of the provided type +func NewList[T RVType](rvType T) *List[T] { + return &List[T]{ + real: make([]T, 0), + rvType: rvType.Copy().(T), + } +} diff --git a/types/map.go b/types/map.go new file mode 100644 index 00000000..c798b6d4 --- /dev/null +++ b/types/map.go @@ -0,0 +1,152 @@ +package types + +// Map represents a Quazal Rendez-Vous/NEX Map type +type Map[K RVType, V RVType] struct { + // * Rendez-Vous/NEX MapMap types can have ANY value for the key, but Go requires + // * map keys to implement the "comparable" constraint. This is not possible with + // * RVTypes. We have to either break spec and only allow primitives as Map keys, + // * or store the key/value types indirectly + keys []K + values []V + keyType K + valueType V +} + +// WriteTo writes the bool to the given writable +func (m *Map[K, V]) WriteTo(writable Writable) { + writable.WritePrimitiveUInt32LE(uint32(m.Size())) + + for i := 0; i < len(m.keys); i++ { + m.keys[i].WriteTo(writable) + m.values[i].WriteTo(writable) + } +} + +// ExtractFrom extracts the bool to the given readable +func (m *Map[K, V]) ExtractFrom(readable Readable) error { + length, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return err + } + + keys := make([]K, 0, length) + values := make([]V, 0, length) + + for i := 0; i < int(length); i++ { + key := m.keyType.Copy() + if err := key.ExtractFrom(readable); err != nil { + return err + } + + value := m.valueType.Copy() + if err := value.ExtractFrom(readable); err != nil { + return err + } + + keys = append(keys, value.(K)) + values = append(values, value.(V)) + } + + m.keys = keys + m.values = values + + return nil +} + +// Copy returns a pointer to a copy of the Map[K, V]. Requires type assertion when used +func (m Map[K, V]) Copy() RVType { + copied := NewMap(m.keyType, m.valueType) + copied.keys = make([]K, len(m.keys)) + copied.values = make([]V, len(m.values)) + + for i := 0; i < len(m.keys); i++ { + copied.keys[i] = m.keys[i].Copy().(K) + copied.values[i] = m.values[i].Copy().(V) + } + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (m *Map[K, V]) Equals(o RVType) bool { + if _, ok := o.(*Map[K, V]); !ok { + return false + } + + other := o.(*Map[K, V]) + + if len(m.keys) != len(other.keys) { + return false + } + + if len(m.values) != len(other.values) { + return false + } + + for i := 0; i < len(m.keys); i++ { + if !m.keys[i].Equals(other.keys[i]) { + return false + } + + if !m.values[i].Equals(other.values[i]) { + return false + } + } + + return true +} + +// Set sets an element to the Map internal slices +func (m *Map[K, V]) Set(key K, value V) { + var index int = -1 + + for i := 0; i < len(m.keys); i++ { + if m.keys[i].Equals(key) { + index = i + break + } + } + + // * Replace the element if exists, otherwise push new + if index != -1 { + m.keys[index] = key + m.values[index] = value + } else { + m.keys = append(m.keys, key) + m.values = append(m.values, value) + } +} + +// Get returns an element from the Map. If not found, "ok" is false +func (m *Map[K, V]) Get(key K) (V, bool) { + var index int = -1 + + for i := 0; i < len(m.keys); i++ { + if m.keys[i].Equals(key) { + index = i + break + } + } + + if index != -1 { + return m.values[index], true + } + + return m.valueType.Copy().(V), false +} + +// Size returns the length of the Map +func (m *Map[K, V]) Size() int { + return len(m.keys) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetKeyType"/"SetValueType" kind of methods? +// NewMap returns a new Map of the provided type +func NewMap[K RVType, V RVType](keyType K, valueType V) *Map[K, V] { + return &Map[K, V]{ + keys: make([]K, 0), + values: make([]V, 0), + keyType: keyType.Copy().(K), + valueType: valueType.Copy().(V), + } +} diff --git a/types/null_data.go b/types/null_data.go new file mode 100644 index 00000000..9d9c0428 --- /dev/null +++ b/types/null_data.go @@ -0,0 +1,84 @@ +package types + +import ( + "errors" + "fmt" + "strings" +) + +// NullData is a Structure with no fields +type NullData struct { + Structure +} + +// WriteTo writes the NullData to the given writable +func (nd *NullData) WriteTo(writable Writable) { + if writable.UseStructureHeader() { + writable.WritePrimitiveUInt8(nd.StructureVersion()) + writable.WritePrimitiveUInt32LE(0) + } +} + +// ExtractFrom extracts the NullData to the given readable +func (nd *NullData) ExtractFrom(readable Readable) error { + if readable.UseStructureHeader() { + version, err := readable.ReadPrimitiveUInt8() + if err != nil { + return fmt.Errorf("Failed to read NullData version. %s", err.Error()) + } + + contentLength, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read NullData content length. %s", err.Error()) + } + + if readable.Remaining() < uint64(contentLength) { + return errors.New("NullData content length longer than data size") + } + + nd.SetStructureVersion(version) + } + + return nil +} + +// Copy returns a pointer to a copy of the NullData. Requires type assertion when used +func (nd NullData) Copy() RVType { + copied := NewNullData() + copied.structureVersion = nd.structureVersion + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (nd *NullData) Equals(o RVType) bool { + if _, ok := o.(*NullData); !ok { + return false + } + + return (*nd).structureVersion == (*o.(*NullData)).structureVersion +} + +// String returns a string representation of the struct +func (nd *NullData) String() string { + return nd.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (nd *NullData) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("NullData{\n") + b.WriteString(fmt.Sprintf("%sstructureVersion: %d\n", indentationValues, nd.structureVersion)) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// NewNullData returns a new NullData Structure +func NewNullData() *NullData { + return &NullData{} +} diff --git a/types/pid.go b/types/pid.go new file mode 100644 index 00000000..046aec55 --- /dev/null +++ b/types/pid.go @@ -0,0 +1,116 @@ +package types + +import ( + "fmt" + "strings" +) + +// PID represents a unique number to identify a user +// +// The true size of this value depends on the client version. +// Legacy clients (WiiU/3DS) use a uint32, whereas modern clients (Nintendo Switch) use a uint64. +// Value is always stored as the higher uint64, the consuming API should assert accordingly +type PID struct { + pid uint64 // TODO - Replace this with PrimitiveU64? +} + +// WriteTo writes the bool to the given writable +func (p *PID) WriteTo(writable Writable) { + if writable.PIDSize() == 8 { + writable.WritePrimitiveUInt64LE(p.pid) + } else { + writable.WritePrimitiveUInt32LE(uint32(p.pid)) + } +} + +// ExtractFrom extracts the bool to the given readable +func (p *PID) ExtractFrom(readable Readable) error { + var pid uint64 + var err error + + if readable.PIDSize() == 8 { + pid, err = readable.ReadPrimitiveUInt64LE() + } else { + p, e := readable.ReadPrimitiveUInt32LE() + + pid = uint64(p) + err = e + } + + if err != nil { + return err + } + + p.pid = pid + + return nil +} + +// Copy returns a pointer to a copy of the PID. Requires type assertion when used +func (p PID) Copy() RVType { + return NewPID(p.pid) +} + +// Equals checks if the input is equal in value to the current instance +func (p *PID) Equals(o RVType) bool { + if _, ok := o.(*PID); !ok { + return false + } + + return p.pid == o.(*PID).pid +} + +// Value returns the numeric value of the PID as a uint64 regardless of client version +func (p *PID) Value() uint64 { + return p.pid +} + +// LegacyValue returns the numeric value of the PID as a uint32, for legacy clients +func (p *PID) LegacyValue() uint32 { + return uint32(p.pid) +} + +// String returns a string representation of the struct +func (p *PID) String() string { + return p.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (p *PID) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("PID{\n") + + switch v := any(p.pid).(type) { + case uint32: + b.WriteString(fmt.Sprintf("%spid: %d (legacy)\n", indentationValues, v)) + case uint64: + b.WriteString(fmt.Sprintf("%spid: %d (modern)\n", indentationValues, v)) + } + + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPID returns a PID instance. The size of PID depends on the client version +func NewPID[T uint32 | uint64](pid T) *PID { + switch v := any(pid).(type) { + case uint32: + return &PID{pid: uint64(v)} + case uint64: + return &PID{pid: v} + } + + // * This will never happen because Go will + // * not compile any code where "pid" is not + // * a uint32/uint64, so it will ALWAYS get + // * caught by the above switch-case. This + // * return is only here because Go won't + // * compile without a default return + return nil +} diff --git a/types/primitive_bool.go b/types/primitive_bool.go new file mode 100644 index 00000000..9a0453df --- /dev/null +++ b/types/primitive_bool.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveBool is a type alias of bool with receiver methods to conform to RVType +type PrimitiveBool bool // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the bool to the given writable +func (b *PrimitiveBool) WriteTo(writable Writable) { + writable.WritePrimitiveBool(bool(*b)) +} + +// ExtractFrom extracts the bool to the given readable +func (b *PrimitiveBool) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveBool() + if err != nil { + return err + } + + *b = PrimitiveBool(value) + + return nil +} + +// Copy returns a pointer to a copy of the PrimitiveBool. Requires type assertion when used +func (b PrimitiveBool) Copy() RVType { + return &b +} + +// Equals checks if the input is equal in value to the current instance +func (b *PrimitiveBool) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveBool); !ok { + return false + } + + return *b == *o.(*PrimitiveBool) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveBool returns a new PrimitiveBool +func NewPrimitiveBool() *PrimitiveBool { + var b PrimitiveBool + return &b +} diff --git a/types/primitive_float32.go b/types/primitive_float32.go new file mode 100644 index 00000000..f11854d5 --- /dev/null +++ b/types/primitive_float32.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveF32 is a type alias of float32 with receiver methods to conform to RVType +type PrimitiveF32 float32 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the float32 to the given writable +func (f32 *PrimitiveF32) WriteTo(writable Writable) { + writable.WritePrimitiveFloat32LE(float32(*f32)) +} + +// ExtractFrom extracts the float32 to the given readable +func (f32 *PrimitiveF32) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveFloat32LE() + if err != nil { + return err + } + + *f32 = PrimitiveF32(value) + + return nil +} + +// Copy returns a pointer to a copy of the float32. Requires type assertion when used +func (f32 PrimitiveF32) Copy() RVType { + return &f32 +} + +// Equals checks if the input is equal in value to the current instance +func (f32 *PrimitiveF32) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveF32); !ok { + return false + } + + return *f32 == *o.(*PrimitiveF32) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveF32 returns a new PrimitiveF32 +func NewPrimitiveF32() *PrimitiveF32 { + var f32 PrimitiveF32 + return &f32 +} diff --git a/types/primitive_float64.go b/types/primitive_float64.go new file mode 100644 index 00000000..656edf7f --- /dev/null +++ b/types/primitive_float64.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveF64 is a type alias of float64 with receiver methods to conform to RVType +type PrimitiveF64 float64 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the float64 to the given writable +func (f64 *PrimitiveF64) WriteTo(writable Writable) { + writable.WritePrimitiveFloat64LE(float64(*f64)) +} + +// ExtractFrom extracts the float64 to the given readable +func (f64 *PrimitiveF64) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveFloat64LE() + if err != nil { + return err + } + + *f64 = PrimitiveF64(value) + + return nil +} + +// Copy returns a pointer to a copy of the float64. Requires type assertion when used +func (f64 PrimitiveF64) Copy() RVType { + return &f64 +} + +// Equals checks if the input is equal in value to the current instance +func (f64 *PrimitiveF64) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveF64); !ok { + return false + } + + return *f64 == *o.(*PrimitiveF64) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveF64 returns a new PrimitiveF64 +func NewPrimitiveF64() *PrimitiveF64 { + var f64 PrimitiveF64 + return &f64 +} diff --git a/types/primitive_s16.go b/types/primitive_s16.go new file mode 100644 index 00000000..384bc888 --- /dev/null +++ b/types/primitive_s16.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveS16 is a type alias of int16 with receiver methods to conform to RVType +type PrimitiveS16 int16 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the int16 to the given writable +func (s16 *PrimitiveS16) WriteTo(writable Writable) { + writable.WritePrimitiveInt16LE(int16(*s16)) +} + +// ExtractFrom extracts the int16 to the given readable +func (s16 *PrimitiveS16) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveInt16LE() + if err != nil { + return err + } + + *s16 = PrimitiveS16(value) + + return nil +} + +// Copy returns a pointer to a copy of the int16. Requires type assertion when used +func (s16 PrimitiveS16) Copy() RVType { + return &s16 +} + +// Equals checks if the input is equal in value to the current instance +func (s16 *PrimitiveS16) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveS16); !ok { + return false + } + + return *s16 == *o.(*PrimitiveS16) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveS16 returns a new PrimitiveS16 +func NewPrimitiveS16() *PrimitiveS16 { + var s16 PrimitiveS16 + return &s16 +} diff --git a/types/primitive_s32.go b/types/primitive_s32.go new file mode 100644 index 00000000..b9158872 --- /dev/null +++ b/types/primitive_s32.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveS32 is a type alias of int32 with receiver methods to conform to RVType +type PrimitiveS32 int32 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the int32 to the given writable +func (s32 *PrimitiveS32) WriteTo(writable Writable) { + writable.WritePrimitiveInt32LE(int32(*s32)) +} + +// ExtractFrom extracts the int32 to the given readable +func (s32 *PrimitiveS32) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveInt32LE() + if err != nil { + return err + } + + *s32 = PrimitiveS32(value) + + return nil +} + +// Copy returns a pointer to a copy of the int32. Requires type assertion when used +func (s32 PrimitiveS32) Copy() RVType { + return &s32 +} + +// Equals checks if the input is equal in value to the current instance +func (s32 *PrimitiveS32) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveS32); !ok { + return false + } + + return *s32 == *o.(*PrimitiveS32) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveS32 returns a new PrimitiveS32 +func NewPrimitiveS32() *PrimitiveS32 { + var s32 PrimitiveS32 + return &s32 +} diff --git a/types/primitive_s64.go b/types/primitive_s64.go new file mode 100644 index 00000000..4799d423 --- /dev/null +++ b/types/primitive_s64.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveS64 is a type alias of int64 with receiver methods to conform to RVType +type PrimitiveS64 int64 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the int64 to the given writable +func (s64 *PrimitiveS64) WriteTo(writable Writable) { + writable.WritePrimitiveInt64LE(int64(*s64)) +} + +// ExtractFrom extracts the int64 to the given readable +func (s64 *PrimitiveS64) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveInt64LE() + if err != nil { + return err + } + + *s64 = PrimitiveS64(value) + + return nil +} + +// Copy returns a pointer to a copy of the int64. Requires type assertion when used +func (s64 PrimitiveS64) Copy() RVType { + return &s64 +} + +// Equals checks if the input is equal in value to the current instance +func (s64 *PrimitiveS64) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveS64); !ok { + return false + } + + return *s64 == *o.(*PrimitiveS64) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveS64 returns a new PrimitiveS64 +func NewPrimitiveS64() *PrimitiveS64 { + var s64 PrimitiveS64 + return &s64 +} diff --git a/types/primitive_s8.go b/types/primitive_s8.go new file mode 100644 index 00000000..1189d122 --- /dev/null +++ b/types/primitive_s8.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveS8 is a type alias of int8 with receiver methods to conform to RVType +type PrimitiveS8 int8 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the int8 to the given writable +func (s8 *PrimitiveS8) WriteTo(writable Writable) { + writable.WritePrimitiveInt8(int8(*s8)) +} + +// ExtractFrom extracts the int8 to the given readable +func (s8 *PrimitiveS8) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveInt8() + if err != nil { + return err + } + + *s8 = PrimitiveS8(value) + + return nil +} + +// Copy returns a pointer to a copy of the int8. Requires type assertion when used +func (s8 PrimitiveS8) Copy() RVType { + return &s8 +} + +// Equals checks if the input is equal in value to the current instance +func (s8 *PrimitiveS8) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveS8); !ok { + return false + } + + return *s8 == *o.(*PrimitiveS8) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveS8 returns a new PrimitiveS8 +func NewPrimitiveS8() *PrimitiveS8 { + var s8 PrimitiveS8 + return &s8 +} diff --git a/types/primitive_u16.go b/types/primitive_u16.go new file mode 100644 index 00000000..cafd070e --- /dev/null +++ b/types/primitive_u16.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveU16 is a type alias of uint16 with receiver methods to conform to RVType +type PrimitiveU16 uint16 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the uint16 to the given writable +func (u16 *PrimitiveU16) WriteTo(writable Writable) { + writable.WritePrimitiveUInt16LE(uint16(*u16)) +} + +// ExtractFrom extracts the uint16 to the given readable +func (u16 *PrimitiveU16) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveUInt16LE() + if err != nil { + return err + } + + *u16 = PrimitiveU16(value) + + return nil +} + +// Copy returns a pointer to a copy of the uint16. Requires type assertion when used +func (u16 PrimitiveU16) Copy() RVType { + return &u16 +} + +// Equals checks if the input is equal in value to the current instance +func (u16 *PrimitiveU16) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveU16); !ok { + return false + } + + return *u16 == *o.(*PrimitiveU16) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveU16 returns a new PrimitiveU16 +func NewPrimitiveU16() *PrimitiveU16 { + var u16 PrimitiveU16 + return &u16 +} diff --git a/types/primitive_u32.go b/types/primitive_u32.go new file mode 100644 index 00000000..fb151914 --- /dev/null +++ b/types/primitive_u32.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveU32 is a type alias of uint32 with receiver methods to conform to RVType +type PrimitiveU32 uint32 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the uint32 to the given writable +func (u32 *PrimitiveU32) WriteTo(writable Writable) { + writable.WritePrimitiveUInt32LE(uint32(*u32)) +} + +// ExtractFrom extracts the uint32 to the given readable +func (u32 *PrimitiveU32) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return err + } + + *u32 = PrimitiveU32(value) + + return nil +} + +// Copy returns a pointer to a copy of the uint32. Requires type assertion when used +func (u32 PrimitiveU32) Copy() RVType { + return &u32 +} + +// Equals checks if the input is equal in value to the current instance +func (u32 *PrimitiveU32) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveU32); !ok { + return false + } + + return *u32 == *o.(*PrimitiveU32) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveU32 returns a new PrimitiveU32 +func NewPrimitiveU32() *PrimitiveU32 { + var u32 PrimitiveU32 + return &u32 +} diff --git a/types/primitive_u64.go b/types/primitive_u64.go new file mode 100644 index 00000000..8bd9d880 --- /dev/null +++ b/types/primitive_u64.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveU64 is a type alias of uint64 with receiver methods to conform to RVType +type PrimitiveU64 uint64 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the uint64 to the given writable +func (u64 *PrimitiveU64) WriteTo(writable Writable) { + writable.WritePrimitiveUInt64LE(uint64(*u64)) +} + +// ExtractFrom extracts the uint64 to the given readable +func (u64 *PrimitiveU64) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveUInt64LE() + if err != nil { + return err + } + + *u64 = PrimitiveU64(value) + + return nil +} + +// Copy returns a pointer to a copy of the uint64. Requires type assertion when used +func (u64 PrimitiveU64) Copy() RVType { + return &u64 +} + +// Equals checks if the input is equal in value to the current instance +func (u64 *PrimitiveU64) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveU64); !ok { + return false + } + + return *u64 == *o.(*PrimitiveU64) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveU64 returns a new PrimitiveU64 +func NewPrimitiveU64() *PrimitiveU64 { + var u32 PrimitiveU64 + return &u32 +} diff --git a/types/primitive_u8.go b/types/primitive_u8.go new file mode 100644 index 00000000..52dbeee5 --- /dev/null +++ b/types/primitive_u8.go @@ -0,0 +1,44 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +// PrimitiveU8 is a type alias of uint8 with receiver methods to conform to RVType +type PrimitiveU8 uint8 // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the uint8 to the given writable +func (u8 *PrimitiveU8) WriteTo(writable Writable) { + writable.WritePrimitiveUInt8(uint8(*u8)) +} + +// ExtractFrom extracts the uint8 to the given readable +func (u8 *PrimitiveU8) ExtractFrom(readable Readable) error { + value, err := readable.ReadPrimitiveUInt8() + if err != nil { + return err + } + + *u8 = PrimitiveU8(value) + + return nil +} + +// Copy returns a pointer to a copy of the uint8. Requires type assertion when used +func (u8 PrimitiveU8) Copy() RVType { + return &u8 +} + +// Equals checks if the input is equal in value to the current instance +func (u8 *PrimitiveU8) Equals(o RVType) bool { + if _, ok := o.(*PrimitiveU8); !ok { + return false + } + + return *u8 == *o.(*PrimitiveU8) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewPrimitiveU8 returns a new PrimitiveU8 +func NewPrimitiveU8() *PrimitiveU8 { + var u8 PrimitiveU8 + return &u8 +} diff --git a/types/qbuffer.go b/types/qbuffer.go new file mode 100644 index 00000000..5124069d --- /dev/null +++ b/types/qbuffer.go @@ -0,0 +1,61 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +import ( + "bytes" + "fmt" +) + +// QBuffer is a type alias of []byte with receiver methods to conform to RVType +type QBuffer []byte // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the []byte to the given writable +func (qb *QBuffer) WriteTo(writable Writable) { + data := *qb + length := len(data) + + writable.WritePrimitiveUInt16LE(uint16(length)) + + if length > 0 { + writable.Write([]byte(data)) + } +} + +// ExtractFrom extracts the QBuffer to the given readable +func (qb *QBuffer) ExtractFrom(readable Readable) error { + length, err := readable.ReadPrimitiveUInt16LE() + if err != nil { + return fmt.Errorf("Failed to read NEX qBuffer length. %s", err.Error()) + } + + data, err := readable.Read(uint64(length)) + if err != nil { + return fmt.Errorf("Failed to read NEX qBuffer data. %s", err.Error()) + } + + *qb = QBuffer(data) + + return nil +} + +// Copy returns a pointer to a copy of the qBuffer. Requires type assertion when used +func (qb QBuffer) Copy() RVType { + return &qb +} + +// Equals checks if the input is equal in value to the current instance +func (qb *QBuffer) Equals(o RVType) bool { + if _, ok := o.(*QBuffer); !ok { + return false + } + + return bytes.Equal([]byte(*qb), []byte(*o.(*Buffer))) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewQBuffer returns a new QBuffer +func NewQBuffer() *QBuffer { + var qb QBuffer + return &qb +} diff --git a/types/quuid.go b/types/quuid.go new file mode 100644 index 00000000..41fe5218 --- /dev/null +++ b/types/quuid.go @@ -0,0 +1,181 @@ +package types + +import ( + "encoding/hex" + "errors" + "fmt" + "slices" + "strings" +) + +// QUUID represents a QRV qUUID type. This type encodes a UUID in little-endian byte order +type QUUID struct { + Data []byte +} + +// WriteTo writes the QUUID to the given writable +func (qu *QUUID) WriteTo(writable Writable) { + writable.Write(qu.Data) +} + +// ExtractFrom extracts the QUUID to the given readable +func (qu *QUUID) ExtractFrom(readable Readable) error { + if readable.Remaining() < uint64(16) { + return errors.New("Not enough data left to read qUUID") + } + + qu.Data, _ = readable.Read(16) + + return nil +} + +// Copy returns a new copied instance of qUUID +func (qu *QUUID) Copy() RVType { + copied := NewQUUID() + + copied.Data = make([]byte, len(qu.Data)) + + copy(copied.Data, qu.Data) + + return copied +} + +// Equals checks if the passed Structure contains the same data as the current instance +func (qu *QUUID) Equals(o RVType) bool { + if _, ok := o.(*QUUID); !ok { + return false + } + + return qu.GetStringValue() == (o.(*QUUID)).GetStringValue() +} + +// String returns a string representation of the struct +func (qu *QUUID) String() string { + return qu.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (qu *QUUID) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("qUUID{\n") + b.WriteString(fmt.Sprintf("%sUUID: %s\n", indentationValues, qu.GetStringValue())) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// GetStringValue returns the UUID encoded in the qUUID +func (qu *QUUID) GetStringValue() string { + // * Create copy of the data since slices.Reverse modifies the slice in-line + data := make([]byte, len(qu.Data)) + copy(data, qu.Data) + + if len(data) != 16 { + // * Default dummy UUID as found in WATCH_DOGS + return "00000000-0000-0000-0000-000000000002" + } + + section1 := data[0:4] + section2 := data[4:6] + section3 := data[6:8] + section4 := data[8:10] + section5_1 := data[10:12] + section5_2 := data[12:14] + section5_3 := data[14:16] + + slices.Reverse(section1) + slices.Reverse(section2) + slices.Reverse(section3) + slices.Reverse(section4) + slices.Reverse(section5_1) + slices.Reverse(section5_2) + slices.Reverse(section5_3) + + var b strings.Builder + + b.WriteString(hex.EncodeToString(section1)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section2)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section3)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section4)) + b.WriteString("-") + b.WriteString(hex.EncodeToString(section5_1)) + b.WriteString(hex.EncodeToString(section5_2)) + b.WriteString(hex.EncodeToString(section5_3)) + + return b.String() +} + +// FromString converts a UUID string to a qUUID +func (qu *QUUID) FromString(uuid string) error { + + sections := strings.Split(uuid, "-") + if len(sections) != 5 { + return fmt.Errorf("Invalid UUID. Not enough sections. Expected 5, got %d", len(sections)) + } + + data := make([]byte, 0, 16) + + var appendSection = func(section string, expectedSize int) error { + sectionBytes, err := hex.DecodeString(section) + if err != nil { + return err + } + + if len(sectionBytes) != expectedSize { + return fmt.Errorf("Unexpected section size. Expected %d, got %d", expectedSize, len(sectionBytes)) + } + + data = append(data, sectionBytes...) + + return nil + } + + if err := appendSection(sections[0], 4); err != nil { + return fmt.Errorf("Failed to read UUID section 1. %s", err.Error()) + } + + if err := appendSection(sections[1], 2); err != nil { + return fmt.Errorf("Failed to read UUID section 2. %s", err.Error()) + } + + if err := appendSection(sections[2], 2); err != nil { + return fmt.Errorf("Failed to read UUID section 3. %s", err.Error()) + } + + if err := appendSection(sections[3], 2); err != nil { + return fmt.Errorf("Failed to read UUID section 4. %s", err.Error()) + } + + if err := appendSection(sections[4], 6); err != nil { + return fmt.Errorf("Failed to read UUID section 5. %s", err.Error()) + } + + slices.Reverse(data[0:4]) + slices.Reverse(data[4:6]) + slices.Reverse(data[6:8]) + slices.Reverse(data[8:10]) + slices.Reverse(data[10:12]) + slices.Reverse(data[12:14]) + slices.Reverse(data[14:16]) + + qu.Data = make([]byte, 0, 16) + + copy(qu.Data, data) + + return nil +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewQUUID returns a new qUUID +func NewQUUID() *QUUID { + return &QUUID{ + Data: make([]byte, 0, 16), + } +} diff --git a/types/readable.go b/types/readable.go new file mode 100644 index 00000000..c4dfb709 --- /dev/null +++ b/types/readable.go @@ -0,0 +1,22 @@ +package types + +// Readable represents a struct that types can read from +type Readable interface { + StringLengthSize() int + PIDSize() int + UseStructureHeader() bool + Remaining() uint64 + ReadRemaining() []byte + Read(length uint64) ([]byte, error) + ReadPrimitiveUInt8() (uint8, error) + ReadPrimitiveUInt16LE() (uint16, error) + ReadPrimitiveUInt32LE() (uint32, error) + ReadPrimitiveUInt64LE() (uint64, error) + ReadPrimitiveInt8() (int8, error) + ReadPrimitiveInt16LE() (int16, error) + ReadPrimitiveInt32LE() (int32, error) + ReadPrimitiveInt64LE() (int64, error) + ReadPrimitiveFloat32LE() (float32, error) + ReadPrimitiveFloat64LE() (float64, error) + ReadPrimitiveBool() (bool, error) +} diff --git a/types/result.go b/types/result.go new file mode 100644 index 00000000..ca9ce815 --- /dev/null +++ b/types/result.go @@ -0,0 +1,97 @@ +package types + +import ( + "fmt" + "strings" +) + +var errorMask = 1 << 31 + +// Result is sent in methods which query large objects +type Result struct { + Code uint32 // TODO - Replace this with PrimitiveU32? +} + +// WriteTo writes the Result to the given writable +func (r *Result) WriteTo(writable Writable) { + writable.WritePrimitiveUInt32LE(r.Code) +} + +// ExtractFrom extracts the Result to the given readable +func (r *Result) ExtractFrom(readable Readable) error { + code, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read Result code. %s", err.Error()) + } + + r.Code = code + + return nil +} + +// Copy returns a pointer to a copy of the Result. Requires type assertion when used +func (r Result) Copy() RVType { + return NewResult(r.Code) +} + +// Equals checks if the input is equal in value to the current instance +func (r *Result) Equals(o RVType) bool { + if _, ok := o.(*Result); !ok { + return false + } + + return r.Code == o.(*Result).Code +} + +// IsSuccess returns true if the Result is a success +func (r *Result) IsSuccess() bool { + return int(r.Code)&errorMask == 0 +} + +// IsError returns true if the Result is a error +func (r *Result) IsError() bool { + return int(r.Code)&errorMask != 0 +} + +// String returns a string representation of the struct +func (r *Result) String() string { + return r.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (r *Result) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("Result{\n") + + if r.IsSuccess() { + b.WriteString(fmt.Sprintf("%scode: %d (success)\n", indentationValues, r.Code)) + } else { + b.WriteString(fmt.Sprintf("%scode: %d (error)\n", indentationValues, r.Code)) + } + + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewResult returns a new Result +func NewResult(code uint32) *Result { + return &Result{code} +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewResultSuccess returns a new Result set as a success +func NewResultSuccess(code uint32) *Result { + return NewResult(uint32(int(code) & ^errorMask)) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewResultError returns a new Result set as an error +func NewResultError(code uint32) *Result { + return NewResult(uint32(int(code) | errorMask)) +} diff --git a/types/result_range.go b/types/result_range.go new file mode 100644 index 00000000..fecef3b2 --- /dev/null +++ b/types/result_range.go @@ -0,0 +1,102 @@ +package types + +import ( + "errors" + "fmt" +) + +// ResultRange class which holds information about how to make queries +type ResultRange struct { + Structure + Offset uint32 // TODO - Replace this with PrimitiveU32? + Length uint32 // TODO - Replace this with PrimitiveU32? +} + +// WriteTo writes the ResultRange to the given writable +func (rr *ResultRange) WriteTo(writable Writable) { + contentWritable := writable.CopyNew() + + contentWritable.WritePrimitiveUInt32LE(rr.Offset) + contentWritable.WritePrimitiveUInt32LE(rr.Length) + + content := contentWritable.Bytes() + + if writable.UseStructureHeader() { + writable.WritePrimitiveUInt8(rr.StructureVersion()) + writable.WritePrimitiveUInt32LE(uint32(len(content))) + } + + writable.Write(content) +} + +// ExtractFrom extracts the ResultRange to the given readable +func (rr *ResultRange) ExtractFrom(readable Readable) error { + if readable.UseStructureHeader() { + version, err := readable.ReadPrimitiveUInt8() + if err != nil { + return fmt.Errorf("Failed to read ResultRange version. %s", err.Error()) + } + + contentLength, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read ResultRange content length. %s", err.Error()) + } + + if readable.Remaining() < uint64(contentLength) { + return errors.New("ResultRange content length longer than data size") + } + + rr.SetStructureVersion(version) + } + + offset, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read ResultRange offset. %s", err.Error()) + } + + length, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read ResultRange length. %s", err.Error()) + } + + rr.Offset = offset + rr.Length = length + + return nil +} + +// Copy returns a new copied instance of ResultRange +func (rr *ResultRange) Copy() RVType { + copied := NewResultRange() + + copied.structureVersion = rr.structureVersion + copied.Offset = rr.Offset + copied.Length = rr.Length + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (rr *ResultRange) Equals(o RVType) bool { + if _, ok := o.(*ResultRange); !ok { + return false + } + + other := o.(*ResultRange) + + if rr.structureVersion != other.structureVersion { + return false + } + + if rr.Offset != other.Offset { + return false + } + + return rr.Length == other.Length +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewResultRange returns a new ResultRange +func NewResultRange() *ResultRange { + return &ResultRange{} +} diff --git a/types/rv_connection_data.go b/types/rv_connection_data.go new file mode 100644 index 00000000..9ba35c50 --- /dev/null +++ b/types/rv_connection_data.go @@ -0,0 +1,143 @@ +package types + +import ( + "errors" + "fmt" +) + +// RVConnectionData is a class which holds data about a Rendez-Vous connection +type RVConnectionData struct { + Structure + StationURL *StationURL + SpecialProtocols *List[*PrimitiveU8] + StationURLSpecialProtocols *StationURL + Time *DateTime +} + +// WriteTo writes the RVConnectionData to the given writable +func (rvcd *RVConnectionData) WriteTo(writable Writable) { + contentWritable := writable.CopyNew() + + rvcd.StationURL.WriteTo(contentWritable) + rvcd.SpecialProtocols.WriteTo(contentWritable) + rvcd.StationURLSpecialProtocols.WriteTo(contentWritable) + + if rvcd.structureVersion >= 1 { + rvcd.Time.WriteTo(contentWritable) + } + + content := contentWritable.Bytes() + + if writable.UseStructureHeader() { + writable.WritePrimitiveUInt8(rvcd.StructureVersion()) + writable.WritePrimitiveUInt32LE(uint32(len(content))) + } + + writable.Write(content) +} + +// ExtractFrom extracts the RVConnectionData to the given readable +func (rvcd *RVConnectionData) ExtractFrom(readable Readable) error { + if readable.UseStructureHeader() { + version, err := readable.ReadPrimitiveUInt8() + if err != nil { + return fmt.Errorf("Failed to read RVConnectionData version. %s", err.Error()) + } + + contentLength, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read RVConnectionData content length. %s", err.Error()) + } + + if readable.Remaining() < uint64(contentLength) { + return errors.New("RVConnectionData content length longer than data size") + } + + rvcd.SetStructureVersion(version) + } + + var stationURL *StationURL + specialProtocols := NewList(NewPrimitiveU8()) + var stationURLSpecialProtocols *StationURL + var time *DateTime + + if err := stationURL.ExtractFrom(readable); err != nil { + return fmt.Errorf("Failed to read RVConnectionData StationURL. %s", err.Error()) + } + + if err := specialProtocols.ExtractFrom(readable); err != nil { + return fmt.Errorf("Failed to read SpecialProtocols StationURL. %s", err.Error()) + } + + if err := stationURLSpecialProtocols.ExtractFrom(readable); err != nil { + return fmt.Errorf("Failed to read StationURLSpecialProtocols StationURL. %s", err.Error()) + } + + if rvcd.structureVersion >= 1 { + if err := time.ExtractFrom(readable); err != nil { + return fmt.Errorf("Failed to read Time StationURL. %s", err.Error()) + } + } + + rvcd.StationURL = stationURL + rvcd.SpecialProtocols = specialProtocols + rvcd.StationURLSpecialProtocols = stationURLSpecialProtocols + rvcd.Time = time + + return nil +} + +// Copy returns a new copied instance of RVConnectionData +func (rvcd *RVConnectionData) Copy() RVType { + copied := NewRVConnectionData() + + copied.structureVersion = rvcd.structureVersion + copied.StationURL = rvcd.StationURL.Copy().(*StationURL) + copied.SpecialProtocols = rvcd.SpecialProtocols.Copy().(*List[*PrimitiveU8]) + copied.StationURLSpecialProtocols = rvcd.StationURLSpecialProtocols.Copy().(*StationURL) + + if rvcd.structureVersion >= 1 { + copied.Time = rvcd.Time.Copy().(*DateTime) + } + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (rvcd *RVConnectionData) Equals(o RVType) bool { + if _, ok := o.(*RVConnectionData); !ok { + return false + } + + other := o.(*RVConnectionData) + + if rvcd.structureVersion != other.structureVersion { + return false + } + + if !rvcd.StationURL.Equals(other.StationURL) { + return false + } + + if !rvcd.SpecialProtocols.Equals(other.SpecialProtocols) { + return false + } + + if !rvcd.StationURLSpecialProtocols.Equals(other.StationURLSpecialProtocols) { + return false + } + + if rvcd.structureVersion >= 1 { + if !rvcd.Time.Equals(other.Time) { + return false + } + } + + return true +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewRVConnectionData returns a new RVConnectionData +func NewRVConnectionData() *RVConnectionData { + return &RVConnectionData{} +} diff --git a/types/rv_type.go b/types/rv_type.go new file mode 100644 index 00000000..dea5be61 --- /dev/null +++ b/types/rv_type.go @@ -0,0 +1,10 @@ +// Package types provides types used in Quazal Rendez-Vous/NEX +package types + +// RVType represents a Quazal Rendez-Vous/NEX type. This includes primitives and custom types +type RVType interface { + WriteTo(writable Writable) + ExtractFrom(readable Readable) error + Copy() RVType + Equals(other RVType) bool +} diff --git a/types/station_url.go b/types/station_url.go new file mode 100644 index 00000000..348546ea --- /dev/null +++ b/types/station_url.go @@ -0,0 +1,169 @@ +package types + +import ( + "fmt" + "strings" +) + +// StationURL contains the data for a NEX station URL +type StationURL struct { + local bool // * Not part of the data structure. Used for easier lookups elsewhere + public bool // * Not part of the data structure. Used for easier lookups elsewhere + Scheme string + Fields map[string]string +} + +// WriteTo writes the StationURL to the given writable +func (s *StationURL) WriteTo(writable Writable) { + str := String(s.EncodeToString()) + + str.WriteTo(writable) +} + +// ExtractFrom extracts the StationURL to the given readable +func (s *StationURL) ExtractFrom(readable Readable) error { + var str String + + if err := str.ExtractFrom(readable); err != nil { + return fmt.Errorf("Failed to read StationURL. %s", err.Error()) + } + + s.FromString(string(str)) + + return nil +} + +// Copy returns a new copied instance of StationURL +func (s *StationURL) Copy() RVType { + return NewStationURL(s.EncodeToString()) +} + +// Equals checks if the input is equal in value to the current instance +func (s *StationURL) Equals(o RVType) bool { + if _, ok := o.(*StationURL); !ok { + return false + } + + other := o.(*StationURL) + + if s.local != other.local { + return false + } + + if s.public != other.public { + return false + } + + if s.Scheme != other.Scheme { + return false + } + + if len(s.Fields) != len(other.Fields) { + return false + } + + for key, value1 := range s.Fields { + value2, ok := other.Fields[key] + if !ok || value1 != value2 { + return false + } + } + + return true +} + +// SetLocal marks the StationURL as an local URL +func (s *StationURL) SetLocal() { + s.local = true + s.public = false +} + +// SetPublic marks the StationURL as an public URL +func (s *StationURL) SetPublic() { + s.local = false + s.public = true +} + +// IsLocal checks if the StationURL is a local URL +func (s *StationURL) IsLocal() bool { + return s.local +} + +// IsPublic checks if the StationURL is a public URL +func (s *StationURL) IsPublic() bool { + return s.public +} + +// FromString parses the StationURL data from a string +func (s *StationURL) FromString(str string) { + if str == "" { + return + } + + split := strings.Split(str, ":/") + + s.Scheme = split[0] + + // * Return if there are no fields + if split[1] == "" { + return + } + + fields := strings.Split(split[1], ";") + + for i := 0; i < len(fields); i++ { + field := strings.Split(fields[i], "=") + + key := field[0] + value := field[1] + + s.Fields[key] = value + } +} + +// EncodeToString encodes the StationURL into a string +func (s *StationURL) EncodeToString() string { + // * Don't return anything if no scheme is set + if s.Scheme == "" { + return "" + } + + fields := make([]string, 0) + + for key, value := range s.Fields { + fields = append(fields, fmt.Sprintf("%s=%s", key, value)) + } + + return s.Scheme + ":/" + strings.Join(fields, ";") +} + +// String returns a string representation of the struct +func (s *StationURL) String() string { + return s.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (s *StationURL) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("StationURL{\n") + b.WriteString(fmt.Sprintf("%surl: %q\n", indentationValues, s.EncodeToString())) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewStationURL returns a new StationURL +func NewStationURL(str string) *StationURL { + stationURL := &StationURL{ + Fields: make(map[string]string), + } + + stationURL.FromString(str) + + return stationURL +} diff --git a/types/string.go b/types/string.go new file mode 100644 index 00000000..4be52276 --- /dev/null +++ b/types/string.go @@ -0,0 +1,82 @@ +package types + +// TODO - Should this have a "Value"-kind of method to get the original value? + +import ( + "errors" + "fmt" + "strings" +) + +// String is a type alias of string with receiver methods to conform to RVType +type String string // TODO - Should we make this a struct instead of a type alias? + +// WriteTo writes the String to the given writable +func (s *String) WriteTo(writable Writable) { + str := *s + "\x00" + strLength := len(str) + + if writable.StringLengthSize() == 4 { + writable.WritePrimitiveUInt32LE(uint32(strLength)) + } else { + writable.WritePrimitiveUInt16LE(uint16(strLength)) + } + + writable.Write([]byte(str)) +} + +// ExtractFrom extracts the String to the given readable +func (s *String) ExtractFrom(readable Readable) error { + var length uint64 + var err error + + if readable.StringLengthSize() == 4 { + l, e := readable.ReadPrimitiveUInt32LE() + length = uint64(l) + err = e + } else { + l, e := readable.ReadPrimitiveUInt16LE() + length = uint64(l) + err = e + } + + if err != nil { + return fmt.Errorf("Failed to read NEX string length. %s", err.Error()) + } + + if readable.Remaining() < length { + return errors.New("NEX string length longer than data size") + } + + stringData, err := readable.Read(length) + if err != nil { + return fmt.Errorf("Failed to read NEX string length. %s", err.Error()) + } + + str := strings.TrimRight(string(stringData), "\x00") + + *s = String(str) + + return nil +} + +// Copy returns a pointer to a copy of the String. Requires type assertion when used +func (s String) Copy() RVType { + return &s +} + +// Equals checks if the input is equal in value to the current instance +func (s *String) Equals(o RVType) bool { + if _, ok := o.(*String); !ok { + return false + } + + return *s == *o.(*String) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewString returns a new String +func NewString() *String { + var s String + return &s +} diff --git a/types/structure.go b/types/structure.go new file mode 100644 index 00000000..5c4b3006 --- /dev/null +++ b/types/structure.go @@ -0,0 +1,39 @@ +package types + +// StructureInterface implements all Structure methods +type StructureInterface interface { + SetParentType(parentType StructureInterface) + ParentType() StructureInterface + SetStructureVersion(version uint8) + StructureVersion() uint8 + Copy() StructureInterface + Equals(other StructureInterface) bool + FormatToString(indentationLevel int) string +} + +// Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct +type Structure struct { + parentType StructureInterface + structureVersion uint8 + StructureInterface +} + +// SetParentType sets the Structures parent type +func (s *Structure) SetParentType(parentType StructureInterface) { + s.parentType = parentType +} + +// ParentType returns the Structures parent type. nil if the Structure does not inherit another Structure +func (s *Structure) ParentType() StructureInterface { + return s.parentType +} + +// SetStructureVersion sets the structures version. Only used in NEX 3.5+ +func (s *Structure) SetStructureVersion(version uint8) { + s.structureVersion = version +} + +// StructureVersion returns the structures version. Only used in NEX 3.5+ +func (s *Structure) StructureVersion() uint8 { + return s.structureVersion +} diff --git a/types/variant.go b/types/variant.go new file mode 100644 index 00000000..fc39571c --- /dev/null +++ b/types/variant.go @@ -0,0 +1,74 @@ +package types + +import ( + "fmt" +) + +// VariantTypes holds a mapping of RVTypes that are accessible in a Variant +var VariantTypes = make(map[uint8]RVType) + +// RegisterVariantType registers a RVType to be accessible in a Variant +func RegisterVariantType(id uint8, rvType RVType) { + VariantTypes[id] = rvType +} + +// Variant is a type which can old many other types +type Variant struct { + TypeID uint8 // TODO - Replace this with PrimitiveU8? + Type RVType +} + +// WriteTo writes the Variant to the given writable +func (v *Variant) WriteTo(writable Writable) { + writable.WritePrimitiveUInt8(v.TypeID) + v.Type.WriteTo(writable) +} + +// ExtractFrom extracts the Variant to the given readable +func (v *Variant) ExtractFrom(readable Readable) error { + typeID, err := readable.ReadPrimitiveUInt8() + if err != nil { + return fmt.Errorf("Failed to read Variant type ID. %s", err.Error()) + } + + v.TypeID = typeID + + if _, ok := VariantTypes[v.TypeID]; !ok { + return fmt.Errorf("Invalid Variant type ID %d", v.TypeID) + } + + v.Type = VariantTypes[v.TypeID].Copy() + + return v.Type.ExtractFrom(readable) +} + +// Copy returns a pointer to a copy of the Variant. Requires type assertion when used +func (v *Variant) Copy() RVType { + copied := NewVariant() + + copied.TypeID = v.TypeID + copied.Type = v.Type.Copy() + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (v *Variant) Equals(o RVType) bool { + if _, ok := o.(*Variant); !ok { + return false + } + + other := o.(*Variant) + + if v.TypeID != other.TypeID { + return false + } + + return v.Type.Equals(other.Type) +} + +// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? +// NewVariant returns a new Variant +func NewVariant() *Variant { + return &Variant{} +} diff --git a/types/writable.go b/types/writable.go new file mode 100644 index 00000000..0d3b584b --- /dev/null +++ b/types/writable.go @@ -0,0 +1,22 @@ +package types + +// Writable represents a struct that types can write to +type Writable interface { + StringLengthSize() int + PIDSize() int + UseStructureHeader() bool + CopyNew() Writable + Write(data []byte) + WritePrimitiveUInt8(value uint8) + WritePrimitiveUInt16LE(value uint16) + WritePrimitiveUInt32LE(value uint32) + WritePrimitiveUInt64LE(value uint64) + WritePrimitiveInt8(value int8) + WritePrimitiveInt16LE(value int16) + WritePrimitiveInt32LE(value int32) + WritePrimitiveInt64LE(value int64) + WritePrimitiveFloat32LE(value float32) + WritePrimitiveFloat64LE(value float64) + WritePrimitiveBool(value bool) + Bytes() []byte +} From 86d13513dabb2925e4a02c565da7611209c806af Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 14:39:25 -0500 Subject: [PATCH 091/178] types: remove Map constructor types --- types/class_version_container.go | 5 +++-- types/map.go | 23 +++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/types/class_version_container.go b/types/class_version_container.go index 56799228..ca7a1cbe 100644 --- a/types/class_version_container.go +++ b/types/class_version_container.go @@ -13,7 +13,9 @@ func (cvc *ClassVersionContainer) WriteTo(writable Writable) { // ExtractFrom extracts the ClassVersionContainer to the given readable func (cvc *ClassVersionContainer) ExtractFrom(readable Readable) error { - cvc.ClassVersions = NewMap(NewString(), NewPrimitiveU16()) + cvc.ClassVersions = NewMap[*String, *PrimitiveU16]() + cvc.ClassVersions.KeyType = NewString() + cvc.ClassVersions.ValueType = NewPrimitiveU16() return cvc.ClassVersions.ExtractFrom(readable) } @@ -35,7 +37,6 @@ func (cvc *ClassVersionContainer) Equals(o RVType) bool { return cvc.ClassVersions.Equals(o) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewClassVersionContainer returns a new ClassVersionContainer func NewClassVersionContainer() *ClassVersionContainer { return &ClassVersionContainer{} diff --git a/types/map.go b/types/map.go index c798b6d4..c2888769 100644 --- a/types/map.go +++ b/types/map.go @@ -8,8 +8,8 @@ type Map[K RVType, V RVType] struct { // * or store the key/value types indirectly keys []K values []V - keyType K - valueType V + KeyType K + ValueType V } // WriteTo writes the bool to the given writable @@ -33,12 +33,12 @@ func (m *Map[K, V]) ExtractFrom(readable Readable) error { values := make([]V, 0, length) for i := 0; i < int(length); i++ { - key := m.keyType.Copy() + key := m.KeyType.Copy() if err := key.ExtractFrom(readable); err != nil { return err } - value := m.valueType.Copy() + value := m.ValueType.Copy() if err := value.ExtractFrom(readable); err != nil { return err } @@ -55,9 +55,11 @@ func (m *Map[K, V]) ExtractFrom(readable Readable) error { // Copy returns a pointer to a copy of the Map[K, V]. Requires type assertion when used func (m Map[K, V]) Copy() RVType { - copied := NewMap(m.keyType, m.valueType) + copied := NewMap[K, V]() copied.keys = make([]K, len(m.keys)) copied.values = make([]V, len(m.values)) + copied.KeyType = m.KeyType.Copy().(K) + copied.ValueType = m.ValueType.Copy().(V) for i := 0; i < len(m.keys); i++ { copied.keys[i] = m.keys[i].Copy().(K) @@ -132,7 +134,7 @@ func (m *Map[K, V]) Get(key K) (V, bool) { return m.values[index], true } - return m.valueType.Copy().(V), false + return m.ValueType.Copy().(V), false } // Size returns the length of the Map @@ -140,13 +142,10 @@ func (m *Map[K, V]) Size() int { return len(m.keys) } -// TODO - Should this take in a default value, or take in nothing and have a "SetKeyType"/"SetValueType" kind of methods? // NewMap returns a new Map of the provided type -func NewMap[K RVType, V RVType](keyType K, valueType V) *Map[K, V] { +func NewMap[K RVType, V RVType]() *Map[K, V] { return &Map[K, V]{ - keys: make([]K, 0), - values: make([]V, 0), - keyType: keyType.Copy().(K), - valueType: valueType.Copy().(V), + keys: make([]K, 0), + values: make([]V, 0), } } From ba29ad816a422eba226211ec594c78744d8f304c Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 14:45:56 -0500 Subject: [PATCH 092/178] types: remove List constructor type --- test/auth.go | 3 ++- test/secure.go | 3 ++- types/list.go | 24 +++++++++++++----------- types/rv_connection_data.go | 4 +++- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/test/auth.go b/test/auth.go index 9917a133..54c599a8 100644 --- a/test/auth.go +++ b/test/auth.go @@ -68,7 +68,8 @@ func login(packet nex.PRUDPPacketInterface) { strReturnMsg := types.String("Test Build") pConnectionData.StationURL = types.NewStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") - pConnectionData.SpecialProtocols = types.NewList(types.NewPrimitiveU8()) + pConnectionData.SpecialProtocols = types.NewList[*types.PrimitiveU8]() + pConnectionData.SpecialProtocols.Type = types.NewPrimitiveU8() pConnectionData.StationURLSpecialProtocols = types.NewStationURL("") pConnectionData.Time = types.NewDateTime(0).Now() diff --git a/test/secure.go b/test/secure.go index 0fea3443..50b067ee 100644 --- a/test/secure.go +++ b/test/secure.go @@ -90,7 +90,8 @@ func registerEx(packet nex.PRUDPPacketInterface) { parametersStream := nex.NewStreamIn(parameters, secureServer) - vecMyURLs := types.NewList(types.NewStationURL("")) + vecMyURLs := types.NewList[*types.StationURL]() + vecMyURLs.Type = types.NewStationURL("") if err := vecMyURLs.ExtractFrom(parametersStream); err != nil { panic(err) } diff --git a/types/list.go b/types/list.go index 77fc9e03..4d838b9d 100644 --- a/types/list.go +++ b/types/list.go @@ -4,8 +4,8 @@ import "errors" // List represents a Quazal Rendez-Vous/NEX List type type List[T RVType] struct { - real []T - rvType T + real []T + Type T } // WriteTo writes the bool to the given writable @@ -27,7 +27,7 @@ func (l *List[T]) ExtractFrom(readable Readable) error { slice := make([]T, 0, length) for i := 0; i < int(length); i++ { - value := l.rvType.Copy() + value := l.Type.Copy() if err := value.ExtractFrom(readable); err != nil { return err } @@ -42,8 +42,9 @@ func (l *List[T]) ExtractFrom(readable Readable) error { // Copy returns a pointer to a copy of the List[T]. Requires type assertion when used func (l List[T]) Copy() RVType { - copied := NewList(l.rvType) + copied := NewList[T]() copied.real = make([]T, len(l.real)) + copied.Type = l.Type.Copy().(T) for i, v := range l.real { copied.real[i] = v.Copy().(T) @@ -81,17 +82,18 @@ func (l *List[T]) Append(value T) { // Get returns an element at the given index. Returns an error if the index is OOB func (l *List[T]) Get(index int) (T, error) { if index < 0 || index >= len(l.real) { - return l.rvType.Copy().(T), errors.New("Index out of bounds") + return l.Type.Copy().(T), errors.New("Index out of bounds") } return l.real[index], nil } -// TODO - Should this take in a default value, or take in nothing and have a "SetType"-kind of method? +// SetFromData sets the List's internal slice to the input data +func (l *List[T]) SetFromData(data []T) { + l.real = data +} + // NewList returns a new List of the provided type -func NewList[T RVType](rvType T) *List[T] { - return &List[T]{ - real: make([]T, 0), - rvType: rvType.Copy().(T), - } +func NewList[T RVType]() *List[T] { + return &List[T]{real: make([]T, 0)} } diff --git a/types/rv_connection_data.go b/types/rv_connection_data.go index 9ba35c50..84964e7b 100644 --- a/types/rv_connection_data.go +++ b/types/rv_connection_data.go @@ -57,10 +57,12 @@ func (rvcd *RVConnectionData) ExtractFrom(readable Readable) error { } var stationURL *StationURL - specialProtocols := NewList(NewPrimitiveU8()) + specialProtocols := NewList[*PrimitiveU8]() var stationURLSpecialProtocols *StationURL var time *DateTime + specialProtocols.Type = NewPrimitiveU8() + if err := stationURL.ExtractFrom(readable); err != nil { return fmt.Errorf("Failed to read RVConnectionData StationURL. %s", err.Error()) } From 919cabda7dffa08b404a306fe60a668c6109f30e Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 15:01:55 -0500 Subject: [PATCH 093/178] types: added constructor values to primitive types --- init.go | 10 +++++----- kerberos.go | 4 ++-- prudp_server.go | 4 ++-- rmc_message.go | 6 +++--- test/auth.go | 4 ++-- test/generate_ticket.go | 6 +----- test/secure.go | 2 +- types/buffer.go | 6 +++--- types/class_version_container.go | 4 ++-- types/datetime.go | 1 - types/pid.go | 1 - types/primitive_bool.go | 6 +++--- types/primitive_float32.go | 6 +++--- types/primitive_float64.go | 6 +++--- types/primitive_s16.go | 6 +++--- types/primitive_s32.go | 6 +++--- types/primitive_s64.go | 6 +++--- types/primitive_s8.go | 6 +++--- types/primitive_u16.go | 6 +++--- types/primitive_u32.go | 6 +++--- types/primitive_u64.go | 7 ++++--- types/primitive_u8.go | 6 +++--- types/qbuffer.go | 6 +++--- types/quuid.go | 1 - types/result.go | 3 --- types/rv_connection_data.go | 2 +- types/string.go | 6 +++--- 27 files changed, 62 insertions(+), 71 deletions(-) diff --git a/init.go b/init.go index 10dd3496..e3027b1d 100644 --- a/init.go +++ b/init.go @@ -10,10 +10,10 @@ var logger = plogger.NewLogger() func init() { initErrorsData() - types.RegisterVariantType(1, types.NewPrimitiveS64()) - types.RegisterVariantType(2, types.NewPrimitiveF64()) - types.RegisterVariantType(3, types.NewPrimitiveBool()) - types.RegisterVariantType(4, types.NewString()) + types.RegisterVariantType(1, types.NewPrimitiveS64(0)) + types.RegisterVariantType(2, types.NewPrimitiveF64(0)) + types.RegisterVariantType(3, types.NewPrimitiveBool(false)) + types.RegisterVariantType(4, types.NewString("")) types.RegisterVariantType(5, types.NewDateTime(0)) - types.RegisterVariantType(6, types.NewPrimitiveU64()) + types.RegisterVariantType(6, types.NewPrimitiveU64(0)) } diff --git a/kerberos.go b/kerberos.go index 0bdedb78..14af53d6 100644 --- a/kerberos.go +++ b/kerberos.go @@ -141,12 +141,12 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([] // Decrypt decrypts the given data and populates the struct func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) error { if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { - ticketKey := types.NewBuffer() + ticketKey := types.NewBuffer([]byte{}) if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data key. %s", err.Error()) } - data := types.NewBuffer() + data := types.NewBuffer([]byte{}) if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data. %s", err.Error()) } diff --git a/prudp_server.go b/prudp_server.go index 94e3d615..823f7843 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -515,12 +515,12 @@ func (s *PRUDPServer) handlePing(packet PRUDPPacketInterface) { func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, uint32, error) { stream := NewStreamIn(payload, s) - ticketData := types.NewBuffer() + ticketData := types.NewBuffer([]byte{}) if err := ticketData.ExtractFrom(stream); err != nil { return nil, nil, 0, err } - requestData := types.NewBuffer() + requestData := types.NewBuffer([]byte{}) if err := requestData.ExtractFrom(stream); err != nil { return nil, nil, 0, err } diff --git a/rmc_message.go b/rmc_message.go index 14c5ba5a..78674f62 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -154,7 +154,7 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { return errors.New("RMC Message has unexpected size") } - rmc.ProtocolName = types.NewString() + rmc.ProtocolName = types.NewString("") if err := rmc.ProtocolName.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read RMC Message protocol name. %s", err.Error()) } @@ -170,7 +170,7 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { return fmt.Errorf("Failed to read RMC Message (request) call ID. %s", err.Error()) } - rmc.MethodName = types.NewString() + rmc.MethodName = types.NewString("") if err := rmc.MethodName.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read RMC Message (request) method name. %s", err.Error()) } @@ -196,7 +196,7 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { return fmt.Errorf("Failed to read RMC Message (response) call ID. %s", err.Error()) } - rmc.MethodName = types.NewString() + rmc.MethodName = types.NewString("") if err := rmc.MethodName.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read RMC Message (response) method name. %s", err.Error()) } diff --git a/test/auth.go b/test/auth.go index 54c599a8..6c8cdff5 100644 --- a/test/auth.go +++ b/test/auth.go @@ -51,7 +51,7 @@ func login(packet nex.PRUDPPacketInterface) { parametersStream := nex.NewStreamIn(parameters, authServer) - strUserName := types.NewString() + strUserName := types.NewString("") if err := strUserName.ExtractFrom(parametersStream); err != nil { panic(err) } @@ -69,7 +69,7 @@ func login(packet nex.PRUDPPacketInterface) { pConnectionData.StationURL = types.NewStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") pConnectionData.SpecialProtocols = types.NewList[*types.PrimitiveU8]() - pConnectionData.SpecialProtocols.Type = types.NewPrimitiveU8() + pConnectionData.SpecialProtocols.Type = types.NewPrimitiveU8(0) pConnectionData.StationURLSpecialProtocols = types.NewStationURL("") pConnectionData.Time = types.NewDateTime(0).Now() diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 78ae45fb..9dbeca5e 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -26,14 +26,10 @@ func generateTicket(userPID *types.PID, targetPID *types.PID) []byte { encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewStreamOut(authServer)) - encryptedTicketInternalDataBuffer := types.NewBuffer() - - *encryptedTicketInternalDataBuffer = encryptedTicketInternalData - ticket := nex.NewKerberosTicket() ticket.SessionKey = sessionKey ticket.TargetPID = targetPID - ticket.InternalData = encryptedTicketInternalDataBuffer + ticket.InternalData = types.NewBuffer(encryptedTicketInternalData) encryptedTicket, _ := ticket.Encrypt(userKey, nex.NewStreamOut(authServer)) diff --git a/test/secure.go b/test/secure.go index 50b067ee..4e226793 100644 --- a/test/secure.go +++ b/test/secure.go @@ -154,7 +154,7 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { }).WriteTo(responseStream) (&comment{ Unknown: 0, - Contents: types.NewString(), + Contents: types.NewString("Rewrite Test"), LastChanged: types.NewDateTime(0), }).WriteTo(responseStream) responseStream.WritePrimitiveUInt32LE(0) // * Stubbed empty list. responseStream.WriteListStructure(friendList) diff --git a/types/buffer.go b/types/buffer.go index db0cb65b..2a3bc090 100644 --- a/types/buffer.go +++ b/types/buffer.go @@ -53,9 +53,9 @@ func (b *Buffer) Equals(o RVType) bool { return bytes.Equal([]byte(*b), []byte(*o.(*Buffer))) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewBuffer returns a new Buffer -func NewBuffer() *Buffer { - var b Buffer +func NewBuffer(data []byte) *Buffer { + var b Buffer = data + return &b } diff --git a/types/class_version_container.go b/types/class_version_container.go index ca7a1cbe..d3d8dd76 100644 --- a/types/class_version_container.go +++ b/types/class_version_container.go @@ -14,8 +14,8 @@ func (cvc *ClassVersionContainer) WriteTo(writable Writable) { // ExtractFrom extracts the ClassVersionContainer to the given readable func (cvc *ClassVersionContainer) ExtractFrom(readable Readable) error { cvc.ClassVersions = NewMap[*String, *PrimitiveU16]() - cvc.ClassVersions.KeyType = NewString() - cvc.ClassVersions.ValueType = NewPrimitiveU16() + cvc.ClassVersions.KeyType = NewString("") + cvc.ClassVersions.ValueType = NewPrimitiveU16(0) return cvc.ClassVersions.ExtractFrom(readable) } diff --git a/types/datetime.go b/types/datetime.go index 68c988b0..7ded5d73 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -134,7 +134,6 @@ func (dt *DateTime) FormatToString(indentationLevel int) string { return b.String() } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewDateTime returns a new DateTime instance func NewDateTime(value uint64) *DateTime { return &DateTime{value: value} diff --git a/types/pid.go b/types/pid.go index 046aec55..4ba04675 100644 --- a/types/pid.go +++ b/types/pid.go @@ -96,7 +96,6 @@ func (p *PID) FormatToString(indentationLevel int) string { return b.String() } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPID returns a PID instance. The size of PID depends on the client version func NewPID[T uint32 | uint64](pid T) *PID { switch v := any(pid).(type) { diff --git a/types/primitive_bool.go b/types/primitive_bool.go index 9a0453df..b4fea9c9 100644 --- a/types/primitive_bool.go +++ b/types/primitive_bool.go @@ -36,9 +36,9 @@ func (b *PrimitiveBool) Equals(o RVType) bool { return *b == *o.(*PrimitiveBool) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveBool returns a new PrimitiveBool -func NewPrimitiveBool() *PrimitiveBool { - var b PrimitiveBool +func NewPrimitiveBool(boolean bool) *PrimitiveBool { + b := PrimitiveBool(boolean) + return &b } diff --git a/types/primitive_float32.go b/types/primitive_float32.go index f11854d5..a071038a 100644 --- a/types/primitive_float32.go +++ b/types/primitive_float32.go @@ -36,9 +36,9 @@ func (f32 *PrimitiveF32) Equals(o RVType) bool { return *f32 == *o.(*PrimitiveF32) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveF32 returns a new PrimitiveF32 -func NewPrimitiveF32() *PrimitiveF32 { - var f32 PrimitiveF32 +func NewPrimitiveF32(float float32) *PrimitiveF32 { + f32 := PrimitiveF32(float) + return &f32 } diff --git a/types/primitive_float64.go b/types/primitive_float64.go index 656edf7f..663f0313 100644 --- a/types/primitive_float64.go +++ b/types/primitive_float64.go @@ -36,9 +36,9 @@ func (f64 *PrimitiveF64) Equals(o RVType) bool { return *f64 == *o.(*PrimitiveF64) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveF64 returns a new PrimitiveF64 -func NewPrimitiveF64() *PrimitiveF64 { - var f64 PrimitiveF64 +func NewPrimitiveF64(float float64) *PrimitiveF64 { + f64 := PrimitiveF64(float) + return &f64 } diff --git a/types/primitive_s16.go b/types/primitive_s16.go index 384bc888..a1a1ad8d 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -36,9 +36,9 @@ func (s16 *PrimitiveS16) Equals(o RVType) bool { return *s16 == *o.(*PrimitiveS16) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveS16 returns a new PrimitiveS16 -func NewPrimitiveS16() *PrimitiveS16 { - var s16 PrimitiveS16 +func NewPrimitiveS16(i16 int16) *PrimitiveS16 { + s16 := PrimitiveS16(i16) + return &s16 } diff --git a/types/primitive_s32.go b/types/primitive_s32.go index b9158872..16115dc8 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -36,9 +36,9 @@ func (s32 *PrimitiveS32) Equals(o RVType) bool { return *s32 == *o.(*PrimitiveS32) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveS32 returns a new PrimitiveS32 -func NewPrimitiveS32() *PrimitiveS32 { - var s32 PrimitiveS32 +func NewPrimitiveS32(i32 int32) *PrimitiveS32 { + s32 := PrimitiveS32(i32) + return &s32 } diff --git a/types/primitive_s64.go b/types/primitive_s64.go index 4799d423..40488af5 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -36,9 +36,9 @@ func (s64 *PrimitiveS64) Equals(o RVType) bool { return *s64 == *o.(*PrimitiveS64) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveS64 returns a new PrimitiveS64 -func NewPrimitiveS64() *PrimitiveS64 { - var s64 PrimitiveS64 +func NewPrimitiveS64(i64 int64) *PrimitiveS64 { + s64 := PrimitiveS64(i64) + return &s64 } diff --git a/types/primitive_s8.go b/types/primitive_s8.go index 1189d122..70520359 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -36,9 +36,9 @@ func (s8 *PrimitiveS8) Equals(o RVType) bool { return *s8 == *o.(*PrimitiveS8) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveS8 returns a new PrimitiveS8 -func NewPrimitiveS8() *PrimitiveS8 { - var s8 PrimitiveS8 +func NewPrimitiveS8(i8 int8) *PrimitiveS8 { + s8 := PrimitiveS8(i8) + return &s8 } diff --git a/types/primitive_u16.go b/types/primitive_u16.go index cafd070e..5ab151d2 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -36,9 +36,9 @@ func (u16 *PrimitiveU16) Equals(o RVType) bool { return *u16 == *o.(*PrimitiveU16) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveU16 returns a new PrimitiveU16 -func NewPrimitiveU16() *PrimitiveU16 { - var u16 PrimitiveU16 +func NewPrimitiveU16(ui16 uint16) *PrimitiveU16 { + u16 := PrimitiveU16(ui16) + return &u16 } diff --git a/types/primitive_u32.go b/types/primitive_u32.go index fb151914..9180be91 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -36,9 +36,9 @@ func (u32 *PrimitiveU32) Equals(o RVType) bool { return *u32 == *o.(*PrimitiveU32) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveU32 returns a new PrimitiveU32 -func NewPrimitiveU32() *PrimitiveU32 { - var u32 PrimitiveU32 +func NewPrimitiveU32(ui32 uint32) *PrimitiveU32 { + u32 := PrimitiveU32(ui32) + return &u32 } diff --git a/types/primitive_u64.go b/types/primitive_u64.go index 8bd9d880..a044fa1d 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -38,7 +38,8 @@ func (u64 *PrimitiveU64) Equals(o RVType) bool { // TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveU64 returns a new PrimitiveU64 -func NewPrimitiveU64() *PrimitiveU64 { - var u32 PrimitiveU64 - return &u32 +func NewPrimitiveU64(ui64 uint64) *PrimitiveU64 { + u64 := PrimitiveU64(ui64) + + return &u64 } diff --git a/types/primitive_u8.go b/types/primitive_u8.go index 52dbeee5..4bd94e34 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -36,9 +36,9 @@ func (u8 *PrimitiveU8) Equals(o RVType) bool { return *u8 == *o.(*PrimitiveU8) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveU8 returns a new PrimitiveU8 -func NewPrimitiveU8() *PrimitiveU8 { - var u8 PrimitiveU8 +func NewPrimitiveU8(ui8 uint8) *PrimitiveU8 { + u8 := PrimitiveU8(ui8) + return &u8 } diff --git a/types/qbuffer.go b/types/qbuffer.go index 5124069d..72821c9b 100644 --- a/types/qbuffer.go +++ b/types/qbuffer.go @@ -53,9 +53,9 @@ func (qb *QBuffer) Equals(o RVType) bool { return bytes.Equal([]byte(*qb), []byte(*o.(*Buffer))) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewQBuffer returns a new QBuffer -func NewQBuffer() *QBuffer { - var qb QBuffer +func NewQBuffer(data []byte) *QBuffer { + var qb QBuffer = data + return &qb } diff --git a/types/quuid.go b/types/quuid.go index 41fe5218..8b513644 100644 --- a/types/quuid.go +++ b/types/quuid.go @@ -172,7 +172,6 @@ func (qu *QUUID) FromString(uuid string) error { return nil } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewQUUID returns a new qUUID func NewQUUID() *QUUID { return &QUUID{ diff --git a/types/result.go b/types/result.go index ca9ce815..c10f4f91 100644 --- a/types/result.go +++ b/types/result.go @@ -78,19 +78,16 @@ func (r *Result) FormatToString(indentationLevel int) string { return b.String() } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewResult returns a new Result func NewResult(code uint32) *Result { return &Result{code} } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewResultSuccess returns a new Result set as a success func NewResultSuccess(code uint32) *Result { return NewResult(uint32(int(code) & ^errorMask)) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewResultError returns a new Result set as an error func NewResultError(code uint32) *Result { return NewResult(uint32(int(code) | errorMask)) diff --git a/types/rv_connection_data.go b/types/rv_connection_data.go index 84964e7b..4b8c405a 100644 --- a/types/rv_connection_data.go +++ b/types/rv_connection_data.go @@ -61,7 +61,7 @@ func (rvcd *RVConnectionData) ExtractFrom(readable Readable) error { var stationURLSpecialProtocols *StationURL var time *DateTime - specialProtocols.Type = NewPrimitiveU8() + specialProtocols.Type = NewPrimitiveU8(0) if err := stationURL.ExtractFrom(readable); err != nil { return fmt.Errorf("Failed to read RVConnectionData StationURL. %s", err.Error()) diff --git a/types/string.go b/types/string.go index 4be52276..fcf82e49 100644 --- a/types/string.go +++ b/types/string.go @@ -74,9 +74,9 @@ func (s *String) Equals(o RVType) bool { return *s == *o.(*String) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewString returns a new String -func NewString() *String { - var s String +func NewString(str string) *String { + s := String(str) + return &s } From 93965aa836df8999909131d7595cb1a1943ad96d Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 15:02:41 -0500 Subject: [PATCH 094/178] types: removed TODO comment for NewPrimitiveU64 --- types/primitive_u64.go | 1 - 1 file changed, 1 deletion(-) diff --git a/types/primitive_u64.go b/types/primitive_u64.go index a044fa1d..c7390e86 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -36,7 +36,6 @@ func (u64 *PrimitiveU64) Equals(o RVType) bool { return *u64 == *o.(*PrimitiveU64) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewPrimitiveU64 returns a new PrimitiveU64 func NewPrimitiveU64(ui64 uint64) *PrimitiveU64 { u64 := PrimitiveU64(ui64) From c6edc9108c1b6353de8d73e579a9e50b02c5fffc Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 15:04:41 -0500 Subject: [PATCH 095/178] types: rename NullData to Empty --- test/secure.go | 4 +-- types/empty.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ types/null_data.go | 84 ---------------------------------------------- 3 files changed, 86 insertions(+), 86 deletions(-) create mode 100644 types/empty.go delete mode 100644 types/null_data.go diff --git a/test/secure.go b/test/secure.go index 4e226793..e9c7ca99 100644 --- a/test/secure.go +++ b/test/secure.go @@ -15,7 +15,7 @@ var secureServer *nex.PRUDPServer type principalPreference struct { types.Structure - *types.NullData + *types.Empty ShowOnlinePresence bool ShowCurrentTitle bool BlockFriendRequests bool @@ -29,7 +29,7 @@ func (pp *principalPreference) WriteTo(stream *nex.StreamOut) { type comment struct { types.Structure - *types.NullData + *types.Empty Unknown uint8 Contents *types.String LastChanged *types.DateTime diff --git a/types/empty.go b/types/empty.go new file mode 100644 index 00000000..985e664d --- /dev/null +++ b/types/empty.go @@ -0,0 +1,84 @@ +package types + +import ( + "errors" + "fmt" + "strings" +) + +// Empty is a Structure with no fields +type Empty struct { + Structure +} + +// WriteTo writes the Empty to the given writable +func (e *Empty) WriteTo(writable Writable) { + if writable.UseStructureHeader() { + writable.WritePrimitiveUInt8(e.StructureVersion()) + writable.WritePrimitiveUInt32LE(0) + } +} + +// ExtractFrom extracts the Empty to the given readable +func (e *Empty) ExtractFrom(readable Readable) error { + if readable.UseStructureHeader() { + version, err := readable.ReadPrimitiveUInt8() + if err != nil { + return fmt.Errorf("Failed to read Empty version. %s", err.Error()) + } + + contentLength, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read Empty content length. %s", err.Error()) + } + + if readable.Remaining() < uint64(contentLength) { + return errors.New("Empty content length longer than data size") + } + + e.SetStructureVersion(version) + } + + return nil +} + +// Copy returns a pointer to a copy of the Empty. Requires type assertion when used +func (e Empty) Copy() RVType { + copied := NewEmpty() + copied.structureVersion = e.structureVersion + + return copied +} + +// Equals checks if the input is equal in value to the current instance +func (e *Empty) Equals(o RVType) bool { + if _, ok := o.(*Empty); !ok { + return false + } + + return (*e).structureVersion == (*o.(*Empty)).structureVersion +} + +// String returns a string representation of the struct +func (e *Empty) String() string { + return e.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (e *Empty) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("Empty{\n") + b.WriteString(fmt.Sprintf("%sstructureVersion: %d\n", indentationValues, e.structureVersion)) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// NewEmpty returns a new Empty Structure +func NewEmpty() *Empty { + return &Empty{} +} diff --git a/types/null_data.go b/types/null_data.go deleted file mode 100644 index 9d9c0428..00000000 --- a/types/null_data.go +++ /dev/null @@ -1,84 +0,0 @@ -package types - -import ( - "errors" - "fmt" - "strings" -) - -// NullData is a Structure with no fields -type NullData struct { - Structure -} - -// WriteTo writes the NullData to the given writable -func (nd *NullData) WriteTo(writable Writable) { - if writable.UseStructureHeader() { - writable.WritePrimitiveUInt8(nd.StructureVersion()) - writable.WritePrimitiveUInt32LE(0) - } -} - -// ExtractFrom extracts the NullData to the given readable -func (nd *NullData) ExtractFrom(readable Readable) error { - if readable.UseStructureHeader() { - version, err := readable.ReadPrimitiveUInt8() - if err != nil { - return fmt.Errorf("Failed to read NullData version. %s", err.Error()) - } - - contentLength, err := readable.ReadPrimitiveUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read NullData content length. %s", err.Error()) - } - - if readable.Remaining() < uint64(contentLength) { - return errors.New("NullData content length longer than data size") - } - - nd.SetStructureVersion(version) - } - - return nil -} - -// Copy returns a pointer to a copy of the NullData. Requires type assertion when used -func (nd NullData) Copy() RVType { - copied := NewNullData() - copied.structureVersion = nd.structureVersion - - return copied -} - -// Equals checks if the input is equal in value to the current instance -func (nd *NullData) Equals(o RVType) bool { - if _, ok := o.(*NullData); !ok { - return false - } - - return (*nd).structureVersion == (*o.(*NullData)).structureVersion -} - -// String returns a string representation of the struct -func (nd *NullData) String() string { - return nd.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (nd *NullData) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("NullData{\n") - b.WriteString(fmt.Sprintf("%sstructureVersion: %d\n", indentationValues, nd.structureVersion)) - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewNullData returns a new NullData Structure -func NewNullData() *NullData { - return &NullData{} -} From 793664257ad8134ccf8b2b89bdbd443d4aa70e9d Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 15:21:22 -0500 Subject: [PATCH 096/178] types: remove unused generic from PID type --- hpp_server.go | 2 +- kerberos.go | 2 +- prudp_client.go | 2 +- prudp_server.go | 4 ++-- test/auth.go | 8 ++++---- types/pid.go | 28 ++++------------------------ 6 files changed, 13 insertions(+), 33 deletions(-) diff --git a/hpp_server.go b/hpp_server.go index 0b0c9fbd..923c2074 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -82,7 +82,7 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { } client := NewHPPClient(tcpAddr, s) - client.SetPID(types.NewPID(uint32(pid))) + client.SetPID(types.NewPID(uint64(pid))) hppPacket, err := NewHPPPacket(client, rmcRequestBytes) if err != nil { diff --git a/kerberos.go b/kerberos.go index 14af53d6..65326287 100644 --- a/kerberos.go +++ b/kerberos.go @@ -171,7 +171,7 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) erro return fmt.Errorf("Failed to read Kerberos ticket internal data timestamp %s", err.Error()) } - userPID := types.NewPID[uint64](0) + userPID := types.NewPID(0) if err := userPID.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data user PID %s", err.Error()) } diff --git a/prudp_client.go b/prudp_client.go index abb538e4..8e0d9eda 100644 --- a/prudp_client.go +++ b/prudp_client.go @@ -224,7 +224,7 @@ func NewPRUDPClient(server *PRUDPServer, address net.Addr, webSocketConnection * address: address, webSocketConnection: webSocketConnection, outgoingPingSequenceIDCounter: NewCounter[uint16](0), - pid: types.NewPID[uint32](0), + pid: types.NewPID(0), unreliableBaseKey: make([]byte, 0x20), } } diff --git a/prudp_server.go b/prudp_server.go index 823f7843..eff50c39 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -525,7 +525,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, ui return nil, nil, 0, err } - serverKey := DeriveKerberosKey(types.NewPID[uint64](2), s.kerberosPassword) + serverKey := DeriveKerberosKey(types.NewPID(2), s.kerberosPassword) ticket := NewKerberosTicketInternalData() if err := ticket.Decrypt(NewStreamIn([]byte(*ticketData), s), serverKey); err != nil { @@ -550,7 +550,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, ui checkDataStream := NewStreamIn(decryptedRequestData, s) - userPID := types.NewPID[uint64](0) + userPID := types.NewPID(0) if err := userPID.ExtractFrom(checkDataStream); err != nil { return nil, nil, 0, err } diff --git a/test/auth.go b/test/auth.go index 6c8cdff5..fcd1b202 100644 --- a/test/auth.go +++ b/test/auth.go @@ -62,8 +62,8 @@ func login(packet nex.PRUDPPacketInterface) { } retval := types.NewResultSuccess(0x00010001) - pidPrincipal := types.NewPID(uint32(converted)) - pbufResponse := types.Buffer(generateTicket(pidPrincipal, types.NewPID[uint32](2))) + pidPrincipal := types.NewPID(uint64(converted)) + pbufResponse := types.Buffer(generateTicket(pidPrincipal, types.NewPID(2))) pConnectionData := types.NewRVConnectionData() strReturnMsg := types.String("Test Build") @@ -113,12 +113,12 @@ func requestTicket(packet nex.PRUDPPacketInterface) { parametersStream := nex.NewStreamIn(parameters, authServer) - idSource := types.NewPID[uint64](0) + idSource := types.NewPID(0) if err := idSource.ExtractFrom(parametersStream); err != nil { panic(err) } - idTarget := types.NewPID[uint64](0) + idTarget := types.NewPID(0) if err := idTarget.ExtractFrom(parametersStream); err != nil { panic(err) } diff --git a/types/pid.go b/types/pid.go index 4ba04675..b4171bca 100644 --- a/types/pid.go +++ b/types/pid.go @@ -83,33 +83,13 @@ func (p *PID) FormatToString(indentationLevel int) string { var b strings.Builder b.WriteString("PID{\n") - - switch v := any(p.pid).(type) { - case uint32: - b.WriteString(fmt.Sprintf("%spid: %d (legacy)\n", indentationValues, v)) - case uint64: - b.WriteString(fmt.Sprintf("%spid: %d (modern)\n", indentationValues, v)) - } - + b.WriteString(fmt.Sprintf("%spid: %d\n", indentationValues, p.pid)) b.WriteString(fmt.Sprintf("%s}", indentationEnd)) return b.String() } -// NewPID returns a PID instance. The size of PID depends on the client version -func NewPID[T uint32 | uint64](pid T) *PID { - switch v := any(pid).(type) { - case uint32: - return &PID{pid: uint64(v)} - case uint64: - return &PID{pid: v} - } - - // * This will never happen because Go will - // * not compile any code where "pid" is not - // * a uint32/uint64, so it will ALWAYS get - // * caught by the above switch-case. This - // * return is only here because Go won't - // * compile without a default return - return nil +// NewPID returns a PID instance. The real size of PID depends on the client version +func NewPID(pid uint64) *PID { + return &PID{pid: pid} } From 9a1db6b090c8f397b7696330f6250cc9c7abb0fa Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Dec 2023 15:50:33 -0500 Subject: [PATCH 097/178] types: fix types Copy methods --- types/any_data_holder.go | 10 ++++++++-- types/buffer.go | 6 ++++-- types/class_version_container.go | 2 +- types/datetime.go | 2 +- types/empty.go | 2 +- types/list.go | 2 +- types/map.go | 2 +- types/pid.go | 2 +- types/primitive_bool.go | 6 ++++-- types/primitive_float32.go | 6 ++++-- types/primitive_float64.go | 6 ++++-- types/primitive_s16.go | 6 ++++-- types/primitive_s32.go | 6 ++++-- types/primitive_s64.go | 6 ++++-- types/primitive_s8.go | 6 ++++-- types/primitive_u16.go | 6 ++++-- types/primitive_u32.go | 6 ++++-- types/primitive_u64.go | 6 ++++-- types/primitive_u8.go | 6 ++++-- types/qbuffer.go | 6 ++++-- types/result.go | 2 +- types/string.go | 6 ++++-- 22 files changed, 71 insertions(+), 37 deletions(-) diff --git a/types/any_data_holder.go b/types/any_data_holder.go index 9964d523..222adac6 100644 --- a/types/any_data_holder.go +++ b/types/any_data_holder.go @@ -74,7 +74,7 @@ func (adh *AnyDataHolder) ExtractFrom(readable Readable) error { } // Copy returns a new copied instance of DataHolder -func (adh *AnyDataHolder) Copy() *AnyDataHolder { +func (adh *AnyDataHolder) Copy() RVType { copied := NewAnyDataHolder() copied.TypeName = adh.TypeName @@ -86,7 +86,13 @@ func (adh *AnyDataHolder) Copy() *AnyDataHolder { } // Equals checks if the passed Structure contains the same data as the current instance -func (adh *AnyDataHolder) Equals(other *AnyDataHolder) bool { +func (adh *AnyDataHolder) Equals(o RVType) bool { + if _, ok := o.(*AnyDataHolder); !ok { + return false + } + + other := o.(*AnyDataHolder) + if adh.TypeName != other.TypeName { return false } diff --git a/types/buffer.go b/types/buffer.go index 2a3bc090..632d531d 100644 --- a/types/buffer.go +++ b/types/buffer.go @@ -40,8 +40,10 @@ func (b *Buffer) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the Buffer. Requires type assertion when used -func (b Buffer) Copy() RVType { - return &b +func (b *Buffer) Copy() RVType { + copied := Buffer(*b) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/class_version_container.go b/types/class_version_container.go index d3d8dd76..66f4f0af 100644 --- a/types/class_version_container.go +++ b/types/class_version_container.go @@ -21,7 +21,7 @@ func (cvc *ClassVersionContainer) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the ClassVersionContainer. Requires type assertion when used -func (cvc ClassVersionContainer) Copy() RVType { +func (cvc *ClassVersionContainer) Copy() RVType { copied := NewClassVersionContainer() copied.ClassVersions = cvc.ClassVersions.Copy().(*Map[*String, *PrimitiveU16]) diff --git a/types/datetime.go b/types/datetime.go index 7ded5d73..c144bf77 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -29,7 +29,7 @@ func (dt *DateTime) ExtractFrom(readable Readable) error { } // Copy returns a new copied instance of DateTime -func (dt DateTime) Copy() RVType { +func (dt *DateTime) Copy() RVType { return NewDateTime(dt.value) } diff --git a/types/empty.go b/types/empty.go index 985e664d..1fa74aa6 100644 --- a/types/empty.go +++ b/types/empty.go @@ -43,7 +43,7 @@ func (e *Empty) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the Empty. Requires type assertion when used -func (e Empty) Copy() RVType { +func (e *Empty) Copy() RVType { copied := NewEmpty() copied.structureVersion = e.structureVersion diff --git a/types/list.go b/types/list.go index 4d838b9d..8c9ddfd5 100644 --- a/types/list.go +++ b/types/list.go @@ -41,7 +41,7 @@ func (l *List[T]) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the List[T]. Requires type assertion when used -func (l List[T]) Copy() RVType { +func (l *List[T]) Copy() RVType { copied := NewList[T]() copied.real = make([]T, len(l.real)) copied.Type = l.Type.Copy().(T) diff --git a/types/map.go b/types/map.go index c2888769..32dc6373 100644 --- a/types/map.go +++ b/types/map.go @@ -54,7 +54,7 @@ func (m *Map[K, V]) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the Map[K, V]. Requires type assertion when used -func (m Map[K, V]) Copy() RVType { +func (m *Map[K, V]) Copy() RVType { copied := NewMap[K, V]() copied.keys = make([]K, len(m.keys)) copied.values = make([]V, len(m.values)) diff --git a/types/pid.go b/types/pid.go index b4171bca..0eb1436a 100644 --- a/types/pid.go +++ b/types/pid.go @@ -47,7 +47,7 @@ func (p *PID) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the PID. Requires type assertion when used -func (p PID) Copy() RVType { +func (p *PID) Copy() RVType { return NewPID(p.pid) } diff --git a/types/primitive_bool.go b/types/primitive_bool.go index b4fea9c9..9e6dbc1d 100644 --- a/types/primitive_bool.go +++ b/types/primitive_bool.go @@ -23,8 +23,10 @@ func (b *PrimitiveBool) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the PrimitiveBool. Requires type assertion when used -func (b PrimitiveBool) Copy() RVType { - return &b +func (b *PrimitiveBool) Copy() RVType { + copied := PrimitiveBool(*b) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_float32.go b/types/primitive_float32.go index a071038a..b0486551 100644 --- a/types/primitive_float32.go +++ b/types/primitive_float32.go @@ -23,8 +23,10 @@ func (f32 *PrimitiveF32) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the float32. Requires type assertion when used -func (f32 PrimitiveF32) Copy() RVType { - return &f32 +func (f32 *PrimitiveF32) Copy() RVType { + copied := PrimitiveF32(*f32) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_float64.go b/types/primitive_float64.go index 663f0313..41be8b1b 100644 --- a/types/primitive_float64.go +++ b/types/primitive_float64.go @@ -23,8 +23,10 @@ func (f64 *PrimitiveF64) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the float64. Requires type assertion when used -func (f64 PrimitiveF64) Copy() RVType { - return &f64 +func (f64 *PrimitiveF64) Copy() RVType { + copied := PrimitiveF64(*f64) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_s16.go b/types/primitive_s16.go index a1a1ad8d..0511e8ed 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -23,8 +23,10 @@ func (s16 *PrimitiveS16) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the int16. Requires type assertion when used -func (s16 PrimitiveS16) Copy() RVType { - return &s16 +func (s16 *PrimitiveS16) Copy() RVType { + copied := PrimitiveS16(*s16) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_s32.go b/types/primitive_s32.go index 16115dc8..919f8bb5 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -23,8 +23,10 @@ func (s32 *PrimitiveS32) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the int32. Requires type assertion when used -func (s32 PrimitiveS32) Copy() RVType { - return &s32 +func (s32 *PrimitiveS32) Copy() RVType { + copied := PrimitiveS32(*s32) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_s64.go b/types/primitive_s64.go index 40488af5..cc91c33a 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -23,8 +23,10 @@ func (s64 *PrimitiveS64) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the int64. Requires type assertion when used -func (s64 PrimitiveS64) Copy() RVType { - return &s64 +func (s64 *PrimitiveS64) Copy() RVType { + copied := PrimitiveS64(*s64) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_s8.go b/types/primitive_s8.go index 70520359..5e15f6fd 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -23,8 +23,10 @@ func (s8 *PrimitiveS8) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the int8. Requires type assertion when used -func (s8 PrimitiveS8) Copy() RVType { - return &s8 +func (s8 *PrimitiveS8) Copy() RVType { + copied := PrimitiveS8(*s8) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_u16.go b/types/primitive_u16.go index 5ab151d2..c70abfe8 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -23,8 +23,10 @@ func (u16 *PrimitiveU16) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the uint16. Requires type assertion when used -func (u16 PrimitiveU16) Copy() RVType { - return &u16 +func (u16 *PrimitiveU16) Copy() RVType { + copied := PrimitiveU16(*u16) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_u32.go b/types/primitive_u32.go index 9180be91..26d08a87 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -23,8 +23,10 @@ func (u32 *PrimitiveU32) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the uint32. Requires type assertion when used -func (u32 PrimitiveU32) Copy() RVType { - return &u32 +func (u32 *PrimitiveU32) Copy() RVType { + copied := PrimitiveU32(*u32) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_u64.go b/types/primitive_u64.go index c7390e86..7c3c654d 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -23,8 +23,10 @@ func (u64 *PrimitiveU64) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the uint64. Requires type assertion when used -func (u64 PrimitiveU64) Copy() RVType { - return &u64 +func (u64 *PrimitiveU64) Copy() RVType { + copied := PrimitiveU64(*u64) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/primitive_u8.go b/types/primitive_u8.go index 4bd94e34..aef73f7d 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -23,8 +23,10 @@ func (u8 *PrimitiveU8) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the uint8. Requires type assertion when used -func (u8 PrimitiveU8) Copy() RVType { - return &u8 +func (u8 *PrimitiveU8) Copy() RVType { + copied := PrimitiveU8(*u8) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/qbuffer.go b/types/qbuffer.go index 72821c9b..e9b38eec 100644 --- a/types/qbuffer.go +++ b/types/qbuffer.go @@ -40,8 +40,10 @@ func (qb *QBuffer) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the qBuffer. Requires type assertion when used -func (qb QBuffer) Copy() RVType { - return &qb +func (qb *QBuffer) Copy() RVType { + copied := QBuffer(*qb) + + return &copied } // Equals checks if the input is equal in value to the current instance diff --git a/types/result.go b/types/result.go index c10f4f91..a1cf91d7 100644 --- a/types/result.go +++ b/types/result.go @@ -30,7 +30,7 @@ func (r *Result) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the Result. Requires type assertion when used -func (r Result) Copy() RVType { +func (r *Result) Copy() RVType { return NewResult(r.Code) } diff --git a/types/string.go b/types/string.go index fcf82e49..06c095aa 100644 --- a/types/string.go +++ b/types/string.go @@ -61,8 +61,10 @@ func (s *String) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the String. Requires type assertion when used -func (s String) Copy() RVType { - return &s +func (s *String) Copy() RVType { + copied := String(*s) + + return &copied } // Equals checks if the input is equal in value to the current instance From e61cddf06e18aa9a66daca8f2312f6ebe9dcc748 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 25 Dec 2023 22:16:31 -0500 Subject: [PATCH 098/178] types: rename Empty to Data and update AnyDataHolder docs --- types/any_data_holder.go | 4 +++- types/{empty.go => data.go} | 42 ++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 22 deletions(-) rename types/{empty.go => data.go} (53%) diff --git a/types/any_data_holder.go b/types/any_data_holder.go index 222adac6..bcfc8688 100644 --- a/types/any_data_holder.go +++ b/types/any_data_holder.go @@ -12,7 +12,9 @@ func RegisterDataHolderType(name string, rvType RVType) { AnyDataHolderObjects[name] = rvType } -// AnyDataHolder is a class which can contain any Structure +// AnyDataHolder is a class which can contain any Structure. These Structures usually inherit from at least one +// other Structure. Typically this base class is the empty `Data` Structure, but this is not always the case. +// The contained Structures name & length are sent with the Structure body, so the receiver can properly decode it type AnyDataHolder struct { TypeName string // TODO - Replace this with String? Length1 uint32 // TODO - Replace this with PrimitiveU32? diff --git a/types/empty.go b/types/data.go similarity index 53% rename from types/empty.go rename to types/data.go index 1fa74aa6..6d72523b 100644 --- a/types/empty.go +++ b/types/data.go @@ -6,34 +6,34 @@ import ( "strings" ) -// Empty is a Structure with no fields -type Empty struct { +// Data is the base class for many other structures. The structure itself has no fields +type Data struct { Structure } -// WriteTo writes the Empty to the given writable -func (e *Empty) WriteTo(writable Writable) { +// WriteTo writes the Data to the given writable +func (e *Data) WriteTo(writable Writable) { if writable.UseStructureHeader() { writable.WritePrimitiveUInt8(e.StructureVersion()) writable.WritePrimitiveUInt32LE(0) } } -// ExtractFrom extracts the Empty to the given readable -func (e *Empty) ExtractFrom(readable Readable) error { +// ExtractFrom extracts the Data to the given readable +func (e *Data) ExtractFrom(readable Readable) error { if readable.UseStructureHeader() { version, err := readable.ReadPrimitiveUInt8() if err != nil { - return fmt.Errorf("Failed to read Empty version. %s", err.Error()) + return fmt.Errorf("Failed to read Data version. %s", err.Error()) } contentLength, err := readable.ReadPrimitiveUInt32LE() if err != nil { - return fmt.Errorf("Failed to read Empty content length. %s", err.Error()) + return fmt.Errorf("Failed to read Data content length. %s", err.Error()) } if readable.Remaining() < uint64(contentLength) { - return errors.New("Empty content length longer than data size") + return errors.New("Data content length longer than data size") } e.SetStructureVersion(version) @@ -42,43 +42,43 @@ func (e *Empty) ExtractFrom(readable Readable) error { return nil } -// Copy returns a pointer to a copy of the Empty. Requires type assertion when used -func (e *Empty) Copy() RVType { - copied := NewEmpty() +// Copy returns a pointer to a copy of the Data. Requires type assertion when used +func (e *Data) Copy() RVType { + copied := NewData() copied.structureVersion = e.structureVersion return copied } // Equals checks if the input is equal in value to the current instance -func (e *Empty) Equals(o RVType) bool { - if _, ok := o.(*Empty); !ok { +func (e *Data) Equals(o RVType) bool { + if _, ok := o.(*Data); !ok { return false } - return (*e).structureVersion == (*o.(*Empty)).structureVersion + return (*e).structureVersion == (*o.(*Data)).structureVersion } // String returns a string representation of the struct -func (e *Empty) String() string { +func (e *Data) String() string { return e.FormatToString(0) } // FormatToString pretty-prints the struct data using the provided indentation level -func (e *Empty) FormatToString(indentationLevel int) string { +func (e *Data) FormatToString(indentationLevel int) string { indentationValues := strings.Repeat("\t", indentationLevel+1) indentationEnd := strings.Repeat("\t", indentationLevel) var b strings.Builder - b.WriteString("Empty{\n") + b.WriteString("Data{\n") b.WriteString(fmt.Sprintf("%sstructureVersion: %d\n", indentationValues, e.structureVersion)) b.WriteString(fmt.Sprintf("%s}", indentationEnd)) return b.String() } -// NewEmpty returns a new Empty Structure -func NewEmpty() *Empty { - return &Empty{} +// NewData returns a new Data Structure +func NewData() *Data { + return &Data{} } From 34fa2d53ded322383692b9c1c658110029335acc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Wed, 27 Dec 2023 22:16:06 +0000 Subject: [PATCH 099/178] types: Replace type aliases with struct wrappers Also remove StructureInterface, and make all structures initialize all their fields. The structure header is now added as a common function `ExtractHeaderFrom` and `WriteHeaderTo`. --- kerberos.go | 12 ++--- prudp_server.go | 8 ++-- test/auth.go | 8 ++-- test/hpp.go | 58 ++++++++++++++--------- test/secure.go | 38 +++++++-------- types/any_data_holder.go | 18 +++---- types/buffer.go | 27 +++++------ types/class_version_container.go | 13 +++-- types/data.go | 30 +++--------- types/datetime.go | 2 +- types/pid.go | 2 +- types/primitive_bool.go | 22 ++++----- types/primitive_float32.go | 22 ++++----- types/primitive_float64.go | 20 ++++---- types/primitive_s16.go | 22 ++++----- types/primitive_s32.go | 22 ++++----- types/primitive_s64.go | 22 ++++----- types/primitive_s8.go | 22 ++++----- types/primitive_u16.go | 22 ++++----- types/primitive_u32.go | 22 ++++----- types/primitive_u64.go | 22 ++++----- types/primitive_u8.go | 21 ++++----- types/qbuffer.go | 25 ++++------ types/quuid.go | 1 - types/result_range.go | 63 +++++++++---------------- types/rv_connection_data.go | 81 +++++++++++++------------------- types/station_url.go | 6 +-- types/string.go | 22 ++++----- types/structure.go | 58 ++++++++++++----------- types/variant.go | 2 +- 30 files changed, 312 insertions(+), 401 deletions(-) diff --git a/kerberos.go b/kerberos.go index 65326287..aad27ad5 100644 --- a/kerberos.go +++ b/kerberos.go @@ -124,8 +124,8 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([] finalStream := NewStreamOut(stream.Server) - var ticketBuffer types.Buffer = ticketKey - var encryptedBuffer types.Buffer = encrypted + ticketBuffer := types.NewBuffer(ticketKey) + encryptedBuffer := types.NewBuffer(encrypted) ticketBuffer.WriteTo(finalStream) encryptedBuffer.WriteTo(finalStream) @@ -141,20 +141,20 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([] // Decrypt decrypts the given data and populates the struct func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) error { if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { - ticketKey := types.NewBuffer([]byte{}) + ticketKey := types.NewBuffer(nil) if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data key. %s", err.Error()) } - data := types.NewBuffer([]byte{}) + data := types.NewBuffer(nil) if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data. %s", err.Error()) } - hash := md5.Sum(append(key, *ticketKey...)) + hash := md5.Sum(append(key, ticketKey.Value...)) key = hash[:] - stream = NewStreamIn(*data, stream.Server) + stream = NewStreamIn(data.Value, stream.Server) } encryption := NewKerberosEncryption(key) diff --git a/prudp_server.go b/prudp_server.go index eff50c39..a081e3c2 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -515,12 +515,12 @@ func (s *PRUDPServer) handlePing(packet PRUDPPacketInterface) { func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, uint32, error) { stream := NewStreamIn(payload, s) - ticketData := types.NewBuffer([]byte{}) + ticketData := types.NewBuffer(nil) if err := ticketData.ExtractFrom(stream); err != nil { return nil, nil, 0, err } - requestData := types.NewBuffer([]byte{}) + requestData := types.NewBuffer(nil) if err := requestData.ExtractFrom(stream); err != nil { return nil, nil, 0, err } @@ -528,7 +528,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, ui serverKey := DeriveKerberosKey(types.NewPID(2), s.kerberosPassword) ticket := NewKerberosTicketInternalData() - if err := ticket.Decrypt(NewStreamIn([]byte(*ticketData), s), serverKey); err != nil { + if err := ticket.Decrypt(NewStreamIn(ticketData.Value, s), serverKey); err != nil { return nil, nil, 0, err } @@ -543,7 +543,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, ui sessionKey := ticket.SessionKey kerberos := NewKerberosEncryption(sessionKey) - decryptedRequestData, err := kerberos.Decrypt(*requestData) + decryptedRequestData, err := kerberos.Decrypt(requestData.Value) if err != nil { return nil, nil, 0, err } diff --git a/test/auth.go b/test/auth.go index fcd1b202..97410325 100644 --- a/test/auth.go +++ b/test/auth.go @@ -56,16 +56,16 @@ func login(packet nex.PRUDPPacketInterface) { panic(err) } - converted, err := strconv.Atoi(string(*strUserName)) + converted, err := strconv.Atoi(strUserName.Value) if err != nil { panic(err) } retval := types.NewResultSuccess(0x00010001) pidPrincipal := types.NewPID(uint64(converted)) - pbufResponse := types.Buffer(generateTicket(pidPrincipal, types.NewPID(2))) + pbufResponse := types.NewBuffer(generateTicket(pidPrincipal, types.NewPID(2))) pConnectionData := types.NewRVConnectionData() - strReturnMsg := types.String("Test Build") + strReturnMsg := types.NewString("Test Build") pConnectionData.StationURL = types.NewStationURL("prudps:/address=192.168.1.98;port=60001;CID=1;PID=2;sid=1;stream=10;type=2") pConnectionData.SpecialProtocols = types.NewList[*types.PrimitiveU8]() @@ -124,7 +124,7 @@ func requestTicket(packet nex.PRUDPPacketInterface) { } retval := types.NewResultSuccess(0x00010001) - pbufResponse := types.Buffer(generateTicket(idSource, idTarget)) + pbufResponse := types.NewBuffer(generateTicket(idSource, idTarget)) responseStream := nex.NewStreamOut(authServer) diff --git a/test/hpp.go b/test/hpp.go index adf445ed..125374d0 100644 --- a/test/hpp.go +++ b/test/hpp.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/types" ) var hppServer *nex.HPPServer @@ -11,14 +12,18 @@ var hppServer *nex.HPPServer // * Took these structs out of the protocols lib for convenience type dataStoreGetNotificationURLParam struct { - nex.Structure - PreviousURL string + types.Structure + PreviousURL *types.String } -func (d *dataStoreGetNotificationURLParam) ExtractFromStream(stream *nex.StreamIn) error { +func (d *dataStoreGetNotificationURLParam) ExtractFrom(readable types.Readable) error { var err error - d.PreviousURL, err = stream.ReadString() + if err = d.ExtractHeaderFrom(readable); err != nil { + return fmt.Errorf("Failed to extract DataStoreGetNotificationURLParam header. %s", err.Error()) + } + + err = d.PreviousURL.ExtractFrom(readable) if err != nil { return fmt.Errorf("Failed to extract DataStoreGetNotificationURLParam.PreviousURL. %s", err.Error()) } @@ -27,23 +32,29 @@ func (d *dataStoreGetNotificationURLParam) ExtractFromStream(stream *nex.StreamI } type dataStoreReqGetNotificationURLInfo struct { - nex.Structure - URL string - Key string - Query string - RootCACert []byte + types.Structure + URL *types.String + Key *types.String + Query *types.String + RootCACert *types.Buffer } -func (d *dataStoreReqGetNotificationURLInfo) Bytes(stream *nex.StreamOut) []byte { - stream.WriteString(d.URL) - stream.WriteString(d.Key) - stream.WriteString(d.Query) - stream.WriteBuffer(d.RootCACert) +func (d *dataStoreReqGetNotificationURLInfo) WriteTo(writable types.Writable) { + contentWritable := writable.CopyNew() + + d.URL.WriteTo(contentWritable) + d.Key.WriteTo(contentWritable) + d.Query.WriteTo(contentWritable) + d.RootCACert.WriteTo(contentWritable) + + content := contentWritable.Bytes() + + d.WriteHeaderTo(writable, uint32(len(content))) - return stream.Bytes() + writable.Write(content) } -func passwordFromPID(pid *nex.PID) (string, uint32) { +func passwordFromPID(pid *types.PID) (string, uint32) { return "notmypassword", 0 } @@ -70,7 +81,7 @@ func startHPPServer() { hppServer.SetAccessKey("76f26496") hppServer.SetPasswordFromPIDFunction(passwordFromPID) - hppServer.Listen(8085) + hppServer.Listen(12345) } func getNotificationURL(packet *nex.HPPPacket) { @@ -81,7 +92,9 @@ func getNotificationURL(packet *nex.HPPPacket) { parametersStream := nex.NewStreamIn(parameters, hppServer) - param, err := nex.StreamReadStructure(parametersStream, &dataStoreGetNotificationURLParam{}) + param := &dataStoreGetNotificationURLParam{} + param.PreviousURL = types.NewString("") + err := param.ExtractFrom(parametersStream) if err != nil { fmt.Println("[HPP]", err) return @@ -92,11 +105,12 @@ func getNotificationURL(packet *nex.HPPPacket) { responseStream := nex.NewStreamOut(hppServer) info := &dataStoreReqGetNotificationURLInfo{} - info.URL = "https://example.com" - info.Key = "whatever/key" - info.Query = "?pretendo=1" + info.URL = types.NewString("https://example.com") + info.Key = types.NewString("whatever/key") + info.Query = types.NewString("?pretendo=1") + info.RootCACert = types.NewBuffer(nil) - responseStream.WriteStructure(info) + info.WriteTo(responseStream) response.IsSuccess = true response.IsRequest = false diff --git a/test/secure.go b/test/secure.go index e9c7ca99..36c289ed 100644 --- a/test/secure.go +++ b/test/secure.go @@ -15,30 +15,30 @@ var secureServer *nex.PRUDPServer type principalPreference struct { types.Structure - *types.Empty - ShowOnlinePresence bool - ShowCurrentTitle bool - BlockFriendRequests bool + *types.Data + ShowOnlinePresence *types.PrimitiveBool + ShowCurrentTitle *types.PrimitiveBool + BlockFriendRequests *types.PrimitiveBool } -func (pp *principalPreference) WriteTo(stream *nex.StreamOut) { - stream.WritePrimitiveBool(pp.ShowOnlinePresence) - stream.WritePrimitiveBool(pp.ShowCurrentTitle) - stream.WritePrimitiveBool(pp.BlockFriendRequests) +func (pp *principalPreference) WriteTo(writable types.Writable) { + pp.ShowOnlinePresence.WriteTo(writable) + pp.ShowCurrentTitle.WriteTo(writable) + pp.BlockFriendRequests.WriteTo(writable) } type comment struct { types.Structure - *types.Empty - Unknown uint8 + *types.Data + Unknown *types.PrimitiveU8 Contents *types.String LastChanged *types.DateTime } -func (c *comment) WriteTo(stream *nex.StreamOut) { - stream.WritePrimitiveUInt8(c.Unknown) - c.Contents.WriteTo(stream) - c.LastChanged.WriteTo(stream) +func (c *comment) WriteTo(writable types.Writable) { + c.Unknown.WriteTo(writable) + c.Contents.WriteTo(writable) + c.LastChanged.WriteTo(writable) } func startSecureServer() { @@ -109,7 +109,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { localStation.Fields["port"] = strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port) retval := types.NewResultSuccess(0x00010001) - localStationURL := types.String(localStation.EncodeToString()) + localStationURL := types.NewString(localStation.EncodeToString()) responseStream := nex.NewStreamOut(secureServer) @@ -148,12 +148,12 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { responseStream := nex.NewStreamOut(secureServer) (&principalPreference{ - ShowOnlinePresence: true, - ShowCurrentTitle: true, - BlockFriendRequests: false, + ShowOnlinePresence: types.NewPrimitiveBool(true), + ShowCurrentTitle: types.NewPrimitiveBool(true), + BlockFriendRequests: types.NewPrimitiveBool(false), }).WriteTo(responseStream) (&comment{ - Unknown: 0, + Unknown: types.NewPrimitiveU8(0), Contents: types.NewString("Rewrite Test"), LastChanged: types.NewDateTime(0), }).WriteTo(responseStream) diff --git a/types/any_data_holder.go b/types/any_data_holder.go index bcfc8688..89fa487e 100644 --- a/types/any_data_holder.go +++ b/types/any_data_holder.go @@ -16,9 +16,9 @@ func RegisterDataHolderType(name string, rvType RVType) { // other Structure. Typically this base class is the empty `Data` Structure, but this is not always the case. // The contained Structures name & length are sent with the Structure body, so the receiver can properly decode it type AnyDataHolder struct { - TypeName string // TODO - Replace this with String? - Length1 uint32 // TODO - Replace this with PrimitiveU32? - Length2 uint32 // TODO - Replace this with PrimitiveU32? + TypeName string + Length1 uint32 + Length2 uint32 ObjectData RVType } @@ -29,7 +29,7 @@ func (adh *AnyDataHolder) WriteTo(writable Writable) { adh.ObjectData.WriteTo(contentWritable) objectData := contentWritable.Bytes() - typeName := String(adh.TypeName) + typeName := NewString(adh.TypeName) length1 := uint32(len(objectData) + 4) length2 := uint32(len(objectData)) @@ -41,7 +41,7 @@ func (adh *AnyDataHolder) WriteTo(writable Writable) { // ExtractFrom extracts the AnyDataholder to the given readable func (adh *AnyDataHolder) ExtractFrom(readable Readable) error { - var typeName String + typeName := NewString("") err := typeName.ExtractFrom(readable) if err != nil { @@ -58,17 +58,17 @@ func (adh *AnyDataHolder) ExtractFrom(readable Readable) error { return fmt.Errorf("Failed to read DanyDataHolder length 2. %s", err.Error()) } - if _, ok := AnyDataHolderObjects[string(typeName)]; !ok { - return fmt.Errorf("Unknown AnyDataHolder type: %s", string(typeName)) + if _, ok := AnyDataHolderObjects[typeName.Value]; !ok { + return fmt.Errorf("Unknown AnyDataHolder type: %s", typeName.Value) } - adh.ObjectData = AnyDataHolderObjects[string(typeName)].Copy() + adh.ObjectData = AnyDataHolderObjects[typeName.Value].Copy() if err := adh.ObjectData.ExtractFrom(readable); err != nil { return fmt.Errorf("Failed to read DanyDataHolder object data. %s", err.Error()) } - adh.TypeName = string(typeName) + adh.TypeName = typeName.Value adh.Length1 = length1 adh.Length2 = length2 diff --git a/types/buffer.go b/types/buffer.go index 632d531d..db149003 100644 --- a/types/buffer.go +++ b/types/buffer.go @@ -5,20 +5,19 @@ import ( "fmt" ) -// TODO - Should this have a "Value"-kind of method to get the original value? - -// Buffer is a type alias of []byte with receiver methods to conform to RVType -type Buffer []byte // TODO - Should we make this a struct instead of a type alias? +// Buffer is a struct of []byte with receiver methods to conform to RVType +type Buffer struct { + Value []byte +} // WriteTo writes the []byte to the given writable func (b *Buffer) WriteTo(writable Writable) { - data := *b - length := len(data) + length := len(b.Value) writable.WritePrimitiveUInt32LE(uint32(length)) if length > 0 { - writable.Write([]byte(data)) + writable.Write(b.Value) } } @@ -29,21 +28,19 @@ func (b *Buffer) ExtractFrom(readable Readable) error { return fmt.Errorf("Failed to read NEX Buffer length. %s", err.Error()) } - data, err := readable.Read(uint64(length)) + value, err := readable.Read(uint64(length)) if err != nil { return fmt.Errorf("Failed to read NEX Buffer data. %s", err.Error()) } - *b = Buffer(data) + b.Value = value return nil } // Copy returns a pointer to a copy of the Buffer. Requires type assertion when used func (b *Buffer) Copy() RVType { - copied := Buffer(*b) - - return &copied + return NewBuffer(b.Value) } // Equals checks if the input is equal in value to the current instance @@ -52,12 +49,10 @@ func (b *Buffer) Equals(o RVType) bool { return false } - return bytes.Equal([]byte(*b), []byte(*o.(*Buffer))) + return bytes.Equal(b.Value, o.(*Buffer).Value) } // NewBuffer returns a new Buffer func NewBuffer(data []byte) *Buffer { - var b Buffer = data - - return &b + return &Buffer{Value: data} } diff --git a/types/class_version_container.go b/types/class_version_container.go index 66f4f0af..a15e8292 100644 --- a/types/class_version_container.go +++ b/types/class_version_container.go @@ -13,10 +13,6 @@ func (cvc *ClassVersionContainer) WriteTo(writable Writable) { // ExtractFrom extracts the ClassVersionContainer to the given readable func (cvc *ClassVersionContainer) ExtractFrom(readable Readable) error { - cvc.ClassVersions = NewMap[*String, *PrimitiveU16]() - cvc.ClassVersions.KeyType = NewString("") - cvc.ClassVersions.ValueType = NewPrimitiveU16(0) - return cvc.ClassVersions.ExtractFrom(readable) } @@ -39,5 +35,12 @@ func (cvc *ClassVersionContainer) Equals(o RVType) bool { // NewClassVersionContainer returns a new ClassVersionContainer func NewClassVersionContainer() *ClassVersionContainer { - return &ClassVersionContainer{} + cvc := &ClassVersionContainer{ + ClassVersions: NewMap[*String, *PrimitiveU16](), + } + + cvc.ClassVersions.KeyType = NewString("") + cvc.ClassVersions.ValueType = NewPrimitiveU16(0) + + return cvc } diff --git a/types/data.go b/types/data.go index 6d72523b..3f399327 100644 --- a/types/data.go +++ b/types/data.go @@ -1,7 +1,6 @@ package types import ( - "errors" "fmt" "strings" ) @@ -13,30 +12,13 @@ type Data struct { // WriteTo writes the Data to the given writable func (e *Data) WriteTo(writable Writable) { - if writable.UseStructureHeader() { - writable.WritePrimitiveUInt8(e.StructureVersion()) - writable.WritePrimitiveUInt32LE(0) - } + e.WriteHeaderTo(writable, 0) } // ExtractFrom extracts the Data to the given readable func (e *Data) ExtractFrom(readable Readable) error { - if readable.UseStructureHeader() { - version, err := readable.ReadPrimitiveUInt8() - if err != nil { - return fmt.Errorf("Failed to read Data version. %s", err.Error()) - } - - contentLength, err := readable.ReadPrimitiveUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read Data content length. %s", err.Error()) - } - - if readable.Remaining() < uint64(contentLength) { - return errors.New("Data content length longer than data size") - } - - e.SetStructureVersion(version) + if err := e.ExtractHeaderFrom(readable); err != nil { + return fmt.Errorf("Failed to read Data header. %s", err.Error()) } return nil @@ -45,7 +27,7 @@ func (e *Data) ExtractFrom(readable Readable) error { // Copy returns a pointer to a copy of the Data. Requires type assertion when used func (e *Data) Copy() RVType { copied := NewData() - copied.structureVersion = e.structureVersion + copied.StructureVersion = e.StructureVersion return copied } @@ -56,7 +38,7 @@ func (e *Data) Equals(o RVType) bool { return false } - return (*e).structureVersion == (*o.(*Data)).structureVersion + return (*e).StructureVersion == (*o.(*Data)).StructureVersion } // String returns a string representation of the struct @@ -72,7 +54,7 @@ func (e *Data) FormatToString(indentationLevel int) string { var b strings.Builder b.WriteString("Data{\n") - b.WriteString(fmt.Sprintf("%sstructureVersion: %d\n", indentationValues, e.structureVersion)) + b.WriteString(fmt.Sprintf("%sStructureVersion: %d\n", indentationValues, e.StructureVersion)) b.WriteString(fmt.Sprintf("%s}", indentationEnd)) return b.String() diff --git a/types/datetime.go b/types/datetime.go index c144bf77..fb890873 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -8,7 +8,7 @@ import ( // DateTime represents a NEX DateTime type type DateTime struct { - value uint64 // TODO - Replace this with PrimitiveU64? + value uint64 } // WriteTo writes the DateTime to the given writable diff --git a/types/pid.go b/types/pid.go index 0eb1436a..d214ddef 100644 --- a/types/pid.go +++ b/types/pid.go @@ -11,7 +11,7 @@ import ( // Legacy clients (WiiU/3DS) use a uint32, whereas modern clients (Nintendo Switch) use a uint64. // Value is always stored as the higher uint64, the consuming API should assert accordingly type PID struct { - pid uint64 // TODO - Replace this with PrimitiveU64? + pid uint64 } // WriteTo writes the bool to the given writable diff --git a/types/primitive_bool.go b/types/primitive_bool.go index 9e6dbc1d..17cc41de 100644 --- a/types/primitive_bool.go +++ b/types/primitive_bool.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveBool is a type alias of bool with receiver methods to conform to RVType -type PrimitiveBool bool // TODO - Should we make this a struct instead of a type alias? +// PrimitiveBool is a struct of bool with receiver methods to conform to RVType +type PrimitiveBool struct { + Value bool +} // WriteTo writes the bool to the given writable func (b *PrimitiveBool) WriteTo(writable Writable) { - writable.WritePrimitiveBool(bool(*b)) + writable.WritePrimitiveBool(b.Value) } // ExtractFrom extracts the bool to the given readable @@ -17,16 +17,14 @@ func (b *PrimitiveBool) ExtractFrom(readable Readable) error { return err } - *b = PrimitiveBool(value) + b.Value = value return nil } // Copy returns a pointer to a copy of the PrimitiveBool. Requires type assertion when used func (b *PrimitiveBool) Copy() RVType { - copied := PrimitiveBool(*b) - - return &copied + return NewPrimitiveBool(b.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (b *PrimitiveBool) Equals(o RVType) bool { return false } - return *b == *o.(*PrimitiveBool) + return b.Value == o.(*PrimitiveBool).Value } // NewPrimitiveBool returns a new PrimitiveBool func NewPrimitiveBool(boolean bool) *PrimitiveBool { - b := PrimitiveBool(boolean) - - return &b + return &PrimitiveBool{Value: boolean} } diff --git a/types/primitive_float32.go b/types/primitive_float32.go index b0486551..b3293702 100644 --- a/types/primitive_float32.go +++ b/types/primitive_float32.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveF32 is a type alias of float32 with receiver methods to conform to RVType -type PrimitiveF32 float32 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveF32 is a struct of float32 with receiver methods to conform to RVType +type PrimitiveF32 struct { + Value float32 +} // WriteTo writes the float32 to the given writable func (f32 *PrimitiveF32) WriteTo(writable Writable) { - writable.WritePrimitiveFloat32LE(float32(*f32)) + writable.WritePrimitiveFloat32LE(f32.Value) } // ExtractFrom extracts the float32 to the given readable @@ -17,16 +17,14 @@ func (f32 *PrimitiveF32) ExtractFrom(readable Readable) error { return err } - *f32 = PrimitiveF32(value) + f32.Value = value return nil } // Copy returns a pointer to a copy of the float32. Requires type assertion when used func (f32 *PrimitiveF32) Copy() RVType { - copied := PrimitiveF32(*f32) - - return &copied + return NewPrimitiveF32(f32.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (f32 *PrimitiveF32) Equals(o RVType) bool { return false } - return *f32 == *o.(*PrimitiveF32) + return f32.Value == o.(*PrimitiveF32).Value } // NewPrimitiveF32 returns a new PrimitiveF32 func NewPrimitiveF32(float float32) *PrimitiveF32 { - f32 := PrimitiveF32(float) - - return &f32 + return &PrimitiveF32{Value: float} } diff --git a/types/primitive_float64.go b/types/primitive_float64.go index 41be8b1b..f988dbed 100644 --- a/types/primitive_float64.go +++ b/types/primitive_float64.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveF64 is a type alias of float64 with receiver methods to conform to RVType -type PrimitiveF64 float64 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveF64 is a struct of float64 with receiver methods to conform to RVType +type PrimitiveF64 struct { + Value float64 +} // WriteTo writes the float64 to the given writable func (f64 *PrimitiveF64) WriteTo(writable Writable) { - writable.WritePrimitiveFloat64LE(float64(*f64)) + writable.WritePrimitiveFloat64LE(f64.Value) } // ExtractFrom extracts the float64 to the given readable @@ -17,16 +17,14 @@ func (f64 *PrimitiveF64) ExtractFrom(readable Readable) error { return err } - *f64 = PrimitiveF64(value) + f64.Value = value return nil } // Copy returns a pointer to a copy of the float64. Requires type assertion when used func (f64 *PrimitiveF64) Copy() RVType { - copied := PrimitiveF64(*f64) - - return &copied + return NewPrimitiveF64(f64.Value) } // Equals checks if the input is equal in value to the current instance @@ -40,7 +38,5 @@ func (f64 *PrimitiveF64) Equals(o RVType) bool { // NewPrimitiveF64 returns a new PrimitiveF64 func NewPrimitiveF64(float float64) *PrimitiveF64 { - f64 := PrimitiveF64(float) - - return &f64 + return &PrimitiveF64{Value: float} } diff --git a/types/primitive_s16.go b/types/primitive_s16.go index 0511e8ed..956e8143 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveS16 is a type alias of int16 with receiver methods to conform to RVType -type PrimitiveS16 int16 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveS16 is a struct of int16 with receiver methods to conform to RVType +type PrimitiveS16 struct { + Value int16 +} // WriteTo writes the int16 to the given writable func (s16 *PrimitiveS16) WriteTo(writable Writable) { - writable.WritePrimitiveInt16LE(int16(*s16)) + writable.WritePrimitiveInt16LE(s16.Value) } // ExtractFrom extracts the int16 to the given readable @@ -17,16 +17,14 @@ func (s16 *PrimitiveS16) ExtractFrom(readable Readable) error { return err } - *s16 = PrimitiveS16(value) + s16.Value = value return nil } // Copy returns a pointer to a copy of the int16. Requires type assertion when used func (s16 *PrimitiveS16) Copy() RVType { - copied := PrimitiveS16(*s16) - - return &copied + return NewPrimitiveS16(s16.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (s16 *PrimitiveS16) Equals(o RVType) bool { return false } - return *s16 == *o.(*PrimitiveS16) + return s16.Value == o.(*PrimitiveS16).Value } // NewPrimitiveS16 returns a new PrimitiveS16 func NewPrimitiveS16(i16 int16) *PrimitiveS16 { - s16 := PrimitiveS16(i16) - - return &s16 + return &PrimitiveS16{Value: i16} } diff --git a/types/primitive_s32.go b/types/primitive_s32.go index 919f8bb5..a7fe4b67 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveS32 is a type alias of int32 with receiver methods to conform to RVType -type PrimitiveS32 int32 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveS32 is a struct of int32 with receiver methods to conform to RVType +type PrimitiveS32 struct { + Value int32 +} // WriteTo writes the int32 to the given writable func (s32 *PrimitiveS32) WriteTo(writable Writable) { - writable.WritePrimitiveInt32LE(int32(*s32)) + writable.WritePrimitiveInt32LE(s32.Value) } // ExtractFrom extracts the int32 to the given readable @@ -17,16 +17,14 @@ func (s32 *PrimitiveS32) ExtractFrom(readable Readable) error { return err } - *s32 = PrimitiveS32(value) + s32.Value = value return nil } // Copy returns a pointer to a copy of the int32. Requires type assertion when used func (s32 *PrimitiveS32) Copy() RVType { - copied := PrimitiveS32(*s32) - - return &copied + return NewPrimitiveS32(s32.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (s32 *PrimitiveS32) Equals(o RVType) bool { return false } - return *s32 == *o.(*PrimitiveS32) + return s32.Value == o.(*PrimitiveS32).Value } // NewPrimitiveS32 returns a new PrimitiveS32 func NewPrimitiveS32(i32 int32) *PrimitiveS32 { - s32 := PrimitiveS32(i32) - - return &s32 + return &PrimitiveS32{Value: i32} } diff --git a/types/primitive_s64.go b/types/primitive_s64.go index cc91c33a..39ec0119 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveS64 is a type alias of int64 with receiver methods to conform to RVType -type PrimitiveS64 int64 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveS64 is a struct of int64 with receiver methods to conform to RVType +type PrimitiveS64 struct { + Value int64 +} // WriteTo writes the int64 to the given writable func (s64 *PrimitiveS64) WriteTo(writable Writable) { - writable.WritePrimitiveInt64LE(int64(*s64)) + writable.WritePrimitiveInt64LE(s64.Value) } // ExtractFrom extracts the int64 to the given readable @@ -17,16 +17,14 @@ func (s64 *PrimitiveS64) ExtractFrom(readable Readable) error { return err } - *s64 = PrimitiveS64(value) + s64.Value = value return nil } // Copy returns a pointer to a copy of the int64. Requires type assertion when used func (s64 *PrimitiveS64) Copy() RVType { - copied := PrimitiveS64(*s64) - - return &copied + return NewPrimitiveS64(s64.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (s64 *PrimitiveS64) Equals(o RVType) bool { return false } - return *s64 == *o.(*PrimitiveS64) + return s64.Value == o.(*PrimitiveS64).Value } // NewPrimitiveS64 returns a new PrimitiveS64 func NewPrimitiveS64(i64 int64) *PrimitiveS64 { - s64 := PrimitiveS64(i64) - - return &s64 + return &PrimitiveS64{Value: i64} } diff --git a/types/primitive_s8.go b/types/primitive_s8.go index 5e15f6fd..432a9677 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveS8 is a type alias of int8 with receiver methods to conform to RVType -type PrimitiveS8 int8 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveS8 is a struct of int8 with receiver methods to conform to RVType +type PrimitiveS8 struct { + Value int8 +} // WriteTo writes the int8 to the given writable func (s8 *PrimitiveS8) WriteTo(writable Writable) { - writable.WritePrimitiveInt8(int8(*s8)) + writable.WritePrimitiveInt8(s8.Value) } // ExtractFrom extracts the int8 to the given readable @@ -17,16 +17,14 @@ func (s8 *PrimitiveS8) ExtractFrom(readable Readable) error { return err } - *s8 = PrimitiveS8(value) + s8.Value = value return nil } // Copy returns a pointer to a copy of the int8. Requires type assertion when used func (s8 *PrimitiveS8) Copy() RVType { - copied := PrimitiveS8(*s8) - - return &copied + return NewPrimitiveS8(s8.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (s8 *PrimitiveS8) Equals(o RVType) bool { return false } - return *s8 == *o.(*PrimitiveS8) + return s8.Value == o.(*PrimitiveS8).Value } // NewPrimitiveS8 returns a new PrimitiveS8 func NewPrimitiveS8(i8 int8) *PrimitiveS8 { - s8 := PrimitiveS8(i8) - - return &s8 + return &PrimitiveS8{Value: i8} } diff --git a/types/primitive_u16.go b/types/primitive_u16.go index c70abfe8..83ba608c 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveU16 is a type alias of uint16 with receiver methods to conform to RVType -type PrimitiveU16 uint16 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveU16 is a struct of uint16 with receiver methods to conform to RVType +type PrimitiveU16 struct { + Value uint16 +} // WriteTo writes the uint16 to the given writable func (u16 *PrimitiveU16) WriteTo(writable Writable) { - writable.WritePrimitiveUInt16LE(uint16(*u16)) + writable.WritePrimitiveUInt16LE(u16.Value) } // ExtractFrom extracts the uint16 to the given readable @@ -17,16 +17,14 @@ func (u16 *PrimitiveU16) ExtractFrom(readable Readable) error { return err } - *u16 = PrimitiveU16(value) + u16.Value = value return nil } // Copy returns a pointer to a copy of the uint16. Requires type assertion when used func (u16 *PrimitiveU16) Copy() RVType { - copied := PrimitiveU16(*u16) - - return &copied + return NewPrimitiveU16(u16.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (u16 *PrimitiveU16) Equals(o RVType) bool { return false } - return *u16 == *o.(*PrimitiveU16) + return u16.Value == o.(*PrimitiveU16).Value } // NewPrimitiveU16 returns a new PrimitiveU16 func NewPrimitiveU16(ui16 uint16) *PrimitiveU16 { - u16 := PrimitiveU16(ui16) - - return &u16 + return &PrimitiveU16{Value: ui16} } diff --git a/types/primitive_u32.go b/types/primitive_u32.go index 26d08a87..9b0eb244 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveU32 is a type alias of uint32 with receiver methods to conform to RVType -type PrimitiveU32 uint32 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveU32 is a struct of uint32 with receiver methods to conform to RVType +type PrimitiveU32 struct { + Value uint32 +} // WriteTo writes the uint32 to the given writable func (u32 *PrimitiveU32) WriteTo(writable Writable) { - writable.WritePrimitiveUInt32LE(uint32(*u32)) + writable.WritePrimitiveUInt32LE(u32.Value) } // ExtractFrom extracts the uint32 to the given readable @@ -17,16 +17,14 @@ func (u32 *PrimitiveU32) ExtractFrom(readable Readable) error { return err } - *u32 = PrimitiveU32(value) + u32.Value = value return nil } // Copy returns a pointer to a copy of the uint32. Requires type assertion when used func (u32 *PrimitiveU32) Copy() RVType { - copied := PrimitiveU32(*u32) - - return &copied + return NewPrimitiveU32(u32.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (u32 *PrimitiveU32) Equals(o RVType) bool { return false } - return *u32 == *o.(*PrimitiveU32) + return u32.Value == o.(*PrimitiveU32).Value } // NewPrimitiveU32 returns a new PrimitiveU32 func NewPrimitiveU32(ui32 uint32) *PrimitiveU32 { - u32 := PrimitiveU32(ui32) - - return &u32 + return &PrimitiveU32{Value: ui32} } diff --git a/types/primitive_u64.go b/types/primitive_u64.go index 7c3c654d..8f28c27d 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -1,13 +1,13 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - -// PrimitiveU64 is a type alias of uint64 with receiver methods to conform to RVType -type PrimitiveU64 uint64 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveU64 is a struct of uint64 with receiver methods to conform to RVType +type PrimitiveU64 struct { + Value uint64 +} // WriteTo writes the uint64 to the given writable func (u64 *PrimitiveU64) WriteTo(writable Writable) { - writable.WritePrimitiveUInt64LE(uint64(*u64)) + writable.WritePrimitiveUInt64LE(u64.Value) } // ExtractFrom extracts the uint64 to the given readable @@ -17,16 +17,14 @@ func (u64 *PrimitiveU64) ExtractFrom(readable Readable) error { return err } - *u64 = PrimitiveU64(value) + u64.Value = value return nil } // Copy returns a pointer to a copy of the uint64. Requires type assertion when used func (u64 *PrimitiveU64) Copy() RVType { - copied := PrimitiveU64(*u64) - - return &copied + return NewPrimitiveU64(u64.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +33,10 @@ func (u64 *PrimitiveU64) Equals(o RVType) bool { return false } - return *u64 == *o.(*PrimitiveU64) + return u64.Value == o.(*PrimitiveU64).Value } // NewPrimitiveU64 returns a new PrimitiveU64 func NewPrimitiveU64(ui64 uint64) *PrimitiveU64 { - u64 := PrimitiveU64(ui64) - - return &u64 + return &PrimitiveU64{Value: ui64} } diff --git a/types/primitive_u8.go b/types/primitive_u8.go index aef73f7d..9492b8d4 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -1,13 +1,14 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? -// PrimitiveU8 is a type alias of uint8 with receiver methods to conform to RVType -type PrimitiveU8 uint8 // TODO - Should we make this a struct instead of a type alias? +// PrimitiveU8 is a struct of uint8 with receiver methods to conform to RVType +type PrimitiveU8 struct { + Value uint8 +} // WriteTo writes the uint8 to the given writable func (u8 *PrimitiveU8) WriteTo(writable Writable) { - writable.WritePrimitiveUInt8(uint8(*u8)) + writable.WritePrimitiveUInt8(u8.Value) } // ExtractFrom extracts the uint8 to the given readable @@ -17,16 +18,14 @@ func (u8 *PrimitiveU8) ExtractFrom(readable Readable) error { return err } - *u8 = PrimitiveU8(value) + u8.Value = value return nil } // Copy returns a pointer to a copy of the uint8. Requires type assertion when used func (u8 *PrimitiveU8) Copy() RVType { - copied := PrimitiveU8(*u8) - - return &copied + return NewPrimitiveU8(u8.Value) } // Equals checks if the input is equal in value to the current instance @@ -35,12 +34,10 @@ func (u8 *PrimitiveU8) Equals(o RVType) bool { return false } - return *u8 == *o.(*PrimitiveU8) + return u8.Value == o.(*PrimitiveU8).Value } // NewPrimitiveU8 returns a new PrimitiveU8 func NewPrimitiveU8(ui8 uint8) *PrimitiveU8 { - u8 := PrimitiveU8(ui8) - - return &u8 + return &PrimitiveU8{Value: ui8} } diff --git a/types/qbuffer.go b/types/qbuffer.go index e9b38eec..0da095e0 100644 --- a/types/qbuffer.go +++ b/types/qbuffer.go @@ -1,24 +1,23 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - import ( "bytes" "fmt" ) -// QBuffer is a type alias of []byte with receiver methods to conform to RVType -type QBuffer []byte // TODO - Should we make this a struct instead of a type alias? +// QBuffer is a struct of []byte with receiver methods to conform to RVType +type QBuffer struct { + Value []byte +} // WriteTo writes the []byte to the given writable func (qb *QBuffer) WriteTo(writable Writable) { - data := *qb - length := len(data) + length := len(qb.Value) writable.WritePrimitiveUInt16LE(uint16(length)) if length > 0 { - writable.Write([]byte(data)) + writable.Write(qb.Value) } } @@ -34,16 +33,14 @@ func (qb *QBuffer) ExtractFrom(readable Readable) error { return fmt.Errorf("Failed to read NEX qBuffer data. %s", err.Error()) } - *qb = QBuffer(data) + qb.Value = data return nil } // Copy returns a pointer to a copy of the qBuffer. Requires type assertion when used func (qb *QBuffer) Copy() RVType { - copied := QBuffer(*qb) - - return &copied + return NewQBuffer(qb.Value) } // Equals checks if the input is equal in value to the current instance @@ -52,12 +49,10 @@ func (qb *QBuffer) Equals(o RVType) bool { return false } - return bytes.Equal([]byte(*qb), []byte(*o.(*Buffer))) + return bytes.Equal(qb.Value, o.(*QBuffer).Value) } // NewQBuffer returns a new QBuffer func NewQBuffer(data []byte) *QBuffer { - var qb QBuffer = data - - return &qb + return &QBuffer{Value: data} } diff --git a/types/quuid.go b/types/quuid.go index 8b513644..d4a99e57 100644 --- a/types/quuid.go +++ b/types/quuid.go @@ -114,7 +114,6 @@ func (qu *QUUID) GetStringValue() string { // FromString converts a UUID string to a qUUID func (qu *QUUID) FromString(uuid string) error { - sections := strings.Split(uuid, "-") if len(sections) != 5 { return fmt.Errorf("Invalid UUID. Not enough sections. Expected 5, got %d", len(sections)) diff --git a/types/result_range.go b/types/result_range.go index fecef3b2..f8fb12e7 100644 --- a/types/result_range.go +++ b/types/result_range.go @@ -1,67 +1,48 @@ package types import ( - "errors" "fmt" ) // ResultRange class which holds information about how to make queries type ResultRange struct { Structure - Offset uint32 // TODO - Replace this with PrimitiveU32? - Length uint32 // TODO - Replace this with PrimitiveU32? + Offset *PrimitiveU32 + Length *PrimitiveU32 } // WriteTo writes the ResultRange to the given writable func (rr *ResultRange) WriteTo(writable Writable) { contentWritable := writable.CopyNew() - contentWritable.WritePrimitiveUInt32LE(rr.Offset) - contentWritable.WritePrimitiveUInt32LE(rr.Length) + rr.Offset.WriteTo(contentWritable) + rr.Length.WriteTo(contentWritable) content := contentWritable.Bytes() - if writable.UseStructureHeader() { - writable.WritePrimitiveUInt8(rr.StructureVersion()) - writable.WritePrimitiveUInt32LE(uint32(len(content))) - } + rr.WriteHeaderTo(writable, uint32(len(content))) writable.Write(content) } // ExtractFrom extracts the ResultRange to the given readable func (rr *ResultRange) ExtractFrom(readable Readable) error { - if readable.UseStructureHeader() { - version, err := readable.ReadPrimitiveUInt8() - if err != nil { - return fmt.Errorf("Failed to read ResultRange version. %s", err.Error()) - } - - contentLength, err := readable.ReadPrimitiveUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read ResultRange content length. %s", err.Error()) - } - - if readable.Remaining() < uint64(contentLength) { - return errors.New("ResultRange content length longer than data size") - } - - rr.SetStructureVersion(version) + var err error + + if err = rr.ExtractHeaderFrom(readable); err != nil { + return fmt.Errorf("Failed to read ResultRange header. %s", err.Error()) } - offset, err := readable.ReadPrimitiveUInt32LE() + err = rr.Offset.ExtractFrom(readable) if err != nil { - return fmt.Errorf("Failed to read ResultRange offset. %s", err.Error()) + return fmt.Errorf("Failed to read ResultRange.Offset. %s", err.Error()) } - length, err := readable.ReadPrimitiveUInt32LE() + err = rr.Length.ExtractFrom(readable) if err != nil { - return fmt.Errorf("Failed to read ResultRange length. %s", err.Error()) + return fmt.Errorf("Failed to read ResultRange.Length. %s", err.Error()) } - rr.Offset = offset - rr.Length = length - return nil } @@ -69,9 +50,9 @@ func (rr *ResultRange) ExtractFrom(readable Readable) error { func (rr *ResultRange) Copy() RVType { copied := NewResultRange() - copied.structureVersion = rr.structureVersion - copied.Offset = rr.Offset - copied.Length = rr.Length + copied.StructureVersion = rr.StructureVersion + copied.Offset = rr.Offset.Copy().(*PrimitiveU32) + copied.Length = rr.Length.Copy().(*PrimitiveU32) return copied } @@ -84,19 +65,21 @@ func (rr *ResultRange) Equals(o RVType) bool { other := o.(*ResultRange) - if rr.structureVersion != other.structureVersion { + if rr.StructureVersion != other.StructureVersion { return false } - if rr.Offset != other.Offset { + if !rr.Offset.Equals(other.Offset) { return false } - return rr.Length == other.Length + return rr.Length.Equals(other.Length) } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewResultRange returns a new ResultRange func NewResultRange() *ResultRange { - return &ResultRange{} + return &ResultRange{ + Offset: NewPrimitiveU32(0), + Length: NewPrimitiveU32(0), + } } diff --git a/types/rv_connection_data.go b/types/rv_connection_data.go index 4b8c405a..eec97487 100644 --- a/types/rv_connection_data.go +++ b/types/rv_connection_data.go @@ -1,7 +1,6 @@ package types import ( - "errors" "fmt" ) @@ -22,70 +21,46 @@ func (rvcd *RVConnectionData) WriteTo(writable Writable) { rvcd.SpecialProtocols.WriteTo(contentWritable) rvcd.StationURLSpecialProtocols.WriteTo(contentWritable) - if rvcd.structureVersion >= 1 { + if rvcd.StructureVersion >= 1 { rvcd.Time.WriteTo(contentWritable) } content := contentWritable.Bytes() - if writable.UseStructureHeader() { - writable.WritePrimitiveUInt8(rvcd.StructureVersion()) - writable.WritePrimitiveUInt32LE(uint32(len(content))) - } + rvcd.WriteHeaderTo(writable, uint32(len(content))) writable.Write(content) } // ExtractFrom extracts the RVConnectionData to the given readable func (rvcd *RVConnectionData) ExtractFrom(readable Readable) error { - if readable.UseStructureHeader() { - version, err := readable.ReadPrimitiveUInt8() - if err != nil { - return fmt.Errorf("Failed to read RVConnectionData version. %s", err.Error()) - } - - contentLength, err := readable.ReadPrimitiveUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read RVConnectionData content length. %s", err.Error()) - } - - if readable.Remaining() < uint64(contentLength) { - return errors.New("RVConnectionData content length longer than data size") - } - - rvcd.SetStructureVersion(version) + var err error + if err = rvcd.ExtractHeaderFrom(readable); err != nil { + return fmt.Errorf("Failed to read RVConnectionData header. %s", err.Error()) } - var stationURL *StationURL - specialProtocols := NewList[*PrimitiveU8]() - var stationURLSpecialProtocols *StationURL - var time *DateTime - - specialProtocols.Type = NewPrimitiveU8(0) - - if err := stationURL.ExtractFrom(readable); err != nil { - return fmt.Errorf("Failed to read RVConnectionData StationURL. %s", err.Error()) + err = rvcd.StationURL.ExtractFrom(readable) + if err != nil { + return fmt.Errorf("Failed to read RVConnectionData.StationURL. %s", err.Error()) } - if err := specialProtocols.ExtractFrom(readable); err != nil { - return fmt.Errorf("Failed to read SpecialProtocols StationURL. %s", err.Error()) + err = rvcd.SpecialProtocols.ExtractFrom(readable) + if err != nil { + return fmt.Errorf("Failed to read RVConnectionData.SpecialProtocols. %s", err.Error()) } - if err := stationURLSpecialProtocols.ExtractFrom(readable); err != nil { - return fmt.Errorf("Failed to read StationURLSpecialProtocols StationURL. %s", err.Error()) + err = rvcd.StationURLSpecialProtocols.ExtractFrom(readable) + if err != nil { + return fmt.Errorf("Failed to read RVConnectionData.StationURLSpecialProtocols. %s", err.Error()) } - if rvcd.structureVersion >= 1 { - if err := time.ExtractFrom(readable); err != nil { - return fmt.Errorf("Failed to read Time StationURL. %s", err.Error()) + if rvcd.StructureVersion >= 1 { + err := rvcd.Time.ExtractFrom(readable) + if err != nil { + return fmt.Errorf("Failed to read RVConnectionData.Time. %s", err.Error()) } } - rvcd.StationURL = stationURL - rvcd.SpecialProtocols = specialProtocols - rvcd.StationURLSpecialProtocols = stationURLSpecialProtocols - rvcd.Time = time - return nil } @@ -93,12 +68,12 @@ func (rvcd *RVConnectionData) ExtractFrom(readable Readable) error { func (rvcd *RVConnectionData) Copy() RVType { copied := NewRVConnectionData() - copied.structureVersion = rvcd.structureVersion + copied.StructureVersion = rvcd.StructureVersion copied.StationURL = rvcd.StationURL.Copy().(*StationURL) copied.SpecialProtocols = rvcd.SpecialProtocols.Copy().(*List[*PrimitiveU8]) copied.StationURLSpecialProtocols = rvcd.StationURLSpecialProtocols.Copy().(*StationURL) - if rvcd.structureVersion >= 1 { + if rvcd.StructureVersion >= 1 { copied.Time = rvcd.Time.Copy().(*DateTime) } @@ -113,7 +88,7 @@ func (rvcd *RVConnectionData) Equals(o RVType) bool { other := o.(*RVConnectionData) - if rvcd.structureVersion != other.structureVersion { + if rvcd.StructureVersion != other.StructureVersion { return false } @@ -129,7 +104,7 @@ func (rvcd *RVConnectionData) Equals(o RVType) bool { return false } - if rvcd.structureVersion >= 1 { + if rvcd.StructureVersion >= 1 { if !rvcd.Time.Equals(other.Time) { return false } @@ -138,8 +113,16 @@ func (rvcd *RVConnectionData) Equals(o RVType) bool { return true } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewRVConnectionData returns a new RVConnectionData func NewRVConnectionData() *RVConnectionData { - return &RVConnectionData{} + rvcd := &RVConnectionData{ + StationURL: NewStationURL(""), + SpecialProtocols: NewList[*PrimitiveU8](), + StationURLSpecialProtocols: NewStationURL(""), + Time: NewDateTime(0), + } + + rvcd.SpecialProtocols.Type = NewPrimitiveU8(0) + + return rvcd } diff --git a/types/station_url.go b/types/station_url.go index 348546ea..c588e8a3 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -15,20 +15,20 @@ type StationURL struct { // WriteTo writes the StationURL to the given writable func (s *StationURL) WriteTo(writable Writable) { - str := String(s.EncodeToString()) + str := NewString(s.EncodeToString()) str.WriteTo(writable) } // ExtractFrom extracts the StationURL to the given readable func (s *StationURL) ExtractFrom(readable Readable) error { - var str String + str := NewString("") if err := str.ExtractFrom(readable); err != nil { return fmt.Errorf("Failed to read StationURL. %s", err.Error()) } - s.FromString(string(str)) + s.FromString(str.Value) return nil } diff --git a/types/string.go b/types/string.go index 06c095aa..e9e201a5 100644 --- a/types/string.go +++ b/types/string.go @@ -1,19 +1,19 @@ package types -// TODO - Should this have a "Value"-kind of method to get the original value? - import ( "errors" "fmt" "strings" ) -// String is a type alias of string with receiver methods to conform to RVType -type String string // TODO - Should we make this a struct instead of a type alias? +// String is a struct of string with receiver methods to conform to RVType +type String struct { + Value string +} // WriteTo writes the String to the given writable func (s *String) WriteTo(writable Writable) { - str := *s + "\x00" + str := s.Value + "\x00" strLength := len(str) if writable.StringLengthSize() == 4 { @@ -55,16 +55,14 @@ func (s *String) ExtractFrom(readable Readable) error { str := strings.TrimRight(string(stringData), "\x00") - *s = String(str) + s.Value = str return nil } // Copy returns a pointer to a copy of the String. Requires type assertion when used func (s *String) Copy() RVType { - copied := String(*s) - - return &copied + return NewString(s.Value) } // Equals checks if the input is equal in value to the current instance @@ -73,12 +71,10 @@ func (s *String) Equals(o RVType) bool { return false } - return *s == *o.(*String) + return s.Value == o.(*String).Value } // NewString returns a new String func NewString(str string) *String { - s := String(str) - - return &s + return &String{Value: str} } diff --git a/types/structure.go b/types/structure.go index 5c4b3006..4d834c95 100644 --- a/types/structure.go +++ b/types/structure.go @@ -1,39 +1,43 @@ package types -// StructureInterface implements all Structure methods -type StructureInterface interface { - SetParentType(parentType StructureInterface) - ParentType() StructureInterface - SetStructureVersion(version uint8) - StructureVersion() uint8 - Copy() StructureInterface - Equals(other StructureInterface) bool - FormatToString(indentationLevel int) string -} +import ( + "errors" + "fmt" +) // Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct type Structure struct { - parentType StructureInterface - structureVersion uint8 - StructureInterface + ParentType RVType + StructureVersion uint8 } -// SetParentType sets the Structures parent type -func (s *Structure) SetParentType(parentType StructureInterface) { - s.parentType = parentType -} +// ExtractHeaderFrom extracts the structure header from the given readable +func (s *Structure) ExtractHeaderFrom(readable Readable) error { + if readable.UseStructureHeader() { + version, err := readable.ReadPrimitiveUInt8() + if err != nil { + return fmt.Errorf("Failed to read Structure version. %s", err.Error()) + } -// ParentType returns the Structures parent type. nil if the Structure does not inherit another Structure -func (s *Structure) ParentType() StructureInterface { - return s.parentType -} + contentLength, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read Structure content length. %s", err.Error()) + } + + if readable.Remaining() < uint64(contentLength) { + return errors.New("Structure content length longer than data size") + } + + s.StructureVersion = version + } -// SetStructureVersion sets the structures version. Only used in NEX 3.5+ -func (s *Structure) SetStructureVersion(version uint8) { - s.structureVersion = version + return nil } -// StructureVersion returns the structures version. Only used in NEX 3.5+ -func (s *Structure) StructureVersion() uint8 { - return s.structureVersion +// WriteHeaderTo writes the structure header to the given writable +func (s *Structure) WriteHeaderTo(writable Writable, contentLength uint32) { + if writable.UseStructureHeader() { + writable.WritePrimitiveUInt8(s.StructureVersion) + writable.WritePrimitiveUInt32LE(contentLength) + } } diff --git a/types/variant.go b/types/variant.go index fc39571c..13b6ac1a 100644 --- a/types/variant.go +++ b/types/variant.go @@ -14,7 +14,7 @@ func RegisterVariantType(id uint8, rvType RVType) { // Variant is a type which can old many other types type Variant struct { - TypeID uint8 // TODO - Replace this with PrimitiveU8? + TypeID uint8 Type RVType } From 03434ad7772debbc76b2c1e6a5e635121ea26c81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Wed, 27 Dec 2023 23:07:51 +0000 Subject: [PATCH 100/178] types: Use RV types for AnyDataHolder and Variant Also fix a typo in the `ExtractFrom` funcions. --- types/any_data_holder.go | 57 ++++++++++++++++---------------- types/buffer.go | 2 +- types/class_version_container.go | 2 +- types/data.go | 2 +- types/datetime.go | 2 +- types/list.go | 2 +- types/map.go | 2 +- types/pid.go | 2 +- types/primitive_bool.go | 2 +- types/primitive_float32.go | 2 +- types/primitive_float64.go | 2 +- types/primitive_s16.go | 2 +- types/primitive_s32.go | 2 +- types/primitive_s64.go | 2 +- types/primitive_s8.go | 2 +- types/primitive_u16.go | 2 +- types/primitive_u32.go | 2 +- types/primitive_u64.go | 2 +- types/primitive_u8.go | 2 +- types/qbuffer.go | 2 +- types/quuid.go | 2 +- types/result.go | 2 +- types/result_range.go | 2 +- types/rv_connection_data.go | 2 +- types/station_url.go | 2 +- types/string.go | 2 +- types/variant.go | 22 ++++++------ 27 files changed, 64 insertions(+), 65 deletions(-) diff --git a/types/any_data_holder.go b/types/any_data_holder.go index 89fa487e..e036f841 100644 --- a/types/any_data_holder.go +++ b/types/any_data_holder.go @@ -16,62 +16,57 @@ func RegisterDataHolderType(name string, rvType RVType) { // other Structure. Typically this base class is the empty `Data` Structure, but this is not always the case. // The contained Structures name & length are sent with the Structure body, so the receiver can properly decode it type AnyDataHolder struct { - TypeName string - Length1 uint32 - Length2 uint32 + TypeName *String + Length1 *PrimitiveU32 + Length2 *PrimitiveU32 ObjectData RVType } -// WriteTo writes the AnyDataholder to the given writable +// WriteTo writes the AnyDataHolder to the given writable func (adh *AnyDataHolder) WriteTo(writable Writable) { contentWritable := writable.CopyNew() adh.ObjectData.WriteTo(contentWritable) objectData := contentWritable.Bytes() - typeName := NewString(adh.TypeName) length1 := uint32(len(objectData) + 4) length2 := uint32(len(objectData)) - typeName.WriteTo(writable) + adh.TypeName.WriteTo(writable) writable.WritePrimitiveUInt32LE(length1) writable.WritePrimitiveUInt32LE(length2) writable.Write(objectData) } -// ExtractFrom extracts the AnyDataholder to the given readable +// ExtractFrom extracts the AnyDataHolder from the given readable func (adh *AnyDataHolder) ExtractFrom(readable Readable) error { - typeName := NewString("") + var err error - err := typeName.ExtractFrom(readable) + err = adh.TypeName.ExtractFrom(readable) if err != nil { - return fmt.Errorf("Failed to read DanyDataHolder type name. %s", err.Error()) + return fmt.Errorf("Failed to read AnyDataHolder type name. %s", err.Error()) } - length1, err := readable.ReadPrimitiveUInt32LE() + err = adh.Length1.ExtractFrom(readable) if err != nil { - return fmt.Errorf("Failed to read DanyDataHolder length 1. %s", err.Error()) + return fmt.Errorf("Failed to read AnyDataHolder length 1. %s", err.Error()) } - length2, err := readable.ReadPrimitiveUInt32LE() + err = adh.Length2.ExtractFrom(readable) if err != nil { - return fmt.Errorf("Failed to read DanyDataHolder length 2. %s", err.Error()) + return fmt.Errorf("Failed to read AnyDataHolder length 2. %s", err.Error()) } - if _, ok := AnyDataHolderObjects[typeName.Value]; !ok { - return fmt.Errorf("Unknown AnyDataHolder type: %s", typeName.Value) + if _, ok := AnyDataHolderObjects[adh.TypeName.Value]; !ok { + return fmt.Errorf("Unknown AnyDataHolder type: %s", adh.TypeName.Value) } - adh.ObjectData = AnyDataHolderObjects[typeName.Value].Copy() + adh.ObjectData = AnyDataHolderObjects[adh.TypeName.Value].Copy() if err := adh.ObjectData.ExtractFrom(readable); err != nil { - return fmt.Errorf("Failed to read DanyDataHolder object data. %s", err.Error()) + return fmt.Errorf("Failed to read AnyDataHolder object data. %s", err.Error()) } - adh.TypeName = typeName.Value - adh.Length1 = length1 - adh.Length2 = length2 - return nil } @@ -79,9 +74,9 @@ func (adh *AnyDataHolder) ExtractFrom(readable Readable) error { func (adh *AnyDataHolder) Copy() RVType { copied := NewAnyDataHolder() - copied.TypeName = adh.TypeName - copied.Length1 = adh.Length1 - copied.Length2 = adh.Length2 + copied.TypeName = adh.TypeName.Copy().(*String) + copied.Length1 = adh.Length1.Copy().(*PrimitiveU32) + copied.Length2 = adh.Length2.Copy().(*PrimitiveU32) copied.ObjectData = adh.ObjectData.Copy() return copied @@ -95,15 +90,15 @@ func (adh *AnyDataHolder) Equals(o RVType) bool { other := o.(*AnyDataHolder) - if adh.TypeName != other.TypeName { + if !adh.TypeName.Equals(other.TypeName) { return false } - if adh.Length1 != other.Length1 { + if !adh.Length1.Equals(other.Length1) { return false } - if adh.Length2 != other.Length2 { + if !adh.Length2.Equals(other.Length2) { return false } @@ -113,5 +108,9 @@ func (adh *AnyDataHolder) Equals(o RVType) bool { // TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewAnyDataHolder returns a new AnyDataHolder func NewAnyDataHolder() *AnyDataHolder { - return &AnyDataHolder{} + return &AnyDataHolder{ + TypeName: NewString(""), + Length1: NewPrimitiveU32(0), + Length2: NewPrimitiveU32(0), + } } diff --git a/types/buffer.go b/types/buffer.go index db149003..d0c4682d 100644 --- a/types/buffer.go +++ b/types/buffer.go @@ -21,7 +21,7 @@ func (b *Buffer) WriteTo(writable Writable) { } } -// ExtractFrom extracts the Buffer to the given readable +// ExtractFrom extracts the Buffer from the given readable func (b *Buffer) ExtractFrom(readable Readable) error { length, err := readable.ReadPrimitiveUInt32LE() if err != nil { diff --git a/types/class_version_container.go b/types/class_version_container.go index a15e8292..9e2e18d4 100644 --- a/types/class_version_container.go +++ b/types/class_version_container.go @@ -11,7 +11,7 @@ func (cvc *ClassVersionContainer) WriteTo(writable Writable) { cvc.ClassVersions.WriteTo(writable) } -// ExtractFrom extracts the ClassVersionContainer to the given readable +// ExtractFrom extracts the ClassVersionContainer from the given readable func (cvc *ClassVersionContainer) ExtractFrom(readable Readable) error { return cvc.ClassVersions.ExtractFrom(readable) } diff --git a/types/data.go b/types/data.go index 3f399327..3fc2fd51 100644 --- a/types/data.go +++ b/types/data.go @@ -15,7 +15,7 @@ func (e *Data) WriteTo(writable Writable) { e.WriteHeaderTo(writable, 0) } -// ExtractFrom extracts the Data to the given readable +// ExtractFrom extracts the Data from the given readable func (e *Data) ExtractFrom(readable Readable) error { if err := e.ExtractHeaderFrom(readable); err != nil { return fmt.Errorf("Failed to read Data header. %s", err.Error()) diff --git a/types/datetime.go b/types/datetime.go index fb890873..9b4bef98 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -16,7 +16,7 @@ func (dt *DateTime) WriteTo(writable Writable) { writable.WritePrimitiveUInt64LE(dt.value) } -// ExtractFrom extracts the DateTime to the given readable +// ExtractFrom extracts the DateTime from the given readable func (dt *DateTime) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveUInt64LE() if err != nil { diff --git a/types/list.go b/types/list.go index 8c9ddfd5..4f4861e4 100644 --- a/types/list.go +++ b/types/list.go @@ -17,7 +17,7 @@ func (l *List[T]) WriteTo(writable Writable) { } } -// ExtractFrom extracts the bool to the given readable +// ExtractFrom extracts the bool from the given readable func (l *List[T]) ExtractFrom(readable Readable) error { length, err := readable.ReadPrimitiveUInt32LE() if err != nil { diff --git a/types/map.go b/types/map.go index 32dc6373..fc321212 100644 --- a/types/map.go +++ b/types/map.go @@ -22,7 +22,7 @@ func (m *Map[K, V]) WriteTo(writable Writable) { } } -// ExtractFrom extracts the bool to the given readable +// ExtractFrom extracts the bool from the given readable func (m *Map[K, V]) ExtractFrom(readable Readable) error { length, err := readable.ReadPrimitiveUInt32LE() if err != nil { diff --git a/types/pid.go b/types/pid.go index d214ddef..a3310fcd 100644 --- a/types/pid.go +++ b/types/pid.go @@ -23,7 +23,7 @@ func (p *PID) WriteTo(writable Writable) { } } -// ExtractFrom extracts the bool to the given readable +// ExtractFrom extracts the bool from the given readable func (p *PID) ExtractFrom(readable Readable) error { var pid uint64 var err error diff --git a/types/primitive_bool.go b/types/primitive_bool.go index 17cc41de..8a37a28b 100644 --- a/types/primitive_bool.go +++ b/types/primitive_bool.go @@ -10,7 +10,7 @@ func (b *PrimitiveBool) WriteTo(writable Writable) { writable.WritePrimitiveBool(b.Value) } -// ExtractFrom extracts the bool to the given readable +// ExtractFrom extracts the bool from the given readable func (b *PrimitiveBool) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveBool() if err != nil { diff --git a/types/primitive_float32.go b/types/primitive_float32.go index b3293702..3215a655 100644 --- a/types/primitive_float32.go +++ b/types/primitive_float32.go @@ -10,7 +10,7 @@ func (f32 *PrimitiveF32) WriteTo(writable Writable) { writable.WritePrimitiveFloat32LE(f32.Value) } -// ExtractFrom extracts the float32 to the given readable +// ExtractFrom extracts the float32 from the given readable func (f32 *PrimitiveF32) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveFloat32LE() if err != nil { diff --git a/types/primitive_float64.go b/types/primitive_float64.go index f988dbed..1477800b 100644 --- a/types/primitive_float64.go +++ b/types/primitive_float64.go @@ -10,7 +10,7 @@ func (f64 *PrimitiveF64) WriteTo(writable Writable) { writable.WritePrimitiveFloat64LE(f64.Value) } -// ExtractFrom extracts the float64 to the given readable +// ExtractFrom extracts the float64 from the given readable func (f64 *PrimitiveF64) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveFloat64LE() if err != nil { diff --git a/types/primitive_s16.go b/types/primitive_s16.go index 956e8143..406bc5ae 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -10,7 +10,7 @@ func (s16 *PrimitiveS16) WriteTo(writable Writable) { writable.WritePrimitiveInt16LE(s16.Value) } -// ExtractFrom extracts the int16 to the given readable +// ExtractFrom extracts the int16 from the given readable func (s16 *PrimitiveS16) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveInt16LE() if err != nil { diff --git a/types/primitive_s32.go b/types/primitive_s32.go index a7fe4b67..2919963c 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -10,7 +10,7 @@ func (s32 *PrimitiveS32) WriteTo(writable Writable) { writable.WritePrimitiveInt32LE(s32.Value) } -// ExtractFrom extracts the int32 to the given readable +// ExtractFrom extracts the int32 from the given readable func (s32 *PrimitiveS32) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveInt32LE() if err != nil { diff --git a/types/primitive_s64.go b/types/primitive_s64.go index 39ec0119..c0a07d80 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -10,7 +10,7 @@ func (s64 *PrimitiveS64) WriteTo(writable Writable) { writable.WritePrimitiveInt64LE(s64.Value) } -// ExtractFrom extracts the int64 to the given readable +// ExtractFrom extracts the int64 from the given readable func (s64 *PrimitiveS64) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveInt64LE() if err != nil { diff --git a/types/primitive_s8.go b/types/primitive_s8.go index 432a9677..c63502ea 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -10,7 +10,7 @@ func (s8 *PrimitiveS8) WriteTo(writable Writable) { writable.WritePrimitiveInt8(s8.Value) } -// ExtractFrom extracts the int8 to the given readable +// ExtractFrom extracts the int8 from the given readable func (s8 *PrimitiveS8) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveInt8() if err != nil { diff --git a/types/primitive_u16.go b/types/primitive_u16.go index 83ba608c..6cc39ac3 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -10,7 +10,7 @@ func (u16 *PrimitiveU16) WriteTo(writable Writable) { writable.WritePrimitiveUInt16LE(u16.Value) } -// ExtractFrom extracts the uint16 to the given readable +// ExtractFrom extracts the uint16 from the given readable func (u16 *PrimitiveU16) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveUInt16LE() if err != nil { diff --git a/types/primitive_u32.go b/types/primitive_u32.go index 9b0eb244..866ce724 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -10,7 +10,7 @@ func (u32 *PrimitiveU32) WriteTo(writable Writable) { writable.WritePrimitiveUInt32LE(u32.Value) } -// ExtractFrom extracts the uint32 to the given readable +// ExtractFrom extracts the uint32 from the given readable func (u32 *PrimitiveU32) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveUInt32LE() if err != nil { diff --git a/types/primitive_u64.go b/types/primitive_u64.go index 8f28c27d..53a71e7f 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -10,7 +10,7 @@ func (u64 *PrimitiveU64) WriteTo(writable Writable) { writable.WritePrimitiveUInt64LE(u64.Value) } -// ExtractFrom extracts the uint64 to the given readable +// ExtractFrom extracts the uint64 from the given readable func (u64 *PrimitiveU64) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveUInt64LE() if err != nil { diff --git a/types/primitive_u8.go b/types/primitive_u8.go index 9492b8d4..e88ee46b 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -11,7 +11,7 @@ func (u8 *PrimitiveU8) WriteTo(writable Writable) { writable.WritePrimitiveUInt8(u8.Value) } -// ExtractFrom extracts the uint8 to the given readable +// ExtractFrom extracts the uint8 from the given readable func (u8 *PrimitiveU8) ExtractFrom(readable Readable) error { value, err := readable.ReadPrimitiveUInt8() if err != nil { diff --git a/types/qbuffer.go b/types/qbuffer.go index 0da095e0..97948aae 100644 --- a/types/qbuffer.go +++ b/types/qbuffer.go @@ -21,7 +21,7 @@ func (qb *QBuffer) WriteTo(writable Writable) { } } -// ExtractFrom extracts the QBuffer to the given readable +// ExtractFrom extracts the QBuffer from the given readable func (qb *QBuffer) ExtractFrom(readable Readable) error { length, err := readable.ReadPrimitiveUInt16LE() if err != nil { diff --git a/types/quuid.go b/types/quuid.go index d4a99e57..ea49bfc3 100644 --- a/types/quuid.go +++ b/types/quuid.go @@ -18,7 +18,7 @@ func (qu *QUUID) WriteTo(writable Writable) { writable.Write(qu.Data) } -// ExtractFrom extracts the QUUID to the given readable +// ExtractFrom extracts the QUUID from the given readable func (qu *QUUID) ExtractFrom(readable Readable) error { if readable.Remaining() < uint64(16) { return errors.New("Not enough data left to read qUUID") diff --git a/types/result.go b/types/result.go index a1cf91d7..cd5a8b33 100644 --- a/types/result.go +++ b/types/result.go @@ -17,7 +17,7 @@ func (r *Result) WriteTo(writable Writable) { writable.WritePrimitiveUInt32LE(r.Code) } -// ExtractFrom extracts the Result to the given readable +// ExtractFrom extracts the Result from the given readable func (r *Result) ExtractFrom(readable Readable) error { code, err := readable.ReadPrimitiveUInt32LE() if err != nil { diff --git a/types/result_range.go b/types/result_range.go index f8fb12e7..4cd1d75e 100644 --- a/types/result_range.go +++ b/types/result_range.go @@ -25,7 +25,7 @@ func (rr *ResultRange) WriteTo(writable Writable) { writable.Write(content) } -// ExtractFrom extracts the ResultRange to the given readable +// ExtractFrom extracts the ResultRange from the given readable func (rr *ResultRange) ExtractFrom(readable Readable) error { var err error diff --git a/types/rv_connection_data.go b/types/rv_connection_data.go index eec97487..1113a8b1 100644 --- a/types/rv_connection_data.go +++ b/types/rv_connection_data.go @@ -32,7 +32,7 @@ func (rvcd *RVConnectionData) WriteTo(writable Writable) { writable.Write(content) } -// ExtractFrom extracts the RVConnectionData to the given readable +// ExtractFrom extracts the RVConnectionData from the given readable func (rvcd *RVConnectionData) ExtractFrom(readable Readable) error { var err error if err = rvcd.ExtractHeaderFrom(readable); err != nil { diff --git a/types/station_url.go b/types/station_url.go index c588e8a3..93e252d2 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -20,7 +20,7 @@ func (s *StationURL) WriteTo(writable Writable) { str.WriteTo(writable) } -// ExtractFrom extracts the StationURL to the given readable +// ExtractFrom extracts the StationURL from the given readable func (s *StationURL) ExtractFrom(readable Readable) error { str := NewString("") diff --git a/types/string.go b/types/string.go index e9e201a5..067ab765 100644 --- a/types/string.go +++ b/types/string.go @@ -25,7 +25,7 @@ func (s *String) WriteTo(writable Writable) { writable.Write([]byte(str)) } -// ExtractFrom extracts the String to the given readable +// ExtractFrom extracts the String from the given readable func (s *String) ExtractFrom(readable Readable) error { var length uint64 var err error diff --git a/types/variant.go b/types/variant.go index 13b6ac1a..0888ef74 100644 --- a/types/variant.go +++ b/types/variant.go @@ -14,30 +14,28 @@ func RegisterVariantType(id uint8, rvType RVType) { // Variant is a type which can old many other types type Variant struct { - TypeID uint8 + TypeID *PrimitiveU8 Type RVType } // WriteTo writes the Variant to the given writable func (v *Variant) WriteTo(writable Writable) { - writable.WritePrimitiveUInt8(v.TypeID) + v.TypeID.WriteTo(writable) v.Type.WriteTo(writable) } -// ExtractFrom extracts the Variant to the given readable +// ExtractFrom extracts the Variant from the given readable func (v *Variant) ExtractFrom(readable Readable) error { - typeID, err := readable.ReadPrimitiveUInt8() + err := v.TypeID.ExtractFrom(readable) if err != nil { return fmt.Errorf("Failed to read Variant type ID. %s", err.Error()) } - v.TypeID = typeID - - if _, ok := VariantTypes[v.TypeID]; !ok { + if _, ok := VariantTypes[v.TypeID.Value]; !ok { return fmt.Errorf("Invalid Variant type ID %d", v.TypeID) } - v.Type = VariantTypes[v.TypeID].Copy() + v.Type = VariantTypes[v.TypeID.Value].Copy() return v.Type.ExtractFrom(readable) } @@ -46,7 +44,7 @@ func (v *Variant) ExtractFrom(readable Readable) error { func (v *Variant) Copy() RVType { copied := NewVariant() - copied.TypeID = v.TypeID + copied.TypeID = v.TypeID.Copy().(*PrimitiveU8) copied.Type = v.Type.Copy() return copied @@ -60,7 +58,7 @@ func (v *Variant) Equals(o RVType) bool { other := o.(*Variant) - if v.TypeID != other.TypeID { + if !v.TypeID.Equals(other.TypeID) { return false } @@ -70,5 +68,7 @@ func (v *Variant) Equals(o RVType) bool { // TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewVariant returns a new Variant func NewVariant() *Variant { - return &Variant{} + return &Variant{ + TypeID: NewPrimitiveU8(0), + } } From 6933349e917653d7e2aae9d7628f77dde51d22bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 28 Dec 2023 00:39:06 +0000 Subject: [PATCH 101/178] structure: Remove ParentType As we know a structure's parent type inside the child structure, this isn't needed anymore now that we are writing all the structure from itself. --- types/structure.go | 1 - 1 file changed, 1 deletion(-) diff --git a/types/structure.go b/types/structure.go index 4d834c95..19ad2388 100644 --- a/types/structure.go +++ b/types/structure.go @@ -7,7 +7,6 @@ import ( // Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct type Structure struct { - ParentType RVType StructureVersion uint8 } From 059da89429791d8853d5c6e65f811f831166a220 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 31 Dec 2023 22:42:00 -0500 Subject: [PATCH 102/178] types: Data e to d --- types/data.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/types/data.go b/types/data.go index 3fc2fd51..2262f744 100644 --- a/types/data.go +++ b/types/data.go @@ -11,13 +11,13 @@ type Data struct { } // WriteTo writes the Data to the given writable -func (e *Data) WriteTo(writable Writable) { - e.WriteHeaderTo(writable, 0) +func (d *Data) WriteTo(writable Writable) { + d.WriteHeaderTo(writable, 0) } // ExtractFrom extracts the Data from the given readable -func (e *Data) ExtractFrom(readable Readable) error { - if err := e.ExtractHeaderFrom(readable); err != nil { +func (d *Data) ExtractFrom(readable Readable) error { + if err := d.ExtractHeaderFrom(readable); err != nil { return fmt.Errorf("Failed to read Data header. %s", err.Error()) } @@ -25,36 +25,36 @@ func (e *Data) ExtractFrom(readable Readable) error { } // Copy returns a pointer to a copy of the Data. Requires type assertion when used -func (e *Data) Copy() RVType { +func (d *Data) Copy() RVType { copied := NewData() - copied.StructureVersion = e.StructureVersion + copied.StructureVersion = d.StructureVersion return copied } // Equals checks if the input is equal in value to the current instance -func (e *Data) Equals(o RVType) bool { +func (d *Data) Equals(o RVType) bool { if _, ok := o.(*Data); !ok { return false } - return (*e).StructureVersion == (*o.(*Data)).StructureVersion + return (*d).StructureVersion == (*o.(*Data)).StructureVersion } // String returns a string representation of the struct -func (e *Data) String() string { - return e.FormatToString(0) +func (d *Data) String() string { + return d.FormatToString(0) } // FormatToString pretty-prints the struct data using the provided indentation level -func (e *Data) FormatToString(indentationLevel int) string { +func (d *Data) FormatToString(indentationLevel int) string { indentationValues := strings.Repeat("\t", indentationLevel+1) indentationEnd := strings.Repeat("\t", indentationLevel) var b strings.Builder b.WriteString("Data{\n") - b.WriteString(fmt.Sprintf("%sStructureVersion: %d\n", indentationValues, e.StructureVersion)) + b.WriteString(fmt.Sprintf("%sStructureVersion: %d\n", indentationValues, d.StructureVersion)) b.WriteString(fmt.Sprintf("%s}", indentationEnd)) return b.String() From 1093a614db3c977554e42b19230fcf4244c81b82 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 31 Dec 2023 23:22:41 -0500 Subject: [PATCH 103/178] streams: rename streams to ByteStreams --- byte_stream_in.go | 180 ++++++++++++++++++++++++++++++++++++++++ byte_stream_out.go | 140 +++++++++++++++++++++++++++++++ kerberos.go | 12 +-- prudp_packet.go | 2 +- prudp_packet_lite.go | 10 +-- prudp_packet_v0.go | 6 +- prudp_packet_v1.go | 12 +-- prudp_server.go | 12 +-- rmc_message.go | 12 +-- stream_in.go | 180 ---------------------------------------- stream_out.go | 140 ------------------------------- test/auth.go | 8 +- test/generate_ticket.go | 4 +- test/hpp.go | 4 +- test/secure.go | 8 +- 15 files changed, 365 insertions(+), 365 deletions(-) create mode 100644 byte_stream_in.go create mode 100644 byte_stream_out.go delete mode 100644 stream_in.go delete mode 100644 stream_out.go diff --git a/byte_stream_in.go b/byte_stream_in.go new file mode 100644 index 00000000..6266bc3b --- /dev/null +++ b/byte_stream_in.go @@ -0,0 +1,180 @@ +package nex + +import ( + "errors" + + crunch "github.com/superwhiskers/crunch/v3" +) + +// ByteStreamIn is an input stream abstraction of github.com/superwhiskers/crunch/v3 with nex type support +type ByteStreamIn struct { + *crunch.Buffer + Server ServerInterface +} + +// StringLengthSize returns the expected size of String length fields +func (bsi *ByteStreamIn) StringLengthSize() int { + size := 2 + + if bsi.Server != nil { + size = bsi.Server.StringLengthSize() + } + + return size +} + +// PIDSize returns the size of PID types +func (bsi *ByteStreamIn) PIDSize() int { + size := 4 + + if bsi.Server != nil && bsi.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + size = 8 + } + + return size +} + +// UseStructureHeader determines if Structure headers should be used +func (bsi *ByteStreamIn) UseStructureHeader() bool { + useStructureHeader := false + + if bsi.Server != nil { + switch server := bsi.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructureHeader = server.PRUDPMinorVersion >= 3 + default: + useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") + } + } + + return useStructureHeader +} + +// Remaining returns the amount of data left to be read in the buffer +func (bsi *ByteStreamIn) Remaining() uint64 { + return uint64(len(bsi.Bytes()[bsi.ByteOffset():])) +} + +// ReadRemaining reads all the data left to be read in the buffer +func (bsi *ByteStreamIn) ReadRemaining() []byte { + // * Can safely ignore this error, since bsi.Remaining() will never be less than itself + remaining, _ := bsi.Read(uint64(bsi.Remaining())) + + return remaining +} + +// Read reads the specified number of bytes. Returns an error if OOB +func (bsi *ByteStreamIn) Read(length uint64) ([]byte, error) { + if bsi.Remaining() < length { + return []byte{}, errors.New("Read is OOB") + } + + return bsi.ReadBytesNext(int64(length)), nil +} + +// ReadPrimitiveUInt8 reads a uint8 +func (bsi *ByteStreamIn) ReadPrimitiveUInt8() (uint8, error) { + if bsi.Remaining() < 1 { + return 0, errors.New("Not enough data to read uint8") + } + + return uint8(bsi.ReadByteNext()), nil +} + +// ReadPrimitiveUInt16LE reads a Little-Endian encoded uint16 +func (bsi *ByteStreamIn) ReadPrimitiveUInt16LE() (uint16, error) { + if bsi.Remaining() < 2 { + return 0, errors.New("Not enough data to read uint16") + } + + return bsi.ReadU16LENext(1)[0], nil +} + +// ReadPrimitiveUInt32LE reads a Little-Endian encoded uint32 +func (bsi *ByteStreamIn) ReadPrimitiveUInt32LE() (uint32, error) { + if bsi.Remaining() < 4 { + return 0, errors.New("Not enough data to read uint32") + } + + return bsi.ReadU32LENext(1)[0], nil +} + +// ReadPrimitiveUInt64LE reads a Little-Endian encoded uint64 +func (bsi *ByteStreamIn) ReadPrimitiveUInt64LE() (uint64, error) { + if bsi.Remaining() < 8 { + return 0, errors.New("Not enough data to read uint64") + } + + return bsi.ReadU64LENext(1)[0], nil +} + +// ReadPrimitiveInt8 reads a uint8 +func (bsi *ByteStreamIn) ReadPrimitiveInt8() (int8, error) { + if bsi.Remaining() < 1 { + return 0, errors.New("Not enough data to read int8") + } + + return int8(bsi.ReadByteNext()), nil +} + +// ReadPrimitiveInt16LE reads a Little-Endian encoded int16 +func (bsi *ByteStreamIn) ReadPrimitiveInt16LE() (int16, error) { + if bsi.Remaining() < 2 { + return 0, errors.New("Not enough data to read int16") + } + + return int16(bsi.ReadU16LENext(1)[0]), nil +} + +// ReadPrimitiveInt32LE reads a Little-Endian encoded int32 +func (bsi *ByteStreamIn) ReadPrimitiveInt32LE() (int32, error) { + if bsi.Remaining() < 4 { + return 0, errors.New("Not enough data to read int32") + } + + return int32(bsi.ReadU32LENext(1)[0]), nil +} + +// ReadPrimitiveInt64LE reads a Little-Endian encoded int64 +func (bsi *ByteStreamIn) ReadPrimitiveInt64LE() (int64, error) { + if bsi.Remaining() < 8 { + return 0, errors.New("Not enough data to read int64") + } + + return int64(bsi.ReadU64LENext(1)[0]), nil +} + +// ReadPrimitiveFloat32LE reads a Little-Endian encoded float32 +func (bsi *ByteStreamIn) ReadPrimitiveFloat32LE() (float32, error) { + if bsi.Remaining() < 4 { + return 0, errors.New("Not enough data to read float32") + } + + return bsi.ReadF32LENext(1)[0], nil +} + +// ReadPrimitiveFloat64LE reads a Little-Endian encoded float64 +func (bsi *ByteStreamIn) ReadPrimitiveFloat64LE() (float64, error) { + if bsi.Remaining() < 8 { + return 0, errors.New("Not enough data to read float64") + } + + return bsi.ReadF64LENext(1)[0], nil +} + +// ReadPrimitiveBool reads a bool +func (bsi *ByteStreamIn) ReadPrimitiveBool() (bool, error) { + if bsi.Remaining() < 1 { + return false, errors.New("Not enough data to read bool") + } + + return bsi.ReadByteNext() == 1, nil +} + +// NewByteStreamIn returns a new NEX input byte stream +func NewByteStreamIn(data []byte, server ServerInterface) *ByteStreamIn { + return &ByteStreamIn{ + Buffer: crunch.NewBuffer(data), + Server: server, + } +} diff --git a/byte_stream_out.go b/byte_stream_out.go new file mode 100644 index 00000000..18fc7036 --- /dev/null +++ b/byte_stream_out.go @@ -0,0 +1,140 @@ +package nex + +import ( + "github.com/PretendoNetwork/nex-go/types" + crunch "github.com/superwhiskers/crunch/v3" +) + +// ByteStreamOut is an abstraction of github.com/superwhiskers/crunch with nex type support +type ByteStreamOut struct { + *crunch.Buffer + Server ServerInterface +} + +// StringLengthSize returns the expected size of String length fields +func (bso *ByteStreamOut) StringLengthSize() int { + size := 2 + + if bso.Server != nil { + size = bso.Server.StringLengthSize() + } + + return size +} + +// PIDSize returns the size of PID types +func (bso *ByteStreamOut) PIDSize() int { + size := 4 + + if bso.Server != nil && bso.Server.LibraryVersion().GreaterOrEqual("4.0.0") { + size = 8 + } + + return size +} + +// UseStructureHeader determines if Structure headers should be used +func (bso *ByteStreamOut) UseStructureHeader() bool { + useStructureHeader := false + + if bso.Server != nil { + switch server := bso.Server.(type) { + case *PRUDPServer: // * Support QRV versions + useStructureHeader = server.PRUDPMinorVersion >= 3 + default: + useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") + } + } + + return useStructureHeader +} + +// CopyNew returns a copy of the StreamOut but with a blank internal buffer. Returns as types.Writable +func (bso *ByteStreamOut) CopyNew() types.Writable { + return NewByteStreamOut(bso.Server) +} + +// Writes the input data to the end of the StreamOut +func (bso *ByteStreamOut) Write(data []byte) { + bso.Grow(int64(len(data))) + bso.WriteBytesNext(data) +} + +// WritePrimitiveUInt8 writes a uint8 +func (bso *ByteStreamOut) WritePrimitiveUInt8(u8 uint8) { + bso.Grow(1) + bso.WriteByteNext(byte(u8)) +} + +// WritePrimitiveUInt16LE writes a uint16 as LE +func (bso *ByteStreamOut) WritePrimitiveUInt16LE(u16 uint16) { + bso.Grow(2) + bso.WriteU16LENext([]uint16{u16}) +} + +// WritePrimitiveUInt32LE writes a uint32 as LE +func (bso *ByteStreamOut) WritePrimitiveUInt32LE(u32 uint32) { + bso.Grow(4) + bso.WriteU32LENext([]uint32{u32}) +} + +// WritePrimitiveUInt64LE writes a uint64 as LE +func (bso *ByteStreamOut) WritePrimitiveUInt64LE(u64 uint64) { + bso.Grow(8) + bso.WriteU64LENext([]uint64{u64}) +} + +// WritePrimitiveInt8 writes a int8 +func (bso *ByteStreamOut) WritePrimitiveInt8(s8 int8) { + bso.Grow(1) + bso.WriteByteNext(byte(s8)) +} + +// WritePrimitiveInt16LE writes a uint16 as LE +func (bso *ByteStreamOut) WritePrimitiveInt16LE(s16 int16) { + bso.Grow(2) + bso.WriteU16LENext([]uint16{uint16(s16)}) +} + +// WritePrimitiveInt32LE writes a int32 as LE +func (bso *ByteStreamOut) WritePrimitiveInt32LE(s32 int32) { + bso.Grow(4) + bso.WriteU32LENext([]uint32{uint32(s32)}) +} + +// WritePrimitiveInt64LE writes a int64 as LE +func (bso *ByteStreamOut) WritePrimitiveInt64LE(s64 int64) { + bso.Grow(8) + bso.WriteU64LENext([]uint64{uint64(s64)}) +} + +// WritePrimitiveFloat32LE writes a float32 as LE +func (bso *ByteStreamOut) WritePrimitiveFloat32LE(f32 float32) { + bso.Grow(4) + bso.WriteF32LENext([]float32{f32}) +} + +// WritePrimitiveFloat64LE writes a float64 as LE +func (bso *ByteStreamOut) WritePrimitiveFloat64LE(f64 float64) { + bso.Grow(8) + bso.WriteF64LENext([]float64{f64}) +} + +// WritePrimitiveBool writes a bool +func (bso *ByteStreamOut) WritePrimitiveBool(b bool) { + var bVar uint8 + if b { + bVar = 1 + } + + bso.Grow(1) + bso.WriteByteNext(byte(bVar)) +} + +// NewByteStreamOut returns a new NEX writable byte stream +func NewByteStreamOut(server ServerInterface) *ByteStreamOut { + return &ByteStreamOut{ + Buffer: crunch.NewBuffer(), + Server: server, + } +} diff --git a/kerberos.go b/kerberos.go index aad27ad5..22ac1345 100644 --- a/kerberos.go +++ b/kerberos.go @@ -74,7 +74,7 @@ type KerberosTicket struct { } // Encrypt writes the ticket data to the provided stream and returns the encrypted byte slice -func (kt *KerberosTicket) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { +func (kt *KerberosTicket) Encrypt(key []byte, stream *ByteStreamOut) ([]byte, error) { encryption := NewKerberosEncryption(key) stream.Grow(int64(len(kt.SessionKey))) @@ -99,7 +99,7 @@ type KerberosTicketInternalData struct { } // Encrypt writes the ticket data to the provided stream and returns the encrypted byte slice -func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([]byte, error) { +func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *ByteStreamOut) ([]byte, error) { ti.Issued.WriteTo(stream) ti.SourcePID.WriteTo(stream) @@ -122,7 +122,7 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([] encrypted := encryption.Encrypt(data) - finalStream := NewStreamOut(stream.Server) + finalStream := NewByteStreamOut(stream.Server) ticketBuffer := types.NewBuffer(ticketKey) encryptedBuffer := types.NewBuffer(encrypted) @@ -139,7 +139,7 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *StreamOut) ([] } // Decrypt decrypts the given data and populates the struct -func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) error { +func (ti *KerberosTicketInternalData) Decrypt(stream *ByteStreamIn, key []byte) error { if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { ticketKey := types.NewBuffer(nil) if err := ticketKey.ExtractFrom(stream); err != nil { @@ -154,7 +154,7 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) erro hash := md5.Sum(append(key, ticketKey.Value...)) key = hash[:] - stream = NewStreamIn(data.Value, stream.Server) + stream = NewByteStreamIn(data.Value, stream.Server) } encryption := NewKerberosEncryption(key) @@ -164,7 +164,7 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *StreamIn, key []byte) erro return fmt.Errorf("Failed to decrypt Kerberos ticket internal data. %s", err.Error()) } - stream = NewStreamIn(decrypted, stream.Server) + stream = NewByteStreamIn(decrypted, stream.Server) timestamp := types.NewDateTime(0) if err := timestamp.ExtractFrom(stream); err != nil { diff --git a/prudp_packet.go b/prudp_packet.go index 2a4463e4..2f84674b 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -6,7 +6,7 @@ import "crypto/rc4" type PRUDPPacket struct { server *PRUDPServer sender *PRUDPClient - readStream *StreamIn + readStream *ByteStreamIn sourceStreamType uint8 sourcePort uint8 destinationStreamType uint8 diff --git a/prudp_packet_lite.go b/prudp_packet_lite.go index 349f0872..20506dc0 100644 --- a/prudp_packet_lite.go +++ b/prudp_packet_lite.go @@ -140,7 +140,7 @@ func (p *PRUDPPacketLite) decode() error { func (p *PRUDPPacketLite) Bytes() []byte { options := p.encodeOptions() - stream := NewStreamOut(nil) + stream := NewByteStreamOut(nil) stream.WritePrimitiveUInt8(0x80) stream.WritePrimitiveUInt8(uint8(len(options))) @@ -163,7 +163,7 @@ func (p *PRUDPPacketLite) Bytes() []byte { func (p *PRUDPPacketLite) decodeOptions() error { data := p.readStream.ReadBytesNext(int64(p.optionsLength)) - optionsStream := NewStreamIn(data, nil) + optionsStream := NewByteStreamIn(data, nil) for optionsStream.Remaining() > 0 { optionID, err := optionsStream.ReadPrimitiveUInt8() @@ -223,7 +223,7 @@ func (p *PRUDPPacketLite) decodeOptions() error { } func (p *PRUDPPacketLite) encodeOptions() []byte { - optionsStream := NewStreamOut(nil) + optionsStream := NewByteStreamOut(nil) if p.packetType == SynPacket || p.packetType == ConnectPacket { optionsStream.WritePrimitiveUInt8(0) @@ -276,7 +276,7 @@ func (p *PRUDPPacketLite) calculateSignature(sessionKey, connectionSignature []b } // NewPRUDPPacketLite creates and returns a new PacketLite using the provided Client and stream -func NewPRUDPPacketLite(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketLite, error) { +func NewPRUDPPacketLite(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketLite, error) { packet := &PRUDPPacketLite{ PRUDPPacket: PRUDPPacket{ sender: client, @@ -300,7 +300,7 @@ func NewPRUDPPacketLite(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacket } // NewPRUDPPacketsLite reads all possible PRUDPLite packets from the stream -func NewPRUDPPacketsLite(client *PRUDPClient, readStream *StreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsLite(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 0a738dbf..66ad02c9 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -196,7 +196,7 @@ func (p *PRUDPPacketV0) decode() error { // Bytes encodes a PRUDPv0 packet into a byte slice func (p *PRUDPPacketV0) Bytes() []byte { server := p.server - stream := NewStreamOut(server) + stream := NewByteStreamOut(server) stream.WritePrimitiveUInt8(p.sourcePort | (p.sourceStreamType << 4)) stream.WritePrimitiveUInt8(p.destinationPort | (p.destinationStreamType << 4)) @@ -353,7 +353,7 @@ func (p *PRUDPPacketV0) calculateChecksum(data []byte) uint32 { } // NewPRUDPPacketV0 creates and returns a new PacketV0 using the provided Client and stream -func NewPRUDPPacketV0(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV0, error) { +func NewPRUDPPacketV0(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketV0, error) { packet := &PRUDPPacketV0{ PRUDPPacket: PRUDPPacket{ sender: client, @@ -377,7 +377,7 @@ func NewPRUDPPacketV0(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV0 } // NewPRUDPPacketsV0 reads all possible PRUDPv0 packets from the stream -func NewPRUDPPacketsV0(client *PRUDPClient, readStream *StreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsV0(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index b708e30e..9c86b6fa 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -109,7 +109,7 @@ func (p *PRUDPPacketV1) Bytes() []byte { header := p.encodeHeader() - stream := NewStreamOut(nil) + stream := NewByteStreamOut(nil) stream.Grow(2) stream.WriteBytesNext([]byte{0xEA, 0xD0}) @@ -200,7 +200,7 @@ func (p *PRUDPPacketV1) decodeHeader() error { } func (p *PRUDPPacketV1) encodeHeader() []byte { - stream := NewStreamOut(nil) + stream := NewByteStreamOut(nil) stream.WritePrimitiveUInt8(1) // * Version stream.WritePrimitiveUInt8(p.optionsLength) @@ -217,7 +217,7 @@ func (p *PRUDPPacketV1) encodeHeader() []byte { func (p *PRUDPPacketV1) decodeOptions() error { data := p.readStream.ReadBytesNext(int64(p.optionsLength)) - optionsStream := NewStreamIn(data, nil) + optionsStream := NewByteStreamIn(data, nil) for optionsStream.Remaining() > 0 { optionID, err := optionsStream.ReadPrimitiveUInt8() @@ -271,7 +271,7 @@ func (p *PRUDPPacketV1) decodeOptions() error { } func (p *PRUDPPacketV1) encodeOptions() []byte { - optionsStream := NewStreamOut(nil) + optionsStream := NewByteStreamOut(nil) if p.packetType == SynPacket || p.packetType == ConnectPacket { optionsStream.WritePrimitiveUInt8(0) @@ -356,7 +356,7 @@ func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byt } // NewPRUDPPacketV1 creates and returns a new PacketV1 using the provided Client and stream -func NewPRUDPPacketV1(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV1, error) { +func NewPRUDPPacketV1(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketV1, error) { packet := &PRUDPPacketV1{ PRUDPPacket: PRUDPPacket{ sender: client, @@ -380,7 +380,7 @@ func NewPRUDPPacketV1(client *PRUDPClient, readStream *StreamIn) (*PRUDPPacketV1 } // NewPRUDPPacketsV1 reads all possible PRUDPv1 packets from the stream -func NewPRUDPPacketsV1(client *PRUDPClient, readStream *StreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsV1(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { diff --git a/prudp_server.go b/prudp_server.go index a081e3c2..a8d621d2 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -200,7 +200,7 @@ func (s *PRUDPServer) listenDatagram(quit chan struct{}) { } func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *gws.Conn) error { - readStream := NewStreamIn(packetData, s) + readStream := NewByteStreamIn(packetData, s) var packets []PRUDPPacketInterface @@ -305,7 +305,7 @@ func (s *PRUDPServer) handleAcknowledgment(packet PRUDPPacketInterface) { func (s *PRUDPServer) handleMultiAcknowledgment(packet PRUDPPacketInterface) { client := packet.Sender().(*PRUDPClient) - stream := NewStreamIn(packet.Payload(), s) + stream := NewByteStreamIn(packet.Payload(), s) sequenceIDs := make([]uint16, 0) var baseSequenceID uint16 var substream *ReliablePacketSubstreamManager @@ -460,7 +460,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { client.SetPID(pid) client.setSessionKey(sessionKey) - stream := NewStreamOut(s) + stream := NewByteStreamOut(s) // * The response value is a Buffer whose data contains // * checkValue+1. This is just a lazy way of encoding @@ -513,7 +513,7 @@ func (s *PRUDPServer) handlePing(packet PRUDPPacketInterface) { } func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, uint32, error) { - stream := NewStreamIn(payload, s) + stream := NewByteStreamIn(payload, s) ticketData := types.NewBuffer(nil) if err := ticketData.ExtractFrom(stream); err != nil { @@ -528,7 +528,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, ui serverKey := DeriveKerberosKey(types.NewPID(2), s.kerberosPassword) ticket := NewKerberosTicketInternalData() - if err := ticket.Decrypt(NewStreamIn(ticketData.Value, s), serverKey); err != nil { + if err := ticket.Decrypt(NewByteStreamIn(ticketData.Value, s), serverKey); err != nil { return nil, nil, 0, err } @@ -548,7 +548,7 @@ func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, ui return nil, nil, 0, err } - checkDataStream := NewStreamIn(decryptedRequestData, s) + checkDataStream := NewByteStreamIn(decryptedRequestData, s) userPID := types.NewPID(0) if err := userPID.ExtractFrom(checkDataStream); err != nil { diff --git a/rmc_message.go b/rmc_message.go index 78674f62..57108fa7 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -56,7 +56,7 @@ func (rmc *RMCMessage) FromBytes(data []byte) error { } func (rmc *RMCMessage) decodePacked(data []byte) error { - stream := NewStreamIn(data, rmc.Server) + stream := NewByteStreamIn(data, rmc.Server) length, err := stream.ReadPrimitiveUInt32LE() if err != nil { @@ -143,7 +143,7 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } func (rmc *RMCMessage) decodeVerbose(data []byte) error { - stream := NewStreamIn(data, rmc.Server) + stream := NewByteStreamIn(data, rmc.Server) length, err := stream.ReadPrimitiveUInt32LE() if err != nil { @@ -232,7 +232,7 @@ func (rmc *RMCMessage) Bytes() []byte { } func (rmc *RMCMessage) encodePacked() []byte { - stream := NewStreamOut(rmc.Server) + stream := NewByteStreamOut(rmc.Server) // * RMC requests have their protocol IDs ORed with 0x80 var protocolIDFlag uint16 = 0x80 @@ -279,7 +279,7 @@ func (rmc *RMCMessage) encodePacked() []byte { serialized := stream.Bytes() - message := NewStreamOut(rmc.Server) + message := NewByteStreamOut(rmc.Server) message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) @@ -289,7 +289,7 @@ func (rmc *RMCMessage) encodePacked() []byte { } func (rmc *RMCMessage) encodeVerbose() []byte { - stream := NewStreamOut(rmc.Server) + stream := NewByteStreamOut(rmc.Server) rmc.ProtocolName.WriteTo(stream) stream.WritePrimitiveBool(rmc.IsRequest) @@ -328,7 +328,7 @@ func (rmc *RMCMessage) encodeVerbose() []byte { serialized := stream.Bytes() - message := NewStreamOut(rmc.Server) + message := NewByteStreamOut(rmc.Server) message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) diff --git a/stream_in.go b/stream_in.go deleted file mode 100644 index ab826517..00000000 --- a/stream_in.go +++ /dev/null @@ -1,180 +0,0 @@ -package nex - -import ( - "errors" - - crunch "github.com/superwhiskers/crunch/v3" -) - -// StreamIn is an input stream abstraction of github.com/superwhiskers/crunch/v3 with nex type support -type StreamIn struct { - *crunch.Buffer - Server ServerInterface -} - -// StringLengthSize returns the expected size of String length fields -func (s *StreamIn) StringLengthSize() int { - size := 2 - - if s.Server != nil { - size = s.Server.StringLengthSize() - } - - return size -} - -// PIDSize returns the size of PID types -func (s *StreamIn) PIDSize() int { - size := 4 - - if s.Server != nil && s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - size = 8 - } - - return size -} - -// UseStructureHeader determines if Structure headers should be used -func (s *StreamIn) UseStructureHeader() bool { - useStructureHeader := false - - if s.Server != nil { - switch server := s.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructureHeader = server.PRUDPMinorVersion >= 3 - default: - useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") - } - } - - return useStructureHeader -} - -// Remaining returns the amount of data left to be read in the buffer -func (s *StreamIn) Remaining() uint64 { - return uint64(len(s.Bytes()[s.ByteOffset():])) -} - -// ReadRemaining reads all the data left to be read in the buffer -func (s *StreamIn) ReadRemaining() []byte { - // * Can safely ignore this error, since s.Remaining() will never be less than itself - remaining, _ := s.Read(uint64(s.Remaining())) - - return remaining -} - -// Read reads the specified number of bytes. Returns an error if OOB -func (s *StreamIn) Read(length uint64) ([]byte, error) { - if s.Remaining() < length { - return []byte{}, errors.New("Read is OOB") - } - - return s.ReadBytesNext(int64(length)), nil -} - -// ReadPrimitiveUInt8 reads a uint8 -func (s *StreamIn) ReadPrimitiveUInt8() (uint8, error) { - if s.Remaining() < 1 { - return 0, errors.New("Not enough data to read uint8") - } - - return uint8(s.ReadByteNext()), nil -} - -// ReadPrimitiveUInt16LE reads a Little-Endian encoded uint16 -func (s *StreamIn) ReadPrimitiveUInt16LE() (uint16, error) { - if s.Remaining() < 2 { - return 0, errors.New("Not enough data to read uint16") - } - - return s.ReadU16LENext(1)[0], nil -} - -// ReadPrimitiveUInt32LE reads a Little-Endian encoded uint32 -func (s *StreamIn) ReadPrimitiveUInt32LE() (uint32, error) { - if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read uint32") - } - - return s.ReadU32LENext(1)[0], nil -} - -// ReadPrimitiveUInt64LE reads a Little-Endian encoded uint64 -func (s *StreamIn) ReadPrimitiveUInt64LE() (uint64, error) { - if s.Remaining() < 8 { - return 0, errors.New("Not enough data to read uint64") - } - - return s.ReadU64LENext(1)[0], nil -} - -// ReadPrimitiveInt8 reads a uint8 -func (s *StreamIn) ReadPrimitiveInt8() (int8, error) { - if s.Remaining() < 1 { - return 0, errors.New("Not enough data to read int8") - } - - return int8(s.ReadByteNext()), nil -} - -// ReadPrimitiveInt16LE reads a Little-Endian encoded int16 -func (s *StreamIn) ReadPrimitiveInt16LE() (int16, error) { - if s.Remaining() < 2 { - return 0, errors.New("Not enough data to read int16") - } - - return int16(s.ReadU16LENext(1)[0]), nil -} - -// ReadPrimitiveInt32LE reads a Little-Endian encoded int32 -func (s *StreamIn) ReadPrimitiveInt32LE() (int32, error) { - if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read int32") - } - - return int32(s.ReadU32LENext(1)[0]), nil -} - -// ReadPrimitiveInt64LE reads a Little-Endian encoded int64 -func (s *StreamIn) ReadPrimitiveInt64LE() (int64, error) { - if s.Remaining() < 8 { - return 0, errors.New("Not enough data to read int64") - } - - return int64(s.ReadU64LENext(1)[0]), nil -} - -// ReadPrimitiveFloat32LE reads a Little-Endian encoded float32 -func (s *StreamIn) ReadPrimitiveFloat32LE() (float32, error) { - if s.Remaining() < 4 { - return 0, errors.New("Not enough data to read float32") - } - - return s.ReadF32LENext(1)[0], nil -} - -// ReadPrimitiveFloat64LE reads a Little-Endian encoded float64 -func (s *StreamIn) ReadPrimitiveFloat64LE() (float64, error) { - if s.Remaining() < 8 { - return 0, errors.New("Not enough data to read float64") - } - - return s.ReadF64LENext(1)[0], nil -} - -// ReadPrimitiveBool reads a bool -func (s *StreamIn) ReadPrimitiveBool() (bool, error) { - if s.Remaining() < 1 { - return false, errors.New("Not enough data to read bool") - } - - return s.ReadByteNext() == 1, nil -} - -// NewStreamIn returns a new NEX input stream -func NewStreamIn(data []byte, server ServerInterface) *StreamIn { - return &StreamIn{ - Buffer: crunch.NewBuffer(data), - Server: server, - } -} diff --git a/stream_out.go b/stream_out.go deleted file mode 100644 index aec2b1fe..00000000 --- a/stream_out.go +++ /dev/null @@ -1,140 +0,0 @@ -package nex - -import ( - "github.com/PretendoNetwork/nex-go/types" - crunch "github.com/superwhiskers/crunch/v3" -) - -// StreamOut is an abstraction of github.com/superwhiskers/crunch with nex type support -type StreamOut struct { - *crunch.Buffer - Server ServerInterface -} - -// StringLengthSize returns the expected size of String length fields -func (s *StreamOut) StringLengthSize() int { - size := 2 - - if s.Server != nil { - size = s.Server.StringLengthSize() - } - - return size -} - -// PIDSize returns the size of PID types -func (s *StreamOut) PIDSize() int { - size := 4 - - if s.Server != nil && s.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - size = 8 - } - - return size -} - -// UseStructureHeader determines if Structure headers should be used -func (s *StreamOut) UseStructureHeader() bool { - useStructureHeader := false - - if s.Server != nil { - switch server := s.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructureHeader = server.PRUDPMinorVersion >= 3 - default: - useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") - } - } - - return useStructureHeader -} - -// CopyNew returns a copy of the StreamOut but with a blank internal buffer. Returns as types.Writable -func (s *StreamOut) CopyNew() types.Writable { - return NewStreamOut(s.Server) -} - -// Writes the input data to the end of the StreamOut -func (s *StreamOut) Write(data []byte) { - s.Grow(int64(len(data))) - s.WriteBytesNext(data) -} - -// WritePrimitiveUInt8 writes a uint8 -func (s *StreamOut) WritePrimitiveUInt8(u8 uint8) { - s.Grow(1) - s.WriteByteNext(byte(u8)) -} - -// WritePrimitiveUInt16LE writes a uint16 as LE -func (s *StreamOut) WritePrimitiveUInt16LE(u16 uint16) { - s.Grow(2) - s.WriteU16LENext([]uint16{u16}) -} - -// WritePrimitiveUInt32LE writes a uint32 as LE -func (s *StreamOut) WritePrimitiveUInt32LE(u32 uint32) { - s.Grow(4) - s.WriteU32LENext([]uint32{u32}) -} - -// WritePrimitiveUInt64LE writes a uint64 as LE -func (s *StreamOut) WritePrimitiveUInt64LE(u64 uint64) { - s.Grow(8) - s.WriteU64LENext([]uint64{u64}) -} - -// WritePrimitiveInt8 writes a int8 -func (s *StreamOut) WritePrimitiveInt8(s8 int8) { - s.Grow(1) - s.WriteByteNext(byte(s8)) -} - -// WritePrimitiveInt16LE writes a uint16 as LE -func (s *StreamOut) WritePrimitiveInt16LE(s16 int16) { - s.Grow(2) - s.WriteU16LENext([]uint16{uint16(s16)}) -} - -// WritePrimitiveInt32LE writes a int32 as LE -func (s *StreamOut) WritePrimitiveInt32LE(s32 int32) { - s.Grow(4) - s.WriteU32LENext([]uint32{uint32(s32)}) -} - -// WritePrimitiveInt64LE writes a int64 as LE -func (s *StreamOut) WritePrimitiveInt64LE(s64 int64) { - s.Grow(8) - s.WriteU64LENext([]uint64{uint64(s64)}) -} - -// WritePrimitiveFloat32LE writes a float32 as LE -func (s *StreamOut) WritePrimitiveFloat32LE(f32 float32) { - s.Grow(4) - s.WriteF32LENext([]float32{f32}) -} - -// WritePrimitiveFloat64LE writes a float64 as LE -func (s *StreamOut) WritePrimitiveFloat64LE(f64 float64) { - s.Grow(8) - s.WriteF64LENext([]float64{f64}) -} - -// WritePrimitiveBool writes a bool -func (s *StreamOut) WritePrimitiveBool(b bool) { - var bVar uint8 - if b { - bVar = 1 - } - - s.Grow(1) - s.WriteByteNext(byte(bVar)) -} - -// NewStreamOut returns a new nex output stream -func NewStreamOut(server ServerInterface) *StreamOut { - return &StreamOut{ - Buffer: crunch.NewBuffer(), - Server: server, - } -} diff --git a/test/auth.go b/test/auth.go index 97410325..bdf30c2b 100644 --- a/test/auth.go +++ b/test/auth.go @@ -49,7 +49,7 @@ func login(packet nex.PRUDPPacketInterface) { parameters := request.Parameters - parametersStream := nex.NewStreamIn(parameters, authServer) + parametersStream := nex.NewByteStreamIn(parameters, authServer) strUserName := types.NewString("") if err := strUserName.ExtractFrom(parametersStream); err != nil { @@ -73,7 +73,7 @@ func login(packet nex.PRUDPPacketInterface) { pConnectionData.StationURLSpecialProtocols = types.NewStationURL("") pConnectionData.Time = types.NewDateTime(0).Now() - responseStream := nex.NewStreamOut(authServer) + responseStream := nex.NewByteStreamOut(authServer) retval.WriteTo(responseStream) pidPrincipal.WriteTo(responseStream) @@ -111,7 +111,7 @@ func requestTicket(packet nex.PRUDPPacketInterface) { parameters := request.Parameters - parametersStream := nex.NewStreamIn(parameters, authServer) + parametersStream := nex.NewByteStreamIn(parameters, authServer) idSource := types.NewPID(0) if err := idSource.ExtractFrom(parametersStream); err != nil { @@ -126,7 +126,7 @@ func requestTicket(packet nex.PRUDPPacketInterface) { retval := types.NewResultSuccess(0x00010001) pbufResponse := types.NewBuffer(generateTicket(idSource, idTarget)) - responseStream := nex.NewStreamOut(authServer) + responseStream := nex.NewByteStreamOut(authServer) retval.WriteTo(responseStream) pbufResponse.WriteTo(responseStream) diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 9dbeca5e..36079156 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -24,14 +24,14 @@ func generateTicket(userPID *types.PID, targetPID *types.PID) []byte { ticketInternalData.SourcePID = userPID ticketInternalData.SessionKey = sessionKey - encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewStreamOut(authServer)) + encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewByteStreamOut(authServer)) ticket := nex.NewKerberosTicket() ticket.SessionKey = sessionKey ticket.TargetPID = targetPID ticket.InternalData = types.NewBuffer(encryptedTicketInternalData) - encryptedTicket, _ := ticket.Encrypt(userKey, nex.NewStreamOut(authServer)) + encryptedTicket, _ := ticket.Encrypt(userKey, nex.NewByteStreamOut(authServer)) return encryptedTicket } diff --git a/test/hpp.go b/test/hpp.go index 125374d0..1464e060 100644 --- a/test/hpp.go +++ b/test/hpp.go @@ -90,7 +90,7 @@ func getNotificationURL(packet *nex.HPPPacket) { parameters := request.Parameters - parametersStream := nex.NewStreamIn(parameters, hppServer) + parametersStream := nex.NewByteStreamIn(parameters, hppServer) param := &dataStoreGetNotificationURLParam{} param.PreviousURL = types.NewString("") @@ -102,7 +102,7 @@ func getNotificationURL(packet *nex.HPPPacket) { fmt.Println("[HPP]", param.PreviousURL) - responseStream := nex.NewStreamOut(hppServer) + responseStream := nex.NewByteStreamOut(hppServer) info := &dataStoreReqGetNotificationURLInfo{} info.URL = types.NewString("https://example.com") diff --git a/test/secure.go b/test/secure.go index 36c289ed..a9eddc0c 100644 --- a/test/secure.go +++ b/test/secure.go @@ -88,7 +88,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { parameters := request.Parameters - parametersStream := nex.NewStreamIn(parameters, secureServer) + parametersStream := nex.NewByteStreamIn(parameters, secureServer) vecMyURLs := types.NewList[*types.StationURL]() vecMyURLs.Type = types.NewStationURL("") @@ -111,7 +111,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { retval := types.NewResultSuccess(0x00010001) localStationURL := types.NewString(localStation.EncodeToString()) - responseStream := nex.NewStreamOut(secureServer) + responseStream := nex.NewByteStreamOut(secureServer) retval.WriteTo(responseStream) responseStream.WritePrimitiveUInt32LE(secureServer.ConnectionIDCounter().Next()) @@ -145,7 +145,7 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() response := nex.NewRMCMessage(secureServer) - responseStream := nex.NewStreamOut(secureServer) + responseStream := nex.NewByteStreamOut(secureServer) (&principalPreference{ ShowOnlinePresence: types.NewPrimitiveBool(true), @@ -193,7 +193,7 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() response := nex.NewRMCMessage(secureServer) - responseStream := nex.NewStreamOut(secureServer) + responseStream := nex.NewByteStreamOut(secureServer) responseStream.WritePrimitiveUInt8(0) // * Unknown From fc1da5b833128fa8078cdc990dfa873b95074ad3 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 31 Dec 2023 23:29:16 -0500 Subject: [PATCH 104/178] prudp: use real Buffer type in connection check --- prudp_server.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index a8d621d2..5aa174b8 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -3,6 +3,7 @@ package nex import ( "bytes" "crypto/rand" + "encoding/binary" "errors" "fmt" "net" @@ -449,7 +450,7 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { client.createReliableSubstreams(0) } - var payload []byte + payload := make([]byte, 0) if slices.Contains(s.SecureVirtualServerPorts, packet.DestinationPort()) { sessionKey, pid, checkValue, err := s.readKerberosTicket(packet.Payload()) @@ -460,17 +461,17 @@ func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { client.SetPID(pid) client.setSessionKey(sessionKey) + responseCheckValue := checkValue + 1 + responseCheckValueBytes := make([]byte, 4) + + binary.LittleEndian.PutUint32(responseCheckValueBytes, responseCheckValue) + + checkValueResponse := types.NewBuffer(responseCheckValueBytes) stream := NewByteStreamOut(s) - // * The response value is a Buffer whose data contains - // * checkValue+1. This is just a lazy way of encoding - // * a Buffer type - stream.WritePrimitiveUInt32LE(4) // * Buffer length - stream.WritePrimitiveUInt32LE(checkValue + 1) // * Buffer data + checkValueResponse.WriteTo(stream) payload = stream.Bytes() - } else { - payload = make([]byte, 0) } ack.SetPayload(payload) From bdfeb0b5a86d97e175eb00503fc68d947db45ff5 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 31 Dec 2023 23:51:18 -0500 Subject: [PATCH 105/178] streams: added ByteStreamSettings --- byte_stream_in.go | 17 ++++++----------- byte_stream_out.go | 17 ++++++----------- byte_stream_settings.go | 17 +++++++++++++++++ hpp_server.go | 18 +++++++++--------- prudp_server.go | 16 ++++++++-------- server_interface.go | 4 ++-- 6 files changed, 48 insertions(+), 41 deletions(-) create mode 100644 byte_stream_settings.go diff --git a/byte_stream_in.go b/byte_stream_in.go index 6266bc3b..192f549b 100644 --- a/byte_stream_in.go +++ b/byte_stream_in.go @@ -16,8 +16,8 @@ type ByteStreamIn struct { func (bsi *ByteStreamIn) StringLengthSize() int { size := 2 - if bsi.Server != nil { - size = bsi.Server.StringLengthSize() + if bsi.Server != nil && bsi.Server.ByteStreamSettings() != nil { + size = bsi.Server.ByteStreamSettings().StringLengthSize } return size @@ -27,8 +27,8 @@ func (bsi *ByteStreamIn) StringLengthSize() int { func (bsi *ByteStreamIn) PIDSize() int { size := 4 - if bsi.Server != nil && bsi.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - size = 8 + if bsi.Server != nil && bsi.Server.ByteStreamSettings() != nil { + size = bsi.Server.ByteStreamSettings().PIDSize } return size @@ -38,13 +38,8 @@ func (bsi *ByteStreamIn) PIDSize() int { func (bsi *ByteStreamIn) UseStructureHeader() bool { useStructureHeader := false - if bsi.Server != nil { - switch server := bsi.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructureHeader = server.PRUDPMinorVersion >= 3 - default: - useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") - } + if bsi.Server != nil && bsi.Server.ByteStreamSettings() != nil { + useStructureHeader = bsi.Server.ByteStreamSettings().UseStructureHeader } return useStructureHeader diff --git a/byte_stream_out.go b/byte_stream_out.go index 18fc7036..27b338d8 100644 --- a/byte_stream_out.go +++ b/byte_stream_out.go @@ -15,8 +15,8 @@ type ByteStreamOut struct { func (bso *ByteStreamOut) StringLengthSize() int { size := 2 - if bso.Server != nil { - size = bso.Server.StringLengthSize() + if bso.Server != nil && bso.Server.ByteStreamSettings() != nil { + size = bso.Server.ByteStreamSettings().StringLengthSize } return size @@ -26,8 +26,8 @@ func (bso *ByteStreamOut) StringLengthSize() int { func (bso *ByteStreamOut) PIDSize() int { size := 4 - if bso.Server != nil && bso.Server.LibraryVersion().GreaterOrEqual("4.0.0") { - size = 8 + if bso.Server != nil && bso.Server.ByteStreamSettings() != nil { + size = bso.Server.ByteStreamSettings().PIDSize } return size @@ -37,13 +37,8 @@ func (bso *ByteStreamOut) PIDSize() int { func (bso *ByteStreamOut) UseStructureHeader() bool { useStructureHeader := false - if bso.Server != nil { - switch server := bso.Server.(type) { - case *PRUDPServer: // * Support QRV versions - useStructureHeader = server.PRUDPMinorVersion >= 3 - default: - useStructureHeader = server.LibraryVersion().GreaterOrEqual("3.5.0") - } + if bso.Server != nil && bso.Server.ByteStreamSettings() != nil { + useStructureHeader = bso.Server.ByteStreamSettings().UseStructureHeader } return useStructureHeader diff --git a/byte_stream_settings.go b/byte_stream_settings.go new file mode 100644 index 00000000..328c1e05 --- /dev/null +++ b/byte_stream_settings.go @@ -0,0 +1,17 @@ +package nex + +// ByteStreamSettings defines some settings for how a ByteStream should handle certain data types +type ByteStreamSettings struct { + StringLengthSize int + PIDSize int + UseStructureHeader bool +} + +// NewByteStreamSettings returns a new ByteStreamSettings +func NewByteStreamSettings() *ByteStreamSettings { + return &ByteStreamSettings{ + StringLengthSize: 2, + PIDSize: 4, + UseStructureHeader: false, + } +} diff --git a/hpp_server.go b/hpp_server.go index 923c2074..9e8a3614 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -24,7 +24,7 @@ type HPPServer struct { natTraversalProtocolVersion *LibraryVersion dataHandlers []func(packet PacketInterface) passwordFromPIDHandler func(pid *types.PID) (string, uint32) - stringLengthSize int + byteStreamSettings *ByteStreamSettings } // OnData adds an event handler which is fired when a new HPP request is received @@ -273,21 +273,21 @@ func (s *HPPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (str s.passwordFromPIDHandler = handler } -// StringLengthSize returns the size of the length field used for Quazal::String types -func (s *HPPServer) StringLengthSize() int { - return s.stringLengthSize +// ByteStreamSettings returns the settings to be used for ByteStreams +func (s *HPPServer) ByteStreamSettings() *ByteStreamSettings { + return s.byteStreamSettings } -// SetStringLengthSize sets the size of the length field used for Quazal::String types -func (s *HPPServer) SetStringLengthSize(size int) { - s.stringLengthSize = size +// SetByteStreamSettings sets the settings to be used for ByteStreams +func (s *HPPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings) { + s.byteStreamSettings = byteStreamSettings } // NewHPPServer returns a new HPP server func NewHPPServer() *HPPServer { s := &HPPServer{ - dataHandlers: make([]func(packet PacketInterface), 0), - stringLengthSize: 2, + dataHandlers: make([]func(packet PacketInterface), 0), + byteStreamSettings: NewByteStreamSettings(), } mux := http.NewServeMux() diff --git a/prudp_server.go b/prudp_server.go index 5aa174b8..8ce51609 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -48,8 +48,8 @@ type PRUDPServer struct { PRUDPv1ConnectionSignatureKey []byte EnhancedChecksum bool PRUDPv0CustomChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 - stringLengthSize int CompressionAlgorithm compression.Algorithm + byteStreamSettings *ByteStreamSettings } // OnData adds an event handler which is fired when a new DATA packet is received @@ -1026,14 +1026,14 @@ func (s *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (s s.passwordFromPIDHandler = handler } -// StringLengthSize returns the size of the length field used for Quazal::String types -func (s *PRUDPServer) StringLengthSize() int { - return s.stringLengthSize +// ByteStreamSettings returns the settings to be used for ByteStreams +func (s *PRUDPServer) ByteStreamSettings() *ByteStreamSettings { + return s.byteStreamSettings } -// SetStringLengthSize sets the size of the length field used for Quazal::String types -func (s *PRUDPServer) SetStringLengthSize(size int) { - s.stringLengthSize = size +// SetByteStreamSettings sets the settings to be used for ByteStreams +func (s *PRUDPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings) { + s.byteStreamSettings = byteStreamSettings } // NewPRUDPServer will return a new PRUDP server @@ -1048,7 +1048,7 @@ func NewPRUDPServer() *PRUDPServer { prudpEventHandlers: make(map[string][]func(PacketInterface)), connectionIDCounter: NewCounter[uint32](10), pingTimeout: time.Second * 15, - stringLengthSize: 2, CompressionAlgorithm: compression.NewDummyCompression(), + byteStreamSettings: NewByteStreamSettings(), } } diff --git a/server_interface.go b/server_interface.go index 8733c804..a3ae1a2a 100644 --- a/server_interface.go +++ b/server_interface.go @@ -19,6 +19,6 @@ type ServerInterface interface { OnData(handler func(packet PacketInterface)) PasswordFromPID(pid *types.PID) (string, uint32) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) - StringLengthSize() int - SetStringLengthSize(size int) + ByteStreamSettings() *ByteStreamSettings + SetByteStreamSettings(settings *ByteStreamSettings) } From 84dd859e17f9b6f752a7d3657d35e202246ac859 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 1 Jan 2024 00:50:45 -0500 Subject: [PATCH 106/178] prudp: added PRUDPV0Settings --- prudp_packet.go | 3 +- prudp_packet_v0.go | 143 ++++++++++++++++++++++--------------------- prudp_packet_v1.go | 1 + prudp_server.go | 66 ++++++++++---------- prudp_v0_settings.go | 27 ++++++++ 5 files changed, 135 insertions(+), 105 deletions(-) create mode 100644 prudp_v0_settings.go diff --git a/prudp_packet.go b/prudp_packet.go index 2f84674b..db6b6bd2 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -7,6 +7,7 @@ type PRUDPPacket struct { server *PRUDPServer sender *PRUDPClient readStream *ByteStreamIn + version uint8 sourceStreamType uint8 sourcePort uint8 destinationStreamType uint8 @@ -153,7 +154,7 @@ func (p *PRUDPPacket) decryptPayload() []byte { // * the RC4 stream is always reset to the default key // * regardless if the client is connecting to a secure // * server (prudps) or not - if p.sender.server.IsQuazalMode { + if p.version == 0 && p.sender.server.PRUDPV0Settings.IsQuazalMode { substream.SetCipherKey([]byte("CD&ML")) } diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 66ad02c9..984906f0 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -56,7 +56,7 @@ func (p *PRUDPPacketV0) Copy() PRUDPPacketInterface { // Version returns the packets PRUDP version func (p *PRUDPPacketV0) Version() int { - return 0 + return int(p.version) } func (p *PRUDPPacketV0) decode() error { @@ -83,7 +83,7 @@ func (p *PRUDPPacketV0) decode() error { p.destinationStreamType = destination >> 4 p.destinationPort = destination & 0xF - if server.IsQuazalMode { + if server.PRUDPV0Settings.IsQuazalMode { typeAndFlags, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 type and flags. %s", err.Error()) @@ -149,7 +149,7 @@ func (p *PRUDPPacketV0) decode() error { } } else { // * Some Quazal games use a 4 byte checksum. NEX uses 1 byte - if server.EnhancedChecksum { + if server.PRUDPV0Settings.UseEnhancedChecksum { payloadSize = uint16(p.readStream.Remaining() - 4) } else { payloadSize = uint16(p.readStream.Remaining() - 1) @@ -162,7 +162,7 @@ func (p *PRUDPPacketV0) decode() error { p.payload = p.readStream.ReadBytesNext(int64(payloadSize)) - if server.EnhancedChecksum && p.readStream.Remaining() < 4 { + if server.PRUDPV0Settings.UseEnhancedChecksum && p.readStream.Remaining() < 4 { return errors.New("Failed to read PRUDPv0 checksum. Not have enough data") } else if p.readStream.Remaining() < 1 { return errors.New("Failed to read PRUDPv0 checksum. Not have enough data") @@ -173,7 +173,7 @@ func (p *PRUDPPacketV0) decode() error { var checksum uint32 var checksumU8 uint8 - if server.EnhancedChecksum { + if server.PRUDPV0Settings.UseEnhancedChecksum { checksum, err = p.readStream.ReadPrimitiveUInt32LE() } else { checksumU8, err = p.readStream.ReadPrimitiveUInt8() @@ -184,7 +184,7 @@ func (p *PRUDPPacketV0) decode() error { return fmt.Errorf("Failed to read PRUDPv0 checksum. %s", err.Error()) } - calculatedChecksum := p.calculateChecksum(checksumData) + calculatedChecksum := p.server.PRUDPV0Settings.ChecksumCalculator(p, checksumData) if checksum != calculatedChecksum { return errors.New("Invalid PRUDPv0 checksum") @@ -201,7 +201,7 @@ func (p *PRUDPPacketV0) Bytes() []byte { stream.WritePrimitiveUInt8(p.sourcePort | (p.sourceStreamType << 4)) stream.WritePrimitiveUInt8(p.destinationPort | (p.destinationStreamType << 4)) - if server.IsQuazalMode { + if server.PRUDPV0Settings.IsQuazalMode { stream.WritePrimitiveUInt8(uint8(p.packetType | (p.flags << 3))) } else { stream.WritePrimitiveUInt16LE(p.packetType | (p.flags << 4)) @@ -230,15 +230,9 @@ func (p *PRUDPPacketV0) Bytes() []byte { stream.WriteBytesNext(p.payload) } - var checksum uint32 - - if p.server.PRUDPv0CustomChecksumCalculator != nil { - checksum = p.server.PRUDPv0CustomChecksumCalculator(p, stream.Bytes()) - } else { - checksum = p.calculateChecksum(stream.Bytes()) - } + checksum := p.server.PRUDPV0Settings.ChecksumCalculator(p, stream.Bytes()) - if server.EnhancedChecksum { + if server.PRUDPV0Settings.UseEnhancedChecksum { stream.WritePrimitiveUInt32LE(checksum) } else { stream.WritePrimitiveUInt8(uint8(checksum)) @@ -248,6 +242,55 @@ func (p *PRUDPPacketV0) Bytes() []byte { } func (p *PRUDPPacketV0) calculateConnectionSignature(addr net.Addr) ([]byte, error) { + return p.server.PRUDPV0Settings.ConnectionSignatureCalculator(p, addr) +} + +func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byte) []byte { + return p.server.PRUDPV0Settings.SignatureCalculator(p, sessionKey, connectionSignature) +} + +// NewPRUDPPacketV0 creates and returns a new PacketV0 using the provided Client and stream +func NewPRUDPPacketV0(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketV0, error) { + packet := &PRUDPPacketV0{ + PRUDPPacket: PRUDPPacket{ + sender: client, + readStream: readStream, + version: 0, + }, + } + + if readStream != nil { + packet.server = readStream.Server.(*PRUDPServer) + err := packet.decode() + if err != nil { + return nil, fmt.Errorf("Failed to decode PRUDPv0 packet. %s", err.Error()) + } + } + + if client != nil { + packet.server = client.server + } + + return packet, nil +} + +// NewPRUDPPacketsV0 reads all possible PRUDPv0 packets from the stream +func NewPRUDPPacketsV0(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { + packets := make([]PRUDPPacketInterface, 0) + + for readStream.Remaining() > 0 { + packet, err := NewPRUDPPacketV0(client, readStream) + if err != nil { + return packets, err + } + + packets = append(packets, packet) + } + + return packets, nil +} + +func defaultPRUDPv0ConnectionSignature(packet *PRUDPPacketV0, addr net.Addr) ([]byte, error) { var ip net.IP var port int @@ -271,14 +314,14 @@ func (p *PRUDPPacketV0) calculateConnectionSignature(addr net.Addr) ([]byte, err return signatureBytes, nil } -func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byte) []byte { - if !p.server.IsQuazalMode { - if p.packetType == DataPacket { - return p.calculateDataSignature(sessionKey) +func defaultPRUDPv0CalculateSignature(packet *PRUDPPacketV0, sessionKey, connectionSignature []byte) []byte { + if !packet.server.PRUDPV0Settings.IsQuazalMode { + if packet.packetType == DataPacket { + return packet.server.PRUDPV0Settings.DataSignatureCalculator(packet, sessionKey) } - if p.packetType == DisconnectPacket && p.server.accessKey != "ridfebb9" { - return p.calculateDataSignature(sessionKey) + if packet.packetType == DisconnectPacket && packet.server.accessKey != "ridfebb9" { + return packet.server.PRUDPV0Settings.DataSignatureCalculator(packet, sessionKey) } } @@ -289,16 +332,16 @@ func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byt return make([]byte, 4) } -func (p *PRUDPPacketV0) calculateDataSignature(sessionKey []byte) []byte { - server := p.server - data := p.payload +func defaultPRUDPv0CalculateDataSignature(packet *PRUDPPacketV0, sessionKey []byte) []byte { + server := packet.server + data := packet.payload if server.AccessKey() != "ridfebb9" { - header := []byte{0, 0, p.fragmentID} - binary.LittleEndian.PutUint16(header[:2], p.sequenceID) + header := []byte{0, 0, packet.fragmentID} + binary.LittleEndian.PutUint16(header[:2], packet.sequenceID) data = append(sessionKey, header...) - data = append(data, p.payload...) + data = append(data, packet.payload...) } if len(data) > 0 { @@ -315,11 +358,11 @@ func (p *PRUDPPacketV0) calculateDataSignature(sessionKey []byte) []byte { return []byte{0x78, 0x56, 0x34, 0x12} } -func (p *PRUDPPacketV0) calculateChecksum(data []byte) uint32 { - server := p.server +func defaultPRUDPv0CalculateChecksum(packet *PRUDPPacketV0, data []byte) uint32 { + server := packet.server checksum := sum[byte, uint32]([]byte(server.AccessKey())) - if server.EnhancedChecksum { + if server.PRUDPV0Settings.UseEnhancedChecksum { padSize := (len(data) + 3) &^ 3 data = append(data, make([]byte, padSize-len(data))...) words := make([]uint32, len(data)/4) @@ -351,43 +394,3 @@ func (p *PRUDPPacketV0) calculateChecksum(data []byte) uint32 { return checksum & 0xFF } } - -// NewPRUDPPacketV0 creates and returns a new PacketV0 using the provided Client and stream -func NewPRUDPPacketV0(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketV0, error) { - packet := &PRUDPPacketV0{ - PRUDPPacket: PRUDPPacket{ - sender: client, - readStream: readStream, - }, - } - - if readStream != nil { - packet.server = readStream.Server.(*PRUDPServer) - err := packet.decode() - if err != nil { - return nil, fmt.Errorf("Failed to decode PRUDPv0 packet. %s", err.Error()) - } - } - - if client != nil { - packet.server = client.server - } - - return packet, nil -} - -// NewPRUDPPacketsV0 reads all possible PRUDPv0 packets from the stream -func NewPRUDPPacketsV0(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { - packets := make([]PRUDPPacketInterface, 0) - - for readStream.Remaining() > 0 { - packet, err := NewPRUDPPacketV0(client, readStream) - if err != nil { - return packets, err - } - - packets = append(packets, packet) - } - - return packets, nil -} diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 9c86b6fa..12ad7ab6 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -361,6 +361,7 @@ func NewPRUDPPacketV1(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPack PRUDPPacket: PRUDPPacket{ sender: client, readStream: readStream, + version: 1, }, } diff --git a/prudp_server.go b/prudp_server.go index 8ce51609..4ab62d0e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -18,38 +18,36 @@ import ( // PRUDPServer represents a bare-bones PRUDP server type PRUDPServer struct { - udpSocket *net.UDPConn - websocketServer *WebSocketServer - PRUDPVersion int - PRUDPMinorVersion uint32 - virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] - IsQuazalMode bool - VirtualServerPorts []uint8 - SecureVirtualServerPorts []uint8 - SupportedFunctions uint32 - accessKey string - kerberosPassword []byte - kerberosTicketVersion int - kerberosKeySize int - FragmentSize int - version *LibraryVersion - datastoreProtocolVersion *LibraryVersion - matchMakingProtocolVersion *LibraryVersion - rankingProtocolVersion *LibraryVersion - ranking2ProtocolVersion *LibraryVersion - messagingProtocolVersion *LibraryVersion - utilityProtocolVersion *LibraryVersion - natTraversalProtocolVersion *LibraryVersion - prudpEventHandlers map[string][]func(packet PacketInterface) - clientRemovedEventHandlers []func(client *PRUDPClient) - connectionIDCounter *Counter[uint32] - pingTimeout time.Duration - passwordFromPIDHandler func(pid *types.PID) (string, uint32) - PRUDPv1ConnectionSignatureKey []byte - EnhancedChecksum bool - PRUDPv0CustomChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 - CompressionAlgorithm compression.Algorithm - byteStreamSettings *ByteStreamSettings + udpSocket *net.UDPConn + websocketServer *WebSocketServer + PRUDPVersion int + PRUDPMinorVersion uint32 + virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] + VirtualServerPorts []uint8 + SecureVirtualServerPorts []uint8 + SupportedFunctions uint32 + accessKey string + kerberosPassword []byte + kerberosTicketVersion int + kerberosKeySize int + FragmentSize int + version *LibraryVersion + datastoreProtocolVersion *LibraryVersion + matchMakingProtocolVersion *LibraryVersion + rankingProtocolVersion *LibraryVersion + ranking2ProtocolVersion *LibraryVersion + messagingProtocolVersion *LibraryVersion + utilityProtocolVersion *LibraryVersion + natTraversalProtocolVersion *LibraryVersion + prudpEventHandlers map[string][]func(packet PacketInterface) + clientRemovedEventHandlers []func(client *PRUDPClient) + connectionIDCounter *Counter[uint32] + pingTimeout time.Duration + passwordFromPIDHandler func(pid *types.PID) (string, uint32) + PRUDPv1ConnectionSignatureKey []byte + CompressionAlgorithm compression.Algorithm + byteStreamSettings *ByteStreamSettings + PRUDPV0Settings *PRUDPV0Settings } // OnData adds an event handler which is fired when a new DATA packet is received @@ -787,7 +785,7 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { // * the RC4 stream is always reset to the default key // * regardless if the client is connecting to a secure // * server (prudps) or not - if s.IsQuazalMode { + if packet.Version() == 0 && s.PRUDPV0Settings.IsQuazalMode { substream.SetCipherKey([]byte("CD&ML")) } @@ -1042,7 +1040,6 @@ func NewPRUDPServer() *PRUDPServer { VirtualServerPorts: []uint8{1}, SecureVirtualServerPorts: make([]uint8, 0), virtualServers: NewMutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]](), - IsQuazalMode: false, kerberosKeySize: 32, FragmentSize: 1300, prudpEventHandlers: make(map[string][]func(PacketInterface)), @@ -1050,5 +1047,6 @@ func NewPRUDPServer() *PRUDPServer { pingTimeout: time.Second * 15, CompressionAlgorithm: compression.NewDummyCompression(), byteStreamSettings: NewByteStreamSettings(), + PRUDPV0Settings: NewPRUDPV0Settings(), } } diff --git a/prudp_v0_settings.go b/prudp_v0_settings.go new file mode 100644 index 00000000..b76d2df5 --- /dev/null +++ b/prudp_v0_settings.go @@ -0,0 +1,27 @@ +package nex + +import "net" + +// TODO - We can also breakout the decoding/encoding functions here too, but that would require getters and setters for all packet fields + +// PRUDPV0Settings defines settings for how to handle aspects of PRUDPv0 packets +type PRUDPV0Settings struct { + IsQuazalMode bool + UseEnhancedChecksum bool + ConnectionSignatureCalculator func(packet *PRUDPPacketV0, addr net.Addr) ([]byte, error) + SignatureCalculator func(packet *PRUDPPacketV0, sessionKey, connectionSignature []byte) []byte + DataSignatureCalculator func(packet *PRUDPPacketV0, sessionKey []byte) []byte + ChecksumCalculator func(packet *PRUDPPacketV0, data []byte) uint32 +} + +// NewPRUDPV0Settings returns a new PRUDPV0Settings +func NewPRUDPV0Settings() *PRUDPV0Settings { + return &PRUDPV0Settings{ + IsQuazalMode: false, + UseEnhancedChecksum: false, + ConnectionSignatureCalculator: defaultPRUDPv0ConnectionSignature, + SignatureCalculator: defaultPRUDPv0CalculateSignature, + DataSignatureCalculator: defaultPRUDPv0CalculateDataSignature, + ChecksumCalculator: defaultPRUDPv0CalculateChecksum, + } +} From 04220dfcd18bc5ebc3ccea3ddc375c77c2d2d367 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Wed, 3 Jan 2024 00:29:15 +0000 Subject: [PATCH 107/178] types: Add String printing overrides to all types --- types/any_data_holder.go | 24 +++++++++++++++++++++++ types/buffer.go | 5 +++++ types/class_version_container.go | 25 ++++++++++++++++++++++++ types/list.go | 10 +++++++++- types/map.go | 33 ++++++++++++++++++++++++++++++++ types/primitive_bool.go | 7 +++++++ types/primitive_float32.go | 7 +++++++ types/primitive_float64.go | 7 +++++++ types/primitive_s16.go | 7 +++++++ types/primitive_s32.go | 7 +++++++ types/primitive_s64.go | 7 +++++++ types/primitive_s8.go | 7 +++++++ types/primitive_u16.go | 7 +++++++ types/primitive_u32.go | 7 +++++++ types/primitive_u64.go | 7 +++++++ types/primitive_u8.go | 7 +++++++ types/qbuffer.go | 5 +++++ types/result_range.go | 22 +++++++++++++++++++++ types/rv_connection_data.go | 24 +++++++++++++++++++++++ types/string.go | 5 +++++ types/variant.go | 21 ++++++++++++++++++++ 21 files changed, 250 insertions(+), 1 deletion(-) diff --git a/types/any_data_holder.go b/types/any_data_holder.go index e036f841..9f2a5bd2 100644 --- a/types/any_data_holder.go +++ b/types/any_data_holder.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "strings" ) // AnyDataHolderObjects holds a mapping of RVTypes that are accessible in a AnyDataHolder @@ -105,6 +106,29 @@ func (adh *AnyDataHolder) Equals(o RVType) bool { return adh.ObjectData.Equals(other.ObjectData) } +// String returns a string representation of the struct +func (adh *AnyDataHolder) String() string { + return adh.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (adh *AnyDataHolder) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("AnyDataHolder{\n") + b.WriteString(fmt.Sprintf("%sTypeName: %s,\n", indentationValues, adh.TypeName)) + b.WriteString(fmt.Sprintf("%sLength1: %s,\n", indentationValues, adh.Length1)) + b.WriteString(fmt.Sprintf("%sLength2: %s,\n", indentationValues, adh.Length2)) + b.WriteString(fmt.Sprintf("%sObjectData: %s\n", indentationValues, adh.ObjectData)) + + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + // TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewAnyDataHolder returns a new AnyDataHolder func NewAnyDataHolder() *AnyDataHolder { diff --git a/types/buffer.go b/types/buffer.go index d0c4682d..8501d4ce 100644 --- a/types/buffer.go +++ b/types/buffer.go @@ -52,6 +52,11 @@ func (b *Buffer) Equals(o RVType) bool { return bytes.Equal(b.Value, o.(*Buffer).Value) } +// String returns a string representation of the struct +func (b *Buffer) String() string { + return fmt.Sprintf("%x", b.Value) +} + // NewBuffer returns a new Buffer func NewBuffer(data []byte) *Buffer { return &Buffer{Value: data} diff --git a/types/class_version_container.go b/types/class_version_container.go index 9e2e18d4..2e6050e2 100644 --- a/types/class_version_container.go +++ b/types/class_version_container.go @@ -1,5 +1,10 @@ package types +import ( + "fmt" + "strings" +) + // ClassVersionContainer contains version info for Structures used in verbose RMC messages type ClassVersionContainer struct { Structure @@ -33,6 +38,26 @@ func (cvc *ClassVersionContainer) Equals(o RVType) bool { return cvc.ClassVersions.Equals(o) } +// String returns a string representation of the struct +func (cvc *ClassVersionContainer) String() string { + return cvc.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (cvc *ClassVersionContainer) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("ClassVersionContainer{\n") + b.WriteString(fmt.Sprintf("%sStructureVersion: %d,\n", indentationValues, cvc.StructureVersion)) + b.WriteString(fmt.Sprintf("%sClassVersions: %s\n", indentationValues, cvc.ClassVersions)) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + // NewClassVersionContainer returns a new ClassVersionContainer func NewClassVersionContainer() *ClassVersionContainer { cvc := &ClassVersionContainer{ diff --git a/types/list.go b/types/list.go index 4f4861e4..5c720283 100644 --- a/types/list.go +++ b/types/list.go @@ -1,6 +1,9 @@ package types -import "errors" +import ( + "errors" + "fmt" +) // List represents a Quazal Rendez-Vous/NEX List type type List[T RVType] struct { @@ -93,6 +96,11 @@ func (l *List[T]) SetFromData(data []T) { l.real = data } +// String returns a string representation of the struct +func (l *List[T]) String() string { + return fmt.Sprintf("%v", l.real) +} + // NewList returns a new List of the provided type func NewList[T RVType]() *List[T] { return &List[T]{real: make([]T, 0)} diff --git a/types/map.go b/types/map.go index fc321212..19a46e27 100644 --- a/types/map.go +++ b/types/map.go @@ -1,5 +1,10 @@ package types +import ( + "fmt" + "strings" +) + // Map represents a Quazal Rendez-Vous/NEX Map type type Map[K RVType, V RVType] struct { // * Rendez-Vous/NEX MapMap types can have ANY value for the key, but Go requires @@ -142,6 +147,34 @@ func (m *Map[K, V]) Size() int { return len(m.keys) } +// String returns a string representation of the struct +func (m *Map[K, V]) String() string { + return m.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (m *Map[K, V]) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + if len(m.keys) == 0 { + b.WriteString(fmt.Sprintf("{}\n")) + } else { + b.WriteString(fmt.Sprintf("{\n")) + + for i := 0; i < len(m.keys); i++ { + // TODO - Special handle the the last item to not add the comma on last item + b.WriteString(fmt.Sprintf("%s%v: %v,\n", indentationValues, m.keys[i], m.values[i])) + } + + b.WriteString(fmt.Sprintf("%s}\n", indentationEnd)) + } + + return b.String() +} + // NewMap returns a new Map of the provided type func NewMap[K RVType, V RVType]() *Map[K, V] { return &Map[K, V]{ diff --git a/types/primitive_bool.go b/types/primitive_bool.go index 8a37a28b..46163664 100644 --- a/types/primitive_bool.go +++ b/types/primitive_bool.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveBool is a struct of bool with receiver methods to conform to RVType type PrimitiveBool struct { Value bool @@ -36,6 +38,11 @@ func (b *PrimitiveBool) Equals(o RVType) bool { return b.Value == o.(*PrimitiveBool).Value } +// String returns a string representation of the struct +func (b *PrimitiveBool) String() string { + return fmt.Sprintf("%t", b.Value) +} + // NewPrimitiveBool returns a new PrimitiveBool func NewPrimitiveBool(boolean bool) *PrimitiveBool { return &PrimitiveBool{Value: boolean} diff --git a/types/primitive_float32.go b/types/primitive_float32.go index 3215a655..d3aa57c7 100644 --- a/types/primitive_float32.go +++ b/types/primitive_float32.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveF32 is a struct of float32 with receiver methods to conform to RVType type PrimitiveF32 struct { Value float32 @@ -36,6 +38,11 @@ func (f32 *PrimitiveF32) Equals(o RVType) bool { return f32.Value == o.(*PrimitiveF32).Value } +// String returns a string representation of the struct +func (f32 *PrimitiveF32) String() string { + return fmt.Sprintf("%f", f32.Value) +} + // NewPrimitiveF32 returns a new PrimitiveF32 func NewPrimitiveF32(float float32) *PrimitiveF32 { return &PrimitiveF32{Value: float} diff --git a/types/primitive_float64.go b/types/primitive_float64.go index 1477800b..529f3b75 100644 --- a/types/primitive_float64.go +++ b/types/primitive_float64.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveF64 is a struct of float64 with receiver methods to conform to RVType type PrimitiveF64 struct { Value float64 @@ -36,6 +38,11 @@ func (f64 *PrimitiveF64) Equals(o RVType) bool { return *f64 == *o.(*PrimitiveF64) } +// String returns a string representation of the struct +func (f64 *PrimitiveF64) String() string { + return fmt.Sprintf("%f", f64.Value) +} + // NewPrimitiveF64 returns a new PrimitiveF64 func NewPrimitiveF64(float float64) *PrimitiveF64 { return &PrimitiveF64{Value: float} diff --git a/types/primitive_s16.go b/types/primitive_s16.go index 406bc5ae..c7ac682b 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveS16 is a struct of int16 with receiver methods to conform to RVType type PrimitiveS16 struct { Value int16 @@ -36,6 +38,11 @@ func (s16 *PrimitiveS16) Equals(o RVType) bool { return s16.Value == o.(*PrimitiveS16).Value } +// String returns a string representation of the struct +func (s16 *PrimitiveS16) String() string { + return fmt.Sprintf("%d", s16.Value) +} + // NewPrimitiveS16 returns a new PrimitiveS16 func NewPrimitiveS16(i16 int16) *PrimitiveS16 { return &PrimitiveS16{Value: i16} diff --git a/types/primitive_s32.go b/types/primitive_s32.go index 2919963c..08b3b13f 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveS32 is a struct of int32 with receiver methods to conform to RVType type PrimitiveS32 struct { Value int32 @@ -36,6 +38,11 @@ func (s32 *PrimitiveS32) Equals(o RVType) bool { return s32.Value == o.(*PrimitiveS32).Value } +// String returns a string representation of the struct +func (s32 *PrimitiveS32) String() string { + return fmt.Sprintf("%d", s32.Value) +} + // NewPrimitiveS32 returns a new PrimitiveS32 func NewPrimitiveS32(i32 int32) *PrimitiveS32 { return &PrimitiveS32{Value: i32} diff --git a/types/primitive_s64.go b/types/primitive_s64.go index c0a07d80..f90a8d0b 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveS64 is a struct of int64 with receiver methods to conform to RVType type PrimitiveS64 struct { Value int64 @@ -36,6 +38,11 @@ func (s64 *PrimitiveS64) Equals(o RVType) bool { return s64.Value == o.(*PrimitiveS64).Value } +// String returns a string representation of the struct +func (s64 *PrimitiveS64) String() string { + return fmt.Sprintf("%d", s64.Value) +} + // NewPrimitiveS64 returns a new PrimitiveS64 func NewPrimitiveS64(i64 int64) *PrimitiveS64 { return &PrimitiveS64{Value: i64} diff --git a/types/primitive_s8.go b/types/primitive_s8.go index c63502ea..23a9ec49 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveS8 is a struct of int8 with receiver methods to conform to RVType type PrimitiveS8 struct { Value int8 @@ -36,6 +38,11 @@ func (s8 *PrimitiveS8) Equals(o RVType) bool { return s8.Value == o.(*PrimitiveS8).Value } +// String returns a string representation of the struct +func (s8 *PrimitiveS8) String() string { + return fmt.Sprintf("%d", s8.Value) +} + // NewPrimitiveS8 returns a new PrimitiveS8 func NewPrimitiveS8(i8 int8) *PrimitiveS8 { return &PrimitiveS8{Value: i8} diff --git a/types/primitive_u16.go b/types/primitive_u16.go index 6cc39ac3..04b949a8 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveU16 is a struct of uint16 with receiver methods to conform to RVType type PrimitiveU16 struct { Value uint16 @@ -36,6 +38,11 @@ func (u16 *PrimitiveU16) Equals(o RVType) bool { return u16.Value == o.(*PrimitiveU16).Value } +// String returns a string representation of the struct +func (u16 *PrimitiveU16) String() string { + return fmt.Sprintf("%d", u16.Value) +} + // NewPrimitiveU16 returns a new PrimitiveU16 func NewPrimitiveU16(ui16 uint16) *PrimitiveU16 { return &PrimitiveU16{Value: ui16} diff --git a/types/primitive_u32.go b/types/primitive_u32.go index 866ce724..9ea72bf8 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveU32 is a struct of uint32 with receiver methods to conform to RVType type PrimitiveU32 struct { Value uint32 @@ -36,6 +38,11 @@ func (u32 *PrimitiveU32) Equals(o RVType) bool { return u32.Value == o.(*PrimitiveU32).Value } +// String returns a string representation of the struct +func (u32 *PrimitiveU32) String() string { + return fmt.Sprintf("%d", u32.Value) +} + // NewPrimitiveU32 returns a new PrimitiveU32 func NewPrimitiveU32(ui32 uint32) *PrimitiveU32 { return &PrimitiveU32{Value: ui32} diff --git a/types/primitive_u64.go b/types/primitive_u64.go index 53a71e7f..8ac31a85 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveU64 is a struct of uint64 with receiver methods to conform to RVType type PrimitiveU64 struct { Value uint64 @@ -36,6 +38,11 @@ func (u64 *PrimitiveU64) Equals(o RVType) bool { return u64.Value == o.(*PrimitiveU64).Value } +// String returns a string representation of the struct +func (u64 *PrimitiveU64) String() string { + return fmt.Sprintf("%d", u64.Value) +} + // NewPrimitiveU64 returns a new PrimitiveU64 func NewPrimitiveU64(ui64 uint64) *PrimitiveU64 { return &PrimitiveU64{Value: ui64} diff --git a/types/primitive_u8.go b/types/primitive_u8.go index e88ee46b..94a2ace4 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -1,5 +1,7 @@ package types +import "fmt" + // PrimitiveU8 is a struct of uint8 with receiver methods to conform to RVType type PrimitiveU8 struct { @@ -37,6 +39,11 @@ func (u8 *PrimitiveU8) Equals(o RVType) bool { return u8.Value == o.(*PrimitiveU8).Value } +// String returns a string representation of the struct +func (u8 *PrimitiveU8) String() string { + return fmt.Sprintf("%d", u8.Value) +} + // NewPrimitiveU8 returns a new PrimitiveU8 func NewPrimitiveU8(ui8 uint8) *PrimitiveU8 { return &PrimitiveU8{Value: ui8} diff --git a/types/qbuffer.go b/types/qbuffer.go index 97948aae..a3b2c7e3 100644 --- a/types/qbuffer.go +++ b/types/qbuffer.go @@ -52,6 +52,11 @@ func (qb *QBuffer) Equals(o RVType) bool { return bytes.Equal(qb.Value, o.(*QBuffer).Value) } +// String returns a string representation of the struct +func (qb *QBuffer) String() string { + return fmt.Sprintf("%x", qb.Value) +} + // NewQBuffer returns a new QBuffer func NewQBuffer(data []byte) *QBuffer { return &QBuffer{Value: data} diff --git a/types/result_range.go b/types/result_range.go index 4cd1d75e..f702ec62 100644 --- a/types/result_range.go +++ b/types/result_range.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "strings" ) // ResultRange class which holds information about how to make queries @@ -76,6 +77,27 @@ func (rr *ResultRange) Equals(o RVType) bool { return rr.Length.Equals(other.Length) } +// String returns a string representation of the struct +func (rr *ResultRange) String() string { + return rr.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (rr *ResultRange) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("ResultRange{\n") + b.WriteString(fmt.Sprintf("%sStructureVersion: %d,\n", indentationValues, rr.StructureVersion)) + b.WriteString(fmt.Sprintf("%sOffset: %s,\n", indentationValues, rr.Offset)) + b.WriteString(fmt.Sprintf("%sLength: %s\n", indentationValues, rr.Length)) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + // NewResultRange returns a new ResultRange func NewResultRange() *ResultRange { return &ResultRange{ diff --git a/types/rv_connection_data.go b/types/rv_connection_data.go index 1113a8b1..5f26401d 100644 --- a/types/rv_connection_data.go +++ b/types/rv_connection_data.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "strings" ) // RVConnectionData is a class which holds data about a Rendez-Vous connection @@ -113,6 +114,29 @@ func (rvcd *RVConnectionData) Equals(o RVType) bool { return true } +// String returns a string representation of the struct +func (rvcd *RVConnectionData) String() string { + return rvcd.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (rvcd *RVConnectionData) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("RVConnectionData{\n") + b.WriteString(fmt.Sprintf("%sStructureVersion: %d,\n", indentationValues, rvcd.StructureVersion)) + b.WriteString(fmt.Sprintf("%sStationURL: %s,\n", indentationValues, rvcd.StationURL.FormatToString(indentationLevel+1))) + b.WriteString(fmt.Sprintf("%sSpecialProtocols: %s,\n", indentationValues, rvcd.SpecialProtocols)) + b.WriteString(fmt.Sprintf("%sStationURLSpecialProtocols: %s,\n", indentationValues, rvcd.StationURLSpecialProtocols.FormatToString(indentationLevel+1))) + b.WriteString(fmt.Sprintf("%sTime: %s\n", indentationValues, rvcd.Time.FormatToString(indentationLevel+1))) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + // NewRVConnectionData returns a new RVConnectionData func NewRVConnectionData() *RVConnectionData { rvcd := &RVConnectionData{ diff --git a/types/string.go b/types/string.go index 067ab765..677f1792 100644 --- a/types/string.go +++ b/types/string.go @@ -74,6 +74,11 @@ func (s *String) Equals(o RVType) bool { return s.Value == o.(*String).Value } +// String returns a string representation of the struct +func (s *String) String() string { + return fmt.Sprintf("%q", s.Value) +} + // NewString returns a new String func NewString(str string) *String { return &String{Value: str} diff --git a/types/variant.go b/types/variant.go index 0888ef74..d5c19719 100644 --- a/types/variant.go +++ b/types/variant.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "strings" ) // VariantTypes holds a mapping of RVTypes that are accessible in a Variant @@ -65,6 +66,26 @@ func (v *Variant) Equals(o RVType) bool { return v.Type.Equals(other.Type) } +// String returns a string representation of the struct +func (v *Variant) String() string { + return v.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (v *Variant) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("Variant{\n") + b.WriteString(fmt.Sprintf("%TypeID: %s,\n", indentationValues, v.TypeID)) + b.WriteString(fmt.Sprintf("%Type: %s\n", indentationValues, v.Type)) + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + // TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewVariant returns a new Variant func NewVariant() *Variant { From 2c4d6aa662ece726546afe6a5c2ba6b8546dae93 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 7 Jan 2024 16:18:32 -0500 Subject: [PATCH 108/178] types: update Godoc comments --- types/any_data_holder.go | 18 ++++++++--------- types/buffer.go | 6 ++++-- types/class_version_container.go | 3 ++- types/data.go | 3 ++- types/datetime.go | 3 ++- types/list.go | 12 +++++++---- types/map.go | 11 +++++++---- types/pid.go | 7 ++++--- types/primitive_bool.go | 2 +- types/primitive_float32.go | 2 +- types/primitive_float64.go | 2 +- types/primitive_s16.go | 2 +- types/primitive_s32.go | 2 +- types/primitive_s64.go | 2 +- types/primitive_s8.go | 2 +- types/primitive_u16.go | 2 +- types/primitive_u32.go | 2 +- types/primitive_u64.go | 2 +- types/primitive_u8.go | 3 +-- types/qbuffer.go | 4 +++- types/quuid.go | 3 ++- types/readable.go | 34 ++++++++++++++++---------------- types/result.go | 6 ++++-- types/result_range.go | 7 ++++--- types/rv_connection_data.go | 9 +++++---- types/rv_type.go | 3 ++- types/station_url.go | 4 ++-- types/string.go | 3 ++- types/structure.go | 2 +- types/variant.go | 4 ++-- types/writable.go | 34 ++++++++++++++++---------------- 31 files changed, 110 insertions(+), 89 deletions(-) diff --git a/types/any_data_holder.go b/types/any_data_holder.go index 9f2a5bd2..288a7f90 100644 --- a/types/any_data_holder.go +++ b/types/any_data_holder.go @@ -13,13 +13,14 @@ func RegisterDataHolderType(name string, rvType RVType) { AnyDataHolderObjects[name] = rvType } -// AnyDataHolder is a class which can contain any Structure. These Structures usually inherit from at least one -// other Structure. Typically this base class is the empty `Data` Structure, but this is not always the case. -// The contained Structures name & length are sent with the Structure body, so the receiver can properly decode it +// AnyDataHolder is a class which can contain any Structure. The official type name and namespace is unknown. +// These Structures usually inherit from at least one other Structure. Typically this base class is the empty +// `Data` Structure, but this is not always the case. The contained Structures name & length are sent with the +// Structure body, so the receiver can properly decode it. type AnyDataHolder struct { TypeName *String - Length1 *PrimitiveU32 - Length2 *PrimitiveU32 + Length1 *PrimitiveU32 // Length of ObjectData + Length2 + Length2 *PrimitiveU32 // Length of ObjectData ObjectData RVType } @@ -71,7 +72,7 @@ func (adh *AnyDataHolder) ExtractFrom(readable Readable) error { return nil } -// Copy returns a new copied instance of DataHolder +// Copy returns a new copied instance of AnyDataHolder func (adh *AnyDataHolder) Copy() RVType { copied := NewAnyDataHolder() @@ -129,12 +130,11 @@ func (adh *AnyDataHolder) FormatToString(indentationLevel int) string { return b.String() } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewAnyDataHolder returns a new AnyDataHolder func NewAnyDataHolder() *AnyDataHolder { return &AnyDataHolder{ TypeName: NewString(""), - Length1: NewPrimitiveU32(0), - Length2: NewPrimitiveU32(0), + Length1: NewPrimitiveU32(0), + Length2: NewPrimitiveU32(0), } } diff --git a/types/buffer.go b/types/buffer.go index 8501d4ce..d8e5d89a 100644 --- a/types/buffer.go +++ b/types/buffer.go @@ -5,12 +5,14 @@ import ( "fmt" ) -// Buffer is a struct of []byte with receiver methods to conform to RVType +// Buffer is an implementation of rdv::Buffer. +// Wraps a primitive Go byte slice. +// Same as QBuffer but with a uint32 length field. type Buffer struct { Value []byte } -// WriteTo writes the []byte to the given writable +// WriteTo writes the Buffer to the given writable func (b *Buffer) WriteTo(writable Writable) { length := len(b.Value) diff --git a/types/class_version_container.go b/types/class_version_container.go index 2e6050e2..a9f3cf22 100644 --- a/types/class_version_container.go +++ b/types/class_version_container.go @@ -5,7 +5,8 @@ import ( "strings" ) -// ClassVersionContainer contains version info for Structures used in verbose RMC messages +// ClassVersionContainer is an implementation of rdv::ClassVersionContainer. +// Contains version info for Structures used in verbose RMC messages. type ClassVersionContainer struct { Structure ClassVersions *Map[*String, *PrimitiveU16] diff --git a/types/data.go b/types/data.go index 2262f744..5fe9ed38 100644 --- a/types/data.go +++ b/types/data.go @@ -5,7 +5,8 @@ import ( "strings" ) -// Data is the base class for many other structures. The structure itself has no fields +// Data is an implementation of rdv::Data. +// This structure has no data, and instead acts as the base class for many other structures. type Data struct { Structure } diff --git a/types/datetime.go b/types/datetime.go index 9b4bef98..a8acc022 100644 --- a/types/datetime.go +++ b/types/datetime.go @@ -6,7 +6,8 @@ import ( "time" ) -// DateTime represents a NEX DateTime type +// DateTime is an implementation of rdv::DateTime. +// The underlying value is a uint64 bit field containing date and time information. type DateTime struct { value uint64 } diff --git a/types/list.go b/types/list.go index 5c720283..39c26c54 100644 --- a/types/list.go +++ b/types/list.go @@ -5,13 +5,17 @@ import ( "fmt" ) -// List represents a Quazal Rendez-Vous/NEX List type +// List is an implementation of rdv::qList. +// This data type holds an array of other types. +// +// Unlike Buffer and qBuffer, which use the same data type with differing size field lengths, +// there does not seem to be an official rdv::List type type List[T RVType] struct { real []T Type T } -// WriteTo writes the bool to the given writable +// WriteTo writes the List to the given writable func (l *List[T]) WriteTo(writable Writable) { writable.WritePrimitiveUInt32LE(uint32(len(l.real))) @@ -20,7 +24,7 @@ func (l *List[T]) WriteTo(writable Writable) { } } -// ExtractFrom extracts the bool from the given readable +// ExtractFrom extracts the List from the given readable func (l *List[T]) ExtractFrom(readable Readable) error { length, err := readable.ReadPrimitiveUInt32LE() if err != nil { @@ -43,7 +47,7 @@ func (l *List[T]) ExtractFrom(readable Readable) error { return nil } -// Copy returns a pointer to a copy of the List[T]. Requires type assertion when used +// Copy returns a pointer to a copy of the List. Requires type assertion when used func (l *List[T]) Copy() RVType { copied := NewList[T]() copied.real = make([]T, len(l.real)) diff --git a/types/map.go b/types/map.go index 19a46e27..b192f6a5 100644 --- a/types/map.go +++ b/types/map.go @@ -5,7 +5,10 @@ import ( "strings" ) -// Map represents a Quazal Rendez-Vous/NEX Map type +// Map represents a Quazal Rendez-Vous/NEX Map type. +// +// There is not an official type in either the rdv or nn::nex namespaces. +// The data is stored as an array of key-value pairs. type Map[K RVType, V RVType] struct { // * Rendez-Vous/NEX MapMap types can have ANY value for the key, but Go requires // * map keys to implement the "comparable" constraint. This is not possible with @@ -17,7 +20,7 @@ type Map[K RVType, V RVType] struct { ValueType V } -// WriteTo writes the bool to the given writable +// WriteTo writes the Map to the given writable func (m *Map[K, V]) WriteTo(writable Writable) { writable.WritePrimitiveUInt32LE(uint32(m.Size())) @@ -27,7 +30,7 @@ func (m *Map[K, V]) WriteTo(writable Writable) { } } -// ExtractFrom extracts the bool from the given readable +// ExtractFrom extracts the Map from the given readable func (m *Map[K, V]) ExtractFrom(readable Readable) error { length, err := readable.ReadPrimitiveUInt32LE() if err != nil { @@ -58,7 +61,7 @@ func (m *Map[K, V]) ExtractFrom(readable Readable) error { return nil } -// Copy returns a pointer to a copy of the Map[K, V]. Requires type assertion when used +// Copy returns a pointer to a copy of the Map. Requires type assertion when used func (m *Map[K, V]) Copy() RVType { copied := NewMap[K, V]() copied.keys = make([]K, len(m.keys)) diff --git a/types/pid.go b/types/pid.go index a3310fcd..5c6c8113 100644 --- a/types/pid.go +++ b/types/pid.go @@ -5,7 +5,8 @@ import ( "strings" ) -// PID represents a unique number to identify a user +// PID represents a unique number to identify a user. +// The official library treats this as a primitive integer. // // The true size of this value depends on the client version. // Legacy clients (WiiU/3DS) use a uint32, whereas modern clients (Nintendo Switch) use a uint64. @@ -14,7 +15,7 @@ type PID struct { pid uint64 } -// WriteTo writes the bool to the given writable +// WriteTo writes the PID to the given writable func (p *PID) WriteTo(writable Writable) { if writable.PIDSize() == 8 { writable.WritePrimitiveUInt64LE(p.pid) @@ -23,7 +24,7 @@ func (p *PID) WriteTo(writable Writable) { } } -// ExtractFrom extracts the bool from the given readable +// ExtractFrom extracts the PID from the given readable func (p *PID) ExtractFrom(readable Readable) error { var pid uint64 var err error diff --git a/types/primitive_bool.go b/types/primitive_bool.go index 46163664..717e9a17 100644 --- a/types/primitive_bool.go +++ b/types/primitive_bool.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveBool is a struct of bool with receiver methods to conform to RVType +// PrimitiveBool is wrapper around a Go primitive bool with receiver methods to conform to RVType type PrimitiveBool struct { Value bool } diff --git a/types/primitive_float32.go b/types/primitive_float32.go index d3aa57c7..1f1e7ef9 100644 --- a/types/primitive_float32.go +++ b/types/primitive_float32.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveF32 is a struct of float32 with receiver methods to conform to RVType +// PrimitiveF32 is wrapper around a Go primitive float32 with receiver methods to conform to RVType type PrimitiveF32 struct { Value float32 } diff --git a/types/primitive_float64.go b/types/primitive_float64.go index 529f3b75..981a6827 100644 --- a/types/primitive_float64.go +++ b/types/primitive_float64.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveF64 is a struct of float64 with receiver methods to conform to RVType +// PrimitiveF64 is wrapper around a Go primitive float64 with receiver methods to conform to RVType type PrimitiveF64 struct { Value float64 } diff --git a/types/primitive_s16.go b/types/primitive_s16.go index c7ac682b..6c79085d 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveS16 is a struct of int16 with receiver methods to conform to RVType +// PrimitiveS16 is wrapper around a Go primitive int16 with receiver methods to conform to RVType type PrimitiveS16 struct { Value int16 } diff --git a/types/primitive_s32.go b/types/primitive_s32.go index 08b3b13f..b2a2cb54 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveS32 is a struct of int32 with receiver methods to conform to RVType +// PrimitiveS32 is wrapper around a Go primitive int32 with receiver methods to conform to RVType type PrimitiveS32 struct { Value int32 } diff --git a/types/primitive_s64.go b/types/primitive_s64.go index f90a8d0b..c97e7d12 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveS64 is a struct of int64 with receiver methods to conform to RVType +// PrimitiveS64 is wrapper around a Go primitive int64 with receiver methods to conform to RVType type PrimitiveS64 struct { Value int64 } diff --git a/types/primitive_s8.go b/types/primitive_s8.go index 23a9ec49..3b066c79 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveS8 is a struct of int8 with receiver methods to conform to RVType +// PrimitiveS8 is wrapper around a Go primitive int8 with receiver methods to conform to RVType type PrimitiveS8 struct { Value int8 } diff --git a/types/primitive_u16.go b/types/primitive_u16.go index 04b949a8..a83f78eb 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveU16 is a struct of uint16 with receiver methods to conform to RVType +// PrimitiveU16 is wrapper around a Go primitive uint16 with receiver methods to conform to RVType type PrimitiveU16 struct { Value uint16 } diff --git a/types/primitive_u32.go b/types/primitive_u32.go index 9ea72bf8..e98a7702 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveU32 is a struct of uint32 with receiver methods to conform to RVType +// PrimitiveU32 is wrapper around a Go primitive uint32 with receiver methods to conform to RVType type PrimitiveU32 struct { Value uint32 } diff --git a/types/primitive_u64.go b/types/primitive_u64.go index 8ac31a85..f2d4baca 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -2,7 +2,7 @@ package types import "fmt" -// PrimitiveU64 is a struct of uint64 with receiver methods to conform to RVType +// PrimitiveU64 is wrapper around a Go primitive uint64 with receiver methods to conform to RVType type PrimitiveU64 struct { Value uint64 } diff --git a/types/primitive_u8.go b/types/primitive_u8.go index 94a2ace4..ccccabc2 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -2,8 +2,7 @@ package types import "fmt" - -// PrimitiveU8 is a struct of uint8 with receiver methods to conform to RVType +// PrimitiveU8 is wrapper around a Go primitive uint8 with receiver methods to conform to RVType type PrimitiveU8 struct { Value uint8 } diff --git a/types/qbuffer.go b/types/qbuffer.go index a3b2c7e3..edf0755b 100644 --- a/types/qbuffer.go +++ b/types/qbuffer.go @@ -5,7 +5,9 @@ import ( "fmt" ) -// QBuffer is a struct of []byte with receiver methods to conform to RVType +// QBuffer is an implementation of rdv::qBuffer. +// Wraps a primitive Go byte slice. +// Same as Buffer but with a uint16 length field. type QBuffer struct { Value []byte } diff --git a/types/quuid.go b/types/quuid.go index ea49bfc3..c8d581e7 100644 --- a/types/quuid.go +++ b/types/quuid.go @@ -8,7 +8,8 @@ import ( "strings" ) -// QUUID represents a QRV qUUID type. This type encodes a UUID in little-endian byte order +// QUUID is an implementation of rdv::qUUID. +// Encodes a UUID in little-endian byte order. type QUUID struct { Data []byte } diff --git a/types/readable.go b/types/readable.go index c4dfb709..8398bf6c 100644 --- a/types/readable.go +++ b/types/readable.go @@ -2,21 +2,21 @@ package types // Readable represents a struct that types can read from type Readable interface { - StringLengthSize() int - PIDSize() int - UseStructureHeader() bool - Remaining() uint64 - ReadRemaining() []byte - Read(length uint64) ([]byte, error) - ReadPrimitiveUInt8() (uint8, error) - ReadPrimitiveUInt16LE() (uint16, error) - ReadPrimitiveUInt32LE() (uint32, error) - ReadPrimitiveUInt64LE() (uint64, error) - ReadPrimitiveInt8() (int8, error) - ReadPrimitiveInt16LE() (int16, error) - ReadPrimitiveInt32LE() (int32, error) - ReadPrimitiveInt64LE() (int64, error) - ReadPrimitiveFloat32LE() (float32, error) - ReadPrimitiveFloat64LE() (float64, error) - ReadPrimitiveBool() (bool, error) + StringLengthSize() int // Returns the size of the length field for rdv::String types. Only 2 and 4 are valid + PIDSize() int // Returns the size of the length fields for nn::nex::PID types. Only 4 and 8 are valid + UseStructureHeader() bool // Returns whether or not Structure types should use a header + Remaining() uint64 // Returns the number of bytes left unread in the buffer + ReadRemaining() []byte // Reads the remaining data from the buffer + Read(length uint64) ([]byte, error) // Reads up to length bytes of data from the buffer. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveUInt8() (uint8, error) // Reads a primitive Go uint8. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveUInt16LE() (uint16, error) // Reads a primitive Go uint16. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveUInt32LE() (uint32, error) // Reads a primitive Go uint32. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveUInt64LE() (uint64, error) // Reads a primitive Go uint64. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveInt8() (int8, error) // Reads a primitive Go int8. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveInt16LE() (int16, error) // Reads a primitive Go int16. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveInt32LE() (int32, error) // Reads a primitive Go int32. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveInt64LE() (int64, error) // Reads a primitive Go int64. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveFloat32LE() (float32, error) // Reads a primitive Go float32. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveFloat64LE() (float64, error) // Reads a primitive Go float64. Returns an error if the read failed, such as if there was not enough data to read + ReadPrimitiveBool() (bool, error) // Reads a primitive Go bool. Returns an error if the read failed, such as if there was not enough data to read } diff --git a/types/result.go b/types/result.go index cd5a8b33..c813ade8 100644 --- a/types/result.go +++ b/types/result.go @@ -7,9 +7,11 @@ import ( var errorMask = 1 << 31 -// Result is sent in methods which query large objects +// Result is an implementation of nn::Result. +// Determines the result of an operation. +// If the MSB is set the result is an error, otherwise success type Result struct { - Code uint32 // TODO - Replace this with PrimitiveU32? + Code uint32 } // WriteTo writes the Result to the given writable diff --git a/types/result_range.go b/types/result_range.go index f702ec62..c7b76ca6 100644 --- a/types/result_range.go +++ b/types/result_range.go @@ -5,11 +5,12 @@ import ( "strings" ) -// ResultRange class which holds information about how to make queries +// ResultRange is an implementation of rdv::ResultRange. +// Holds information about how to make queries which may return large data. type ResultRange struct { Structure - Offset *PrimitiveU32 - Length *PrimitiveU32 + Offset *PrimitiveU32 // Offset into the dataset + Length *PrimitiveU32 // Number of items to return } // WriteTo writes the ResultRange to the given writable diff --git a/types/rv_connection_data.go b/types/rv_connection_data.go index 5f26401d..4cd57595 100644 --- a/types/rv_connection_data.go +++ b/types/rv_connection_data.go @@ -5,7 +5,8 @@ import ( "strings" ) -// RVConnectionData is a class which holds data about a Rendez-Vous connection +// RVConnectionData is an implementation of rdv::RVConnectionData. +// Contains the locations and data of Rendez-Vous connection. type RVConnectionData struct { Structure StationURL *StationURL @@ -140,10 +141,10 @@ func (rvcd *RVConnectionData) FormatToString(indentationLevel int) string { // NewRVConnectionData returns a new RVConnectionData func NewRVConnectionData() *RVConnectionData { rvcd := &RVConnectionData{ - StationURL: NewStationURL(""), - SpecialProtocols: NewList[*PrimitiveU8](), + StationURL: NewStationURL(""), + SpecialProtocols: NewList[*PrimitiveU8](), StationURLSpecialProtocols: NewStationURL(""), - Time: NewDateTime(0), + Time: NewDateTime(0), } rvcd.SpecialProtocols.Type = NewPrimitiveU8(0) diff --git a/types/rv_type.go b/types/rv_type.go index dea5be61..675df1d0 100644 --- a/types/rv_type.go +++ b/types/rv_type.go @@ -1,7 +1,8 @@ // Package types provides types used in Quazal Rendez-Vous/NEX package types -// RVType represents a Quazal Rendez-Vous/NEX type. This includes primitives and custom types +// RVType represents a Quazal Rendez-Vous/NEX type. +// This includes primitives and custom types. type RVType interface { WriteTo(writable Writable) ExtractFrom(readable Readable) error diff --git a/types/station_url.go b/types/station_url.go index 93e252d2..3a86a906 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -5,7 +5,8 @@ import ( "strings" ) -// StationURL contains the data for a NEX station URL +// StationURL is an implementation of rdv::StationURL. +// Contains location of a station to connect to, with data about how to connect. type StationURL struct { local bool // * Not part of the data structure. Used for easier lookups elsewhere public bool // * Not part of the data structure. Used for easier lookups elsewhere @@ -156,7 +157,6 @@ func (s *StationURL) FormatToString(indentationLevel int) string { return b.String() } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewStationURL returns a new StationURL func NewStationURL(str string) *StationURL { stationURL := &StationURL{ diff --git a/types/string.go b/types/string.go index 677f1792..ef52f1c1 100644 --- a/types/string.go +++ b/types/string.go @@ -6,7 +6,8 @@ import ( "strings" ) -// String is a struct of string with receiver methods to conform to RVType +// String is an implementation of rdv::String. +// Wraps a primitive Go string. type String struct { Value string } diff --git a/types/structure.go b/types/structure.go index 19ad2388..ce3936f2 100644 --- a/types/structure.go +++ b/types/structure.go @@ -5,7 +5,7 @@ import ( "fmt" ) -// Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct +// Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct. type Structure struct { StructureVersion uint8 } diff --git a/types/variant.go b/types/variant.go index d5c19719..8ff0e4eb 100644 --- a/types/variant.go +++ b/types/variant.go @@ -13,7 +13,8 @@ func RegisterVariantType(id uint8, rvType RVType) { VariantTypes[id] = rvType } -// Variant is a type which can old many other types +// Variant is an implementation of rdv::Variant. +// This type can hold many other types, denoted by a type ID. type Variant struct { TypeID *PrimitiveU8 Type RVType @@ -86,7 +87,6 @@ func (v *Variant) FormatToString(indentationLevel int) string { return b.String() } -// TODO - Should this take in a default value, or take in nothing and have a "SetFromData"-kind of method? // NewVariant returns a new Variant func NewVariant() *Variant { return &Variant{ diff --git a/types/writable.go b/types/writable.go index 0d3b584b..29f20664 100644 --- a/types/writable.go +++ b/types/writable.go @@ -2,21 +2,21 @@ package types // Writable represents a struct that types can write to type Writable interface { - StringLengthSize() int - PIDSize() int - UseStructureHeader() bool - CopyNew() Writable - Write(data []byte) - WritePrimitiveUInt8(value uint8) - WritePrimitiveUInt16LE(value uint16) - WritePrimitiveUInt32LE(value uint32) - WritePrimitiveUInt64LE(value uint64) - WritePrimitiveInt8(value int8) - WritePrimitiveInt16LE(value int16) - WritePrimitiveInt32LE(value int32) - WritePrimitiveInt64LE(value int64) - WritePrimitiveFloat32LE(value float32) - WritePrimitiveFloat64LE(value float64) - WritePrimitiveBool(value bool) - Bytes() []byte + StringLengthSize() int // Returns the size of the length field for rdv::String types. Only 2 and 4 are valid + PIDSize() int // Returns the size of the length fields for nn::nex::PID types. Only 4 and 8 are valid + UseStructureHeader() bool // Returns whether or not Structure types should use a header + CopyNew() Writable // Returns a new Writable with the same settings, but an empty buffer + Write(data []byte) // Writes the provided data to the buffer + WritePrimitiveUInt8(value uint8) // Writes a primitive Go uint8 + WritePrimitiveUInt16LE(value uint16) // Writes a primitive Go uint16 + WritePrimitiveUInt32LE(value uint32) // Writes a primitive Go uint32 + WritePrimitiveUInt64LE(value uint64) // Writes a primitive Go uint64 + WritePrimitiveInt8(value int8) // Writes a primitive Go int8 + WritePrimitiveInt16LE(value int16) // Writes a primitive Go int16 + WritePrimitiveInt32LE(value int32) // Writes a primitive Go int32 + WritePrimitiveInt64LE(value int64) // Writes a primitive Go int64 + WritePrimitiveFloat32LE(value float32) // Writes a primitive Go float32 + WritePrimitiveFloat64LE(value float64) // Writes a primitive Go float64 + WritePrimitiveBool(value bool) // Writes a primitive Go bool + Bytes() []byte // Returns the data written t othe buffer } From 7c38a411fed7c978e6517a5a86da52ee36f5c7d0 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 7 Jan 2024 16:19:24 -0500 Subject: [PATCH 109/178] types: update Map FormatToString to remove linter warning --- types/map.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/types/map.go b/types/map.go index b192f6a5..5f0a6c49 100644 --- a/types/map.go +++ b/types/map.go @@ -163,9 +163,9 @@ func (m *Map[K, V]) FormatToString(indentationLevel int) string { var b strings.Builder if len(m.keys) == 0 { - b.WriteString(fmt.Sprintf("{}\n")) + b.WriteString("{}\n") } else { - b.WriteString(fmt.Sprintf("{\n")) + b.WriteString("{\n") for i := 0; i < len(m.keys); i++ { // TODO - Special handle the the last item to not add the comma on last item From b82f545a039f86295c43ba41da16f8fd03e1484f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 7 Jan 2024 16:52:30 -0500 Subject: [PATCH 110/178] types: update some bad Equals methods --- types/data.go | 8 +++++++- types/primitive_float64.go | 2 +- types/structure.go | 4 +++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/types/data.go b/types/data.go index 5fe9ed38..f20ef34e 100644 --- a/types/data.go +++ b/types/data.go @@ -39,7 +39,13 @@ func (d *Data) Equals(o RVType) bool { return false } - return (*d).StructureVersion == (*o.(*Data)).StructureVersion + other := o.(*Data) + + if d.StructureVersion == other.StructureVersion { + return false + } + + return d.StructureContentLength == other.StructureContentLength } // String returns a string representation of the struct diff --git a/types/primitive_float64.go b/types/primitive_float64.go index 981a6827..54e0b527 100644 --- a/types/primitive_float64.go +++ b/types/primitive_float64.go @@ -35,7 +35,7 @@ func (f64 *PrimitiveF64) Equals(o RVType) bool { return false } - return *f64 == *o.(*PrimitiveF64) + return f64.Value == o.(*PrimitiveF64).Value } // String returns a string representation of the struct diff --git a/types/structure.go b/types/structure.go index ce3936f2..aa521c63 100644 --- a/types/structure.go +++ b/types/structure.go @@ -7,7 +7,8 @@ import ( // Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct. type Structure struct { - StructureVersion uint8 + StructureVersion uint8 + StructureContentLength uint32 } // ExtractHeaderFrom extracts the structure header from the given readable @@ -28,6 +29,7 @@ func (s *Structure) ExtractHeaderFrom(readable Readable) error { } s.StructureVersion = version + s.StructureContentLength = contentLength } return nil From 9cec5d9e9ccc9f6686b5812f26bf873c72b83dad Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 15:01:26 -0500 Subject: [PATCH 111/178] prudp: completely redo virtual connections --- compression/algorithm.go | 1 + compression/dummy.go | 5 + compression/lzo.go | 10 + compression/zlib.go | 5 + encryption/algorithm.go | 12 + encryption/dummy.go | 44 ++ encryption/rc4.go | 95 +++++ prudp_client.go | 230 ----------- prudp_connection.go | 218 ++++++++++ prudp_endpoint.go | 603 +++++++++++++++++++++++++++ prudp_packet.go | 94 +++-- prudp_packet_interface.go | 16 +- prudp_packet_lite.go | 92 +++-- prudp_packet_v0.go | 31 +- prudp_packet_v1.go | 31 +- prudp_server.go | 738 +++++----------------------------- prudp_virtual_stream_types.go | 36 -- resend_scheduler.go | 42 +- server_interface.go | 1 - sliding_window.go | 81 ++++ socket_connection.go | 26 ++ stream_settings.go | 70 ++++ stream_type.go | 49 +++ test/auth.go | 26 +- test/secure.go | 51 +-- virtual_port.go | 31 ++ websocket_server.go | 39 +- 27 files changed, 1563 insertions(+), 1114 deletions(-) create mode 100644 encryption/algorithm.go create mode 100644 encryption/dummy.go create mode 100644 encryption/rc4.go delete mode 100644 prudp_client.go create mode 100644 prudp_connection.go create mode 100644 prudp_endpoint.go delete mode 100644 prudp_virtual_stream_types.go create mode 100644 sliding_window.go create mode 100644 socket_connection.go create mode 100644 stream_settings.go create mode 100644 stream_type.go create mode 100644 virtual_port.go diff --git a/compression/algorithm.go b/compression/algorithm.go index 6a09124a..7e527690 100644 --- a/compression/algorithm.go +++ b/compression/algorithm.go @@ -6,4 +6,5 @@ package compression type Algorithm interface { Compress(payload []byte) ([]byte, error) Decompress(payload []byte) ([]byte, error) + Copy() Algorithm } diff --git a/compression/dummy.go b/compression/dummy.go index 33a1080b..8cca9f37 100644 --- a/compression/dummy.go +++ b/compression/dummy.go @@ -13,6 +13,11 @@ func (d *Dummy) Decompress(payload []byte) ([]byte, error) { return payload, nil } +// Copy returns a copy of the algorithm +func (d *Dummy) Copy() Algorithm { + return NewDummyCompression() +} + // NewDummyCompression returns a new instance of the Dummy compression func NewDummyCompression() *Dummy { return &Dummy{} diff --git a/compression/lzo.go b/compression/lzo.go index 55c4a284..5e93d3eb 100644 --- a/compression/lzo.go +++ b/compression/lzo.go @@ -79,3 +79,13 @@ func (l *LZO) Decompress(payload []byte) ([]byte, error) { return decompressedBytes, nil } + +// Copy returns a copy of the algorithm +func (l *LZO) Copy() Algorithm { + return NewLZOCompression() +} + +// NewLZOCompression returns a new instance of the LZO compression +func NewLZOCompression() *LZO { + return &LZO{} +} diff --git a/compression/zlib.go b/compression/zlib.go index 82dd096a..f570d84e 100644 --- a/compression/zlib.go +++ b/compression/zlib.go @@ -77,6 +77,11 @@ func (z *Zlib) Decompress(payload []byte) ([]byte, error) { return decompressedBytes, nil } +// Copy returns a copy of the algorithm +func (z *Zlib) Copy() Algorithm { + return NewZlibCompression() +} + // NewZlibCompression returns a new instance of the Zlib compression func NewZlibCompression() *Zlib { return &Zlib{} diff --git a/encryption/algorithm.go b/encryption/algorithm.go new file mode 100644 index 00000000..8bf94310 --- /dev/null +++ b/encryption/algorithm.go @@ -0,0 +1,12 @@ +// Package encryption provides a set of encryption algorithms found +// in several versions of Rendez-Vous for encrypting payloads +package encryption + +// Algorithm defines all the methods a compression algorithm should have +type Algorithm interface { + Key() []byte + SetKey(key []byte) error + Encrypt(payload []byte) ([]byte, error) + Decrypt(payload []byte) ([]byte, error) + Copy() Algorithm +} diff --git a/encryption/dummy.go b/encryption/dummy.go new file mode 100644 index 00000000..02625e99 --- /dev/null +++ b/encryption/dummy.go @@ -0,0 +1,44 @@ +package encryption + +// Dummy does no encryption. Payloads are returned as-is +type Dummy struct { + key []byte +} + +// Key returns the crypto key +func (d *Dummy) Key() []byte { + return d.key +} + +// SetKey sets the crypto key +func (d *Dummy) SetKey(key []byte) error { + d.key = key + + return nil +} + +// Encrypt does nothing +func (d *Dummy) Encrypt(payload []byte) ([]byte, error) { + return payload, nil +} + +// Decrypt does nothing +func (d *Dummy) Decrypt(payload []byte) ([]byte, error) { + return payload, nil +} + +// Copy returns a copy of the algorithm while retaining it's state +func (d *Dummy) Copy() Algorithm { + copied := NewDummyEncryption() + + copied.key = d.key + + return copied +} + +// NewDummyEncryption returns a new instance of the Dummy encryption +func NewDummyEncryption() *Dummy { + return &Dummy{ + key: make([]byte, 0), + } +} diff --git a/encryption/rc4.go b/encryption/rc4.go new file mode 100644 index 00000000..c4b16b9a --- /dev/null +++ b/encryption/rc4.go @@ -0,0 +1,95 @@ +package encryption + +import ( + "crypto/rc4" +) + +// RC4 does no encryption. Payloads are returned as-is +type RC4 struct { + key []byte + cipher *rc4.Cipher + decipher *rc4.Cipher + cipheredCount uint64 + decipheredCount uint64 +} + +// Key returns the crypto key +func (r *RC4) Key() []byte { + return r.key +} + +// SetKey sets the crypto key and updates the ciphers +func (r *RC4) SetKey(key []byte) error { + r.key = key + + cipher, err := rc4.NewCipher(key) + if err != nil { + return err + } + + decipher, err := rc4.NewCipher(key) + if err != nil { + return err + } + + r.cipher = cipher + r.decipher = decipher + + return nil +} + +// Encrypt encrypts the payload with the outgoing RC4 stream +func (r *RC4) Encrypt(payload []byte) ([]byte, error) { + ciphered := make([]byte, len(payload)) + + r.cipher.XORKeyStream(ciphered, payload) + + r.cipheredCount += uint64(len(payload)) + + return ciphered, nil +} + +// Decrypt decrypts the payload with the incoming RC4 stream +func (r *RC4) Decrypt(payload []byte) ([]byte, error) { + deciphered := make([]byte, len(payload)) + + r.decipher.XORKeyStream(deciphered, payload) + + r.decipheredCount += uint64(len(payload)) + + return deciphered, nil +} + +// Copy returns a copy of the algorithm while retaining it's state +func (r *RC4) Copy() Algorithm { + copied := NewRC4Encryption() + + copied.SetKey(r.key) + + // * crypto/rc4 does not expose a way to directly copy streams and retain their state. + // * This just discards the number of iterations done in the original ciphers to sync + // * the copied ciphers states to the original + for i := 0; i < int(r.cipheredCount); i++ { + copied.cipher.XORKeyStream([]byte{0}, []byte{0}) + } + + for i := 0; i < int(r.decipheredCount); i++ { + copied.decipher.XORKeyStream([]byte{0}, []byte{0}) + } + + copied.cipheredCount = r.cipheredCount + copied.decipheredCount = r.decipheredCount + + return copied +} + +// NewRC4Encryption returns a new instance of the RC4 encryption +func NewRC4Encryption() *RC4 { + encryption := &RC4{ + key: make([]byte, 0), + } + + encryption.SetKey([]byte("CD&ML")) // TODO - Make this configurable? + + return encryption +} diff --git a/prudp_client.go b/prudp_client.go deleted file mode 100644 index 8e0d9eda..00000000 --- a/prudp_client.go +++ /dev/null @@ -1,230 +0,0 @@ -package nex - -import ( - "crypto/md5" - "fmt" - "net" - "time" - - "github.com/PretendoNetwork/nex-go/types" - "github.com/lxzan/gws" -) - -// PRUDPClient represents a single PRUDP client -type PRUDPClient struct { - server *PRUDPServer - address net.Addr - webSocketConnection *gws.Conn - pid *types.PID - clientConnectionSignature []byte - serverConnectionSignature []byte - clientSessionID uint8 - serverSessionID uint8 - sessionKey []byte - reliableSubstreams []*ReliablePacketSubstreamManager - outgoingUnreliableSequenceIDCounter *Counter[uint16] - outgoingPingSequenceIDCounter *Counter[uint16] - heartbeatTimer *time.Timer - pingKickTimer *time.Timer - SourceStreamType uint8 - SourcePort uint8 - DestinationStreamType uint8 - DestinationPort uint8 - minorVersion uint32 // * Not currently used for anything, but maybe useful later? - supportedFunctions uint32 // * Not currently used for anything, but maybe useful later? - ConnectionID uint32 - StationURLs []*types.StationURL - unreliableBaseKey []byte -} - -// reset sets the client back to it's default state -func (c *PRUDPClient) reset() { - for _, substream := range c.reliableSubstreams { - substream.ResendScheduler.Stop() - } - - c.clientConnectionSignature = make([]byte, 0) - c.serverConnectionSignature = make([]byte, 0) - c.sessionKey = make([]byte, 0) - c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) - c.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](1) - c.outgoingPingSequenceIDCounter = NewCounter[uint16](0) - c.SourceStreamType = 0 - c.SourcePort = 0 - c.DestinationStreamType = 0 - c.DestinationPort = 0 -} - -// cleanup cleans up any resources the client may be using -// -// This is similar to Client.reset(), with the key difference -// being that cleanup does not care about the state the client -// is currently in, or will be in, after execution. It only -// frees resources that are not easily garbage collected -func (c *PRUDPClient) cleanup() { - for _, substream := range c.reliableSubstreams { - substream.ResendScheduler.Stop() - } - - c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, 0) - c.stopHeartbeatTimers() - - if c.webSocketConnection != nil { - c.webSocketConnection.NetConn().Close() // TODO - Swap this out for WriteClose() to send a close frame? - } - - c.server.emitRemoved(c) -} - -// Server returns the server the client is connecting to -func (c *PRUDPClient) Server() ServerInterface { - return c.server -} - -// Address returns the clients address as a net.Addr -func (c *PRUDPClient) Address() net.Addr { - return c.address -} - -// PID returns the clients NEX PID -func (c *PRUDPClient) PID() *types.PID { - return c.pid -} - -// SetPID sets the clients NEX PID -func (c *PRUDPClient) SetPID(pid *types.PID) { - c.pid = pid -} - -// setSessionKey sets the clients session key used for reliable RC4 ciphers -func (c *PRUDPClient) setSessionKey(sessionKey []byte) { - c.sessionKey = sessionKey - - c.reliableSubstreams[0].SetCipherKey(sessionKey) - - // * Only the first substream uses the session key directly. - // * All other substreams modify the key before it so that - // * all substreams have a unique cipher key - for _, substream := range c.reliableSubstreams[1:] { - modifier := len(sessionKey)/2 + 1 - - // * Create a new slice to avoid modifying past keys - sessionKey = append(make([]byte, 0), sessionKey...) - - // * Only the first half of the key is modified - for i := 0; i < len(sessionKey)/2; i++ { - sessionKey[i] = (sessionKey[i] + byte(modifier-i)) & 0xFF - } - - substream.SetCipherKey(sessionKey) - } - - // * Init the base key used for unreliable DATA packets. - // * - // * Since unreliable DATA packets can come in out of - // * order, each packet uses a dedicated RC4 stream. The - // * key of each RC4 stream is made up by using this base - // * key, modified using the packets sequence/session IDs - unreliableBaseKeyPart1 := md5.Sum(append(sessionKey, []byte{0x18, 0xD8, 0x23, 0x34, 0x37, 0xE4, 0xE3, 0xFE}...)) - unreliableBaseKeyPart2 := md5.Sum(append(sessionKey, []byte{0x23, 0x3E, 0x60, 0x01, 0x23, 0xCD, 0xAB, 0x80}...)) - - c.unreliableBaseKey = append(unreliableBaseKeyPart1[:], unreliableBaseKeyPart2[:]...) -} - -// reliableSubstream returns the clients reliable substream ID -func (c *PRUDPClient) reliableSubstream(substreamID uint8) *ReliablePacketSubstreamManager { - // * Fail-safe. The client may not always have - // * the correct number of substreams. See the - // * comment in handleSocketMessage of PRUDPServer - // * for more details - if int(substreamID) >= len(c.reliableSubstreams) { - return c.reliableSubstreams[0] - } else { - return c.reliableSubstreams[substreamID] - } -} - -// createReliableSubstreams creates the list of substreams used for reliable PRUDP packets -func (c *PRUDPClient) createReliableSubstreams(maxSubstreamID uint8) { - // * Kill any existing substreams - for _, substream := range c.reliableSubstreams { - substream.ResendScheduler.Stop() - } - - substreams := maxSubstreamID + 1 - - c.reliableSubstreams = make([]*ReliablePacketSubstreamManager, substreams) - - for i := 0; i < len(c.reliableSubstreams); i++ { - // * First DATA packet from the client has sequence ID 2 - // * First DATA packet from the server has sequence ID 1 (starts counter at 0 and is incremeneted) - c.reliableSubstreams[i] = NewReliablePacketSubstreamManager(2, 0) - } -} - -func (c *PRUDPClient) nextOutgoingUnreliableSequenceID() uint16 { - return c.outgoingUnreliableSequenceIDCounter.Next() -} - -func (c *PRUDPClient) nextOutgoingPingSequenceID() uint16 { - return c.outgoingPingSequenceIDCounter.Next() -} - -func (c *PRUDPClient) resetHeartbeat() { - if c.pingKickTimer != nil { - c.pingKickTimer.Stop() - } - - if c.heartbeatTimer != nil { - c.heartbeatTimer.Reset(c.server.pingTimeout) - } -} - -func (c *PRUDPClient) startHeartbeat() { - server := c.server - - // * Every time a packet is sent, client.resetHeartbeat() - // * is called which resets this timer. If this function - // * ever executes, it means we haven't seen the client - // * in the expected time frame. If this happens, send - // * the client a PING packet to try and kick start the - // * heartbeat again - c.heartbeatTimer = time.AfterFunc(server.pingTimeout, func() { - server.sendPing(c) - - // * If the heartbeat still did not restart, assume the - // * client is dead and clean up - c.pingKickTimer = time.AfterFunc(server.pingTimeout, func() { - c.cleanup() // * "removed" event is dispatched here - - virtualServer, _ := c.server.virtualServers.Get(c.DestinationPort) - virtualServerStream, _ := virtualServer.Get(c.DestinationStreamType) - - discriminator := fmt.Sprintf("%s-%d-%d", c.address.String(), c.SourcePort, c.SourceStreamType) - - virtualServerStream.Delete(discriminator) - }) - }) -} - -func (c *PRUDPClient) stopHeartbeatTimers() { - if c.pingKickTimer != nil { - c.pingKickTimer.Stop() - } - - if c.heartbeatTimer != nil { - c.heartbeatTimer.Stop() - } -} - -// NewPRUDPClient creates and returns a new PRUDPClient -func NewPRUDPClient(server *PRUDPServer, address net.Addr, webSocketConnection *gws.Conn) *PRUDPClient { - return &PRUDPClient{ - server: server, - address: address, - webSocketConnection: webSocketConnection, - outgoingPingSequenceIDCounter: NewCounter[uint16](0), - pid: types.NewPID(0), - unreliableBaseKey: make([]byte, 0x20), - } -} diff --git a/prudp_connection.go b/prudp_connection.go new file mode 100644 index 00000000..f0db7fc0 --- /dev/null +++ b/prudp_connection.go @@ -0,0 +1,218 @@ +package nex + +import ( + "crypto/md5" + "fmt" + "net" + "time" + + "github.com/PretendoNetwork/nex-go/types" +) + +// PRUDPConnection implements an individual PRUDP virtual connection. +// Does not necessarily represent a socket connection. +// A single network socket may be used to open multiple PRUDP virtual connections +type PRUDPConnection struct { + Socket *SocketConnection // * The connections parent socket + Endpoint *PRUDPEndPoint // * The PRUDP endpoint the connection is connected to + ID uint32 // * Connection ID + SessionID uint8 // * Random value generated at the start of the session. Client and server IDs do not need to match + ServerSessionID uint8 // * Random value generated at the start of the session. Client and server IDs do not need to match + SessionKey []byte // * Secret key generated at the start of the session. Used for encrypting packets to the secure server + pid *types.PID // * PID of the user + DefaultPRUDPVersion int // * The PRUDP version the connection was established with. Used for sending PING packets + StreamType StreamType // * rdv::Stream::Type used in this connection + StreamID uint8 // * rdv::Stream ID, also called the "port number", used in this connection. 0-15 on PRUDPv0/v1, and 0-31 on PRUDPLite + StreamSettings *StreamSettings // * Settings for this virtual connection + Signature []byte // * Connection signature for packets coming from the client, as seen by the server + ServerConnectionSignature []byte // * Connection signature for packets coming from the server, as seen by the client + UnreliablePacketBaseKey []byte // * The base key used for encrypting unreliable DATA packets + slidingWindows *MutexMap[uint8, *SlidingWindow] // * Reliable packet substreams + outgoingUnreliableSequenceIDCounter *Counter[uint16] + outgoingPingSequenceIDCounter *Counter[uint16] + heartbeatTimer *time.Timer + pingKickTimer *time.Timer +} + +// Server returns the PRUDP server the connections socket is connected to +func (pc *PRUDPConnection) Server() ServerInterface { + return pc.Socket.Server +} + +// Address returns the socket address of the connection +func (pc *PRUDPConnection) Address() net.Addr { + return pc.Socket.Address +} + +// PID returns the clients unique PID +func (pc *PRUDPConnection) PID() *types.PID { + return pc.pid +} + +// SetPID sets the clients unique PID +func (pc *PRUDPConnection) SetPID(pid *types.PID) { + pc.pid = pid +} + +// reset resets the connection state to all zero values +func (pc *PRUDPConnection) reset() { + pc.slidingWindows.Clear(func(_ uint8, slidingWindow *SlidingWindow) { + slidingWindow.ResendScheduler.Stop() + }) + + pc.Signature = make([]byte, 0) + pc.ServerConnectionSignature = make([]byte, 0) + pc.SessionKey = make([]byte, 0) + pc.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](1) + pc.outgoingPingSequenceIDCounter = NewCounter[uint16](0) +} + +// cleanup resets the connection state and cleans up some resources. Used when a client is considered dead and to be removed from the endpoint +func (pc *PRUDPConnection) cleanup() { + pc.reset() + + pc.stopHeartbeatTimers() + + pc.Socket.Connections.Delete(pc.SessionID) + + pc.Endpoint.emitConnectionEnded(pc) + + if pc.Socket.Connections.Size() == 0 { + // * No more PRUDP connections, assume the socket connection is also closed + pc.Endpoint.Server.Connections.Delete(pc.Socket.Address.String()) + // TODO - Is there any other cleanup that needs to happen here? + // TODO - Should we add an event for when a socket closes too? + } +} + +// InitializeSlidingWindows returns the InitializeSlidingWindows for the given substream +func (pc *PRUDPConnection) InitializeSlidingWindows(maxSubstreamID uint8) { + // * Nuke any existing SlidingWindows + pc.slidingWindows = NewMutexMap[uint8, *SlidingWindow]() + + for i := 0; i < int(maxSubstreamID+1); i++ { + pc.CreateSlidingWindow(uint8(i)) + } +} + +// CreateSlidingWindow returns the CreateSlidingWindow for the given substream +func (pc *PRUDPConnection) CreateSlidingWindow(substreamID uint8) *SlidingWindow { + slidingWindow := NewSlidingWindow() + slidingWindow.incomingSequenceIDCounter = NewCounter[uint16](2) // * First DATA packet from the client has sequence ID 2 + slidingWindow.outgoingSequenceIDCounter = NewCounter[uint16](0) // * First DATA packet from the server has sequence ID 1 (start counter at 0 and is incremeneted) + slidingWindow.streamSettings = pc.StreamSettings.Copy() + + pc.slidingWindows.Set(substreamID, slidingWindow) + + return slidingWindow +} + +// SlidingWindow returns the SlidingWindow for the given substream +func (pc *PRUDPConnection) SlidingWindow(substreamID uint8) *SlidingWindow { + slidingWindow, ok := pc.slidingWindows.Get(substreamID) + if !ok { + // * Fail-safe. The connection may not always have + // * the correct number of substreams. See the + // * comment in handleSocketMessage of PRUDPEndPoint + // * for more details + slidingWindow = pc.CreateSlidingWindow(substreamID) + } + + return slidingWindow +} + +// setSessionKey sets the connection's session key and updates the SlidingWindows +func (pc *PRUDPConnection) setSessionKey(sessionKey []byte) { + pc.SessionKey = sessionKey + + pc.slidingWindows.Each(func(substreamID uint8, slidingWindow *SlidingWindow) bool { + // * Only the first substream uses the session key directly. + // * All other substreams modify the key before it so that + // * all substreams have a unique cipher key + + if substreamID == 0 { + slidingWindow.SetCipherKey(sessionKey) + } else { + modifier := len(sessionKey)/2 + 1 + + // * Create a new slice to avoid modifying past keys + sessionKey = append(make([]byte, 0), sessionKey...) + + // * Only the first half of the key is modified + for i := 0; i < len(sessionKey)/2; i++ { + sessionKey[i] = (sessionKey[i] + byte(modifier-i)) & 0xFF + } + + slidingWindow.SetCipherKey(sessionKey) + } + + return false + }) + + // * Init the base key used for unreliable DATA packets. + // * + // * Since unreliable DATA packets can come in out of + // * order, each packet uses a dedicated RC4 stream. The + // * key of each RC4 stream is made up by using this base + // * key, modified using the packets sequence/session IDs + unreliableBaseKeyPart1 := md5.Sum(append(sessionKey, []byte{0x18, 0xD8, 0x23, 0x34, 0x37, 0xE4, 0xE3, 0xFE}...)) + unreliableBaseKeyPart2 := md5.Sum(append(sessionKey, []byte{0x23, 0x3E, 0x60, 0x01, 0x23, 0xCD, 0xAB, 0x80}...)) + + pc.UnreliablePacketBaseKey = append(unreliableBaseKeyPart1[:], unreliableBaseKeyPart2[:]...) +} + +func (pc *PRUDPConnection) resetHeartbeat() { + if pc.pingKickTimer != nil { + pc.pingKickTimer.Stop() + } + + if pc.heartbeatTimer != nil { + pc.heartbeatTimer.Reset(pc.Endpoint.Server.pingTimeout) // TODO - This is part of StreamSettings + } +} + +func (pc *PRUDPConnection) startHeartbeat() { + endpoint := pc.Endpoint + server := endpoint.Server + + // * Every time a packet is sent, connection.resetHeartbeat() + // * is called which resets this timer. If this function + // * ever executes, it means we haven't seen the client + // * in the expected time frame. If this happens, send + // * the client a PING packet to try and kick start the + // * heartbeat again + pc.heartbeatTimer = time.AfterFunc(server.pingTimeout, func() { + endpoint.sendPing(pc) + + // * If the heartbeat still did not restart, assume the + // * connection is dead and clean up + pc.pingKickTimer = time.AfterFunc(server.pingTimeout, func() { + pc.cleanup() // * "removed" event is dispatched here + + discriminator := fmt.Sprintf("%s-%d-%d", pc.Socket.Address.String(), pc.StreamType, pc.StreamID) + + endpoint.Connections.Delete(discriminator) + }) + }) +} + +func (pc *PRUDPConnection) stopHeartbeatTimers() { + if pc.pingKickTimer != nil { + pc.pingKickTimer.Stop() + } + + if pc.heartbeatTimer != nil { + pc.heartbeatTimer.Stop() + } +} + +// NewPRUDPConnection creates a new PRUDPConnection for a given socket +func NewPRUDPConnection(socket *SocketConnection) *PRUDPConnection { + return &PRUDPConnection{ + Socket: socket, + pid: types.NewPID(0), + slidingWindows: NewMutexMap[uint8, *SlidingWindow](), + outgoingUnreliableSequenceIDCounter: NewCounter[uint16](1), + outgoingPingSequenceIDCounter: NewCounter[uint16](0), + } +} diff --git a/prudp_endpoint.go b/prudp_endpoint.go new file mode 100644 index 00000000..1502e417 --- /dev/null +++ b/prudp_endpoint.go @@ -0,0 +1,603 @@ +package nex + +import ( + "encoding/binary" + "errors" + "fmt" + "slices" + "time" + + "github.com/PretendoNetwork/nex-go/types" +) + +// PRUDPEndPoint is an implementation of rdv::PRUDPEndPoint. +// A PRUDPEndPoint represents a remote server location the client may connect to using a given remote stream ID. +// Each PRUDPEndPoint handles it's own set of PRUDPConnections, state, and events. +type PRUDPEndPoint struct { + Server *PRUDPServer + StreamID uint8 + DefaultstreamSettings *StreamSettings + Connections *MutexMap[string, *PRUDPConnection] + packetEventHandlers map[string][]func(packet PacketInterface) + connectionEndedEventHandlers []func(connection *PRUDPConnection) + ConnectionIDCounter *Counter[uint32] + IsSecureEndpoint bool // TODO - Remove this? Assume if CONNECT packet has a body, it's the secure server? +} + +// OnData adds an event handler which is fired when a new DATA packet is received +func (pep *PRUDPEndPoint) OnData(handler func(packet PacketInterface)) { + pep.on("data", handler) +} + +// OnDisconnect adds an event handler which is fired when a new DISCONNECT packet is received +// +// To handle a connection being removed from the server, see OnConnectionEnded which fires on more cases +func (pep *PRUDPEndPoint) OnDisconnect(handler func(packet PacketInterface)) { + pep.on("disconnect", handler) +} + +// OnConnectionEnded adds an event handler which is fired when a connection is removed from the server +// +// Fires both on a natural disconnect and from a timeout +func (pep *PRUDPEndPoint) OnConnectionEnded(handler func(connection *PRUDPConnection)) { + // * "Ended" events are a special case, so handle them separately + pep.connectionEndedEventHandlers = append(pep.connectionEndedEventHandlers, handler) +} + +func (pep *PRUDPEndPoint) on(name string, handler func(packet PacketInterface)) { + if _, ok := pep.packetEventHandlers[name]; !ok { + pep.packetEventHandlers[name] = make([]func(packet PacketInterface), 0) + } + + pep.packetEventHandlers[name] = append(pep.packetEventHandlers[name], handler) +} + +func (pep *PRUDPEndPoint) emit(name string, packet PRUDPPacketInterface) { + if handlers, ok := pep.packetEventHandlers[name]; ok { + for _, handler := range handlers { + go handler(packet) + } + } +} + +func (pep *PRUDPEndPoint) emitConnectionEnded(connection *PRUDPConnection) { + for _, handler := range pep.connectionEndedEventHandlers { + go handler(connection) + } +} + +func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *SocketConnection) { + streamType := packet.SourceVirtualPortStreamType() + streamID := packet.SourceVirtualPortStreamID() + discriminator := fmt.Sprintf("%s-%d-%d", socket.Address.String(), streamType, streamID) + connection, ok := pep.Connections.Get(discriminator) + + if !ok { + connection = NewPRUDPConnection(socket) + connection.Endpoint = pep + connection.ID = pep.ConnectionIDCounter.Next() + connection.DefaultPRUDPVersion = packet.Version() + connection.StreamType = streamType + connection.StreamID = streamID + connection.StreamSettings = pep.DefaultstreamSettings.Copy() + connection.startHeartbeat() + + // * Fail-safe. If the server reboots, then + // * connection has no record of old connections. + // * An existing client which has not killed + // * the connection on it's end MAY still send + // * DATA packets once the server is back + // * online, assuming it reboots fast enough. + // * Since the client did NOT redo the SYN + // * and CONNECT packets, it's reliable + // * substreams never got remade. This is put + // * in place to ensure there is always AT + // * LEAST one substream in place, so the client + // * can naturally error out due to the RC4 + // * errors. + // * + // * NOTE: THE CLIENT MAY NOT HAVE THE REAL + // * CORRECT NUMBER OF SUBSTREAMS HERE. THIS + // * IS ONLY DONE TO PREVENT A SERVER CRASH, + // * NOT TO SAVE THE CLIENT. THE CLIENT IS + // * EXPECTED TO NATURALLY DIE HERE + connection.InitializeSlidingWindows(0) + + pep.Connections.Set(discriminator, connection) + } + + packet.SetSender(connection) + connection.resetHeartbeat() + + if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { + pep.handleAcknowledgment(packet) + return + } + + switch packet.Type() { + case SynPacket: + pep.handleSyn(packet) + case ConnectPacket: + pep.handleConnect(packet) + case DataPacket: + pep.handleData(packet) + case DisconnectPacket: + pep.handleDisconnect(packet) + case PingPacket: + pep.handlePing(packet) + } +} + +func (pep *PRUDPEndPoint) handleAcknowledgment(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagMultiAck) { + pep.handleMultiAcknowledgment(packet) + return + } + + connection := packet.Sender().(*PRUDPConnection) + + slidingWindow := connection.SlidingWindow(packet.SubstreamID()) + slidingWindow.ResendScheduler.AcknowledgePacket(packet.SequenceID()) +} + +func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface) { + connection := packet.Sender().(*PRUDPConnection) + stream := NewByteStreamIn(packet.Payload(), pep.Server) + sequenceIDs := make([]uint16, 0) + var baseSequenceID uint16 + var slidingWindow *SlidingWindow + + if packet.SubstreamID() == 1 { + // * New aggregate acknowledgment packets set this to 1 + // * and encode the real substream ID in in the payload + substreamID, _ := stream.ReadPrimitiveUInt8() + additionalIDsCount, _ := stream.ReadPrimitiveUInt8() + baseSequenceID, _ = stream.ReadPrimitiveUInt16LE() + slidingWindow = connection.SlidingWindow(substreamID) + + for i := 0; i < int(additionalIDsCount); i++ { + additionalID, _ := stream.ReadPrimitiveUInt16LE() + sequenceIDs = append(sequenceIDs, additionalID) + } + } else { + // TODO - This is how Kinnay's client handles this, but it doesn't make sense for QRV? Since it can have multiple reliable substreams? + // * Old aggregate acknowledgment packets always use + // * substream 0 + slidingWindow = connection.SlidingWindow(0) + baseSequenceID = packet.SequenceID() + + for stream.Remaining() > 0 { + additionalID, _ := stream.ReadPrimitiveUInt16LE() + sequenceIDs = append(sequenceIDs, additionalID) + } + } + + // * MutexMap.Each locks the mutex, can't remove while reading. + // * Have to just loop again + slidingWindow.ResendScheduler.packets.Each(func(sequenceID uint16, pending *PendingPacket) bool { + if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) { + sequenceIDs = append(sequenceIDs, sequenceID) + } + + return false + }) + + // * Actually remove the packets from the pool + for _, sequenceID := range sequenceIDs { + slidingWindow.ResendScheduler.AcknowledgePacket(sequenceID) + } +} + +func (pep *PRUDPEndPoint) handleSyn(packet PRUDPPacketInterface) { + connection := packet.Sender().(*PRUDPConnection) + + var ack PRUDPPacketInterface + + if packet.Version() == 2 { + ack, _ = NewPRUDPPacketLite(connection, nil) + } else if packet.Version() == 1 { + ack, _ = NewPRUDPPacketV1(connection, nil) + } else { + ack, _ = NewPRUDPPacketV0(connection, nil) + } + + connectionSignature, err := packet.calculateConnectionSignature(connection.Socket.Address) + if err != nil { + logger.Error(err.Error()) + } + + connection.reset() + connection.Signature = connectionSignature + + ack.SetType(SynPacket) + ack.AddFlag(FlagAck) + ack.AddFlag(FlagHasSize) + ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + ack.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) + ack.setConnectionSignature(connectionSignature) + ack.setSignature(ack.calculateSignature([]byte{}, []byte{})) + + if ack, ok := ack.(*PRUDPPacketV1); ok { + // * Negotiate with the client what we support + ack.maximumSubstreamID = packet.(*PRUDPPacketV1).maximumSubstreamID // * No change needed, we can just support what the client wants + ack.minorVersion = packet.(*PRUDPPacketV1).minorVersion // * No change needed, we can just support what the client wants + ack.supportedFunctions = pep.Server.SupportedFunctions & packet.(*PRUDPPacketV1).supportedFunctions + } + + pep.emit("syn", ack) + + pep.Server.sendRaw(connection.Socket, ack.Bytes()) +} + +func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { + connection := packet.Sender().(*PRUDPConnection) + + var ack PRUDPPacketInterface + + if packet.Version() == 2 { + ack, _ = NewPRUDPPacketLite(connection, nil) + } else if packet.Version() == 1 { + ack, _ = NewPRUDPPacketV1(connection, nil) + } else { + ack, _ = NewPRUDPPacketV0(connection, nil) + } + + connection.ServerConnectionSignature = packet.getConnectionSignature() + connection.SessionID = packet.SessionID() + + connectionSignature, err := packet.calculateConnectionSignature(connection.Socket.Address) + if err != nil { + logger.Error(err.Error()) + } + + connection.ServerSessionID = packet.SessionID() + + ack.SetType(ConnectPacket) + ack.AddFlag(FlagAck) + ack.AddFlag(FlagHasSize) + ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + ack.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) + ack.setConnectionSignature(make([]byte, len(connectionSignature))) + ack.SetSessionID(connection.ServerSessionID) + ack.SetSequenceID(1) + + if ack, ok := ack.(*PRUDPPacketV1); ok { + // * At this stage the client and server have already + // * negotiated what they each can support, so configure + // * the client now and just send the client back the + // * negotiated configuration + ack.maximumSubstreamID = packet.(*PRUDPPacketV1).maximumSubstreamID + ack.minorVersion = packet.(*PRUDPPacketV1).minorVersion + ack.supportedFunctions = packet.(*PRUDPPacketV1).supportedFunctions + + connection.InitializeSlidingWindows(ack.maximumSubstreamID) + connection.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](packet.(*PRUDPPacketV1).initialUnreliableSequenceID) + } else { + connection.InitializeSlidingWindows(0) + } + + payload := make([]byte, 0) + + if pep.IsSecureEndpoint { + sessionKey, pid, checkValue, err := pep.readKerberosTicket(packet.Payload()) + if err != nil { + logger.Error(err.Error()) + } + + connection.SetPID(pid) + connection.setSessionKey(sessionKey) + + responseCheckValue := checkValue + 1 + responseCheckValueBytes := make([]byte, 4) + + binary.LittleEndian.PutUint32(responseCheckValueBytes, responseCheckValue) + + checkValueResponse := types.NewBuffer(responseCheckValueBytes) + stream := NewByteStreamOut(pep.Server) + + checkValueResponse.WriteTo(stream) + + payload = stream.Bytes() + } + + ack.SetPayload(payload) + ack.setSignature(ack.calculateSignature([]byte{}, packet.getConnectionSignature())) + + pep.emit("connect", ack) + + pep.Server.sendRaw(connection.Socket, ack.Bytes()) +} + +func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagReliable) { + pep.handleReliable(packet) + } else { + pep.handleUnreliable(packet) + } +} + +func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + pep.acknowledgePacket(packet) + } + + streamType := packet.SourceVirtualPortStreamType() + streamID := packet.SourceVirtualPortStreamID() + discriminator := fmt.Sprintf("%s-%d-%d", packet.Sender().Address().String(), streamType, streamID) + if connection, ok := pep.Connections.Get(discriminator); ok { + connection.cleanup() + pep.Connections.Delete(discriminator) + } + + pep.emit("disconnect", packet) +} + +func (pep *PRUDPEndPoint) handlePing(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + pep.acknowledgePacket(packet) + } +} + +func (pep *PRUDPEndPoint) readKerberosTicket(payload []byte) ([]byte, *types.PID, uint32, error) { + stream := NewByteStreamIn(payload, pep.Server) + + ticketData := types.NewBuffer(nil) + if err := ticketData.ExtractFrom(stream); err != nil { + return nil, nil, 0, err + } + + requestData := types.NewBuffer(nil) + if err := requestData.ExtractFrom(stream); err != nil { + return nil, nil, 0, err + } + + serverKey := DeriveKerberosKey(types.NewPID(2), pep.Server.kerberosPassword) + + ticket := NewKerberosTicketInternalData() + if err := ticket.Decrypt(NewByteStreamIn(ticketData.Value, pep.Server), serverKey); err != nil { + return nil, nil, 0, err + } + + ticketTime := ticket.Issued.Standard() + serverTime := time.Now().UTC() + + timeLimit := ticketTime.Add(time.Minute * 2) + if serverTime.After(timeLimit) { + return nil, nil, 0, errors.New("Kerberos ticket expired") + } + + sessionKey := ticket.SessionKey + kerberos := NewKerberosEncryption(sessionKey) + + decryptedRequestData, err := kerberos.Decrypt(requestData.Value) + if err != nil { + return nil, nil, 0, err + } + + checkDataStream := NewByteStreamIn(decryptedRequestData, pep.Server) + + userPID := types.NewPID(0) + if err := userPID.ExtractFrom(checkDataStream); err != nil { + return nil, nil, 0, err + } + + _, err = checkDataStream.ReadPrimitiveUInt32LE() // * CID of secure server station url + if err != nil { + return nil, nil, 0, err + } + + responseCheck, err := checkDataStream.ReadPrimitiveUInt32LE() + if err != nil { + return nil, nil, 0, err + } + + return sessionKey, userPID, responseCheck, nil +} + +func (pep *PRUDPEndPoint) acknowledgePacket(packet PRUDPPacketInterface) { + var ack PRUDPPacketInterface + + if packet.Version() == 2 { + ack, _ = NewPRUDPPacketLite(packet.Sender().(*PRUDPConnection), nil) + } else if packet.Version() == 1 { + ack, _ = NewPRUDPPacketV1(packet.Sender().(*PRUDPConnection), nil) + } else { + ack, _ = NewPRUDPPacketV0(packet.Sender().(*PRUDPConnection), nil) + } + + ack.SetType(packet.Type()) + ack.AddFlag(FlagAck) + ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + ack.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) + ack.SetSequenceID(packet.SequenceID()) + ack.setFragmentID(packet.getFragmentID()) + ack.SetSubstreamID(packet.SubstreamID()) + + pep.Server.sendPacket(ack) + + // * Servers send the DISCONNECT ACK 3 times + if packet.Type() == DisconnectPacket { + pep.Server.sendPacket(ack) + pep.Server.sendPacket(ack) + } +} + +func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + pep.acknowledgePacket(packet) + } + + connection := packet.Sender().(*PRUDPConnection) + + slidingWindow := packet.Sender().(*PRUDPConnection).SlidingWindow(packet.SubstreamID()) + + for _, pendingPacket := range slidingWindow.Update(packet) { + if packet.Type() == DataPacket { + var decryptedPayload []byte + + if packet.Version() != 2 { + decryptedPayload = pendingPacket.decryptPayload() + } else { + // * PRUDPLite does not encrypt payloads + decryptedPayload = pendingPacket.Payload() + } + + decompressedPayload, err := connection.StreamSettings.CompressionAlgorithm.Decompress(decryptedPayload) + if err != nil { + logger.Error(err.Error()) + } + + payload := slidingWindow.AddFragment(decompressedPayload) + + if packet.getFragmentID() == 0 { + message := NewRMCMessage(pep.Server) + err := message.FromBytes(payload) + if err != nil { + // TODO - Should this return the error too? + logger.Error(err.Error()) + } + + slidingWindow.ResetFragmentedPayload() + + packet.SetRMCMessage(message) + + pep.emit("data", packet) + } + } + } +} + +func (pep *PRUDPEndPoint) handleUnreliable(packet PRUDPPacketInterface) { + if packet.HasFlag(FlagNeedsAck) { + pep.acknowledgePacket(packet) + } + + // * Since unreliable DATA packets can in theory reach the + // * server in any order, and they lack a subsslidingWindowtream, it's + // * not actually possible to know what order they should + // * be processed in for each request. So assume all packets + // * MUST be fragment 0 (unreliable packets do not have frags) + // * + // * Example - + // * + // * Say there is 2 requests to the same protocol, methods 1 + // * and 2. The starting unreliable sequence ID is 10. If both + // * method 1 and 2 are called at the same time, but method 1 + // * has a fragmented payload, the packets could, in theory, reach + // * the server like so: + // * + // * - Method1 - Sequence 10, Fragment 1 + // * - Method1 - Sequence 13, Fragment 3 + // * - Method2 - Sequence 12, Fragment 0 + // * - Method1 - Sequence 11, Fragment 2 + // * - Method1 - Sequence 14, Fragment 0 + // * + // * If we reorder these to the proper order, like so: + // * + // * - Method1 - Sequence 10, Fragment 1 + // * - Method1 - Sequence 11, Fragment 2 + // * - Method2 - Sequence 12, Fragment 0 + // * - Method1 - Sequence 13, Fragment 3 + // * - Method1 - Sequence 14, Fragment 0 + // * + // * We still have a gap where Method2 was called. It's not + // * possible to know if the packet with sequence ID 12 belongs + // * to the Method1 calls or not. We don't even know which methods + // * the packets are for at this stage yet, since the RMC data + // * can't be checked until all the fragments are collected and + // * the payload decrypted. In this case, we would see fragment 0 + // * and assume that's the end of fragments, losing the real last + // * fragments and resulting in a bad decryption + // TODO - Is this actually true? I'm just assuming, based on common sense, tbh. Kinnay also does not implement fragmented unreliable packets? + if packet.getFragmentID() != 0 { + logger.Warningf("Unexpected unreliable fragment ID. Expected 0, got %d", packet.getFragmentID()) + return + } + + payload := packet.processUnreliableCrypto() + + message := NewRMCMessage(pep.Server) + err := message.FromBytes(payload) + if err != nil { + // TODO - Should this return the error too? + logger.Error(err.Error()) + } + + packet.SetRMCMessage(message) + + pep.emit("data", packet) +} + +func (pep *PRUDPEndPoint) sendPing(connection *PRUDPConnection) { + var ping PRUDPPacketInterface + + switch connection.DefaultPRUDPVersion { + case 0: + ping, _ = NewPRUDPPacketV0(connection, nil) + case 1: + ping, _ = NewPRUDPPacketV1(connection, nil) + case 2: + ping, _ = NewPRUDPPacketLite(connection, nil) + } + + ping.SetType(PingPacket) + ping.AddFlag(FlagNeedsAck) + ping.SetSourceVirtualPortStreamType(connection.StreamType) + ping.SetSourceVirtualPortStreamID(pep.StreamID) + ping.SetDestinationVirtualPortStreamType(connection.StreamType) + ping.SetDestinationVirtualPortStreamID(connection.StreamID) + ping.SetSubstreamID(0) + + pep.Server.sendPacket(ping) +} + +// FindConnectionByID returns the PRUDP client connected with the given connection ID +func (pep *PRUDPEndPoint) FindConnectionByID(serverPort, serverStreamType uint8, connectedID uint32) *PRUDPConnection { + var connection *PRUDPConnection + + pep.Connections.Each(func(discriminator string, pc *PRUDPConnection) bool { + if pc.ID == connectedID { + connection = pc + return true + } + + return false + }) + + return connection +} + +// FindConnectionByPID returns the PRUDP client connected with the given PID +func (pep *PRUDPEndPoint) FindConnectionByPID(serverPort, serverStreamType uint8, pid uint64) *PRUDPConnection { + var connection *PRUDPConnection + + pep.Connections.Each(func(discriminator string, pc *PRUDPConnection) bool { + if pc.pid.Value() == pid { + connection = pc + return true + } + + return false + }) + + return connection +} + +// NewPRUDPEndPoint returns a new PRUDPEndPoint for a server on the provided stream ID +func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { + return &PRUDPEndPoint{ + StreamID: streamID, + DefaultstreamSettings: NewStreamSettings(), + Connections: NewMutexMap[string, *PRUDPConnection](), + packetEventHandlers: make(map[string][]func(PacketInterface)), + connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), + ConnectionIDCounter: NewCounter[uint32](0), + IsSecureEndpoint: false, + } +} diff --git a/prudp_packet.go b/prudp_packet.go index db6b6bd2..a29fa418 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -4,29 +4,27 @@ import "crypto/rc4" // PRUDPPacket holds all the fields each packet should have in all PRUDP versions type PRUDPPacket struct { - server *PRUDPServer - sender *PRUDPClient - readStream *ByteStreamIn - version uint8 - sourceStreamType uint8 - sourcePort uint8 - destinationStreamType uint8 - destinationPort uint8 - packetType uint16 - flags uint16 - sessionID uint8 - substreamID uint8 - signature []byte - sequenceID uint16 - connectionSignature []byte - fragmentID uint8 - payload []byte - message *RMCMessage + server *PRUDPServer + sender *PRUDPConnection + readStream *ByteStreamIn + version uint8 + sourceVirtualPort VirtualPort + destinationVirtualPort VirtualPort + packetType uint16 + flags uint16 + sessionID uint8 + substreamID uint8 + signature []byte + sequenceID uint16 + connectionSignature []byte + fragmentID uint8 + payload []byte + message *RMCMessage } // SetSender sets the Client who sent the packet func (p *PRUDPPacket) SetSender(sender ClientInterface) { - p.sender = sender.(*PRUDPClient) + p.sender = sender.(*PRUDPConnection) } // Sender returns the Client who sent the packet @@ -59,44 +57,44 @@ func (p *PRUDPPacket) Type() uint16 { return p.packetType } -// SetSourceStreamType sets the packet virtual source stream type -func (p *PRUDPPacket) SetSourceStreamType(sourceStreamType uint8) { - p.sourceStreamType = sourceStreamType +// SetSourceVirtualPortStreamType sets the packets source VirtualPort StreamType +func (p *PRUDPPacket) SetSourceVirtualPortStreamType(streamType StreamType) { + p.sourceVirtualPort.SetStreamType(streamType) } -// SourceStreamType returns the packet virtual source stream type -func (p *PRUDPPacket) SourceStreamType() uint8 { - return p.sourceStreamType +// SourceVirtualPortStreamType returns the packets source VirtualPort StreamType +func (p *PRUDPPacket) SourceVirtualPortStreamType() StreamType { + return p.sourceVirtualPort.StreamType() } -// SetSourcePort sets the packet virtual source stream type -func (p *PRUDPPacket) SetSourcePort(sourcePort uint8) { - p.sourcePort = sourcePort +// SetSourceVirtualPortStreamID sets the packets source VirtualPort port number +func (p *PRUDPPacket) SetSourceVirtualPortStreamID(port uint8) { + p.sourceVirtualPort.SetStreamID(port) } -// SourcePort returns the packet virtual source stream type -func (p *PRUDPPacket) SourcePort() uint8 { - return p.sourcePort +// SourceVirtualPortStreamID returns the packets source VirtualPort port number +func (p *PRUDPPacket) SourceVirtualPortStreamID() uint8 { + return p.sourceVirtualPort.StreamID() } -// SetDestinationStreamType sets the packet virtual destination stream type -func (p *PRUDPPacket) SetDestinationStreamType(destinationStreamType uint8) { - p.destinationStreamType = destinationStreamType +// SetDestinationVirtualPortStreamType sets the packets destination VirtualPort StreamType +func (p *PRUDPPacket) SetDestinationVirtualPortStreamType(streamType StreamType) { + p.destinationVirtualPort.SetStreamType(streamType) } -// DestinationStreamType returns the packet virtual destination stream type -func (p *PRUDPPacket) DestinationStreamType() uint8 { - return p.destinationStreamType +// DestinationVirtualPortStreamType returns the packets destination VirtualPort StreamType +func (p *PRUDPPacket) DestinationVirtualPortStreamType() StreamType { + return p.destinationVirtualPort.StreamType() } -// SetDestinationPort sets the packet virtual destination port -func (p *PRUDPPacket) SetDestinationPort(destinationPort uint8) { - p.destinationPort = destinationPort +// SetDestinationVirtualPortStreamID sets the packets destination VirtualPort port number +func (p *PRUDPPacket) SetDestinationVirtualPortStreamID(port uint8) { + p.destinationVirtualPort.SetStreamID(port) } -// DestinationPort returns the packet virtual destination port -func (p *PRUDPPacket) DestinationPort() uint8 { - return p.destinationPort +// DestinationVirtualPortStreamID returns the packets destination VirtualPort port number +func (p *PRUDPPacket) DestinationVirtualPortStreamID() uint8 { + return p.destinationVirtualPort.StreamID() } // SessionID returns the packets session ID @@ -148,17 +146,17 @@ func (p *PRUDPPacket) decryptPayload() []byte { // TODO - This assumes a reliable DATA packet. Handle unreliable here? Or do that in a different method? if p.packetType == DataPacket { - substream := p.sender.reliableSubstream(p.SubstreamID()) + slidingWindow := p.sender.SlidingWindow(p.SubstreamID()) // * According to other Quazal server implementations, // * the RC4 stream is always reset to the default key // * regardless if the client is connecting to a secure // * server (prudps) or not - if p.version == 0 && p.sender.server.PRUDPV0Settings.IsQuazalMode { - substream.SetCipherKey([]byte("CD&ML")) + if p.version == 0 && p.sender.Endpoint.Server.PRUDPV0Settings.IsQuazalMode { + slidingWindow.SetCipherKey([]byte("CD&ML")) } - payload = substream.Decrypt(payload) + payload, _ = slidingWindow.streamSettings.EncryptionAlgorithm.Decrypt(payload) } return payload @@ -193,7 +191,7 @@ func (p *PRUDPPacket) SetRMCMessage(message *RMCMessage) { func (p *PRUDPPacket) processUnreliableCrypto() []byte { // * Since unreliable DATA packets can come in out of // * order, each packet uses a dedicated RC4 stream - uniqueKey := p.sender.unreliableBaseKey[:] + uniqueKey := p.sender.UnreliablePacketBaseKey[:] uniqueKey[0] = byte((uint16(uniqueKey[0]) + p.sequenceID) & 0xFF) uniqueKey[1] = byte((uint16(uniqueKey[1]) + (p.sequenceID >> 8)) & 0xFF) uniqueKey[31] = byte((uniqueKey[31] + p.sessionID) & 0xFF) diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index 94f441aa..0349a002 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -14,14 +14,14 @@ type PRUDPPacketInterface interface { AddFlag(flag uint16) SetType(packetType uint16) Type() uint16 - SetSourceStreamType(sourceStreamType uint8) - SourceStreamType() uint8 - SetSourcePort(sourcePort uint8) - SourcePort() uint8 - SetDestinationStreamType(destinationStreamType uint8) - DestinationStreamType() uint8 - SetDestinationPort(destinationPort uint8) - DestinationPort() uint8 + SetSourceVirtualPortStreamType(streamType StreamType) + SourceVirtualPortStreamType() StreamType + SetSourceVirtualPortStreamID(port uint8) + SourceVirtualPortStreamID() uint8 + SetDestinationVirtualPortStreamType(streamType StreamType) + DestinationVirtualPortStreamType() StreamType + SetDestinationVirtualPortStreamID(port uint8) + DestinationVirtualPortStreamID() uint8 SessionID() uint8 SetSessionID(sessionID uint8) SubstreamID() uint8 diff --git a/prudp_packet_lite.go b/prudp_packet_lite.go index 20506dc0..0c595b00 100644 --- a/prudp_packet_lite.go +++ b/prudp_packet_lite.go @@ -11,25 +11,69 @@ import ( // PRUDPPacketLite represents a PRUDPLite packet type PRUDPPacketLite struct { PRUDPPacket - optionsLength uint8 - minorVersion uint32 - supportedFunctions uint32 - maximumSubstreamID uint8 - initialUnreliableSequenceID uint16 - liteSignature []byte + sourceVirtualPortStreamType StreamType + sourceVirtualPortStreamID uint8 + destinationVirtualPortStreamType StreamType + destinationVirtualPortStreamID uint8 + optionsLength uint8 + minorVersion uint32 + supportedFunctions uint32 + maximumSubstreamID uint8 + initialUnreliableSequenceID uint16 + liteSignature []byte +} + +// SetSourceVirtualPortStreamType sets the packets source VirtualPort StreamType +func (p *PRUDPPacketLite) SetSourceVirtualPortStreamType(streamType StreamType) { + p.sourceVirtualPortStreamType = streamType +} + +// SourceVirtualPortStreamType returns the packets source VirtualPort StreamType +func (p *PRUDPPacketLite) SourceVirtualPortStreamType() StreamType { + return p.sourceVirtualPortStreamType +} + +// SetSourceVirtualPortStreamID sets the packets source VirtualPort port number +func (p *PRUDPPacketLite) SetSourceVirtualPortStreamID(port uint8) { + p.sourceVirtualPortStreamID = port +} + +// SourceVirtualPortStreamID returns the packets source VirtualPort port number +func (p *PRUDPPacketLite) SourceVirtualPortStreamID() uint8 { + return p.sourceVirtualPort.StreamID() +} + +// SetDestinationVirtualPortStreamType sets the packets destination VirtualPort StreamType +func (p *PRUDPPacketLite) SetDestinationVirtualPortStreamType(streamType StreamType) { + p.destinationVirtualPortStreamType = streamType +} + +// DestinationVirtualPortStreamType returns the packets destination VirtualPort StreamType +func (p *PRUDPPacketLite) DestinationVirtualPortStreamType() StreamType { + return p.destinationVirtualPortStreamType +} + +// SetDestinationVirtualPortStreamID sets the packets destination VirtualPort port number +func (p *PRUDPPacketLite) SetDestinationVirtualPortStreamID(port uint8) { + p.destinationVirtualPortStreamID = port +} + +// DestinationVirtualPortStreamID returns the packets destination VirtualPort port number +func (p *PRUDPPacketLite) DestinationVirtualPortStreamID() uint8 { + return p.destinationVirtualPortStreamID } // Copy copies the packet into a new PRUDPPacketLite // -// Retains the same PRUDPClient pointer +// Retains the same PRUDPConnection pointer func (p *PRUDPPacketLite) Copy() PRUDPPacketInterface { copied, _ := NewPRUDPPacketLite(p.sender, nil) copied.server = p.server - copied.sourceStreamType = p.sourceStreamType - copied.sourcePort = p.sourcePort - copied.destinationStreamType = p.destinationStreamType - copied.destinationPort = p.destinationPort + copied.sourceVirtualPortStreamType = p.sourceVirtualPortStreamType + copied.sourceVirtualPortStreamID = p.sourceVirtualPortStreamID + copied.destinationVirtualPortStreamType = p.destinationVirtualPortStreamType + copied.destinationVirtualPortStreamID = p.destinationVirtualPortStreamID copied.packetType = p.packetType copied.flags = p.flags copied.sessionID = p.sessionID @@ -95,15 +139,15 @@ func (p *PRUDPPacketLite) decode() error { return fmt.Errorf("Failed to decode PRUDPLite virtual ports stream types. %s", err.Error()) } - p.sourceStreamType = streamTypes >> 4 - p.destinationStreamType = streamTypes & 0xF + p.sourceVirtualPortStreamType = StreamType(streamTypes >> 4) + p.destinationVirtualPortStreamType = StreamType(streamTypes & 0xF) - p.sourcePort, err = p.readStream.ReadPrimitiveUInt8() + p.sourceVirtualPortStreamID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite virtual source port. %s", err.Error()) } - p.destinationPort, err = p.readStream.ReadPrimitiveUInt8() + p.destinationVirtualPortStreamID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to decode PRUDPLite virtual destination port. %s", err.Error()) } @@ -145,9 +189,9 @@ func (p *PRUDPPacketLite) Bytes() []byte { stream.WritePrimitiveUInt8(0x80) stream.WritePrimitiveUInt8(uint8(len(options))) stream.WritePrimitiveUInt16LE(uint16(len(p.payload))) - stream.WritePrimitiveUInt8((p.sourceStreamType << 4) | p.destinationStreamType) - stream.WritePrimitiveUInt8(p.sourcePort) - stream.WritePrimitiveUInt8(p.destinationPort) + stream.WritePrimitiveUInt8(uint8((p.sourceVirtualPortStreamType << 4) | p.destinationVirtualPortStreamType)) + stream.WritePrimitiveUInt8(p.sourceVirtualPortStreamID) + stream.WritePrimitiveUInt8(p.destinationVirtualPortStreamID) stream.WritePrimitiveUInt8(p.fragmentID) stream.WritePrimitiveUInt16LE(p.packetType | (p.flags << 4)) stream.WritePrimitiveUInt16LE(p.sequenceID) @@ -276,10 +320,10 @@ func (p *PRUDPPacketLite) calculateSignature(sessionKey, connectionSignature []b } // NewPRUDPPacketLite creates and returns a new PacketLite using the provided Client and stream -func NewPRUDPPacketLite(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketLite, error) { +func NewPRUDPPacketLite(connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketLite, error) { packet := &PRUDPPacketLite{ PRUDPPacket: PRUDPPacket{ - sender: client, + sender: connection, readStream: readStream, }, } @@ -292,19 +336,19 @@ func NewPRUDPPacketLite(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPa } } - if client != nil { - packet.server = client.server + if connection != nil { + packet.server = connection.Endpoint.Server } return packet, nil } // NewPRUDPPacketsLite reads all possible PRUDPLite packets from the stream -func NewPRUDPPacketsLite(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsLite(connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { - packet, err := NewPRUDPPacketLite(client, readStream) + packet, err := NewPRUDPPacketLite(connection, readStream) if err != nil { return packets, err } diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 984906f0..3b2f05af 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -17,15 +17,13 @@ type PRUDPPacketV0 struct { // Copy copies the packet into a new PRUDPPacketV0 // -// Retains the same PRUDPClient pointer +// Retains the same PRUDPConnection pointer func (p *PRUDPPacketV0) Copy() PRUDPPacketInterface { copied, _ := NewPRUDPPacketV0(p.sender, nil) copied.server = p.server - copied.sourceStreamType = p.sourceStreamType - copied.sourcePort = p.sourcePort - copied.destinationStreamType = p.destinationStreamType - copied.destinationPort = p.destinationPort + copied.sourceVirtualPort = p.sourceVirtualPort + copied.destinationVirtualPort = p.destinationVirtualPort copied.packetType = p.packetType copied.flags = p.flags copied.sessionID = p.sessionID @@ -73,15 +71,14 @@ func (p *PRUDPPacketV0) decode() error { return fmt.Errorf("Failed to read PRUDPv0 source. %s", err.Error()) } + p.sourceVirtualPort = VirtualPort(source) + destination, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 destination. %s", err.Error()) } - p.sourceStreamType = source >> 4 - p.sourcePort = source & 0xF - p.destinationStreamType = destination >> 4 - p.destinationPort = destination & 0xF + p.destinationVirtualPort = VirtualPort(destination) if server.PRUDPV0Settings.IsQuazalMode { typeAndFlags, err := p.readStream.ReadPrimitiveUInt8() @@ -198,8 +195,8 @@ func (p *PRUDPPacketV0) Bytes() []byte { server := p.server stream := NewByteStreamOut(server) - stream.WritePrimitiveUInt8(p.sourcePort | (p.sourceStreamType << 4)) - stream.WritePrimitiveUInt8(p.destinationPort | (p.destinationStreamType << 4)) + stream.WritePrimitiveUInt8(uint8(p.sourceVirtualPort)) + stream.WritePrimitiveUInt8(uint8(p.destinationVirtualPort)) if server.PRUDPV0Settings.IsQuazalMode { stream.WritePrimitiveUInt8(uint8(p.packetType | (p.flags << 3))) @@ -250,10 +247,10 @@ func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byt } // NewPRUDPPacketV0 creates and returns a new PacketV0 using the provided Client and stream -func NewPRUDPPacketV0(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketV0, error) { +func NewPRUDPPacketV0(connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketV0, error) { packet := &PRUDPPacketV0{ PRUDPPacket: PRUDPPacket{ - sender: client, + sender: connection, readStream: readStream, version: 0, }, @@ -267,19 +264,19 @@ func NewPRUDPPacketV0(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPack } } - if client != nil { - packet.server = client.server + if connection != nil { + packet.server = connection.Endpoint.Server } return packet, nil } // NewPRUDPPacketsV0 reads all possible PRUDPv0 packets from the stream -func NewPRUDPPacketsV0(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsV0(connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { - packet, err := NewPRUDPPacketV0(client, readStream) + packet, err := NewPRUDPPacketV0(connection, readStream) if err != nil { return packets, err } diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 12ad7ab6..24de698c 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -23,15 +23,13 @@ type PRUDPPacketV1 struct { // Copy copies the packet into a new PRUDPPacketV1 // -// Retains the same PRUDPClient pointer +// Retains the same PRUDPConnection pointer func (p *PRUDPPacketV1) Copy() PRUDPPacketInterface { copied, _ := NewPRUDPPacketV1(p.sender, nil) copied.server = p.server - copied.sourceStreamType = p.sourceStreamType - copied.sourcePort = p.sourcePort - copied.destinationStreamType = p.destinationStreamType - copied.destinationPort = p.destinationPort + copied.sourceVirtualPort = p.sourceVirtualPort + copied.destinationVirtualPort = p.destinationVirtualPort copied.packetType = p.packetType copied.flags = p.flags copied.sessionID = p.sessionID @@ -158,15 +156,14 @@ func (p *PRUDPPacketV1) decodeHeader() error { return fmt.Errorf("Failed to read PRUDPv1 source. %s", err.Error()) } + p.sourceVirtualPort = VirtualPort(source) + destination, err := p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 destination. %s", err.Error()) } - p.sourceStreamType = source >> 4 - p.sourcePort = source & 0xF - p.destinationStreamType = destination >> 4 - p.destinationPort = destination & 0xF + p.destinationVirtualPort = VirtualPort(destination) // TODO - Does QRV also encode it this way in PRUDPv1? typeAndFlags, err := p.readStream.ReadPrimitiveUInt16LE() @@ -205,8 +202,8 @@ func (p *PRUDPPacketV1) encodeHeader() []byte { stream.WritePrimitiveUInt8(1) // * Version stream.WritePrimitiveUInt8(p.optionsLength) stream.WritePrimitiveUInt16LE(uint16(len(p.payload))) - stream.WritePrimitiveUInt8(p.sourcePort | (p.sourceStreamType << 4)) - stream.WritePrimitiveUInt8(p.destinationPort | (p.destinationStreamType << 4)) + stream.WritePrimitiveUInt8(uint8(p.sourceVirtualPort)) + stream.WritePrimitiveUInt8(uint8(p.destinationVirtualPort)) stream.WritePrimitiveUInt16LE(p.packetType | (p.flags << 4)) // TODO - Does QRV also encode it this way in PRUDPv1? stream.WritePrimitiveUInt8(p.sessionID) stream.WritePrimitiveUInt8(p.substreamID) @@ -356,10 +353,10 @@ func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byt } // NewPRUDPPacketV1 creates and returns a new PacketV1 using the provided Client and stream -func NewPRUDPPacketV1(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPacketV1, error) { +func NewPRUDPPacketV1(connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketV1, error) { packet := &PRUDPPacketV1{ PRUDPPacket: PRUDPPacket{ - sender: client, + sender: connection, readStream: readStream, version: 1, }, @@ -373,19 +370,19 @@ func NewPRUDPPacketV1(client *PRUDPClient, readStream *ByteStreamIn) (*PRUDPPack } } - if client != nil { - packet.server = client.server + if connection != nil { + packet.server = connection.Endpoint.Server } return packet, nil } // NewPRUDPPacketsV1 reads all possible PRUDPv1 packets from the stream -func NewPRUDPPacketsV1(client *PRUDPClient, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsV1(connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { - packet, err := NewPRUDPPacketV1(client, readStream) + packet, err := NewPRUDPPacketV1(connection, readStream) if err != nil { return packets, err } diff --git a/prudp_server.go b/prudp_server.go index 4ab62d0e..7cfa9712 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -3,15 +3,11 @@ package nex import ( "bytes" "crypto/rand" - "encoding/binary" - "errors" "fmt" "net" "runtime" - "slices" "time" - "github.com/PretendoNetwork/nex-go/compression" "github.com/PretendoNetwork/nex-go/types" "github.com/lxzan/gws" ) @@ -20,11 +16,8 @@ import ( type PRUDPServer struct { udpSocket *net.UDPConn websocketServer *WebSocketServer - PRUDPVersion int - PRUDPMinorVersion uint32 - virtualServers *MutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]] - VirtualServerPorts []uint8 - SecureVirtualServerPorts []uint8 + Endpoints *MutexMap[uint8, *PRUDPEndPoint] + Connections *MutexMap[string, *SocketConnection] SupportedFunctions uint32 accessKey string kerberosPassword []byte @@ -39,57 +32,22 @@ type PRUDPServer struct { messagingProtocolVersion *LibraryVersion utilityProtocolVersion *LibraryVersion natTraversalProtocolVersion *LibraryVersion - prudpEventHandlers map[string][]func(packet PacketInterface) - clientRemovedEventHandlers []func(client *PRUDPClient) - connectionIDCounter *Counter[uint32] pingTimeout time.Duration passwordFromPIDHandler func(pid *types.PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte - CompressionAlgorithm compression.Algorithm byteStreamSettings *ByteStreamSettings PRUDPV0Settings *PRUDPV0Settings } -// OnData adds an event handler which is fired when a new DATA packet is received -func (s *PRUDPServer) OnData(handler func(packet PacketInterface)) { - s.on("data", handler) -} - -// OnDisconnect adds an event handler which is fired when a new DISCONNECT packet is received -// -// To handle a client being removed from the server, see OnClientRemoved which fires on more cases -func (s *PRUDPServer) OnDisconnect(handler func(packet PacketInterface)) { - s.on("disconnect", handler) -} - -// OnClientRemoved adds an event handler which is fired when a client is removed from the server -// -// Fires both on a natural disconnect and from a timeout -func (s *PRUDPServer) OnClientRemoved(handler func(client *PRUDPClient)) { - // * "removed" events are a special case, so handle them separately - s.clientRemovedEventHandlers = append(s.clientRemovedEventHandlers, handler) -} - -func (s *PRUDPServer) on(name string, handler func(packet PacketInterface)) { - if _, ok := s.prudpEventHandlers[name]; !ok { - s.prudpEventHandlers[name] = make([]func(packet PacketInterface), 0) - } - - s.prudpEventHandlers[name] = append(s.prudpEventHandlers[name], handler) -} - -func (s *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { - if handlers, ok := s.prudpEventHandlers[name]; ok { - for _, handler := range handlers { - go handler(packet) - } +// BindPRUDPEndPoint binds a provided PRUDPEndPoint to the server +func (s *PRUDPServer) BindPRUDPEndPoint(endpoint *PRUDPEndPoint) { + if s.Endpoints.Has(endpoint.StreamID) { + logger.Warningf("Tried to bind already existing PRUDPEndPoint %d", endpoint.StreamID) + return } -} -func (s *PRUDPServer) emitRemoved(client *PRUDPClient) { - for _, handler := range s.clientRemovedEventHandlers { - go handler(client) - } + endpoint.Server = s + s.Endpoints.Set(endpoint.StreamID, endpoint) } // Listen is an alias of ListenUDP. Implemented to conform to the ServerInterface @@ -100,7 +58,6 @@ func (s *PRUDPServer) Listen(port int) { // ListenUDP starts a PRUDP server on a given port using a UDP server func (s *PRUDPServer) ListenUDP(port int) { s.initPRUDPv1ConnectionSignatureKey() - s.initVirtualPorts() udpAddress, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err != nil { @@ -123,10 +80,29 @@ func (s *PRUDPServer) ListenUDP(port int) { <-quit } +func (s *PRUDPServer) listenDatagram(quit chan struct{}) { + var err error + + for err == nil { + buffer := make([]byte, 64000) + var read int + var addr *net.UDPAddr + + read, addr, err = s.udpSocket.ReadFromUDP(buffer) + packetData := buffer[:read] + + err = s.handleSocketMessage(packetData, addr, nil) + } + + quit <- struct{}{} + + panic(err) +} + // ListenWebSocket starts a PRUDP server on a given port using a WebSocket server func (s *PRUDPServer) ListenWebSocket(port int) { s.initPRUDPv1ConnectionSignatureKey() - s.initVirtualPorts() + //s.initVirtualPorts() s.websocketServer = &WebSocketServer{ prudpServer: s, @@ -138,7 +114,7 @@ func (s *PRUDPServer) ListenWebSocket(port int) { // ListenWebSocketSecure starts a PRUDP server on a given port using a secure (TLS) WebSocket server func (s *PRUDPServer) ListenWebSocketSecure(port int, certFile, keyFile string) { s.initPRUDPv1ConnectionSignatureKey() - s.initVirtualPorts() + //s.initVirtualPorts() s.websocketServer = &WebSocketServer{ prudpServer: s, @@ -158,46 +134,6 @@ func (s *PRUDPServer) initPRUDPv1ConnectionSignatureKey() { } } -func (s *PRUDPServer) initVirtualPorts() { - for _, port := range s.VirtualServerPorts { - virtualServer := NewMutexMap[uint8, *MutexMap[string, *PRUDPClient]]() - virtualServer.Set(VirtualStreamTypeDO, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeRV, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeOldRVSec, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeSBMGMT, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeNAT, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeSessionDiscovery, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeNATEcho, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeRouting, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeGame, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeRVSecure, NewMutexMap[string, *PRUDPClient]()) - virtualServer.Set(VirtualStreamTypeRelay, NewMutexMap[string, *PRUDPClient]()) - - s.virtualServers.Set(port, virtualServer) - } - - logger.Success("Virtual ports created") -} - -func (s *PRUDPServer) listenDatagram(quit chan struct{}) { - var err error - - for err == nil { - buffer := make([]byte, 64000) - var read int - var addr *net.UDPAddr - - read, addr, err = s.udpSocket.ReadFromUDP(buffer) - packetData := buffer[:read] - - err = s.handleSocketMessage(packetData, addr, nil) - } - - quit <- struct{}{} - - panic(err) -} - func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *gws.Conn) error { readStream := NewByteStreamIn(packetData, s) @@ -223,504 +159,56 @@ func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, w } func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *gws.Conn) { - if !slices.Contains(s.VirtualServerPorts, packet.DestinationPort()) { - logger.Warningf("Client %s trying to connect to unbound server vport %d", address.String(), packet.DestinationPort()) - return - } - - if packet.DestinationStreamType() > VirtualStreamTypeRelay { - logger.Warningf("Client %s trying to use invalid to server stream type %d", address.String(), packet.DestinationStreamType()) + if !s.Endpoints.Has(packet.DestinationVirtualPortStreamID()) { + logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID()) return } - virtualServer, _ := s.virtualServers.Get(packet.DestinationPort()) - virtualServerStream, _ := virtualServer.Get(packet.DestinationStreamType()) - - discriminator := fmt.Sprintf("%s-%d-%d", address.String(), packet.SourcePort(), packet.SourceStreamType()) - - client, ok := virtualServerStream.Get(discriminator) - + endpoint, ok := s.Endpoints.Get(packet.DestinationVirtualPortStreamID()) if !ok { - client = NewPRUDPClient(s, address, webSocketConnection) - client.startHeartbeat() - - // * Fail-safe. If the server reboots, then - // * clients has no record of old clients. - // * An existing client which has not killed - // * the connection on it's end MAY still send - // * DATA packets once the server is back - // * online, assuming it reboots fast enough. - // * Since the client did NOT redo the SYN - // * and CONNECT packets, it's reliable - // * substreams never got remade. This is put - // * in place to ensure there is always AT - // * LEAST one substream in place, so the client - // * can naturally error out due to the RC4 - // * errors. - // * - // * NOTE: THE CLIENT MAY NOT HAVE THE REAL - // * CORRECT NUMBER OF SUBSTREAMS HERE. THIS - // * IS ONLY DONE TO PREVENT A SERVER CRASH, - // * NOT TO SAVE THE CLIENT. THE CLIENT IS - // * EXPECTED TO NATURALLY DIE HERE - client.createReliableSubstreams(0) - - virtualServerStream.Set(discriminator, client) - } - - packet.SetSender(client) - client.resetHeartbeat() - - if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { - s.handleAcknowledgment(packet) + logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID()) return } - switch packet.Type() { - case SynPacket: - s.handleSyn(packet) - case ConnectPacket: - s.handleConnect(packet) - case DataPacket: - s.handleData(packet) - case DisconnectPacket: - s.handleDisconnect(packet) - case PingPacket: - s.handlePing(packet) - } -} - -func (s *PRUDPServer) handleAcknowledgment(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagMultiAck) { - s.handleMultiAcknowledgment(packet) + if packet.DestinationVirtualPortStreamType() != packet.SourceVirtualPortStreamType() { + logger.Warningf("Client %s trying to use non matching destination and source stream types %d and %d", address.String(), packet.DestinationVirtualPortStreamType(), packet.SourceVirtualPortStreamType()) return } - client := packet.Sender().(*PRUDPClient) - - substream := client.reliableSubstream(packet.SubstreamID()) - substream.ResendScheduler.AcknowledgePacket(packet.SequenceID()) -} - -func (s *PRUDPServer) handleMultiAcknowledgment(packet PRUDPPacketInterface) { - client := packet.Sender().(*PRUDPClient) - stream := NewByteStreamIn(packet.Payload(), s) - sequenceIDs := make([]uint16, 0) - var baseSequenceID uint16 - var substream *ReliablePacketSubstreamManager - - if packet.SubstreamID() == 1 { - // * New aggregate acknowledgment packets set this to 1 - // * and encode the real substream ID in in the payload - substreamID, _ := stream.ReadPrimitiveUInt8() - additionalIDsCount, _ := stream.ReadPrimitiveUInt8() - baseSequenceID, _ = stream.ReadPrimitiveUInt16LE() - substream = client.reliableSubstream(substreamID) - - for i := 0; i < int(additionalIDsCount); i++ { - additionalID, _ := stream.ReadPrimitiveUInt16LE() - sequenceIDs = append(sequenceIDs, additionalID) - } - } else { - // TODO - This is how Kinnay's client handles this, but it doesn't make sense for QRV? Since it can have multiple reliable substreams? - // * Old aggregate acknowledgment packets always use - // * substream 0 - substream = client.reliableSubstream(0) - baseSequenceID = packet.SequenceID() - - for stream.Remaining() > 0 { - additionalID, _ := stream.ReadPrimitiveUInt16LE() - sequenceIDs = append(sequenceIDs, additionalID) - } - } - - // * MutexMap.Each locks the mutex, can't remove while reading. - // * Have to just loop again - substream.ResendScheduler.packets.Each(func(sequenceID uint16, pending *PendingPacket) bool { - if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) { - sequenceIDs = append(sequenceIDs, sequenceID) - } - - return false - }) - - // * Actually remove the packets from the pool - for _, sequenceID := range sequenceIDs { - substream.ResendScheduler.AcknowledgePacket(sequenceID) - } -} - -func (s *PRUDPServer) handleSyn(packet PRUDPPacketInterface) { - client := packet.Sender().(*PRUDPClient) - - var ack PRUDPPacketInterface - - if packet.Version() == 2 { - ack, _ = NewPRUDPPacketLite(client, nil) - } else if packet.Version() == 1 { - ack, _ = NewPRUDPPacketV1(client, nil) - } else { - ack, _ = NewPRUDPPacketV0(client, nil) - } - - connectionSignature, err := packet.calculateConnectionSignature(client.address) - if err != nil { - logger.Error(err.Error()) - } - - client.reset() - client.clientConnectionSignature = connectionSignature - client.SourceStreamType = packet.SourceStreamType() - client.SourcePort = packet.SourcePort() - client.DestinationStreamType = packet.DestinationStreamType() - client.DestinationPort = packet.DestinationPort() - - ack.SetType(SynPacket) - ack.AddFlag(FlagAck) - ack.AddFlag(FlagHasSize) - ack.SetSourceStreamType(packet.DestinationStreamType()) - ack.SetSourcePort(packet.DestinationPort()) - ack.SetDestinationStreamType(packet.SourceStreamType()) - ack.SetDestinationPort(packet.SourcePort()) - ack.setConnectionSignature(connectionSignature) - ack.setSignature(ack.calculateSignature([]byte{}, []byte{})) - - if ack, ok := ack.(*PRUDPPacketV1); ok { - // * Negotiate with the client what we support - ack.maximumSubstreamID = packet.(*PRUDPPacketV1).maximumSubstreamID // * No change needed, we can just support what the client wants - ack.minorVersion = packet.(*PRUDPPacketV1).minorVersion // * No change needed, we can just support what the client wants - ack.supportedFunctions = s.SupportedFunctions & packet.(*PRUDPPacketV1).supportedFunctions - } - - s.emit("syn", ack) - - s.sendRaw(client, ack.Bytes()) -} - -func (s *PRUDPServer) handleConnect(packet PRUDPPacketInterface) { - client := packet.Sender().(*PRUDPClient) - - var ack PRUDPPacketInterface - - if packet.Version() == 2 { - ack, _ = NewPRUDPPacketLite(client, nil) - } else if packet.Version() == 1 { - ack, _ = NewPRUDPPacketV1(client, nil) - } else { - ack, _ = NewPRUDPPacketV0(client, nil) - } - - client.serverConnectionSignature = packet.getConnectionSignature() - client.clientSessionID = packet.SessionID() - - connectionSignature, err := packet.calculateConnectionSignature(client.address) - if err != nil { - logger.Error(err.Error()) - } - - client.serverSessionID = packet.SessionID() - - ack.SetType(ConnectPacket) - ack.AddFlag(FlagAck) - ack.AddFlag(FlagHasSize) - ack.SetSourceStreamType(packet.DestinationStreamType()) - ack.SetSourcePort(packet.DestinationPort()) - ack.SetDestinationStreamType(packet.SourceStreamType()) - ack.SetDestinationPort(packet.SourcePort()) - ack.setConnectionSignature(make([]byte, len(connectionSignature))) - ack.SetSessionID(client.serverSessionID) - ack.SetSequenceID(1) - - if ack, ok := ack.(*PRUDPPacketV1); ok { - // * At this stage the client and server have already - // * negotiated what they each can support, so configure - // * the client now and just send the client back the - // * negotiated configuration - ack.maximumSubstreamID = packet.(*PRUDPPacketV1).maximumSubstreamID - ack.minorVersion = packet.(*PRUDPPacketV1).minorVersion - ack.supportedFunctions = packet.(*PRUDPPacketV1).supportedFunctions - - client.minorVersion = ack.minorVersion - client.supportedFunctions = ack.supportedFunctions - client.createReliableSubstreams(ack.maximumSubstreamID) - client.outgoingUnreliableSequenceIDCounter = NewCounter[uint16](packet.(*PRUDPPacketV1).initialUnreliableSequenceID) - } else { - client.createReliableSubstreams(0) - } - - payload := make([]byte, 0) - - if slices.Contains(s.SecureVirtualServerPorts, packet.DestinationPort()) { - sessionKey, pid, checkValue, err := s.readKerberosTicket(packet.Payload()) - if err != nil { - logger.Error(err.Error()) - } - - client.SetPID(pid) - client.setSessionKey(sessionKey) - - responseCheckValue := checkValue + 1 - responseCheckValueBytes := make([]byte, 4) - - binary.LittleEndian.PutUint32(responseCheckValueBytes, responseCheckValue) - - checkValueResponse := types.NewBuffer(responseCheckValueBytes) - stream := NewByteStreamOut(s) - - checkValueResponse.WriteTo(stream) - - payload = stream.Bytes() - } - - ack.SetPayload(payload) - ack.setSignature(ack.calculateSignature([]byte{}, packet.getConnectionSignature())) - - s.emit("connect", ack) - - s.sendRaw(client, ack.Bytes()) -} - -func (s *PRUDPServer) handleData(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagReliable) { - s.handleReliable(packet) - } else { - s.handleUnreliable(packet) - } -} - -func (s *PRUDPServer) handleDisconnect(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagNeedsAck) { - s.acknowledgePacket(packet) - } - - virtualServer, _ := s.virtualServers.Get(packet.DestinationPort()) - virtualServerStream, _ := virtualServer.Get(packet.DestinationStreamType()) - - client := packet.Sender().(*PRUDPClient) - discriminator := fmt.Sprintf("%s-%d-%d", client.address.String(), packet.SourcePort(), packet.SourceStreamType()) - - client.cleanup() // * "removed" event is dispatched here - virtualServerStream.Delete(discriminator) - - s.emit("disconnect", packet) -} - -func (s *PRUDPServer) handlePing(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagNeedsAck) { - s.acknowledgePacket(packet) - } -} - -func (s *PRUDPServer) readKerberosTicket(payload []byte) ([]byte, *types.PID, uint32, error) { - stream := NewByteStreamIn(payload, s) - - ticketData := types.NewBuffer(nil) - if err := ticketData.ExtractFrom(stream); err != nil { - return nil, nil, 0, err - } - - requestData := types.NewBuffer(nil) - if err := requestData.ExtractFrom(stream); err != nil { - return nil, nil, 0, err - } - - serverKey := DeriveKerberosKey(types.NewPID(2), s.kerberosPassword) - - ticket := NewKerberosTicketInternalData() - if err := ticket.Decrypt(NewByteStreamIn(ticketData.Value, s), serverKey); err != nil { - return nil, nil, 0, err - } - - ticketTime := ticket.Issued.Standard() - serverTime := time.Now().UTC() - - timeLimit := ticketTime.Add(time.Minute * 2) - if serverTime.After(timeLimit) { - return nil, nil, 0, errors.New("Kerberos ticket expired") - } - - sessionKey := ticket.SessionKey - kerberos := NewKerberosEncryption(sessionKey) - - decryptedRequestData, err := kerberos.Decrypt(requestData.Value) - if err != nil { - return nil, nil, 0, err - } - - checkDataStream := NewByteStreamIn(decryptedRequestData, s) - - userPID := types.NewPID(0) - if err := userPID.ExtractFrom(checkDataStream); err != nil { - return nil, nil, 0, err - } - - _, err = checkDataStream.ReadPrimitiveUInt32LE() // * CID of secure server station url - if err != nil { - return nil, nil, 0, err - } - - responseCheck, err := checkDataStream.ReadPrimitiveUInt32LE() - if err != nil { - return nil, nil, 0, err + if packet.DestinationVirtualPortStreamType() > StreamTypeRelay { + logger.Warningf("Client %s trying to use invalid to destination stream type %d", address.String(), packet.DestinationVirtualPortStreamType()) + return } - return sessionKey, userPID, responseCheck, nil -} - -func (s *PRUDPServer) acknowledgePacket(packet PRUDPPacketInterface) { - var ack PRUDPPacketInterface - - if packet.Version() == 2 { - ack, _ = NewPRUDPPacketLite(packet.Sender().(*PRUDPClient), nil) - } else if packet.Version() == 1 { - ack, _ = NewPRUDPPacketV1(packet.Sender().(*PRUDPClient), nil) - } else { - ack, _ = NewPRUDPPacketV0(packet.Sender().(*PRUDPClient), nil) + if packet.SourceVirtualPortStreamType() > StreamTypeRelay { + logger.Warningf("Client %s trying to use invalid to source stream type %d", address.String(), packet.DestinationVirtualPortStreamType()) + return } - ack.SetType(packet.Type()) - ack.AddFlag(FlagAck) - ack.SetSourceStreamType(packet.DestinationStreamType()) - ack.SetSourcePort(packet.DestinationPort()) - ack.SetDestinationStreamType(packet.SourceStreamType()) - ack.SetDestinationPort(packet.SourcePort()) - ack.SetSequenceID(packet.SequenceID()) - ack.setFragmentID(packet.getFragmentID()) - ack.SetSubstreamID(packet.SubstreamID()) - - s.sendPacket(ack) - - // * Servers send the DISCONNECT ACK 3 times - if packet.Type() == DisconnectPacket { - s.sendPacket(ack) - s.sendPacket(ack) - } -} + sourcePortNumber := packet.SourceVirtualPortStreamID() + invalidSourcePort := false -func (s *PRUDPServer) handleReliable(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagNeedsAck) { - s.acknowledgePacket(packet) + // * PRUDPLite packets can use port numbers 0-31 + // * PRUDPv0 and PRUDPv1 can use port numbers 0-15 + if _, ok := packet.(*PRUDPPacketLite); ok && sourcePortNumber > 31 { + invalidSourcePort = true + } else if sourcePortNumber > 15 { + invalidSourcePort = true } - substream := packet.Sender().(*PRUDPClient).reliableSubstream(packet.SubstreamID()) - - for _, pendingPacket := range substream.Update(packet) { - if packet.Type() == DataPacket { - var decryptedPayload []byte - - if packet.Version() != 2 { - decryptedPayload = pendingPacket.decryptPayload() - } else { - // * PRUDPLite does not encrypt payloads - decryptedPayload = pendingPacket.Payload() - } - - decompressedPayload, err := s.CompressionAlgorithm.Decompress(decryptedPayload) - if err != nil { - logger.Error(err.Error()) - } - - payload := substream.AddFragment(decompressedPayload) - - if packet.getFragmentID() == 0 { - message := NewRMCMessage(s) - err := message.FromBytes(payload) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - } - - substream.ResetFragmentedPayload() - - packet.SetRMCMessage(message) - - s.emit("data", packet) - } - } - } -} - -func (s *PRUDPServer) handleUnreliable(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagNeedsAck) { - s.acknowledgePacket(packet) - } - - // * Since unreliable DATA packets can in theory reach the - // * server in any order, and they lack a substream, it's - // * not actually possible to know what order they should - // * be processed in for each request. So assume all packets - // * MUST be fragment 0 (unreliable packets do not have frags) - // * - // * Example - - // * - // * Say there is 2 requests to the same protocol, methods 1 - // * and 2. The starting unreliable sequence ID is 10. If both - // * method 1 and 2 are called at the same time, but method 1 - // * has a fragmented payload, the packets could, in theory, reach - // * the server like so: - // * - // * - Method1 - Sequence 10, Fragment 1 - // * - Method1 - Sequence 13, Fragment 3 - // * - Method2 - Sequence 12, Fragment 0 - // * - Method1 - Sequence 11, Fragment 2 - // * - Method1 - Sequence 14, Fragment 0 - // * - // * If we reorder these to the proper order, like so: - // * - // * - Method1 - Sequence 10, Fragment 1 - // * - Method1 - Sequence 11, Fragment 2 - // * - Method2 - Sequence 12, Fragment 0 - // * - Method1 - Sequence 13, Fragment 3 - // * - Method1 - Sequence 14, Fragment 0 - // * - // * We still have a gap where Method2 was called. It's not - // * possible to know if the packet with sequence ID 12 belongs - // * to the Method1 calls or not. We don't even know which methods - // * the packets are for at this stage yet, since the RMC data - // * can't be checked until all the fragments are collected and - // * the payload decrypted. In this case, we would see fragment 0 - // * and assume that's the end of fragments, losing the real last - // * fragments and resulting in a bad decryption - // TODO - Is this actually true? I'm just assuming, based on common sense, tbh. Kinnay also does not implement fragmented unreliable packets? - if packet.getFragmentID() != 0 { - logger.Warningf("Unexpected unreliable fragment ID. Expected 0, got %d", packet.getFragmentID()) + if invalidSourcePort { + logger.Warningf("Client %s trying to use invalid to source port number %d. Port number too large", address.String(), sourcePortNumber) return } - payload := packet.processUnreliableCrypto() - - message := NewRMCMessage(s) - err := message.FromBytes(payload) - if err != nil { - // TODO - Should this return the error too? - logger.Error(err.Error()) - } - - packet.SetRMCMessage(message) - - s.emit("data", packet) -} - -func (s *PRUDPServer) sendPing(client *PRUDPClient) { - var ping PRUDPPacketInterface - - if s.websocketServer != nil { - ping, _ = NewPRUDPPacketLite(client, nil) - } else if s.PRUDPVersion == 0 { - ping, _ = NewPRUDPPacketV0(client, nil) - } else { - ping, _ = NewPRUDPPacketV1(client, nil) + discriminator := address.String() + socket, ok := s.Connections.Get(discriminator) + if !ok { + socket = NewSocketConnection(s, address, webSocketConnection) + s.Connections.Set(discriminator, socket) } - ping.SetType(PingPacket) - ping.AddFlag(FlagNeedsAck) - ping.SetSourceStreamType(client.DestinationStreamType) - ping.SetSourcePort(client.DestinationPort) - ping.SetDestinationStreamType(client.SourceStreamType) - ping.SetDestinationPort(client.SourcePort) - ping.SetSubstreamID(0) - - s.sendPacket(ping) + endpoint.processPacket(packet, socket) } // Send sends the packet to the packets sender @@ -753,74 +241,75 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { // * multiple packets at once, due to the same pointer being // * reused, we must make a copy of the packet being sent packetCopy := packet.Copy() - client := packetCopy.Sender().(*PRUDPClient) + connection := packetCopy.Sender().(*PRUDPConnection) if !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { if packetCopy.HasFlag(FlagReliable) { - substream := client.reliableSubstream(packetCopy.SubstreamID()) - packetCopy.SetSequenceID(substream.NextOutgoingSequenceID()) + slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) + packetCopy.SetSequenceID(slidingWindow.NextOutgoingSequenceID()) } else if packetCopy.Type() == DataPacket { - packetCopy.SetSequenceID(client.nextOutgoingUnreliableSequenceID()) + packetCopy.SetSequenceID(connection.outgoingUnreliableSequenceIDCounter.Next()) } else if packetCopy.Type() == PingPacket { - packetCopy.SetSequenceID(client.nextOutgoingPingSequenceID()) + packetCopy.SetSequenceID(connection.outgoingPingSequenceIDCounter.Next()) } else { packetCopy.SetSequenceID(0) } } - packetCopy.SetSessionID(client.serverSessionID) + packetCopy.SetSessionID(connection.ServerSessionID) if packetCopy.Type() == DataPacket && !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { if packetCopy.HasFlag(FlagReliable) { + slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) payload := packetCopy.Payload() - compressedPayload, err := s.CompressionAlgorithm.Compress(payload) + compressedPayload, err := slidingWindow.streamSettings.CompressionAlgorithm.Compress(payload) if err != nil { logger.Error(err.Error()) } - substream := client.reliableSubstream(packetCopy.SubstreamID()) - // * According to other Quazal server implementations, // * the RC4 stream is always reset to the default key // * regardless if the client is connecting to a secure // * server (prudps) or not if packet.Version() == 0 && s.PRUDPV0Settings.IsQuazalMode { - substream.SetCipherKey([]byte("CD&ML")) + slidingWindow.SetCipherKey([]byte("CD&ML")) } - // * PRUDPLite packet. No RC4 - if packetCopy.Version() != 2 { - packetCopy.SetPayload(substream.Encrypt(compressedPayload)) + encryptedPayload, err := slidingWindow.streamSettings.EncryptionAlgorithm.Encrypt(compressedPayload) + if err != nil { + logger.Error(err.Error()) } + + packetCopy.SetPayload(encryptedPayload) } else { - // * PRUDPLite packet. No RC4 + // * PRUDPLite does not encrypt payloads, since they go over WSS if packetCopy.Version() != 2 { packetCopy.SetPayload(packetCopy.processUnreliableCrypto()) } } } - packetCopy.setSignature(packetCopy.calculateSignature(client.sessionKey, client.serverConnectionSignature)) + packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature)) if packetCopy.HasFlag(FlagReliable) && packetCopy.HasFlag(FlagNeedsAck) { - substream := client.reliableSubstream(packetCopy.SubstreamID()) - substream.ResendScheduler.AddPacket(packetCopy) + slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) + slidingWindow.ResendScheduler.AddPacket(packetCopy) } - s.sendRaw(packetCopy.Sender().(*PRUDPClient), packetCopy.Bytes()) + s.sendRaw(packetCopy.Sender().(*PRUDPConnection).Socket, packetCopy.Bytes()) } -// sendRaw will send the given client the provided packet -func (s *PRUDPServer) sendRaw(client *PRUDPClient, data []byte) { +// sendRaw will send the given socket the provided packet +func (s *PRUDPServer) sendRaw(socket *SocketConnection, data []byte) { // TODO - Should this return the error too? var err error - if s.udpSocket != nil { - _, err = s.udpSocket.WriteToUDP(data, client.address.(*net.UDPAddr)) - } else if client.webSocketConnection != nil { - err = client.webSocketConnection.WriteMessage(gws.OpcodeBinary, data) + if address, ok := socket.Address.(*net.UDPAddr); ok && s.udpSocket != nil { + _, err = s.udpSocket.WriteToUDP(data, address) + } else if socket.WebSocketConnection != nil { + err = socket.WebSocketConnection.WriteMessage(gws.OpcodeBinary, data) } if err != nil { @@ -966,49 +455,6 @@ func (s *PRUDPServer) NATTraversalProtocolVersion() *LibraryVersion { return s.natTraversalProtocolVersion } -// ConnectionIDCounter returns the servers CID counter -func (s *PRUDPServer) ConnectionIDCounter() *Counter[uint32] { - return s.connectionIDCounter -} - -// FindClientByConnectionID returns the PRUDP client connected with the given connection ID -func (s *PRUDPServer) FindClientByConnectionID(serverPort, serverStreamType uint8, connectedID uint32) *PRUDPClient { - var client *PRUDPClient - - virtualServer, _ := s.virtualServers.Get(serverPort) - virtualServerStream, _ := virtualServer.Get(serverStreamType) - - virtualServerStream.Each(func(discriminator string, c *PRUDPClient) bool { - if c.ConnectionID == connectedID { - client = c - return true - } - - return false - }) - - return client -} - -// FindClientByPID returns the PRUDP client connected with the given PID -func (s *PRUDPServer) FindClientByPID(serverPort, serverStreamType uint8, pid uint64) *PRUDPClient { - var client *PRUDPClient - - virtualServer, _ := s.virtualServers.Get(serverPort) - virtualServerStream, _ := virtualServer.Get(serverStreamType) - - virtualServerStream.Each(func(discriminator string, c *PRUDPClient) bool { - if c.pid.Value() == pid { - client = c - return true - } - - return false - }) - - return client -} - // PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result func (s *PRUDPServer) PasswordFromPID(pid *types.PID) (string, uint32) { if s.passwordFromPIDHandler == nil { @@ -1037,16 +483,12 @@ func (s *PRUDPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettin // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ - VirtualServerPorts: []uint8{1}, - SecureVirtualServerPorts: make([]uint8, 0), - virtualServers: NewMutexMap[uint8, *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]](), - kerberosKeySize: 32, - FragmentSize: 1300, - prudpEventHandlers: make(map[string][]func(PacketInterface)), - connectionIDCounter: NewCounter[uint32](10), - pingTimeout: time.Second * 15, - CompressionAlgorithm: compression.NewDummyCompression(), - byteStreamSettings: NewByteStreamSettings(), - PRUDPV0Settings: NewPRUDPV0Settings(), + Endpoints: NewMutexMap[uint8, *PRUDPEndPoint](), + Connections: NewMutexMap[string, *SocketConnection](), + kerberosKeySize: 32, + FragmentSize: 1300, + pingTimeout: time.Second * 15, + byteStreamSettings: NewByteStreamSettings(), + PRUDPV0Settings: NewPRUDPV0Settings(), } } diff --git a/prudp_virtual_stream_types.go b/prudp_virtual_stream_types.go deleted file mode 100644 index b639f0c0..00000000 --- a/prudp_virtual_stream_types.go +++ /dev/null @@ -1,36 +0,0 @@ -package nex - -const ( - // VirtualStreamTypeDO represents the DO PRUDP virtual connection stream type - VirtualStreamTypeDO uint8 = 1 - - // VirtualStreamTypeRV represents the RV PRUDP virtual connection stream type - VirtualStreamTypeRV uint8 = 2 - - // VirtualStreamTypeOldRVSec represents the OldRVSec PRUDP virtual connection stream type - VirtualStreamTypeOldRVSec uint8 = 3 - - // VirtualStreamTypeSBMGMT represents the SBMGMT PRUDP virtual connection stream type - VirtualStreamTypeSBMGMT uint8 = 4 - - // VirtualStreamTypeNAT represents the NAT PRUDP virtual connection stream type - VirtualStreamTypeNAT uint8 = 5 - - // VirtualStreamTypeSessionDiscovery represents the SessionDiscovery PRUDP virtual connection stream type - VirtualStreamTypeSessionDiscovery uint8 = 6 - - // VirtualStreamTypeNATEcho represents the NATEcho PRUDP virtual connection stream type - VirtualStreamTypeNATEcho uint8 = 7 - - // VirtualStreamTypeRouting represents the Routing PRUDP virtual connection stream type - VirtualStreamTypeRouting uint8 = 8 - - // VirtualStreamTypeGame represents the Game PRUDP virtual connection stream type - VirtualStreamTypeGame uint8 = 9 - - // VirtualStreamTypeRVSecure represents the RVSecure PRUDP virtual connection stream type - VirtualStreamTypeRVSecure uint8 = 10 - - // VirtualStreamTypeRelay represents the Relay PRUDP virtual connection stream type - VirtualStreamTypeRelay uint8 = 11 -) diff --git a/resend_scheduler.go b/resend_scheduler.go index df174edc..080f4e0c 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -5,11 +5,13 @@ import ( "time" ) +// TODO - REMOVE THIS ENTIRELY AND REPLACE IT WITH AN IMPLEMENTATION OF rdv::Timeout AND rdv::TimeoutManager AND USE MORE STREAM SETTINGS! + // PendingPacket represends a packet scheduled to be resent type PendingPacket struct { packet PRUDPPacketInterface lastSendTime time.Time - resendCount int + resendCount uint32 isAcknowledged bool interval time.Duration ticker *time.Ticker @@ -32,10 +34,9 @@ func (pi *PendingPacket) startResendTimer() { // ResendScheduler manages the resending of reliable PRUDP packets type ResendScheduler struct { - packets *MutexMap[uint16, *PendingPacket] - MaxResendCount int - Interval time.Duration - Increase time.Duration + packets *MutexMap[uint16, *PendingPacket] + Interval time.Duration + Increase time.Duration } // Stop kills the resend scheduler and stops all pending packets @@ -87,29 +88,29 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { } packet := pendingPacket.packet - client := packet.Sender().(*PRUDPClient) + connection := packet.Sender().(*PRUDPConnection) + slidingWindow := connection.SlidingWindow(packet.SubstreamID()) - if pendingPacket.resendCount >= rs.MaxResendCount { - // * The maximum resend count has been reached, consider the client dead. + if pendingPacket.resendCount >= slidingWindow.streamSettings.MaxPacketRetransmissions { + // * The maximum resend count has been reached, consider the connection dead. pendingPacket.ticker.Stop() rs.packets.Delete(packet.SequenceID()) - client.cleanup() // * "removed" event is dispatched here - - virtualServer, _ := client.server.virtualServers.Get(client.DestinationPort) - virtualServerStream, _ := virtualServer.Get(client.DestinationStreamType) + connection.cleanup() // * "removed" event is dispatched here - discriminator := fmt.Sprintf("%s-%d-%d", client.address.String(), client.SourcePort, client.SourceStreamType) + streamType := packet.SourceVirtualPortStreamType() + streamID := packet.SourceVirtualPortStreamID() + discriminator := fmt.Sprintf("%s-%d-%d", packet.Sender().Address().String(), streamType, streamID) - virtualServerStream.Delete(discriminator) + connection.Endpoint.Connections.Delete(discriminator) return } if time.Since(pendingPacket.lastSendTime) >= rs.Interval { - // * Resend the packet to the client - server := client.server + // * Resend the packet to the connection + server := connection.Endpoint.Server data := packet.Bytes() - server.sendRaw(client, data) + server.sendRaw(connection.Socket, data) pendingPacket.interval += rs.Increase pendingPacket.ticker.Reset(pendingPacket.interval) @@ -125,9 +126,8 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { // after the 1st, and the 3rd will take place 11 seconds after the 2nd func NewResendScheduler(maxResendCount int, interval, increase time.Duration) *ResendScheduler { return &ResendScheduler{ - packets: NewMutexMap[uint16, *PendingPacket](), - MaxResendCount: maxResendCount, - Interval: interval, - Increase: increase, + packets: NewMutexMap[uint16, *PendingPacket](), + Interval: interval, + Increase: increase, } } diff --git a/server_interface.go b/server_interface.go index a3ae1a2a..bed89aaa 100644 --- a/server_interface.go +++ b/server_interface.go @@ -16,7 +16,6 @@ type ServerInterface interface { NATTraversalProtocolVersion() *LibraryVersion SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) - OnData(handler func(packet PacketInterface)) PasswordFromPID(pid *types.PID) (string, uint32) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) ByteStreamSettings() *ByteStreamSettings diff --git a/sliding_window.go b/sliding_window.go new file mode 100644 index 00000000..c5f58002 --- /dev/null +++ b/sliding_window.go @@ -0,0 +1,81 @@ +package nex + +import ( + "time" +) + +// SlidingWindow is an implementation of rdv::SlidingWindow. +// SlidingWindow reorders pending reliable packets to ensure they are handled in the expected order. +// In the original library each virtual connection stream only uses a single SlidingWindow, but starting +// in PRUDPv1 with NEX virtual connections may have multiple reliable substreams and thus multiple SlidingWindows. +type SlidingWindow struct { + pendingPackets *MutexMap[uint16, PRUDPPacketInterface] + incomingSequenceIDCounter *Counter[uint16] + outgoingSequenceIDCounter *Counter[uint16] + streamSettings *StreamSettings + fragmentedPayload []byte + ResendScheduler *ResendScheduler +} + +// Update adds an incoming packet to the list of known packets and returns a list of packets to be processed in order +func (sw *SlidingWindow) Update(packet PRUDPPacketInterface) []PRUDPPacketInterface { + packets := make([]PRUDPPacketInterface, 0) + + if packet.SequenceID() >= sw.incomingSequenceIDCounter.Value && !sw.pendingPackets.Has(packet.SequenceID()) { + sw.pendingPackets.Set(packet.SequenceID(), packet) + + for sw.pendingPackets.Has(sw.incomingSequenceIDCounter.Value) { + storedPacket, _ := sw.pendingPackets.Get(sw.incomingSequenceIDCounter.Value) + packets = append(packets, storedPacket) + sw.pendingPackets.Delete(sw.incomingSequenceIDCounter.Value) + sw.incomingSequenceIDCounter.Next() + } + } + + return packets +} + +// SetCipherKey sets the reliable substreams RC4 cipher keys +func (sw *SlidingWindow) SetCipherKey(key []byte) { + sw.streamSettings.EncryptionAlgorithm.SetKey(key) +} + +// NextOutgoingSequenceID sets the reliable substreams RC4 cipher keys +func (sw *SlidingWindow) NextOutgoingSequenceID() uint16 { + return sw.outgoingSequenceIDCounter.Next() +} + +// Decrypt decrypts the provided data with the substreams decipher +func (sw *SlidingWindow) Decrypt(data []byte) ([]byte, error) { + return sw.streamSettings.EncryptionAlgorithm.Decrypt(data) +} + +// Encrypt encrypts the provided data with the substreams cipher +func (sw *SlidingWindow) Encrypt(data []byte) ([]byte, error) { + return sw.streamSettings.EncryptionAlgorithm.Encrypt(data) +} + +// AddFragment adds the given fragment to the substreams fragmented payload +// Returns the current fragmented payload +func (sw *SlidingWindow) AddFragment(fragment []byte) []byte { + sw.fragmentedPayload = append(sw.fragmentedPayload, fragment...) + + return sw.fragmentedPayload +} + +// ResetFragmentedPayload resets the substreams fragmented payload +func (sw *SlidingWindow) ResetFragmentedPayload() { + sw.fragmentedPayload = make([]byte, 0) +} + +// NewSlidingWindow initializes a new SlidingWindow with a starting counter value. +func NewSlidingWindow() *SlidingWindow { + sw := &SlidingWindow{ + pendingPackets: NewMutexMap[uint16, PRUDPPacketInterface](), + incomingSequenceIDCounter: NewCounter[uint16](0), + outgoingSequenceIDCounter: NewCounter[uint16](0), + ResendScheduler: NewResendScheduler(5, time.Second, 0), + } + + return sw +} diff --git a/socket_connection.go b/socket_connection.go new file mode 100644 index 00000000..941d318d --- /dev/null +++ b/socket_connection.go @@ -0,0 +1,26 @@ +package nex + +import ( + "net" + + "github.com/lxzan/gws" +) + +// SocketConnection represents a single open socket. +// A single socket may have many PRUDP connections open on it. +type SocketConnection struct { + Server *PRUDPServer // * PRUDP server the socket is connected to + Address net.Addr // * Sockets address + WebSocketConnection *gws.Conn // * Only used in PRUDPLite + Connections *MutexMap[uint8, *PRUDPConnection] // * Open PRUDP connections separated by rdv::Stream ID, also called "port number" +} + +// NewSocketConnection creates a new SocketConnection +func NewSocketConnection(server *PRUDPServer, address net.Addr, webSocketConnection *gws.Conn) *SocketConnection { + return &SocketConnection{ + Server: server, + Address: address, + WebSocketConnection: webSocketConnection, + Connections: NewMutexMap[uint8, *PRUDPConnection](), + } +} diff --git a/stream_settings.go b/stream_settings.go new file mode 100644 index 00000000..1cbb8408 --- /dev/null +++ b/stream_settings.go @@ -0,0 +1,70 @@ +package nex + +import ( + "github.com/PretendoNetwork/nex-go/compression" + "github.com/PretendoNetwork/nex-go/encryption" +) + +// StreamSettings is an implementation of rdv::StreamSettings. +// StreamSettings holds the state and settings for a PRUDP virtual connection stream. +// Each virtual connection is composed of a virtual port and stream type. +// In the original library this would be tied to a rdv::Stream class, but here it is not. +// The original library has more settings which are not present here as their use is unknown. +// Not all values are used at this time, and only exist to future-proof for a later time. +type StreamSettings struct { + ExtraRestransmitTimeoutTrigger uint32 // * Unused. The number of times a packet can be retransmitted before ExtraRetransmitTimeoutMultiplier is used + MaxPacketRetransmissions uint32 // *The number of times a packet can be retransmitted before the timeout time is checked + KeepAliveTimeout uint32 // * Unused. Presumably the time a packet can be alive for without acknowledgement? Milliseconds? + ChecksumBase uint32 // * Unused. The base value for PRUDPv0 checksum calculations + FaultDetectionEnabled bool // * Unused. Presumably used to detect PIA faults? + InitialRTT uint32 // * Unused. The connections initial RTT + EncryptionAlgorithm encryption.Algorithm // * The encryption algorithm used for packet payloads + ExtraRetransmitTimeoutMultiplier float32 // * Unused. Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has been reached + WindowSize uint32 // * Unused. The max number of (reliable?) packets allowed in a SlidingWindow + CompressionAlgorithm compression.Algorithm // * The compression algorithm used for packet payloads + RTTRetransmit uint32 // * Unused. Unknown use + RetransmitTimeoutMultiplier float32 // * Unused. Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has not been reached + MaxSilenceTime uint32 // * Unused. Presumably the time a connection can go without any packets from the other side? Milliseconds? +} + +// Copy returns a new copy of the settings +func (ss *StreamSettings) Copy() *StreamSettings { + copied := NewStreamSettings() + + copied.ExtraRestransmitTimeoutTrigger = ss.ExtraRestransmitTimeoutTrigger + copied.MaxPacketRetransmissions = ss.MaxPacketRetransmissions + copied.KeepAliveTimeout = ss.KeepAliveTimeout + copied.ChecksumBase = ss.ChecksumBase + copied.FaultDetectionEnabled = ss.FaultDetectionEnabled + copied.InitialRTT = ss.InitialRTT + copied.EncryptionAlgorithm = ss.EncryptionAlgorithm.Copy() + copied.ExtraRetransmitTimeoutMultiplier = ss.ExtraRetransmitTimeoutMultiplier + copied.WindowSize = ss.WindowSize + copied.CompressionAlgorithm = ss.CompressionAlgorithm.Copy() + copied.RTTRetransmit = ss.RTTRetransmit + copied.RetransmitTimeoutMultiplier = ss.RetransmitTimeoutMultiplier + copied.MaxSilenceTime = ss.MaxSilenceTime + + return copied +} + +// NewStreamSettings returns a new instance of StreamSettings with default params +func NewStreamSettings() *StreamSettings { + // * Default values based on WATCH_DOGS. Not all values are used currently, and only + // * exist to mimic what is seen in that game. Many are planned for future use. + return &StreamSettings{ + ExtraRestransmitTimeoutTrigger: 0x32, + MaxPacketRetransmissions: 0x14, + KeepAliveTimeout: 1000, + ChecksumBase: 0, + FaultDetectionEnabled: true, + InitialRTT: 0xFA, + EncryptionAlgorithm: encryption.NewRC4Encryption(), + ExtraRetransmitTimeoutMultiplier: 1.0, + WindowSize: 8, + CompressionAlgorithm: compression.NewDummyCompression(), + RTTRetransmit: 0x32, + RetransmitTimeoutMultiplier: 1.25, + MaxSilenceTime: 5000, + } +} diff --git a/stream_type.go b/stream_type.go new file mode 100644 index 00000000..8264cbf6 --- /dev/null +++ b/stream_type.go @@ -0,0 +1,49 @@ +package nex + +// TODO - Should this be moved to the types module? + +// StreamType is an implementation of rdv::Stream::Type. +// StreamType is used to create VirtualPorts used in PRUDP virtual +// connections. Each stream may be one of these types, and each stream +// has it's own state. +type StreamType uint8 + +// EnumIndex returns the StreamType enum index as a uint8 +func (st StreamType) EnumIndex() uint8 { + return uint8(st) +} + +const ( + // StreamTypeDO represents the DO PRUDP virtual connection stream type + StreamTypeDO StreamType = iota + 1 + + // StreamTypeRV represents the RV PRUDP virtual connection stream type + StreamTypeRV + + // StreamTypeOldRVSec represents the OldRVSec PRUDP virtual connection stream type + StreamTypeOldRVSec + + // StreamTypeSBMGMT represents the SBMGMT PRUDP virtual connection stream type + StreamTypeSBMGMT + + // StreamTypeNAT represents the NAT PRUDP virtual connection stream type + StreamTypeNAT + + // StreamTypeSessionDiscovery represents the SessionDiscovery PRUDP virtual connection stream type + StreamTypeSessionDiscovery + + // StreamTypeNATEcho represents the NATEcho PRUDP virtual connection stream type + StreamTypeNATEcho + + // StreamTypeRouting represents the Routing PRUDP virtual connection stream type + StreamTypeRouting + + // StreamTypeGame represents the Game PRUDP virtual connection stream type + StreamTypeGame + + // StreamTypeRVSecure represents the RVSecure PRUDP virtual connection stream type + StreamTypeRVSecure + + // StreamTypeRelay represents the Relay PRUDP virtual connection stream type + StreamTypeRelay +) diff --git a/test/auth.go b/test/auth.go index bdf30c2b..7c674603 100644 --- a/test/auth.go +++ b/test/auth.go @@ -16,7 +16,9 @@ func startAuthenticationServer() { authServer = nex.NewPRUDPServer() - authServer.OnData(func(packet nex.PacketInterface) { + endpoint := nex.NewPRUDPEndPoint(1) + + endpoint.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { request := packet.RMCMessage() @@ -35,11 +37,11 @@ func startAuthenticationServer() { }) authServer.SetFragmentSize(962) - //authServer.PRUDPVersion = 1 authServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) authServer.SetKerberosPassword([]byte("password")) authServer.SetKerberosKeySize(16) authServer.SetAccessKey("ridfebb9") + authServer.BindPRUDPEndPoint(endpoint) authServer.Listen(60000) } @@ -89,16 +91,16 @@ func login(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) responsePacket.AddFlag(nex.FlagReliable) responsePacket.AddFlag(nex.FlagNeedsAck) - responsePacket.SetSourceStreamType(packet.DestinationStreamType()) - responsePacket.SetSourcePort(packet.DestinationPort()) - responsePacket.SetDestinationStreamType(packet.SourceStreamType()) - responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + responsePacket.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) responsePacket.SetSubstreamID(packet.SubstreamID()) responsePacket.SetPayload(response.Bytes()) @@ -139,16 +141,16 @@ func requestTicket(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) responsePacket.AddFlag(nex.FlagReliable) responsePacket.AddFlag(nex.FlagNeedsAck) - responsePacket.SetSourceStreamType(packet.DestinationStreamType()) - responsePacket.SetSourcePort(packet.DestinationPort()) - responsePacket.SetDestinationStreamType(packet.SourceStreamType()) - responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + responsePacket.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) responsePacket.SetSubstreamID(packet.SubstreamID()) responsePacket.SetPayload(response.Bytes()) diff --git a/test/secure.go b/test/secure.go index a9eddc0c..f0aaa42b 100644 --- a/test/secure.go +++ b/test/secure.go @@ -46,7 +46,10 @@ func startSecureServer() { secureServer = nex.NewPRUDPServer() - secureServer.OnData(func(packet nex.PacketInterface) { + endpoint := nex.NewPRUDPEndPoint(1) + endpoint.IsSecureEndpoint = true + + endpoint.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { request := packet.RMCMessage() @@ -72,19 +75,19 @@ func startSecureServer() { } }) - secureServer.SecureVirtualServerPorts = []uint8{1} - //secureServer.PRUDPVersion = 1 secureServer.SetFragmentSize(962) secureServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) secureServer.SetKerberosPassword([]byte("password")) secureServer.SetKerberosKeySize(16) secureServer.SetAccessKey("ridfebb9") + secureServer.BindPRUDPEndPoint(endpoint) secureServer.Listen(60001) } func registerEx(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() response := nex.NewRMCMessage(secureServer) + connection := packet.Sender().(*nex.PRUDPConnection) parameters := request.Parameters @@ -114,7 +117,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { responseStream := nex.NewByteStreamOut(secureServer) retval.WriteTo(responseStream) - responseStream.WritePrimitiveUInt32LE(secureServer.ConnectionIDCounter().Next()) + responseStream.WritePrimitiveUInt32LE(connection.ID) localStationURL.WriteTo(responseStream) response.IsSuccess = true @@ -125,16 +128,16 @@ func registerEx(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(connection, nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) responsePacket.AddFlag(nex.FlagReliable) responsePacket.AddFlag(nex.FlagNeedsAck) - responsePacket.SetSourceStreamType(packet.DestinationStreamType()) - responsePacket.SetSourcePort(packet.DestinationPort()) - responsePacket.SetDestinationStreamType(packet.SourceStreamType()) - responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + responsePacket.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) responsePacket.SetSubstreamID(packet.SubstreamID()) responsePacket.SetPayload(response.Bytes()) @@ -173,16 +176,16 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) responsePacket.AddFlag(nex.FlagReliable) responsePacket.AddFlag(nex.FlagNeedsAck) - responsePacket.SetSourceStreamType(packet.DestinationStreamType()) - responsePacket.SetSourcePort(packet.DestinationPort()) - responsePacket.SetDestinationStreamType(packet.SourceStreamType()) - responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + responsePacket.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) responsePacket.SetSubstreamID(packet.SubstreamID()) responsePacket.SetPayload(response.Bytes()) @@ -205,16 +208,16 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) responsePacket.AddFlag(nex.FlagReliable) responsePacket.AddFlag(nex.FlagNeedsAck) - responsePacket.SetSourceStreamType(packet.DestinationStreamType()) - responsePacket.SetSourcePort(packet.DestinationPort()) - responsePacket.SetDestinationStreamType(packet.SourceStreamType()) - responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + responsePacket.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) responsePacket.SetSubstreamID(packet.SubstreamID()) responsePacket.SetPayload(response.Bytes()) @@ -232,16 +235,16 @@ func updatePresence(packet nex.PRUDPPacketInterface) { response.CallID = request.CallID response.MethodID = request.MethodID - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPClient), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) responsePacket.AddFlag(nex.FlagReliable) responsePacket.AddFlag(nex.FlagNeedsAck) - responsePacket.SetSourceStreamType(packet.DestinationStreamType()) - responsePacket.SetSourcePort(packet.DestinationPort()) - responsePacket.SetDestinationStreamType(packet.SourceStreamType()) - responsePacket.SetDestinationPort(packet.SourcePort()) + responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) + responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) + responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) + responsePacket.SetDestinationVirtualPortStreamID(packet.SourceVirtualPortStreamID()) responsePacket.SetSubstreamID(packet.SubstreamID()) responsePacket.SetPayload(response.Bytes()) diff --git a/virtual_port.go b/virtual_port.go new file mode 100644 index 00000000..7e38244f --- /dev/null +++ b/virtual_port.go @@ -0,0 +1,31 @@ +package nex + +// TODO - Should this be moved to the types module? + +// VirtualPort in an implementation of rdv::VirtualPort. +// PRUDP will reuse a single physical socket connection for many virtual PRUDP connections. +// VirtualPorts are a byte which represents a stream for a virtual PRUDP connection. +// This byte is two 4-bit fields. The upper 4 bits are the stream type, the lower 4 bits +// are the stream ID. The client starts with stream ID 15, decrementing by one with each new +// virtual connection. +type VirtualPort byte + +// SetStreamType sets the VirtualPort stream type +func (vp *VirtualPort) SetStreamType(streamType StreamType) { + *vp = VirtualPort((byte(*vp) & 0x0F) | (byte(streamType) << 4)) +} + +// StreamType returns the VirtualPort stream type +func (vp VirtualPort) StreamType() StreamType { + return StreamType(vp >> 4) +} + +// SetStreamID sets the VirtualPort stream ID +func (vp *VirtualPort) SetStreamID(streamID uint8) { + *vp = VirtualPort((byte(*vp) & 0xF0) | (streamID & 0x0F)) +} + +// StreamID returns the VirtualPort stream ID +func (vp VirtualPort) StreamID() uint8 { + return uint8(vp & 0xF) +} diff --git a/websocket_server.go b/websocket_server.go index 8a3f1587..e7f92931 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -3,7 +3,6 @@ package nex import ( "fmt" "net/http" - "strings" "time" "github.com/lxzan/gws" @@ -22,26 +21,17 @@ func (wseh *wsEventHandler) OnOpen(socket *gws.Conn) { _ = socket.SetDeadline(time.Now().Add(pingInterval + pingWait)) } -func (wseh *wsEventHandler) OnClose(socket *gws.Conn, err error) { - clientsToCleanup := make([]*PRUDPClient, 0) +func (wseh *wsEventHandler) OnClose(wsConn *gws.Conn, err error) { + connections := make([]*PRUDPConnection, 0) - // * Loop over all bound ports, and each ports stream types - // * to look for clients connecting from this WebSocket - // TODO - This kinda sucks tbh. Unsure how much this effects performance. Test more and refactor? - wseh.prudpServer.virtualServers.Each(func(port uint8, stream *MutexMap[uint8, *MutexMap[string, *PRUDPClient]]) bool { - stream.Each(func(streamType uint8, clients *MutexMap[string, *PRUDPClient]) bool { - clients.Each(func(discriminator string, client *PRUDPClient) bool { - if strings.HasPrefix(discriminator, socket.RemoteAddr().String()) { - clientsToCleanup = append(clientsToCleanup, client) - return true // * Assume only one client connected per server port per stream type - } - - return false - }) - - return false - }) + socket, ok := wseh.prudpServer.Connections.Get(wsConn.RemoteAddr().String()) + if !ok { + // TODO - Error? + return + } + socket.Connections.Each(func(_ uint8, connection *PRUDPConnection) bool { + connections = append(connections, connection) return false }) @@ -49,15 +39,8 @@ func (wseh *wsEventHandler) OnClose(socket *gws.Conn, err error) { // * since the mutex is locked. We first need to grab // * the entries we want to delete, and then loop over // * them here to actually clean them up - for _, client := range clientsToCleanup { - client.cleanup() // * "removed" event is dispatched here - - virtualServer, _ := wseh.prudpServer.virtualServers.Get(client.DestinationPort) - virtualServerStream, _ := virtualServer.Get(client.DestinationStreamType) - - discriminator := fmt.Sprintf("%s-%d-%d", client.address.String(), client.SourcePort, client.SourceStreamType) - - virtualServerStream.Delete(discriminator) + for _, connection := range connections { + connection.cleanup() // * "removed" event is dispatched here } } From a63ca441a93c33d60c3f8b1763bee735dbd84a2f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 15:07:57 -0500 Subject: [PATCH 112/178] encryption: added QuazalRC4 --- encryption/quazal_rc4.go | 95 ++++++++++++++++++++++++++++++++++++++++ prudp_packet.go | 8 ---- prudp_server.go | 8 ---- 3 files changed, 95 insertions(+), 16 deletions(-) create mode 100644 encryption/quazal_rc4.go diff --git a/encryption/quazal_rc4.go b/encryption/quazal_rc4.go new file mode 100644 index 00000000..63c423d7 --- /dev/null +++ b/encryption/quazal_rc4.go @@ -0,0 +1,95 @@ +package encryption + +import ( + "crypto/rc4" +) + +// QuazalRC4 encrypts data with RC4. Each iteration uses a new cipher instance. The key is always CD&ML +type QuazalRC4 struct { + key []byte + cipher *rc4.Cipher + decipher *rc4.Cipher + cipheredCount uint64 + decipheredCount uint64 +} + +// Key returns the crypto key +func (r *QuazalRC4) Key() []byte { + return r.key +} + +// SetKey sets the crypto key and updates the ciphers +func (r *QuazalRC4) SetKey(key []byte) error { + r.key = key + + cipher, err := rc4.NewCipher(key) + if err != nil { + return err + } + + decipher, err := rc4.NewCipher(key) + if err != nil { + return err + } + + r.cipher = cipher + r.decipher = decipher + + return nil +} + +// Encrypt encrypts the payload with the outgoing QuazalRC4 stream +func (r *QuazalRC4) Encrypt(payload []byte) ([]byte, error) { + r.SetKey([]byte("CD&ML")) + + ciphered := make([]byte, len(payload)) + + r.cipher.XORKeyStream(ciphered, payload) + + r.cipheredCount += uint64(len(payload)) + + return ciphered, nil +} + +// Decrypt decrypts the payload with the incoming QuazalRC4 stream +func (r *QuazalRC4) Decrypt(payload []byte) ([]byte, error) { + r.SetKey([]byte("CD&ML")) + + deciphered := make([]byte, len(payload)) + + r.decipher.XORKeyStream(deciphered, payload) + + r.decipheredCount += uint64(len(payload)) + + return deciphered, nil +} + +// Copy returns a copy of the algorithm while retaining it's state +func (r *QuazalRC4) Copy() Algorithm { + copied := NewQuazalRC4Encryption() + + copied.SetKey(r.key) + + // * crypto/rc4 does not expose a way to directly copy streams and retain their state. + // * This just discards the number of iterations done in the original ciphers to sync + // * the copied ciphers states to the original + for i := 0; i < int(r.cipheredCount); i++ { + copied.cipher.XORKeyStream([]byte{0}, []byte{0}) + } + + for i := 0; i < int(r.decipheredCount); i++ { + copied.decipher.XORKeyStream([]byte{0}, []byte{0}) + } + + copied.cipheredCount = r.cipheredCount + copied.decipheredCount = r.decipheredCount + + return copied +} + +// NewQuazalRC4Encryption returns a new instance of the QuazalRC4 encryption +func NewQuazalRC4Encryption() *QuazalRC4 { + return &QuazalRC4{ + key: make([]byte, 0), + } +} diff --git a/prudp_packet.go b/prudp_packet.go index a29fa418..535628a9 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -148,14 +148,6 @@ func (p *PRUDPPacket) decryptPayload() []byte { if p.packetType == DataPacket { slidingWindow := p.sender.SlidingWindow(p.SubstreamID()) - // * According to other Quazal server implementations, - // * the RC4 stream is always reset to the default key - // * regardless if the client is connecting to a secure - // * server (prudps) or not - if p.version == 0 && p.sender.Endpoint.Server.PRUDPV0Settings.IsQuazalMode { - slidingWindow.SetCipherKey([]byte("CD&ML")) - } - payload, _ = slidingWindow.streamSettings.EncryptionAlgorithm.Decrypt(payload) } diff --git a/prudp_server.go b/prudp_server.go index 7cfa9712..9efdc011 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -268,14 +268,6 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { logger.Error(err.Error()) } - // * According to other Quazal server implementations, - // * the RC4 stream is always reset to the default key - // * regardless if the client is connecting to a secure - // * server (prudps) or not - if packet.Version() == 0 && s.PRUDPV0Settings.IsQuazalMode { - slidingWindow.SetCipherKey([]byte("CD&ML")) - } - encryptedPayload, err := slidingWindow.streamSettings.EncryptionAlgorithm.Encrypt(compressedPayload) if err != nil { logger.Error(err.Error()) From a8eb72b445a62c9efc1e7fc6b1ad3d15cbd267ed Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 15:08:17 -0500 Subject: [PATCH 113/178] encryption: updated RC4 Godoc comment --- encryption/rc4.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/encryption/rc4.go b/encryption/rc4.go index c4b16b9a..a59f7813 100644 --- a/encryption/rc4.go +++ b/encryption/rc4.go @@ -4,7 +4,7 @@ import ( "crypto/rc4" ) -// RC4 does no encryption. Payloads are returned as-is +// RC4 encrypts data with RC4 type RC4 struct { key []byte cipher *rc4.Cipher From 24d78540b5f720f355edbf4d2501c4dc2f03e974 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 15:13:43 -0500 Subject: [PATCH 114/178] chore: update README --- README.md | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 0c5effbf..ab6df18f 100644 --- a/README.md +++ b/README.md @@ -59,23 +59,37 @@ import ( ) func main() { - nexServer := nex.NewPRUDPServer() - nexServer.PRUDPVersion = 0 - nexServer.SetFragmentSize(962) - nexServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) - nexServer.SetKerberosPassword([]byte("password")) - nexServer.SetKerberosKeySize(16) - nexServer.SetAccessKey("ridfebb9") - - nexServer.OnData(func(packet nex.PacketInterface) { - request := packet.RMCMessage() - - fmt.Println("==Friends - Auth==") - fmt.Printf("Protocol ID: %#v\n", request.ProtocolID) - fmt.Printf("Method ID: %#v\n", request.MethodID) - fmt.Println("==================") + // Skeleton of a WiiU/3DS Friends server running on PRUDPv0 with a single endpoint + + authServer := nex.NewPRUDPServer() // The main PRUDP server + endpoint := nex.NewPRUDPEndPoint(1) // A PRUDP endpoint for PRUDP connections to connect to. Bound to StreamID 1 + + // Setup event handlers for the endpoint + endpoint.OnData(func(packet nex.PacketInterface) { + if packet, ok := packet.(nex.PRUDPPacketInterface); ok { + request := packet.RMCMessage() + + fmt.Println("[AUTH]", request.ProtocolID, request.MethodID) + + if request.ProtocolID == 0xA { // TicketGrantingProtocol + if request.MethodID == 0x1 { // TicketGrantingProtocol::Login + handleLogin(packet) + } + + if request.MethodID == 0x3 { // TicketGrantingProtocol::RequestTicket + handleRequestTicket(packet) + } + } + } }) - nexServer.Listen(60000) + // Bind the endpoint to the server and configure it's settings + authServer.BindPRUDPEndPoint(endpoint) + authServer.SetFragmentSize(962) + authServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) + authServer.SetKerberosPassword([]byte("password")) + authServer.SetKerberosKeySize(16) + authServer.SetAccessKey("ridfebb9") + authServer.Listen(60000) } ``` From 8dba91974d5b090c24fe09f7f2a6ca99227461ec Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 16:32:57 -0500 Subject: [PATCH 115/178] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniel López Guimaraes <112760654+DaniElectra@users.noreply.github.com> --- prudp_server.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 9efdc011..626ea6a0 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -102,7 +102,6 @@ func (s *PRUDPServer) listenDatagram(quit chan struct{}) { // ListenWebSocket starts a PRUDP server on a given port using a WebSocket server func (s *PRUDPServer) ListenWebSocket(port int) { s.initPRUDPv1ConnectionSignatureKey() - //s.initVirtualPorts() s.websocketServer = &WebSocketServer{ prudpServer: s, @@ -114,7 +113,6 @@ func (s *PRUDPServer) ListenWebSocket(port int) { // ListenWebSocketSecure starts a PRUDP server on a given port using a secure (TLS) WebSocket server func (s *PRUDPServer) ListenWebSocketSecure(port int, certFile, keyFile string) { s.initPRUDPv1ConnectionSignatureKey() - //s.initVirtualPorts() s.websocketServer = &WebSocketServer{ prudpServer: s, From 9dc57ac1ff13275fc137dc596d916c40fcc3a500 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 16:39:05 -0500 Subject: [PATCH 116/178] types: Result -> QResult --- test/auth.go | 4 +- test/secure.go | 4 +- types/qresult.go | 96 ++++++++++++++++++++++++++++++++++++++++++++++++ types/result.go | 96 ------------------------------------------------ 4 files changed, 100 insertions(+), 100 deletions(-) create mode 100644 types/qresult.go delete mode 100644 types/result.go diff --git a/test/auth.go b/test/auth.go index 7c674603..b6590415 100644 --- a/test/auth.go +++ b/test/auth.go @@ -63,7 +63,7 @@ func login(packet nex.PRUDPPacketInterface) { panic(err) } - retval := types.NewResultSuccess(0x00010001) + retval := types.NewQResultSuccess(0x00010001) pidPrincipal := types.NewPID(uint64(converted)) pbufResponse := types.NewBuffer(generateTicket(pidPrincipal, types.NewPID(2))) pConnectionData := types.NewRVConnectionData() @@ -125,7 +125,7 @@ func requestTicket(packet nex.PRUDPPacketInterface) { panic(err) } - retval := types.NewResultSuccess(0x00010001) + retval := types.NewQResultSuccess(0x00010001) pbufResponse := types.NewBuffer(generateTicket(idSource, idTarget)) responseStream := nex.NewByteStreamOut(authServer) diff --git a/test/secure.go b/test/secure.go index f0aaa42b..9b3c54cb 100644 --- a/test/secure.go +++ b/test/secure.go @@ -46,7 +46,7 @@ func startSecureServer() { secureServer = nex.NewPRUDPServer() - endpoint := nex.NewPRUDPEndPoint(1) + endpoint := nex.NewPRUDPEndPoint(2) endpoint.IsSecureEndpoint = true endpoint.OnData(func(packet nex.PacketInterface) { @@ -111,7 +111,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { localStation.Fields["address"] = address localStation.Fields["port"] = strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port) - retval := types.NewResultSuccess(0x00010001) + retval := types.NewQResultSuccess(0x00010001) localStationURL := types.NewString(localStation.EncodeToString()) responseStream := nex.NewByteStreamOut(secureServer) diff --git a/types/qresult.go b/types/qresult.go new file mode 100644 index 00000000..b3af2637 --- /dev/null +++ b/types/qresult.go @@ -0,0 +1,96 @@ +package types + +import ( + "fmt" + "strings" +) + +var errorMask = 1 << 31 + +// QResult is an implementation of rdv::qResult. +// Determines the result of an operation. +// If the MSB is set the result is an error, otherwise success +type QResult struct { + Code uint32 +} + +// WriteTo writes the QResult to the given writable +func (r *QResult) WriteTo(writable Writable) { + writable.WritePrimitiveUInt32LE(r.Code) +} + +// ExtractFrom extracts the QResult from the given readable +func (r *QResult) ExtractFrom(readable Readable) error { + code, err := readable.ReadPrimitiveUInt32LE() + if err != nil { + return fmt.Errorf("Failed to read QResult code. %s", err.Error()) + } + + r.Code = code + + return nil +} + +// Copy returns a pointer to a copy of the QResult. Requires type assertion when used +func (r *QResult) Copy() RVType { + return NewQResult(r.Code) +} + +// Equals checks if the input is equal in value to the current instance +func (r *QResult) Equals(o RVType) bool { + if _, ok := o.(*QResult); !ok { + return false + } + + return r.Code == o.(*QResult).Code +} + +// IsSuccess returns true if the QResult is a success +func (r *QResult) IsSuccess() bool { + return int(r.Code)&errorMask == 0 +} + +// IsError returns true if the QResult is a error +func (r *QResult) IsError() bool { + return int(r.Code)&errorMask != 0 +} + +// String returns a string representation of the struct +func (r *QResult) String() string { + return r.FormatToString(0) +} + +// FormatToString pretty-prints the struct data using the provided indentation level +func (r *QResult) FormatToString(indentationLevel int) string { + indentationValues := strings.Repeat("\t", indentationLevel+1) + indentationEnd := strings.Repeat("\t", indentationLevel) + + var b strings.Builder + + b.WriteString("QResult{\n") + + if r.IsSuccess() { + b.WriteString(fmt.Sprintf("%scode: %d (success)\n", indentationValues, r.Code)) + } else { + b.WriteString(fmt.Sprintf("%scode: %d (error)\n", indentationValues, r.Code)) + } + + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) + + return b.String() +} + +// NewQResult returns a new QResult +func NewQResult(code uint32) *QResult { + return &QResult{code} +} + +// NewQResultSuccess returns a new QResult set as a success +func NewQResultSuccess(code uint32) *QResult { + return NewQResult(uint32(int(code) & ^errorMask)) +} + +// NewQResultError returns a new QResult set as an error +func NewQResultError(code uint32) *QResult { + return NewQResult(uint32(int(code) | errorMask)) +} diff --git a/types/result.go b/types/result.go deleted file mode 100644 index c813ade8..00000000 --- a/types/result.go +++ /dev/null @@ -1,96 +0,0 @@ -package types - -import ( - "fmt" - "strings" -) - -var errorMask = 1 << 31 - -// Result is an implementation of nn::Result. -// Determines the result of an operation. -// If the MSB is set the result is an error, otherwise success -type Result struct { - Code uint32 -} - -// WriteTo writes the Result to the given writable -func (r *Result) WriteTo(writable Writable) { - writable.WritePrimitiveUInt32LE(r.Code) -} - -// ExtractFrom extracts the Result from the given readable -func (r *Result) ExtractFrom(readable Readable) error { - code, err := readable.ReadPrimitiveUInt32LE() - if err != nil { - return fmt.Errorf("Failed to read Result code. %s", err.Error()) - } - - r.Code = code - - return nil -} - -// Copy returns a pointer to a copy of the Result. Requires type assertion when used -func (r *Result) Copy() RVType { - return NewResult(r.Code) -} - -// Equals checks if the input is equal in value to the current instance -func (r *Result) Equals(o RVType) bool { - if _, ok := o.(*Result); !ok { - return false - } - - return r.Code == o.(*Result).Code -} - -// IsSuccess returns true if the Result is a success -func (r *Result) IsSuccess() bool { - return int(r.Code)&errorMask == 0 -} - -// IsError returns true if the Result is a error -func (r *Result) IsError() bool { - return int(r.Code)&errorMask != 0 -} - -// String returns a string representation of the struct -func (r *Result) String() string { - return r.FormatToString(0) -} - -// FormatToString pretty-prints the struct data using the provided indentation level -func (r *Result) FormatToString(indentationLevel int) string { - indentationValues := strings.Repeat("\t", indentationLevel+1) - indentationEnd := strings.Repeat("\t", indentationLevel) - - var b strings.Builder - - b.WriteString("Result{\n") - - if r.IsSuccess() { - b.WriteString(fmt.Sprintf("%scode: %d (success)\n", indentationValues, r.Code)) - } else { - b.WriteString(fmt.Sprintf("%scode: %d (error)\n", indentationValues, r.Code)) - } - - b.WriteString(fmt.Sprintf("%s}", indentationEnd)) - - return b.String() -} - -// NewResult returns a new Result -func NewResult(code uint32) *Result { - return &Result{code} -} - -// NewResultSuccess returns a new Result set as a success -func NewResultSuccess(code uint32) *Result { - return NewResult(uint32(int(code) & ^errorMask)) -} - -// NewResultError returns a new Result set as an error -func NewResultError(code uint32) *Result { - return NewResult(uint32(int(code) | errorMask)) -} From ee2bca04dfdb879b379a67047d79af49dfc45e07 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 17:55:35 -0500 Subject: [PATCH 117/178] prudp: add events back to PRUDPServer --- prudp_endpoint.go | 3 +++ prudp_server.go | 38 +++++++++++++++++++++++++++++++------- server_interface.go | 1 + 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 1502e417..e1e4ff01 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -58,6 +58,9 @@ func (pep *PRUDPEndPoint) emit(name string, packet PRUDPPacketInterface) { go handler(packet) } } + + // * propagate the event up to the PRUDP server + pep.Server.emit(name, packet) } func (pep *PRUDPEndPoint) emitConnectionEnded(connection *PRUDPConnection) { diff --git a/prudp_server.go b/prudp_server.go index 626ea6a0..1176fe9e 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -37,6 +37,7 @@ type PRUDPServer struct { PRUDPv1ConnectionSignatureKey []byte byteStreamSettings *ByteStreamSettings PRUDPV0Settings *PRUDPV0Settings + packetEventHandlers map[string][]func(packet PacketInterface) } // BindPRUDPEndPoint binds a provided PRUDPEndPoint to the server @@ -470,15 +471,38 @@ func (s *PRUDPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettin s.byteStreamSettings = byteStreamSettings } +// OnData adds an event handler which is fired when a new DATA packet is received +func (s *PRUDPServer) OnData(handler func(packet PacketInterface)) { + s.on("data", handler) +} + +func (s *PRUDPServer) on(name string, handler func(packet PacketInterface)) { + if _, ok := s.packetEventHandlers[name]; !ok { + s.packetEventHandlers[name] = make([]func(packet PacketInterface), 0) + } + + s.packetEventHandlers[name] = append(s.packetEventHandlers[name], handler) +} + +// emit emits an event to all relevant listeners. These events fire after the PRUDPEndPoint event handlers +func (s *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { + if handlers, ok := s.packetEventHandlers[name]; ok { + for _, handler := range handlers { + go handler(packet) + } + } +} + // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ - Endpoints: NewMutexMap[uint8, *PRUDPEndPoint](), - Connections: NewMutexMap[string, *SocketConnection](), - kerberosKeySize: 32, - FragmentSize: 1300, - pingTimeout: time.Second * 15, - byteStreamSettings: NewByteStreamSettings(), - PRUDPV0Settings: NewPRUDPV0Settings(), + Endpoints: NewMutexMap[uint8, *PRUDPEndPoint](), + Connections: NewMutexMap[string, *SocketConnection](), + kerberosKeySize: 32, + FragmentSize: 1300, + pingTimeout: time.Second * 15, + byteStreamSettings: NewByteStreamSettings(), + PRUDPV0Settings: NewPRUDPV0Settings(), + packetEventHandlers: make(map[string][]func(PacketInterface)), } } diff --git a/server_interface.go b/server_interface.go index bed89aaa..a3ae1a2a 100644 --- a/server_interface.go +++ b/server_interface.go @@ -16,6 +16,7 @@ type ServerInterface interface { NATTraversalProtocolVersion() *LibraryVersion SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) + OnData(handler func(packet PacketInterface)) PasswordFromPID(pid *types.PID) (string, uint32) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) ByteStreamSettings() *ByteStreamSettings From 6697eec1f851ce7e0bedf78bf4bcbe82ae2b04b5 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 17:56:49 -0500 Subject: [PATCH 118/178] prudp: change PRUDPServer variable name --- prudp_server.go | 240 ++++++++++++++++++++++++------------------------ 1 file changed, 120 insertions(+), 120 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index 1176fe9e..b35fb191 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -41,24 +41,24 @@ type PRUDPServer struct { } // BindPRUDPEndPoint binds a provided PRUDPEndPoint to the server -func (s *PRUDPServer) BindPRUDPEndPoint(endpoint *PRUDPEndPoint) { - if s.Endpoints.Has(endpoint.StreamID) { +func (ps *PRUDPServer) BindPRUDPEndPoint(endpoint *PRUDPEndPoint) { + if ps.Endpoints.Has(endpoint.StreamID) { logger.Warningf("Tried to bind already existing PRUDPEndPoint %d", endpoint.StreamID) return } - endpoint.Server = s - s.Endpoints.Set(endpoint.StreamID, endpoint) + endpoint.Server = ps + ps.Endpoints.Set(endpoint.StreamID, endpoint) } // Listen is an alias of ListenUDP. Implemented to conform to the ServerInterface -func (s *PRUDPServer) Listen(port int) { - s.ListenUDP(port) +func (ps *PRUDPServer) Listen(port int) { + ps.ListenUDP(port) } // ListenUDP starts a PRUDP server on a given port using a UDP server -func (s *PRUDPServer) ListenUDP(port int) { - s.initPRUDPv1ConnectionSignatureKey() +func (ps *PRUDPServer) ListenUDP(port int) { + ps.initPRUDPv1ConnectionSignatureKey() udpAddress, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err != nil { @@ -70,18 +70,18 @@ func (s *PRUDPServer) ListenUDP(port int) { panic(err) } - s.udpSocket = socket + ps.udpSocket = socket quit := make(chan struct{}) for i := 0; i < runtime.NumCPU(); i++ { - go s.listenDatagram(quit) + go ps.listenDatagram(quit) } <-quit } -func (s *PRUDPServer) listenDatagram(quit chan struct{}) { +func (ps *PRUDPServer) listenDatagram(quit chan struct{}) { var err error for err == nil { @@ -89,10 +89,10 @@ func (s *PRUDPServer) listenDatagram(quit chan struct{}) { var read int var addr *net.UDPAddr - read, addr, err = s.udpSocket.ReadFromUDP(buffer) + read, addr, err = ps.udpSocket.ReadFromUDP(buffer) packetData := buffer[:read] - err = s.handleSocketMessage(packetData, addr, nil) + err = ps.handleSocketMessage(packetData, addr, nil) } quit <- struct{}{} @@ -101,40 +101,40 @@ func (s *PRUDPServer) listenDatagram(quit chan struct{}) { } // ListenWebSocket starts a PRUDP server on a given port using a WebSocket server -func (s *PRUDPServer) ListenWebSocket(port int) { - s.initPRUDPv1ConnectionSignatureKey() +func (ps *PRUDPServer) ListenWebSocket(port int) { + ps.initPRUDPv1ConnectionSignatureKey() - s.websocketServer = &WebSocketServer{ - prudpServer: s, + ps.websocketServer = &WebSocketServer{ + prudpServer: ps, } - s.websocketServer.listen(port) + ps.websocketServer.listen(port) } // ListenWebSocketSecure starts a PRUDP server on a given port using a secure (TLS) WebSocket server -func (s *PRUDPServer) ListenWebSocketSecure(port int, certFile, keyFile string) { - s.initPRUDPv1ConnectionSignatureKey() +func (ps *PRUDPServer) ListenWebSocketSecure(port int, certFile, keyFile string) { + ps.initPRUDPv1ConnectionSignatureKey() - s.websocketServer = &WebSocketServer{ - prudpServer: s, + ps.websocketServer = &WebSocketServer{ + prudpServer: ps, } - s.websocketServer.listenSecure(port, certFile, keyFile) + ps.websocketServer.listenSecure(port, certFile, keyFile) } -func (s *PRUDPServer) initPRUDPv1ConnectionSignatureKey() { +func (ps *PRUDPServer) initPRUDPv1ConnectionSignatureKey() { // * Ensure the server has a key for PRUDPv1 connection signatures - if len(s.PRUDPv1ConnectionSignatureKey) != 16 { - s.PRUDPv1ConnectionSignatureKey = make([]byte, 16) - _, err := rand.Read(s.PRUDPv1ConnectionSignatureKey) + if len(ps.PRUDPv1ConnectionSignatureKey) != 16 { + ps.PRUDPv1ConnectionSignatureKey = make([]byte, 16) + _, err := rand.Read(ps.PRUDPv1ConnectionSignatureKey) if err != nil { panic(err) } } } -func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *gws.Conn) error { - readStream := NewByteStreamIn(packetData, s) +func (ps *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *gws.Conn) error { + readStream := NewByteStreamIn(packetData, ps) var packets []PRUDPPacketInterface @@ -142,7 +142,7 @@ func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, w // * with that same type. Also keep reading from the stream // * until no more data is left, to account for multiple // * packets being sent at once - if s.websocketServer != nil && packetData[0] == 0x80 { + if ps.websocketServer != nil && packetData[0] == 0x80 { packets, _ = NewPRUDPPacketsLite(nil, readStream) } else if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) { packets, _ = NewPRUDPPacketsV1(nil, readStream) @@ -151,19 +151,19 @@ func (s *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, w } for _, packet := range packets { - go s.processPacket(packet, address, webSocketConnection) + go ps.processPacket(packet, address, webSocketConnection) } return nil } -func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *gws.Conn) { - if !s.Endpoints.Has(packet.DestinationVirtualPortStreamID()) { +func (ps *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *gws.Conn) { + if !ps.Endpoints.Has(packet.DestinationVirtualPortStreamID()) { logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID()) return } - endpoint, ok := s.Endpoints.Get(packet.DestinationVirtualPortStreamID()) + endpoint, ok := ps.Endpoints.Get(packet.DestinationVirtualPortStreamID()) if !ok { logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID()) return @@ -201,40 +201,40 @@ func (s *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Add } discriminator := address.String() - socket, ok := s.Connections.Get(discriminator) + socket, ok := ps.Connections.Get(discriminator) if !ok { - socket = NewSocketConnection(s, address, webSocketConnection) - s.Connections.Set(discriminator, socket) + socket = NewSocketConnection(ps, address, webSocketConnection) + ps.Connections.Set(discriminator, socket) } endpoint.processPacket(packet, socket) } // Send sends the packet to the packets sender -func (s *PRUDPServer) Send(packet PacketInterface) { +func (ps *PRUDPServer) Send(packet PacketInterface) { if packet, ok := packet.(PRUDPPacketInterface); ok { data := packet.Payload() - fragments := int(len(data) / s.FragmentSize) + fragments := int(len(data) / ps.FragmentSize) var fragmentID uint8 = 1 for i := 0; i <= fragments; i++ { - if len(data) < s.FragmentSize { + if len(data) < ps.FragmentSize { packet.SetPayload(data) packet.setFragmentID(0) } else { - packet.SetPayload(data[:s.FragmentSize]) + packet.SetPayload(data[:ps.FragmentSize]) packet.setFragmentID(fragmentID) - data = data[s.FragmentSize:] + data = data[ps.FragmentSize:] fragmentID++ } - s.sendPacket(packet) + ps.sendPacket(packet) } } } -func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { +func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { // * PRUDPServer.Send will send fragments as the same packet, // * just with different fields. In order to prevent modifying // * multiple packets at once, due to the same pointer being @@ -288,17 +288,17 @@ func (s *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { slidingWindow.ResendScheduler.AddPacket(packetCopy) } - s.sendRaw(packetCopy.Sender().(*PRUDPConnection).Socket, packetCopy.Bytes()) + ps.sendRaw(packetCopy.Sender().(*PRUDPConnection).Socket, packetCopy.Bytes()) } // sendRaw will send the given socket the provided packet -func (s *PRUDPServer) sendRaw(socket *SocketConnection, data []byte) { +func (ps *PRUDPServer) sendRaw(socket *SocketConnection, data []byte) { // TODO - Should this return the error too? var err error - if address, ok := socket.Address.(*net.UDPAddr); ok && s.udpSocket != nil { - _, err = s.udpSocket.WriteToUDP(data, address) + if address, ok := socket.Address.(*net.UDPAddr); ok && ps.udpSocket != nil { + _, err = ps.udpSocket.WriteToUDP(data, address) } else if socket.WebSocketConnection != nil { err = socket.WebSocketConnection.WriteMessage(gws.OpcodeBinary, data) } @@ -309,27 +309,27 @@ func (s *PRUDPServer) sendRaw(socket *SocketConnection, data []byte) { } // AccessKey returns the servers sandbox access key -func (s *PRUDPServer) AccessKey() string { - return s.accessKey +func (ps *PRUDPServer) AccessKey() string { + return ps.accessKey } // SetAccessKey sets the servers sandbox access key -func (s *PRUDPServer) SetAccessKey(accessKey string) { - s.accessKey = accessKey +func (ps *PRUDPServer) SetAccessKey(accessKey string) { + ps.accessKey = accessKey } // KerberosPassword returns the server kerberos password -func (s *PRUDPServer) KerberosPassword() []byte { - return s.kerberosPassword +func (ps *PRUDPServer) KerberosPassword() []byte { + return ps.kerberosPassword } // SetKerberosPassword sets the server kerberos password -func (s *PRUDPServer) SetKerberosPassword(kerberosPassword []byte) { - s.kerberosPassword = kerberosPassword +func (ps *PRUDPServer) SetKerberosPassword(kerberosPassword []byte) { + ps.kerberosPassword = kerberosPassword } // SetFragmentSize sets the max size for a packets payload -func (s *PRUDPServer) SetFragmentSize(fragmentSize int) { +func (ps *PRUDPServer) SetFragmentSize(fragmentSize int) { // TODO - Derive this value from the MTU // * From the wiki: // * @@ -341,152 +341,152 @@ func (s *PRUDPServer) SetFragmentSize(fragmentSize int) { // * // * Later, the MTU was increased to 1364, and the maximum payload // * size is seems to be 1300 bytes, unless PRUDP v0 is used, in which case it’s 1264 bytes. - s.FragmentSize = fragmentSize + ps.FragmentSize = fragmentSize } // SetKerberosTicketVersion sets the version used when handling kerberos tickets -func (s *PRUDPServer) SetKerberosTicketVersion(kerberosTicketVersion int) { - s.kerberosTicketVersion = kerberosTicketVersion +func (ps *PRUDPServer) SetKerberosTicketVersion(kerberosTicketVersion int) { + ps.kerberosTicketVersion = kerberosTicketVersion } // KerberosKeySize gets the size for the kerberos session key -func (s *PRUDPServer) KerberosKeySize() int { - return s.kerberosKeySize +func (ps *PRUDPServer) KerberosKeySize() int { + return ps.kerberosKeySize } // SetKerberosKeySize sets the size for the kerberos session key -func (s *PRUDPServer) SetKerberosKeySize(kerberosKeySize int) { - s.kerberosKeySize = kerberosKeySize +func (ps *PRUDPServer) SetKerberosKeySize(kerberosKeySize int) { + ps.kerberosKeySize = kerberosKeySize } // LibraryVersion returns the server NEX version -func (s *PRUDPServer) LibraryVersion() *LibraryVersion { - return s.version +func (ps *PRUDPServer) LibraryVersion() *LibraryVersion { + return ps.version } // SetDefaultLibraryVersion sets the default NEX protocol versions -func (s *PRUDPServer) SetDefaultLibraryVersion(version *LibraryVersion) { - s.version = version - s.datastoreProtocolVersion = version.Copy() - s.matchMakingProtocolVersion = version.Copy() - s.rankingProtocolVersion = version.Copy() - s.ranking2ProtocolVersion = version.Copy() - s.messagingProtocolVersion = version.Copy() - s.utilityProtocolVersion = version.Copy() - s.natTraversalProtocolVersion = version.Copy() +func (ps *PRUDPServer) SetDefaultLibraryVersion(version *LibraryVersion) { + ps.version = version + ps.datastoreProtocolVersion = version.Copy() + ps.matchMakingProtocolVersion = version.Copy() + ps.rankingProtocolVersion = version.Copy() + ps.ranking2ProtocolVersion = version.Copy() + ps.messagingProtocolVersion = version.Copy() + ps.utilityProtocolVersion = version.Copy() + ps.natTraversalProtocolVersion = version.Copy() } // DataStoreProtocolVersion returns the servers DataStore protocol version -func (s *PRUDPServer) DataStoreProtocolVersion() *LibraryVersion { - return s.datastoreProtocolVersion +func (ps *PRUDPServer) DataStoreProtocolVersion() *LibraryVersion { + return ps.datastoreProtocolVersion } // SetDataStoreProtocolVersion sets the servers DataStore protocol version -func (s *PRUDPServer) SetDataStoreProtocolVersion(version *LibraryVersion) { - s.datastoreProtocolVersion = version +func (ps *PRUDPServer) SetDataStoreProtocolVersion(version *LibraryVersion) { + ps.datastoreProtocolVersion = version } // MatchMakingProtocolVersion returns the servers MatchMaking protocol version -func (s *PRUDPServer) MatchMakingProtocolVersion() *LibraryVersion { - return s.matchMakingProtocolVersion +func (ps *PRUDPServer) MatchMakingProtocolVersion() *LibraryVersion { + return ps.matchMakingProtocolVersion } // SetMatchMakingProtocolVersion sets the servers MatchMaking protocol version -func (s *PRUDPServer) SetMatchMakingProtocolVersion(version *LibraryVersion) { - s.matchMakingProtocolVersion = version +func (ps *PRUDPServer) SetMatchMakingProtocolVersion(version *LibraryVersion) { + ps.matchMakingProtocolVersion = version } // RankingProtocolVersion returns the servers Ranking protocol version -func (s *PRUDPServer) RankingProtocolVersion() *LibraryVersion { - return s.rankingProtocolVersion +func (ps *PRUDPServer) RankingProtocolVersion() *LibraryVersion { + return ps.rankingProtocolVersion } // SetRankingProtocolVersion sets the servers Ranking protocol version -func (s *PRUDPServer) SetRankingProtocolVersion(version *LibraryVersion) { - s.rankingProtocolVersion = version +func (ps *PRUDPServer) SetRankingProtocolVersion(version *LibraryVersion) { + ps.rankingProtocolVersion = version } // Ranking2ProtocolVersion returns the servers Ranking2 protocol version -func (s *PRUDPServer) Ranking2ProtocolVersion() *LibraryVersion { - return s.ranking2ProtocolVersion +func (ps *PRUDPServer) Ranking2ProtocolVersion() *LibraryVersion { + return ps.ranking2ProtocolVersion } // SetRanking2ProtocolVersion sets the servers Ranking2 protocol version -func (s *PRUDPServer) SetRanking2ProtocolVersion(version *LibraryVersion) { - s.ranking2ProtocolVersion = version +func (ps *PRUDPServer) SetRanking2ProtocolVersion(version *LibraryVersion) { + ps.ranking2ProtocolVersion = version } // MessagingProtocolVersion returns the servers Messaging protocol version -func (s *PRUDPServer) MessagingProtocolVersion() *LibraryVersion { - return s.messagingProtocolVersion +func (ps *PRUDPServer) MessagingProtocolVersion() *LibraryVersion { + return ps.messagingProtocolVersion } // SetMessagingProtocolVersion sets the servers Messaging protocol version -func (s *PRUDPServer) SetMessagingProtocolVersion(version *LibraryVersion) { - s.messagingProtocolVersion = version +func (ps *PRUDPServer) SetMessagingProtocolVersion(version *LibraryVersion) { + ps.messagingProtocolVersion = version } // UtilityProtocolVersion returns the servers Utility protocol version -func (s *PRUDPServer) UtilityProtocolVersion() *LibraryVersion { - return s.utilityProtocolVersion +func (ps *PRUDPServer) UtilityProtocolVersion() *LibraryVersion { + return ps.utilityProtocolVersion } // SetUtilityProtocolVersion sets the servers Utility protocol version -func (s *PRUDPServer) SetUtilityProtocolVersion(version *LibraryVersion) { - s.utilityProtocolVersion = version +func (ps *PRUDPServer) SetUtilityProtocolVersion(version *LibraryVersion) { + ps.utilityProtocolVersion = version } // SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version -func (s *PRUDPServer) SetNATTraversalProtocolVersion(version *LibraryVersion) { - s.natTraversalProtocolVersion = version +func (ps *PRUDPServer) SetNATTraversalProtocolVersion(version *LibraryVersion) { + ps.natTraversalProtocolVersion = version } // NATTraversalProtocolVersion returns the servers NAT Traversal protocol version -func (s *PRUDPServer) NATTraversalProtocolVersion() *LibraryVersion { - return s.natTraversalProtocolVersion +func (ps *PRUDPServer) NATTraversalProtocolVersion() *LibraryVersion { + return ps.natTraversalProtocolVersion } // PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result -func (s *PRUDPServer) PasswordFromPID(pid *types.PID) (string, uint32) { - if s.passwordFromPIDHandler == nil { +func (ps *PRUDPServer) PasswordFromPID(pid *types.PID) (string, uint32) { + if ps.passwordFromPIDHandler == nil { logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") return "", Errors.Core.NotImplemented } - return s.passwordFromPIDHandler(pid) + return ps.passwordFromPIDHandler(pid) } // SetPasswordFromPIDFunction sets the function for the auth server to get a NEX password using the PID -func (s *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) { - s.passwordFromPIDHandler = handler +func (ps *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) { + ps.passwordFromPIDHandler = handler } // ByteStreamSettings returns the settings to be used for ByteStreams -func (s *PRUDPServer) ByteStreamSettings() *ByteStreamSettings { - return s.byteStreamSettings +func (ps *PRUDPServer) ByteStreamSettings() *ByteStreamSettings { + return ps.byteStreamSettings } // SetByteStreamSettings sets the settings to be used for ByteStreams -func (s *PRUDPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings) { - s.byteStreamSettings = byteStreamSettings +func (ps *PRUDPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings) { + ps.byteStreamSettings = byteStreamSettings } // OnData adds an event handler which is fired when a new DATA packet is received -func (s *PRUDPServer) OnData(handler func(packet PacketInterface)) { - s.on("data", handler) +func (ps *PRUDPServer) OnData(handler func(packet PacketInterface)) { + ps.on("data", handler) } -func (s *PRUDPServer) on(name string, handler func(packet PacketInterface)) { - if _, ok := s.packetEventHandlers[name]; !ok { - s.packetEventHandlers[name] = make([]func(packet PacketInterface), 0) +func (ps *PRUDPServer) on(name string, handler func(packet PacketInterface)) { + if _, ok := ps.packetEventHandlers[name]; !ok { + ps.packetEventHandlers[name] = make([]func(packet PacketInterface), 0) } - s.packetEventHandlers[name] = append(s.packetEventHandlers[name], handler) + ps.packetEventHandlers[name] = append(ps.packetEventHandlers[name], handler) } // emit emits an event to all relevant listeners. These events fire after the PRUDPEndPoint event handlers -func (s *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { - if handlers, ok := s.packetEventHandlers[name]; ok { +func (ps *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { + if handlers, ok := ps.packetEventHandlers[name]; ok { for _, handler := range handlers { go handler(packet) } From e6c15b1b3319ffd05bd24502d372cbb5c2e02581 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 15 Jan 2024 17:57:37 -0500 Subject: [PATCH 119/178] prudp: remove TODO from virtual_port.go --- virtual_port.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/virtual_port.go b/virtual_port.go index 7e38244f..d479515f 100644 --- a/virtual_port.go +++ b/virtual_port.go @@ -1,7 +1,5 @@ package nex -// TODO - Should this be moved to the types module? - // VirtualPort in an implementation of rdv::VirtualPort. // PRUDP will reuse a single physical socket connection for many virtual PRUDP connections. // VirtualPorts are a byte which represents a stream for a virtual PRUDP connection. From cc18168f0e4e331169eddbb5c286db4269cadd3e Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 20 Jan 2024 13:20:01 -0500 Subject: [PATCH 120/178] types: added List.Length receiver --- types/list.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/types/list.go b/types/list.go index 39c26c54..4df9dc86 100644 --- a/types/list.go +++ b/types/list.go @@ -100,6 +100,11 @@ func (l *List[T]) SetFromData(data []T) { l.real = data } +// Length returns the number of elements in the List +func (l *List[T]) Length() int { + return len(l.real) +} + // String returns a string representation of the struct func (l *List[T]) String() string { return fmt.Sprintf("%v", l.real) From a9cf3d73cbf1a97cce26b3f595a318f72ea5a470 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 20 Jan 2024 13:20:48 -0500 Subject: [PATCH 121/178] types: added bitwise receivers to numeric types --- types/primitive_s16.go | 70 ++++++++++++++++++++++++++++++++++++++++++ types/primitive_s32.go | 70 ++++++++++++++++++++++++++++++++++++++++++ types/primitive_s64.go | 70 ++++++++++++++++++++++++++++++++++++++++++ types/primitive_s8.go | 70 ++++++++++++++++++++++++++++++++++++++++++ types/primitive_u16.go | 70 ++++++++++++++++++++++++++++++++++++++++++ types/primitive_u32.go | 70 ++++++++++++++++++++++++++++++++++++++++++ types/primitive_u64.go | 70 ++++++++++++++++++++++++++++++++++++++++++ types/primitive_u8.go | 70 ++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 560 insertions(+) diff --git a/types/primitive_s16.go b/types/primitive_s16.go index 6c79085d..4bf8ebc1 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -43,6 +43,76 @@ func (s16 *PrimitiveS16) String() string { return fmt.Sprintf("%d", s16.Value) } +// AND runs a bitwise AND operation on the PrimitiveS16 value. Consumes and returns a NEX primitive +func (s16 *PrimitiveS16) AND(other *PrimitiveS16) *PrimitiveS16 { + return NewPrimitiveS16(s16.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveS16 value. Consumes and returns a Go primitive +func (s16 *PrimitiveS16) PAND(value int16) int16 { + return s16.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveS16 value. Consumes and returns a NEX primitive +func (s16 *PrimitiveS16) OR(other *PrimitiveS16) *PrimitiveS16 { + return NewPrimitiveS16(s16.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveS16 value. Consumes and returns a Go primitive +func (s16 *PrimitiveS16) POR(value int16) int16 { + return s16.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveS16 value. Consumes and returns a NEX primitive +func (s16 *PrimitiveS16) XOR(other *PrimitiveS16) *PrimitiveS16 { + return NewPrimitiveS16(s16.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveS16 value. Consumes and returns a Go primitive +func (s16 *PrimitiveS16) PXOR(value int16) int16 { + return s16.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveS16 value. Returns a NEX primitive +func (s16 *PrimitiveS16) NOT() *PrimitiveS16 { + return NewPrimitiveS16(s16.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveS16 value. Returns a Go primitive +func (s16 *PrimitiveS16) PNOT() int16 { + return ^s16.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveS16 value. Consumes and returns a NEX primitive +func (s16 *PrimitiveS16) ANDNOT(other *PrimitiveS16) *PrimitiveS16 { + return NewPrimitiveS16(s16.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveS16 value. Consumes and returns a Go primitive +func (s16 *PrimitiveS16) PANDNOT(value int16) int16 { + return s16.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveS16 value. Consumes and returns a NEX primitive +func (s16 *PrimitiveS16) LShift(other *PrimitiveS16) *PrimitiveS16 { + return NewPrimitiveS16(s16.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS16 value. Consumes and returns a Go primitive +func (s16 *PrimitiveS16) PLShift(value int16) int16 { + return s16.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveS16 value. Consumes and returns a NEX primitive +func (s16 *PrimitiveS16) RShift(other *PrimitiveS16) *PrimitiveS16 { + return NewPrimitiveS16(s16.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS16 value. Consumes and returns a Go primitive +func (s16 *PrimitiveS16) PRShift(value int16) int16 { + return s16.Value &^ value +} + // NewPrimitiveS16 returns a new PrimitiveS16 func NewPrimitiveS16(i16 int16) *PrimitiveS16 { return &PrimitiveS16{Value: i16} diff --git a/types/primitive_s32.go b/types/primitive_s32.go index b2a2cb54..a0cca294 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -43,6 +43,76 @@ func (s32 *PrimitiveS32) String() string { return fmt.Sprintf("%d", s32.Value) } +// AND runs a bitwise AND operation on the PrimitiveS32 value. Consumes and returns a NEX primitive +func (s32 *PrimitiveS32) AND(other *PrimitiveS32) *PrimitiveS32 { + return NewPrimitiveS32(s32.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveS32 value. Consumes and returns a Go primitive +func (s32 *PrimitiveS32) PAND(value int32) int32 { + return s32.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveS32 value. Consumes and returns a NEX primitive +func (s32 *PrimitiveS32) OR(other *PrimitiveS32) *PrimitiveS32 { + return NewPrimitiveS32(s32.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveS32 value. Consumes and returns a Go primitive +func (s32 *PrimitiveS32) POR(value int32) int32 { + return s32.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveS32 value. Consumes and returns a NEX primitive +func (s32 *PrimitiveS32) XOR(other *PrimitiveS32) *PrimitiveS32 { + return NewPrimitiveS32(s32.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveS32 value. Consumes and returns a Go primitive +func (s32 *PrimitiveS32) PXOR(value int32) int32 { + return s32.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveS32 value. Returns a NEX primitive +func (s32 *PrimitiveS32) NOT() *PrimitiveS32 { + return NewPrimitiveS32(s32.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveS32 value. Returns a Go primitive +func (s32 *PrimitiveS32) PNOT() int32 { + return ^s32.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveS32 value. Consumes and returns a NEX primitive +func (s32 *PrimitiveS32) ANDNOT(other *PrimitiveS32) *PrimitiveS32 { + return NewPrimitiveS32(s32.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveS32 value. Consumes and returns a Go primitive +func (s32 *PrimitiveS32) PANDNOT(value int32) int32 { + return s32.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveS32 value. Consumes and returns a NEX primitive +func (s32 *PrimitiveS32) LShift(other *PrimitiveS32) *PrimitiveS32 { + return NewPrimitiveS32(s32.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS32 value. Consumes and returns a Go primitive +func (s32 *PrimitiveS32) PLShift(value int32) int32 { + return s32.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveS32 value. Consumes and returns a NEX primitive +func (s32 *PrimitiveS32) RShift(other *PrimitiveS32) *PrimitiveS32 { + return NewPrimitiveS32(s32.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS32 value. Consumes and returns a Go primitive +func (s32 *PrimitiveS32) PRShift(value int32) int32 { + return s32.Value &^ value +} + // NewPrimitiveS32 returns a new PrimitiveS32 func NewPrimitiveS32(i32 int32) *PrimitiveS32 { return &PrimitiveS32{Value: i32} diff --git a/types/primitive_s64.go b/types/primitive_s64.go index c97e7d12..946cf9f0 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -43,6 +43,76 @@ func (s64 *PrimitiveS64) String() string { return fmt.Sprintf("%d", s64.Value) } +// AND runs a bitwise AND operation on the PrimitiveS64 value. Consumes and returns a NEX primitive +func (s64 *PrimitiveS64) AND(other *PrimitiveS64) *PrimitiveS64 { + return NewPrimitiveS64(s64.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveS64 value. Consumes and returns a Go primitive +func (s64 *PrimitiveS64) PAND(value int64) int64 { + return s64.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveS64 value. Consumes and returns a NEX primitive +func (s64 *PrimitiveS64) OR(other *PrimitiveS64) *PrimitiveS64 { + return NewPrimitiveS64(s64.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveS64 value. Consumes and returns a Go primitive +func (s64 *PrimitiveS64) POR(value int64) int64 { + return s64.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveS64 value. Consumes and returns a NEX primitive +func (s64 *PrimitiveS64) XOR(other *PrimitiveS64) *PrimitiveS64 { + return NewPrimitiveS64(s64.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveS64 value. Consumes and returns a Go primitive +func (s64 *PrimitiveS64) PXOR(value int64) int64 { + return s64.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveS64 value. Returns a NEX primitive +func (s64 *PrimitiveS64) NOT() *PrimitiveS64 { + return NewPrimitiveS64(s64.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveS64 value. Returns a Go primitive +func (s64 *PrimitiveS64) PNOT() int64 { + return ^s64.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveS64 value. Consumes and returns a NEX primitive +func (s64 *PrimitiveS64) ANDNOT(other *PrimitiveS64) *PrimitiveS64 { + return NewPrimitiveS64(s64.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveS64 value. Consumes and returns a Go primitive +func (s64 *PrimitiveS64) PANDNOT(value int64) int64 { + return s64.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveS64 value. Consumes and returns a NEX primitive +func (s64 *PrimitiveS64) LShift(other *PrimitiveS64) *PrimitiveS64 { + return NewPrimitiveS64(s64.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS64 value. Consumes and returns a Go primitive +func (s64 *PrimitiveS64) PLShift(value int64) int64 { + return s64.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveS64 value. Consumes and returns a NEX primitive +func (s64 *PrimitiveS64) RShift(other *PrimitiveS64) *PrimitiveS64 { + return NewPrimitiveS64(s64.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS64 value. Consumes and returns a Go primitive +func (s64 *PrimitiveS64) PRShift(value int64) int64 { + return s64.Value &^ value +} + // NewPrimitiveS64 returns a new PrimitiveS64 func NewPrimitiveS64(i64 int64) *PrimitiveS64 { return &PrimitiveS64{Value: i64} diff --git a/types/primitive_s8.go b/types/primitive_s8.go index 3b066c79..2328fed0 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -43,6 +43,76 @@ func (s8 *PrimitiveS8) String() string { return fmt.Sprintf("%d", s8.Value) } +// AND runs a bitwise AND operation on the PrimitiveS8 value. Consumes and returns a NEX primitive +func (s8 *PrimitiveS8) AND(other *PrimitiveS8) *PrimitiveS8 { + return NewPrimitiveS8(s8.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveS8 value. Consumes and returns a Go primitive +func (s8 *PrimitiveS8) PAND(value int8) int8 { + return s8.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveS8 value. Consumes and returns a NEX primitive +func (s8 *PrimitiveS8) OR(other *PrimitiveS8) *PrimitiveS8 { + return NewPrimitiveS8(s8.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveS8 value. Consumes and returns a Go primitive +func (s8 *PrimitiveS8) POR(value int8) int8 { + return s8.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveS8 value. Consumes and returns a NEX primitive +func (s8 *PrimitiveS8) XOR(other *PrimitiveS8) *PrimitiveS8 { + return NewPrimitiveS8(s8.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveS8 value. Consumes and returns a Go primitive +func (s8 *PrimitiveS8) PXOR(value int8) int8 { + return s8.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveS8 value. Returns a NEX primitive +func (s8 *PrimitiveS8) NOT() *PrimitiveS8 { + return NewPrimitiveS8(s8.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveS8 value. Returns a Go primitive +func (s8 *PrimitiveS8) PNOT() int8 { + return ^s8.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveS8 value. Consumes and returns a NEX primitive +func (s8 *PrimitiveS8) ANDNOT(other *PrimitiveS8) *PrimitiveS8 { + return NewPrimitiveS8(s8.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveS8 value. Consumes and returns a Go primitive +func (s8 *PrimitiveS8) PANDNOT(value int8) int8 { + return s8.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveS8 value. Consumes and returns a NEX primitive +func (s8 *PrimitiveS8) LShift(other *PrimitiveS8) *PrimitiveS8 { + return NewPrimitiveS8(s8.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS8 value. Consumes and returns a Go primitive +func (s8 *PrimitiveS8) PLShift(value int8) int8 { + return s8.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveS8 value. Consumes and returns a NEX primitive +func (s8 *PrimitiveS8) RShift(other *PrimitiveS8) *PrimitiveS8 { + return NewPrimitiveS8(s8.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS8 value. Consumes and returns a Go primitive +func (s8 *PrimitiveS8) PRShift(value int8) int8 { + return s8.Value &^ value +} + // NewPrimitiveS8 returns a new PrimitiveS8 func NewPrimitiveS8(i8 int8) *PrimitiveS8 { return &PrimitiveS8{Value: i8} diff --git a/types/primitive_u16.go b/types/primitive_u16.go index a83f78eb..7b174ddf 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -43,6 +43,76 @@ func (u16 *PrimitiveU16) String() string { return fmt.Sprintf("%d", u16.Value) } +// AND runs a bitwise AND operation on the PrimitiveU16 value. Consumes and returns a NEX primitive +func (u16 *PrimitiveU16) AND(other *PrimitiveU16) *PrimitiveU16 { + return NewPrimitiveU16(u16.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveU16 value. Consumes and returns a Go primitive +func (u16 *PrimitiveU16) PAND(value uint16) uint16 { + return u16.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveU16 value. Consumes and returns a NEX primitive +func (u16 *PrimitiveU16) OR(other *PrimitiveU16) *PrimitiveU16 { + return NewPrimitiveU16(u16.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveU16 value. Consumes and returns a Go primitive +func (u16 *PrimitiveU16) POR(value uint16) uint16 { + return u16.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveU16 value. Consumes and returns a NEX primitive +func (u16 *PrimitiveU16) XOR(other *PrimitiveU16) *PrimitiveU16 { + return NewPrimitiveU16(u16.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveU16 value. Consumes and returns a Go primitive +func (u16 *PrimitiveU16) PXOR(value uint16) uint16 { + return u16.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveU16 value. Returns a NEX primitive +func (u16 *PrimitiveU16) NOT() *PrimitiveU16 { + return NewPrimitiveU16(u16.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveU16 value. Returns a Go primitive +func (u16 *PrimitiveU16) PNOT() uint16 { + return ^u16.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveU16 value. Consumes and returns a NEX primitive +func (u16 *PrimitiveU16) ANDNOT(other *PrimitiveU16) *PrimitiveU16 { + return NewPrimitiveU16(u16.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveU16 value. Consumes and returns a Go primitive +func (u16 *PrimitiveU16) PANDNOT(value uint16) uint16 { + return u16.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveU16 value. Consumes and returns a NEX primitive +func (u16 *PrimitiveU16) LShift(other *PrimitiveU16) *PrimitiveU16 { + return NewPrimitiveU16(u16.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU16 value. Consumes and returns a Go primitive +func (u16 *PrimitiveU16) PLShift(value uint16) uint16 { + return u16.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveU16 value. Consumes and returns a NEX primitive +func (u16 *PrimitiveU16) RShift(other *PrimitiveU16) *PrimitiveU16 { + return NewPrimitiveU16(u16.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU16 value. Consumes and returns a Go primitive +func (u16 *PrimitiveU16) PRShift(value uint16) uint16 { + return u16.Value &^ value +} + // NewPrimitiveU16 returns a new PrimitiveU16 func NewPrimitiveU16(ui16 uint16) *PrimitiveU16 { return &PrimitiveU16{Value: ui16} diff --git a/types/primitive_u32.go b/types/primitive_u32.go index e98a7702..744ad4ac 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -43,6 +43,76 @@ func (u32 *PrimitiveU32) String() string { return fmt.Sprintf("%d", u32.Value) } +// AND runs a bitwise AND operation on the PrimitiveU32 value. Consumes and returns a NEX primitive +func (u32 *PrimitiveU32) AND(other *PrimitiveU32) *PrimitiveU32 { + return NewPrimitiveU32(u32.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveU32 value. Consumes and returns a Go primitive +func (u32 *PrimitiveU32) PAND(value uint32) uint32 { + return u32.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveU32 value. Consumes and returns a NEX primitive +func (u32 *PrimitiveU32) OR(other *PrimitiveU32) *PrimitiveU32 { + return NewPrimitiveU32(u32.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveU32 value. Consumes and returns a Go primitive +func (u32 *PrimitiveU32) POR(value uint32) uint32 { + return u32.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveU32 value. Consumes and returns a NEX primitive +func (u32 *PrimitiveU32) XOR(other *PrimitiveU32) *PrimitiveU32 { + return NewPrimitiveU32(u32.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveU32 value. Consumes and returns a Go primitive +func (u32 *PrimitiveU32) PXOR(value uint32) uint32 { + return u32.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveU32 value. Returns a NEX primitive +func (u32 *PrimitiveU32) NOT() *PrimitiveU32 { + return NewPrimitiveU32(u32.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveU32 value. Returns a Go primitive +func (u32 *PrimitiveU32) PNOT() uint32 { + return ^u32.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveU32 value. Consumes and returns a NEX primitive +func (u32 *PrimitiveU32) ANDNOT(other *PrimitiveU32) *PrimitiveU32 { + return NewPrimitiveU32(u32.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveU32 value. Consumes and returns a Go primitive +func (u32 *PrimitiveU32) PANDNOT(value uint32) uint32 { + return u32.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveU32 value. Consumes and returns a NEX primitive +func (u32 *PrimitiveU32) LShift(other *PrimitiveU32) *PrimitiveU32 { + return NewPrimitiveU32(u32.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU32 value. Consumes and returns a Go primitive +func (u32 *PrimitiveU32) PLShift(value uint32) uint32 { + return u32.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveU32 value. Consumes and returns a NEX primitive +func (u32 *PrimitiveU32) RShift(other *PrimitiveU32) *PrimitiveU32 { + return NewPrimitiveU32(u32.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU32 value. Consumes and returns a Go primitive +func (u32 *PrimitiveU32) PRShift(value uint32) uint32 { + return u32.Value &^ value +} + // NewPrimitiveU32 returns a new PrimitiveU32 func NewPrimitiveU32(ui32 uint32) *PrimitiveU32 { return &PrimitiveU32{Value: ui32} diff --git a/types/primitive_u64.go b/types/primitive_u64.go index f2d4baca..f960f351 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -43,6 +43,76 @@ func (u64 *PrimitiveU64) String() string { return fmt.Sprintf("%d", u64.Value) } +// AND runs a bitwise AND operation on the PrimitiveU64 value. Consumes and returns a NEX primitive +func (u64 *PrimitiveU64) AND(other *PrimitiveU64) *PrimitiveU64 { + return NewPrimitiveU64(u64.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveU64 value. Consumes and returns a Go primitive +func (u64 *PrimitiveU64) PAND(value uint64) uint64 { + return u64.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveU64 value. Consumes and returns a NEX primitive +func (u64 *PrimitiveU64) OR(other *PrimitiveU64) *PrimitiveU64 { + return NewPrimitiveU64(u64.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveU64 value. Consumes and returns a Go primitive +func (u64 *PrimitiveU64) POR(value uint64) uint64 { + return u64.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveU64 value. Consumes and returns a NEX primitive +func (u64 *PrimitiveU64) XOR(other *PrimitiveU64) *PrimitiveU64 { + return NewPrimitiveU64(u64.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveU64 value. Consumes and returns a Go primitive +func (u64 *PrimitiveU64) PXOR(value uint64) uint64 { + return u64.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveU64 value. Returns a NEX primitive +func (u64 *PrimitiveU64) NOT() *PrimitiveU64 { + return NewPrimitiveU64(u64.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveU64 value. Returns a Go primitive +func (u64 *PrimitiveU64) PNOT() uint64 { + return ^u64.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveU64 value. Consumes and returns a NEX primitive +func (u64 *PrimitiveU64) ANDNOT(other *PrimitiveU64) *PrimitiveU64 { + return NewPrimitiveU64(u64.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveU64 value. Consumes and returns a Go primitive +func (u64 *PrimitiveU64) PANDNOT(value uint64) uint64 { + return u64.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveU64 value. Consumes and returns a NEX primitive +func (u64 *PrimitiveU64) LShift(other *PrimitiveU64) *PrimitiveU64 { + return NewPrimitiveU64(u64.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU64 value. Consumes and returns a Go primitive +func (u64 *PrimitiveU64) PLShift(value uint64) uint64 { + return u64.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveU64 value. Consumes and returns a NEX primitive +func (u64 *PrimitiveU64) RShift(other *PrimitiveU64) *PrimitiveU64 { + return NewPrimitiveU64(u64.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU64 value. Consumes and returns a Go primitive +func (u64 *PrimitiveU64) PRShift(value uint64) uint64 { + return u64.Value &^ value +} + // NewPrimitiveU64 returns a new PrimitiveU64 func NewPrimitiveU64(ui64 uint64) *PrimitiveU64 { return &PrimitiveU64{Value: ui64} diff --git a/types/primitive_u8.go b/types/primitive_u8.go index ccccabc2..78c3c793 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -43,6 +43,76 @@ func (u8 *PrimitiveU8) String() string { return fmt.Sprintf("%d", u8.Value) } +// AND runs a bitwise AND operation on the PrimitiveU8 value. Consumes and returns a NEX primitive +func (u8 *PrimitiveU8) AND(other *PrimitiveU8) *PrimitiveU8 { + return NewPrimitiveU8(u8.PAND(other.Value)) +} + +// PAND (Primitive AND) runs a bitwise AND operation on the PrimitiveU8 value. Consumes and returns a Go primitive +func (u8 *PrimitiveU8) PAND(value uint8) uint8 { + return u8.Value & value +} + +// OR runs a bitwise OR operation on the PrimitiveU8 value. Consumes and returns a NEX primitive +func (u8 *PrimitiveU8) OR(other *PrimitiveU8) *PrimitiveU8 { + return NewPrimitiveU8(u8.POR(other.Value)) +} + +// POR (Primitive OR) runs a bitwise OR operation on the PrimitiveU8 value. Consumes and returns a Go primitive +func (u8 *PrimitiveU8) POR(value uint8) uint8 { + return u8.Value | value +} + +// XOR runs a bitwise XOR operation on the PrimitiveU8 value. Consumes and returns a NEX primitive +func (u8 *PrimitiveU8) XOR(other *PrimitiveU8) *PrimitiveU8 { + return NewPrimitiveU8(u8.PXOR(other.Value)) +} + +// PXOR (Primitive XOR) runs a bitwise XOR operation on the PrimitiveU8 value. Consumes and returns a Go primitive +func (u8 *PrimitiveU8) PXOR(value uint8) uint8 { + return u8.Value ^ value +} + +// NOT runs a bitwise NOT operation on the PrimitiveU8 value. Returns a NEX primitive +func (u8 *PrimitiveU8) NOT() *PrimitiveU8 { + return NewPrimitiveU8(u8.PNOT()) +} + +// PNOT (Primitive NOT) runs a bitwise NOT operation on the PrimitiveU8 value. Returns a Go primitive +func (u8 *PrimitiveU8) PNOT() uint8 { + return ^u8.Value +} + +// ANDNOT runs a bitwise ANDNOT operation on the PrimitiveU8 value. Consumes and returns a NEX primitive +func (u8 *PrimitiveU8) ANDNOT(other *PrimitiveU8) *PrimitiveU8 { + return NewPrimitiveU8(u8.PANDNOT(other.Value)) +} + +// PANDNOT (Primitive AND-NOT) runs a bitwise AND-NOT operation on the PrimitiveU8 value. Consumes and returns a Go primitive +func (u8 *PrimitiveU8) PANDNOT(value uint8) uint8 { + return u8.Value &^ value +} + +// LShift runs a left shift operation on the PrimitiveU8 value. Consumes and returns a NEX primitive +func (u8 *PrimitiveU8) LShift(other *PrimitiveU8) *PrimitiveU8 { + return NewPrimitiveU8(u8.PLShift(other.Value)) +} + +// PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU8 value. Consumes and returns a Go primitive +func (u8 *PrimitiveU8) PLShift(value uint8) uint8 { + return u8.Value &^ value +} + +// RShift runs a right shift operation on the PrimitiveU8 value. Consumes and returns a NEX primitive +func (u8 *PrimitiveU8) RShift(other *PrimitiveU8) *PrimitiveU8 { + return NewPrimitiveU8(u8.PRShift(other.Value)) +} + +// PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU8 value. Consumes and returns a Go primitive +func (u8 *PrimitiveU8) PRShift(value uint8) uint8 { + return u8.Value &^ value +} + // NewPrimitiveU8 returns a new PrimitiveU8 func NewPrimitiveU8(ui8 uint8) *PrimitiveU8 { return &PrimitiveU8{Value: ui8} From af6a183ecb057673ce12ef869ae746e1a57baeb7 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 20 Jan 2024 13:44:09 -0500 Subject: [PATCH 122/178] types: added List.Each receiver --- types/list.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/types/list.go b/types/list.go index 4df9dc86..9f18249a 100644 --- a/types/list.go +++ b/types/list.go @@ -105,6 +105,19 @@ func (l *List[T]) Length() int { return len(l.real) } +// Each runs a callback function for every element in the List +// The List should not be modified inside the callback function +// Returns true if the loop was terminated early +func (l *List[T]) Each(callback func(i int, value T) bool) bool { + for i, value := range l.real { + if callback(i, value) { + return true + } + } + + return false +} + // String returns a string representation of the struct func (l *List[T]) String() string { return fmt.Sprintf("%v", l.real) From bfc793b039ea0e52b29138bcd75c597f7cb2e709 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 21 Jan 2024 21:06:22 -0500 Subject: [PATCH 123/178] types: added List.Contains receiver --- types/list.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/types/list.go b/types/list.go index 9f18249a..0768c262 100644 --- a/types/list.go +++ b/types/list.go @@ -118,6 +118,23 @@ func (l *List[T]) Each(callback func(i int, value T) bool) bool { return false } +// Contains checks if the provided value exists in the List +func (l *List[T]) Contains(checkValue T) bool { + contains := false + + l.Each(func(_ int, value T) bool { + if value.Equals(checkValue) { + contains = true + + return true + } + + return false + }) + + return contains +} + // String returns a string representation of the struct func (l *List[T]) String() string { return fmt.Sprintf("%v", l.real) From af072671da16ba4aee054485be080aba75c54018 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 21 Jan 2024 21:15:13 -0500 Subject: [PATCH 124/178] types: added List.Slice receiver --- types/list.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/types/list.go b/types/list.go index 0768c262..1fa8be50 100644 --- a/types/list.go +++ b/types/list.go @@ -81,6 +81,11 @@ func (l *List[T]) Equals(o RVType) bool { return true } +// Slice returns the real underlying slice for the List +func (l *List[T]) Slice() []T { + return l.real +} + // Append appends an element to the List internal slice func (l *List[T]) Append(value T) { l.real = append(l.real, value) From e8b45e95add79c3805fc45a3d4c64115577b61bc Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 21 Jan 2024 21:15:57 -0500 Subject: [PATCH 125/178] prudp: removed unused arguments from FindConnectionByID and FindConnectionByPID --- prudp_endpoint.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index e1e4ff01..67af48c1 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -561,7 +561,7 @@ func (pep *PRUDPEndPoint) sendPing(connection *PRUDPConnection) { } // FindConnectionByID returns the PRUDP client connected with the given connection ID -func (pep *PRUDPEndPoint) FindConnectionByID(serverPort, serverStreamType uint8, connectedID uint32) *PRUDPConnection { +func (pep *PRUDPEndPoint) FindConnectionByID(connectedID uint32) *PRUDPConnection { var connection *PRUDPConnection pep.Connections.Each(func(discriminator string, pc *PRUDPConnection) bool { @@ -577,7 +577,7 @@ func (pep *PRUDPEndPoint) FindConnectionByID(serverPort, serverStreamType uint8, } // FindConnectionByPID returns the PRUDP client connected with the given PID -func (pep *PRUDPEndPoint) FindConnectionByPID(serverPort, serverStreamType uint8, pid uint64) *PRUDPConnection { +func (pep *PRUDPEndPoint) FindConnectionByPID(pid uint64) *PRUDPConnection { var connection *PRUDPConnection pep.Connections.Each(func(discriminator string, pc *PRUDPConnection) bool { From 6006842ed2ce9c563398e05e4ad2cbbb128b58be Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 22 Jan 2024 10:43:38 -0500 Subject: [PATCH 126/178] prudp: added StationURLs to PRUDPConnection --- prudp_connection.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/prudp_connection.go b/prudp_connection.go index f0db7fc0..bd576df1 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -32,6 +32,7 @@ type PRUDPConnection struct { outgoingPingSequenceIDCounter *Counter[uint16] heartbeatTimer *time.Timer pingKickTimer *time.Timer + StationURLs *types.List[*types.StationURL] } // Server returns the PRUDP server the connections socket is connected to @@ -208,11 +209,16 @@ func (pc *PRUDPConnection) stopHeartbeatTimers() { // NewPRUDPConnection creates a new PRUDPConnection for a given socket func NewPRUDPConnection(socket *SocketConnection) *PRUDPConnection { - return &PRUDPConnection{ + pc := &PRUDPConnection{ Socket: socket, pid: types.NewPID(0), slidingWindows: NewMutexMap[uint8, *SlidingWindow](), outgoingUnreliableSequenceIDCounter: NewCounter[uint16](1), outgoingPingSequenceIDCounter: NewCounter[uint16](0), + StationURLs: types.NewList[*types.StationURL](), } + + pc.StationURLs.Type = types.NewStationURL("") + + return pc } From af9a57eca0de8916328228bba67724343b106784 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 22 Jan 2024 11:00:22 -0500 Subject: [PATCH 127/178] types: added List.DeleteIndex and List.Remove receivers --- types/list.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/types/list.go b/types/list.go index 1fa8be50..a892e565 100644 --- a/types/list.go +++ b/types/list.go @@ -100,6 +100,27 @@ func (l *List[T]) Get(index int) (T, error) { return l.real[index], nil } +// DeleteIndex deletes an element at the given index. Returns an error if the index is OOB +func (l *List[T]) DeleteIndex(index int) error { + if index < 0 || index >= len(l.real) { + return errors.New("Index out of bounds") + } + + l.real = append(l.real[:index], l.real[index+1:]...) + + return nil +} + +// Remove removes the first occurance of the input from the List. Returns an error if the index is OOB +func (l *List[T]) Remove(check T) { + for i, value := range l.real { + if value.Equals(check) { + l.DeleteIndex(i) + return + } + } +} + // SetFromData sets the List's internal slice to the input data func (l *List[T]) SetFromData(data []T) { l.real = data From 00d12db5832913a18907c44704427a56c5889b1a Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 22 Jan 2024 12:38:51 -0500 Subject: [PATCH 128/178] types: added List.SetIndex receiver --- types/list.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/types/list.go b/types/list.go index a892e565..56a303cd 100644 --- a/types/list.go +++ b/types/list.go @@ -100,6 +100,11 @@ func (l *List[T]) Get(index int) (T, error) { return l.real[index], nil } +// SetIndex sets a value in the List at the given index +func (l *List[T]) SetIndex(index int, value T) { + l.real[index] = value +} + // DeleteIndex deletes an element at the given index. Returns an error if the index is OOB func (l *List[T]) DeleteIndex(index int) error { if index < 0 || index >= len(l.real) { From 64b26b3eedd7078d4edb9926136f9c0cd6920868 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 24 Jan 2024 12:42:11 -0500 Subject: [PATCH 129/178] prudp: rework Kerberos ticket generation to use user accounts properly --- account.go | 27 +++++++++++++++++++++++ hpp_packet.go | 9 ++++---- hpp_server.go | 18 ++-------------- kerberos.go | 2 +- prudp_endpoint.go | 19 +++++++++++++---- prudp_server.go | 36 ++++--------------------------- server_interface.go | 4 ---- test/auth.go | 21 ++++++++++-------- test/generate_ticket.go | 13 ++++++------ test/hpp.go | 3 ++- test/main.go | 47 ++++++++++++++++++++++++++++++++++++++++- test/secure.go | 8 ++++--- 12 files changed, 126 insertions(+), 81 deletions(-) create mode 100644 account.go diff --git a/account.go b/account.go new file mode 100644 index 00000000..ab27101a --- /dev/null +++ b/account.go @@ -0,0 +1,27 @@ +package nex + +import "github.com/PretendoNetwork/nex-go/types" + +// Account represents a game server account. +// +// Game server accounts are separate from other accounts, like Uplay, Nintendo Accounts and NNIDs. +// These exist only on the game server. Account passwords are used as part of the servers Kerberos +// authentication. There are also a collection of non-user, special, accounts. These include a +// guest account, an account which represents the authentication server, and one which represents +// the secure server. See https://nintendo-wiki.pretendo.network/docs/nex/kerberos for more information. +type Account struct { + PID *types.PID // * The PID of the account. PIDs are unique IDs per account. NEX PIDs start at 1800000000 and decrement with each new account. + Username string // * The username for the account. For NEX user accounts this is the same as the accounts PID. + Password string // * The password for the account. For NEX accounts this is always 16 characters long using seemingly any ASCII character +} + +// NewAccount returns a new instance of Account. +// This does not register an account, only creates a new +// struct instance. +func NewAccount(pid *types.PID, username, password string) *Account { + return &Account{ + PID: pid, + Username: username, + Password: password, + } +} diff --git a/hpp_packet.go b/hpp_packet.go index fbfb9c3e..64139bf6 100644 --- a/hpp_packet.go +++ b/hpp_packet.go @@ -91,13 +91,14 @@ func (p *HPPPacket) validatePasswordSignature(signature string) error { } func (p *HPPPacket) calculatePasswordSignature() ([]byte, error) { - pid := p.Sender().PID() - password, _ := p.Sender().Server().PasswordFromPID(pid) - if password == "" { + sender := p.Sender() + pid := sender.PID() + account, _ := sender.Server().(*HPPServer).AccountDetailsByPID(pid) + if account == nil { return nil, errors.New("PID does not exist") } - key := DeriveKerberosKey(pid, []byte(password)) + key := DeriveKerberosKey(pid, []byte(account.Password)) signature, err := p.calculateSignature(p.payload, key) if err != nil { diff --git a/hpp_server.go b/hpp_server.go index 9e8a3614..2608f687 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -23,8 +23,9 @@ type HPPServer struct { utilityProtocolVersion *LibraryVersion natTraversalProtocolVersion *LibraryVersion dataHandlers []func(packet PacketInterface) - passwordFromPIDHandler func(pid *types.PID) (string, uint32) byteStreamSettings *ByteStreamSettings + AccountDetailsByPID func(pid *types.PID) (*Account, uint32) + AccountDetailsByUsername func(username string) (*Account, uint32) } // OnData adds an event handler which is fired when a new HPP request is received @@ -258,21 +259,6 @@ func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { return s.natTraversalProtocolVersion } -// PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result -func (s *HPPServer) PasswordFromPID(pid *types.PID) (string, uint32) { - if s.passwordFromPIDHandler == nil { - logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") - return "", Errors.Core.NotImplemented - } - - return s.passwordFromPIDHandler(pid) -} - -// SetPasswordFromPIDFunction sets the function for HPP to get a NEX password using the PID -func (s *HPPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) { - s.passwordFromPIDHandler = handler -} - // ByteStreamSettings returns the settings to be used for ByteStreams func (s *HPPServer) ByteStreamSettings() *ByteStreamSettings { return s.byteStreamSettings diff --git a/kerberos.go b/kerberos.go index 22ac1345..3aeea057 100644 --- a/kerberos.go +++ b/kerberos.go @@ -178,7 +178,7 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *ByteStreamIn, key []byte) ti.Issued = timestamp ti.SourcePID = userPID - ti.SessionKey = stream.ReadBytesNext(int64(stream.Server.(*PRUDPServer).kerberosKeySize)) + ti.SessionKey = stream.ReadBytesNext(int64(stream.Server.(*PRUDPServer).SessionKeyLength)) return nil } diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 67af48c1..359f0720 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -21,7 +21,9 @@ type PRUDPEndPoint struct { packetEventHandlers map[string][]func(packet PacketInterface) connectionEndedEventHandlers []func(connection *PRUDPConnection) ConnectionIDCounter *Counter[uint32] - IsSecureEndpoint bool // TODO - Remove this? Assume if CONNECT packet has a body, it's the secure server? + ServerAccount *Account + AccountDetailsByPID func(pid *types.PID) (*Account, uint32) + AccountDetailsByUsername func(username string) (*Account, uint32) } // OnData adds an event handler which is fired when a new DATA packet is received @@ -285,7 +287,7 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { payload := make([]byte, 0) - if pep.IsSecureEndpoint { + if len(packet.Payload()) != 0 { sessionKey, pid, checkValue, err := pep.readKerberosTicket(packet.Payload()) if err != nil { logger.Error(err.Error()) @@ -358,7 +360,17 @@ func (pep *PRUDPEndPoint) readKerberosTicket(payload []byte) ([]byte, *types.PID return nil, nil, 0, err } - serverKey := DeriveKerberosKey(types.NewPID(2), pep.Server.kerberosPassword) + // * Sanity checks + serverAccount, _ := pep.AccountDetailsByUsername(pep.ServerAccount.Username) + if serverAccount == nil { + return nil, nil, 0, errors.New("Failed to find endpoint server account") + } + + if serverAccount.Password != pep.ServerAccount.Password { + return nil, nil, 0, errors.New("Password for endpoint server account does not match the records from AccountDetailsByUsername") + } + + serverKey := DeriveKerberosKey(serverAccount.PID, []byte(serverAccount.Password)) ticket := NewKerberosTicketInternalData() if err := ticket.Decrypt(NewByteStreamIn(ticketData.Value, pep.Server), serverKey); err != nil { @@ -601,6 +613,5 @@ func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { packetEventHandlers: make(map[string][]func(PacketInterface)), connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), ConnectionIDCounter: NewCounter[uint32](0), - IsSecureEndpoint: false, } } diff --git a/prudp_server.go b/prudp_server.go index b35fb191..42c39804 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -8,7 +8,6 @@ import ( "runtime" "time" - "github.com/PretendoNetwork/nex-go/types" "github.com/lxzan/gws" ) @@ -20,9 +19,8 @@ type PRUDPServer struct { Connections *MutexMap[string, *SocketConnection] SupportedFunctions uint32 accessKey string - kerberosPassword []byte kerberosTicketVersion int - kerberosKeySize int + SessionKeyLength int FragmentSize int version *LibraryVersion datastoreProtocolVersion *LibraryVersion @@ -33,7 +31,6 @@ type PRUDPServer struct { utilityProtocolVersion *LibraryVersion natTraversalProtocolVersion *LibraryVersion pingTimeout time.Duration - passwordFromPIDHandler func(pid *types.PID) (string, uint32) PRUDPv1ConnectionSignatureKey []byte byteStreamSettings *ByteStreamSettings PRUDPV0Settings *PRUDPV0Settings @@ -318,16 +315,6 @@ func (ps *PRUDPServer) SetAccessKey(accessKey string) { ps.accessKey = accessKey } -// KerberosPassword returns the server kerberos password -func (ps *PRUDPServer) KerberosPassword() []byte { - return ps.kerberosPassword -} - -// SetKerberosPassword sets the server kerberos password -func (ps *PRUDPServer) SetKerberosPassword(kerberosPassword []byte) { - ps.kerberosPassword = kerberosPassword -} - // SetFragmentSize sets the max size for a packets payload func (ps *PRUDPServer) SetFragmentSize(fragmentSize int) { // TODO - Derive this value from the MTU @@ -351,12 +338,12 @@ func (ps *PRUDPServer) SetKerberosTicketVersion(kerberosTicketVersion int) { // KerberosKeySize gets the size for the kerberos session key func (ps *PRUDPServer) KerberosKeySize() int { - return ps.kerberosKeySize + return ps.SessionKeyLength } // SetKerberosKeySize sets the size for the kerberos session key func (ps *PRUDPServer) SetKerberosKeySize(kerberosKeySize int) { - ps.kerberosKeySize = kerberosKeySize + ps.SessionKeyLength = kerberosKeySize } // LibraryVersion returns the server NEX version @@ -446,21 +433,6 @@ func (ps *PRUDPServer) NATTraversalProtocolVersion() *LibraryVersion { return ps.natTraversalProtocolVersion } -// PasswordFromPID calls the function set with SetPasswordFromPIDFunction and returns the result -func (ps *PRUDPServer) PasswordFromPID(pid *types.PID) (string, uint32) { - if ps.passwordFromPIDHandler == nil { - logger.Errorf("Missing PasswordFromPID handler. Set with SetPasswordFromPIDFunction") - return "", Errors.Core.NotImplemented - } - - return ps.passwordFromPIDHandler(pid) -} - -// SetPasswordFromPIDFunction sets the function for the auth server to get a NEX password using the PID -func (ps *PRUDPServer) SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) { - ps.passwordFromPIDHandler = handler -} - // ByteStreamSettings returns the settings to be used for ByteStreams func (ps *PRUDPServer) ByteStreamSettings() *ByteStreamSettings { return ps.byteStreamSettings @@ -498,7 +470,7 @@ func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ Endpoints: NewMutexMap[uint8, *PRUDPEndPoint](), Connections: NewMutexMap[string, *SocketConnection](), - kerberosKeySize: 32, + SessionKeyLength: 32, FragmentSize: 1300, pingTimeout: time.Second * 15, byteStreamSettings: NewByteStreamSettings(), diff --git a/server_interface.go b/server_interface.go index a3ae1a2a..56c3fad6 100644 --- a/server_interface.go +++ b/server_interface.go @@ -1,7 +1,5 @@ package nex -import "github.com/PretendoNetwork/nex-go/types" - // ServerInterface defines all the methods a server should have regardless of type type ServerInterface interface { AccessKey() string @@ -17,8 +15,6 @@ type ServerInterface interface { SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) OnData(handler func(packet PacketInterface)) - PasswordFromPID(pid *types.PID) (string, uint32) - SetPasswordFromPIDFunction(handler func(pid *types.PID) (string, uint32)) ByteStreamSettings() *ByteStreamSettings SetByteStreamSettings(settings *ByteStreamSettings) } diff --git a/test/auth.go b/test/auth.go index b6590415..22aaaef4 100644 --- a/test/auth.go +++ b/test/auth.go @@ -3,7 +3,6 @@ package main import ( "fmt" - "strconv" "github.com/PretendoNetwork/nex-go" "github.com/PretendoNetwork/nex-go/types" @@ -18,6 +17,10 @@ func startAuthenticationServer() { endpoint := nex.NewPRUDPEndPoint(1) + endpoint.AccountDetailsByPID = accountDetailsByPID + endpoint.AccountDetailsByUsername = accountDetailsByUsername + endpoint.ServerAccount = authenticationServerAccount + endpoint.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { request := packet.RMCMessage() @@ -38,7 +41,6 @@ func startAuthenticationServer() { authServer.SetFragmentSize(962) authServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) - authServer.SetKerberosPassword([]byte("password")) authServer.SetKerberosKeySize(16) authServer.SetAccessKey("ridfebb9") authServer.BindPRUDPEndPoint(endpoint) @@ -58,14 +60,12 @@ func login(packet nex.PRUDPPacketInterface) { panic(err) } - converted, err := strconv.Atoi(strUserName.Value) - if err != nil { - panic(err) - } + sourceAccount, _ := accountDetailsByUsername(strUserName.Value) + targetAccount, _ := accountDetailsByUsername(secureServerAccount.Username) retval := types.NewQResultSuccess(0x00010001) - pidPrincipal := types.NewPID(uint64(converted)) - pbufResponse := types.NewBuffer(generateTicket(pidPrincipal, types.NewPID(2))) + pidPrincipal := sourceAccount.PID + pbufResponse := types.NewBuffer(generateTicket(sourceAccount, targetAccount)) pConnectionData := types.NewRVConnectionData() strReturnMsg := types.NewString("Test Build") @@ -125,8 +125,11 @@ func requestTicket(packet nex.PRUDPPacketInterface) { panic(err) } + sourceAccount, _ := accountDetailsByPID(idSource) + targetAccount, _ := accountDetailsByPID(idTarget) + retval := types.NewQResultSuccess(0x00010001) - pbufResponse := types.NewBuffer(generateTicket(idSource, idTarget)) + pbufResponse := types.NewBuffer(generateTicket(sourceAccount, targetAccount)) responseStream := nex.NewByteStreamOut(authServer) diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 36079156..7e0d3975 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -7,9 +7,10 @@ import ( "github.com/PretendoNetwork/nex-go/types" ) -func generateTicket(userPID *types.PID, targetPID *types.PID) []byte { - userKey := nex.DeriveKerberosKey(userPID, []byte("z5sykuHnX0q5SCJN")) - targetKey := nex.DeriveKerberosKey(targetPID, []byte("password")) +// func generateTicket(userPID *types.PID, targetPID *types.PID) []byte { +func generateTicket(source *nex.Account, target *nex.Account) []byte { + sourceKey := nex.DeriveKerberosKey(source.PID, []byte(source.Password)) + targetKey := nex.DeriveKerberosKey(target.PID, []byte(target.Password)) sessionKey := make([]byte, authServer.KerberosKeySize()) _, err := rand.Read(sessionKey) @@ -21,17 +22,17 @@ func generateTicket(userPID *types.PID, targetPID *types.PID) []byte { serverTime := types.NewDateTime(0).Now() ticketInternalData.Issued = serverTime - ticketInternalData.SourcePID = userPID + ticketInternalData.SourcePID = source.PID ticketInternalData.SessionKey = sessionKey encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewByteStreamOut(authServer)) ticket := nex.NewKerberosTicket() ticket.SessionKey = sessionKey - ticket.TargetPID = targetPID + ticket.TargetPID = target.PID ticket.InternalData = types.NewBuffer(encryptedTicketInternalData) - encryptedTicket, _ := ticket.Encrypt(userKey, nex.NewByteStreamOut(authServer)) + encryptedTicket, _ := ticket.Encrypt(sourceKey, nex.NewByteStreamOut(authServer)) return encryptedTicket } diff --git a/test/hpp.go b/test/hpp.go index 1464e060..1eec2b9c 100644 --- a/test/hpp.go +++ b/test/hpp.go @@ -79,7 +79,8 @@ func startHPPServer() { hppServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(2, 4, 1)) hppServer.SetAccessKey("76f26496") - hppServer.SetPasswordFromPIDFunction(passwordFromPID) + hppServer.AccountDetailsByPID = accountDetailsByPID + hppServer.AccountDetailsByUsername = accountDetailsByUsername hppServer.Listen(12345) } diff --git a/test/main.go b/test/main.go index 39ac9b22..de895830 100644 --- a/test/main.go +++ b/test/main.go @@ -1,10 +1,55 @@ package main -import "sync" +import ( + "sync" + + "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/types" +) var wg sync.WaitGroup +var authenticationServerAccount *nex.Account +var secureServerAccount *nex.Account +var testUserAccount *nex.Account + +func accountDetailsByPID(pid *types.PID) (*nex.Account, uint32) { + if pid.Equals(authenticationServerAccount.PID) { + return authenticationServerAccount, 0 + } + + if pid.Equals(secureServerAccount.PID) { + return secureServerAccount, 0 + } + + if pid.Equals(testUserAccount.PID) { + return testUserAccount, 0 + } + + return nil, nex.Errors.RendezVous.InvalidPID +} + +func accountDetailsByUsername(username string) (*nex.Account, uint32) { + if username == authenticationServerAccount.Username { + return authenticationServerAccount, 0 + } + + if username == secureServerAccount.Username { + return secureServerAccount, 0 + } + + if username == testUserAccount.Username { + return testUserAccount, 0 + } + + return nil, nex.Errors.RendezVous.InvalidUsername +} + func main() { + authenticationServerAccount = nex.NewAccount(types.NewPID(1), "Quazal Authentication", "authpassword") + secureServerAccount = nex.NewAccount(types.NewPID(2), "Quazal Rendez-Vous", "securepassword") + testUserAccount = nex.NewAccount(types.NewPID(1800000000), "1800000000", "nexuserpassword") + wg.Add(3) go startAuthenticationServer() diff --git a/test/secure.go b/test/secure.go index 9b3c54cb..9649d226 100644 --- a/test/secure.go +++ b/test/secure.go @@ -46,8 +46,11 @@ func startSecureServer() { secureServer = nex.NewPRUDPServer() - endpoint := nex.NewPRUDPEndPoint(2) - endpoint.IsSecureEndpoint = true + endpoint := nex.NewPRUDPEndPoint(1) + + endpoint.AccountDetailsByPID = accountDetailsByPID + endpoint.AccountDetailsByUsername = accountDetailsByUsername + endpoint.ServerAccount = secureServerAccount endpoint.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { @@ -77,7 +80,6 @@ func startSecureServer() { secureServer.SetFragmentSize(962) secureServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) - secureServer.SetKerberosPassword([]byte("password")) secureServer.SetKerberosKeySize(16) secureServer.SetAccessKey("ridfebb9") secureServer.BindPRUDPEndPoint(endpoint) From f1549a79ca43a99c014b6a8c4772f433511326bc Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 24 Jan 2024 13:08:49 -0500 Subject: [PATCH 130/178] prudp: removed unnecessary getters/setters --- kerberos.go | 4 ++-- prudp_server.go | 17 +---------------- test/auth.go | 6 +++--- test/generate_ticket.go | 5 ++--- test/secure.go | 2 +- 5 files changed, 9 insertions(+), 25 deletions(-) diff --git a/kerberos.go b/kerberos.go index 3aeea057..bd858e61 100644 --- a/kerberos.go +++ b/kerberos.go @@ -108,7 +108,7 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *ByteStreamOut) data := stream.Bytes() - if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { + if stream.Server.(*PRUDPServer).KerberosTicketVersion == 1 { ticketKey := make([]byte, 16) _, err := rand.Read(ticketKey) if err != nil { @@ -140,7 +140,7 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *ByteStreamOut) // Decrypt decrypts the given data and populates the struct func (ti *KerberosTicketInternalData) Decrypt(stream *ByteStreamIn, key []byte) error { - if stream.Server.(*PRUDPServer).kerberosTicketVersion == 1 { + if stream.Server.(*PRUDPServer).KerberosTicketVersion == 1 { ticketKey := types.NewBuffer(nil) if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data key. %s", err.Error()) diff --git a/prudp_server.go b/prudp_server.go index 42c39804..a4ab0afb 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -19,7 +19,7 @@ type PRUDPServer struct { Connections *MutexMap[string, *SocketConnection] SupportedFunctions uint32 accessKey string - kerberosTicketVersion int + KerberosTicketVersion int SessionKeyLength int FragmentSize int version *LibraryVersion @@ -331,21 +331,6 @@ func (ps *PRUDPServer) SetFragmentSize(fragmentSize int) { ps.FragmentSize = fragmentSize } -// SetKerberosTicketVersion sets the version used when handling kerberos tickets -func (ps *PRUDPServer) SetKerberosTicketVersion(kerberosTicketVersion int) { - ps.kerberosTicketVersion = kerberosTicketVersion -} - -// KerberosKeySize gets the size for the kerberos session key -func (ps *PRUDPServer) KerberosKeySize() int { - return ps.SessionKeyLength -} - -// SetKerberosKeySize sets the size for the kerberos session key -func (ps *PRUDPServer) SetKerberosKeySize(kerberosKeySize int) { - ps.SessionKeyLength = kerberosKeySize -} - // LibraryVersion returns the server NEX version func (ps *PRUDPServer) LibraryVersion() *LibraryVersion { return ps.version diff --git a/test/auth.go b/test/auth.go index 22aaaef4..16cd8733 100644 --- a/test/auth.go +++ b/test/auth.go @@ -41,7 +41,7 @@ func startAuthenticationServer() { authServer.SetFragmentSize(962) authServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) - authServer.SetKerberosKeySize(16) + authServer.SessionKeyLength = 16 authServer.SetAccessKey("ridfebb9") authServer.BindPRUDPEndPoint(endpoint) authServer.Listen(60000) @@ -65,7 +65,7 @@ func login(packet nex.PRUDPPacketInterface) { retval := types.NewQResultSuccess(0x00010001) pidPrincipal := sourceAccount.PID - pbufResponse := types.NewBuffer(generateTicket(sourceAccount, targetAccount)) + pbufResponse := types.NewBuffer(generateTicket(sourceAccount, targetAccount, authServer.SessionKeyLength)) pConnectionData := types.NewRVConnectionData() strReturnMsg := types.NewString("Test Build") @@ -129,7 +129,7 @@ func requestTicket(packet nex.PRUDPPacketInterface) { targetAccount, _ := accountDetailsByPID(idTarget) retval := types.NewQResultSuccess(0x00010001) - pbufResponse := types.NewBuffer(generateTicket(sourceAccount, targetAccount)) + pbufResponse := types.NewBuffer(generateTicket(sourceAccount, targetAccount, authServer.SessionKeyLength)) responseStream := nex.NewByteStreamOut(authServer) diff --git a/test/generate_ticket.go b/test/generate_ticket.go index 7e0d3975..b2023eab 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -7,11 +7,10 @@ import ( "github.com/PretendoNetwork/nex-go/types" ) -// func generateTicket(userPID *types.PID, targetPID *types.PID) []byte { -func generateTicket(source *nex.Account, target *nex.Account) []byte { +func generateTicket(source *nex.Account, target *nex.Account, sessionKeyLength int) []byte { sourceKey := nex.DeriveKerberosKey(source.PID, []byte(source.Password)) targetKey := nex.DeriveKerberosKey(target.PID, []byte(target.Password)) - sessionKey := make([]byte, authServer.KerberosKeySize()) + sessionKey := make([]byte, sessionKeyLength) _, err := rand.Read(sessionKey) if err != nil { diff --git a/test/secure.go b/test/secure.go index 9649d226..613324c3 100644 --- a/test/secure.go +++ b/test/secure.go @@ -80,7 +80,7 @@ func startSecureServer() { secureServer.SetFragmentSize(962) secureServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) - secureServer.SetKerberosKeySize(16) + secureServer.SessionKeyLength = 16 secureServer.SetAccessKey("ridfebb9") secureServer.BindPRUDPEndPoint(endpoint) secureServer.Listen(60001) From a26750ac1058d63b8214dba4993e4bb54877e3e7 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 24 Jan 2024 13:10:42 -0500 Subject: [PATCH 131/178] chore: fix MaxPacketRetransmissions comment on StreamSettings --- stream_settings.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stream_settings.go b/stream_settings.go index 1cbb8408..a7ee51e6 100644 --- a/stream_settings.go +++ b/stream_settings.go @@ -13,7 +13,7 @@ import ( // Not all values are used at this time, and only exist to future-proof for a later time. type StreamSettings struct { ExtraRestransmitTimeoutTrigger uint32 // * Unused. The number of times a packet can be retransmitted before ExtraRetransmitTimeoutMultiplier is used - MaxPacketRetransmissions uint32 // *The number of times a packet can be retransmitted before the timeout time is checked + MaxPacketRetransmissions uint32 // * The number of times a packet can be retransmitted before the timeout time is checked KeepAliveTimeout uint32 // * Unused. Presumably the time a packet can be alive for without acknowledgement? Milliseconds? ChecksumBase uint32 // * Unused. The base value for PRUDPv0 checksum calculations FaultDetectionEnabled bool // * Unused. Presumably used to detect PIA faults? From b5ef1a0a2b1b012b3b2bb110e3bcf2fe85343c67 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 24 Jan 2024 13:19:23 -0500 Subject: [PATCH 132/178] chore: rename errors to result codes --- errors.go | 660 ------------------------------------------------ hpp_server.go | 2 +- init.go | 2 +- result_codes.go | 660 ++++++++++++++++++++++++++++++++++++++++++++++++ test/main.go | 4 +- 5 files changed, 664 insertions(+), 664 deletions(-) delete mode 100644 errors.go create mode 100644 result_codes.go diff --git a/errors.go b/errors.go deleted file mode 100644 index 921f9418..00000000 --- a/errors.go +++ /dev/null @@ -1,660 +0,0 @@ -package nex - -import ( - "reflect" - "strconv" -) - -var errorMask = 1 << 31 - -type nexerrors struct { - Core struct { - Unknown uint32 - NotImplemented uint32 - InvalidPointer uint32 - OperationAborted uint32 - Exception uint32 - AccessDenied uint32 - InvalidHandle uint32 - InvalidIndex uint32 - OutOfMemory uint32 - InvalidArgument uint32 - Timeout uint32 - InitializationFailure uint32 - CallInitiationFailure uint32 - RegistrationError uint32 - BufferOverflow uint32 - InvalidLockState uint32 - InvalidSequence uint32 - SystemError uint32 - Cancelled uint32 - } - - DDL struct { - InvalidSignature uint32 - IncorrectVersion uint32 - } - - RendezVous struct { - ConnectionFailure uint32 - NotAuthenticated uint32 - InvalidUsername uint32 - InvalidPassword uint32 - UsernameAlreadyExists uint32 - AccountDisabled uint32 - AccountExpired uint32 - ConcurrentLoginDenied uint32 - EncryptionFailure uint32 - InvalidPID uint32 - MaxConnectionsReached uint32 - InvalidGID uint32 - InvalidControlScriptID uint32 - InvalidOperationInLiveEnvironment uint32 - DuplicateEntry uint32 - ControlScriptFailure uint32 - ClassNotFound uint32 - SessionVoid uint32 - DDLMismatch uint32 - InvalidConfiguration uint32 - SessionFull uint32 - InvalidGatheringPassword uint32 - WithoutParticipationPeriod uint32 - PersistentGatheringCreationMax uint32 - PersistentGatheringParticipationMax uint32 - DeniedByParticipants uint32 - ParticipantInBlackList uint32 - GameServerMaintenance uint32 - OperationPostpone uint32 - OutOfRatingRange uint32 - ConnectionDisconnected uint32 - InvalidOperation uint32 - NotParticipatedGathering uint32 - MatchmakeSessionUserPasswordUnmatch uint32 - MatchmakeSessionSystemPasswordUnmatch uint32 - UserIsOffline uint32 - AlreadyParticipatedGathering uint32 - PermissionDenied uint32 - NotFriend uint32 - SessionClosed uint32 - DatabaseTemporarilyUnavailable uint32 - InvalidUniqueID uint32 - MatchmakingWithdrawn uint32 - LimitExceeded uint32 - AccountTemporarilyDisabled uint32 - PartiallyServiceClosed uint32 - ConnectionDisconnectedForConcurrentLogin uint32 - } - - PythonCore struct { - Exception uint32 - TypeError uint32 - IndexError uint32 - InvalidReference uint32 - CallFailure uint32 - MemoryError uint32 - KeyError uint32 - OperationError uint32 - ConversionError uint32 - ValidationError uint32 - } - - Transport struct { - Unknown uint32 - ConnectionFailure uint32 - InvalidURL uint32 - InvalidKey uint32 - InvalidURLType uint32 - DuplicateEndpoint uint32 - IOError uint32 - Timeout uint32 - ConnectionReset uint32 - IncorrectRemoteAuthentication uint32 - ServerRequestError uint32 - DecompressionFailure uint32 - ReliableSendBufferFullFatal uint32 - UPnPCannotInit uint32 - UPnPCannotAddMapping uint32 - NatPMPCannotInit uint32 - NatPMPCannotAddMapping uint32 - UnsupportedNAT uint32 - DNSError uint32 - ProxyError uint32 - DataRemaining uint32 - NoBuffer uint32 - NotFound uint32 - TemporaryServerError uint32 - PermanentServerError uint32 - ServiceUnavailable uint32 - ReliableSendBufferFull uint32 - InvalidStation uint32 - InvalidSubStreamID uint32 - PacketBufferFull uint32 - NatTraversalError uint32 - NatCheckError uint32 - } - - DOCore struct { - StationNotReached uint32 - TargetStationDisconnect uint32 - LocalStationLeaving uint32 - ObjectNotFound uint32 - InvalidRole uint32 - CallTimeout uint32 - RMCDispatchFailed uint32 - MigrationInProgress uint32 - NoAuthority uint32 - NoTargetStationSpecified uint32 - JoinFailed uint32 - JoinDenied uint32 - ConnectivityTestFailed uint32 - Unknown uint32 - UnfreedReferences uint32 - JobTerminationFailed uint32 - InvalidState uint32 - FaultRecoveryFatal uint32 - FaultRecoveryJobProcessFailed uint32 - StationInconsitency uint32 - AbnormalMasterState uint32 - VersionMismatch uint32 - } - - FPD struct { - NotInitialized uint32 - AlreadyInitialized uint32 - NotConnected uint32 - Connected uint32 - InitializationFailure uint32 - OutOfMemory uint32 - RmcFailed uint32 - InvalidArgument uint32 - InvalidLocalAccountID uint32 - InvalidPrincipalID uint32 - InvalidLocalFriendCode uint32 - LocalAccountNotExists uint32 - LocalAccountNotLoaded uint32 - LocalAccountAlreadyLoaded uint32 - FriendAlreadyExists uint32 - FriendNotExists uint32 - FriendNumMax uint32 - NotFriend uint32 - FileIO uint32 - P2PInternetProhibited uint32 - Unknown uint32 - InvalidState uint32 - AddFriendProhibited uint32 - InvalidAccount uint32 - BlacklistedByMe uint32 - FriendAlreadyAdded uint32 - MyFriendListLimitExceed uint32 - RequestLimitExceed uint32 - InvalidMessageID uint32 - MessageIsNotMine uint32 - MessageIsNotForMe uint32 - FriendRequestBlocked uint32 - NotInMyFriendList uint32 - FriendListedByMe uint32 - NotInMyBlacklist uint32 - IncompatibleAccount uint32 - BlockSettingChangeNotAllowed uint32 - SizeLimitExceeded uint32 - OperationNotAllowed uint32 - NotNetworkAccount uint32 - NotificationNotFound uint32 - PreferenceNotInitialized uint32 - FriendRequestNotAllowed uint32 - } - - Ranking struct { - NotInitialized uint32 - InvalidArgument uint32 - RegistrationError uint32 - NotFound uint32 - InvalidScore uint32 - InvalidDataSize uint32 - PermissionDenied uint32 - Unknown uint32 - NotImplemented uint32 - } - - Authentication struct { - NASAuthenticateError uint32 - TokenParseError uint32 - HTTPConnectionError uint32 - HTTPDNSError uint32 - HTTPGetProxySetting uint32 - TokenExpired uint32 - ValidationFailed uint32 - InvalidParam uint32 - PrincipalIDUnmatched uint32 - MoveCountUnmatch uint32 - UnderMaintenance uint32 - UnsupportedVersion uint32 - ServerVersionIsOld uint32 - Unknown uint32 - ClientVersionIsOld uint32 - AccountLibraryError uint32 - ServiceNoLongerAvailable uint32 - UnknownApplication uint32 - ApplicationVersionIsOld uint32 - OutOfService uint32 - NetworkServiceLicenseRequired uint32 - NetworkServiceLicenseSystemError uint32 - NetworkServiceLicenseError3 uint32 - NetworkServiceLicenseError4 uint32 - } - - DataStore struct { - Unknown uint32 - InvalidArgument uint32 - PermissionDenied uint32 - NotFound uint32 - AlreadyLocked uint32 - UnderReviewing uint32 - Expired uint32 - InvalidCheckToken uint32 - SystemFileError uint32 - OverCapacity uint32 - OperationNotAllowed uint32 - InvalidPassword uint32 - ValueNotEqual uint32 - } - - ServiceItem struct { - Unknown uint32 - InvalidArgument uint32 - EShopUnknownHTTPError uint32 - EShopResponseParseError uint32 - NotOwned uint32 - InvalidLimitationType uint32 - ConsumptionRightShortage uint32 - } - - MatchmakeReferee struct { - Unknown uint32 - InvalidArgument uint32 - AlreadyExists uint32 - NotParticipatedGathering uint32 - NotParticipatedRound uint32 - StatsNotFound uint32 - RoundNotFound uint32 - RoundArbitrated uint32 - RoundNotArbitrated uint32 - } - - Subscriber struct { - Unknown uint32 - InvalidArgument uint32 - OverLimit uint32 - PermissionDenied uint32 - } - - Ranking2 struct { - Unknown uint32 - InvalidArgument uint32 - InvalidScore uint32 - } - - SmartDeviceVoiceChat struct { - Unknown uint32 - InvalidArgument uint32 - InvalidResponse uint32 - InvalidAccessToken uint32 - Unauthorized uint32 - AccessError uint32 - UserNotFound uint32 - RoomNotFound uint32 - RoomNotActivated uint32 - ApplicationNotSupported uint32 - InternalServerError uint32 - ServiceUnavailable uint32 - UnexpectedError uint32 - UnderMaintenance uint32 - ServiceNoLongerAvailable uint32 - AccountTemporarilyDisabled uint32 - PermissionDenied uint32 - NetworkServiceLicenseRequired uint32 - AccountLibraryError uint32 - GameModeNotFound uint32 - } - - Screening struct { - Unknown uint32 - InvalidArgument uint32 - NotFound uint32 - } - - Custom struct { - Unknown uint32 - } - - Ess struct { - Unknown uint32 - GameSessionError uint32 - GameSessionMaintenance uint32 - } -} - -// ErrorNames contains a map of all the error string names, indexed by the error ID -var ErrorNames = map[uint32]string{} - -// Errors provides a struct containing error codes using dot-notation -var Errors nexerrors - -func initErrorsData() { - Errors.Core.Unknown = 0x00010001 - Errors.Core.NotImplemented = 0x00010002 - Errors.Core.InvalidPointer = 0x00010003 - Errors.Core.OperationAborted = 0x00010004 - Errors.Core.Exception = 0x00010005 - Errors.Core.AccessDenied = 0x00010006 - Errors.Core.InvalidHandle = 0x00010007 - Errors.Core.InvalidIndex = 0x00010008 - Errors.Core.OutOfMemory = 0x00010009 - Errors.Core.InvalidArgument = 0x0001000A - Errors.Core.Timeout = 0x0001000B - Errors.Core.InitializationFailure = 0x0001000C - Errors.Core.CallInitiationFailure = 0x0001000D - Errors.Core.RegistrationError = 0x0001000E - Errors.Core.BufferOverflow = 0x0001000F - Errors.Core.InvalidLockState = 0x00010010 - Errors.Core.InvalidSequence = 0x00010011 - Errors.Core.SystemError = 0x00010012 - Errors.Core.Cancelled = 0x00010013 - - Errors.DDL.InvalidSignature = 0x00020001 - Errors.DDL.IncorrectVersion = 0x00020002 - - Errors.RendezVous.ConnectionFailure = 0x00030001 - Errors.RendezVous.NotAuthenticated = 0x00030002 - Errors.RendezVous.InvalidUsername = 0x00030064 - Errors.RendezVous.InvalidPassword = 0x00030065 - Errors.RendezVous.UsernameAlreadyExists = 0x00030066 - Errors.RendezVous.AccountDisabled = 0x00030067 - Errors.RendezVous.AccountExpired = 0x00030068 - Errors.RendezVous.ConcurrentLoginDenied = 0x00030069 - Errors.RendezVous.EncryptionFailure = 0x0003006A - Errors.RendezVous.InvalidPID = 0x0003006B - Errors.RendezVous.MaxConnectionsReached = 0x0003006C - Errors.RendezVous.InvalidGID = 0x0003006D - Errors.RendezVous.InvalidControlScriptID = 0x0003006E - Errors.RendezVous.InvalidOperationInLiveEnvironment = 0x0003006F - Errors.RendezVous.DuplicateEntry = 0x00030070 - Errors.RendezVous.ControlScriptFailure = 0x00030071 - Errors.RendezVous.ClassNotFound = 0x00030072 - Errors.RendezVous.SessionVoid = 0x00030073 - Errors.RendezVous.DDLMismatch = 0x00030075 - Errors.RendezVous.InvalidConfiguration = 0x00030076 - Errors.RendezVous.SessionFull = 0x000300C8 - Errors.RendezVous.InvalidGatheringPassword = 0x000300C9 - Errors.RendezVous.WithoutParticipationPeriod = 0x000300CA - Errors.RendezVous.PersistentGatheringCreationMax = 0x000300CB - Errors.RendezVous.PersistentGatheringParticipationMax = 0x000300CC - Errors.RendezVous.DeniedByParticipants = 0x000300CD - Errors.RendezVous.ParticipantInBlackList = 0x000300CE - Errors.RendezVous.GameServerMaintenance = 0x000300CF - Errors.RendezVous.OperationPostpone = 0x000300D0 - Errors.RendezVous.OutOfRatingRange = 0x000300D1 - Errors.RendezVous.ConnectionDisconnected = 0x000300D2 - Errors.RendezVous.InvalidOperation = 0x000300D3 - Errors.RendezVous.NotParticipatedGathering = 0x000300D4 - Errors.RendezVous.MatchmakeSessionUserPasswordUnmatch = 0x000300D5 - Errors.RendezVous.MatchmakeSessionSystemPasswordUnmatch = 0x000300D6 - Errors.RendezVous.UserIsOffline = 0x000300D7 - Errors.RendezVous.AlreadyParticipatedGathering = 0x000300D8 - Errors.RendezVous.PermissionDenied = 0x000300D9 - Errors.RendezVous.NotFriend = 0x000300DA - Errors.RendezVous.SessionClosed = 0x000300DB - Errors.RendezVous.DatabaseTemporarilyUnavailable = 0x000300DC - Errors.RendezVous.InvalidUniqueID = 0x000300DD - Errors.RendezVous.MatchmakingWithdrawn = 0x000300DE - Errors.RendezVous.LimitExceeded = 0x000300DF - Errors.RendezVous.AccountTemporarilyDisabled = 0x000300E0 - Errors.RendezVous.PartiallyServiceClosed = 0x000300E1 - Errors.RendezVous.ConnectionDisconnectedForConcurrentLogin = 0x000300E2 - - Errors.PythonCore.Exception = 0x00040001 - Errors.PythonCore.TypeError = 0x00040002 - Errors.PythonCore.IndexError = 0x00040003 - Errors.PythonCore.InvalidReference = 0x00040004 - Errors.PythonCore.CallFailure = 0x00040005 - Errors.PythonCore.MemoryError = 0x00040006 - Errors.PythonCore.KeyError = 0x00040007 - Errors.PythonCore.OperationError = 0x00040008 - Errors.PythonCore.ConversionError = 0x00040009 - Errors.PythonCore.ValidationError = 0x0004000A - - Errors.Transport.Unknown = 0x00050001 - Errors.Transport.ConnectionFailure = 0x00050002 - Errors.Transport.InvalidURL = 0x00050003 - Errors.Transport.InvalidKey = 0x00050004 - Errors.Transport.InvalidURLType = 0x00050005 - Errors.Transport.DuplicateEndpoint = 0x00050006 - Errors.Transport.IOError = 0x00050007 - Errors.Transport.Timeout = 0x00050008 - Errors.Transport.ConnectionReset = 0x00050009 - Errors.Transport.IncorrectRemoteAuthentication = 0x0005000A - Errors.Transport.ServerRequestError = 0x0005000B - Errors.Transport.DecompressionFailure = 0x0005000C - Errors.Transport.ReliableSendBufferFullFatal = 0x0005000D - Errors.Transport.UPnPCannotInit = 0x0005000E - Errors.Transport.UPnPCannotAddMapping = 0x0005000F - Errors.Transport.NatPMPCannotInit = 0x00050010 - Errors.Transport.NatPMPCannotAddMapping = 0x00050011 - Errors.Transport.UnsupportedNAT = 0x00050013 - Errors.Transport.DNSError = 0x00050014 - Errors.Transport.ProxyError = 0x00050015 - Errors.Transport.DataRemaining = 0x00050016 - Errors.Transport.NoBuffer = 0x00050017 - Errors.Transport.NotFound = 0x00050018 - Errors.Transport.TemporaryServerError = 0x00050019 - Errors.Transport.PermanentServerError = 0x0005001A - Errors.Transport.ServiceUnavailable = 0x0005001B - Errors.Transport.ReliableSendBufferFull = 0x0005001C - Errors.Transport.InvalidStation = 0x0005001D - Errors.Transport.InvalidSubStreamID = 0x0005001E - Errors.Transport.PacketBufferFull = 0x0005001F - Errors.Transport.NatTraversalError = 0x00050020 - Errors.Transport.NatCheckError = 0x00050021 - - Errors.DOCore.StationNotReached = 0x00060001 - Errors.DOCore.TargetStationDisconnect = 0x00060002 - Errors.DOCore.LocalStationLeaving = 0x00060003 - Errors.DOCore.ObjectNotFound = 0x00060004 - Errors.DOCore.InvalidRole = 0x00060005 - Errors.DOCore.CallTimeout = 0x00060006 - Errors.DOCore.RMCDispatchFailed = 0x00060007 - Errors.DOCore.MigrationInProgress = 0x00060008 - Errors.DOCore.NoAuthority = 0x00060009 - Errors.DOCore.NoTargetStationSpecified = 0x0006000A - Errors.DOCore.JoinFailed = 0x0006000B - Errors.DOCore.JoinDenied = 0x0006000C - Errors.DOCore.ConnectivityTestFailed = 0x0006000D - Errors.DOCore.Unknown = 0x0006000E - Errors.DOCore.UnfreedReferences = 0x0006000F - Errors.DOCore.JobTerminationFailed = 0x00060010 - Errors.DOCore.InvalidState = 0x00060011 - Errors.DOCore.FaultRecoveryFatal = 0x00060012 - Errors.DOCore.FaultRecoveryJobProcessFailed = 0x00060013 - Errors.DOCore.StationInconsitency = 0x00060014 - Errors.DOCore.AbnormalMasterState = 0x00060015 - Errors.DOCore.VersionMismatch = 0x00060016 - - Errors.FPD.NotInitialized = 0x00650000 - Errors.FPD.AlreadyInitialized = 0x00650001 - Errors.FPD.NotConnected = 0x00650002 - Errors.FPD.Connected = 0x00650003 - Errors.FPD.InitializationFailure = 0x00650004 - Errors.FPD.OutOfMemory = 0x00650005 - Errors.FPD.RmcFailed = 0x00650006 - Errors.FPD.InvalidArgument = 0x00650007 - Errors.FPD.InvalidLocalAccountID = 0x00650008 - Errors.FPD.InvalidPrincipalID = 0x00650009 - Errors.FPD.InvalidLocalFriendCode = 0x0065000A - Errors.FPD.LocalAccountNotExists = 0x0065000B - Errors.FPD.LocalAccountNotLoaded = 0x0065000C - Errors.FPD.LocalAccountAlreadyLoaded = 0x0065000D - Errors.FPD.FriendAlreadyExists = 0x0065000E - Errors.FPD.FriendNotExists = 0x0065000F - Errors.FPD.FriendNumMax = 0x00650010 - Errors.FPD.NotFriend = 0x00650011 - Errors.FPD.FileIO = 0x00650012 - Errors.FPD.P2PInternetProhibited = 0x00650013 - Errors.FPD.Unknown = 0x00650014 - Errors.FPD.InvalidState = 0x00650015 - Errors.FPD.AddFriendProhibited = 0x00650017 - Errors.FPD.InvalidAccount = 0x00650019 - Errors.FPD.BlacklistedByMe = 0x0065001A - Errors.FPD.FriendAlreadyAdded = 0x0065001C - Errors.FPD.MyFriendListLimitExceed = 0x0065001D - Errors.FPD.RequestLimitExceed = 0x0065001E - Errors.FPD.InvalidMessageID = 0x0065001F - Errors.FPD.MessageIsNotMine = 0x00650020 - Errors.FPD.MessageIsNotForMe = 0x00650021 - Errors.FPD.FriendRequestBlocked = 0x00650022 - Errors.FPD.NotInMyFriendList = 0x00650023 - Errors.FPD.FriendListedByMe = 0x00650024 - Errors.FPD.NotInMyBlacklist = 0x00650025 - Errors.FPD.IncompatibleAccount = 0x00650026 - Errors.FPD.BlockSettingChangeNotAllowed = 0x00650027 - Errors.FPD.SizeLimitExceeded = 0x00650028 - Errors.FPD.OperationNotAllowed = 0x00650029 - Errors.FPD.NotNetworkAccount = 0x0065002A - Errors.FPD.NotificationNotFound = 0x0065002B - Errors.FPD.PreferenceNotInitialized = 0x0065002C - Errors.FPD.FriendRequestNotAllowed = 0x0065002D - - Errors.Ranking.NotInitialized = 0x00670001 - Errors.Ranking.InvalidArgument = 0x00670002 - Errors.Ranking.RegistrationError = 0x00670003 - Errors.Ranking.NotFound = 0x00670005 - Errors.Ranking.InvalidScore = 0x00670006 - Errors.Ranking.InvalidDataSize = 0x00670007 - Errors.Ranking.PermissionDenied = 0x00670009 - Errors.Ranking.Unknown = 0x0067000A - Errors.Ranking.NotImplemented = 0x0067000B - - Errors.Authentication.NASAuthenticateError = 0x00680001 - Errors.Authentication.TokenParseError = 0x00680002 - Errors.Authentication.HTTPConnectionError = 0x00680003 - Errors.Authentication.HTTPDNSError = 0x00680004 - Errors.Authentication.HTTPGetProxySetting = 0x00680005 - Errors.Authentication.TokenExpired = 0x00680006 - Errors.Authentication.ValidationFailed = 0x00680007 - Errors.Authentication.InvalidParam = 0x00680008 - Errors.Authentication.PrincipalIDUnmatched = 0x00680009 - Errors.Authentication.MoveCountUnmatch = 0x0068000A - Errors.Authentication.UnderMaintenance = 0x0068000B - Errors.Authentication.UnsupportedVersion = 0x0068000C - Errors.Authentication.ServerVersionIsOld = 0x0068000D - Errors.Authentication.Unknown = 0x0068000E - Errors.Authentication.ClientVersionIsOld = 0x0068000F - Errors.Authentication.AccountLibraryError = 0x00680010 - Errors.Authentication.ServiceNoLongerAvailable = 0x00680011 - Errors.Authentication.UnknownApplication = 0x00680012 - Errors.Authentication.ApplicationVersionIsOld = 0x00680013 - Errors.Authentication.OutOfService = 0x00680014 - Errors.Authentication.NetworkServiceLicenseRequired = 0x00680015 - Errors.Authentication.NetworkServiceLicenseSystemError = 0x00680016 - Errors.Authentication.NetworkServiceLicenseError3 = 0x00680017 - Errors.Authentication.NetworkServiceLicenseError4 = 0x00680018 - - Errors.DataStore.Unknown = 0x00690001 - Errors.DataStore.InvalidArgument = 0x00690002 - Errors.DataStore.PermissionDenied = 0x00690003 - Errors.DataStore.NotFound = 0x00690004 - Errors.DataStore.AlreadyLocked = 0x00690005 - Errors.DataStore.UnderReviewing = 0x00690006 - Errors.DataStore.Expired = 0x00690007 - Errors.DataStore.InvalidCheckToken = 0x00690008 - Errors.DataStore.SystemFileError = 0x00690009 - Errors.DataStore.OverCapacity = 0x0069000A - Errors.DataStore.OperationNotAllowed = 0x0069000B - Errors.DataStore.InvalidPassword = 0x0069000C - Errors.DataStore.ValueNotEqual = 0x0069000D - - Errors.ServiceItem.Unknown = 0x006C0001 - Errors.ServiceItem.InvalidArgument = 0x006C0002 - Errors.ServiceItem.EShopUnknownHTTPError = 0x006C0003 - Errors.ServiceItem.EShopResponseParseError = 0x006C0004 - Errors.ServiceItem.NotOwned = 0x006C0005 - Errors.ServiceItem.InvalidLimitationType = 0x006C0006 - Errors.ServiceItem.ConsumptionRightShortage = 0x006C0007 - - Errors.MatchmakeReferee.Unknown = 0x006F0001 - Errors.MatchmakeReferee.InvalidArgument = 0x006F0002 - Errors.MatchmakeReferee.AlreadyExists = 0x006F0003 - Errors.MatchmakeReferee.NotParticipatedGathering = 0x006F0004 - Errors.MatchmakeReferee.NotParticipatedRound = 0x006F0005 - Errors.MatchmakeReferee.StatsNotFound = 0x006F0006 - Errors.MatchmakeReferee.RoundNotFound = 0x006F0007 - Errors.MatchmakeReferee.RoundArbitrated = 0x006F0008 - Errors.MatchmakeReferee.RoundNotArbitrated = 0x006F0009 - - Errors.Subscriber.Unknown = 0x00700001 - Errors.Subscriber.InvalidArgument = 0x00700002 - Errors.Subscriber.OverLimit = 0x00700003 - Errors.Subscriber.PermissionDenied = 0x00700004 - - Errors.Ranking2.Unknown = 0x00710001 - Errors.Ranking2.InvalidArgument = 0x00710002 - Errors.Ranking2.InvalidScore = 0x00710003 - - Errors.SmartDeviceVoiceChat.Unknown = 0x00720001 - Errors.SmartDeviceVoiceChat.InvalidArgument = 0x00720002 - Errors.SmartDeviceVoiceChat.InvalidResponse = 0x00720003 - Errors.SmartDeviceVoiceChat.InvalidAccessToken = 0x00720004 - Errors.SmartDeviceVoiceChat.Unauthorized = 0x00720005 - Errors.SmartDeviceVoiceChat.AccessError = 0x00720006 - Errors.SmartDeviceVoiceChat.UserNotFound = 0x00720007 - Errors.SmartDeviceVoiceChat.RoomNotFound = 0x00720008 - Errors.SmartDeviceVoiceChat.RoomNotActivated = 0x00720009 - Errors.SmartDeviceVoiceChat.ApplicationNotSupported = 0x0072000A - Errors.SmartDeviceVoiceChat.InternalServerError = 0x0072000B - Errors.SmartDeviceVoiceChat.ServiceUnavailable = 0x0072000C - Errors.SmartDeviceVoiceChat.UnexpectedError = 0x0072000D - Errors.SmartDeviceVoiceChat.UnderMaintenance = 0x0072000E - Errors.SmartDeviceVoiceChat.ServiceNoLongerAvailable = 0x0072000F - Errors.SmartDeviceVoiceChat.AccountTemporarilyDisabled = 0x00720010 - Errors.SmartDeviceVoiceChat.PermissionDenied = 0x00720011 - Errors.SmartDeviceVoiceChat.NetworkServiceLicenseRequired = 0x00720012 - Errors.SmartDeviceVoiceChat.AccountLibraryError = 0x00720013 - Errors.SmartDeviceVoiceChat.GameModeNotFound = 0x00720014 - - Errors.Screening.Unknown = 0x00730001 - Errors.Screening.InvalidArgument = 0x00730002 - Errors.Screening.NotFound = 0x00730003 - - Errors.Custom.Unknown = 0x00740001 - - Errors.Ess.Unknown = 0x00750001 - Errors.Ess.GameSessionError = 0x00750002 - Errors.Ess.GameSessionMaintenance = 0x00750003 - - valueOfErrors := reflect.ValueOf(Errors) - typeOfErrors := valueOfErrors.Type() - - for i := 0; i < valueOfErrors.NumField(); i++ { - category := typeOfErrors.Field(i).Name - - valueOfCategory := reflect.ValueOf(valueOfErrors.Field(i).Interface()) - typeOfCategory := valueOfCategory.Type() - - for j := 0; j < valueOfCategory.NumField(); j++ { - errorName := typeOfCategory.Field(j).Name - errorCode := valueOfCategory.Field(j).Interface().(uint32) - - ErrorNames[errorCode] = category + "::" + errorName - } - } -} - -// ErrorNameFromCode returns an error code string for the provided error code -func ErrorNameFromCode(errorCode uint32) string { - name := ErrorNames[errorCode] - - if name == "" { - return "Invalid Error Code: " + strconv.Itoa(int(errorCode)) - } - - return name -} diff --git a/hpp_server.go b/hpp_server.go index 2608f687..6e4f6e63 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -106,7 +106,7 @@ func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { rmcMessage := hppPacket.RMCMessage() // HPP returns PythonCore::ValidationError if password is missing or invalid - errorResponse := NewRMCError(s, Errors.PythonCore.ValidationError) + errorResponse := NewRMCError(s, ResultCodes.PythonCore.ValidationError) errorResponse.CallID = rmcMessage.CallID errorResponse.IsHPP = true diff --git a/init.go b/init.go index e3027b1d..27d94f2f 100644 --- a/init.go +++ b/init.go @@ -8,7 +8,7 @@ import ( var logger = plogger.NewLogger() func init() { - initErrorsData() + initResultCodes() types.RegisterVariantType(1, types.NewPrimitiveS64(0)) types.RegisterVariantType(2, types.NewPrimitiveF64(0)) diff --git a/result_codes.go b/result_codes.go new file mode 100644 index 00000000..88eeae93 --- /dev/null +++ b/result_codes.go @@ -0,0 +1,660 @@ +package nex + +import ( + "reflect" + "strconv" +) + +var errorMask = 1 << 31 + +type resultCodes struct { + Core struct { + Unknown uint32 + NotImplemented uint32 + InvalidPointer uint32 + OperationAborted uint32 + Exception uint32 + AccessDenied uint32 + InvalidHandle uint32 + InvalidIndex uint32 + OutOfMemory uint32 + InvalidArgument uint32 + Timeout uint32 + InitializationFailure uint32 + CallInitiationFailure uint32 + RegistrationError uint32 + BufferOverflow uint32 + InvalidLockState uint32 + InvalidSequence uint32 + SystemError uint32 + Cancelled uint32 + } + + DDL struct { + InvalidSignature uint32 + IncorrectVersion uint32 + } + + RendezVous struct { + ConnectionFailure uint32 + NotAuthenticated uint32 + InvalidUsername uint32 + InvalidPassword uint32 + UsernameAlreadyExists uint32 + AccountDisabled uint32 + AccountExpired uint32 + ConcurrentLoginDenied uint32 + EncryptionFailure uint32 + InvalidPID uint32 + MaxConnectionsReached uint32 + InvalidGID uint32 + InvalidControlScriptID uint32 + InvalidOperationInLiveEnvironment uint32 + DuplicateEntry uint32 + ControlScriptFailure uint32 + ClassNotFound uint32 + SessionVoid uint32 + DDLMismatch uint32 + InvalidConfiguration uint32 + SessionFull uint32 + InvalidGatheringPassword uint32 + WithoutParticipationPeriod uint32 + PersistentGatheringCreationMax uint32 + PersistentGatheringParticipationMax uint32 + DeniedByParticipants uint32 + ParticipantInBlackList uint32 + GameServerMaintenance uint32 + OperationPostpone uint32 + OutOfRatingRange uint32 + ConnectionDisconnected uint32 + InvalidOperation uint32 + NotParticipatedGathering uint32 + MatchmakeSessionUserPasswordUnmatch uint32 + MatchmakeSessionSystemPasswordUnmatch uint32 + UserIsOffline uint32 + AlreadyParticipatedGathering uint32 + PermissionDenied uint32 + NotFriend uint32 + SessionClosed uint32 + DatabaseTemporarilyUnavailable uint32 + InvalidUniqueID uint32 + MatchmakingWithdrawn uint32 + LimitExceeded uint32 + AccountTemporarilyDisabled uint32 + PartiallyServiceClosed uint32 + ConnectionDisconnectedForConcurrentLogin uint32 + } + + PythonCore struct { + Exception uint32 + TypeError uint32 + IndexError uint32 + InvalidReference uint32 + CallFailure uint32 + MemoryError uint32 + KeyError uint32 + OperationError uint32 + ConversionError uint32 + ValidationError uint32 + } + + Transport struct { + Unknown uint32 + ConnectionFailure uint32 + InvalidURL uint32 + InvalidKey uint32 + InvalidURLType uint32 + DuplicateEndpoint uint32 + IOError uint32 + Timeout uint32 + ConnectionReset uint32 + IncorrectRemoteAuthentication uint32 + ServerRequestError uint32 + DecompressionFailure uint32 + ReliableSendBufferFullFatal uint32 + UPnPCannotInit uint32 + UPnPCannotAddMapping uint32 + NatPMPCannotInit uint32 + NatPMPCannotAddMapping uint32 + UnsupportedNAT uint32 + DNSError uint32 + ProxyError uint32 + DataRemaining uint32 + NoBuffer uint32 + NotFound uint32 + TemporaryServerError uint32 + PermanentServerError uint32 + ServiceUnavailable uint32 + ReliableSendBufferFull uint32 + InvalidStation uint32 + InvalidSubStreamID uint32 + PacketBufferFull uint32 + NatTraversalError uint32 + NatCheckError uint32 + } + + DOCore struct { + StationNotReached uint32 + TargetStationDisconnect uint32 + LocalStationLeaving uint32 + ObjectNotFound uint32 + InvalidRole uint32 + CallTimeout uint32 + RMCDispatchFailed uint32 + MigrationInProgress uint32 + NoAuthority uint32 + NoTargetStationSpecified uint32 + JoinFailed uint32 + JoinDenied uint32 + ConnectivityTestFailed uint32 + Unknown uint32 + UnfreedReferences uint32 + JobTerminationFailed uint32 + InvalidState uint32 + FaultRecoveryFatal uint32 + FaultRecoveryJobProcessFailed uint32 + StationInconsitency uint32 + AbnormalMasterState uint32 + VersionMismatch uint32 + } + + FPD struct { + NotInitialized uint32 + AlreadyInitialized uint32 + NotConnected uint32 + Connected uint32 + InitializationFailure uint32 + OutOfMemory uint32 + RmcFailed uint32 + InvalidArgument uint32 + InvalidLocalAccountID uint32 + InvalidPrincipalID uint32 + InvalidLocalFriendCode uint32 + LocalAccountNotExists uint32 + LocalAccountNotLoaded uint32 + LocalAccountAlreadyLoaded uint32 + FriendAlreadyExists uint32 + FriendNotExists uint32 + FriendNumMax uint32 + NotFriend uint32 + FileIO uint32 + P2PInternetProhibited uint32 + Unknown uint32 + InvalidState uint32 + AddFriendProhibited uint32 + InvalidAccount uint32 + BlacklistedByMe uint32 + FriendAlreadyAdded uint32 + MyFriendListLimitExceed uint32 + RequestLimitExceed uint32 + InvalidMessageID uint32 + MessageIsNotMine uint32 + MessageIsNotForMe uint32 + FriendRequestBlocked uint32 + NotInMyFriendList uint32 + FriendListedByMe uint32 + NotInMyBlacklist uint32 + IncompatibleAccount uint32 + BlockSettingChangeNotAllowed uint32 + SizeLimitExceeded uint32 + OperationNotAllowed uint32 + NotNetworkAccount uint32 + NotificationNotFound uint32 + PreferenceNotInitialized uint32 + FriendRequestNotAllowed uint32 + } + + Ranking struct { + NotInitialized uint32 + InvalidArgument uint32 + RegistrationError uint32 + NotFound uint32 + InvalidScore uint32 + InvalidDataSize uint32 + PermissionDenied uint32 + Unknown uint32 + NotImplemented uint32 + } + + Authentication struct { + NASAuthenticateError uint32 + TokenParseError uint32 + HTTPConnectionError uint32 + HTTPDNSError uint32 + HTTPGetProxySetting uint32 + TokenExpired uint32 + ValidationFailed uint32 + InvalidParam uint32 + PrincipalIDUnmatched uint32 + MoveCountUnmatch uint32 + UnderMaintenance uint32 + UnsupportedVersion uint32 + ServerVersionIsOld uint32 + Unknown uint32 + ClientVersionIsOld uint32 + AccountLibraryError uint32 + ServiceNoLongerAvailable uint32 + UnknownApplication uint32 + ApplicationVersionIsOld uint32 + OutOfService uint32 + NetworkServiceLicenseRequired uint32 + NetworkServiceLicenseSystemError uint32 + NetworkServiceLicenseError3 uint32 + NetworkServiceLicenseError4 uint32 + } + + DataStore struct { + Unknown uint32 + InvalidArgument uint32 + PermissionDenied uint32 + NotFound uint32 + AlreadyLocked uint32 + UnderReviewing uint32 + Expired uint32 + InvalidCheckToken uint32 + SystemFileError uint32 + OverCapacity uint32 + OperationNotAllowed uint32 + InvalidPassword uint32 + ValueNotEqual uint32 + } + + ServiceItem struct { + Unknown uint32 + InvalidArgument uint32 + EShopUnknownHTTPError uint32 + EShopResponseParseError uint32 + NotOwned uint32 + InvalidLimitationType uint32 + ConsumptionRightShortage uint32 + } + + MatchmakeReferee struct { + Unknown uint32 + InvalidArgument uint32 + AlreadyExists uint32 + NotParticipatedGathering uint32 + NotParticipatedRound uint32 + StatsNotFound uint32 + RoundNotFound uint32 + RoundArbitrated uint32 + RoundNotArbitrated uint32 + } + + Subscriber struct { + Unknown uint32 + InvalidArgument uint32 + OverLimit uint32 + PermissionDenied uint32 + } + + Ranking2 struct { + Unknown uint32 + InvalidArgument uint32 + InvalidScore uint32 + } + + SmartDeviceVoiceChat struct { + Unknown uint32 + InvalidArgument uint32 + InvalidResponse uint32 + InvalidAccessToken uint32 + Unauthorized uint32 + AccessError uint32 + UserNotFound uint32 + RoomNotFound uint32 + RoomNotActivated uint32 + ApplicationNotSupported uint32 + InternalServerError uint32 + ServiceUnavailable uint32 + UnexpectedError uint32 + UnderMaintenance uint32 + ServiceNoLongerAvailable uint32 + AccountTemporarilyDisabled uint32 + PermissionDenied uint32 + NetworkServiceLicenseRequired uint32 + AccountLibraryError uint32 + GameModeNotFound uint32 + } + + Screening struct { + Unknown uint32 + InvalidArgument uint32 + NotFound uint32 + } + + Custom struct { + Unknown uint32 + } + + Ess struct { + Unknown uint32 + GameSessionError uint32 + GameSessionMaintenance uint32 + } +} + +// ResultNames contains a map of all the result code string names, indexed by the result code +var ResultNames = map[uint32]string{} + +// ResultCodes provides a struct containing RDV result codes using dot-notation +var ResultCodes resultCodes + +func initResultCodes() { + ResultCodes.Core.Unknown = 0x00010001 + ResultCodes.Core.NotImplemented = 0x00010002 + ResultCodes.Core.InvalidPointer = 0x00010003 + ResultCodes.Core.OperationAborted = 0x00010004 + ResultCodes.Core.Exception = 0x00010005 + ResultCodes.Core.AccessDenied = 0x00010006 + ResultCodes.Core.InvalidHandle = 0x00010007 + ResultCodes.Core.InvalidIndex = 0x00010008 + ResultCodes.Core.OutOfMemory = 0x00010009 + ResultCodes.Core.InvalidArgument = 0x0001000A + ResultCodes.Core.Timeout = 0x0001000B + ResultCodes.Core.InitializationFailure = 0x0001000C + ResultCodes.Core.CallInitiationFailure = 0x0001000D + ResultCodes.Core.RegistrationError = 0x0001000E + ResultCodes.Core.BufferOverflow = 0x0001000F + ResultCodes.Core.InvalidLockState = 0x00010010 + ResultCodes.Core.InvalidSequence = 0x00010011 + ResultCodes.Core.SystemError = 0x00010012 + ResultCodes.Core.Cancelled = 0x00010013 + + ResultCodes.DDL.InvalidSignature = 0x00020001 + ResultCodes.DDL.IncorrectVersion = 0x00020002 + + ResultCodes.RendezVous.ConnectionFailure = 0x00030001 + ResultCodes.RendezVous.NotAuthenticated = 0x00030002 + ResultCodes.RendezVous.InvalidUsername = 0x00030064 + ResultCodes.RendezVous.InvalidPassword = 0x00030065 + ResultCodes.RendezVous.UsernameAlreadyExists = 0x00030066 + ResultCodes.RendezVous.AccountDisabled = 0x00030067 + ResultCodes.RendezVous.AccountExpired = 0x00030068 + ResultCodes.RendezVous.ConcurrentLoginDenied = 0x00030069 + ResultCodes.RendezVous.EncryptionFailure = 0x0003006A + ResultCodes.RendezVous.InvalidPID = 0x0003006B + ResultCodes.RendezVous.MaxConnectionsReached = 0x0003006C + ResultCodes.RendezVous.InvalidGID = 0x0003006D + ResultCodes.RendezVous.InvalidControlScriptID = 0x0003006E + ResultCodes.RendezVous.InvalidOperationInLiveEnvironment = 0x0003006F + ResultCodes.RendezVous.DuplicateEntry = 0x00030070 + ResultCodes.RendezVous.ControlScriptFailure = 0x00030071 + ResultCodes.RendezVous.ClassNotFound = 0x00030072 + ResultCodes.RendezVous.SessionVoid = 0x00030073 + ResultCodes.RendezVous.DDLMismatch = 0x00030075 + ResultCodes.RendezVous.InvalidConfiguration = 0x00030076 + ResultCodes.RendezVous.SessionFull = 0x000300C8 + ResultCodes.RendezVous.InvalidGatheringPassword = 0x000300C9 + ResultCodes.RendezVous.WithoutParticipationPeriod = 0x000300CA + ResultCodes.RendezVous.PersistentGatheringCreationMax = 0x000300CB + ResultCodes.RendezVous.PersistentGatheringParticipationMax = 0x000300CC + ResultCodes.RendezVous.DeniedByParticipants = 0x000300CD + ResultCodes.RendezVous.ParticipantInBlackList = 0x000300CE + ResultCodes.RendezVous.GameServerMaintenance = 0x000300CF + ResultCodes.RendezVous.OperationPostpone = 0x000300D0 + ResultCodes.RendezVous.OutOfRatingRange = 0x000300D1 + ResultCodes.RendezVous.ConnectionDisconnected = 0x000300D2 + ResultCodes.RendezVous.InvalidOperation = 0x000300D3 + ResultCodes.RendezVous.NotParticipatedGathering = 0x000300D4 + ResultCodes.RendezVous.MatchmakeSessionUserPasswordUnmatch = 0x000300D5 + ResultCodes.RendezVous.MatchmakeSessionSystemPasswordUnmatch = 0x000300D6 + ResultCodes.RendezVous.UserIsOffline = 0x000300D7 + ResultCodes.RendezVous.AlreadyParticipatedGathering = 0x000300D8 + ResultCodes.RendezVous.PermissionDenied = 0x000300D9 + ResultCodes.RendezVous.NotFriend = 0x000300DA + ResultCodes.RendezVous.SessionClosed = 0x000300DB + ResultCodes.RendezVous.DatabaseTemporarilyUnavailable = 0x000300DC + ResultCodes.RendezVous.InvalidUniqueID = 0x000300DD + ResultCodes.RendezVous.MatchmakingWithdrawn = 0x000300DE + ResultCodes.RendezVous.LimitExceeded = 0x000300DF + ResultCodes.RendezVous.AccountTemporarilyDisabled = 0x000300E0 + ResultCodes.RendezVous.PartiallyServiceClosed = 0x000300E1 + ResultCodes.RendezVous.ConnectionDisconnectedForConcurrentLogin = 0x000300E2 + + ResultCodes.PythonCore.Exception = 0x00040001 + ResultCodes.PythonCore.TypeError = 0x00040002 + ResultCodes.PythonCore.IndexError = 0x00040003 + ResultCodes.PythonCore.InvalidReference = 0x00040004 + ResultCodes.PythonCore.CallFailure = 0x00040005 + ResultCodes.PythonCore.MemoryError = 0x00040006 + ResultCodes.PythonCore.KeyError = 0x00040007 + ResultCodes.PythonCore.OperationError = 0x00040008 + ResultCodes.PythonCore.ConversionError = 0x00040009 + ResultCodes.PythonCore.ValidationError = 0x0004000A + + ResultCodes.Transport.Unknown = 0x00050001 + ResultCodes.Transport.ConnectionFailure = 0x00050002 + ResultCodes.Transport.InvalidURL = 0x00050003 + ResultCodes.Transport.InvalidKey = 0x00050004 + ResultCodes.Transport.InvalidURLType = 0x00050005 + ResultCodes.Transport.DuplicateEndpoint = 0x00050006 + ResultCodes.Transport.IOError = 0x00050007 + ResultCodes.Transport.Timeout = 0x00050008 + ResultCodes.Transport.ConnectionReset = 0x00050009 + ResultCodes.Transport.IncorrectRemoteAuthentication = 0x0005000A + ResultCodes.Transport.ServerRequestError = 0x0005000B + ResultCodes.Transport.DecompressionFailure = 0x0005000C + ResultCodes.Transport.ReliableSendBufferFullFatal = 0x0005000D + ResultCodes.Transport.UPnPCannotInit = 0x0005000E + ResultCodes.Transport.UPnPCannotAddMapping = 0x0005000F + ResultCodes.Transport.NatPMPCannotInit = 0x00050010 + ResultCodes.Transport.NatPMPCannotAddMapping = 0x00050011 + ResultCodes.Transport.UnsupportedNAT = 0x00050013 + ResultCodes.Transport.DNSError = 0x00050014 + ResultCodes.Transport.ProxyError = 0x00050015 + ResultCodes.Transport.DataRemaining = 0x00050016 + ResultCodes.Transport.NoBuffer = 0x00050017 + ResultCodes.Transport.NotFound = 0x00050018 + ResultCodes.Transport.TemporaryServerError = 0x00050019 + ResultCodes.Transport.PermanentServerError = 0x0005001A + ResultCodes.Transport.ServiceUnavailable = 0x0005001B + ResultCodes.Transport.ReliableSendBufferFull = 0x0005001C + ResultCodes.Transport.InvalidStation = 0x0005001D + ResultCodes.Transport.InvalidSubStreamID = 0x0005001E + ResultCodes.Transport.PacketBufferFull = 0x0005001F + ResultCodes.Transport.NatTraversalError = 0x00050020 + ResultCodes.Transport.NatCheckError = 0x00050021 + + ResultCodes.DOCore.StationNotReached = 0x00060001 + ResultCodes.DOCore.TargetStationDisconnect = 0x00060002 + ResultCodes.DOCore.LocalStationLeaving = 0x00060003 + ResultCodes.DOCore.ObjectNotFound = 0x00060004 + ResultCodes.DOCore.InvalidRole = 0x00060005 + ResultCodes.DOCore.CallTimeout = 0x00060006 + ResultCodes.DOCore.RMCDispatchFailed = 0x00060007 + ResultCodes.DOCore.MigrationInProgress = 0x00060008 + ResultCodes.DOCore.NoAuthority = 0x00060009 + ResultCodes.DOCore.NoTargetStationSpecified = 0x0006000A + ResultCodes.DOCore.JoinFailed = 0x0006000B + ResultCodes.DOCore.JoinDenied = 0x0006000C + ResultCodes.DOCore.ConnectivityTestFailed = 0x0006000D + ResultCodes.DOCore.Unknown = 0x0006000E + ResultCodes.DOCore.UnfreedReferences = 0x0006000F + ResultCodes.DOCore.JobTerminationFailed = 0x00060010 + ResultCodes.DOCore.InvalidState = 0x00060011 + ResultCodes.DOCore.FaultRecoveryFatal = 0x00060012 + ResultCodes.DOCore.FaultRecoveryJobProcessFailed = 0x00060013 + ResultCodes.DOCore.StationInconsitency = 0x00060014 + ResultCodes.DOCore.AbnormalMasterState = 0x00060015 + ResultCodes.DOCore.VersionMismatch = 0x00060016 + + ResultCodes.FPD.NotInitialized = 0x00650000 + ResultCodes.FPD.AlreadyInitialized = 0x00650001 + ResultCodes.FPD.NotConnected = 0x00650002 + ResultCodes.FPD.Connected = 0x00650003 + ResultCodes.FPD.InitializationFailure = 0x00650004 + ResultCodes.FPD.OutOfMemory = 0x00650005 + ResultCodes.FPD.RmcFailed = 0x00650006 + ResultCodes.FPD.InvalidArgument = 0x00650007 + ResultCodes.FPD.InvalidLocalAccountID = 0x00650008 + ResultCodes.FPD.InvalidPrincipalID = 0x00650009 + ResultCodes.FPD.InvalidLocalFriendCode = 0x0065000A + ResultCodes.FPD.LocalAccountNotExists = 0x0065000B + ResultCodes.FPD.LocalAccountNotLoaded = 0x0065000C + ResultCodes.FPD.LocalAccountAlreadyLoaded = 0x0065000D + ResultCodes.FPD.FriendAlreadyExists = 0x0065000E + ResultCodes.FPD.FriendNotExists = 0x0065000F + ResultCodes.FPD.FriendNumMax = 0x00650010 + ResultCodes.FPD.NotFriend = 0x00650011 + ResultCodes.FPD.FileIO = 0x00650012 + ResultCodes.FPD.P2PInternetProhibited = 0x00650013 + ResultCodes.FPD.Unknown = 0x00650014 + ResultCodes.FPD.InvalidState = 0x00650015 + ResultCodes.FPD.AddFriendProhibited = 0x00650017 + ResultCodes.FPD.InvalidAccount = 0x00650019 + ResultCodes.FPD.BlacklistedByMe = 0x0065001A + ResultCodes.FPD.FriendAlreadyAdded = 0x0065001C + ResultCodes.FPD.MyFriendListLimitExceed = 0x0065001D + ResultCodes.FPD.RequestLimitExceed = 0x0065001E + ResultCodes.FPD.InvalidMessageID = 0x0065001F + ResultCodes.FPD.MessageIsNotMine = 0x00650020 + ResultCodes.FPD.MessageIsNotForMe = 0x00650021 + ResultCodes.FPD.FriendRequestBlocked = 0x00650022 + ResultCodes.FPD.NotInMyFriendList = 0x00650023 + ResultCodes.FPD.FriendListedByMe = 0x00650024 + ResultCodes.FPD.NotInMyBlacklist = 0x00650025 + ResultCodes.FPD.IncompatibleAccount = 0x00650026 + ResultCodes.FPD.BlockSettingChangeNotAllowed = 0x00650027 + ResultCodes.FPD.SizeLimitExceeded = 0x00650028 + ResultCodes.FPD.OperationNotAllowed = 0x00650029 + ResultCodes.FPD.NotNetworkAccount = 0x0065002A + ResultCodes.FPD.NotificationNotFound = 0x0065002B + ResultCodes.FPD.PreferenceNotInitialized = 0x0065002C + ResultCodes.FPD.FriendRequestNotAllowed = 0x0065002D + + ResultCodes.Ranking.NotInitialized = 0x00670001 + ResultCodes.Ranking.InvalidArgument = 0x00670002 + ResultCodes.Ranking.RegistrationError = 0x00670003 + ResultCodes.Ranking.NotFound = 0x00670005 + ResultCodes.Ranking.InvalidScore = 0x00670006 + ResultCodes.Ranking.InvalidDataSize = 0x00670007 + ResultCodes.Ranking.PermissionDenied = 0x00670009 + ResultCodes.Ranking.Unknown = 0x0067000A + ResultCodes.Ranking.NotImplemented = 0x0067000B + + ResultCodes.Authentication.NASAuthenticateError = 0x00680001 + ResultCodes.Authentication.TokenParseError = 0x00680002 + ResultCodes.Authentication.HTTPConnectionError = 0x00680003 + ResultCodes.Authentication.HTTPDNSError = 0x00680004 + ResultCodes.Authentication.HTTPGetProxySetting = 0x00680005 + ResultCodes.Authentication.TokenExpired = 0x00680006 + ResultCodes.Authentication.ValidationFailed = 0x00680007 + ResultCodes.Authentication.InvalidParam = 0x00680008 + ResultCodes.Authentication.PrincipalIDUnmatched = 0x00680009 + ResultCodes.Authentication.MoveCountUnmatch = 0x0068000A + ResultCodes.Authentication.UnderMaintenance = 0x0068000B + ResultCodes.Authentication.UnsupportedVersion = 0x0068000C + ResultCodes.Authentication.ServerVersionIsOld = 0x0068000D + ResultCodes.Authentication.Unknown = 0x0068000E + ResultCodes.Authentication.ClientVersionIsOld = 0x0068000F + ResultCodes.Authentication.AccountLibraryError = 0x00680010 + ResultCodes.Authentication.ServiceNoLongerAvailable = 0x00680011 + ResultCodes.Authentication.UnknownApplication = 0x00680012 + ResultCodes.Authentication.ApplicationVersionIsOld = 0x00680013 + ResultCodes.Authentication.OutOfService = 0x00680014 + ResultCodes.Authentication.NetworkServiceLicenseRequired = 0x00680015 + ResultCodes.Authentication.NetworkServiceLicenseSystemError = 0x00680016 + ResultCodes.Authentication.NetworkServiceLicenseError3 = 0x00680017 + ResultCodes.Authentication.NetworkServiceLicenseError4 = 0x00680018 + + ResultCodes.DataStore.Unknown = 0x00690001 + ResultCodes.DataStore.InvalidArgument = 0x00690002 + ResultCodes.DataStore.PermissionDenied = 0x00690003 + ResultCodes.DataStore.NotFound = 0x00690004 + ResultCodes.DataStore.AlreadyLocked = 0x00690005 + ResultCodes.DataStore.UnderReviewing = 0x00690006 + ResultCodes.DataStore.Expired = 0x00690007 + ResultCodes.DataStore.InvalidCheckToken = 0x00690008 + ResultCodes.DataStore.SystemFileError = 0x00690009 + ResultCodes.DataStore.OverCapacity = 0x0069000A + ResultCodes.DataStore.OperationNotAllowed = 0x0069000B + ResultCodes.DataStore.InvalidPassword = 0x0069000C + ResultCodes.DataStore.ValueNotEqual = 0x0069000D + + ResultCodes.ServiceItem.Unknown = 0x006C0001 + ResultCodes.ServiceItem.InvalidArgument = 0x006C0002 + ResultCodes.ServiceItem.EShopUnknownHTTPError = 0x006C0003 + ResultCodes.ServiceItem.EShopResponseParseError = 0x006C0004 + ResultCodes.ServiceItem.NotOwned = 0x006C0005 + ResultCodes.ServiceItem.InvalidLimitationType = 0x006C0006 + ResultCodes.ServiceItem.ConsumptionRightShortage = 0x006C0007 + + ResultCodes.MatchmakeReferee.Unknown = 0x006F0001 + ResultCodes.MatchmakeReferee.InvalidArgument = 0x006F0002 + ResultCodes.MatchmakeReferee.AlreadyExists = 0x006F0003 + ResultCodes.MatchmakeReferee.NotParticipatedGathering = 0x006F0004 + ResultCodes.MatchmakeReferee.NotParticipatedRound = 0x006F0005 + ResultCodes.MatchmakeReferee.StatsNotFound = 0x006F0006 + ResultCodes.MatchmakeReferee.RoundNotFound = 0x006F0007 + ResultCodes.MatchmakeReferee.RoundArbitrated = 0x006F0008 + ResultCodes.MatchmakeReferee.RoundNotArbitrated = 0x006F0009 + + ResultCodes.Subscriber.Unknown = 0x00700001 + ResultCodes.Subscriber.InvalidArgument = 0x00700002 + ResultCodes.Subscriber.OverLimit = 0x00700003 + ResultCodes.Subscriber.PermissionDenied = 0x00700004 + + ResultCodes.Ranking2.Unknown = 0x00710001 + ResultCodes.Ranking2.InvalidArgument = 0x00710002 + ResultCodes.Ranking2.InvalidScore = 0x00710003 + + ResultCodes.SmartDeviceVoiceChat.Unknown = 0x00720001 + ResultCodes.SmartDeviceVoiceChat.InvalidArgument = 0x00720002 + ResultCodes.SmartDeviceVoiceChat.InvalidResponse = 0x00720003 + ResultCodes.SmartDeviceVoiceChat.InvalidAccessToken = 0x00720004 + ResultCodes.SmartDeviceVoiceChat.Unauthorized = 0x00720005 + ResultCodes.SmartDeviceVoiceChat.AccessError = 0x00720006 + ResultCodes.SmartDeviceVoiceChat.UserNotFound = 0x00720007 + ResultCodes.SmartDeviceVoiceChat.RoomNotFound = 0x00720008 + ResultCodes.SmartDeviceVoiceChat.RoomNotActivated = 0x00720009 + ResultCodes.SmartDeviceVoiceChat.ApplicationNotSupported = 0x0072000A + ResultCodes.SmartDeviceVoiceChat.InternalServerError = 0x0072000B + ResultCodes.SmartDeviceVoiceChat.ServiceUnavailable = 0x0072000C + ResultCodes.SmartDeviceVoiceChat.UnexpectedError = 0x0072000D + ResultCodes.SmartDeviceVoiceChat.UnderMaintenance = 0x0072000E + ResultCodes.SmartDeviceVoiceChat.ServiceNoLongerAvailable = 0x0072000F + ResultCodes.SmartDeviceVoiceChat.AccountTemporarilyDisabled = 0x00720010 + ResultCodes.SmartDeviceVoiceChat.PermissionDenied = 0x00720011 + ResultCodes.SmartDeviceVoiceChat.NetworkServiceLicenseRequired = 0x00720012 + ResultCodes.SmartDeviceVoiceChat.AccountLibraryError = 0x00720013 + ResultCodes.SmartDeviceVoiceChat.GameModeNotFound = 0x00720014 + + ResultCodes.Screening.Unknown = 0x00730001 + ResultCodes.Screening.InvalidArgument = 0x00730002 + ResultCodes.Screening.NotFound = 0x00730003 + + ResultCodes.Custom.Unknown = 0x00740001 + + ResultCodes.Ess.Unknown = 0x00750001 + ResultCodes.Ess.GameSessionError = 0x00750002 + ResultCodes.Ess.GameSessionMaintenance = 0x00750003 + + valueOfResultCodes := reflect.ValueOf(ResultCodes) + typeOfResultCodes := valueOfResultCodes.Type() + + for i := 0; i < valueOfResultCodes.NumField(); i++ { + category := typeOfResultCodes.Field(i).Name + + valueOfCategory := reflect.ValueOf(valueOfResultCodes.Field(i).Interface()) + typeOfCategory := valueOfCategory.Type() + + for j := 0; j < valueOfCategory.NumField(); j++ { + name := typeOfCategory.Field(j).Name + resultCode := valueOfCategory.Field(j).Interface().(uint32) + + ResultNames[resultCode] = category + "::" + name + } + } +} + +// ResultCodeToName returns an error code string for the provided error code +func ResultCodeToName(resultCode uint32) string { + name := ResultNames[resultCode] + + if name == "" { + return "Invalid Result Code: " + strconv.Itoa(int(resultCode)) + } + + return name +} diff --git a/test/main.go b/test/main.go index de895830..be6c9177 100644 --- a/test/main.go +++ b/test/main.go @@ -26,7 +26,7 @@ func accountDetailsByPID(pid *types.PID) (*nex.Account, uint32) { return testUserAccount, 0 } - return nil, nex.Errors.RendezVous.InvalidPID + return nil, nex.ResultCodes.RendezVous.InvalidPID } func accountDetailsByUsername(username string) (*nex.Account, uint32) { @@ -42,7 +42,7 @@ func accountDetailsByUsername(username string) (*nex.Account, uint32) { return testUserAccount, 0 } - return nil, nex.Errors.RendezVous.InvalidUsername + return nil, nex.ResultCodes.RendezVous.InvalidUsername } func main() { From dc4e7ae733cf5ed1887b186f63fc365e2bc830a8 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 24 Jan 2024 14:27:17 -0500 Subject: [PATCH 133/178] update: added new Error type which conforms to the error interface --- error.go | 36 ++++++++++++++++++++++++++++++++++++ hpp_server.go | 4 ++-- prudp_endpoint.go | 4 ++-- test/main.go | 20 ++++++++++---------- 4 files changed, 50 insertions(+), 14 deletions(-) create mode 100644 error.go diff --git a/error.go b/error.go new file mode 100644 index 00000000..62467f3a --- /dev/null +++ b/error.go @@ -0,0 +1,36 @@ +package nex + +import "fmt" + +// TODO - Add more metadata? Like the sender or whatever? + +// Error is a custom error type implementing the error interface +type Error struct { + ResultCode uint32 + Message string +} + +// Error satisfies the error interface and prints the underlying error +func (e Error) Error() string { + resultCode := e.ResultCode + + if int(resultCode)&errorMask != 0 { + // * Result codes are stored without the MSB set + resultCode = resultCode & ^uint32(errorMask) + } + + return fmt.Sprintf("[%s] %s", ResultCodeToName(resultCode), e.Message) +} + +// NewError returns a new NEX error with a RDV result code +func NewError(resultCode uint32, message string) *Error { + if int(resultCode)&errorMask == 0 { + // * Set the MSB to mark the result as an error + resultCode = uint32(int(resultCode) | errorMask) + } + + return &Error{ + ResultCode: resultCode, + Message: message, + } +} diff --git a/hpp_server.go b/hpp_server.go index 6e4f6e63..16f6bbe1 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -24,8 +24,8 @@ type HPPServer struct { natTraversalProtocolVersion *LibraryVersion dataHandlers []func(packet PacketInterface) byteStreamSettings *ByteStreamSettings - AccountDetailsByPID func(pid *types.PID) (*Account, uint32) - AccountDetailsByUsername func(username string) (*Account, uint32) + AccountDetailsByPID func(pid *types.PID) (*Account, *Error) + AccountDetailsByUsername func(username string) (*Account, *Error) } // OnData adds an event handler which is fired when a new HPP request is received diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 359f0720..d636fdce 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -22,8 +22,8 @@ type PRUDPEndPoint struct { connectionEndedEventHandlers []func(connection *PRUDPConnection) ConnectionIDCounter *Counter[uint32] ServerAccount *Account - AccountDetailsByPID func(pid *types.PID) (*Account, uint32) - AccountDetailsByUsername func(username string) (*Account, uint32) + AccountDetailsByPID func(pid *types.PID) (*Account, *Error) + AccountDetailsByUsername func(username string) (*Account, *Error) } // OnData adds an event handler which is fired when a new DATA packet is received diff --git a/test/main.go b/test/main.go index be6c9177..8c927b9d 100644 --- a/test/main.go +++ b/test/main.go @@ -13,36 +13,36 @@ var authenticationServerAccount *nex.Account var secureServerAccount *nex.Account var testUserAccount *nex.Account -func accountDetailsByPID(pid *types.PID) (*nex.Account, uint32) { +func accountDetailsByPID(pid *types.PID) (*nex.Account, *nex.Error) { if pid.Equals(authenticationServerAccount.PID) { - return authenticationServerAccount, 0 + return authenticationServerAccount, nil } if pid.Equals(secureServerAccount.PID) { - return secureServerAccount, 0 + return secureServerAccount, nil } if pid.Equals(testUserAccount.PID) { - return testUserAccount, 0 + return testUserAccount, nil } - return nil, nex.ResultCodes.RendezVous.InvalidPID + return nil, nex.NewError(nex.ResultCodes.RendezVous.InvalidPID, "Invalid PID") } -func accountDetailsByUsername(username string) (*nex.Account, uint32) { +func accountDetailsByUsername(username string) (*nex.Account, *nex.Error) { if username == authenticationServerAccount.Username { - return authenticationServerAccount, 0 + return authenticationServerAccount, nil } if username == secureServerAccount.Username { - return secureServerAccount, 0 + return secureServerAccount, nil } if username == testUserAccount.Username { - return testUserAccount, 0 + return testUserAccount, nil } - return nil, nex.ResultCodes.RendezVous.InvalidUsername + return nil, nex.NewError(nex.ResultCodes.RendezVous.InvalidPID, "Invalid username") } func main() { From be784327c7716d7510c03401d4b3124061b7c88c Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 24 Jan 2024 15:57:48 -0500 Subject: [PATCH 134/178] types: add nil and type ID checks to Variant --- types/variant.go | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/types/variant.go b/types/variant.go index 8ff0e4eb..783220b5 100644 --- a/types/variant.go +++ b/types/variant.go @@ -23,7 +23,10 @@ type Variant struct { // WriteTo writes the Variant to the given writable func (v *Variant) WriteTo(writable Writable) { v.TypeID.WriteTo(writable) - v.Type.WriteTo(writable) + + if v.Type != nil { + v.Type.WriteTo(writable) + } } // ExtractFrom extracts the Variant from the given readable @@ -33,6 +36,11 @@ func (v *Variant) ExtractFrom(readable Readable) error { return fmt.Errorf("Failed to read Variant type ID. %s", err.Error()) } + // * Type ID of 0 is a "None" type. There is no data + if v.TypeID.Value == 0 { + return nil + } + if _, ok := VariantTypes[v.TypeID.Value]; !ok { return fmt.Errorf("Invalid Variant type ID %d", v.TypeID) } @@ -47,7 +55,10 @@ func (v *Variant) Copy() RVType { copied := NewVariant() copied.TypeID = v.TypeID.Copy().(*PrimitiveU8) - copied.Type = v.Type.Copy() + + if v.Type != nil { + copied.Type = v.Type.Copy() + } return copied } @@ -64,7 +75,11 @@ func (v *Variant) Equals(o RVType) bool { return false } - return v.Type.Equals(other.Type) + if v.Type != nil { + return v.Type.Equals(other.Type) + } + + return true } // String returns a string representation of the struct @@ -80,8 +95,14 @@ func (v *Variant) FormatToString(indentationLevel int) string { var b strings.Builder b.WriteString("Variant{\n") - b.WriteString(fmt.Sprintf("%TypeID: %s,\n", indentationValues, v.TypeID)) - b.WriteString(fmt.Sprintf("%Type: %s\n", indentationValues, v.Type)) + b.WriteString(fmt.Sprintf("%sTypeID: %s,\n", indentationValues, v.TypeID)) + + if v.Type != nil { + b.WriteString(fmt.Sprintf("%sType: %s\n", indentationValues, v.Type)) + } else { + b.WriteString(fmt.Sprintf("%sType: None\n", indentationValues)) + } + b.WriteString(fmt.Sprintf("%s}", indentationEnd)) return b.String() @@ -89,7 +110,9 @@ func (v *Variant) FormatToString(indentationLevel int) string { // NewVariant returns a new Variant func NewVariant() *Variant { + // * Type ID of 0 is a "None" type. There is no data return &Variant{ TypeID: NewPrimitiveU8(0), + Type: nil, } } From 191455352a7c1500ee1883245062c86b1452667c Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Wed, 24 Jan 2024 16:40:38 -0500 Subject: [PATCH 135/178] update: added RegisterServiceProtocol to HPP and PRUDP servers --- hpp_server.go | 5 +++++ prudp_endpoint.go | 8 +++++--- prudp_server.go | 38 +++++++------------------------------- server_interface.go | 1 - service_protocol.go | 6 ++++++ 5 files changed, 23 insertions(+), 35 deletions(-) create mode 100644 service_protocol.go diff --git a/hpp_server.go b/hpp_server.go index 16f6bbe1..b5cb9dd1 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -28,6 +28,11 @@ type HPPServer struct { AccountDetailsByUsername func(username string) (*Account, *Error) } +// RegisterServiceProtocol registers a NEX service with the HPP server +func (s *HPPServer) RegisterServiceProtocol(protocol ServiceProtocol) { + s.OnData(protocol.HandlePacket) +} + // OnData adds an event handler which is fired when a new HPP request is received func (s *HPPServer) OnData(handler func(packet PacketInterface)) { s.dataHandlers = append(s.dataHandlers, handler) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index d636fdce..b2ef39e5 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -26,6 +26,11 @@ type PRUDPEndPoint struct { AccountDetailsByUsername func(username string) (*Account, *Error) } +// RegisterServiceProtocol registers a NEX service with the endpoint +func (pep *PRUDPEndPoint) RegisterServiceProtocol(protocol ServiceProtocol) { + pep.OnData(protocol.HandlePacket) +} + // OnData adds an event handler which is fired when a new DATA packet is received func (pep *PRUDPEndPoint) OnData(handler func(packet PacketInterface)) { pep.on("data", handler) @@ -60,9 +65,6 @@ func (pep *PRUDPEndPoint) emit(name string, packet PRUDPPacketInterface) { go handler(packet) } } - - // * propagate the event up to the PRUDP server - pep.Server.emit(name, packet) } func (pep *PRUDPEndPoint) emitConnectionEnded(connection *PRUDPConnection) { diff --git a/prudp_server.go b/prudp_server.go index a4ab0afb..1c942fa8 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -34,7 +34,6 @@ type PRUDPServer struct { PRUDPv1ConnectionSignatureKey []byte byteStreamSettings *ByteStreamSettings PRUDPV0Settings *PRUDPV0Settings - packetEventHandlers map[string][]func(packet PacketInterface) } // BindPRUDPEndPoint binds a provided PRUDPEndPoint to the server @@ -428,38 +427,15 @@ func (ps *PRUDPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSetti ps.byteStreamSettings = byteStreamSettings } -// OnData adds an event handler which is fired when a new DATA packet is received -func (ps *PRUDPServer) OnData(handler func(packet PacketInterface)) { - ps.on("data", handler) -} - -func (ps *PRUDPServer) on(name string, handler func(packet PacketInterface)) { - if _, ok := ps.packetEventHandlers[name]; !ok { - ps.packetEventHandlers[name] = make([]func(packet PacketInterface), 0) - } - - ps.packetEventHandlers[name] = append(ps.packetEventHandlers[name], handler) -} - -// emit emits an event to all relevant listeners. These events fire after the PRUDPEndPoint event handlers -func (ps *PRUDPServer) emit(name string, packet PRUDPPacketInterface) { - if handlers, ok := ps.packetEventHandlers[name]; ok { - for _, handler := range handlers { - go handler(packet) - } - } -} - // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ - Endpoints: NewMutexMap[uint8, *PRUDPEndPoint](), - Connections: NewMutexMap[string, *SocketConnection](), - SessionKeyLength: 32, - FragmentSize: 1300, - pingTimeout: time.Second * 15, - byteStreamSettings: NewByteStreamSettings(), - PRUDPV0Settings: NewPRUDPV0Settings(), - packetEventHandlers: make(map[string][]func(PacketInterface)), + Endpoints: NewMutexMap[uint8, *PRUDPEndPoint](), + Connections: NewMutexMap[string, *SocketConnection](), + SessionKeyLength: 32, + FragmentSize: 1300, + pingTimeout: time.Second * 15, + byteStreamSettings: NewByteStreamSettings(), + PRUDPV0Settings: NewPRUDPV0Settings(), } } diff --git a/server_interface.go b/server_interface.go index 56c3fad6..94183e6c 100644 --- a/server_interface.go +++ b/server_interface.go @@ -14,7 +14,6 @@ type ServerInterface interface { NATTraversalProtocolVersion() *LibraryVersion SetDefaultLibraryVersion(version *LibraryVersion) Send(packet PacketInterface) - OnData(handler func(packet PacketInterface)) ByteStreamSettings() *ByteStreamSettings SetByteStreamSettings(settings *ByteStreamSettings) } diff --git a/service_protocol.go b/service_protocol.go new file mode 100644 index 00000000..bc979ec6 --- /dev/null +++ b/service_protocol.go @@ -0,0 +1,6 @@ +package nex + +// ServiceProtocol represents a NEX service capable of handling PRUDP/HPP packets +type ServiceProtocol interface { + HandlePacket(packet PacketInterface) +} From 4683f67a04aadb2e044c94d3d99a101f53b025f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 25 Jan 2024 17:35:00 +0000 Subject: [PATCH 136/178] prudp: DefaultstreamSettings -> DefaultStreamSettings --- prudp_endpoint.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index b2ef39e5..ae215e0c 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -16,7 +16,7 @@ import ( type PRUDPEndPoint struct { Server *PRUDPServer StreamID uint8 - DefaultstreamSettings *StreamSettings + DefaultStreamSettings *StreamSettings Connections *MutexMap[string, *PRUDPConnection] packetEventHandlers map[string][]func(packet PacketInterface) connectionEndedEventHandlers []func(connection *PRUDPConnection) @@ -86,7 +86,7 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc connection.DefaultPRUDPVersion = packet.Version() connection.StreamType = streamType connection.StreamID = streamID - connection.StreamSettings = pep.DefaultstreamSettings.Copy() + connection.StreamSettings = pep.DefaultStreamSettings.Copy() connection.startHeartbeat() // * Fail-safe. If the server reboots, then @@ -610,7 +610,7 @@ func (pep *PRUDPEndPoint) FindConnectionByPID(pid uint64) *PRUDPConnection { func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { return &PRUDPEndPoint{ StreamID: streamID, - DefaultstreamSettings: NewStreamSettings(), + DefaultStreamSettings: NewStreamSettings(), Connections: NewMutexMap[string, *PRUDPConnection](), packetEventHandlers: make(map[string][]func(PacketInterface)), connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), From 3b90bdc96b483166fb0b6d23ac272d8e96b59cb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 10 Feb 2024 23:49:16 +0000 Subject: [PATCH 137/178] refactor: Update ServerInterface to EndpointInterface This requires a lot of refactor but it brings back the option of getting events when a connection disconnects. --- README.md | 13 ++-- byte_stream_in.go | 20 +++--- byte_stream_out.go | 22 ++++--- client_interface.go | 2 +- endpoint_interface.go | 11 ++++ go.mod | 5 +- go.sum | 14 +++-- hpp_client.go | 16 ++--- hpp_packet.go | 6 +- hpp_server.go | 102 +++--------------------------- kerberos.go | 17 ++--- library_version.go | 29 +++++++++ prudp_connection.go | 14 ++--- prudp_endpoint.go | 72 +++++++++++++++------- prudp_packet_lite.go | 21 +++---- prudp_packet_v0.go | 25 ++++---- prudp_packet_v1.go | 25 ++++---- prudp_server.go | 133 +++------------------------------------- resend_scheduler.go | 4 +- rmc_message.go | 22 +++---- server_interface.go | 19 ------ test/auth.go | 33 +++++----- test/generate_ticket.go | 6 +- test/hpp.go | 6 +- test/secure.go | 41 +++++++------ websocket_server.go | 8 +-- 26 files changed, 270 insertions(+), 416 deletions(-) create mode 100644 endpoint_interface.go delete mode 100644 server_interface.go diff --git a/README.md b/README.md index ab6df18f..7e1f81c5 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,8 @@ package main import ( "fmt" - nex "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/types" ) func main() { @@ -63,6 +64,9 @@ func main() { authServer := nex.NewPRUDPServer() // The main PRUDP server endpoint := nex.NewPRUDPEndPoint(1) // A PRUDP endpoint for PRUDP connections to connect to. Bound to StreamID 1 + endpoint.ServerAccount = nex.NewAccount(types.NewPID(1), "Quazal Authentication", "password")) + endpoint.AccountDetailsByPID = accountDetailsByPID + endpoint.AccountDetailsByUsername = accountDetailsByUsername // Setup event handlers for the endpoint endpoint.OnData(func(packet nex.PacketInterface) { @@ -86,10 +90,9 @@ func main() { // Bind the endpoint to the server and configure it's settings authServer.BindPRUDPEndPoint(endpoint) authServer.SetFragmentSize(962) - authServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) - authServer.SetKerberosPassword([]byte("password")) - authServer.SetKerberosKeySize(16) - authServer.SetAccessKey("ridfebb9") + authServer.LibraryVersions.SetDefault(nex.NewLibraryVersion(1, 1, 0)) + authServer.SessionKeyLength = 16 + authServer.AccessKey = "ridfebb9" authServer.Listen(60000) } ``` diff --git a/byte_stream_in.go b/byte_stream_in.go index 192f549b..294e696b 100644 --- a/byte_stream_in.go +++ b/byte_stream_in.go @@ -9,15 +9,16 @@ import ( // ByteStreamIn is an input stream abstraction of github.com/superwhiskers/crunch/v3 with nex type support type ByteStreamIn struct { *crunch.Buffer - Server ServerInterface + LibraryVersions *LibraryVersions + Settings *ByteStreamSettings } // StringLengthSize returns the expected size of String length fields func (bsi *ByteStreamIn) StringLengthSize() int { size := 2 - if bsi.Server != nil && bsi.Server.ByteStreamSettings() != nil { - size = bsi.Server.ByteStreamSettings().StringLengthSize + if bsi.Settings != nil { + size = bsi.Settings.StringLengthSize } return size @@ -27,8 +28,8 @@ func (bsi *ByteStreamIn) StringLengthSize() int { func (bsi *ByteStreamIn) PIDSize() int { size := 4 - if bsi.Server != nil && bsi.Server.ByteStreamSettings() != nil { - size = bsi.Server.ByteStreamSettings().PIDSize + if bsi.Settings != nil { + size = bsi.Settings.PIDSize } return size @@ -38,8 +39,8 @@ func (bsi *ByteStreamIn) PIDSize() int { func (bsi *ByteStreamIn) UseStructureHeader() bool { useStructureHeader := false - if bsi.Server != nil && bsi.Server.ByteStreamSettings() != nil { - useStructureHeader = bsi.Server.ByteStreamSettings().UseStructureHeader + if bsi.Settings != nil { + useStructureHeader = bsi.Settings.UseStructureHeader } return useStructureHeader @@ -167,9 +168,10 @@ func (bsi *ByteStreamIn) ReadPrimitiveBool() (bool, error) { } // NewByteStreamIn returns a new NEX input byte stream -func NewByteStreamIn(data []byte, server ServerInterface) *ByteStreamIn { +func NewByteStreamIn(data []byte, libraryVersions *LibraryVersions, settings *ByteStreamSettings) *ByteStreamIn { return &ByteStreamIn{ Buffer: crunch.NewBuffer(data), - Server: server, + LibraryVersions: libraryVersions, + Settings: settings, } } diff --git a/byte_stream_out.go b/byte_stream_out.go index 27b338d8..27f63a28 100644 --- a/byte_stream_out.go +++ b/byte_stream_out.go @@ -8,15 +8,16 @@ import ( // ByteStreamOut is an abstraction of github.com/superwhiskers/crunch with nex type support type ByteStreamOut struct { *crunch.Buffer - Server ServerInterface + LibraryVersions *LibraryVersions + Settings *ByteStreamSettings } // StringLengthSize returns the expected size of String length fields func (bso *ByteStreamOut) StringLengthSize() int { size := 2 - if bso.Server != nil && bso.Server.ByteStreamSettings() != nil { - size = bso.Server.ByteStreamSettings().StringLengthSize + if bso.Settings != nil { + size = bso.Settings.StringLengthSize } return size @@ -26,8 +27,8 @@ func (bso *ByteStreamOut) StringLengthSize() int { func (bso *ByteStreamOut) PIDSize() int { size := 4 - if bso.Server != nil && bso.Server.ByteStreamSettings() != nil { - size = bso.Server.ByteStreamSettings().PIDSize + if bso.Settings != nil { + size = bso.Settings.PIDSize } return size @@ -37,8 +38,8 @@ func (bso *ByteStreamOut) PIDSize() int { func (bso *ByteStreamOut) UseStructureHeader() bool { useStructureHeader := false - if bso.Server != nil && bso.Server.ByteStreamSettings() != nil { - useStructureHeader = bso.Server.ByteStreamSettings().UseStructureHeader + if bso.Settings != nil { + useStructureHeader = bso.Settings.UseStructureHeader } return useStructureHeader @@ -46,7 +47,7 @@ func (bso *ByteStreamOut) UseStructureHeader() bool { // CopyNew returns a copy of the StreamOut but with a blank internal buffer. Returns as types.Writable func (bso *ByteStreamOut) CopyNew() types.Writable { - return NewByteStreamOut(bso.Server) + return NewByteStreamOut(bso.LibraryVersions, bso.Settings) } // Writes the input data to the end of the StreamOut @@ -127,9 +128,10 @@ func (bso *ByteStreamOut) WritePrimitiveBool(b bool) { } // NewByteStreamOut returns a new NEX writable byte stream -func NewByteStreamOut(server ServerInterface) *ByteStreamOut { +func NewByteStreamOut(libraryVersions *LibraryVersions, settings *ByteStreamSettings) *ByteStreamOut { return &ByteStreamOut{ Buffer: crunch.NewBuffer(), - Server: server, + LibraryVersions: libraryVersions, + Settings: settings, } } diff --git a/client_interface.go b/client_interface.go index 7788468a..dfc16556 100644 --- a/client_interface.go +++ b/client_interface.go @@ -9,7 +9,7 @@ import ( // ClientInterface defines all the methods a client should have regardless of server type type ClientInterface interface { - Server() ServerInterface + Endpoint() EndpointInterface Address() net.Addr PID() *types.PID SetPID(pid *types.PID) diff --git a/endpoint_interface.go b/endpoint_interface.go new file mode 100644 index 00000000..b8b2dfee --- /dev/null +++ b/endpoint_interface.go @@ -0,0 +1,11 @@ +package nex + +// EndpointInterface defines all the methods an endpoint should have regardless of type +type EndpointInterface interface { + AccessKey() string + SetAccessKey(accessKey string) + Send(packet PacketInterface) + LibraryVersions() *LibraryVersions + ByteStreamSettings() *ByteStreamSettings + SetByteStreamSettings(settings *ByteStreamSettings) +} diff --git a/go.mod b/go.mod index cca3b394..fa87d6c8 100644 --- a/go.mod +++ b/go.mod @@ -5,16 +5,17 @@ go 1.21 require ( github.com/PretendoNetwork/plogger-go v1.0.4 github.com/cyberdelia/lzo v1.0.0 - github.com/lxzan/gws v1.7.0 + github.com/lxzan/gws v1.8.0 github.com/superwhiskers/crunch/v3 v3.5.7 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/mod v0.12.0 ) require ( + github.com/dolthub/maphash v0.1.0 // indirect github.com/fatih/color v1.15.0 // indirect github.com/jwalton/go-supportscolor v1.2.0 // indirect - github.com/klauspost/compress v1.16.5 // indirect + github.com/klauspost/compress v1.17.5 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect golang.org/x/sys v0.15.0 // indirect diff --git a/go.sum b/go.sum index 6d57e1bc..8af4fe94 100644 --- a/go.sum +++ b/go.sum @@ -4,16 +4,18 @@ github.com/cyberdelia/lzo v1.0.0 h1:smmvcahczwI/VWSzZ7iikt50lubari5py3qL4hAEHII= github.com/cyberdelia/lzo v1.0.0/go.mod h1:UVNk6eM6Sozt1wx17TECJKuqmIY58TJOVeJxjlGGAGs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= +github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jwalton/go-supportscolor v1.2.0 h1:g6Ha4u7Vm3LIsQ5wmeBpS4gazu0UP1DRDE8y6bre4H8= github.com/jwalton/go-supportscolor v1.2.0/go.mod h1:hFVUAZV2cWg+WFFC4v8pT2X/S2qUUBYMioBD9AINXGs= -github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= -github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= -github.com/lxzan/gws v1.7.0 h1:/yy5/+3eccMy61/scXM57fTDvucN/t7/0t5wLTwL+qY= -github.com/lxzan/gws v1.7.0/go.mod h1:dsC6S7kJNh+iWqqu2HiO8tnNCji04HwyJCYfTOS+6iY= +github.com/klauspost/compress v1.17.5 h1:d4vBd+7CHydUqpFBgUEKkSdtSugf9YFmSkvUYPquI5E= +github.com/klauspost/compress v1.17.5/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/lxzan/gws v1.8.0 h1:SqRuU6PUez/BA6CHB9BufV6n+gCnRtWHUntjLcaHA44= +github.com/lxzan/gws v1.8.0/go.mod h1:FcGeRMB7HwGuTvMLR24ku0Zx0p6RXqeKASeMc4VYgi4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -21,8 +23,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/superwhiskers/crunch/v3 v3.5.7 h1:N9RLxaR65C36i26BUIpzPXGy2f6pQ7wisu2bawbKNqg= github.com/superwhiskers/crunch/v3 v3.5.7/go.mod h1:4ub2EKgF1MAhTjoOCTU4b9uLMsAweHEa89aRrfAypXA= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= diff --git a/hpp_client.go b/hpp_client.go index 480b453e..6d8fa779 100644 --- a/hpp_client.go +++ b/hpp_client.go @@ -8,14 +8,14 @@ import ( // HPPClient represents a single HPP client type HPPClient struct { - address *net.TCPAddr - server *HPPServer - pid *types.PID + address *net.TCPAddr + endpoint *HPPServer + pid *types.PID } -// Server returns the server the client is connecting to -func (c *HPPClient) Server() ServerInterface { - return c.server +// Endpoint returns the server the client is connecting to +func (c *HPPClient) Endpoint() EndpointInterface { + return c.endpoint } // Address returns the clients address as a net.Addr @@ -36,7 +36,7 @@ func (c *HPPClient) SetPID(pid *types.PID) { // NewHPPClient creates and returns a new Client using the provided IP address and server func NewHPPClient(address *net.TCPAddr, server *HPPServer) *HPPClient { return &HPPClient{ - address: address, - server: server, + address: address, + endpoint: server, } } diff --git a/hpp_packet.go b/hpp_packet.go index 64139bf6..8c4c2a2a 100644 --- a/hpp_packet.go +++ b/hpp_packet.go @@ -55,7 +55,7 @@ func (p *HPPPacket) validateAccessKeySignature(signature string) error { } func (p *HPPPacket) calculateAccessKeySignature() ([]byte, error) { - accessKey := p.Sender().Server().AccessKey() + accessKey := p.Sender().Endpoint().AccessKey() accessKeyBytes, err := hex.DecodeString(accessKey) if err != nil { @@ -93,7 +93,7 @@ func (p *HPPPacket) validatePasswordSignature(signature string) error { func (p *HPPPacket) calculatePasswordSignature() ([]byte, error) { sender := p.Sender() pid := sender.PID() - account, _ := sender.Server().(*HPPServer).AccountDetailsByPID(pid) + account, _ := sender.Endpoint().(*HPPServer).AccountDetailsByPID(pid) if account == nil { return nil, errors.New("PID does not exist") } @@ -140,7 +140,7 @@ func NewHPPPacket(client *HPPClient, payload []byte) (*HPPPacket, error) { } if payload != nil { - rmcMessage := NewRMCRequest(client.Server()) + rmcMessage := NewRMCRequest(client.Endpoint()) err := rmcMessage.FromBytes(payload) if err != nil { return nil, fmt.Errorf("Failed to decode HPP request. %s", err) diff --git a/hpp_server.go b/hpp_server.go index b5cb9dd1..a4522c2a 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -14,14 +14,7 @@ import ( type HPPServer struct { server *http.Server accessKey string - version *LibraryVersion - datastoreProtocolVersion *LibraryVersion - matchMakingProtocolVersion *LibraryVersion - rankingProtocolVersion *LibraryVersion - ranking2ProtocolVersion *LibraryVersion - messagingProtocolVersion *LibraryVersion - utilityProtocolVersion *LibraryVersion - natTraversalProtocolVersion *LibraryVersion + libraryVersions *LibraryVersions dataHandlers []func(packet PacketInterface) byteStreamSettings *ByteStreamSettings AccountDetailsByPID func(pid *types.PID) (*Account, *Error) @@ -167,6 +160,11 @@ func (s *HPPServer) Send(packet PacketInterface) { } } +// LibraryVersions returns the versions that the server has +func (s *HPPServer) LibraryVersions() *LibraryVersions { + return s.libraryVersions +} + // AccessKey returns the servers sandbox access key func (s *HPPServer) AccessKey() string { return s.accessKey @@ -177,93 +175,6 @@ func (s *HPPServer) SetAccessKey(accessKey string) { s.accessKey = accessKey } -// LibraryVersion returns the server NEX version -func (s *HPPServer) LibraryVersion() *LibraryVersion { - return s.version -} - -// SetDefaultLibraryVersion sets the default NEX protocol versions -func (s *HPPServer) SetDefaultLibraryVersion(version *LibraryVersion) { - s.version = version - s.datastoreProtocolVersion = version.Copy() - s.matchMakingProtocolVersion = version.Copy() - s.rankingProtocolVersion = version.Copy() - s.ranking2ProtocolVersion = version.Copy() - s.messagingProtocolVersion = version.Copy() - s.utilityProtocolVersion = version.Copy() - s.natTraversalProtocolVersion = version.Copy() -} - -// DataStoreProtocolVersion returns the servers DataStore protocol version -func (s *HPPServer) DataStoreProtocolVersion() *LibraryVersion { - return s.datastoreProtocolVersion -} - -// SetDataStoreProtocolVersion sets the servers DataStore protocol version -func (s *HPPServer) SetDataStoreProtocolVersion(version *LibraryVersion) { - s.datastoreProtocolVersion = version -} - -// MatchMakingProtocolVersion returns the servers MatchMaking protocol version -func (s *HPPServer) MatchMakingProtocolVersion() *LibraryVersion { - return s.matchMakingProtocolVersion -} - -// SetMatchMakingProtocolVersion sets the servers MatchMaking protocol version -func (s *HPPServer) SetMatchMakingProtocolVersion(version *LibraryVersion) { - s.matchMakingProtocolVersion = version -} - -// RankingProtocolVersion returns the servers Ranking protocol version -func (s *HPPServer) RankingProtocolVersion() *LibraryVersion { - return s.rankingProtocolVersion -} - -// SetRankingProtocolVersion sets the servers Ranking protocol version -func (s *HPPServer) SetRankingProtocolVersion(version *LibraryVersion) { - s.rankingProtocolVersion = version -} - -// Ranking2ProtocolVersion returns the servers Ranking2 protocol version -func (s *HPPServer) Ranking2ProtocolVersion() *LibraryVersion { - return s.ranking2ProtocolVersion -} - -// SetRanking2ProtocolVersion sets the servers Ranking2 protocol version -func (s *HPPServer) SetRanking2ProtocolVersion(version *LibraryVersion) { - s.ranking2ProtocolVersion = version -} - -// MessagingProtocolVersion returns the servers Messaging protocol version -func (s *HPPServer) MessagingProtocolVersion() *LibraryVersion { - return s.messagingProtocolVersion -} - -// SetMessagingProtocolVersion sets the servers Messaging protocol version -func (s *HPPServer) SetMessagingProtocolVersion(version *LibraryVersion) { - s.messagingProtocolVersion = version -} - -// UtilityProtocolVersion returns the servers Utility protocol version -func (s *HPPServer) UtilityProtocolVersion() *LibraryVersion { - return s.utilityProtocolVersion -} - -// SetUtilityProtocolVersion sets the servers Utility protocol version -func (s *HPPServer) SetUtilityProtocolVersion(version *LibraryVersion) { - s.utilityProtocolVersion = version -} - -// SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version -func (s *HPPServer) SetNATTraversalProtocolVersion(version *LibraryVersion) { - s.natTraversalProtocolVersion = version -} - -// NATTraversalProtocolVersion returns the servers NAT Traversal protocol version -func (s *HPPServer) NATTraversalProtocolVersion() *LibraryVersion { - return s.natTraversalProtocolVersion -} - // ByteStreamSettings returns the settings to be used for ByteStreams func (s *HPPServer) ByteStreamSettings() *ByteStreamSettings { return s.byteStreamSettings @@ -278,6 +189,7 @@ func (s *HPPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings func NewHPPServer() *HPPServer { s := &HPPServer{ dataHandlers: make([]func(packet PacketInterface), 0), + libraryVersions: NewLibraryVersions(), byteStreamSettings: NewByteStreamSettings(), } diff --git a/kerberos.go b/kerberos.go index bd858e61..fa724bb0 100644 --- a/kerberos.go +++ b/kerberos.go @@ -93,6 +93,7 @@ func NewKerberosTicket() *KerberosTicket { // KerberosTicketInternalData holds the internal data for a kerberos ticket to be processed by the server type KerberosTicketInternalData struct { + Server *PRUDPServer // TODO - Remove this dependency and make a settings struct Issued *types.DateTime SourcePID *types.PID SessionKey []byte @@ -108,7 +109,7 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *ByteStreamOut) data := stream.Bytes() - if stream.Server.(*PRUDPServer).KerberosTicketVersion == 1 { + if ti.Server.KerberosTicketVersion == 1 { ticketKey := make([]byte, 16) _, err := rand.Read(ticketKey) if err != nil { @@ -122,7 +123,7 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *ByteStreamOut) encrypted := encryption.Encrypt(data) - finalStream := NewByteStreamOut(stream.Server) + finalStream := NewByteStreamOut(stream.LibraryVersions, stream.Settings) ticketBuffer := types.NewBuffer(ticketKey) encryptedBuffer := types.NewBuffer(encrypted) @@ -140,7 +141,7 @@ func (ti *KerberosTicketInternalData) Encrypt(key []byte, stream *ByteStreamOut) // Decrypt decrypts the given data and populates the struct func (ti *KerberosTicketInternalData) Decrypt(stream *ByteStreamIn, key []byte) error { - if stream.Server.(*PRUDPServer).KerberosTicketVersion == 1 { + if ti.Server.KerberosTicketVersion == 1 { ticketKey := types.NewBuffer(nil) if err := ticketKey.ExtractFrom(stream); err != nil { return fmt.Errorf("Failed to read Kerberos ticket internal data key. %s", err.Error()) @@ -154,7 +155,7 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *ByteStreamIn, key []byte) hash := md5.Sum(append(key, ticketKey.Value...)) key = hash[:] - stream = NewByteStreamIn(data.Value, stream.Server) + stream = NewByteStreamIn(data.Value, stream.LibraryVersions, stream.Settings) } encryption := NewKerberosEncryption(key) @@ -164,7 +165,7 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *ByteStreamIn, key []byte) return fmt.Errorf("Failed to decrypt Kerberos ticket internal data. %s", err.Error()) } - stream = NewByteStreamIn(decrypted, stream.Server) + stream = NewByteStreamIn(decrypted, stream.LibraryVersions, stream.Settings) timestamp := types.NewDateTime(0) if err := timestamp.ExtractFrom(stream); err != nil { @@ -178,14 +179,14 @@ func (ti *KerberosTicketInternalData) Decrypt(stream *ByteStreamIn, key []byte) ti.Issued = timestamp ti.SourcePID = userPID - ti.SessionKey = stream.ReadBytesNext(int64(stream.Server.(*PRUDPServer).SessionKeyLength)) + ti.SessionKey = stream.ReadBytesNext(int64(ti.Server.SessionKeyLength)) return nil } // NewKerberosTicketInternalData returns a new KerberosTicketInternalData instance -func NewKerberosTicketInternalData() *KerberosTicketInternalData { - return &KerberosTicketInternalData{} +func NewKerberosTicketInternalData(server *PRUDPServer) *KerberosTicketInternalData { + return &KerberosTicketInternalData{Server: server} } // DeriveKerberosKey derives a users kerberos encryption key based on their PID and password diff --git a/library_version.go b/library_version.go index adb0dcf7..911877a2 100644 --- a/library_version.go +++ b/library_version.go @@ -76,3 +76,32 @@ func NewLibraryVersion(major, minor, patch int) *LibraryVersion { semver: fmt.Sprintf("v%d.%d.%d", major, minor, patch), } } + +// LibraryVersions contains a set of the NEX version that the server uses +type LibraryVersions struct { + Main *LibraryVersion + DataStore *LibraryVersion + MatchMaking *LibraryVersion + Ranking *LibraryVersion + Ranking2 *LibraryVersion + Messaging *LibraryVersion + Utility *LibraryVersion + NATTraversal *LibraryVersion +} + +// SetDefault sets the default NEX protocol versions +func (lvs *LibraryVersions) SetDefault(version *LibraryVersion) { + lvs.Main = version + lvs.DataStore = version.Copy() + lvs.MatchMaking = version.Copy() + lvs.Ranking = version.Copy() + lvs.Ranking2 = version.Copy() + lvs.Messaging = version.Copy() + lvs.Utility = version.Copy() + lvs.NATTraversal = version.Copy() +} + +// NewLibraryVersions returns a new set of LibraryVersions +func NewLibraryVersions() *LibraryVersions { + return &LibraryVersions{} +} diff --git a/prudp_connection.go b/prudp_connection.go index bd576df1..a5a1036f 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -14,7 +14,7 @@ import ( // A single network socket may be used to open multiple PRUDP virtual connections type PRUDPConnection struct { Socket *SocketConnection // * The connections parent socket - Endpoint *PRUDPEndPoint // * The PRUDP endpoint the connection is connected to + endpoint *PRUDPEndPoint // * The PRUDP endpoint the connection is connected to ID uint32 // * Connection ID SessionID uint8 // * Random value generated at the start of the session. Client and server IDs do not need to match ServerSessionID uint8 // * Random value generated at the start of the session. Client and server IDs do not need to match @@ -36,8 +36,8 @@ type PRUDPConnection struct { } // Server returns the PRUDP server the connections socket is connected to -func (pc *PRUDPConnection) Server() ServerInterface { - return pc.Socket.Server +func (pc *PRUDPConnection) Endpoint() EndpointInterface { + return pc.endpoint } // Address returns the socket address of the connection @@ -76,11 +76,11 @@ func (pc *PRUDPConnection) cleanup() { pc.Socket.Connections.Delete(pc.SessionID) - pc.Endpoint.emitConnectionEnded(pc) + pc.endpoint.emitConnectionEnded(pc) if pc.Socket.Connections.Size() == 0 { // * No more PRUDP connections, assume the socket connection is also closed - pc.Endpoint.Server.Connections.Delete(pc.Socket.Address.String()) + pc.endpoint.Server.Connections.Delete(pc.Socket.Address.String()) // TODO - Is there any other cleanup that needs to happen here? // TODO - Should we add an event for when a socket closes too? } @@ -168,12 +168,12 @@ func (pc *PRUDPConnection) resetHeartbeat() { } if pc.heartbeatTimer != nil { - pc.heartbeatTimer.Reset(pc.Endpoint.Server.pingTimeout) // TODO - This is part of StreamSettings + pc.heartbeatTimer.Reset(pc.endpoint.Server.pingTimeout) // TODO - This is part of StreamSettings } } func (pc *PRUDPConnection) startHeartbeat() { - endpoint := pc.Endpoint + endpoint := pc.endpoint server := endpoint.Server // * Every time a packet is sent, connection.resetHeartbeat() diff --git a/prudp_endpoint.go b/prudp_endpoint.go index ae215e0c..74f3bb99 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -81,7 +81,7 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc if !ok { connection = NewPRUDPConnection(socket) - connection.Endpoint = pep + connection.endpoint = pep connection.ID = pep.ConnectionIDCounter.Next() connection.DefaultPRUDPVersion = packet.Version() connection.StreamType = streamType @@ -149,7 +149,7 @@ func (pep *PRUDPEndPoint) handleAcknowledgment(packet PRUDPPacketInterface) { func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface) { connection := packet.Sender().(*PRUDPConnection) - stream := NewByteStreamIn(packet.Payload(), pep.Server) + stream := NewByteStreamIn(packet.Payload(), pep.Server.LibraryVersions, pep.ByteStreamSettings()) sequenceIDs := make([]uint16, 0) var baseSequenceID uint16 var slidingWindow *SlidingWindow @@ -201,11 +201,11 @@ func (pep *PRUDPEndPoint) handleSyn(packet PRUDPPacketInterface) { var ack PRUDPPacketInterface if packet.Version() == 2 { - ack, _ = NewPRUDPPacketLite(connection, nil) + ack, _ = NewPRUDPPacketLite(pep.Server, connection, nil) } else if packet.Version() == 1 { - ack, _ = NewPRUDPPacketV1(connection, nil) + ack, _ = NewPRUDPPacketV1(pep.Server, connection, nil) } else { - ack, _ = NewPRUDPPacketV0(connection, nil) + ack, _ = NewPRUDPPacketV0(pep.Server, connection, nil) } connectionSignature, err := packet.calculateConnectionSignature(connection.Socket.Address) @@ -244,11 +244,11 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { var ack PRUDPPacketInterface if packet.Version() == 2 { - ack, _ = NewPRUDPPacketLite(connection, nil) + ack, _ = NewPRUDPPacketLite(pep.Server, connection, nil) } else if packet.Version() == 1 { - ack, _ = NewPRUDPPacketV1(connection, nil) + ack, _ = NewPRUDPPacketV1(pep.Server, connection, nil) } else { - ack, _ = NewPRUDPPacketV0(connection, nil) + ack, _ = NewPRUDPPacketV0(pep.Server, connection, nil) } connection.ServerConnectionSignature = packet.getConnectionSignature() @@ -304,7 +304,7 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { binary.LittleEndian.PutUint32(responseCheckValueBytes, responseCheckValue) checkValueResponse := types.NewBuffer(responseCheckValueBytes) - stream := NewByteStreamOut(pep.Server) + stream := NewByteStreamOut(pep.Server.LibraryVersions, pep.ByteStreamSettings()) checkValueResponse.WriteTo(stream) @@ -350,7 +350,7 @@ func (pep *PRUDPEndPoint) handlePing(packet PRUDPPacketInterface) { } func (pep *PRUDPEndPoint) readKerberosTicket(payload []byte) ([]byte, *types.PID, uint32, error) { - stream := NewByteStreamIn(payload, pep.Server) + stream := NewByteStreamIn(payload, pep.Server.LibraryVersions, pep.ByteStreamSettings()) ticketData := types.NewBuffer(nil) if err := ticketData.ExtractFrom(stream); err != nil { @@ -374,8 +374,8 @@ func (pep *PRUDPEndPoint) readKerberosTicket(payload []byte) ([]byte, *types.PID serverKey := DeriveKerberosKey(serverAccount.PID, []byte(serverAccount.Password)) - ticket := NewKerberosTicketInternalData() - if err := ticket.Decrypt(NewByteStreamIn(ticketData.Value, pep.Server), serverKey); err != nil { + ticket := NewKerberosTicketInternalData(pep.Server) + if err := ticket.Decrypt(NewByteStreamIn(ticketData.Value, pep.Server.LibraryVersions, pep.ByteStreamSettings()), serverKey); err != nil { return nil, nil, 0, err } @@ -395,7 +395,7 @@ func (pep *PRUDPEndPoint) readKerberosTicket(payload []byte) ([]byte, *types.PID return nil, nil, 0, err } - checkDataStream := NewByteStreamIn(decryptedRequestData, pep.Server) + checkDataStream := NewByteStreamIn(decryptedRequestData, pep.Server.LibraryVersions, pep.ByteStreamSettings()) userPID := types.NewPID(0) if err := userPID.ExtractFrom(checkDataStream); err != nil { @@ -419,11 +419,11 @@ func (pep *PRUDPEndPoint) acknowledgePacket(packet PRUDPPacketInterface) { var ack PRUDPPacketInterface if packet.Version() == 2 { - ack, _ = NewPRUDPPacketLite(packet.Sender().(*PRUDPConnection), nil) + ack, _ = NewPRUDPPacketLite(pep.Server, packet.Sender().(*PRUDPConnection), nil) } else if packet.Version() == 1 { - ack, _ = NewPRUDPPacketV1(packet.Sender().(*PRUDPConnection), nil) + ack, _ = NewPRUDPPacketV1(pep.Server, packet.Sender().(*PRUDPConnection), nil) } else { - ack, _ = NewPRUDPPacketV0(packet.Sender().(*PRUDPConnection), nil) + ack, _ = NewPRUDPPacketV0(pep.Server, packet.Sender().(*PRUDPConnection), nil) } ack.SetType(packet.Type()) @@ -473,7 +473,7 @@ func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { payload := slidingWindow.AddFragment(decompressedPayload) if packet.getFragmentID() == 0 { - message := NewRMCMessage(pep.Server) + message := NewRMCMessage(pep) err := message.FromBytes(payload) if err != nil { // TODO - Should this return the error too? @@ -539,7 +539,7 @@ func (pep *PRUDPEndPoint) handleUnreliable(packet PRUDPPacketInterface) { payload := packet.processUnreliableCrypto() - message := NewRMCMessage(pep.Server) + message := NewRMCMessage(pep) err := message.FromBytes(payload) if err != nil { // TODO - Should this return the error too? @@ -556,11 +556,11 @@ func (pep *PRUDPEndPoint) sendPing(connection *PRUDPConnection) { switch connection.DefaultPRUDPVersion { case 0: - ping, _ = NewPRUDPPacketV0(connection, nil) + ping, _ = NewPRUDPPacketV0(pep.Server, connection, nil) case 1: - ping, _ = NewPRUDPPacketV1(connection, nil) + ping, _ = NewPRUDPPacketV1(pep.Server, connection, nil) case 2: - ping, _ = NewPRUDPPacketLite(connection, nil) + ping, _ = NewPRUDPPacketLite(pep.Server, connection, nil) } ping.SetType(PingPacket) @@ -606,6 +606,36 @@ func (pep *PRUDPEndPoint) FindConnectionByPID(pid uint64) *PRUDPConnection { return connection } +// AccessKey returns the servers sandbox access key +func (pep *PRUDPEndPoint) AccessKey() string { + return pep.Server.AccessKey +} + +// SetAccessKey sets the servers sandbox access key +func (pep *PRUDPEndPoint) SetAccessKey(accessKey string) { + pep.Server.AccessKey = accessKey +} + +// Send sends the packet to the packets sender +func (pep *PRUDPEndPoint) Send(packet PacketInterface) { + pep.Server.Send(packet) +} + +// LibraryVersions returns the versions that the server has +func (pep *PRUDPEndPoint) LibraryVersions() *LibraryVersions { + return pep.Server.LibraryVersions +} + +// ByteStreamSettings returns the settings to be used for ByteStreams +func (pep *PRUDPEndPoint) ByteStreamSettings() *ByteStreamSettings { + return pep.Server.ByteStreamSettings +} + +// SetByteStreamSettings sets the settings to be used for ByteStreams +func (pep *PRUDPEndPoint) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings) { + pep.Server.ByteStreamSettings = byteStreamSettings +} + // NewPRUDPEndPoint returns a new PRUDPEndPoint for a server on the provided stream ID func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { return &PRUDPEndPoint{ diff --git a/prudp_packet_lite.go b/prudp_packet_lite.go index 0c595b00..8557c9f2 100644 --- a/prudp_packet_lite.go +++ b/prudp_packet_lite.go @@ -67,7 +67,7 @@ func (p *PRUDPPacketLite) DestinationVirtualPortStreamID() uint8 { // // Retains the same PRUDPConnection pointer func (p *PRUDPPacketLite) Copy() PRUDPPacketInterface { - copied, _ := NewPRUDPPacketLite(p.sender, nil) + copied, _ := NewPRUDPPacketLite(p.server, p.sender, nil) copied.server = p.server copied.sourceVirtualPortStreamType = p.sourceVirtualPortStreamType @@ -184,7 +184,7 @@ func (p *PRUDPPacketLite) decode() error { func (p *PRUDPPacketLite) Bytes() []byte { options := p.encodeOptions() - stream := NewByteStreamOut(nil) + stream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings) stream.WritePrimitiveUInt8(0x80) stream.WritePrimitiveUInt8(uint8(len(options))) @@ -207,7 +207,7 @@ func (p *PRUDPPacketLite) Bytes() []byte { func (p *PRUDPPacketLite) decodeOptions() error { data := p.readStream.ReadBytesNext(int64(p.optionsLength)) - optionsStream := NewByteStreamIn(data, nil) + optionsStream := NewByteStreamIn(data, p.server.LibraryVersions, p.server.ByteStreamSettings) for optionsStream.Remaining() > 0 { optionID, err := optionsStream.ReadPrimitiveUInt8() @@ -267,7 +267,7 @@ func (p *PRUDPPacketLite) decodeOptions() error { } func (p *PRUDPPacketLite) encodeOptions() []byte { - optionsStream := NewByteStreamOut(nil) + optionsStream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings) if p.packetType == SynPacket || p.packetType == ConnectPacket { optionsStream.WritePrimitiveUInt8(0) @@ -320,7 +320,7 @@ func (p *PRUDPPacketLite) calculateSignature(sessionKey, connectionSignature []b } // NewPRUDPPacketLite creates and returns a new PacketLite using the provided Client and stream -func NewPRUDPPacketLite(connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketLite, error) { +func NewPRUDPPacketLite(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketLite, error) { packet := &PRUDPPacketLite{ PRUDPPacket: PRUDPPacket{ sender: connection, @@ -328,27 +328,24 @@ func NewPRUDPPacketLite(connection *PRUDPConnection, readStream *ByteStreamIn) ( }, } + packet.server = server + if readStream != nil { - packet.server = readStream.Server.(*PRUDPServer) err := packet.decode() if err != nil { return nil, fmt.Errorf("Failed to decode PRUDPLite packet. %s", err.Error()) } } - if connection != nil { - packet.server = connection.Endpoint.Server - } - return packet, nil } // NewPRUDPPacketsLite reads all possible PRUDPLite packets from the stream -func NewPRUDPPacketsLite(connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsLite(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { - packet, err := NewPRUDPPacketLite(connection, readStream) + packet, err := NewPRUDPPacketLite(server, connection, readStream) if err != nil { return packets, err } diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 3b2f05af..7aca22a8 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -19,7 +19,7 @@ type PRUDPPacketV0 struct { // // Retains the same PRUDPConnection pointer func (p *PRUDPPacketV0) Copy() PRUDPPacketInterface { - copied, _ := NewPRUDPPacketV0(p.sender, nil) + copied, _ := NewPRUDPPacketV0(p.server, p.sender, nil) copied.server = p.server copied.sourceVirtualPort = p.sourceVirtualPort @@ -193,7 +193,7 @@ func (p *PRUDPPacketV0) decode() error { // Bytes encodes a PRUDPv0 packet into a byte slice func (p *PRUDPPacketV0) Bytes() []byte { server := p.server - stream := NewByteStreamOut(server) + stream := NewByteStreamOut(server.LibraryVersions, server.ByteStreamSettings) stream.WritePrimitiveUInt8(uint8(p.sourceVirtualPort)) stream.WritePrimitiveUInt8(uint8(p.destinationVirtualPort)) @@ -247,7 +247,7 @@ func (p *PRUDPPacketV0) calculateSignature(sessionKey, connectionSignature []byt } // NewPRUDPPacketV0 creates and returns a new PacketV0 using the provided Client and stream -func NewPRUDPPacketV0(connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketV0, error) { +func NewPRUDPPacketV0(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketV0, error) { packet := &PRUDPPacketV0{ PRUDPPacket: PRUDPPacket{ sender: connection, @@ -256,27 +256,24 @@ func NewPRUDPPacketV0(connection *PRUDPConnection, readStream *ByteStreamIn) (*P }, } + packet.server = server + if readStream != nil { - packet.server = readStream.Server.(*PRUDPServer) err := packet.decode() if err != nil { return nil, fmt.Errorf("Failed to decode PRUDPv0 packet. %s", err.Error()) } } - if connection != nil { - packet.server = connection.Endpoint.Server - } - return packet, nil } // NewPRUDPPacketsV0 reads all possible PRUDPv0 packets from the stream -func NewPRUDPPacketsV0(connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsV0(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { - packet, err := NewPRUDPPacketV0(connection, readStream) + packet, err := NewPRUDPPacketV0(server, connection, readStream) if err != nil { return packets, err } @@ -317,7 +314,7 @@ func defaultPRUDPv0CalculateSignature(packet *PRUDPPacketV0, sessionKey, connect return packet.server.PRUDPV0Settings.DataSignatureCalculator(packet, sessionKey) } - if packet.packetType == DisconnectPacket && packet.server.accessKey != "ridfebb9" { + if packet.packetType == DisconnectPacket && packet.server.AccessKey != "ridfebb9" { return packet.server.PRUDPV0Settings.DataSignatureCalculator(packet, sessionKey) } } @@ -333,7 +330,7 @@ func defaultPRUDPv0CalculateDataSignature(packet *PRUDPPacketV0, sessionKey []by server := packet.server data := packet.payload - if server.AccessKey() != "ridfebb9" { + if server.AccessKey != "ridfebb9" { header := []byte{0, 0, packet.fragmentID} binary.LittleEndian.PutUint16(header[:2], packet.sequenceID) @@ -342,7 +339,7 @@ func defaultPRUDPv0CalculateDataSignature(packet *PRUDPPacketV0, sessionKey []by } if len(data) > 0 { - key := md5.Sum([]byte(server.AccessKey())) + key := md5.Sum([]byte(server.AccessKey)) mac := hmac.New(md5.New, key[:]) mac.Write(data) @@ -357,7 +354,7 @@ func defaultPRUDPv0CalculateDataSignature(packet *PRUDPPacketV0, sessionKey []by func defaultPRUDPv0CalculateChecksum(packet *PRUDPPacketV0, data []byte) uint32 { server := packet.server - checksum := sum[byte, uint32]([]byte(server.AccessKey())) + checksum := sum[byte, uint32]([]byte(server.AccessKey)) if server.PRUDPV0Settings.UseEnhancedChecksum { padSize := (len(data) + 3) &^ 3 diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 24de698c..cd867ea2 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -25,7 +25,7 @@ type PRUDPPacketV1 struct { // // Retains the same PRUDPConnection pointer func (p *PRUDPPacketV1) Copy() PRUDPPacketInterface { - copied, _ := NewPRUDPPacketV1(p.sender, nil) + copied, _ := NewPRUDPPacketV1(p.server, p.sender, nil) copied.server = p.server copied.sourceVirtualPort = p.sourceVirtualPort @@ -107,7 +107,7 @@ func (p *PRUDPPacketV1) Bytes() []byte { header := p.encodeHeader() - stream := NewByteStreamOut(nil) + stream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings) stream.Grow(2) stream.WriteBytesNext([]byte{0xEA, 0xD0}) @@ -197,7 +197,7 @@ func (p *PRUDPPacketV1) decodeHeader() error { } func (p *PRUDPPacketV1) encodeHeader() []byte { - stream := NewByteStreamOut(nil) + stream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings) stream.WritePrimitiveUInt8(1) // * Version stream.WritePrimitiveUInt8(p.optionsLength) @@ -214,7 +214,7 @@ func (p *PRUDPPacketV1) encodeHeader() []byte { func (p *PRUDPPacketV1) decodeOptions() error { data := p.readStream.ReadBytesNext(int64(p.optionsLength)) - optionsStream := NewByteStreamIn(data, nil) + optionsStream := NewByteStreamIn(data, p.server.LibraryVersions, p.server.ByteStreamSettings) for optionsStream.Remaining() > 0 { optionID, err := optionsStream.ReadPrimitiveUInt8() @@ -268,7 +268,7 @@ func (p *PRUDPPacketV1) decodeOptions() error { } func (p *PRUDPPacketV1) encodeOptions() []byte { - optionsStream := NewByteStreamOut(nil) + optionsStream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings) if p.packetType == SynPacket || p.packetType == ConnectPacket { optionsStream.WritePrimitiveUInt8(0) @@ -331,7 +331,7 @@ func (p *PRUDPPacketV1) calculateConnectionSignature(addr net.Addr) ([]byte, err } func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byte) []byte { - accessKeyBytes := []byte(p.server.accessKey) + accessKeyBytes := []byte(p.server.AccessKey) options := p.encodeOptions() header := p.encodeHeader() @@ -353,7 +353,7 @@ func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byt } // NewPRUDPPacketV1 creates and returns a new PacketV1 using the provided Client and stream -func NewPRUDPPacketV1(connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketV1, error) { +func NewPRUDPPacketV1(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) (*PRUDPPacketV1, error) { packet := &PRUDPPacketV1{ PRUDPPacket: PRUDPPacket{ sender: connection, @@ -362,27 +362,24 @@ func NewPRUDPPacketV1(connection *PRUDPConnection, readStream *ByteStreamIn) (*P }, } + packet.server = server + if readStream != nil { - packet.server = readStream.Server.(*PRUDPServer) err := packet.decode() if err != nil { return nil, fmt.Errorf("Failed to decode PRUDPv1 packet. %s", err.Error()) } } - if connection != nil { - packet.server = connection.Endpoint.Server - } - return packet, nil } // NewPRUDPPacketsV1 reads all possible PRUDPv1 packets from the stream -func NewPRUDPPacketsV1(connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { +func NewPRUDPPacketsV1(server *PRUDPServer, connection *PRUDPConnection, readStream *ByteStreamIn) ([]PRUDPPacketInterface, error) { packets := make([]PRUDPPacketInterface, 0) for readStream.Remaining() > 0 { - packet, err := NewPRUDPPacketV1(connection, readStream) + packet, err := NewPRUDPPacketV1(server, connection, readStream) if err != nil { return packets, err } diff --git a/prudp_server.go b/prudp_server.go index 1c942fa8..f84899f9 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -18,21 +18,14 @@ type PRUDPServer struct { Endpoints *MutexMap[uint8, *PRUDPEndPoint] Connections *MutexMap[string, *SocketConnection] SupportedFunctions uint32 - accessKey string + AccessKey string KerberosTicketVersion int SessionKeyLength int FragmentSize int - version *LibraryVersion - datastoreProtocolVersion *LibraryVersion - matchMakingProtocolVersion *LibraryVersion - rankingProtocolVersion *LibraryVersion - ranking2ProtocolVersion *LibraryVersion - messagingProtocolVersion *LibraryVersion - utilityProtocolVersion *LibraryVersion - natTraversalProtocolVersion *LibraryVersion pingTimeout time.Duration PRUDPv1ConnectionSignatureKey []byte - byteStreamSettings *ByteStreamSettings + LibraryVersions *LibraryVersions + ByteStreamSettings *ByteStreamSettings PRUDPV0Settings *PRUDPV0Settings } @@ -47,7 +40,7 @@ func (ps *PRUDPServer) BindPRUDPEndPoint(endpoint *PRUDPEndPoint) { ps.Endpoints.Set(endpoint.StreamID, endpoint) } -// Listen is an alias of ListenUDP. Implemented to conform to the ServerInterface +// Listen is an alias of ListenUDP. Implemented to conform to the EndpointInterface func (ps *PRUDPServer) Listen(port int) { ps.ListenUDP(port) } @@ -130,7 +123,7 @@ func (ps *PRUDPServer) initPRUDPv1ConnectionSignatureKey() { } func (ps *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, webSocketConnection *gws.Conn) error { - readStream := NewByteStreamIn(packetData, ps) + readStream := NewByteStreamIn(packetData, ps.LibraryVersions, ps.ByteStreamSettings) var packets []PRUDPPacketInterface @@ -139,11 +132,11 @@ func (ps *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, // * until no more data is left, to account for multiple // * packets being sent at once if ps.websocketServer != nil && packetData[0] == 0x80 { - packets, _ = NewPRUDPPacketsLite(nil, readStream) + packets, _ = NewPRUDPPacketsLite(ps, nil, readStream) } else if bytes.Equal(packetData[:2], []byte{0xEA, 0xD0}) { - packets, _ = NewPRUDPPacketsV1(nil, readStream) + packets, _ = NewPRUDPPacketsV1(ps, nil, readStream) } else { - packets, _ = NewPRUDPPacketsV0(nil, readStream) + packets, _ = NewPRUDPPacketsV0(ps, nil, readStream) } for _, packet := range packets { @@ -304,16 +297,6 @@ func (ps *PRUDPServer) sendRaw(socket *SocketConnection, data []byte) { } } -// AccessKey returns the servers sandbox access key -func (ps *PRUDPServer) AccessKey() string { - return ps.accessKey -} - -// SetAccessKey sets the servers sandbox access key -func (ps *PRUDPServer) SetAccessKey(accessKey string) { - ps.accessKey = accessKey -} - // SetFragmentSize sets the max size for a packets payload func (ps *PRUDPServer) SetFragmentSize(fragmentSize int) { // TODO - Derive this value from the MTU @@ -330,103 +313,6 @@ func (ps *PRUDPServer) SetFragmentSize(fragmentSize int) { ps.FragmentSize = fragmentSize } -// LibraryVersion returns the server NEX version -func (ps *PRUDPServer) LibraryVersion() *LibraryVersion { - return ps.version -} - -// SetDefaultLibraryVersion sets the default NEX protocol versions -func (ps *PRUDPServer) SetDefaultLibraryVersion(version *LibraryVersion) { - ps.version = version - ps.datastoreProtocolVersion = version.Copy() - ps.matchMakingProtocolVersion = version.Copy() - ps.rankingProtocolVersion = version.Copy() - ps.ranking2ProtocolVersion = version.Copy() - ps.messagingProtocolVersion = version.Copy() - ps.utilityProtocolVersion = version.Copy() - ps.natTraversalProtocolVersion = version.Copy() -} - -// DataStoreProtocolVersion returns the servers DataStore protocol version -func (ps *PRUDPServer) DataStoreProtocolVersion() *LibraryVersion { - return ps.datastoreProtocolVersion -} - -// SetDataStoreProtocolVersion sets the servers DataStore protocol version -func (ps *PRUDPServer) SetDataStoreProtocolVersion(version *LibraryVersion) { - ps.datastoreProtocolVersion = version -} - -// MatchMakingProtocolVersion returns the servers MatchMaking protocol version -func (ps *PRUDPServer) MatchMakingProtocolVersion() *LibraryVersion { - return ps.matchMakingProtocolVersion -} - -// SetMatchMakingProtocolVersion sets the servers MatchMaking protocol version -func (ps *PRUDPServer) SetMatchMakingProtocolVersion(version *LibraryVersion) { - ps.matchMakingProtocolVersion = version -} - -// RankingProtocolVersion returns the servers Ranking protocol version -func (ps *PRUDPServer) RankingProtocolVersion() *LibraryVersion { - return ps.rankingProtocolVersion -} - -// SetRankingProtocolVersion sets the servers Ranking protocol version -func (ps *PRUDPServer) SetRankingProtocolVersion(version *LibraryVersion) { - ps.rankingProtocolVersion = version -} - -// Ranking2ProtocolVersion returns the servers Ranking2 protocol version -func (ps *PRUDPServer) Ranking2ProtocolVersion() *LibraryVersion { - return ps.ranking2ProtocolVersion -} - -// SetRanking2ProtocolVersion sets the servers Ranking2 protocol version -func (ps *PRUDPServer) SetRanking2ProtocolVersion(version *LibraryVersion) { - ps.ranking2ProtocolVersion = version -} - -// MessagingProtocolVersion returns the servers Messaging protocol version -func (ps *PRUDPServer) MessagingProtocolVersion() *LibraryVersion { - return ps.messagingProtocolVersion -} - -// SetMessagingProtocolVersion sets the servers Messaging protocol version -func (ps *PRUDPServer) SetMessagingProtocolVersion(version *LibraryVersion) { - ps.messagingProtocolVersion = version -} - -// UtilityProtocolVersion returns the servers Utility protocol version -func (ps *PRUDPServer) UtilityProtocolVersion() *LibraryVersion { - return ps.utilityProtocolVersion -} - -// SetUtilityProtocolVersion sets the servers Utility protocol version -func (ps *PRUDPServer) SetUtilityProtocolVersion(version *LibraryVersion) { - ps.utilityProtocolVersion = version -} - -// SetNATTraversalProtocolVersion sets the servers NAT Traversal protocol version -func (ps *PRUDPServer) SetNATTraversalProtocolVersion(version *LibraryVersion) { - ps.natTraversalProtocolVersion = version -} - -// NATTraversalProtocolVersion returns the servers NAT Traversal protocol version -func (ps *PRUDPServer) NATTraversalProtocolVersion() *LibraryVersion { - return ps.natTraversalProtocolVersion -} - -// ByteStreamSettings returns the settings to be used for ByteStreams -func (ps *PRUDPServer) ByteStreamSettings() *ByteStreamSettings { - return ps.byteStreamSettings -} - -// SetByteStreamSettings sets the settings to be used for ByteStreams -func (ps *PRUDPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings) { - ps.byteStreamSettings = byteStreamSettings -} - // NewPRUDPServer will return a new PRUDP server func NewPRUDPServer() *PRUDPServer { return &PRUDPServer{ @@ -435,7 +321,8 @@ func NewPRUDPServer() *PRUDPServer { SessionKeyLength: 32, FragmentSize: 1300, pingTimeout: time.Second * 15, - byteStreamSettings: NewByteStreamSettings(), + LibraryVersions: NewLibraryVersions(), + ByteStreamSettings: NewByteStreamSettings(), PRUDPV0Settings: NewPRUDPV0Settings(), } } diff --git a/resend_scheduler.go b/resend_scheduler.go index 080f4e0c..b637a68a 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -101,14 +101,14 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { streamID := packet.SourceVirtualPortStreamID() discriminator := fmt.Sprintf("%s-%d-%d", packet.Sender().Address().String(), streamType, streamID) - connection.Endpoint.Connections.Delete(discriminator) + connection.endpoint.Connections.Delete(discriminator) return } if time.Since(pendingPacket.lastSendTime) >= rs.Interval { // * Resend the packet to the connection - server := connection.Endpoint.Server + server := connection.endpoint.Server data := packet.Bytes() server.sendRaw(connection.Socket, data) diff --git a/rmc_message.go b/rmc_message.go index 57108fa7..147582b7 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -9,7 +9,7 @@ import ( // RMCMessage represents a message in the RMC (Remote Method Call) protocol type RMCMessage struct { - Server ServerInterface + Server EndpointInterface VerboseMode bool // * Determines whether or not to encode the message using the "verbose" encoding method IsRequest bool // * Indicates if the message is a request message (true) or response message (false) IsSuccess bool // * Indicates if the message is a success message (true) for a response message @@ -56,7 +56,7 @@ func (rmc *RMCMessage) FromBytes(data []byte) error { } func (rmc *RMCMessage) decodePacked(data []byte) error { - stream := NewByteStreamIn(data, rmc.Server) + stream := NewByteStreamIn(data, rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) length, err := stream.ReadPrimitiveUInt32LE() if err != nil { @@ -143,7 +143,7 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } func (rmc *RMCMessage) decodeVerbose(data []byte) error { - stream := NewByteStreamIn(data, rmc.Server) + stream := NewByteStreamIn(data, rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) length, err := stream.ReadPrimitiveUInt32LE() if err != nil { @@ -232,7 +232,7 @@ func (rmc *RMCMessage) Bytes() []byte { } func (rmc *RMCMessage) encodePacked() []byte { - stream := NewByteStreamOut(rmc.Server) + stream := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) // * RMC requests have their protocol IDs ORed with 0x80 var protocolIDFlag uint16 = 0x80 @@ -279,7 +279,7 @@ func (rmc *RMCMessage) encodePacked() []byte { serialized := stream.Bytes() - message := NewByteStreamOut(rmc.Server) + message := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) @@ -289,7 +289,7 @@ func (rmc *RMCMessage) encodePacked() []byte { } func (rmc *RMCMessage) encodeVerbose() []byte { - stream := NewByteStreamOut(rmc.Server) + stream := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) rmc.ProtocolName.WriteTo(stream) stream.WritePrimitiveBool(rmc.IsRequest) @@ -328,7 +328,7 @@ func (rmc *RMCMessage) encodeVerbose() []byte { serialized := stream.Bytes() - message := NewByteStreamOut(rmc.Server) + message := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) @@ -338,14 +338,14 @@ func (rmc *RMCMessage) encodeVerbose() []byte { } // NewRMCMessage returns a new generic RMC Message -func NewRMCMessage(server ServerInterface) *RMCMessage { +func NewRMCMessage(server EndpointInterface) *RMCMessage { return &RMCMessage{ Server: server, } } // NewRMCRequest returns a new blank RMCRequest -func NewRMCRequest(server ServerInterface) *RMCMessage { +func NewRMCRequest(server EndpointInterface) *RMCMessage { return &RMCMessage{ Server: server, IsRequest: true, @@ -353,7 +353,7 @@ func NewRMCRequest(server ServerInterface) *RMCMessage { } // NewRMCSuccess returns a new RMC Message configured as a success response -func NewRMCSuccess(server ServerInterface, parameters []byte) *RMCMessage { +func NewRMCSuccess(server EndpointInterface, parameters []byte) *RMCMessage { message := NewRMCMessage(server) message.IsRequest = false message.IsSuccess = true @@ -363,7 +363,7 @@ func NewRMCSuccess(server ServerInterface, parameters []byte) *RMCMessage { } // NewRMCError returns a new RMC Message configured as a error response -func NewRMCError(server ServerInterface, errorCode uint32) *RMCMessage { +func NewRMCError(server EndpointInterface, errorCode uint32) *RMCMessage { if int(errorCode)&errorMask == 0 { errorCode = uint32(int(errorCode) | errorMask) } diff --git a/server_interface.go b/server_interface.go deleted file mode 100644 index 94183e6c..00000000 --- a/server_interface.go +++ /dev/null @@ -1,19 +0,0 @@ -package nex - -// ServerInterface defines all the methods a server should have regardless of type -type ServerInterface interface { - AccessKey() string - SetAccessKey(accessKey string) - LibraryVersion() *LibraryVersion - DataStoreProtocolVersion() *LibraryVersion - MatchMakingProtocolVersion() *LibraryVersion - RankingProtocolVersion() *LibraryVersion - Ranking2ProtocolVersion() *LibraryVersion - MessagingProtocolVersion() *LibraryVersion - UtilityProtocolVersion() *LibraryVersion - NATTraversalProtocolVersion() *LibraryVersion - SetDefaultLibraryVersion(version *LibraryVersion) - Send(packet PacketInterface) - ByteStreamSettings() *ByteStreamSettings - SetByteStreamSettings(settings *ByteStreamSettings) -} diff --git a/test/auth.go b/test/auth.go index 16cd8733..a85c4329 100644 --- a/test/auth.go +++ b/test/auth.go @@ -9,19 +9,20 @@ import ( ) var authServer *nex.PRUDPServer +var authEndpoint *nex.PRUDPEndPoint func startAuthenticationServer() { fmt.Println("Starting auth") authServer = nex.NewPRUDPServer() - endpoint := nex.NewPRUDPEndPoint(1) + authEndpoint = nex.NewPRUDPEndPoint(1) - endpoint.AccountDetailsByPID = accountDetailsByPID - endpoint.AccountDetailsByUsername = accountDetailsByUsername - endpoint.ServerAccount = authenticationServerAccount + authEndpoint.AccountDetailsByPID = accountDetailsByPID + authEndpoint.AccountDetailsByUsername = accountDetailsByUsername + authEndpoint.ServerAccount = authenticationServerAccount - endpoint.OnData(func(packet nex.PacketInterface) { + authEndpoint.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { request := packet.RMCMessage() @@ -40,20 +41,20 @@ func startAuthenticationServer() { }) authServer.SetFragmentSize(962) - authServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) + authServer.LibraryVersions.SetDefault(nex.NewLibraryVersion(1, 1, 0)) authServer.SessionKeyLength = 16 - authServer.SetAccessKey("ridfebb9") - authServer.BindPRUDPEndPoint(endpoint) + authServer.AccessKey = "ridfebb9" + authServer.BindPRUDPEndPoint(authEndpoint) authServer.Listen(60000) } func login(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage(authServer) + response := nex.NewRMCMessage(authEndpoint) parameters := request.Parameters - parametersStream := nex.NewByteStreamIn(parameters, authServer) + parametersStream := nex.NewByteStreamIn(parameters, authEndpoint.LibraryVersions(), authEndpoint.ByteStreamSettings()) strUserName := types.NewString("") if err := strUserName.ExtractFrom(parametersStream); err != nil { @@ -75,7 +76,7 @@ func login(packet nex.PRUDPPacketInterface) { pConnectionData.StationURLSpecialProtocols = types.NewStationURL("") pConnectionData.Time = types.NewDateTime(0).Now() - responseStream := nex.NewByteStreamOut(authServer) + responseStream := nex.NewByteStreamOut(authEndpoint.LibraryVersions(), authEndpoint.ByteStreamSettings()) retval.WriteTo(responseStream) pidPrincipal.WriteTo(responseStream) @@ -91,7 +92,7 @@ func login(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(authServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) @@ -109,11 +110,11 @@ func login(packet nex.PRUDPPacketInterface) { func requestTicket(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage(authServer) + response := nex.NewRMCMessage(authEndpoint) parameters := request.Parameters - parametersStream := nex.NewByteStreamIn(parameters, authServer) + parametersStream := nex.NewByteStreamIn(parameters, authEndpoint.LibraryVersions(), authEndpoint.ByteStreamSettings()) idSource := types.NewPID(0) if err := idSource.ExtractFrom(parametersStream); err != nil { @@ -131,7 +132,7 @@ func requestTicket(packet nex.PRUDPPacketInterface) { retval := types.NewQResultSuccess(0x00010001) pbufResponse := types.NewBuffer(generateTicket(sourceAccount, targetAccount, authServer.SessionKeyLength)) - responseStream := nex.NewByteStreamOut(authServer) + responseStream := nex.NewByteStreamOut(authEndpoint.LibraryVersions(), authEndpoint.ByteStreamSettings()) retval.WriteTo(responseStream) pbufResponse.WriteTo(responseStream) @@ -144,7 +145,7 @@ func requestTicket(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(authServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) diff --git a/test/generate_ticket.go b/test/generate_ticket.go index b2023eab..0dcc4462 100644 --- a/test/generate_ticket.go +++ b/test/generate_ticket.go @@ -17,21 +17,21 @@ func generateTicket(source *nex.Account, target *nex.Account, sessionKeyLength i panic(err) } - ticketInternalData := nex.NewKerberosTicketInternalData() + ticketInternalData := nex.NewKerberosTicketInternalData(authServer) serverTime := types.NewDateTime(0).Now() ticketInternalData.Issued = serverTime ticketInternalData.SourcePID = source.PID ticketInternalData.SessionKey = sessionKey - encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewByteStreamOut(authServer)) + encryptedTicketInternalData, _ := ticketInternalData.Encrypt(targetKey, nex.NewByteStreamOut(authServer.LibraryVersions, authServer.ByteStreamSettings)) ticket := nex.NewKerberosTicket() ticket.SessionKey = sessionKey ticket.TargetPID = target.PID ticket.InternalData = types.NewBuffer(encryptedTicketInternalData) - encryptedTicket, _ := ticket.Encrypt(sourceKey, nex.NewByteStreamOut(authServer)) + encryptedTicket, _ := ticket.Encrypt(sourceKey, nex.NewByteStreamOut(authServer.LibraryVersions, authServer.ByteStreamSettings)) return encryptedTicket } diff --git a/test/hpp.go b/test/hpp.go index 1eec2b9c..03ca1bfe 100644 --- a/test/hpp.go +++ b/test/hpp.go @@ -77,7 +77,7 @@ func startHPPServer() { } }) - hppServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(2, 4, 1)) + hppServer.LibraryVersions().SetDefault(nex.NewLibraryVersion(2, 4, 1)) hppServer.SetAccessKey("76f26496") hppServer.AccountDetailsByPID = accountDetailsByPID hppServer.AccountDetailsByUsername = accountDetailsByUsername @@ -91,7 +91,7 @@ func getNotificationURL(packet *nex.HPPPacket) { parameters := request.Parameters - parametersStream := nex.NewByteStreamIn(parameters, hppServer) + parametersStream := nex.NewByteStreamIn(parameters, hppServer.LibraryVersions(), hppServer.ByteStreamSettings()) param := &dataStoreGetNotificationURLParam{} param.PreviousURL = types.NewString("") @@ -103,7 +103,7 @@ func getNotificationURL(packet *nex.HPPPacket) { fmt.Println("[HPP]", param.PreviousURL) - responseStream := nex.NewByteStreamOut(hppServer) + responseStream := nex.NewByteStreamOut(hppServer.LibraryVersions(), hppServer.ByteStreamSettings()) info := &dataStoreReqGetNotificationURLInfo{} info.URL = types.NewString("https://example.com") diff --git a/test/secure.go b/test/secure.go index 613324c3..615ebf53 100644 --- a/test/secure.go +++ b/test/secure.go @@ -10,6 +10,7 @@ import ( ) var secureServer *nex.PRUDPServer +var secureEndpoint *nex.PRUDPEndPoint // * Took these structs out of the protocols lib for convenience @@ -46,13 +47,13 @@ func startSecureServer() { secureServer = nex.NewPRUDPServer() - endpoint := nex.NewPRUDPEndPoint(1) + secureEndpoint = nex.NewPRUDPEndPoint(1) - endpoint.AccountDetailsByPID = accountDetailsByPID - endpoint.AccountDetailsByUsername = accountDetailsByUsername - endpoint.ServerAccount = secureServerAccount + secureEndpoint.AccountDetailsByPID = accountDetailsByPID + secureEndpoint.AccountDetailsByUsername = accountDetailsByUsername + secureEndpoint.ServerAccount = secureServerAccount - endpoint.OnData(func(packet nex.PacketInterface) { + secureEndpoint.OnData(func(packet nex.PacketInterface) { if packet, ok := packet.(nex.PRUDPPacketInterface); ok { request := packet.RMCMessage() @@ -79,21 +80,21 @@ func startSecureServer() { }) secureServer.SetFragmentSize(962) - secureServer.SetDefaultLibraryVersion(nex.NewLibraryVersion(1, 1, 0)) + secureServer.LibraryVersions.SetDefault(nex.NewLibraryVersion(1, 1, 0)) secureServer.SessionKeyLength = 16 - secureServer.SetAccessKey("ridfebb9") - secureServer.BindPRUDPEndPoint(endpoint) + secureServer.AccessKey = "ridfebb9" + secureServer.BindPRUDPEndPoint(secureEndpoint) secureServer.Listen(60001) } func registerEx(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage(secureServer) + response := nex.NewRMCMessage(secureEndpoint) connection := packet.Sender().(*nex.PRUDPConnection) parameters := request.Parameters - parametersStream := nex.NewByteStreamIn(parameters, secureServer) + parametersStream := nex.NewByteStreamIn(parameters, secureEndpoint.LibraryVersions(), secureEndpoint.ByteStreamSettings()) vecMyURLs := types.NewList[*types.StationURL]() vecMyURLs.Type = types.NewStationURL("") @@ -116,7 +117,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { retval := types.NewQResultSuccess(0x00010001) localStationURL := types.NewString(localStation.EncodeToString()) - responseStream := nex.NewByteStreamOut(secureServer) + responseStream := nex.NewByteStreamOut(secureEndpoint.LibraryVersions(), secureEndpoint.ByteStreamSettings()) retval.WriteTo(responseStream) responseStream.WritePrimitiveUInt32LE(connection.ID) @@ -130,7 +131,7 @@ func registerEx(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(connection, nil) + responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, connection, nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) @@ -148,9 +149,9 @@ func registerEx(packet nex.PRUDPPacketInterface) { func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage(secureServer) + response := nex.NewRMCMessage(secureEndpoint) - responseStream := nex.NewByteStreamOut(secureServer) + responseStream := nex.NewByteStreamOut(secureEndpoint.LibraryVersions(), secureEndpoint.ByteStreamSettings()) (&principalPreference{ ShowOnlinePresence: types.NewPrimitiveBool(true), @@ -178,7 +179,7 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) @@ -196,9 +197,9 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { func checkSettingStatus(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage(secureServer) + response := nex.NewRMCMessage(secureEndpoint) - responseStream := nex.NewByteStreamOut(secureServer) + responseStream := nex.NewByteStreamOut(secureEndpoint.LibraryVersions(), secureEndpoint.ByteStreamSettings()) responseStream.WritePrimitiveUInt8(0) // * Unknown @@ -210,7 +211,7 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { response.MethodID = request.MethodID response.Parameters = responseStream.Bytes() - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) @@ -228,7 +229,7 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { func updatePresence(packet nex.PRUDPPacketInterface) { request := packet.RMCMessage() - response := nex.NewRMCMessage(secureServer) + response := nex.NewRMCMessage(secureEndpoint) response.IsSuccess = true response.IsRequest = false @@ -237,7 +238,7 @@ func updatePresence(packet nex.PRUDPPacketInterface) { response.CallID = request.CallID response.MethodID = request.MethodID - responsePacket, _ := nex.NewPRUDPPacketV0(packet.Sender().(*nex.PRUDPConnection), nil) + responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) responsePacket.AddFlag(nex.FlagHasSize) diff --git a/websocket_server.go b/websocket_server.go index e7f92931..ceb5ea49 100644 --- a/websocket_server.go +++ b/websocket_server.go @@ -75,10 +75,10 @@ func (ws *WebSocketServer) init() { ws.upgrader = gws.NewUpgrader(&wsEventHandler{ prudpServer: ws.prudpServer, }, &gws.ServerOption{ - ReadAsyncEnabled: true, // * Parallel message processing - Recovery: gws.Recovery, // * Exception recovery - ReadBufferSize: 64000, - WriteBufferSize: 64000, + ParallelEnabled: true, // * Parallel message processing + Recovery: gws.Recovery, // * Exception recovery + ReadBufferSize: 64000, + WriteBufferSize: 64000, }) ws.mux = http.NewServeMux() From 857fb68f629efe1719b5a2b95a05811c9fb9de4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 11 Feb 2024 00:56:06 +0000 Subject: [PATCH 138/178] prudp: Support legacy NEX 1 clients --- prudp_endpoint.go | 40 +++++++++++++++++++++++++++++++++++++--- prudp_packet_v0.go | 2 +- prudp_v0_settings.go | 4 ++++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 74f3bb99..7f4e8241 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -24,6 +24,7 @@ type PRUDPEndPoint struct { ServerAccount *Account AccountDetailsByPID func(pid *types.PID) (*Account, *Error) AccountDetailsByUsername func(username string) (*Account, *Error) + IsSecureEndpoint bool } // RegisterServiceProtocol registers a NEX service with the endpoint @@ -289,10 +290,25 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { payload := make([]byte, 0) - if len(packet.Payload()) != 0 { - sessionKey, pid, checkValue, err := pep.readKerberosTicket(packet.Payload()) + if pep.IsSecureEndpoint { + var decryptedPayload []byte + if pep.Server.PRUDPV0Settings.EncryptedConnect { + decryptedPayload = packet.decryptPayload() + + } else { + decryptedPayload = packet.Payload() + } + + decompressedPayload, err := connection.StreamSettings.CompressionAlgorithm.Decompress(decryptedPayload) if err != nil { logger.Error(err.Error()) + return + } + + sessionKey, pid, checkValue, err := pep.readKerberosTicket(decompressedPayload) + if err != nil { + logger.Error(err.Error()) + return } connection.SetPID(pid) @@ -311,7 +327,24 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { payload = stream.Bytes() } - ack.SetPayload(payload) + compressedPayload, err := connection.StreamSettings.CompressionAlgorithm.Compress(payload) + if err != nil { + logger.Error(err.Error()) + return + } + + var encryptedPayload []byte + if pep.Server.PRUDPV0Settings.EncryptedConnect { + encryptedPayload, err = connection.StreamSettings.EncryptionAlgorithm.Encrypt(compressedPayload) + if err != nil { + logger.Error(err.Error()) + return + } + } else { + encryptedPayload = compressedPayload + } + + ack.SetPayload(encryptedPayload) ack.setSignature(ack.calculateSignature([]byte{}, packet.getConnectionSignature())) pep.emit("connect", ack) @@ -645,5 +678,6 @@ func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { packetEventHandlers: make(map[string][]func(PacketInterface)), connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), ConnectionIDCounter: NewCounter[uint32](0), + IsSecureEndpoint: false, } } diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 7aca22a8..0ec33800 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -309,7 +309,7 @@ func defaultPRUDPv0ConnectionSignature(packet *PRUDPPacketV0, addr net.Addr) ([] } func defaultPRUDPv0CalculateSignature(packet *PRUDPPacketV0, sessionKey, connectionSignature []byte) []byte { - if !packet.server.PRUDPV0Settings.IsQuazalMode { + if !packet.server.PRUDPV0Settings.LegacyConnectionSignature { if packet.packetType == DataPacket { return packet.server.PRUDPV0Settings.DataSignatureCalculator(packet, sessionKey) } diff --git a/prudp_v0_settings.go b/prudp_v0_settings.go index b76d2df5..e768f3c3 100644 --- a/prudp_v0_settings.go +++ b/prudp_v0_settings.go @@ -7,6 +7,8 @@ import "net" // PRUDPV0Settings defines settings for how to handle aspects of PRUDPv0 packets type PRUDPV0Settings struct { IsQuazalMode bool + EncryptedConnect bool + LegacyConnectionSignature bool UseEnhancedChecksum bool ConnectionSignatureCalculator func(packet *PRUDPPacketV0, addr net.Addr) ([]byte, error) SignatureCalculator func(packet *PRUDPPacketV0, sessionKey, connectionSignature []byte) []byte @@ -18,6 +20,8 @@ type PRUDPV0Settings struct { func NewPRUDPV0Settings() *PRUDPV0Settings { return &PRUDPV0Settings{ IsQuazalMode: false, + EncryptedConnect: false, + LegacyConnectionSignature: false, UseEnhancedChecksum: false, ConnectionSignatureCalculator: defaultPRUDPv0ConnectionSignature, SignatureCalculator: defaultPRUDPv0CalculateSignature, From 9215e65ec3beddf419db40b3ddc5687b4532d73a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 11 Feb 2024 14:44:16 +0000 Subject: [PATCH 139/178] chore: bugfixes and cleanup --- prudp_endpoint.go | 6 +++++- types/data.go | 6 +----- types/map.go | 2 +- types/structure.go | 2 -- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 7f4e8241..7b2121bb 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -293,7 +293,11 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { if pep.IsSecureEndpoint { var decryptedPayload []byte if pep.Server.PRUDPV0Settings.EncryptedConnect { - decryptedPayload = packet.decryptPayload() + decryptedPayload, err = connection.StreamSettings.EncryptionAlgorithm.Decrypt(packet.Payload()) + if err != nil { + logger.Error(err.Error()) + return + } } else { decryptedPayload = packet.Payload() diff --git a/types/data.go b/types/data.go index f20ef34e..bd5227d6 100644 --- a/types/data.go +++ b/types/data.go @@ -41,11 +41,7 @@ func (d *Data) Equals(o RVType) bool { other := o.(*Data) - if d.StructureVersion == other.StructureVersion { - return false - } - - return d.StructureContentLength == other.StructureContentLength + return d.StructureVersion == other.StructureVersion } // String returns a string representation of the struct diff --git a/types/map.go b/types/map.go index 5f0a6c49..5a1f2a20 100644 --- a/types/map.go +++ b/types/map.go @@ -51,7 +51,7 @@ func (m *Map[K, V]) ExtractFrom(readable Readable) error { return err } - keys = append(keys, value.(K)) + keys = append(keys, key.(K)) values = append(values, value.(V)) } diff --git a/types/structure.go b/types/structure.go index aa521c63..35964415 100644 --- a/types/structure.go +++ b/types/structure.go @@ -8,7 +8,6 @@ import ( // Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct. type Structure struct { StructureVersion uint8 - StructureContentLength uint32 } // ExtractHeaderFrom extracts the structure header from the given readable @@ -29,7 +28,6 @@ func (s *Structure) ExtractHeaderFrom(readable Readable) error { } s.StructureVersion = version - s.StructureContentLength = contentLength } return nil From 463b5e21f7c112b570d9e6e96a040fefc8f23a5c Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 11 Feb 2024 16:58:22 -0500 Subject: [PATCH 140/178] chore: rename ClientInterface to ConnectionInterface --- client_interface.go => connection_interface.go | 4 ++-- hpp_packet.go | 2 +- packet_interface.go | 2 +- prudp_packet.go | 4 ++-- prudp_packet_interface.go | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) rename client_interface.go => connection_interface.go (67%) diff --git a/client_interface.go b/connection_interface.go similarity index 67% rename from client_interface.go rename to connection_interface.go index dfc16556..e86d8a35 100644 --- a/client_interface.go +++ b/connection_interface.go @@ -7,8 +7,8 @@ import ( "github.com/PretendoNetwork/nex-go/types" ) -// ClientInterface defines all the methods a client should have regardless of server type -type ClientInterface interface { +// ConnectionInterface defines all the methods a connection should have regardless of server type +type ConnectionInterface interface { Endpoint() EndpointInterface Address() net.Addr PID() *types.PID diff --git a/hpp_packet.go b/hpp_packet.go index 8c4c2a2a..3b433a37 100644 --- a/hpp_packet.go +++ b/hpp_packet.go @@ -20,7 +20,7 @@ type HPPPacket struct { } // Sender returns the Client who sent the packet -func (p *HPPPacket) Sender() ClientInterface { +func (p *HPPPacket) Sender() ConnectionInterface { return p.sender } diff --git a/packet_interface.go b/packet_interface.go index a2bcc397..3d004e2f 100644 --- a/packet_interface.go +++ b/packet_interface.go @@ -2,7 +2,7 @@ package nex // PacketInterface defines all the methods a packet for both PRUDP and HPP should have type PacketInterface interface { - Sender() ClientInterface + Sender() ConnectionInterface Payload() []byte SetPayload(payload []byte) RMCMessage() *RMCMessage diff --git a/prudp_packet.go b/prudp_packet.go index 535628a9..ad261f49 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -23,12 +23,12 @@ type PRUDPPacket struct { } // SetSender sets the Client who sent the packet -func (p *PRUDPPacket) SetSender(sender ClientInterface) { +func (p *PRUDPPacket) SetSender(sender ConnectionInterface) { p.sender = sender.(*PRUDPConnection) } // Sender returns the Client who sent the packet -func (p *PRUDPPacket) Sender() ClientInterface { +func (p *PRUDPPacket) Sender() ConnectionInterface { return p.sender } diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index 0349a002..d1e38dbb 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -7,8 +7,8 @@ type PRUDPPacketInterface interface { Copy() PRUDPPacketInterface Version() int Bytes() []byte - SetSender(sender ClientInterface) - Sender() ClientInterface + SetSender(sender ConnectionInterface) + Sender() ConnectionInterface Flags() uint16 HasFlag(flag uint16) bool AddFlag(flag uint16) From ff03d9956848f4687b36b9e09784d114c69e698f Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 12 Feb 2024 15:07:49 -0500 Subject: [PATCH 141/178] update: add Packet property to Error and OnError event hooks --- endpoint_interface.go | 1 + error.go | 9 ++++----- hpp_server.go | 23 ++++++++++++++++------- prudp_endpoint.go | 15 +++++++++++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/endpoint_interface.go b/endpoint_interface.go index b8b2dfee..796f2f3b 100644 --- a/endpoint_interface.go +++ b/endpoint_interface.go @@ -8,4 +8,5 @@ type EndpointInterface interface { LibraryVersions() *LibraryVersions ByteStreamSettings() *ByteStreamSettings SetByteStreamSettings(settings *ByteStreamSettings) + EmitError(err *Error) } diff --git a/error.go b/error.go index 62467f3a..b98f5138 100644 --- a/error.go +++ b/error.go @@ -2,12 +2,11 @@ package nex import "fmt" -// TODO - Add more metadata? Like the sender or whatever? - -// Error is a custom error type implementing the error interface +// Error is a custom error type implementing the error interface. type Error struct { - ResultCode uint32 - Message string + ResultCode uint32 // * NEX result code. See result_codes.go for details + Message string // * The error base message + Packet PacketInterface // * The packet which caused the error. May not always be present } // Error satisfies the error interface and prints the underlying error diff --git a/hpp_server.go b/hpp_server.go index a4522c2a..1650cc4e 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -12,13 +12,14 @@ import ( // HPPServer represents a bare-bones HPP server type HPPServer struct { - server *http.Server - accessKey string - libraryVersions *LibraryVersions - dataHandlers []func(packet PacketInterface) - byteStreamSettings *ByteStreamSettings - AccountDetailsByPID func(pid *types.PID) (*Account, *Error) - AccountDetailsByUsername func(username string) (*Account, *Error) + server *http.Server + accessKey string + libraryVersions *LibraryVersions + dataHandlers []func(packet PacketInterface) + errorEventHandlers []func(err *Error) + byteStreamSettings *ByteStreamSettings + AccountDetailsByPID func(pid *types.PID) (*Account, *Error) + AccountDetailsByUsername func(username string) (*Account, *Error) } // RegisterServiceProtocol registers a NEX service with the HPP server @@ -31,6 +32,13 @@ func (s *HPPServer) OnData(handler func(packet PacketInterface)) { s.dataHandlers = append(s.dataHandlers, handler) } +// EmitError calls all the endpoints error event handlers with the provided error +func (s *HPPServer) EmitError(err *Error) { + for _, handler := range s.errorEventHandlers { + go handler(err) + } +} + func (s *HPPServer) handleRequest(w http.ResponseWriter, req *http.Request) { if req.Method != "POST" { w.WriteHeader(http.StatusBadRequest) @@ -189,6 +197,7 @@ func (s *HPPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings func NewHPPServer() *HPPServer { s := &HPPServer{ dataHandlers: make([]func(packet PacketInterface), 0), + errorEventHandlers: make([]func(err *Error), 0), libraryVersions: NewLibraryVersions(), byteStreamSettings: NewByteStreamSettings(), } diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 7b2121bb..59a20a8f 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -20,6 +20,7 @@ type PRUDPEndPoint struct { Connections *MutexMap[string, *PRUDPConnection] packetEventHandlers map[string][]func(packet PacketInterface) connectionEndedEventHandlers []func(connection *PRUDPConnection) + errorEventHandlers []func(err *Error) ConnectionIDCounter *Counter[uint32] ServerAccount *Account AccountDetailsByPID func(pid *types.PID) (*Account, *Error) @@ -37,6 +38,12 @@ func (pep *PRUDPEndPoint) OnData(handler func(packet PacketInterface)) { pep.on("data", handler) } +// OnError adds an event handler which is fired when an error occurs on the endpoint +func (pep *PRUDPEndPoint) OnError(handler func(err *Error)) { + // * "Ended" events are a special case, so handle them separately + pep.errorEventHandlers = append(pep.errorEventHandlers, handler) +} + // OnDisconnect adds an event handler which is fired when a new DISCONNECT packet is received // // To handle a connection being removed from the server, see OnConnectionEnded which fires on more cases @@ -74,6 +81,13 @@ func (pep *PRUDPEndPoint) emitConnectionEnded(connection *PRUDPConnection) { } } +// EmitError calls all the endpoints error event handlers with the provided error +func (pep *PRUDPEndPoint) EmitError(err *Error) { + for _, handler := range pep.errorEventHandlers { + go handler(err) + } +} + func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *SocketConnection) { streamType := packet.SourceVirtualPortStreamType() streamID := packet.SourceVirtualPortStreamID() @@ -681,6 +695,7 @@ func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { Connections: NewMutexMap[string, *PRUDPConnection](), packetEventHandlers: make(map[string][]func(PacketInterface)), connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), + errorEventHandlers: make([]func(err *Error), 0), ConnectionIDCounter: NewCounter[uint32](0), IsSecureEndpoint: false, } From 42d8a5380c783be8f695c71a74379281764561e3 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 18 Feb 2024 12:02:00 -0500 Subject: [PATCH 142/178] chore: set protocols endpoint in RegisterServiceProtocol --- hpp_server.go | 1 + prudp_endpoint.go | 1 + service_protocol.go | 2 ++ 3 files changed, 4 insertions(+) diff --git a/hpp_server.go b/hpp_server.go index 1650cc4e..4abec495 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -24,6 +24,7 @@ type HPPServer struct { // RegisterServiceProtocol registers a NEX service with the HPP server func (s *HPPServer) RegisterServiceProtocol(protocol ServiceProtocol) { + protocol.SetEndpoint(s) s.OnData(protocol.HandlePacket) } diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 59a20a8f..54b42d19 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -30,6 +30,7 @@ type PRUDPEndPoint struct { // RegisterServiceProtocol registers a NEX service with the endpoint func (pep *PRUDPEndPoint) RegisterServiceProtocol(protocol ServiceProtocol) { + protocol.SetEndpoint(pep) pep.OnData(protocol.HandlePacket) } diff --git a/service_protocol.go b/service_protocol.go index bc979ec6..f8c8a3a5 100644 --- a/service_protocol.go +++ b/service_protocol.go @@ -3,4 +3,6 @@ package nex // ServiceProtocol represents a NEX service capable of handling PRUDP/HPP packets type ServiceProtocol interface { HandlePacket(packet PacketInterface) + Endpoint() EndpointInterface + SetEndpoint(endpoint EndpointInterface) } From 5f04a25d1d6171cc40c10751af8d3437f1da52d6 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 18 Feb 2024 22:23:55 -0500 Subject: [PATCH 143/178] prudp: added ConnectionState checks --- connection_state.go | 25 +++++++++++++++++++++++++ prudp_connection.go | 6 ++++-- prudp_endpoint.go | 27 +++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 connection_state.go diff --git a/connection_state.go b/connection_state.go new file mode 100644 index 00000000..a1cd10ef --- /dev/null +++ b/connection_state.go @@ -0,0 +1,25 @@ +package nex + +// ConnectionState is an implementation of nn::nex::EndPoint::_ConnectionState. +// +// The state represents a PRUDP clients connection state. The original Rendez-Vous +// library supports states 0-6, though NEX only supports 0-4. The remaining 2 are +// unknown +type ConnectionState uint8 + +const ( + // StateNotConnected indicates the client has not established a full PRUDP connection + StateNotConnected ConnectionState = iota + + // StateConnecting indicates the client is attempting to establish a PRUDP connection + StateConnecting + + // StateConnected indicates the client has established a full PRUDP connection + StateConnected + + // StateDisconnecting indicates the client is disconnecting from a PRUDP connection. Currently unused + StateDisconnecting + + // StateFaulty indicates the client connection is faulty. Currently unused + StateFaulty +) diff --git a/prudp_connection.go b/prudp_connection.go index a5a1036f..005cd17a 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -13,8 +13,9 @@ import ( // Does not necessarily represent a socket connection. // A single network socket may be used to open multiple PRUDP virtual connections type PRUDPConnection struct { - Socket *SocketConnection // * The connections parent socket - endpoint *PRUDPEndPoint // * The PRUDP endpoint the connection is connected to + Socket *SocketConnection // * The connections parent socket + endpoint *PRUDPEndPoint // * The PRUDP endpoint the connection is connected to + ConnectionState ConnectionState ID uint32 // * Connection ID SessionID uint8 // * Random value generated at the start of the session. Client and server IDs do not need to match ServerSessionID uint8 // * Random value generated at the start of the session. Client and server IDs do not need to match @@ -211,6 +212,7 @@ func (pc *PRUDPConnection) stopHeartbeatTimers() { func NewPRUDPConnection(socket *SocketConnection) *PRUDPConnection { pc := &PRUDPConnection{ Socket: socket, + ConnectionState: StateNotConnected, pid: types.NewPID(0), slidingWindows: NewMutexMap[uint8, *SlidingWindow](), outgoingUnreliableSequenceIDCounter: NewCounter[uint16](1), diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 54b42d19..7e8893c2 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -214,6 +214,12 @@ func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface) func (pep *PRUDPEndPoint) handleSyn(packet PRUDPPacketInterface) { connection := packet.Sender().(*PRUDPConnection) + if connection.ConnectionState != StateNotConnected { + // TODO - Log this? + // * Connection is in a bad state, drop the packet and let it die + return + } + var ack PRUDPPacketInterface if packet.Version() == 2 { @@ -249,6 +255,8 @@ func (pep *PRUDPEndPoint) handleSyn(packet PRUDPPacketInterface) { ack.supportedFunctions = pep.Server.SupportedFunctions & packet.(*PRUDPPacketV1).supportedFunctions } + connection.ConnectionState = StateConnecting + pep.emit("syn", ack) pep.Server.sendRaw(connection.Socket, ack.Bytes()) @@ -257,6 +265,12 @@ func (pep *PRUDPEndPoint) handleSyn(packet PRUDPPacketInterface) { func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { connection := packet.Sender().(*PRUDPConnection) + if connection.ConnectionState != StateConnecting { + // TODO - Log this? + // * Connection is in a bad state, drop the packet and let it die + return + } + var ack PRUDPPacketInterface if packet.Version() == 2 { @@ -366,12 +380,22 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { ack.SetPayload(encryptedPayload) ack.setSignature(ack.calculateSignature([]byte{}, packet.getConnectionSignature())) + connection.ConnectionState = StateConnected + pep.emit("connect", ack) pep.Server.sendRaw(connection.Socket, ack.Bytes()) } func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { + connection := packet.Sender().(*PRUDPConnection) + + if connection.ConnectionState != StateConnected { + // TODO - Log this? + // * Connection is in a bad state, drop the packet and let it die + return + } + if packet.HasFlag(FlagReliable) { pep.handleReliable(packet) } else { @@ -380,6 +404,9 @@ func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { } func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { + // TODO - Should we check the state here, or just let the connection disconnect at any time? + // TODO - Should we bother to set the connections state here? It's being destroyed anyway + if packet.HasFlag(FlagNeedsAck) { pep.acknowledgePacket(packet) } From 88a2d4bbac307cf855fba89ac929d670764ee3b1 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 18 Feb 2024 22:35:08 -0500 Subject: [PATCH 144/178] rmc: add a way to enable verbose RMC messages --- endpoint_interface.go | 2 ++ hpp_server.go | 11 +++++++++++ prudp_endpoint.go | 10 ++++++++++ prudp_server.go | 1 + rmc_message.go | 37 ++++++++++++++++++------------------- 5 files changed, 42 insertions(+), 19 deletions(-) diff --git a/endpoint_interface.go b/endpoint_interface.go index 796f2f3b..9cba7189 100644 --- a/endpoint_interface.go +++ b/endpoint_interface.go @@ -8,5 +8,7 @@ type EndpointInterface interface { LibraryVersions() *LibraryVersions ByteStreamSettings() *ByteStreamSettings SetByteStreamSettings(settings *ByteStreamSettings) + UseVerboseRMC() bool // TODO - Move this to a RMCSettings struct? + EnableVerboseRMC(enabled bool) EmitError(err *Error) } diff --git a/hpp_server.go b/hpp_server.go index 4abec495..61f6085a 100644 --- a/hpp_server.go +++ b/hpp_server.go @@ -20,6 +20,7 @@ type HPPServer struct { byteStreamSettings *ByteStreamSettings AccountDetailsByPID func(pid *types.PID) (*Account, *Error) AccountDetailsByUsername func(username string) (*Account, *Error) + useVerboseRMC bool } // RegisterServiceProtocol registers a NEX service with the HPP server @@ -194,6 +195,16 @@ func (s *HPPServer) SetByteStreamSettings(byteStreamSettings *ByteStreamSettings s.byteStreamSettings = byteStreamSettings } +// UseVerboseRMC checks whether or not the endpoint uses verbose RMC +func (s *HPPServer) UseVerboseRMC() bool { + return s.useVerboseRMC +} + +// EnableVerboseRMC enable or disables the use of verbose RMC +func (s *HPPServer) EnableVerboseRMC(enable bool) { + s.useVerboseRMC = enable +} + // NewHPPServer returns a new HPP server func NewHPPServer() *HPPServer { s := &HPPServer{ diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 7e8893c2..2a393172 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -715,6 +715,16 @@ func (pep *PRUDPEndPoint) SetByteStreamSettings(byteStreamSettings *ByteStreamSe pep.Server.ByteStreamSettings = byteStreamSettings } +// UseVerboseRMC checks whether or not the endpoint uses verbose RMC +func (pep *PRUDPEndPoint) UseVerboseRMC() bool { + return pep.Server.UseVerboseRMC +} + +// EnableVerboseRMC enable or disables the use of verbose RMC +func (pep *PRUDPEndPoint) EnableVerboseRMC(enable bool) { + pep.Server.UseVerboseRMC = enable +} + // NewPRUDPEndPoint returns a new PRUDPEndPoint for a server on the provided stream ID func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { return &PRUDPEndPoint{ diff --git a/prudp_server.go b/prudp_server.go index f84899f9..a1464d9a 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -27,6 +27,7 @@ type PRUDPServer struct { LibraryVersions *LibraryVersions ByteStreamSettings *ByteStreamSettings PRUDPV0Settings *PRUDPV0Settings + UseVerboseRMC bool } // BindPRUDPEndPoint binds a provided PRUDPEndPoint to the server diff --git a/rmc_message.go b/rmc_message.go index 147582b7..26276c52 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -9,8 +9,7 @@ import ( // RMCMessage represents a message in the RMC (Remote Method Call) protocol type RMCMessage struct { - Server EndpointInterface - VerboseMode bool // * Determines whether or not to encode the message using the "verbose" encoding method + Endpoint EndpointInterface IsRequest bool // * Indicates if the message is a request message (true) or response message (false) IsSuccess bool // * Indicates if the message is a success message (true) for a response message IsHPP bool // * Indicates if the message is an HPP message @@ -27,7 +26,7 @@ type RMCMessage struct { // Copy copies the message into a new RMCMessage func (rmc *RMCMessage) Copy() *RMCMessage { - copied := NewRMCMessage(rmc.Server) + copied := NewRMCMessage(rmc.Endpoint) copied.IsRequest = rmc.IsRequest copied.IsSuccess = rmc.IsSuccess @@ -48,7 +47,7 @@ func (rmc *RMCMessage) Copy() *RMCMessage { // FromBytes decodes an RMCMessage from the given byte slice. func (rmc *RMCMessage) FromBytes(data []byte) error { - if rmc.VerboseMode { + if rmc.Endpoint.UseVerboseRMC() { return rmc.decodeVerbose(data) } else { return rmc.decodePacked(data) @@ -56,7 +55,7 @@ func (rmc *RMCMessage) FromBytes(data []byte) error { } func (rmc *RMCMessage) decodePacked(data []byte) error { - stream := NewByteStreamIn(data, rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) + stream := NewByteStreamIn(data, rmc.Endpoint.LibraryVersions(), rmc.Endpoint.ByteStreamSettings()) length, err := stream.ReadPrimitiveUInt32LE() if err != nil { @@ -143,7 +142,7 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } func (rmc *RMCMessage) decodeVerbose(data []byte) error { - stream := NewByteStreamIn(data, rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) + stream := NewByteStreamIn(data, rmc.Endpoint.LibraryVersions(), rmc.Endpoint.ByteStreamSettings()) length, err := stream.ReadPrimitiveUInt32LE() if err != nil { @@ -224,7 +223,7 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { // Bytes serializes the RMCMessage to a byte slice. func (rmc *RMCMessage) Bytes() []byte { - if rmc.VerboseMode { + if rmc.Endpoint.UseVerboseRMC() { return rmc.encodeVerbose() } else { return rmc.encodePacked() @@ -232,7 +231,7 @@ func (rmc *RMCMessage) Bytes() []byte { } func (rmc *RMCMessage) encodePacked() []byte { - stream := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) + stream := NewByteStreamOut(rmc.Endpoint.LibraryVersions(), rmc.Endpoint.ByteStreamSettings()) // * RMC requests have their protocol IDs ORed with 0x80 var protocolIDFlag uint16 = 0x80 @@ -279,7 +278,7 @@ func (rmc *RMCMessage) encodePacked() []byte { serialized := stream.Bytes() - message := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) + message := NewByteStreamOut(rmc.Endpoint.LibraryVersions(), rmc.Endpoint.ByteStreamSettings()) message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) @@ -289,7 +288,7 @@ func (rmc *RMCMessage) encodePacked() []byte { } func (rmc *RMCMessage) encodeVerbose() []byte { - stream := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) + stream := NewByteStreamOut(rmc.Endpoint.LibraryVersions(), rmc.Endpoint.ByteStreamSettings()) rmc.ProtocolName.WriteTo(stream) stream.WritePrimitiveBool(rmc.IsRequest) @@ -328,7 +327,7 @@ func (rmc *RMCMessage) encodeVerbose() []byte { serialized := stream.Bytes() - message := NewByteStreamOut(rmc.Server.LibraryVersions(), rmc.Server.ByteStreamSettings()) + message := NewByteStreamOut(rmc.Endpoint.LibraryVersions(), rmc.Endpoint.ByteStreamSettings()) message.WritePrimitiveUInt32LE(uint32(len(serialized))) message.Grow(int64(len(serialized))) @@ -338,23 +337,23 @@ func (rmc *RMCMessage) encodeVerbose() []byte { } // NewRMCMessage returns a new generic RMC Message -func NewRMCMessage(server EndpointInterface) *RMCMessage { +func NewRMCMessage(endpoint EndpointInterface) *RMCMessage { return &RMCMessage{ - Server: server, + Endpoint: endpoint, } } // NewRMCRequest returns a new blank RMCRequest -func NewRMCRequest(server EndpointInterface) *RMCMessage { +func NewRMCRequest(endpoint EndpointInterface) *RMCMessage { return &RMCMessage{ - Server: server, + Endpoint: endpoint, IsRequest: true, } } // NewRMCSuccess returns a new RMC Message configured as a success response -func NewRMCSuccess(server EndpointInterface, parameters []byte) *RMCMessage { - message := NewRMCMessage(server) +func NewRMCSuccess(endpoint EndpointInterface, parameters []byte) *RMCMessage { + message := NewRMCMessage(endpoint) message.IsRequest = false message.IsSuccess = true message.Parameters = parameters @@ -363,12 +362,12 @@ func NewRMCSuccess(server EndpointInterface, parameters []byte) *RMCMessage { } // NewRMCError returns a new RMC Message configured as a error response -func NewRMCError(server EndpointInterface, errorCode uint32) *RMCMessage { +func NewRMCError(endpoint EndpointInterface, errorCode uint32) *RMCMessage { if int(errorCode)&errorMask == 0 { errorCode = uint32(int(errorCode) | errorMask) } - message := NewRMCMessage(server) + message := NewRMCMessage(endpoint) message.IsRequest = false message.IsSuccess = false message.ErrorCode = errorCode From 8fdee7cc60d6fab831fbe8b0755a86866abf38b3 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 20 Feb 2024 15:20:17 -0500 Subject: [PATCH 145/178] prudp: update godoc comment in PRUDPConnection --- prudp_connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prudp_connection.go b/prudp_connection.go index 005cd17a..ad35ef6f 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -36,7 +36,7 @@ type PRUDPConnection struct { StationURLs *types.List[*types.StationURL] } -// Server returns the PRUDP server the connections socket is connected to +// Endpoint returns the PRUDP server the connections socket is connected to func (pc *PRUDPConnection) Endpoint() EndpointInterface { return pc.endpoint } From 582bcab984a675ecf6208d788644f6b1cb1b40d8 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 20 Feb 2024 15:21:24 -0500 Subject: [PATCH 146/178] prudp: fix ineffectual assignment to err in PRUDPServer --- prudp_server.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/prudp_server.go b/prudp_server.go index a1464d9a..10432ec5 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -80,9 +80,11 @@ func (ps *PRUDPServer) listenDatagram(quit chan struct{}) { var addr *net.UDPAddr read, addr, err = ps.udpSocket.ReadFromUDP(buffer) - packetData := buffer[:read] + if err == nil { + packetData := buffer[:read] - err = ps.handleSocketMessage(packetData, addr, nil) + err = ps.handleSocketMessage(packetData, addr, nil) + } } quit <- struct{}{} From 15672fd9815b3c41f9411af085ea3804defa2d6e Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 20 Feb 2024 15:29:48 -0500 Subject: [PATCH 147/178] chore: update comment about the SecureEndPoint in PRUDPEndPoint --- prudp_endpoint.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 2a393172..30389fd3 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -13,6 +13,10 @@ import ( // PRUDPEndPoint is an implementation of rdv::PRUDPEndPoint. // A PRUDPEndPoint represents a remote server location the client may connect to using a given remote stream ID. // Each PRUDPEndPoint handles it's own set of PRUDPConnections, state, and events. +// +// In NEX there exists nn::nex::SecureEndPoint, which presumably is what differentiates between the authentication +// and secure servers. However the functionality of rdv::PRUDPEndPoint and nn::nex::SecureEndPoint is seemingly +// identical. Rather than duplicate the logic from PRUDPEndpoint, a IsSecureEndpoint flag has been added instead. type PRUDPEndPoint struct { Server *PRUDPServer StreamID uint8 From 738045e14e005e2ac32ebbc3439c9e5bf708edc7 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 20 Feb 2024 15:31:45 -0500 Subject: [PATCH 148/178] prudp: rename IsSecureEndpoint to IsSecureEndPoint for accuracy --- prudp_endpoint.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 30389fd3..196be3ab 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -29,7 +29,7 @@ type PRUDPEndPoint struct { ServerAccount *Account AccountDetailsByPID func(pid *types.PID) (*Account, *Error) AccountDetailsByUsername func(username string) (*Account, *Error) - IsSecureEndpoint bool + IsSecureEndPoint bool } // RegisterServiceProtocol registers a NEX service with the endpoint @@ -323,7 +323,7 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { payload := make([]byte, 0) - if pep.IsSecureEndpoint { + if pep.IsSecureEndPoint { var decryptedPayload []byte if pep.Server.PRUDPV0Settings.EncryptedConnect { decryptedPayload, err = connection.StreamSettings.EncryptionAlgorithm.Decrypt(packet.Payload()) @@ -739,6 +739,6 @@ func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), errorEventHandlers: make([]func(err *Error), 0), ConnectionIDCounter: NewCounter[uint32](0), - IsSecureEndpoint: false, + IsSecureEndPoint: false, } } From 36582bbc2283d162c49fec0e645b0bd410204389 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 20 Feb 2024 15:33:01 -0500 Subject: [PATCH 149/178] prudp: removed unnecessary sliding window initialization --- prudp_endpoint.go | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 196be3ab..d831a956 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -109,27 +109,6 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc connection.StreamSettings = pep.DefaultStreamSettings.Copy() connection.startHeartbeat() - // * Fail-safe. If the server reboots, then - // * connection has no record of old connections. - // * An existing client which has not killed - // * the connection on it's end MAY still send - // * DATA packets once the server is back - // * online, assuming it reboots fast enough. - // * Since the client did NOT redo the SYN - // * and CONNECT packets, it's reliable - // * substreams never got remade. This is put - // * in place to ensure there is always AT - // * LEAST one substream in place, so the client - // * can naturally error out due to the RC4 - // * errors. - // * - // * NOTE: THE CLIENT MAY NOT HAVE THE REAL - // * CORRECT NUMBER OF SUBSTREAMS HERE. THIS - // * IS ONLY DONE TO PREVENT A SERVER CRASH, - // * NOT TO SAVE THE CLIENT. THE CLIENT IS - // * EXPECTED TO NATURALLY DIE HERE - connection.InitializeSlidingWindows(0) - pep.Connections.Set(discriminator, connection) } From 7b1bdc17a6b3f0be4bdc2ccab828abd0e9518d28 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 20 Feb 2024 15:34:53 -0500 Subject: [PATCH 150/178] prudp: only start heartbeat once connected --- prudp_endpoint.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index d831a956..ff9b37d8 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -107,13 +107,11 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc connection.StreamType = streamType connection.StreamID = streamID connection.StreamSettings = pep.DefaultStreamSettings.Copy() - connection.startHeartbeat() pep.Connections.Set(discriminator, connection) } packet.SetSender(connection) - connection.resetHeartbeat() if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { pep.handleAcknowledgment(packet) @@ -364,6 +362,7 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { ack.setSignature(ack.calculateSignature([]byte{}, packet.getConnectionSignature())) connection.ConnectionState = StateConnected + connection.startHeartbeat() pep.emit("connect", ack) @@ -379,6 +378,8 @@ func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { return } + connection.resetHeartbeat() + if packet.HasFlag(FlagReliable) { pep.handleReliable(packet) } else { @@ -406,6 +407,10 @@ func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { } func (pep *PRUDPEndPoint) handlePing(packet PRUDPPacketInterface) { + connection := packet.Sender().(*PRUDPConnection) + + connection.resetHeartbeat() + if packet.HasFlag(FlagNeedsAck) { pep.acknowledgePacket(packet) } From 0c105826808d94295b9893372fb705677c64e1b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 22 Feb 2024 16:09:15 +0000 Subject: [PATCH 151/178] prudp: Add connection check on ACKs We don't expect to receive any ACKs until the client is connected, and packet acknowledgement requires initialized sliding windows. --- prudp_endpoint.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index ff9b37d8..3eb0a66f 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -133,13 +133,18 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc } func (pep *PRUDPEndPoint) handleAcknowledgment(packet PRUDPPacketInterface) { + connection := packet.Sender().(*PRUDPConnection) + if connection.ConnectionState != StateConnected { + // TODO - Log this? + // * Connection is in a bad state, drop the packet and let it die + return + } + if packet.HasFlag(FlagMultiAck) { pep.handleMultiAcknowledgment(packet) return } - connection := packet.Sender().(*PRUDPConnection) - slidingWindow := connection.SlidingWindow(packet.SubstreamID()) slidingWindow.ResendScheduler.AcknowledgePacket(packet.SequenceID()) } From 9711453c58585f8f0024a1f67d31d03b87c8b377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Thu, 22 Feb 2024 16:11:26 +0000 Subject: [PATCH 152/178] prudp_connection: Fix godoc comment typo --- prudp_connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prudp_connection.go b/prudp_connection.go index ad35ef6f..19cb1698 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -36,7 +36,7 @@ type PRUDPConnection struct { StationURLs *types.List[*types.StationURL] } -// Endpoint returns the PRUDP server the connections socket is connected to +// Endpoint returns the PRUDP endpoint the connections socket is connected to func (pc *PRUDPConnection) Endpoint() EndpointInterface { return pc.endpoint } From cdcc2bbb9d0167daa32d6190a33bbe178ae27e73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 24 Feb 2024 00:18:49 +0000 Subject: [PATCH 153/178] prudpv1: Add PRUDPV1Settings Adds setting for controlling legacy connection signature calculation, needed for some games like Luigi's Mansion 2. --- prudp_packet_v1.go | 4 ++++ prudp_server.go | 10 +++++++++- prudp_v1_settings.go | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 prudp_v1_settings.go diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index cd867ea2..3571d527 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -342,6 +342,10 @@ func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byt key := md5.Sum(accessKeyBytes) mac := hmac.New(md5.New, key[:]) + if p.packetType == ConnectPacket && p.server.PRUDPV1Settings.LegacyConnectionSignature { + connectionSignature = make([]byte, 0) + } + mac.Write(header[4:]) mac.Write(sessionKey) mac.Write(accessKeySumBytes) diff --git a/prudp_server.go b/prudp_server.go index 10432ec5..965033f2 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -27,6 +27,7 @@ type PRUDPServer struct { LibraryVersions *LibraryVersions ByteStreamSettings *ByteStreamSettings PRUDPV0Settings *PRUDPV0Settings + PRUDPV1Settings *PRUDPV1Settings UseVerboseRMC bool } @@ -273,7 +274,13 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { } } - packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature)) + + if ps.PRUDPV1Settings.LegacyConnectionSignature { + packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.Signature)) + } else { + packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature)) + } + if packetCopy.HasFlag(FlagReliable) && packetCopy.HasFlag(FlagNeedsAck) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) @@ -327,5 +334,6 @@ func NewPRUDPServer() *PRUDPServer { LibraryVersions: NewLibraryVersions(), ByteStreamSettings: NewByteStreamSettings(), PRUDPV0Settings: NewPRUDPV0Settings(), + PRUDPV1Settings: NewPRUDPV1Settings(), } } diff --git a/prudp_v1_settings.go b/prudp_v1_settings.go new file mode 100644 index 00000000..5f0d0a87 --- /dev/null +++ b/prudp_v1_settings.go @@ -0,0 +1,15 @@ +package nex + +// TODO - We can also breakout the decoding/encoding functions here too, but that would require getters and setters for all packet fields + +// PRUDPV1Settings defines settings for how to handle aspects of PRUDPv1 packets +type PRUDPV1Settings struct { + LegacyConnectionSignature bool +} + +// NewPRUDPV1Settings returns a new PRUDPV1Settings +func NewPRUDPV1Settings() *PRUDPV1Settings { + return &PRUDPV1Settings{ + LegacyConnectionSignature: false, + } +} From 3f3d8da0ba52a5bf2aef9f8ea5598c625e384141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 24 Feb 2024 11:22:50 +0000 Subject: [PATCH 154/178] prudpv1: Add signature calculator functions settings --- prudp_packet_v1.go | 92 ++++++++++++++++++++++++-------------------- prudp_v1_settings.go | 10 ++++- 2 files changed, 58 insertions(+), 44 deletions(-) diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 3571d527..00a48260 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -309,51 +309,11 @@ func (p *PRUDPPacketV1) encodeOptions() []byte { } func (p *PRUDPPacketV1) calculateConnectionSignature(addr net.Addr) ([]byte, error) { - var ip net.IP - var port int - - switch v := addr.(type) { - case *net.UDPAddr: - ip = v.IP.To4() - port = v.Port - default: - return nil, fmt.Errorf("Unsupported network type: %T", addr) - } - - portBytes := make([]byte, 2) - binary.BigEndian.PutUint16(portBytes, uint16(port)) - - data := append(ip, portBytes...) - hash := hmac.New(md5.New, p.server.PRUDPv1ConnectionSignatureKey) - hash.Write(data) - - return hash.Sum(nil), nil + return p.server.PRUDPV1Settings.ConnectionSignatureCalculator(p, addr) } func (p *PRUDPPacketV1) calculateSignature(sessionKey, connectionSignature []byte) []byte { - accessKeyBytes := []byte(p.server.AccessKey) - options := p.encodeOptions() - header := p.encodeHeader() - - accessKeySum := sum[byte, uint32](accessKeyBytes) - accessKeySumBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(accessKeySumBytes, accessKeySum) - - key := md5.Sum(accessKeyBytes) - mac := hmac.New(md5.New, key[:]) - - if p.packetType == ConnectPacket && p.server.PRUDPV1Settings.LegacyConnectionSignature { - connectionSignature = make([]byte, 0) - } - - mac.Write(header[4:]) - mac.Write(sessionKey) - mac.Write(accessKeySumBytes) - mac.Write(connectionSignature) - mac.Write(options) - mac.Write(p.payload) - - return mac.Sum(nil) + return p.server.PRUDPV1Settings.SignatureCalculator(p, sessionKey, connectionSignature) } // NewPRUDPPacketV1 creates and returns a new PacketV1 using the provided Client and stream @@ -393,3 +353,51 @@ func NewPRUDPPacketsV1(server *PRUDPServer, connection *PRUDPConnection, readStr return packets, nil } + +func defaultPRUDPv1ConnectionSignature(packet *PRUDPPacketV1, addr net.Addr) ([]byte, error) { + var ip net.IP + var port int + + switch v := addr.(type) { + case *net.UDPAddr: + ip = v.IP.To4() + port = v.Port + default: + return nil, fmt.Errorf("Unsupported network type: %T", addr) + } + + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(port)) + + data := append(ip, portBytes...) + hash := hmac.New(md5.New, packet.server.PRUDPv1ConnectionSignatureKey) + hash.Write(data) + + return hash.Sum(nil), nil +} + +func defaultPRUDPv1CalculateSignature(packet *PRUDPPacketV1, sessionKey, connectionSignature []byte) []byte { + accessKeyBytes := []byte(packet.server.AccessKey) + options := packet.encodeOptions() + header := packet.encodeHeader() + + accessKeySum := sum[byte, uint32](accessKeyBytes) + accessKeySumBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(accessKeySumBytes, accessKeySum) + + key := md5.Sum(accessKeyBytes) + mac := hmac.New(md5.New, key[:]) + + if packet.packetType == ConnectPacket && packet.server.PRUDPV1Settings.LegacyConnectionSignature { + connectionSignature = make([]byte, 0) + } + + mac.Write(header[4:]) + mac.Write(sessionKey) + mac.Write(accessKeySumBytes) + mac.Write(connectionSignature) + mac.Write(options) + mac.Write(packet.payload) + + return mac.Sum(nil) +} diff --git a/prudp_v1_settings.go b/prudp_v1_settings.go index 5f0d0a87..d3033dbe 100644 --- a/prudp_v1_settings.go +++ b/prudp_v1_settings.go @@ -1,15 +1,21 @@ package nex +import "net" + // TODO - We can also breakout the decoding/encoding functions here too, but that would require getters and setters for all packet fields // PRUDPV1Settings defines settings for how to handle aspects of PRUDPv1 packets type PRUDPV1Settings struct { - LegacyConnectionSignature bool + LegacyConnectionSignature bool + ConnectionSignatureCalculator func(packet *PRUDPPacketV1, addr net.Addr) ([]byte, error) + SignatureCalculator func(packet *PRUDPPacketV1, sessionKey, connectionSignature []byte) []byte } // NewPRUDPV1Settings returns a new PRUDPV1Settings func NewPRUDPV1Settings() *PRUDPV1Settings { return &PRUDPV1Settings{ - LegacyConnectionSignature: false, + LegacyConnectionSignature: false, + ConnectionSignatureCalculator: defaultPRUDPv1ConnectionSignature, + SignatureCalculator: defaultPRUDPv1CalculateSignature, } } From 6ddc9bdf46daf65eb09e6415b1196409f87db3c9 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 26 Feb 2024 13:25:15 -0500 Subject: [PATCH 155/178] prudp: update enum godoc comments --- connection_state.go | 2 +- stream_type.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/connection_state.go b/connection_state.go index a1cd10ef..5fd5f216 100644 --- a/connection_state.go +++ b/connection_state.go @@ -1,6 +1,6 @@ package nex -// ConnectionState is an implementation of nn::nex::EndPoint::_ConnectionState. +// ConnectionState is an implementation of the nn::nex::EndPoint::_ConnectionState enum. // // The state represents a PRUDP clients connection state. The original Rendez-Vous // library supports states 0-6, though NEX only supports 0-4. The remaining 2 are diff --git a/stream_type.go b/stream_type.go index 8264cbf6..b1fb0249 100644 --- a/stream_type.go +++ b/stream_type.go @@ -2,7 +2,8 @@ package nex // TODO - Should this be moved to the types module? -// StreamType is an implementation of rdv::Stream::Type. +// StreamType is an implementation of the rdv::Stream::Type enum. +// // StreamType is used to create VirtualPorts used in PRUDP virtual // connections. Each stream may be one of these types, and each stream // has it's own state. From 6488c96db7dd3812ee19cfa8e4044f258a308e90 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 26 Feb 2024 13:53:40 -0500 Subject: [PATCH 156/178] prudp: add start of SignatureMethod enum --- signature_method.go | 74 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 signature_method.go diff --git a/signature_method.go b/signature_method.go new file mode 100644 index 00000000..6bd847ce --- /dev/null +++ b/signature_method.go @@ -0,0 +1,74 @@ +package nex + +// SignatureMethod is an implementation of the nn::nex::PRUDPMessageInterface::SignatureMethod enum. +// +// The signature method is used as part of the packet signature calculation process. It determines +// what data is used and from where when calculating the packets signature. +// +// Currently unused. Implemented for future use and dodumentation/note taking purposes. +// +// The following notes are derived from Xenoblade on the Wii U. Many details are unknown. +// +// Based on the `nn::nex::PRUDPMessageV1::CalcSignatureHelper` (`__CPR210__CalcSignatureHelper__Q3_2nn3nex14PRUDPMessageV1FPCQ3_2nn3nex6PacketQ4_2nn3nex21PRUDPMessageInterface15SignatureMethodPCQ3_2nn3nex3KeyQJ68J3nex6Stream4TypePCQ3_2nn3nex14SignatureBytesRQ3_2nn3nexJ167J`) +// function: +// +// There appears to be 9 signature methods. Methods 0, 2, 3, and 9 seem to do nothing. Method 1 +// seems to calculate the signature using the connection address. Methods 4-8 calculate the signature +// using parts of the packet. +// +// - Method 0: Calls `func_0x04b10f90` and bails immediately? +// - Method 1: Seems to calculate the signature using ONLY the connections address? It uses the values +// from `nn::nex::InetAddress::GetAddress` and `nn::nex::InetAddress::GetPortNumber`, among others. +// It does NOT follow the same code path as methods 4-9 +// - Method 2: Unknown. Bails without doing anything +// - Method 3: Unknown. Bails without doing anything +// +// Methods 4-8 build the signature from one or many parts of the packet +// +// - Methods 4-8: Use the value from `nn::nex::Packet::GetHeaderForSignatureCalc`? +// - Methods 5-8: Use whatever is passed as `signature_bytes_1`, but only if: +// 1. `signature_bytes_1` is not empty. +// 2. The packet type is not `SYN`. +// 3. The packet type is not `CONNECT`. +// 4. The packet type is not `USER` (?). +// 5. `type_flags & 0x200 == 0`. +// 6. `type_flags & 0x400 == 0`. +// - Method 6: Use an optional "key", if not null +// - If method 7 is used, 2 local variables are set to 0. Otherwise they get set the content pointer +// and size of the calculated signature buffer. In both cases another local variable is set to +// `packet->field_0x94`, and then some checks are done on it before it's set to the packets payload? +// - Method 8: 16 random numbers generated and appended to `signature_bytes_2` +// - Method 9: The signature seems ignored entirely? +type SignatureMethod uint8 + +const ( + // SignatureMethod0 is an unknown signature type + SignatureMethod0 SignatureMethod = iota + + // SignatureMethodConnectionAddress seems to indicate the signature is based on the connection address + SignatureMethodConnectionAddress + + // SignatureMethod2 is an unknown signature type + SignatureMethod2 + + // SignatureMethod3 is an unknown signature type + SignatureMethod3 + + // SignatureMethod4 is an unknown signature method + SignatureMethod4 + + // SignatureMethod5 is an unknown signature method + SignatureMethod5 + + // SignatureMethodUseKey seems to indicate the signature uses the provided key value, if not null + SignatureMethodUseKey + + // SignatureMethod7 is an unknown signature method + SignatureMethod7 + + // SignatureMethodUseEntropy seems to indicate the signature includes 16 random bytes + SignatureMethodUseEntropy + + // SignatureMethodIgnore seems to indicate the signature is ignored + SignatureMethodIgnore +) From fcf3430788c766bdb128388ecf6e24244cafca4b Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 27 Feb 2024 18:34:16 -0500 Subject: [PATCH 157/178] prudp: swap to native Go LZO lib --- compression/lzo.go | 44 ++++++++------------------------------------ go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 11 insertions(+), 39 deletions(-) diff --git a/compression/lzo.go b/compression/lzo.go index 5e93d3eb..def2f3d7 100644 --- a/compression/lzo.go +++ b/compression/lzo.go @@ -4,7 +4,7 @@ import ( "bytes" "fmt" - "github.com/cyberdelia/lzo" + "github.com/rasky/go-lzo" ) // TODO - Untested. I think this works. Maybe. Verify and remove this comment @@ -14,29 +14,15 @@ type LZO struct{} // Compress compresses the payload using LZO func (l *LZO) Compress(payload []byte) ([]byte, error) { - var compressed bytes.Buffer + compressed := lzo.Compress1X(payload) - lzoWriter := lzo.NewWriter(&compressed) + compressionRatio := len(payload)/len(compressed) + 1 - _, err := lzoWriter.Write(payload) - if err != nil { - return []byte{}, err - } - - err = lzoWriter.Close() - if err != nil { - return []byte{}, err - } - - compressedBytes := compressed.Bytes() - - compressionRatio := len(payload)/len(compressedBytes) + 1 - - result := make([]byte, len(compressedBytes)+1) + result := make([]byte, len(compressed)+1) result[0] = byte(compressionRatio) - copy(result[1:], compressedBytes) + copy(result[1:], compressed) return result, nil } @@ -52,32 +38,18 @@ func (l *LZO) Decompress(payload []byte) ([]byte, error) { } reader := bytes.NewReader(compressed) - decompressed := bytes.Buffer{} - - lzoReader, err := lzo.NewReader(reader) - if err != nil { - return []byte{}, err - } - - _, err = decompressed.ReadFrom(lzoReader) + decompressed, err := lzo.Decompress1X(reader, len(compressed), 0) if err != nil { return []byte{}, err } - err = lzoReader.Close() - if err != nil { - return []byte{}, err - } - - decompressedBytes := decompressed.Bytes() - - ratioCheck := len(decompressedBytes)/len(compressed) + 1 + ratioCheck := len(decompressed)/len(compressed) + 1 if ratioCheck != int(compressionRatio) { return []byte{}, fmt.Errorf("Failed to decompress payload. Got bad ratio. Expected %d, got %d", compressionRatio, ratioCheck) } - return decompressedBytes, nil + return decompressed, nil } // Copy returns a copy of the algorithm diff --git a/go.mod b/go.mod index fa87d6c8..310f63a7 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.21 require ( github.com/PretendoNetwork/plogger-go v1.0.4 - github.com/cyberdelia/lzo v1.0.0 github.com/lxzan/gws v1.8.0 + github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e github.com/superwhiskers/crunch/v3 v3.5.7 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/mod v0.12.0 diff --git a/go.sum b/go.sum index 8af4fe94..99546b6c 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/PretendoNetwork/plogger-go v1.0.4 h1:PF7xHw9eDRHH+RsAP9tmAE7fG0N0p6H4iPwHKnsoXwc= github.com/PretendoNetwork/plogger-go v1.0.4/go.mod h1:7kD6M4vPq1JL4LTuPg6kuB1OvUBOwQOtAvTaUwMbwvU= -github.com/cyberdelia/lzo v1.0.0 h1:smmvcahczwI/VWSzZ7iikt50lubari5py3qL4hAEHII= -github.com/cyberdelia/lzo v1.0.0/go.mod h1:UVNk6eM6Sozt1wx17TECJKuqmIY58TJOVeJxjlGGAGs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= @@ -23,6 +21,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e h1:dCWirM5F3wMY+cmRda/B1BiPsFtmzXqV9b0hLWtVBMs= +github.com/rasky/go-lzo v0.0.0-20200203143853-96a758eda86e/go.mod h1:9leZcVcItj6m9/CfHY5Em/iBrCz7js8LcRQGTKEEv2M= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/superwhiskers/crunch/v3 v3.5.7 h1:N9RLxaR65C36i26BUIpzPXGy2f6pQ7wisu2bawbKNqg= From 652dbecf487d2ef90a1ac829aa83f17822431125 Mon Sep 17 00:00:00 2001 From: PabloMK7 Date: Tue, 5 Mar 2024 19:28:21 +0100 Subject: [PATCH 158/178] Use more fields from StreamSettings --- prudp_connection.go | 8 ++++---- prudp_server.go | 3 --- reliable_packet_substream_manager.go | 3 +-- resend_scheduler.go | 29 ++++++++++++++-------------- sliding_window.go | 6 +----- stream_settings.go | 10 +++++----- 6 files changed, 25 insertions(+), 34 deletions(-) diff --git a/prudp_connection.go b/prudp_connection.go index 19cb1698..b7d82eb4 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -169,13 +169,13 @@ func (pc *PRUDPConnection) resetHeartbeat() { } if pc.heartbeatTimer != nil { - pc.heartbeatTimer.Reset(pc.endpoint.Server.pingTimeout) // TODO - This is part of StreamSettings + pc.heartbeatTimer.Reset(time.Duration(pc.StreamSettings.MaxSilenceTime) * time.Millisecond) } } func (pc *PRUDPConnection) startHeartbeat() { endpoint := pc.endpoint - server := endpoint.Server + maxSilenceTime := time.Duration(pc.StreamSettings.MaxSilenceTime) * time.Millisecond // * Every time a packet is sent, connection.resetHeartbeat() // * is called which resets this timer. If this function @@ -183,12 +183,12 @@ func (pc *PRUDPConnection) startHeartbeat() { // * in the expected time frame. If this happens, send // * the client a PING packet to try and kick start the // * heartbeat again - pc.heartbeatTimer = time.AfterFunc(server.pingTimeout, func() { + pc.heartbeatTimer = time.AfterFunc(maxSilenceTime, func() { endpoint.sendPing(pc) // * If the heartbeat still did not restart, assume the // * connection is dead and clean up - pc.pingKickTimer = time.AfterFunc(server.pingTimeout, func() { + pc.pingKickTimer = time.AfterFunc(maxSilenceTime, func() { pc.cleanup() // * "removed" event is dispatched here discriminator := fmt.Sprintf("%s-%d-%d", pc.Socket.Address.String(), pc.StreamType, pc.StreamID) diff --git a/prudp_server.go b/prudp_server.go index 965033f2..231ee95a 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "runtime" - "time" "github.com/lxzan/gws" ) @@ -22,7 +21,6 @@ type PRUDPServer struct { KerberosTicketVersion int SessionKeyLength int FragmentSize int - pingTimeout time.Duration PRUDPv1ConnectionSignatureKey []byte LibraryVersions *LibraryVersions ByteStreamSettings *ByteStreamSettings @@ -330,7 +328,6 @@ func NewPRUDPServer() *PRUDPServer { Connections: NewMutexMap[string, *SocketConnection](), SessionKeyLength: 32, FragmentSize: 1300, - pingTimeout: time.Second * 15, LibraryVersions: NewLibraryVersions(), ByteStreamSettings: NewByteStreamSettings(), PRUDPV0Settings: NewPRUDPV0Settings(), diff --git a/reliable_packet_substream_manager.go b/reliable_packet_substream_manager.go index 972df322..cd85ce63 100644 --- a/reliable_packet_substream_manager.go +++ b/reliable_packet_substream_manager.go @@ -2,7 +2,6 @@ package nex import ( "crypto/rc4" - "time" ) // ReliablePacketSubstreamManager represents a substream manager for reliable PRUDP packets @@ -85,7 +84,7 @@ func NewReliablePacketSubstreamManager(startingIncomingSequenceID, startingOutgo packetMap: NewMutexMap[uint16, PRUDPPacketInterface](), incomingSequenceIDCounter: NewCounter[uint16](startingIncomingSequenceID), outgoingSequenceIDCounter: NewCounter[uint16](startingOutgoingSequenceID), - ResendScheduler: NewResendScheduler(5, time.Second, 0), + ResendScheduler: NewResendScheduler(), } psm.SetCipherKey([]byte("CD&ML")) diff --git a/resend_scheduler.go b/resend_scheduler.go index b637a68a..e14f6109 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -35,8 +35,6 @@ func (pi *PendingPacket) startResendTimer() { // ResendScheduler manages the resending of reliable PRUDP packets type ResendScheduler struct { packets *MutexMap[uint16, *PendingPacket] - Interval time.Duration - Increase time.Duration } // Stop kills the resend scheduler and stops all pending packets @@ -62,10 +60,13 @@ func (rs *ResendScheduler) Stop() { // AddPacket adds a packet to the scheduler and begins it's timer func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { + connection := packet.Sender().(*PRUDPConnection) + slidingWindow := connection.SlidingWindow(packet.SubstreamID()) + pendingPacket := &PendingPacket{ packet: packet, rs: rs, - interval: rs.Interval, + interval: time.Duration(slidingWindow.streamSettings.KeepAliveTimeout) * time.Millisecond, } rs.packets.Set(packet.SequenceID(), pendingPacket) @@ -106,28 +107,26 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { return } - if time.Since(pendingPacket.lastSendTime) >= rs.Interval { + if time.Since(pendingPacket.lastSendTime) >= time.Duration(slidingWindow.streamSettings.KeepAliveTimeout) * time.Millisecond { // * Resend the packet to the connection server := connection.endpoint.Server data := packet.Bytes() server.sendRaw(connection.Socket, data) - - pendingPacket.interval += rs.Increase - pendingPacket.ticker.Reset(pendingPacket.interval) + pendingPacket.resendCount++ + if (pendingPacket.resendCount < slidingWindow.streamSettings.ExtraRestransmitTimeoutTrigger) { + pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout) * slidingWindow.streamSettings.RetransmitTimeoutMultiplier)) * time.Millisecond + } else { + pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout) * slidingWindow.streamSettings.ExtraRetransmitTimeoutMultiplier)) * time.Millisecond + } + pendingPacket.ticker.Reset(pendingPacket.interval) pendingPacket.lastSendTime = time.Now() } } -// NewResendScheduler creates a new ResendScheduler with the provided max resend count and interval and increase durations -// -// If increase is non-zero then every resend will have it's duration increased by that amount. For example an interval of -// 1 second and an increase of 5 seconds. The 1st resend happens after 1 second, the 2nd will take place 6 seconds -// after the 1st, and the 3rd will take place 11 seconds after the 2nd -func NewResendScheduler(maxResendCount int, interval, increase time.Duration) *ResendScheduler { +// NewResendScheduler creates a new ResendScheduler +func NewResendScheduler() *ResendScheduler { return &ResendScheduler{ packets: NewMutexMap[uint16, *PendingPacket](), - Interval: interval, - Increase: increase, } } diff --git a/sliding_window.go b/sliding_window.go index c5f58002..17835623 100644 --- a/sliding_window.go +++ b/sliding_window.go @@ -1,9 +1,5 @@ package nex -import ( - "time" -) - // SlidingWindow is an implementation of rdv::SlidingWindow. // SlidingWindow reorders pending reliable packets to ensure they are handled in the expected order. // In the original library each virtual connection stream only uses a single SlidingWindow, but starting @@ -74,7 +70,7 @@ func NewSlidingWindow() *SlidingWindow { pendingPackets: NewMutexMap[uint16, PRUDPPacketInterface](), incomingSequenceIDCounter: NewCounter[uint16](0), outgoingSequenceIDCounter: NewCounter[uint16](0), - ResendScheduler: NewResendScheduler(5, time.Second, 0), + ResendScheduler: NewResendScheduler(), } return sw diff --git a/stream_settings.go b/stream_settings.go index a7ee51e6..6ee1fac0 100644 --- a/stream_settings.go +++ b/stream_settings.go @@ -12,19 +12,19 @@ import ( // The original library has more settings which are not present here as their use is unknown. // Not all values are used at this time, and only exist to future-proof for a later time. type StreamSettings struct { - ExtraRestransmitTimeoutTrigger uint32 // * Unused. The number of times a packet can be retransmitted before ExtraRetransmitTimeoutMultiplier is used + ExtraRestransmitTimeoutTrigger uint32 // * The number of times a packet can be retransmitted before ExtraRetransmitTimeoutMultiplier is used MaxPacketRetransmissions uint32 // * The number of times a packet can be retransmitted before the timeout time is checked - KeepAliveTimeout uint32 // * Unused. Presumably the time a packet can be alive for without acknowledgement? Milliseconds? + KeepAliveTimeout uint32 // * Presumably the time a packet can be alive for without acknowledgement? Milliseconds? ChecksumBase uint32 // * Unused. The base value for PRUDPv0 checksum calculations FaultDetectionEnabled bool // * Unused. Presumably used to detect PIA faults? InitialRTT uint32 // * Unused. The connections initial RTT EncryptionAlgorithm encryption.Algorithm // * The encryption algorithm used for packet payloads - ExtraRetransmitTimeoutMultiplier float32 // * Unused. Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has been reached + ExtraRetransmitTimeoutMultiplier float32 // * Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has been reached WindowSize uint32 // * Unused. The max number of (reliable?) packets allowed in a SlidingWindow CompressionAlgorithm compression.Algorithm // * The compression algorithm used for packet payloads RTTRetransmit uint32 // * Unused. Unknown use - RetransmitTimeoutMultiplier float32 // * Unused. Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has not been reached - MaxSilenceTime uint32 // * Unused. Presumably the time a connection can go without any packets from the other side? Milliseconds? + RetransmitTimeoutMultiplier float32 // * Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has not been reached + MaxSilenceTime uint32 // * Presumably the time a connection can go without any packets from the other side? Milliseconds? } // Copy returns a new copy of the settings From 89ef0596a1a035b473b9413456e6f1fc4f30e3d8 Mon Sep 17 00:00:00 2001 From: PabloMK7 Date: Wed, 6 Mar 2024 13:59:14 +0100 Subject: [PATCH 159/178] Apply suggestions --- prudp_connection.go | 3 +++ resend_scheduler.go | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/prudp_connection.go b/prudp_connection.go index b7d82eb4..5c6e41c5 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -169,12 +169,15 @@ func (pc *PRUDPConnection) resetHeartbeat() { } if pc.heartbeatTimer != nil { + // TODO: This may not be accurate, needs more research pc.heartbeatTimer.Reset(time.Duration(pc.StreamSettings.MaxSilenceTime) * time.Millisecond) } } func (pc *PRUDPConnection) startHeartbeat() { endpoint := pc.endpoint + + // TODO: This may not be accurate, needs more research maxSilenceTime := time.Duration(pc.StreamSettings.MaxSilenceTime) * time.Millisecond // * Every time a packet is sent, connection.resetHeartbeat() diff --git a/resend_scheduler.go b/resend_scheduler.go index e14f6109..3f0dfe86 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -66,6 +66,7 @@ func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { pendingPacket := &PendingPacket{ packet: packet, rs: rs, + // TODO: This may not be accurate, needs more research interval: time.Duration(slidingWindow.streamSettings.KeepAliveTimeout) * time.Millisecond, } @@ -107,6 +108,7 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { return } + // TODO: This may not be accurate, needs more research if time.Since(pendingPacket.lastSendTime) >= time.Duration(slidingWindow.streamSettings.KeepAliveTimeout) * time.Millisecond { // * Resend the packet to the connection server := connection.endpoint.Server @@ -114,11 +116,15 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { server.sendRaw(connection.Socket, data) pendingPacket.resendCount++ + + var retransmitTimeoutMultiplier float32 if (pendingPacket.resendCount < slidingWindow.streamSettings.ExtraRestransmitTimeoutTrigger) { - pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout) * slidingWindow.streamSettings.RetransmitTimeoutMultiplier)) * time.Millisecond + retransmitTimeoutMultiplier = slidingWindow.streamSettings.RetransmitTimeoutMultiplier } else { - pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout) * slidingWindow.streamSettings.ExtraRetransmitTimeoutMultiplier)) * time.Millisecond + retransmitTimeoutMultiplier = slidingWindow.streamSettings.ExtraRetransmitTimeoutMultiplier } + pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout) * retransmitTimeoutMultiplier)) * time.Millisecond + pendingPacket.ticker.Reset(pendingPacket.interval) pendingPacket.lastSendTime = time.Now() } From ab069f3410f47ded36f841ee9807bb7ab30d3fe9 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Mar 2024 13:17:00 -0400 Subject: [PATCH 160/178] chore: formatting --- byte_stream_in.go | 4 ++-- byte_stream_out.go | 4 ++-- prudp_server.go | 2 -- resend_scheduler.go | 18 +++++++++--------- types/structure.go | 2 +- 5 files changed, 14 insertions(+), 16 deletions(-) diff --git a/byte_stream_in.go b/byte_stream_in.go index 294e696b..0e52df86 100644 --- a/byte_stream_in.go +++ b/byte_stream_in.go @@ -170,8 +170,8 @@ func (bsi *ByteStreamIn) ReadPrimitiveBool() (bool, error) { // NewByteStreamIn returns a new NEX input byte stream func NewByteStreamIn(data []byte, libraryVersions *LibraryVersions, settings *ByteStreamSettings) *ByteStreamIn { return &ByteStreamIn{ - Buffer: crunch.NewBuffer(data), + Buffer: crunch.NewBuffer(data), LibraryVersions: libraryVersions, - Settings: settings, + Settings: settings, } } diff --git a/byte_stream_out.go b/byte_stream_out.go index 27f63a28..447ab25f 100644 --- a/byte_stream_out.go +++ b/byte_stream_out.go @@ -130,8 +130,8 @@ func (bso *ByteStreamOut) WritePrimitiveBool(b bool) { // NewByteStreamOut returns a new NEX writable byte stream func NewByteStreamOut(libraryVersions *LibraryVersions, settings *ByteStreamSettings) *ByteStreamOut { return &ByteStreamOut{ - Buffer: crunch.NewBuffer(), + Buffer: crunch.NewBuffer(), LibraryVersions: libraryVersions, - Settings: settings, + Settings: settings, } } diff --git a/prudp_server.go b/prudp_server.go index 231ee95a..3b20b388 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -272,14 +272,12 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { } } - if ps.PRUDPV1Settings.LegacyConnectionSignature { packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.Signature)) } else { packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature)) } - if packetCopy.HasFlag(FlagReliable) && packetCopy.HasFlag(FlagNeedsAck) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) slidingWindow.ResendScheduler.AddPacket(packetCopy) diff --git a/resend_scheduler.go b/resend_scheduler.go index 3f0dfe86..1f0a8048 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -34,7 +34,7 @@ func (pi *PendingPacket) startResendTimer() { // ResendScheduler manages the resending of reliable PRUDP packets type ResendScheduler struct { - packets *MutexMap[uint16, *PendingPacket] + packets *MutexMap[uint16, *PendingPacket] } // Stop kills the resend scheduler and stops all pending packets @@ -64,8 +64,8 @@ func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { slidingWindow := connection.SlidingWindow(packet.SubstreamID()) pendingPacket := &PendingPacket{ - packet: packet, - rs: rs, + packet: packet, + rs: rs, // TODO: This may not be accurate, needs more research interval: time.Duration(slidingWindow.streamSettings.KeepAliveTimeout) * time.Millisecond, } @@ -109,22 +109,22 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { } // TODO: This may not be accurate, needs more research - if time.Since(pendingPacket.lastSendTime) >= time.Duration(slidingWindow.streamSettings.KeepAliveTimeout) * time.Millisecond { + if time.Since(pendingPacket.lastSendTime) >= time.Duration(slidingWindow.streamSettings.KeepAliveTimeout)*time.Millisecond { // * Resend the packet to the connection server := connection.endpoint.Server data := packet.Bytes() server.sendRaw(connection.Socket, data) - + pendingPacket.resendCount++ var retransmitTimeoutMultiplier float32 - if (pendingPacket.resendCount < slidingWindow.streamSettings.ExtraRestransmitTimeoutTrigger) { + if pendingPacket.resendCount < slidingWindow.streamSettings.ExtraRestransmitTimeoutTrigger { retransmitTimeoutMultiplier = slidingWindow.streamSettings.RetransmitTimeoutMultiplier } else { retransmitTimeoutMultiplier = slidingWindow.streamSettings.ExtraRetransmitTimeoutMultiplier } - pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout) * retransmitTimeoutMultiplier)) * time.Millisecond - + pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout)*retransmitTimeoutMultiplier)) * time.Millisecond + pendingPacket.ticker.Reset(pendingPacket.interval) pendingPacket.lastSendTime = time.Now() } @@ -133,6 +133,6 @@ func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) { // NewResendScheduler creates a new ResendScheduler func NewResendScheduler() *ResendScheduler { return &ResendScheduler{ - packets: NewMutexMap[uint16, *PendingPacket](), + packets: NewMutexMap[uint16, *PendingPacket](), } } diff --git a/types/structure.go b/types/structure.go index 35964415..ce3936f2 100644 --- a/types/structure.go +++ b/types/structure.go @@ -7,7 +7,7 @@ import ( // Structure represents a Quazal Rendez-Vous/NEX Structure (custom class) base struct. type Structure struct { - StructureVersion uint8 + StructureVersion uint8 } // ExtractHeaderFrom extracts the structure header from the given readable From b37dede3c42be60196ebd27bfeaf6e634842004a Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Mar 2024 13:19:59 -0400 Subject: [PATCH 161/178] prudp: added nil check to pendingPacket.ticker --- resend_scheduler.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/resend_scheduler.go b/resend_scheduler.go index 1f0a8048..66cce228 100644 --- a/resend_scheduler.go +++ b/resend_scheduler.go @@ -52,7 +52,13 @@ func (rs *ResendScheduler) Stop() { for _, sequenceID := range stillPending { if pendingPacket, ok := rs.packets.Get(sequenceID); ok { pendingPacket.isAcknowledged = true // * Prevent an edge case where the ticker is already being processed - pendingPacket.ticker.Stop() + + if pendingPacket.ticker != nil { + // * This should never happen, but popped up in CTGP-7 testing? + // * Did the GC clear this before we called it? + pendingPacket.ticker.Stop() + } + rs.packets.Delete(sequenceID) } } From a1a590392b76666ce49eede108102f9814be6521 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Mar 2024 16:09:23 -0400 Subject: [PATCH 162/178] chore: move constants and enums to constants package --- .../prudp_packet_flags.go | 2 +- .../prudp_packet_types.go | 10 +--- .../signature_method.go | 2 +- stream_type.go => constants/stream_type.go | 2 +- prudp_connection.go | 3 +- prudp_endpoint.go | 47 ++++++++++--------- prudp_packet.go | 16 ++++--- prudp_packet_interface.go | 14 ++++-- prudp_packet_lite.go | 38 ++++++++------- prudp_packet_v0.go | 20 ++++---- prudp_packet_v1.go | 18 +++---- prudp_server.go | 19 ++++---- test/auth.go | 13 ++--- test/secure.go | 29 ++++++------ virtual_port.go | 8 ++-- 15 files changed, 127 insertions(+), 114 deletions(-) rename prudp_packet_flags.go => constants/prudp_packet_flags.go (95%) rename prudp_packet_types.go => constants/prudp_packet_types.go (71%) rename signature_method.go => constants/signature_method.go (99%) rename stream_type.go => constants/stream_type.go (98%) diff --git a/prudp_packet_flags.go b/constants/prudp_packet_flags.go similarity index 95% rename from prudp_packet_flags.go rename to constants/prudp_packet_flags.go index 147dd7b1..6c54c7f0 100644 --- a/prudp_packet_flags.go +++ b/constants/prudp_packet_flags.go @@ -1,4 +1,4 @@ -package nex +package constants const ( // FlagAck is the ID for the PRUDP Ack Flag diff --git a/prudp_packet_types.go b/constants/prudp_packet_types.go similarity index 71% rename from prudp_packet_types.go rename to constants/prudp_packet_types.go index fc690866..510200e4 100644 --- a/prudp_packet_types.go +++ b/constants/prudp_packet_types.go @@ -1,4 +1,4 @@ -package nex +package constants const ( // SynPacket is the ID for the PRUDP Syn Packet type @@ -16,11 +16,3 @@ const ( // PingPacket is the ID for the PRUDP Ping Packet type PingPacket uint16 = 0x4 ) - -var validPacketTypes = map[uint16]bool{ - SynPacket: true, - ConnectPacket: true, - DataPacket: true, - DisconnectPacket: true, - PingPacket: true, -} diff --git a/signature_method.go b/constants/signature_method.go similarity index 99% rename from signature_method.go rename to constants/signature_method.go index 6bd847ce..49f372d0 100644 --- a/signature_method.go +++ b/constants/signature_method.go @@ -1,4 +1,4 @@ -package nex +package constants // SignatureMethod is an implementation of the nn::nex::PRUDPMessageInterface::SignatureMethod enum. // diff --git a/stream_type.go b/constants/stream_type.go similarity index 98% rename from stream_type.go rename to constants/stream_type.go index b1fb0249..38c58a74 100644 --- a/stream_type.go +++ b/constants/stream_type.go @@ -1,4 +1,4 @@ -package nex +package constants // TODO - Should this be moved to the types module? diff --git a/prudp_connection.go b/prudp_connection.go index 5c6e41c5..4292e462 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -6,6 +6,7 @@ import ( "net" "time" + "github.com/PretendoNetwork/nex-go/constants" "github.com/PretendoNetwork/nex-go/types" ) @@ -22,7 +23,7 @@ type PRUDPConnection struct { SessionKey []byte // * Secret key generated at the start of the session. Used for encrypting packets to the secure server pid *types.PID // * PID of the user DefaultPRUDPVersion int // * The PRUDP version the connection was established with. Used for sending PING packets - StreamType StreamType // * rdv::Stream::Type used in this connection + StreamType constants.StreamType // * rdv::Stream::Type used in this connection StreamID uint8 // * rdv::Stream ID, also called the "port number", used in this connection. 0-15 on PRUDPv0/v1, and 0-31 on PRUDPLite StreamSettings *StreamSettings // * Settings for this virtual connection Signature []byte // * Connection signature for packets coming from the client, as seen by the server diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 3eb0a66f..3778d6ea 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -7,6 +7,7 @@ import ( "slices" "time" + "github.com/PretendoNetwork/nex-go/constants" "github.com/PretendoNetwork/nex-go/types" ) @@ -113,21 +114,21 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc packet.SetSender(connection) - if packet.HasFlag(FlagAck) || packet.HasFlag(FlagMultiAck) { + if packet.HasFlag(constants.FlagAck) || packet.HasFlag(constants.FlagMultiAck) { pep.handleAcknowledgment(packet) return } switch packet.Type() { - case SynPacket: + case constants.SynPacket: pep.handleSyn(packet) - case ConnectPacket: + case constants.ConnectPacket: pep.handleConnect(packet) - case DataPacket: + case constants.DataPacket: pep.handleData(packet) - case DisconnectPacket: + case constants.DisconnectPacket: pep.handleDisconnect(packet) - case PingPacket: + case constants.PingPacket: pep.handlePing(packet) } } @@ -140,7 +141,7 @@ func (pep *PRUDPEndPoint) handleAcknowledgment(packet PRUDPPacketInterface) { return } - if packet.HasFlag(FlagMultiAck) { + if packet.HasFlag(constants.FlagMultiAck) { pep.handleMultiAcknowledgment(packet) return } @@ -224,9 +225,9 @@ func (pep *PRUDPEndPoint) handleSyn(packet PRUDPPacketInterface) { connection.reset() connection.Signature = connectionSignature - ack.SetType(SynPacket) - ack.AddFlag(FlagAck) - ack.AddFlag(FlagHasSize) + ack.SetType(constants.SynPacket) + ack.AddFlag(constants.FlagAck) + ack.AddFlag(constants.FlagHasSize) ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -277,9 +278,9 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { connection.ServerSessionID = packet.SessionID() - ack.SetType(ConnectPacket) - ack.AddFlag(FlagAck) - ack.AddFlag(FlagHasSize) + ack.SetType(constants.ConnectPacket) + ack.AddFlag(constants.FlagAck) + ack.AddFlag(constants.FlagHasSize) ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -385,7 +386,7 @@ func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { connection.resetHeartbeat() - if packet.HasFlag(FlagReliable) { + if packet.HasFlag(constants.FlagReliable) { pep.handleReliable(packet) } else { pep.handleUnreliable(packet) @@ -396,7 +397,7 @@ func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { // TODO - Should we check the state here, or just let the connection disconnect at any time? // TODO - Should we bother to set the connections state here? It's being destroyed anyway - if packet.HasFlag(FlagNeedsAck) { + if packet.HasFlag(constants.FlagNeedsAck) { pep.acknowledgePacket(packet) } @@ -416,7 +417,7 @@ func (pep *PRUDPEndPoint) handlePing(packet PRUDPPacketInterface) { connection.resetHeartbeat() - if packet.HasFlag(FlagNeedsAck) { + if packet.HasFlag(constants.FlagNeedsAck) { pep.acknowledgePacket(packet) } } @@ -499,7 +500,7 @@ func (pep *PRUDPEndPoint) acknowledgePacket(packet PRUDPPacketInterface) { } ack.SetType(packet.Type()) - ack.AddFlag(FlagAck) + ack.AddFlag(constants.FlagAck) ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -511,14 +512,14 @@ func (pep *PRUDPEndPoint) acknowledgePacket(packet PRUDPPacketInterface) { pep.Server.sendPacket(ack) // * Servers send the DISCONNECT ACK 3 times - if packet.Type() == DisconnectPacket { + if packet.Type() == constants.DisconnectPacket { pep.Server.sendPacket(ack) pep.Server.sendPacket(ack) } } func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagNeedsAck) { + if packet.HasFlag(constants.FlagNeedsAck) { pep.acknowledgePacket(packet) } @@ -527,7 +528,7 @@ func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { slidingWindow := packet.Sender().(*PRUDPConnection).SlidingWindow(packet.SubstreamID()) for _, pendingPacket := range slidingWindow.Update(packet) { - if packet.Type() == DataPacket { + if packet.Type() == constants.DataPacket { var decryptedPayload []byte if packet.Version() != 2 { @@ -563,7 +564,7 @@ func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { } func (pep *PRUDPEndPoint) handleUnreliable(packet PRUDPPacketInterface) { - if packet.HasFlag(FlagNeedsAck) { + if packet.HasFlag(constants.FlagNeedsAck) { pep.acknowledgePacket(packet) } @@ -635,8 +636,8 @@ func (pep *PRUDPEndPoint) sendPing(connection *PRUDPConnection) { ping, _ = NewPRUDPPacketLite(pep.Server, connection, nil) } - ping.SetType(PingPacket) - ping.AddFlag(FlagNeedsAck) + ping.SetType(constants.PingPacket) + ping.AddFlag(constants.FlagNeedsAck) ping.SetSourceVirtualPortStreamType(connection.StreamType) ping.SetSourceVirtualPortStreamID(pep.StreamID) ping.SetDestinationVirtualPortStreamType(connection.StreamType) diff --git a/prudp_packet.go b/prudp_packet.go index ad261f49..e2fed56a 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -1,6 +1,10 @@ package nex -import "crypto/rc4" +import ( + "crypto/rc4" + + "github.com/PretendoNetwork/nex-go/constants" +) // PRUDPPacket holds all the fields each packet should have in all PRUDP versions type PRUDPPacket struct { @@ -58,12 +62,12 @@ func (p *PRUDPPacket) Type() uint16 { } // SetSourceVirtualPortStreamType sets the packets source VirtualPort StreamType -func (p *PRUDPPacket) SetSourceVirtualPortStreamType(streamType StreamType) { +func (p *PRUDPPacket) SetSourceVirtualPortStreamType(streamType constants.StreamType) { p.sourceVirtualPort.SetStreamType(streamType) } // SourceVirtualPortStreamType returns the packets source VirtualPort StreamType -func (p *PRUDPPacket) SourceVirtualPortStreamType() StreamType { +func (p *PRUDPPacket) SourceVirtualPortStreamType() constants.StreamType { return p.sourceVirtualPort.StreamType() } @@ -78,12 +82,12 @@ func (p *PRUDPPacket) SourceVirtualPortStreamID() uint8 { } // SetDestinationVirtualPortStreamType sets the packets destination VirtualPort StreamType -func (p *PRUDPPacket) SetDestinationVirtualPortStreamType(streamType StreamType) { +func (p *PRUDPPacket) SetDestinationVirtualPortStreamType(streamType constants.StreamType) { p.destinationVirtualPort.SetStreamType(streamType) } // DestinationVirtualPortStreamType returns the packets destination VirtualPort StreamType -func (p *PRUDPPacket) DestinationVirtualPortStreamType() StreamType { +func (p *PRUDPPacket) DestinationVirtualPortStreamType() constants.StreamType { return p.destinationVirtualPort.StreamType() } @@ -145,7 +149,7 @@ func (p *PRUDPPacket) decryptPayload() []byte { payload := p.payload // TODO - This assumes a reliable DATA packet. Handle unreliable here? Or do that in a different method? - if p.packetType == DataPacket { + if p.packetType == constants.DataPacket { slidingWindow := p.sender.SlidingWindow(p.SubstreamID()) payload, _ = slidingWindow.streamSettings.EncryptionAlgorithm.Decrypt(payload) diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index d1e38dbb..337acba7 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -1,6 +1,10 @@ package nex -import "net" +import ( + "net" + + "github.com/PretendoNetwork/nex-go/constants" +) // PRUDPPacketInterface defines all the methods a PRUDP packet should have type PRUDPPacketInterface interface { @@ -14,12 +18,12 @@ type PRUDPPacketInterface interface { AddFlag(flag uint16) SetType(packetType uint16) Type() uint16 - SetSourceVirtualPortStreamType(streamType StreamType) - SourceVirtualPortStreamType() StreamType + SetSourceVirtualPortStreamType(streamType constants.StreamType) + SourceVirtualPortStreamType() constants.StreamType SetSourceVirtualPortStreamID(port uint8) SourceVirtualPortStreamID() uint8 - SetDestinationVirtualPortStreamType(streamType StreamType) - DestinationVirtualPortStreamType() StreamType + SetDestinationVirtualPortStreamType(streamType constants.StreamType) + DestinationVirtualPortStreamType() constants.StreamType SetDestinationVirtualPortStreamID(port uint8) DestinationVirtualPortStreamID() uint8 SessionID() uint8 diff --git a/prudp_packet_lite.go b/prudp_packet_lite.go index 8557c9f2..5d368eab 100644 --- a/prudp_packet_lite.go +++ b/prudp_packet_lite.go @@ -6,14 +6,16 @@ import ( "encoding/binary" "fmt" "net" + + "github.com/PretendoNetwork/nex-go/constants" ) // PRUDPPacketLite represents a PRUDPLite packet type PRUDPPacketLite struct { PRUDPPacket - sourceVirtualPortStreamType StreamType + sourceVirtualPortStreamType constants.StreamType sourceVirtualPortStreamID uint8 - destinationVirtualPortStreamType StreamType + destinationVirtualPortStreamType constants.StreamType destinationVirtualPortStreamID uint8 optionsLength uint8 minorVersion uint32 @@ -23,13 +25,13 @@ type PRUDPPacketLite struct { liteSignature []byte } -// SetSourceVirtualPortStreamType sets the packets source VirtualPort StreamType -func (p *PRUDPPacketLite) SetSourceVirtualPortStreamType(streamType StreamType) { +// SetSourceVirtualPortStreamType sets the packets source VirtualPort StreamType +func (p *PRUDPPacketLite) SetSourceVirtualPortStreamType(streamType constants.StreamType) { p.sourceVirtualPortStreamType = streamType } // SourceVirtualPortStreamType returns the packets source VirtualPort StreamType -func (p *PRUDPPacketLite) SourceVirtualPortStreamType() StreamType { +func (p *PRUDPPacketLite) SourceVirtualPortStreamType() constants.StreamType { return p.sourceVirtualPortStreamType } @@ -43,13 +45,13 @@ func (p *PRUDPPacketLite) SourceVirtualPortStreamID() uint8 { return p.sourceVirtualPort.StreamID() } -// SetDestinationVirtualPortStreamType sets the packets destination VirtualPort StreamType -func (p *PRUDPPacketLite) SetDestinationVirtualPortStreamType(streamType StreamType) { +// SetDestinationVirtualPortStreamType sets the packets destination VirtualPort constants.StreamType +func (p *PRUDPPacketLite) SetDestinationVirtualPortStreamType(streamType constants.StreamType) { p.destinationVirtualPortStreamType = streamType } -// DestinationVirtualPortStreamType returns the packets destination VirtualPort StreamType -func (p *PRUDPPacketLite) DestinationVirtualPortStreamType() StreamType { +// DestinationVirtualPortStreamType returns the packets destination VirtualPort constants.StreamType +func (p *PRUDPPacketLite) DestinationVirtualPortStreamType() constants.StreamType { return p.destinationVirtualPortStreamType } @@ -139,8 +141,8 @@ func (p *PRUDPPacketLite) decode() error { return fmt.Errorf("Failed to decode PRUDPLite virtual ports stream types. %s", err.Error()) } - p.sourceVirtualPortStreamType = StreamType(streamTypes >> 4) - p.destinationVirtualPortStreamType = StreamType(streamTypes & 0xF) + p.sourceVirtualPortStreamType = constants.StreamType(streamTypes >> 4) + p.destinationVirtualPortStreamType = constants.StreamType(streamTypes & 0xF) p.sourceVirtualPortStreamID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { @@ -220,7 +222,7 @@ func (p *PRUDPPacketLite) decodeOptions() error { return err } - if p.packetType == SynPacket || p.packetType == ConnectPacket { + if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket { if optionID == 0 { p.supportedFunctions, err = optionsStream.ReadPrimitiveUInt32LE() @@ -237,19 +239,19 @@ func (p *PRUDPPacketLite) decodeOptions() error { } } - if p.packetType == ConnectPacket { + if p.packetType == constants.ConnectPacket { if optionID == 3 { p.initialUnreliableSequenceID, err = optionsStream.ReadPrimitiveUInt16LE() } } - if p.packetType == DataPacket { + if p.packetType == constants.DataPacket { if optionID == 2 { p.fragmentID, err = optionsStream.ReadPrimitiveUInt8() } } - if p.packetType == ConnectPacket && !p.HasFlag(FlagAck) { + if p.packetType == constants.ConnectPacket && !p.HasFlag(constants.FlagAck) { if optionID == 0x80 { p.liteSignature = optionsStream.ReadBytesNext(int64(optionSize)) } @@ -269,19 +271,19 @@ func (p *PRUDPPacketLite) decodeOptions() error { func (p *PRUDPPacketLite) encodeOptions() []byte { optionsStream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings) - if p.packetType == SynPacket || p.packetType == ConnectPacket { + if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket { optionsStream.WritePrimitiveUInt8(0) optionsStream.WritePrimitiveUInt8(4) optionsStream.WritePrimitiveUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) - if p.packetType == SynPacket && p.HasFlag(FlagAck) { + if p.packetType == constants.SynPacket && p.HasFlag(constants.FlagAck) { optionsStream.WritePrimitiveUInt8(1) optionsStream.WritePrimitiveUInt8(16) optionsStream.Grow(16) optionsStream.WriteBytesNext(p.connectionSignature) } - if p.packetType == ConnectPacket && !p.HasFlag(FlagAck) { + if p.packetType == constants.ConnectPacket && !p.HasFlag(constants.FlagAck) { optionsStream.WritePrimitiveUInt8(1) optionsStream.WritePrimitiveUInt8(16) optionsStream.Grow(16) diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 0ec33800..1a338de2 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -8,6 +8,8 @@ import ( "fmt" "net" "slices" + + "github.com/PretendoNetwork/nex-go/constants" ) // PRUDPPacketV0 represents a PRUDPv0 packet @@ -98,7 +100,7 @@ func (p *PRUDPPacketV0) decode() error { p.packetType = typeAndFlags & 0xF } - if _, ok := validPacketTypes[p.packetType]; !ok { + if p.packetType > constants.PingPacket { return errors.New("Invalid PRUDPv0 packet type") } @@ -114,7 +116,7 @@ func (p *PRUDPPacketV0) decode() error { return fmt.Errorf("Failed to read PRUDPv0 sequence ID. %s", err.Error()) } - if p.packetType == SynPacket || p.packetType == ConnectPacket { + if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket { if p.readStream.Remaining() < 4 { return errors.New("Failed to read PRUDPv0 connection signature. Not have enough data") } @@ -122,7 +124,7 @@ func (p *PRUDPPacketV0) decode() error { p.connectionSignature = p.readStream.ReadBytesNext(4) } - if p.packetType == DataPacket { + if p.packetType == constants.DataPacket { if p.readStream.Remaining() < 1 { return errors.New("Failed to read PRUDPv0 fragment ID. Not have enough data") } @@ -135,7 +137,7 @@ func (p *PRUDPPacketV0) decode() error { var payloadSize uint16 - if p.HasFlag(FlagHasSize) { + if p.HasFlag(constants.FlagHasSize) { if p.readStream.Remaining() < 2 { return errors.New("Failed to read PRUDPv0 payload size. Not have enough data") } @@ -209,16 +211,16 @@ func (p *PRUDPPacketV0) Bytes() []byte { stream.WriteBytesNext(p.signature) stream.WritePrimitiveUInt16LE(p.sequenceID) - if p.packetType == SynPacket || p.packetType == ConnectPacket { + if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket { stream.Grow(int64(len(p.connectionSignature))) stream.WriteBytesNext(p.connectionSignature) } - if p.packetType == DataPacket { + if p.packetType == constants.DataPacket { stream.WritePrimitiveUInt8(p.fragmentID) } - if p.HasFlag(FlagHasSize) { + if p.HasFlag(constants.FlagHasSize) { stream.WritePrimitiveUInt16LE(uint16(len(p.payload))) } @@ -310,11 +312,11 @@ func defaultPRUDPv0ConnectionSignature(packet *PRUDPPacketV0, addr net.Addr) ([] func defaultPRUDPv0CalculateSignature(packet *PRUDPPacketV0, sessionKey, connectionSignature []byte) []byte { if !packet.server.PRUDPV0Settings.LegacyConnectionSignature { - if packet.packetType == DataPacket { + if packet.packetType == constants.DataPacket { return packet.server.PRUDPV0Settings.DataSignatureCalculator(packet, sessionKey) } - if packet.packetType == DisconnectPacket && packet.server.AccessKey != "ridfebb9" { + if packet.packetType == constants.DisconnectPacket && packet.server.AccessKey != "ridfebb9" { return packet.server.PRUDPV0Settings.DataSignatureCalculator(packet, sessionKey) } } diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 00a48260..9a343084 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "net" + + "github.com/PretendoNetwork/nex-go/constants" ) // PRUDPPacketV1 represents a PRUDPv1 packet @@ -174,7 +176,7 @@ func (p *PRUDPPacketV1) decodeHeader() error { p.flags = typeAndFlags >> 4 p.packetType = typeAndFlags & 0xF - if _, ok := validPacketTypes[p.packetType]; !ok { + if p.packetType > constants.PingPacket { return errors.New("Invalid PRUDPv1 packet type") } @@ -227,7 +229,7 @@ func (p *PRUDPPacketV1) decodeOptions() error { return err } - if p.packetType == SynPacket || p.packetType == ConnectPacket { + if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket { if optionID == 0 { p.supportedFunctions, err = optionsStream.ReadPrimitiveUInt32LE() @@ -244,13 +246,13 @@ func (p *PRUDPPacketV1) decodeOptions() error { } } - if p.packetType == ConnectPacket { + if p.packetType == constants.ConnectPacket { if optionID == 3 { p.initialUnreliableSequenceID, err = optionsStream.ReadPrimitiveUInt16LE() } } - if p.packetType == DataPacket { + if p.packetType == constants.DataPacket { if optionID == 2 { p.fragmentID, err = optionsStream.ReadPrimitiveUInt8() } @@ -270,7 +272,7 @@ func (p *PRUDPPacketV1) decodeOptions() error { func (p *PRUDPPacketV1) encodeOptions() []byte { optionsStream := NewByteStreamOut(p.server.LibraryVersions, p.server.ByteStreamSettings) - if p.packetType == SynPacket || p.packetType == ConnectPacket { + if p.packetType == constants.SynPacket || p.packetType == constants.ConnectPacket { optionsStream.WritePrimitiveUInt8(0) optionsStream.WritePrimitiveUInt8(4) optionsStream.WritePrimitiveUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) @@ -288,7 +290,7 @@ func (p *PRUDPPacketV1) encodeOptions() []byte { // * specific order. Due to how this section is // * parsed, though, order REALLY doesn't matter. // * NintendoClients expects option 3 before 4, though - if p.packetType == ConnectPacket { + if p.packetType == constants.ConnectPacket { optionsStream.WritePrimitiveUInt8(3) optionsStream.WritePrimitiveUInt8(2) optionsStream.WritePrimitiveUInt16LE(p.initialUnreliableSequenceID) @@ -299,7 +301,7 @@ func (p *PRUDPPacketV1) encodeOptions() []byte { optionsStream.WritePrimitiveUInt8(p.maximumSubstreamID) } - if p.packetType == DataPacket { + if p.packetType == constants.DataPacket { optionsStream.WritePrimitiveUInt8(2) optionsStream.WritePrimitiveUInt8(1) optionsStream.WritePrimitiveUInt8(p.fragmentID) @@ -388,7 +390,7 @@ func defaultPRUDPv1CalculateSignature(packet *PRUDPPacketV1, sessionKey, connect key := md5.Sum(accessKeyBytes) mac := hmac.New(md5.New, key[:]) - if packet.packetType == ConnectPacket && packet.server.PRUDPV1Settings.LegacyConnectionSignature { + if packet.packetType == constants.ConnectPacket && packet.server.PRUDPV1Settings.LegacyConnectionSignature { connectionSignature = make([]byte, 0) } diff --git a/prudp_server.go b/prudp_server.go index 3b20b388..ad15cd96 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -7,6 +7,7 @@ import ( "net" "runtime" + "github.com/PretendoNetwork/nex-go/constants" "github.com/lxzan/gws" ) @@ -165,12 +166,12 @@ func (ps *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Ad return } - if packet.DestinationVirtualPortStreamType() > StreamTypeRelay { + if packet.DestinationVirtualPortStreamType() > constants.StreamTypeRelay { logger.Warningf("Client %s trying to use invalid to destination stream type %d", address.String(), packet.DestinationVirtualPortStreamType()) return } - if packet.SourceVirtualPortStreamType() > StreamTypeRelay { + if packet.SourceVirtualPortStreamType() > constants.StreamTypeRelay { logger.Warningf("Client %s trying to use invalid to source stream type %d", address.String(), packet.DestinationVirtualPortStreamType()) return } @@ -233,13 +234,13 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy := packet.Copy() connection := packetCopy.Sender().(*PRUDPConnection) - if !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { - if packetCopy.HasFlag(FlagReliable) { + if !packetCopy.HasFlag(constants.FlagAck) && !packetCopy.HasFlag(constants.FlagMultiAck) { + if packetCopy.HasFlag(constants.FlagReliable) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) packetCopy.SetSequenceID(slidingWindow.NextOutgoingSequenceID()) - } else if packetCopy.Type() == DataPacket { + } else if packetCopy.Type() == constants.DataPacket { packetCopy.SetSequenceID(connection.outgoingUnreliableSequenceIDCounter.Next()) - } else if packetCopy.Type() == PingPacket { + } else if packetCopy.Type() == constants.PingPacket { packetCopy.SetSequenceID(connection.outgoingPingSequenceIDCounter.Next()) } else { packetCopy.SetSequenceID(0) @@ -248,8 +249,8 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy.SetSessionID(connection.ServerSessionID) - if packetCopy.Type() == DataPacket && !packetCopy.HasFlag(FlagAck) && !packetCopy.HasFlag(FlagMultiAck) { - if packetCopy.HasFlag(FlagReliable) { + if packetCopy.Type() == constants.DataPacket && !packetCopy.HasFlag(constants.FlagAck) && !packetCopy.HasFlag(constants.FlagMultiAck) { + if packetCopy.HasFlag(constants.FlagReliable) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) payload := packetCopy.Payload() @@ -278,7 +279,7 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature)) } - if packetCopy.HasFlag(FlagReliable) && packetCopy.HasFlag(FlagNeedsAck) { + if packetCopy.HasFlag(constants.FlagReliable) && packetCopy.HasFlag(constants.FlagNeedsAck) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) slidingWindow.ResendScheduler.AddPacket(packetCopy) } diff --git a/test/auth.go b/test/auth.go index a85c4329..004cd574 100644 --- a/test/auth.go +++ b/test/auth.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/constants" "github.com/PretendoNetwork/nex-go/types" ) @@ -95,9 +96,9 @@ func login(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(authServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(nex.FlagHasSize) - responsePacket.AddFlag(nex.FlagReliable) - responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.AddFlag(constants.FlagHasSize) + responsePacket.AddFlag(constants.FlagReliable) + responsePacket.AddFlag(constants.FlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -148,9 +149,9 @@ func requestTicket(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(authServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(nex.FlagHasSize) - responsePacket.AddFlag(nex.FlagReliable) - responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.AddFlag(constants.FlagHasSize) + responsePacket.AddFlag(constants.FlagReliable) + responsePacket.AddFlag(constants.FlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) diff --git a/test/secure.go b/test/secure.go index 615ebf53..47f11490 100644 --- a/test/secure.go +++ b/test/secure.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/PretendoNetwork/nex-go" + "github.com/PretendoNetwork/nex-go/constants" "github.com/PretendoNetwork/nex-go/types" ) @@ -111,8 +112,8 @@ func registerEx(packet nex.PRUDPPacketInterface) { address := packet.Sender().Address().(*net.UDPAddr).IP.String() - localStation.Fields["address"] = address - localStation.Fields["port"] = strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port) + localStation.Params["address"] = address + localStation.Params["port"] = strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port) retval := types.NewQResultSuccess(0x00010001) localStationURL := types.NewString(localStation.EncodeToString()) @@ -134,9 +135,9 @@ func registerEx(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, connection, nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(nex.FlagHasSize) - responsePacket.AddFlag(nex.FlagReliable) - responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.AddFlag(constants.FlagHasSize) + responsePacket.AddFlag(constants.FlagReliable) + responsePacket.AddFlag(constants.FlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -182,9 +183,9 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(nex.FlagHasSize) - responsePacket.AddFlag(nex.FlagReliable) - responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.AddFlag(constants.FlagHasSize) + responsePacket.AddFlag(constants.FlagReliable) + responsePacket.AddFlag(constants.FlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -214,9 +215,9 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(nex.FlagHasSize) - responsePacket.AddFlag(nex.FlagReliable) - responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.AddFlag(constants.FlagHasSize) + responsePacket.AddFlag(constants.FlagReliable) + responsePacket.AddFlag(constants.FlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -241,9 +242,9 @@ func updatePresence(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(nex.FlagHasSize) - responsePacket.AddFlag(nex.FlagReliable) - responsePacket.AddFlag(nex.FlagNeedsAck) + responsePacket.AddFlag(constants.FlagHasSize) + responsePacket.AddFlag(constants.FlagReliable) + responsePacket.AddFlag(constants.FlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) diff --git a/virtual_port.go b/virtual_port.go index d479515f..fdf3d1fe 100644 --- a/virtual_port.go +++ b/virtual_port.go @@ -1,5 +1,7 @@ package nex +import "github.com/PretendoNetwork/nex-go/constants" + // VirtualPort in an implementation of rdv::VirtualPort. // PRUDP will reuse a single physical socket connection for many virtual PRUDP connections. // VirtualPorts are a byte which represents a stream for a virtual PRUDP connection. @@ -9,13 +11,13 @@ package nex type VirtualPort byte // SetStreamType sets the VirtualPort stream type -func (vp *VirtualPort) SetStreamType(streamType StreamType) { +func (vp *VirtualPort) SetStreamType(streamType constants.StreamType) { *vp = VirtualPort((byte(*vp) & 0x0F) | (byte(streamType) << 4)) } // StreamType returns the VirtualPort stream type -func (vp VirtualPort) StreamType() StreamType { - return StreamType(vp >> 4) +func (vp VirtualPort) StreamType() constants.StreamType { + return constants.StreamType(vp >> 4) } // SetStreamID sets the VirtualPort stream ID From 719cb01968bd4c1279511455c4ecba18c26817a2 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Mar 2024 18:24:10 -0400 Subject: [PATCH 163/178] prudp: added more StationURL constants --- constants/nat_filtering_properties.go | 19 +++++++++++++++++++ constants/nat_mapping_properties.go | 19 +++++++++++++++++++ constants/station_url_flag.go | 12 ++++++++++++ constants/station_url_type.go | 20 ++++++++++++++++++++ 4 files changed, 70 insertions(+) create mode 100644 constants/nat_filtering_properties.go create mode 100644 constants/nat_mapping_properties.go create mode 100644 constants/station_url_flag.go create mode 100644 constants/station_url_type.go diff --git a/constants/nat_filtering_properties.go b/constants/nat_filtering_properties.go new file mode 100644 index 00000000..5cec4136 --- /dev/null +++ b/constants/nat_filtering_properties.go @@ -0,0 +1,19 @@ +package constants + +// NATFilteringProperties is an implementation of the nn::nex::NATProperties::FilteringProperties enum. +// +// NATFilteringProperties is used to indicate the NAT filtering properties of the users router. +// +// See https://datatracker.ietf.org/doc/html/rfc4787 for more details +type NATFilteringProperties uint8 + +const ( + // UnknownNATFiltering indicates the NAT type could not be identified + UnknownNATFiltering NATFilteringProperties = iota + + // PIFNATFiltering indicates port-independent filtering + PIFNATFiltering + + // PDFNATFiltering indicates port-dependent filtering + PDFNATFiltering +) diff --git a/constants/nat_mapping_properties.go b/constants/nat_mapping_properties.go new file mode 100644 index 00000000..10c16a6f --- /dev/null +++ b/constants/nat_mapping_properties.go @@ -0,0 +1,19 @@ +package constants + +// NATMappingProperties is an implementation of the nn::nex::NATProperties::MappingProperties enum. +// +// NATMappingProperties is used to indicate the NAT mapping properties of the users router. +// +// See https://datatracker.ietf.org/doc/html/rfc4787 for more details +type NATMappingProperties uint8 + +const ( + // UnknownNATMapping indicates the NAT type could not be identified + UnknownNATMapping NATMappingProperties = iota + + // EIMNATMapping indicates endpoint-independent mapping + EIMNATMapping + + // EDMNATMapping indicates endpoint-dependent mapping + EDMNATMapping +) diff --git a/constants/station_url_flag.go b/constants/station_url_flag.go new file mode 100644 index 00000000..3942898a --- /dev/null +++ b/constants/station_url_flag.go @@ -0,0 +1,12 @@ +package constants + +// StationURLFlag is an enum of flags used by the StationURL "type" parameter. +type StationURLFlag uint8 + +const ( + // StationURLFlagBehindNAT indicates the user is behind NAT + StationURLFlagBehindNAT StationURLFlag = iota + 1 + + // StationURLFlagPublic indicates the station is a public address + StationURLFlagPublic +) diff --git a/constants/station_url_type.go b/constants/station_url_type.go new file mode 100644 index 00000000..ed750d5e --- /dev/null +++ b/constants/station_url_type.go @@ -0,0 +1,20 @@ +package constants + +// StationURLType is an implementation of the nn::nex::StationURL::URLType enum. +// +// StationURLType is used to indicate the type of connection to use when contacting a station. +type StationURLType uint8 + +const ( + // UnknownStationURLType indicates an unknown URL type + UnknownStationURLType StationURLType = iota + + // StationURLPRUDP indicates the station should be contacted with a standard PRUDP connection + StationURLPRUDP + + // StationURLPRUDPS indicates the station should be contacted with a secure PRUDP connection + StationURLPRUDPS + + // StationURLUDP indicates the station should be contacted with raw UDP data. Used for custom protocols + StationURLUDP +) From ff1d3b1482b47e076a4f3ced6578fa8a6d7ae210 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 15 Mar 2024 18:25:46 -0400 Subject: [PATCH 164/178] prudp: rename packet flags to distinguish them from other constants --- constants/prudp_packet_flags.go | 20 ++++++++++---------- prudp_endpoint.go | 26 +++++++++++++------------- prudp_packet_lite.go | 6 +++--- prudp_packet_v0.go | 4 ++-- prudp_server.go | 10 +++++----- test/auth.go | 12 ++++++------ test/secure.go | 24 ++++++++++++------------ 7 files changed, 51 insertions(+), 51 deletions(-) diff --git a/constants/prudp_packet_flags.go b/constants/prudp_packet_flags.go index 6c54c7f0..33cc324c 100644 --- a/constants/prudp_packet_flags.go +++ b/constants/prudp_packet_flags.go @@ -1,18 +1,18 @@ package constants const ( - // FlagAck is the ID for the PRUDP Ack Flag - FlagAck uint16 = 0x1 + // PacketFlagAck is the ID for the PRUDP Ack Flag + PacketFlagAck uint16 = 0x1 - // FlagReliable is the ID for the PRUDP Reliable Flag - FlagReliable uint16 = 0x2 + // PacketFlagReliable is the ID for the PRUDP Reliable Flag + PacketFlagReliable uint16 = 0x2 - // FlagNeedsAck is the ID for the PRUDP NeedsAck Flag - FlagNeedsAck uint16 = 0x4 + // PacketFlagNeedsAck is the ID for the PRUDP NeedsAck Flag + PacketFlagNeedsAck uint16 = 0x4 - // FlagHasSize is the ID for the PRUDP HasSize Flag - FlagHasSize uint16 = 0x8 + // PacketFlagHasSize is the ID for the PRUDP HasSize Flag + PacketFlagHasSize uint16 = 0x8 - // FlagMultiAck is the ID for the PRUDP MultiAck Flag - FlagMultiAck uint16 = 0x200 + // PacketFlagMultiAck is the ID for the PRUDP MultiAck Flag + PacketFlagMultiAck uint16 = 0x200 ) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 3778d6ea..f070353a 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -114,7 +114,7 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc packet.SetSender(connection) - if packet.HasFlag(constants.FlagAck) || packet.HasFlag(constants.FlagMultiAck) { + if packet.HasFlag(constants.PacketFlagAck) || packet.HasFlag(constants.PacketFlagMultiAck) { pep.handleAcknowledgment(packet) return } @@ -141,7 +141,7 @@ func (pep *PRUDPEndPoint) handleAcknowledgment(packet PRUDPPacketInterface) { return } - if packet.HasFlag(constants.FlagMultiAck) { + if packet.HasFlag(constants.PacketFlagMultiAck) { pep.handleMultiAcknowledgment(packet) return } @@ -226,8 +226,8 @@ func (pep *PRUDPEndPoint) handleSyn(packet PRUDPPacketInterface) { connection.Signature = connectionSignature ack.SetType(constants.SynPacket) - ack.AddFlag(constants.FlagAck) - ack.AddFlag(constants.FlagHasSize) + ack.AddFlag(constants.PacketFlagAck) + ack.AddFlag(constants.PacketFlagHasSize) ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -279,8 +279,8 @@ func (pep *PRUDPEndPoint) handleConnect(packet PRUDPPacketInterface) { connection.ServerSessionID = packet.SessionID() ack.SetType(constants.ConnectPacket) - ack.AddFlag(constants.FlagAck) - ack.AddFlag(constants.FlagHasSize) + ack.AddFlag(constants.PacketFlagAck) + ack.AddFlag(constants.PacketFlagHasSize) ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -386,7 +386,7 @@ func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { connection.resetHeartbeat() - if packet.HasFlag(constants.FlagReliable) { + if packet.HasFlag(constants.PacketFlagReliable) { pep.handleReliable(packet) } else { pep.handleUnreliable(packet) @@ -397,7 +397,7 @@ func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { // TODO - Should we check the state here, or just let the connection disconnect at any time? // TODO - Should we bother to set the connections state here? It's being destroyed anyway - if packet.HasFlag(constants.FlagNeedsAck) { + if packet.HasFlag(constants.PacketFlagNeedsAck) { pep.acknowledgePacket(packet) } @@ -417,7 +417,7 @@ func (pep *PRUDPEndPoint) handlePing(packet PRUDPPacketInterface) { connection.resetHeartbeat() - if packet.HasFlag(constants.FlagNeedsAck) { + if packet.HasFlag(constants.PacketFlagNeedsAck) { pep.acknowledgePacket(packet) } } @@ -500,7 +500,7 @@ func (pep *PRUDPEndPoint) acknowledgePacket(packet PRUDPPacketInterface) { } ack.SetType(packet.Type()) - ack.AddFlag(constants.FlagAck) + ack.AddFlag(constants.PacketFlagAck) ack.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) ack.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) ack.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -519,7 +519,7 @@ func (pep *PRUDPEndPoint) acknowledgePacket(packet PRUDPPacketInterface) { } func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { - if packet.HasFlag(constants.FlagNeedsAck) { + if packet.HasFlag(constants.PacketFlagNeedsAck) { pep.acknowledgePacket(packet) } @@ -564,7 +564,7 @@ func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { } func (pep *PRUDPEndPoint) handleUnreliable(packet PRUDPPacketInterface) { - if packet.HasFlag(constants.FlagNeedsAck) { + if packet.HasFlag(constants.PacketFlagNeedsAck) { pep.acknowledgePacket(packet) } @@ -637,7 +637,7 @@ func (pep *PRUDPEndPoint) sendPing(connection *PRUDPConnection) { } ping.SetType(constants.PingPacket) - ping.AddFlag(constants.FlagNeedsAck) + ping.AddFlag(constants.PacketFlagNeedsAck) ping.SetSourceVirtualPortStreamType(connection.StreamType) ping.SetSourceVirtualPortStreamID(pep.StreamID) ping.SetDestinationVirtualPortStreamType(connection.StreamType) diff --git a/prudp_packet_lite.go b/prudp_packet_lite.go index 5d368eab..02b84f6e 100644 --- a/prudp_packet_lite.go +++ b/prudp_packet_lite.go @@ -251,7 +251,7 @@ func (p *PRUDPPacketLite) decodeOptions() error { } } - if p.packetType == constants.ConnectPacket && !p.HasFlag(constants.FlagAck) { + if p.packetType == constants.ConnectPacket && !p.HasFlag(constants.PacketFlagAck) { if optionID == 0x80 { p.liteSignature = optionsStream.ReadBytesNext(int64(optionSize)) } @@ -276,14 +276,14 @@ func (p *PRUDPPacketLite) encodeOptions() []byte { optionsStream.WritePrimitiveUInt8(4) optionsStream.WritePrimitiveUInt32LE(p.minorVersion | (p.supportedFunctions << 8)) - if p.packetType == constants.SynPacket && p.HasFlag(constants.FlagAck) { + if p.packetType == constants.SynPacket && p.HasFlag(constants.PacketFlagAck) { optionsStream.WritePrimitiveUInt8(1) optionsStream.WritePrimitiveUInt8(16) optionsStream.Grow(16) optionsStream.WriteBytesNext(p.connectionSignature) } - if p.packetType == constants.ConnectPacket && !p.HasFlag(constants.FlagAck) { + if p.packetType == constants.ConnectPacket && !p.HasFlag(constants.PacketFlagAck) { optionsStream.WritePrimitiveUInt8(1) optionsStream.WritePrimitiveUInt8(16) optionsStream.Grow(16) diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 1a338de2..447135f7 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -137,7 +137,7 @@ func (p *PRUDPPacketV0) decode() error { var payloadSize uint16 - if p.HasFlag(constants.FlagHasSize) { + if p.HasFlag(constants.PacketFlagHasSize) { if p.readStream.Remaining() < 2 { return errors.New("Failed to read PRUDPv0 payload size. Not have enough data") } @@ -220,7 +220,7 @@ func (p *PRUDPPacketV0) Bytes() []byte { stream.WritePrimitiveUInt8(p.fragmentID) } - if p.HasFlag(constants.FlagHasSize) { + if p.HasFlag(constants.PacketFlagHasSize) { stream.WritePrimitiveUInt16LE(uint16(len(p.payload))) } diff --git a/prudp_server.go b/prudp_server.go index ad15cd96..fa50b017 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -234,8 +234,8 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy := packet.Copy() connection := packetCopy.Sender().(*PRUDPConnection) - if !packetCopy.HasFlag(constants.FlagAck) && !packetCopy.HasFlag(constants.FlagMultiAck) { - if packetCopy.HasFlag(constants.FlagReliable) { + if !packetCopy.HasFlag(constants.PacketFlagAck) && !packetCopy.HasFlag(constants.PacketFlagMultiAck) { + if packetCopy.HasFlag(constants.PacketFlagReliable) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) packetCopy.SetSequenceID(slidingWindow.NextOutgoingSequenceID()) } else if packetCopy.Type() == constants.DataPacket { @@ -249,8 +249,8 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy.SetSessionID(connection.ServerSessionID) - if packetCopy.Type() == constants.DataPacket && !packetCopy.HasFlag(constants.FlagAck) && !packetCopy.HasFlag(constants.FlagMultiAck) { - if packetCopy.HasFlag(constants.FlagReliable) { + if packetCopy.Type() == constants.DataPacket && !packetCopy.HasFlag(constants.PacketFlagAck) && !packetCopy.HasFlag(constants.PacketFlagMultiAck) { + if packetCopy.HasFlag(constants.PacketFlagReliable) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) payload := packetCopy.Payload() @@ -279,7 +279,7 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature)) } - if packetCopy.HasFlag(constants.FlagReliable) && packetCopy.HasFlag(constants.FlagNeedsAck) { + if packetCopy.HasFlag(constants.PacketFlagReliable) && packetCopy.HasFlag(constants.PacketFlagNeedsAck) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) slidingWindow.ResendScheduler.AddPacket(packetCopy) } diff --git a/test/auth.go b/test/auth.go index 004cd574..5e41e264 100644 --- a/test/auth.go +++ b/test/auth.go @@ -96,9 +96,9 @@ func login(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(authServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(constants.FlagHasSize) - responsePacket.AddFlag(constants.FlagReliable) - responsePacket.AddFlag(constants.FlagNeedsAck) + responsePacket.AddFlag(constants.PacketFlagHasSize) + responsePacket.AddFlag(constants.PacketFlagReliable) + responsePacket.AddFlag(constants.PacketFlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -149,9 +149,9 @@ func requestTicket(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(authServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(constants.FlagHasSize) - responsePacket.AddFlag(constants.FlagReliable) - responsePacket.AddFlag(constants.FlagNeedsAck) + responsePacket.AddFlag(constants.PacketFlagHasSize) + responsePacket.AddFlag(constants.PacketFlagReliable) + responsePacket.AddFlag(constants.PacketFlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) diff --git a/test/secure.go b/test/secure.go index 47f11490..cacca755 100644 --- a/test/secure.go +++ b/test/secure.go @@ -135,9 +135,9 @@ func registerEx(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, connection, nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(constants.FlagHasSize) - responsePacket.AddFlag(constants.FlagReliable) - responsePacket.AddFlag(constants.FlagNeedsAck) + responsePacket.AddFlag(constants.PacketFlagHasSize) + responsePacket.AddFlag(constants.PacketFlagReliable) + responsePacket.AddFlag(constants.PacketFlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -183,9 +183,9 @@ func updateAndGetAllInformation(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(constants.FlagHasSize) - responsePacket.AddFlag(constants.FlagReliable) - responsePacket.AddFlag(constants.FlagNeedsAck) + responsePacket.AddFlag(constants.PacketFlagHasSize) + responsePacket.AddFlag(constants.PacketFlagReliable) + responsePacket.AddFlag(constants.PacketFlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -215,9 +215,9 @@ func checkSettingStatus(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(constants.FlagHasSize) - responsePacket.AddFlag(constants.FlagReliable) - responsePacket.AddFlag(constants.FlagNeedsAck) + responsePacket.AddFlag(constants.PacketFlagHasSize) + responsePacket.AddFlag(constants.PacketFlagReliable) + responsePacket.AddFlag(constants.PacketFlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) @@ -242,9 +242,9 @@ func updatePresence(packet nex.PRUDPPacketInterface) { responsePacket, _ := nex.NewPRUDPPacketV0(secureServer, packet.Sender().(*nex.PRUDPConnection), nil) responsePacket.SetType(packet.Type()) - responsePacket.AddFlag(constants.FlagHasSize) - responsePacket.AddFlag(constants.FlagReliable) - responsePacket.AddFlag(constants.FlagNeedsAck) + responsePacket.AddFlag(constants.PacketFlagHasSize) + responsePacket.AddFlag(constants.PacketFlagReliable) + responsePacket.AddFlag(constants.PacketFlagNeedsAck) responsePacket.SetSourceVirtualPortStreamType(packet.DestinationVirtualPortStreamType()) responsePacket.SetSourceVirtualPortStreamID(packet.DestinationVirtualPortStreamID()) responsePacket.SetDestinationVirtualPortStreamType(packet.SourceVirtualPortStreamType()) From dde9a9921dcedcd221c6a0d2df57c27307f32e2d Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 16 Mar 2024 11:49:41 -0400 Subject: [PATCH 165/178] prudp: more accurate StationURL implementation --- test/secure.go | 5 +- types/station_url.go | 542 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 502 insertions(+), 45 deletions(-) diff --git a/test/secure.go b/test/secure.go index cacca755..56eaff89 100644 --- a/test/secure.go +++ b/test/secure.go @@ -3,7 +3,6 @@ package main import ( "fmt" "net" - "strconv" "github.com/PretendoNetwork/nex-go" "github.com/PretendoNetwork/nex-go/constants" @@ -112,8 +111,8 @@ func registerEx(packet nex.PRUDPPacketInterface) { address := packet.Sender().Address().(*net.UDPAddr).IP.String() - localStation.Params["address"] = address - localStation.Params["port"] = strconv.Itoa(packet.Sender().Address().(*net.UDPAddr).Port) + localStation.SetAddress(address) + localStation.SetPortNumber(uint16(packet.Sender().Address().(*net.UDPAddr).Port)) retval := types.NewQResultSuccess(0x00010001) localStationURL := types.NewString(localStation.EncodeToString()) diff --git a/types/station_url.go b/types/station_url.go index 3a86a906..cfbbf495 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -2,16 +2,73 @@ package types import ( "fmt" + "strconv" "strings" + + "github.com/PretendoNetwork/nex-go/constants" ) // StationURL is an implementation of rdv::StationURL. +// // Contains location of a station to connect to, with data about how to connect. type StationURL struct { - local bool // * Not part of the data structure. Used for easier lookups elsewhere - public bool // * Not part of the data structure. Used for easier lookups elsewhere - Scheme string - Fields map[string]string + urlType constants.StationURLType + flags uint8 + params map[string]string +} + +func (s *StationURL) numberParamValue(name string, bits int) (uint64, bool) { + valueString, ok := s.ParamValue(name) + if !ok { + return 0, false + } + + value, err := strconv.ParseUint(valueString, 10, bits) + if err != nil { + return 0, false + } + + return value, true +} + +func (s *StationURL) uint8ParamValue(name string) (uint8, bool) { + value, ok := s.numberParamValue(name, 8) + if !ok { + return 0, false + } + + return uint8(value), true +} + +func (s *StationURL) uint16ParamValue(name string) (uint16, bool) { + value, ok := s.numberParamValue(name, 16) + if !ok { + return 0, false + } + + return uint16(value), true +} + +func (s *StationURL) uint32ParamValue(name string) (uint32, bool) { + value, ok := s.numberParamValue(name, 32) + if !ok { + return 0, false + } + + return uint32(value), true +} + +func (s *StationURL) uint64ParamValue(name string) (uint64, bool) { + return s.numberParamValue(name, 64) +} + +func (s *StationURL) boolParamValue(name string) bool { + valueString, ok := s.ParamValue(name) + if !ok { + return false + } + + return valueString == "1" } // WriteTo writes the StationURL to the given writable @@ -47,24 +104,20 @@ func (s *StationURL) Equals(o RVType) bool { other := o.(*StationURL) - if s.local != other.local { - return false - } - - if s.public != other.public { + if s.urlType != other.urlType { return false } - if s.Scheme != other.Scheme { + if s.flags != other.flags { return false } - if len(s.Fields) != len(other.Fields) { + if len(s.params) != len(other.params) { return false } - for key, value1 := range s.Fields { - value2, ok := other.Fields[key] + for key, value1 := range s.params { + value2, ok := other.params[key] if !ok || value1 != value2 { return false } @@ -73,26 +126,400 @@ func (s *StationURL) Equals(o RVType) bool { return true } -// SetLocal marks the StationURL as an local URL -func (s *StationURL) SetLocal() { - s.local = true - s.public = false +// SetParamValue sets a StationURL parameter +func (s *StationURL) SetParamValue(name, value string) { + s.params[name] = value +} + +// RemoveParam removes a StationURL parameter. +// +// Not part of the original API +func (s *StationURL) RemoveParam(name string) { + delete(s.params, name) +} + +// ParamValue returns the value of the requested param. +// +// Returns the string value and a bool indicating if the value existed or not. +// +// Originally called nn::nex::StationURL::GetParamValue +func (s *StationURL) ParamValue(name string) (string, bool) { + if value, ok := s.params[name]; ok { + return value, true + } + + return "", false +} + +// SetAddress sets the stations IP address +func (s *StationURL) SetAddress(address string) { + s.SetParamValue("address", address) +} + +// Address gets the stations IP address. +// +// Originally called nn::nex::StationURL::GetAddress +func (s *StationURL) Address(address string) (string, bool) { + return s.ParamValue("address") +} + +// SetPortNumber sets the stations port +func (s *StationURL) SetPortNumber(port uint16) { + s.SetParamValue("port", strconv.FormatUint(uint64(port), 10)) +} + +// PortNumber gets the stations port. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetPortNumber +func (s *StationURL) PortNumber() (uint16, bool) { + return s.uint16ParamValue("port") +} + +// SetURLType sets the stations URL scheme type +func (s *StationURL) SetURLType(urlType constants.StationURLType) { + s.urlType = urlType +} + +// URLType returns the stations scheme type +// +// Originally called nn::nex::StationURL::GetURLType +func (s *StationURL) URLType() constants.StationURLType { + return s.urlType +} + +// SetStreamID sets the stations stream ID +// +// See VirtualPort +func (s *StationURL) SetStreamID(streamID uint8) { + s.SetParamValue("sid", strconv.FormatUint(uint64(streamID), 10)) +} + +// StreamID gets the stations stream ID. +// +// See VirtualPort. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetStreamID +func (s *StationURL) StreamID() (uint8, bool) { + return s.uint8ParamValue("sid") +} + +// SetStreamType sets the stations stream type +// +// See VirtualPort +func (s *StationURL) SetStreamType(streamType constants.StreamType) { + s.SetParamValue("stream", strconv.FormatUint(uint64(streamType), 10)) +} + +// StreamType gets the stations stream type. +// +// See VirtualPort. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetStreamType +func (s *StationURL) StreamType() (constants.StreamType, bool) { + streamType, ok := s.uint8ParamValue("stream") + + // TODO - Range check on the enum? + + return constants.StreamType(streamType), ok +} + +// SetNodeID sets the stations node ID +// +// Originally called nn::nex::StationURL::SetNodeId +func (s *StationURL) SetNodeID(nodeID uint16) { + s.SetParamValue("NodeID", strconv.FormatUint(uint64(nodeID), 10)) +} + +// NodeID gets the stations node ID. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetNodeId +func (s *StationURL) NodeID() (uint16, bool) { + return s.uint16ParamValue("NodeID") +} + +// SetPrincipalID sets the stations target PID +func (s *StationURL) SetPrincipalID(pid *PID) { + s.SetParamValue("PID", strconv.FormatUint(pid.Value(), 10)) +} + +// PrincipalID gets the stations target PID. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetPrincipalID +func (s *StationURL) PrincipalID() (*PID, bool) { + pid, ok := s.uint64ParamValue("PID") + if !ok { + return nil, false + } + + return NewPID(pid), true +} + +// SetConnectionID sets the stations connection ID +// +// Unsure how this differs from the Rendez-Vous connection ID +func (s *StationURL) SetConnectionID(connectionID uint32) { + s.SetParamValue("CID", strconv.FormatUint(uint64(connectionID), 10)) +} + +// ConnectionID gets the stations connection ID. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetConnectionID +func (s *StationURL) ConnectionID() (uint32, bool) { + return s.uint32ParamValue("CID") } -// SetPublic marks the StationURL as an public URL -func (s *StationURL) SetPublic() { - s.local = false - s.public = true +// SetRVConnectionID sets the stations Rendez-Vous connection ID +// +// Unsure how this differs from the connection ID +func (s *StationURL) SetRVConnectionID(connectionID uint32) { + s.SetParamValue("RVCID", strconv.FormatUint(uint64(connectionID), 10)) } -// IsLocal checks if the StationURL is a local URL -func (s *StationURL) IsLocal() bool { - return s.local +// RVConnectionID gets the stations Rendez-Vous connection ID. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetRVConnectionID +func (s *StationURL) RVConnectionID() (uint32, bool) { + return s.uint32ParamValue("RVCID") } -// IsPublic checks if the StationURL is a public URL +// SetProbeRequestID sets the probe request ID +func (s *StationURL) SetProbeRequestID(probeRequestID uint32) { + s.SetParamValue("PRID", strconv.FormatUint(uint64(probeRequestID), 10)) +} + +// ProbeRequestID gets the probe request ID. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetProbeRequestID +func (s *StationURL) ProbeRequestID() (uint32, bool) { + return s.uint32ParamValue("PRID") +} + +// SetFastProbeResponse sets whether fast probe response should be enabled or not +func (s *StationURL) SetFastProbeResponse(fast bool) { + if fast { + s.SetParamValue("fastproberesponse", "1") + } else { + s.SetParamValue("fastproberesponse", "0") + } +} + +// IsFastProbeResponseEnabled checks if fast probe response is enabled +// +// Originally called nn::nex::StationURL::GetFastProbeResponse +func (s *StationURL) IsFastProbeResponseEnabled() bool { + return s.boolParamValue("fastproberesponse") +} + +// SetNATMapping sets the clients NAT mapping properties +func (s *StationURL) SetNATMapping(mapping constants.NATMappingProperties) { + s.SetParamValue("natm", strconv.FormatUint(uint64(mapping), 10)) +} + +// NATMapping gets the clients NAT mapping properties. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetNATMapping +func (s *StationURL) NATMapping() (constants.NATMappingProperties, bool) { + natm, ok := s.uint8ParamValue("natm") + + // TODO - Range check on the enum? + + return constants.NATMappingProperties(natm), ok +} + +// SetNATFiltering sets the clients NAT filtering properties +func (s *StationURL) SetNATFiltering(filtering constants.NATFilteringProperties) { + s.SetParamValue("natf", strconv.FormatUint(uint64(filtering), 10)) +} + +// NATFiltering gets the clients NAT filtering properties. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetNATFiltering +func (s *StationURL) NATFiltering() (constants.NATFilteringProperties, bool) { + natf, ok := s.uint8ParamValue("natf") + + // TODO - Range check on the enum? + + return constants.NATFilteringProperties(natf), ok +} + +// SetProbeRequestInitiation sets whether probing should begin or not +func (s *StationURL) SetProbeRequestInitiation(probeinit bool) { + if probeinit { + s.SetParamValue("probeinit", "1") + } else { + s.SetParamValue("probeinit", "0") + } +} + +// IsProbeRequestInitiationEnabled checks wheteher probing should be initiated. +// +// Originally called nn::nex::StationURL::GetProbeRequestInitiation +func (s *StationURL) IsProbeRequestInitiationEnabled() bool { + return s.boolParamValue("probeinit") +} + +// SetUPnPSupport sets whether UPnP should be enabled or not +func (s *StationURL) SetUPnPSupport(supported bool) { + if supported { + s.SetParamValue("upnp", "1") + } else { + s.SetParamValue("upnp", "0") + } +} + +// IsUPnPSupported checks whether UPnP is enabled on the station. +// +// Originally called nn::nex::StationURL::GetUPnPSupport +func (s *StationURL) IsUPnPSupported() bool { + return s.boolParamValue("upnp") +} + +// SetNATPMPSupport sets whether PMP should be enabled or not. +// +// Originally called nn::nex::StationURL::SetNatPMPSupport +func (s *StationURL) SetNATPMPSupport(supported bool) { + if supported { + s.SetParamValue("pmp", "1") + } else { + s.SetParamValue("pmp", "0") + } +} + +// IsNATPMPSupported checks whether PMP is enabled on the station. +// +// Originally called nn::nex::StationURL::GetNatPMPSupport +func (s *StationURL) IsNATPMPSupported() bool { + return s.boolParamValue("pmp") +} + +// SetType sets the stations type flags +func (s *StationURL) SetType(flags uint8) { + s.flags = flags // * This normally isn't done, but makes IsPublic and IsBehindNAT simpler + s.SetParamValue("type", strconv.FormatUint(uint64(flags), 10)) +} + +// Type gets the stations type flags. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetType +func (s *StationURL) Type() (uint8, bool) { + return s.uint8ParamValue("type") +} + +// SetRelayServerAddress sets the address for the relay server +func (s *StationURL) SetRelayServerAddress(address string) { + s.SetParamValue("Rsa", address) +} + +// RelayServerAddress gets the address for the relay server +// +// Originally called nn::nex::StationURL::GetRelayServerAddress +func (s *StationURL) RelayServerAddress() (string, bool) { + return s.ParamValue("Rsa") +} + +// SetRelayServerPort sets the port for the relay server +func (s *StationURL) SetRelayServerPort(port uint16) { + s.SetParamValue("Rsp", strconv.FormatUint(uint64(port), 10)) +} + +// RelayServerPort gets the stations relay server port. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetRelayServerPort +func (s *StationURL) RelayServerPort() (uint16, bool) { + return s.uint16ParamValue("Rsp") +} + +// SetRelayAddress gets the address for the relay +func (s *StationURL) SetRelayAddress(address string) { + s.SetParamValue("Ra", address) +} + +// RelayAddress gets the address for the relay +// +// Originally called nn::nex::StationURL::GetRelayAddress +func (s *StationURL) RelayAddress() (string, bool) { + return s.ParamValue("Ra") +} + +// SetRelayPort sets the port for the relay +func (s *StationURL) SetRelayPort(port uint16) { + s.SetParamValue("Rp", strconv.FormatUint(uint64(port), 10)) +} + +// RelayPort gets the stations relay port. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetRelayPort +func (s *StationURL) RelayPort() (uint16, bool) { + return s.uint16ParamValue("Rp") +} + +// SetUseRelayServer sets whether or not a relay server should be used +func (s *StationURL) SetUseRelayServer(useRelayServer bool) { + if useRelayServer { + s.SetParamValue("R", "1") + } else { + s.SetParamValue("R", "0") + } +} + +// IsRelayServerEnabled checks whether the connection should use a relay server. +// +// Originally called nn::nex::StationURL::GetUseRelayServer +func (s *StationURL) IsRelayServerEnabled() bool { + return s.boolParamValue("R") +} + +// SetPlatformType sets the stations platform type +func (s *StationURL) SetPlatformType(platformType uint8) { + // * This is likely to change based on the target platforms, so no enum + // * 2 = Wii U (Seen in Minecraft) + // * 1 = 3DS? Assumed based on Wii U + s.SetParamValue("Pl", strconv.FormatUint(uint64(platformType), 10)) +} + +// PlatformType gets the stations target platform. Legal values vary by developer and platforms. +// +// Returns a bool indicating if the parameter existed or not. +// +// Originally called nn::nex::StationURL::GetPortNumber +func (s *StationURL) PlatformType() (uint8, bool) { + return s.uint8ParamValue("Pl") +} + +// IsPublic checks if the station is a public address func (s *StationURL) IsPublic() bool { - return s.public + return s.flags&uint8(constants.StationURLFlagPublic) == uint8(constants.StationURLFlagPublic) +} + +// IsBehindNAT checks if the user is behind NAT +func (s *StationURL) IsBehindNAT() bool { + return s.flags&uint8(constants.StationURLFlagBehindNAT) == uint8(constants.StationURLFlagBehindNAT) } // FromString parses the StationURL data from a string @@ -101,41 +528,72 @@ func (s *StationURL) FromString(str string) { return } - split := strings.Split(str, ":/") - - s.Scheme = split[0] + parts := strings.Split(str, ":/") + parametersString := "" + + // * Unknown scehemes seem to be supported based on + // * Format__Q3_2nn3nex10StationURLFv + if len(parts) == 1 { + parametersString = parts[0] + s.SetURLType(constants.UnknownStationURLType) + } else if len(parts) == 2 { + scheme := parts[0] + parametersString = parts[1] + + if scheme == "prudp" { + s.SetURLType(constants.StationURLPRUDP) + } else if scheme == "prudps" { + s.SetURLType(constants.StationURLPRUDPS) + } else if scheme == "udp" { + s.SetURLType(constants.StationURLUDP) + } else { + s.SetURLType(constants.UnknownStationURLType) + } + } else { + // * Badly formatted station + return + } // * Return if there are no fields - if split[1] == "" { + if parametersString == "" { return } - fields := strings.Split(split[1], ";") + parameters := strings.Split(parametersString, ";") - for i := 0; i < len(fields); i++ { - field := strings.Split(fields[i], "=") + for i := 0; i < len(parameters); i++ { + parameter := strings.Split(parameters[i], "=") - key := field[0] - value := field[1] + // TODO - StationURL parameters support extra data through the # delimiter. What is that? Need to support it somehow + name := parameter[0] + value := parameter[1] - s.Fields[key] = value + s.SetParamValue(name, value) } } // EncodeToString encodes the StationURL into a string func (s *StationURL) EncodeToString() string { - // * Don't return anything if no scheme is set - if s.Scheme == "" { - return "" + scheme := "" + + // * Unknown scehemes seem to be supported based on + // * Format__Q3_2nn3nex10StationURLFv + if s.urlType == constants.StationURLPRUDP { + scheme = "prudp:/" + } else if s.urlType == constants.StationURLPRUDPS { + scheme = "prudps:/" + } else if s.urlType == constants.StationURLUDP { + scheme = "udp:/" } fields := make([]string, 0) - for key, value := range s.Fields { + for key, value := range s.params { + // TODO - StationURL parameters support extra data through the # delimiter. What is that? Need to support it somehow fields = append(fields, fmt.Sprintf("%s=%s", key, value)) } - return s.Scheme + ":/" + strings.Join(fields, ";") + return scheme + strings.Join(fields, ";") } // String returns a string representation of the struct @@ -160,7 +618,7 @@ func (s *StationURL) FormatToString(indentationLevel int) string { // NewStationURL returns a new StationURL func NewStationURL(str string) *StationURL { stationURL := &StationURL{ - Fields: make(map[string]string), + params: make(map[string]string), } stationURL.FromString(str) From ce2526f3dd540724955c2483ae0ce71cce409519 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 16 Mar 2024 11:59:50 -0400 Subject: [PATCH 166/178] chore: spelling in StationURL comments --- types/station_url.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/types/station_url.go b/types/station_url.go index cfbbf495..98996881 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -531,7 +531,7 @@ func (s *StationURL) FromString(str string) { parts := strings.Split(str, ":/") parametersString := "" - // * Unknown scehemes seem to be supported based on + // * Unknown schemes seem to be supported based on // * Format__Q3_2nn3nex10StationURLFv if len(parts) == 1 { parametersString = parts[0] @@ -576,7 +576,7 @@ func (s *StationURL) FromString(str string) { func (s *StationURL) EncodeToString() string { scheme := "" - // * Unknown scehemes seem to be supported based on + // * Unknown schemes seem to be supported based on // * Format__Q3_2nn3nex10StationURLFv if s.urlType == constants.StationURLPRUDP { scheme = "prudp:/" From bfadcf1fc239fdebb1c8dd8267445bcb42947f06 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 16 Mar 2024 12:00:41 -0400 Subject: [PATCH 167/178] chore: StationURL does actually have a Remove method --- types/station_url.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/types/station_url.go b/types/station_url.go index 98996881..e32782c0 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -133,7 +133,7 @@ func (s *StationURL) SetParamValue(name, value string) { // RemoveParam removes a StationURL parameter. // -// Not part of the original API +// Originally called nn::nex::StationURL::Remove func (s *StationURL) RemoveParam(name string) { delete(s.params, name) } From 050ebb1a884b613526d2b000fd658bca3d0bedd2 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 17 Mar 2024 12:12:27 -0400 Subject: [PATCH 168/178] prudp: allow for custom packet types --- prudp_endpoint.go | 32 ++++++++++++++++++++------------ prudp_packet_v0.go | 4 ---- prudp_packet_v1.go | 4 ---- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index f070353a..87c31bb5 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -23,6 +23,7 @@ type PRUDPEndPoint struct { StreamID uint8 DefaultStreamSettings *StreamSettings Connections *MutexMap[string, *PRUDPConnection] + packetHandlers map[uint16]func(packet PRUDPPacketInterface) packetEventHandlers map[string][]func(packet PacketInterface) connectionEndedEventHandlers []func(connection *PRUDPConnection) errorEventHandlers []func(err *Error) @@ -39,6 +40,11 @@ func (pep *PRUDPEndPoint) RegisterServiceProtocol(protocol ServiceProtocol) { pep.OnData(protocol.HandlePacket) } +// RegisterCustomPacketHandler registers a custom handler for a given packet type. Used to override existing handlers or create new ones for custom packet types. +func (pep *PRUDPEndPoint) RegisterCustomPacketHandler(packetType uint16, handler func(packet PRUDPPacketInterface)) { + pep.packetHandlers[packetType] = handler +} + // OnData adds an event handler which is fired when a new DATA packet is received func (pep *PRUDPEndPoint) OnData(handler func(packet PacketInterface)) { pep.on("data", handler) @@ -119,17 +125,10 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc return } - switch packet.Type() { - case constants.SynPacket: - pep.handleSyn(packet) - case constants.ConnectPacket: - pep.handleConnect(packet) - case constants.DataPacket: - pep.handleData(packet) - case constants.DisconnectPacket: - pep.handleDisconnect(packet) - case constants.PingPacket: - pep.handlePing(packet) + if packetHandler, ok := pep.packetHandlers[packet.Type()]; ok { + packetHandler(packet) + } else { + logger.Warningf("Unhandled packet type %d", packet.Type()) } } @@ -721,14 +720,23 @@ func (pep *PRUDPEndPoint) EnableVerboseRMC(enable bool) { // NewPRUDPEndPoint returns a new PRUDPEndPoint for a server on the provided stream ID func NewPRUDPEndPoint(streamID uint8) *PRUDPEndPoint { - return &PRUDPEndPoint{ + pep := &PRUDPEndPoint{ StreamID: streamID, DefaultStreamSettings: NewStreamSettings(), Connections: NewMutexMap[string, *PRUDPConnection](), + packetHandlers: make(map[uint16]func(packet PRUDPPacketInterface)), packetEventHandlers: make(map[string][]func(PacketInterface)), connectionEndedEventHandlers: make([]func(connection *PRUDPConnection), 0), errorEventHandlers: make([]func(err *Error), 0), ConnectionIDCounter: NewCounter[uint32](0), IsSecureEndPoint: false, } + + pep.packetHandlers[constants.SynPacket] = pep.handleSyn + pep.packetHandlers[constants.ConnectPacket] = pep.handleConnect + pep.packetHandlers[constants.DataPacket] = pep.handleData + pep.packetHandlers[constants.DisconnectPacket] = pep.handleDisconnect + pep.packetHandlers[constants.PingPacket] = pep.handlePing + + return pep } diff --git a/prudp_packet_v0.go b/prudp_packet_v0.go index 447135f7..47eb7614 100644 --- a/prudp_packet_v0.go +++ b/prudp_packet_v0.go @@ -100,10 +100,6 @@ func (p *PRUDPPacketV0) decode() error { p.packetType = typeAndFlags & 0xF } - if p.packetType > constants.PingPacket { - return errors.New("Invalid PRUDPv0 packet type") - } - p.sessionID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv0 session ID. %s", err.Error()) diff --git a/prudp_packet_v1.go b/prudp_packet_v1.go index 9a343084..ec6f7e43 100644 --- a/prudp_packet_v1.go +++ b/prudp_packet_v1.go @@ -176,10 +176,6 @@ func (p *PRUDPPacketV1) decodeHeader() error { p.flags = typeAndFlags >> 4 p.packetType = typeAndFlags & 0xF - if p.packetType > constants.PingPacket { - return errors.New("Invalid PRUDPv1 packet type") - } - p.sessionID, err = p.readStream.ReadPrimitiveUInt8() if err != nil { return fmt.Errorf("Failed to read PRUDPv1 session ID. %s", err.Error()) From a0d9eacb86bada4582435acf18751e937b0582c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 17 Mar 2024 22:29:58 +0000 Subject: [PATCH 169/178] types/station_url: Remove unused parameter for Address getter --- types/station_url.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/types/station_url.go b/types/station_url.go index e32782c0..909609ff 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -159,7 +159,7 @@ func (s *StationURL) SetAddress(address string) { // Address gets the stations IP address. // // Originally called nn::nex::StationURL::GetAddress -func (s *StationURL) Address(address string) (string, bool) { +func (s *StationURL) Address() (string, bool) { return s.ParamValue("address") } From 115a2962317fe4fc1f84e56274bfa1821f6c969d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sun, 17 Mar 2024 22:31:02 +0000 Subject: [PATCH 170/178] rmc_message: Remove unneeded error checks --- rmc_message.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/rmc_message.go b/rmc_message.go index 26276c52..d6bd13e9 100644 --- a/rmc_message.go +++ b/rmc_message.go @@ -93,9 +93,6 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } rmc.Parameters = stream.ReadRemaining() - if err != nil { - return fmt.Errorf("Failed to read RMC Message (request) parameters. %s", err.Error()) - } } else { rmc.IsRequest = false rmc.IsSuccess, err = stream.ReadPrimitiveBool() @@ -115,14 +112,7 @@ func (rmc *RMCMessage) decodePacked(data []byte) error { } rmc.MethodID = rmc.MethodID & ^uint32(0x8000) - if err != nil { - return fmt.Errorf("Failed to read RMC Message (response) method ID. %s", err.Error()) - } - rmc.Parameters = stream.ReadRemaining() - if err != nil { - return fmt.Errorf("Failed to read RMC Message (response) parameters. %s", err.Error()) - } } else { rmc.ErrorCode, err = stream.ReadPrimitiveUInt32LE() @@ -180,9 +170,6 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { } rmc.Parameters = stream.ReadRemaining() - if err != nil { - return fmt.Errorf("Failed to read RMC Message (request) parameters. %s", err.Error()) - } } else { rmc.IsSuccess, err = stream.ReadPrimitiveBool() if err != nil { @@ -201,9 +188,6 @@ func (rmc *RMCMessage) decodeVerbose(data []byte) error { } rmc.Parameters = stream.ReadRemaining() - if err != nil { - return fmt.Errorf("Failed to read RMC Message (response) parameters. %s", err.Error()) - } } else { rmc.ErrorCode, err = stream.ReadPrimitiveUInt32LE() From f2b9bf2085d91ea6b5538082dae57e84f77b46ee Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sat, 23 Mar 2024 12:03:48 -0400 Subject: [PATCH 171/178] update: added MutexSlice --- mutex_slice.go | 153 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 mutex_slice.go diff --git a/mutex_slice.go b/mutex_slice.go new file mode 100644 index 00000000..f63d2c60 --- /dev/null +++ b/mutex_slice.go @@ -0,0 +1,153 @@ +package nex + +import "sync" + +// TODO - This currently only properly supports Go native types, due to the use of == for comparisons. Can this be updated to support custom types? + +// MutexSlice implements a slice type with go routine safe accessors through mutex locks. +// +// Embeds sync.RWMutex. +type MutexSlice[V comparable] struct { + *sync.RWMutex + real []V +} + +// Add adds a value to the slice +func (m *MutexSlice[V]) Add(value V) { + m.Lock() + defer m.Unlock() + + m.real = append(m.real, value) +} + +// Delete removes the first instance of the given value from the slice. +// +// Returns true if the value existed and was deleted, otherwise returns false. +func (m *MutexSlice[V]) Delete(value V) bool { + m.Lock() + defer m.Unlock() + + for i, v := range m.real { + if v == value { + m.real = append(m.real[:i], m.real[i+1:]...) + return true + } + } + + return false +} + +// DeleteAll removes all instances of the given value from the slice. +// +// Returns true if the value existed and was deleted, otherwise returns false. +func (m *MutexSlice[V]) DeleteAll(value V) bool { + m.Lock() + defer m.Unlock() + + newSlice := make([]V, 0) + oldLength := len(m.real) + + for _, v := range m.real { + if v != value { + newSlice = append(newSlice, v) + } + } + + m.real = newSlice + + return len(newSlice) < oldLength +} + +// Has checks if the slice contains the given value. +func (m *MutexSlice[V]) Has(value V) bool { + m.Lock() + defer m.Unlock() + + for _, v := range m.real { + if v == value { + return true + } + } + + return false +} + +// GetIndex checks if the slice contains the given value and returns it's index. +// +// Returns -1 if the value does not exist in the slice. +func (m *MutexSlice[V]) GetIndex(value V) int { + m.Lock() + defer m.Unlock() + + for i, v := range m.real { + if v == value { + return i + } + } + + return -1 +} + +// At returns value at the given index. +// +// Returns a bool indicating if the value was found successfully. +func (m *MutexSlice[V]) At(index int) (V, bool) { + m.Lock() + defer m.Unlock() + + if index >= len(m.real) { + return *new(V), false + } + + return m.real[index], true +} + +// Values returns the internal slice. +func (m *MutexSlice[V]) Values() []V { + m.Lock() + defer m.Unlock() + + return m.real +} + +// Size returns the length of the internal slice +func (m *MutexSlice[V]) Size() int { + m.RLock() + defer m.RUnlock() + + return len(m.real) +} + +// Each runs a callback function for every item in the slice. +// +// The slice cannot not be modified inside the callback function. +// +// Returns true if the loop was terminated early. +func (m *MutexSlice[V]) Each(callback func(index int, value V) bool) bool { + m.RLock() + defer m.RUnlock() + + for i, value := range m.real { + if callback(i, value) { + return true + } + } + + return false +} + +// Clear removes all items from the slice. +func (m *MutexSlice[V]) Clear() { + m.Lock() + defer m.Unlock() + + m.real = make([]V, 0) +} + +// NewMutexSlice returns a new instance of MutexSlice with the provided value type +func NewMutexSlice[K comparable, V comparable]() *MutexSlice[V] { + return &MutexSlice[V]{ + RWMutex: &sync.RWMutex{}, + real: make([]V, 0), + } +} From 24a197d7477f1bed4cad723ce5d61d59a6dbd44b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Sat, 23 Mar 2024 19:24:36 +0000 Subject: [PATCH 172/178] mutex_slice: Fix typos --- mutex_slice.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mutex_slice.go b/mutex_slice.go index f63d2c60..eec61798 100644 --- a/mutex_slice.go +++ b/mutex_slice.go @@ -120,7 +120,7 @@ func (m *MutexSlice[V]) Size() int { // Each runs a callback function for every item in the slice. // -// The slice cannot not be modified inside the callback function. +// The slice cannot be modified inside the callback function. // // Returns true if the loop was terminated early. func (m *MutexSlice[V]) Each(callback func(index int, value V) bool) bool { @@ -145,7 +145,7 @@ func (m *MutexSlice[V]) Clear() { } // NewMutexSlice returns a new instance of MutexSlice with the provided value type -func NewMutexSlice[K comparable, V comparable]() *MutexSlice[V] { +func NewMutexSlice[V comparable]() *MutexSlice[V] { return &MutexSlice[V]{ RWMutex: &sync.RWMutex{}, real: make([]V, 0), From 2ccb21dfd64ac9e959ac164a715599143e23b765 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20L=C3=B3pez=20Guimaraes?= Date: Tue, 26 Mar 2024 23:15:55 +0000 Subject: [PATCH 173/178] StationURL: Use strings.Cut for extracting parameters This avoids crashes when a field with no value is given. This popped up on Mario Kart 7. --- types/station_url.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/types/station_url.go b/types/station_url.go index 909609ff..35e91805 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -562,11 +562,8 @@ func (s *StationURL) FromString(str string) { parameters := strings.Split(parametersString, ";") for i := 0; i < len(parameters); i++ { - parameter := strings.Split(parameters[i], "=") - // TODO - StationURL parameters support extra data through the # delimiter. What is that? Need to support it somehow - name := parameter[0] - value := parameter[1] + name, value, _ := strings.Cut(parameters[i], "=") s.SetParamValue(name, value) } From 02078bf185fb8e166c4c96146a2ba4ec61b70753 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 29 Mar 2024 21:18:39 -0400 Subject: [PATCH 174/178] types: fix primitive number type shift methods --- types/primitive_s16.go | 4 ++-- types/primitive_s32.go | 4 ++-- types/primitive_s64.go | 4 ++-- types/primitive_s8.go | 4 ++-- types/primitive_u16.go | 4 ++-- types/primitive_u32.go | 4 ++-- types/primitive_u64.go | 4 ++-- types/primitive_u8.go | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/types/primitive_s16.go b/types/primitive_s16.go index 4bf8ebc1..2d1948d7 100644 --- a/types/primitive_s16.go +++ b/types/primitive_s16.go @@ -100,7 +100,7 @@ func (s16 *PrimitiveS16) LShift(other *PrimitiveS16) *PrimitiveS16 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS16 value. Consumes and returns a Go primitive func (s16 *PrimitiveS16) PLShift(value int16) int16 { - return s16.Value &^ value + return s16.Value << value } // RShift runs a right shift operation on the PrimitiveS16 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (s16 *PrimitiveS16) RShift(other *PrimitiveS16) *PrimitiveS16 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS16 value. Consumes and returns a Go primitive func (s16 *PrimitiveS16) PRShift(value int16) int16 { - return s16.Value &^ value + return s16.Value >> value } // NewPrimitiveS16 returns a new PrimitiveS16 diff --git a/types/primitive_s32.go b/types/primitive_s32.go index a0cca294..a0a439c3 100644 --- a/types/primitive_s32.go +++ b/types/primitive_s32.go @@ -100,7 +100,7 @@ func (s32 *PrimitiveS32) LShift(other *PrimitiveS32) *PrimitiveS32 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS32 value. Consumes and returns a Go primitive func (s32 *PrimitiveS32) PLShift(value int32) int32 { - return s32.Value &^ value + return s32.Value << value } // RShift runs a right shift operation on the PrimitiveS32 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (s32 *PrimitiveS32) RShift(other *PrimitiveS32) *PrimitiveS32 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS32 value. Consumes and returns a Go primitive func (s32 *PrimitiveS32) PRShift(value int32) int32 { - return s32.Value &^ value + return s32.Value >> value } // NewPrimitiveS32 returns a new PrimitiveS32 diff --git a/types/primitive_s64.go b/types/primitive_s64.go index 946cf9f0..ac5c33c0 100644 --- a/types/primitive_s64.go +++ b/types/primitive_s64.go @@ -100,7 +100,7 @@ func (s64 *PrimitiveS64) LShift(other *PrimitiveS64) *PrimitiveS64 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS64 value. Consumes and returns a Go primitive func (s64 *PrimitiveS64) PLShift(value int64) int64 { - return s64.Value &^ value + return s64.Value << value } // RShift runs a right shift operation on the PrimitiveS64 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (s64 *PrimitiveS64) RShift(other *PrimitiveS64) *PrimitiveS64 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS64 value. Consumes and returns a Go primitive func (s64 *PrimitiveS64) PRShift(value int64) int64 { - return s64.Value &^ value + return s64.Value >> value } // NewPrimitiveS64 returns a new PrimitiveS64 diff --git a/types/primitive_s8.go b/types/primitive_s8.go index 2328fed0..f3971319 100644 --- a/types/primitive_s8.go +++ b/types/primitive_s8.go @@ -100,7 +100,7 @@ func (s8 *PrimitiveS8) LShift(other *PrimitiveS8) *PrimitiveS8 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveS8 value. Consumes and returns a Go primitive func (s8 *PrimitiveS8) PLShift(value int8) int8 { - return s8.Value &^ value + return s8.Value << value } // RShift runs a right shift operation on the PrimitiveS8 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (s8 *PrimitiveS8) RShift(other *PrimitiveS8) *PrimitiveS8 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveS8 value. Consumes and returns a Go primitive func (s8 *PrimitiveS8) PRShift(value int8) int8 { - return s8.Value &^ value + return s8.Value >> value } // NewPrimitiveS8 returns a new PrimitiveS8 diff --git a/types/primitive_u16.go b/types/primitive_u16.go index 7b174ddf..a8678b76 100644 --- a/types/primitive_u16.go +++ b/types/primitive_u16.go @@ -100,7 +100,7 @@ func (u16 *PrimitiveU16) LShift(other *PrimitiveU16) *PrimitiveU16 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU16 value. Consumes and returns a Go primitive func (u16 *PrimitiveU16) PLShift(value uint16) uint16 { - return u16.Value &^ value + return u16.Value << value } // RShift runs a right shift operation on the PrimitiveU16 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (u16 *PrimitiveU16) RShift(other *PrimitiveU16) *PrimitiveU16 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU16 value. Consumes and returns a Go primitive func (u16 *PrimitiveU16) PRShift(value uint16) uint16 { - return u16.Value &^ value + return u16.Value >> value } // NewPrimitiveU16 returns a new PrimitiveU16 diff --git a/types/primitive_u32.go b/types/primitive_u32.go index 744ad4ac..95028516 100644 --- a/types/primitive_u32.go +++ b/types/primitive_u32.go @@ -100,7 +100,7 @@ func (u32 *PrimitiveU32) LShift(other *PrimitiveU32) *PrimitiveU32 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU32 value. Consumes and returns a Go primitive func (u32 *PrimitiveU32) PLShift(value uint32) uint32 { - return u32.Value &^ value + return u32.Value << value } // RShift runs a right shift operation on the PrimitiveU32 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (u32 *PrimitiveU32) RShift(other *PrimitiveU32) *PrimitiveU32 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU32 value. Consumes and returns a Go primitive func (u32 *PrimitiveU32) PRShift(value uint32) uint32 { - return u32.Value &^ value + return u32.Value >> value } // NewPrimitiveU32 returns a new PrimitiveU32 diff --git a/types/primitive_u64.go b/types/primitive_u64.go index f960f351..499fcd19 100644 --- a/types/primitive_u64.go +++ b/types/primitive_u64.go @@ -100,7 +100,7 @@ func (u64 *PrimitiveU64) LShift(other *PrimitiveU64) *PrimitiveU64 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU64 value. Consumes and returns a Go primitive func (u64 *PrimitiveU64) PLShift(value uint64) uint64 { - return u64.Value &^ value + return u64.Value << value } // RShift runs a right shift operation on the PrimitiveU64 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (u64 *PrimitiveU64) RShift(other *PrimitiveU64) *PrimitiveU64 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU64 value. Consumes and returns a Go primitive func (u64 *PrimitiveU64) PRShift(value uint64) uint64 { - return u64.Value &^ value + return u64.Value >> value } // NewPrimitiveU64 returns a new PrimitiveU64 diff --git a/types/primitive_u8.go b/types/primitive_u8.go index 78c3c793..af664ac6 100644 --- a/types/primitive_u8.go +++ b/types/primitive_u8.go @@ -100,7 +100,7 @@ func (u8 *PrimitiveU8) LShift(other *PrimitiveU8) *PrimitiveU8 { // PLShift (Primitive Left Shift) runs a left shift operation on the PrimitiveU8 value. Consumes and returns a Go primitive func (u8 *PrimitiveU8) PLShift(value uint8) uint8 { - return u8.Value &^ value + return u8.Value << value } // RShift runs a right shift operation on the PrimitiveU8 value. Consumes and returns a NEX primitive @@ -110,7 +110,7 @@ func (u8 *PrimitiveU8) RShift(other *PrimitiveU8) *PrimitiveU8 { // PRShift (Primitive Right Shift) runs a right shift operation on the PrimitiveU8 value. Consumes and returns a Go primitive func (u8 *PrimitiveU8) PRShift(value uint8) uint8 { - return u8.Value &^ value + return u8.Value >> value } // NewPrimitiveU8 returns a new PrimitiveU8 From 29b24b93bb3a7f7068f5087ae3e12e3be0510b58 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Fri, 29 Mar 2024 21:20:51 -0400 Subject: [PATCH 175/178] types: add OOB check in List.SetIndex --- types/list.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/types/list.go b/types/list.go index 56a303cd..3269744d 100644 --- a/types/list.go +++ b/types/list.go @@ -101,8 +101,14 @@ func (l *List[T]) Get(index int) (T, error) { } // SetIndex sets a value in the List at the given index -func (l *List[T]) SetIndex(index int, value T) { +func (l *List[T]) SetIndex(index int, value T) error { + if index < 0 || index >= len(l.real) { + return errors.New("Index out of bounds") + } + l.real[index] = value + + return nil } // DeleteIndex deletes an element at the given index. Returns an error if the index is OOB From a7b7cd29f46a92f4c1390ae9deab585f7dabb9ea Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Sun, 31 Mar 2024 21:45:28 -0400 Subject: [PATCH 176/178] prudp: move resetHeartbeat call back to processPacket --- prudp_endpoint.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 87c31bb5..c7e643a6 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -119,6 +119,7 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc } packet.SetSender(connection) + connection.resetHeartbeat() if packet.HasFlag(constants.PacketFlagAck) || packet.HasFlag(constants.PacketFlagMultiAck) { pep.handleAcknowledgment(packet) @@ -383,8 +384,6 @@ func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { return } - connection.resetHeartbeat() - if packet.HasFlag(constants.PacketFlagReliable) { pep.handleReliable(packet) } else { @@ -412,10 +411,6 @@ func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { } func (pep *PRUDPEndPoint) handlePing(packet PRUDPPacketInterface) { - connection := packet.Sender().(*PRUDPConnection) - - connection.resetHeartbeat() - if packet.HasFlag(constants.PacketFlagNeedsAck) { pep.acknowledgePacket(packet) } From 58526f8a1a9b36c4c190995e04f6ad529d7880d8 Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Mon, 1 Apr 2024 22:00:02 -0400 Subject: [PATCH 177/178] prudp: fix StationURL.PlatformType comment --- types/station_url.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/types/station_url.go b/types/station_url.go index 35e91805..964f7a27 100644 --- a/types/station_url.go +++ b/types/station_url.go @@ -507,7 +507,7 @@ func (s *StationURL) SetPlatformType(platformType uint8) { // // Returns a bool indicating if the parameter existed or not. // -// Originally called nn::nex::StationURL::GetPortNumber +// Originally called nn::nex::StationURL::GetPlatformType func (s *StationURL) PlatformType() (uint8, bool) { return s.uint8ParamValue("Pl") } From dff47f8748603d4893aea3f82a29d4c3e17bfeea Mon Sep 17 00:00:00 2001 From: Jonathan Barrow Date: Tue, 2 Apr 2024 18:02:43 -0400 Subject: [PATCH 178/178] prudp: remove unused ReliablePacketSubstreamManager --- reliable_packet_substream_manager.go | 93 ---------------------------- 1 file changed, 93 deletions(-) delete mode 100644 reliable_packet_substream_manager.go diff --git a/reliable_packet_substream_manager.go b/reliable_packet_substream_manager.go deleted file mode 100644 index cd85ce63..00000000 --- a/reliable_packet_substream_manager.go +++ /dev/null @@ -1,93 +0,0 @@ -package nex - -import ( - "crypto/rc4" -) - -// ReliablePacketSubstreamManager represents a substream manager for reliable PRUDP packets -type ReliablePacketSubstreamManager struct { - packetMap *MutexMap[uint16, PRUDPPacketInterface] - incomingSequenceIDCounter *Counter[uint16] - outgoingSequenceIDCounter *Counter[uint16] - cipher *rc4.Cipher - decipher *rc4.Cipher - fragmentedPayload []byte - ResendScheduler *ResendScheduler -} - -// Update adds an incoming packet to the list of known packets and returns a list of packets to be processed in order -func (psm *ReliablePacketSubstreamManager) Update(packet PRUDPPacketInterface) []PRUDPPacketInterface { - packets := make([]PRUDPPacketInterface, 0) - - if packet.SequenceID() >= psm.incomingSequenceIDCounter.Value && !psm.packetMap.Has(packet.SequenceID()) { - psm.packetMap.Set(packet.SequenceID(), packet) - - for psm.packetMap.Has(psm.incomingSequenceIDCounter.Value) { - storedPacket, _ := psm.packetMap.Get(psm.incomingSequenceIDCounter.Value) - packets = append(packets, storedPacket) - psm.packetMap.Delete(psm.incomingSequenceIDCounter.Value) - psm.incomingSequenceIDCounter.Next() - } - } - - return packets -} - -// SetCipherKey sets the reliable substreams RC4 cipher keys -func (psm *ReliablePacketSubstreamManager) SetCipherKey(key []byte) { - cipher, _ := rc4.NewCipher(key) - decipher, _ := rc4.NewCipher(key) - - psm.cipher = cipher - psm.decipher = decipher -} - -// NextOutgoingSequenceID sets the reliable substreams RC4 cipher keys -func (psm *ReliablePacketSubstreamManager) NextOutgoingSequenceID() uint16 { - return psm.outgoingSequenceIDCounter.Next() -} - -// Decrypt decrypts the provided data with the substreams decipher -func (psm *ReliablePacketSubstreamManager) Decrypt(data []byte) []byte { - ciphered := make([]byte, len(data)) - - psm.decipher.XORKeyStream(ciphered, data) - - return ciphered -} - -// Encrypt encrypts the provided data with the substreams cipher -func (psm *ReliablePacketSubstreamManager) Encrypt(data []byte) []byte { - ciphered := make([]byte, len(data)) - - psm.cipher.XORKeyStream(ciphered, data) - - return ciphered -} - -// AddFragment adds the given fragment to the substreams fragmented payload -// Returns the current fragmented payload -func (psm *ReliablePacketSubstreamManager) AddFragment(fragment []byte) []byte { - psm.fragmentedPayload = append(psm.fragmentedPayload, fragment...) - - return psm.fragmentedPayload -} - -// ResetFragmentedPayload resets the substreams fragmented payload -func (psm *ReliablePacketSubstreamManager) ResetFragmentedPayload() { - psm.fragmentedPayload = make([]byte, 0) -} - -// NewReliablePacketSubstreamManager initializes a new ReliablePacketSubstreamManager with a starting counter value. -func NewReliablePacketSubstreamManager(startingIncomingSequenceID, startingOutgoingSequenceID uint16) *ReliablePacketSubstreamManager { - psm := &ReliablePacketSubstreamManager{ - packetMap: NewMutexMap[uint16, PRUDPPacketInterface](), - incomingSequenceIDCounter: NewCounter[uint16](startingIncomingSequenceID), - outgoingSequenceIDCounter: NewCounter[uint16](startingOutgoingSequenceID), - ResendScheduler: NewResendScheduler(), - } - - psm.SetCipherKey([]byte("CD&ML")) - - return psm -}