diff --git a/modules/l4postgres/matcher.go b/modules/l4postgres/matcher.go index 9bd7444..3a9f72d 100644 --- a/modules/l4postgres/matcher.go +++ b/modules/l4postgres/matcher.go @@ -12,7 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package l4postgres allows the L4 multiplexing of Postgres connections +// Package l4postgres allows the L4 multiplexing of Postgres connections. +// Connections can be required to have SSL disabled. +// Non-SSL connections can also match on Message parameters. +// +// Example matcher configs: +// +// { +// "postgres": {} +// } +// +// { +// "postgres": { +// "user": { +// "*": ["public_db"], +// "alice": ["planets_db", "stars_db"] +// } +// } +// } +// +// { +// "postgres_client": ["psql", "TablePlus"] +// } +// +// { +// "postgres_ssl": { +// disabled: false +// } // // With thanks to docs and code published at these links: // ref: https://github.com/mholt/caddy-l4/blob/master/modules/l4ssh/matcher.go @@ -27,6 +53,7 @@ import ( "encoding/binary" "errors" "io" + "slices" "github.com/caddyserver/caddy/v2" "github.com/mholt/caddy-l4/layer4" @@ -34,52 +61,38 @@ import ( func init() { caddy.RegisterModule(MatchPostgres{}) + caddy.RegisterModule(MatchPostgresClient{}) + caddy.RegisterModule(MatchPostgresSSL{}) } const ( // Magic number to identify a SSLRequest message sslRequestCode = 80877103 - // byte size of the message length field - initMessageSizeLength = 4 ) -// Message provides readers for various types and -// updates the offset after each read -type message struct { - data []byte - offset uint32 -} - -func (b *message) ReadUint32() (r uint32) { - r = binary.BigEndian.Uint32(b.data[b.offset : b.offset+4]) - b.offset += 4 - return r -} +// NewMessageFromConn create a message from the Connection +func newMessageFromConn(cx *layer4.Connection) (*message, error) { + // Get bytes containing the message length + head := make([]byte, lengthFieldSize) + if _, err := io.ReadFull(cx, head); err != nil { + return nil, err + } -func (b *message) ReadString() (r string) { - end := b.offset - max := uint32(len(b.data)) - for ; end != max && b.data[end] != 0; end++ { + // Get actual message length + data := make([]byte, binary.BigEndian.Uint32(head)-lengthFieldSize) + if _, err := io.ReadFull(cx, data); err != nil { + return nil, err } - r = string(b.data[b.offset:end]) - b.offset = end + 1 - return r -} -// NewMessageFromBytes wraps the raw bytes of a message to enable processing -func newMessageFromBytes(b []byte) *message { - return &message{data: b} + return newMessageFromBytes(data), nil } -// StartupMessage contains the values parsed from the startup message -type startupMessage struct { - ProtocolVersion uint32 - Parameters map[string]string +// MatchPostgres is able to match Postgres connections +type MatchPostgres struct { + User map[string][]string + startup *startupMessage } -// MatchPostgres is able to match Postgres connections. -type MatchPostgres struct{} - // CaddyModule returns the Caddy module information. func (MatchPostgres) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ @@ -88,46 +101,139 @@ func (MatchPostgres) CaddyModule() caddy.ModuleInfo { } } -// Match returns true if the connection looks like the Postgres protocol. +// Match returns true if the connection looks like the Postgres protocol, and +// can match `user` and `database` parameters func (m MatchPostgres) Match(cx *layer4.Connection) (bool, error) { - // Get bytes containing the message length - head := make([]byte, initMessageSizeLength) - if _, err := io.ReadFull(cx, head); err != nil { + b, err := newMessageFromConn(cx) + if err != nil { return false, err } - // Get actual message length - data := make([]byte, binary.BigEndian.Uint32(head)-initMessageSizeLength) - if _, err := io.ReadFull(cx, data); err != nil { - return false, err + m.startup = newStartupMessage(b) + hasConfig := len(m.User) > 0 + + // Finish if this is a SSLRequest and there are no other matchers + if m.startup.IsSSL() && !hasConfig { + return true, nil } - b := newMessageFromBytes(data) + // Check supported protocol + if !m.startup.IsSupported() { + return false, errors.New("pg protocol < 3.0 is not supported") + } - // Check if it is a SSLRequest - code := b.ReadUint32() - if code == sslRequestCode { + // Finish if no more matchers are configured + if !hasConfig { return true, nil } + // Is there a user to check? + user, ok := m.startup.Parameters["user"] + if !ok { + // Are there public databases to check? + if databases, ok := m.User["*"]; ok { + if db, ok := m.startup.Parameters["database"]; ok { + return slices.Contains(databases, db), nil + } + } + return false, nil + } + + databases, ok := m.User[user] + if !ok { + return false, nil + } + + // Are there databases to check? + if len(databases) > 0 { + if db, ok := m.startup.Parameters["database"]; ok { + return slices.Contains(databases, db), nil + } + } + + return true, nil +} + +// MatchPostgresClient is able to match Postgres connections that +// contain an `application_name` field +type MatchPostgresClient struct { + Client []string + startup *startupMessage +} + +// CaddyModule returns the Caddy module information. +func (MatchPostgresClient) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.matchers.postgres_client", + New: func() caddy.Module { return new(MatchPostgresClient) }, + } +} + +// Match returns true if the connection looks like the Postgres protocol and +// passes any `application_name` parameter matchers +func (m MatchPostgresClient) Match(cx *layer4.Connection) (bool, error) { + b, err := newMessageFromConn(cx) + if err != nil { + return false, err + } + + m.startup = newStartupMessage(b) + + // Reject if this is a SSLRequest as it has no params + if m.startup.IsSSL() { + return false, nil + } + // Check supported protocol - if majorVersion := code >> 16; majorVersion < 3 { + if !m.startup.IsSupported() { return false, errors.New("pg protocol < 3.0 is not supported") } - // Try parsing Postgres Params - startup := &startupMessage{ProtocolVersion: code, Parameters: make(map[string]string)} - for { - k := b.ReadString() - if k == "" { - break - } - startup.Parameters[k] = b.ReadString() + // Is there a application_name to check? + name, ok := m.startup.Parameters["application_name"] + if !ok { + return false, nil + } + + // Check clients list + return slices.Contains(m.Client, name), nil +} + +// MatchPostgresSSL is able to require/reject Postgres SSL connections. +type MatchPostgresSSL struct { + Disabled bool +} + +// CaddyModule returns the Caddy module information. +func (MatchPostgresSSL) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.matchers.postgres_ssl", + New: func() caddy.Module { return new(MatchPostgresSSL) }, } - // TODO(metafeather): match on param values: user, database, options, etc +} - return len(startup.Parameters) > 0, nil +// Match checks whether the connection is a Postgres SSL request. +func (m MatchPostgresSSL) Match(cx *layer4.Connection) (bool, error) { + b, err := newMessageFromConn(cx) + if err != nil { + return false, err + } + + code := b.ReadUint32() + disabled := !isSSLRequest(code) + + // Enforce SSL enabled + if !m.Disabled && !disabled { + return true, nil + } + // Enforce SSL disabled + if m.Disabled && disabled { + return true, nil + } + return false, nil } // Interface guard var _ layer4.ConnMatcher = (*MatchPostgres)(nil) +var _ layer4.ConnMatcher = (*MatchPostgresClient)(nil) +var _ layer4.ConnMatcher = (*MatchPostgresSSL)(nil) diff --git a/modules/l4postgres/matcher_test.go b/modules/l4postgres/matcher_test.go new file mode 100644 index 0000000..010d5d7 --- /dev/null +++ b/modules/l4postgres/matcher_test.go @@ -0,0 +1,248 @@ +package l4postgres + +import ( + "bytes" + "io" + "net" + "sync" + "testing" + + "github.com/mholt/caddy-l4/layer4" + "go.uber.org/zap" +) + +// MessageReader allows any Example to be used in tests as data +type MessageReader interface { + Read() ([]byte, error) +} + +// Example extends StartupMessage with utils to create messages +type example struct { + startupMessage +} + +// Read gets []byte from an Example struct as the raw protocol is similar to +// "user\u0000alice\u0000database\u0000stars_db" +func (x *example) Read() ([]byte, error) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(x.Reader()) + return buf.Bytes(), err +} + +func assertNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("Unexpected error: %s\n", err) + } +} + +func closePipe(wg *sync.WaitGroup, c1 net.Conn, c2 net.Conn) { + wg.Wait() + _ = c1.Close() + _ = c2.Close() +} + +func matchTester(t *testing.T, matcher layer4.ConnMatcher, data []byte) (bool, error) { + wg := &sync.WaitGroup{} + in, out := net.Pipe() + defer closePipe(wg, in, out) + + cx := layer4.WrapConnection(in, []byte{}, zap.NewNop()) + + wg.Add(1) + go func() { + defer wg.Done() + defer out.Close() + _, err := out.Write(data) + assertNoError(t, err) + }() + + matched, err := matcher.Match(cx) + + _, _ = io.Copy(io.Discard, in) + + return matched, err +} + +func Fatalf(t *testing.T, err error, matched bool, expect bool, explain string) { + t.Helper() + if matched != expect { + if err != nil { + t.Logf("Unexpected error: %s\n", err) + } + t.Fatalf("matcher did not match: returned %t != expected %t; %s", matched, expect, explain) + } +} + +func TestPostgres(t *testing.T) { + // ref: https://go.dev/wiki/TableDrivenTests + tests := []struct { + name string + matcher layer4.ConnMatcher + data MessageReader + expect bool + explain string + }{ + { + name: "rejects an empty StartupMessage", + matcher: MatchPostgres{}, + data: &example{ + startupMessage: startupMessage{}, + }, + expect: false, + explain: "an empty Postgres StartupMessage has no version to check", + }, + { + name: "allows any SSLRequest", + matcher: MatchPostgres{}, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: sslRequestCode, + }, + }, + expect: true, + explain: "any Postgres SSLRequest should be accepted", + }, + { + name: "allows any StartupMessage with a supported ProtocolVersion", + matcher: MatchPostgres{}, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: 196608, // v3.0 + }, + }, + expect: true, + explain: "any Postgres StartupMessage without parameters should be rejected", + }, + { + name: "allows any StartupMessage with parameters", + matcher: MatchPostgres{}, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: 196608, // v3.0 + Parameters: map[string]string{ + "user": "alice", + "database": "stars_db", + }, + }, + }, + expect: true, + explain: "any Postgres StartupMessage with parameters should be accepted", + }, + } + for _, tc := range tests { + tc := tc // NOTE: /wiki/CommonMistakes#using-goroutines-on-loop-iterator-variables + t.Run(tc.name, func(t *testing.T) { + data, err := tc.data.Read() + assertNoError(t, err) + + matched, err := matchTester(t, tc.matcher, data) + Fatalf(t, err, matched, tc.expect, tc.explain) + }) + } +} + +func TestPostgresSSL(t *testing.T) { + // ref: https://go.dev/wiki/TableDrivenTests + tests := []struct { + name string + matcher layer4.ConnMatcher + data MessageReader + expect bool + explain string + }{ + { + name: "rejects an empty StartupMessage", + matcher: MatchPostgresSSL{}, + data: &example{ + startupMessage: startupMessage{}, + }, + expect: false, + explain: "an empty Postgres StartupMessage has no version to check", + }, + { + name: "implicitly requires SSL Requests", + matcher: MatchPostgresSSL{}, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: sslRequestCode, + }, + }, + expect: true, + explain: "SSL is enabled", + }, + { + name: "explictly requires SSL Requests", + matcher: MatchPostgresSSL{ + Disabled: false, + }, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: sslRequestCode, + }, + }, + expect: true, + explain: "SSL is enabled", + }, + { + name: "implicitly rejects non-SSL Requests", + matcher: MatchPostgresSSL{}, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: 196608, + }, + }, + expect: false, + explain: "SSL is enabled", + }, + { + name: "explictly rejects non-SSL Requests", + matcher: MatchPostgresSSL{ + Disabled: false, + }, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: 196608, + }, + }, + expect: false, + explain: "SSL is enabled", + }, + { + name: "explictly requires non-SSL Requests", + matcher: MatchPostgresSSL{ + Disabled: true, + }, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: 196608, + }, + }, + expect: true, + explain: "SSL is disabled", + }, + { + name: "explictly rejects SSL Requests", + matcher: MatchPostgresSSL{ + Disabled: true, + }, + data: &example{ + startupMessage: startupMessage{ + ProtocolVersion: sslRequestCode, + }, + }, + expect: false, + explain: "SSL is disabled", + }, + } + for _, tc := range tests { + tc := tc // NOTE: /wiki/CommonMistakes#using-goroutines-on-loop-iterator-variables + t.Run(tc.name, func(t *testing.T) { + data, err := tc.data.Read() + assertNoError(t, err) + + matched, err := matchTester(t, tc.matcher, data) + Fatalf(t, err, matched, tc.expect, tc.explain) + }) + } +} diff --git a/modules/l4postgres/messages.go b/modules/l4postgres/messages.go new file mode 100644 index 0000000..d7663d5 --- /dev/null +++ b/modules/l4postgres/messages.go @@ -0,0 +1,142 @@ +// ref: https://github.com/rueian/pgbroker/blob/master/message/util.go +package l4postgres + +import ( + "bytes" + "encoding/binary" + "io" +) + +const ( + // byte size of the message length field + lengthFieldSize = 4 +) + +// Message provides readers for various types and +// updates the offset after each read/write operation. +type message struct { + data []byte + offset uint32 +} + +func (b *message) ReadUint32() (r uint32) { + r = binary.BigEndian.Uint32(b.data[b.offset : b.offset+4]) + b.offset += 4 + return r +} + +func (b *message) ReadString() (r string) { + end := b.offset + max := uint32(len(b.data)) + for ; end != max && b.data[end] != 0; end++ { + } + r = string(b.data[b.offset:end]) + b.offset = end + 1 + return r +} + +func (b *message) WriteByte(i byte) { + b.data[b.offset] = i + b.offset++ +} + +func (b *message) WriteByteN(i []byte) { + for _, s := range i { + b.WriteByte(s) + } +} + +func (b *message) WriteUint32(i uint32) { + binary.BigEndian.PutUint32(b.data[b.offset:b.offset+4], i) + b.offset += 4 +} + +func (b *message) WriteString(i string) { + b.WriteByteN([]byte(i)) + b.WriteByte(0) +} + +func (b *message) Length() int { + return len(b.data) +} + +func (b *message) Reader() io.Reader { + length := make([]byte, lengthFieldSize) + binary.BigEndian.PutUint32(length, uint32(b.Length()+lengthFieldSize)) + return io.MultiReader( + bytes.NewReader(length), + bytes.NewReader(b.data), + ) +} + +func newMessage(len int) *message { + return &message{data: make([]byte, len)} +} + +// NewMessageFromBytes wraps the raw bytes of a message to enable processing +func newMessageFromBytes(b []byte) *message { + return &message{data: b} +} + +// StartupMessage contains the values parsed from the first message received. +// This should be either a SSLRequest or StartupMessage +type startupMessage struct { + ProtocolVersion uint32 + Parameters map[string]string +} + +func (m *startupMessage) Reader() io.Reader { + length := lengthFieldSize + for k, v := range m.Parameters { + length += len(k) + 1 + length += len(v) + 1 + } + length += 1 + b := newMessage(length) + b.WriteUint32(m.ProtocolVersion) + for k, v := range m.Parameters { + b.WriteString(k) + b.WriteString(v) + } + b.WriteByte(0) + return b.Reader() +} + +// IsSSL confirms this is a SSLRequest +func (s startupMessage) IsSSL() bool { + return isSSLRequest(s.ProtocolVersion) +} + +// IsSupported confirms this is a supported version of Postgres +func (s startupMessage) IsSupported() bool { + return isSupported(s.ProtocolVersion) +} + +// NewStartupMessage creates a new startupMessage from the message bytes +func newStartupMessage(b *message) *startupMessage { + return &startupMessage{ + ProtocolVersion: b.ReadUint32(), + Parameters: parseParameters(b), + } +} + +func isSSLRequest(code uint32) bool { + return code == sslRequestCode +} + +func isSupported(code uint32) bool { + majorVersion := code >> 16 + return majorVersion >= 3 +} + +func parseParameters(b *message) map[string]string { + params := make(map[string]string) + for { + k := b.ReadString() + if k == "" { + break + } + params[k] = b.ReadString() + } + return params +} diff --git a/modules/l4postgres/messages_test.go b/modules/l4postgres/messages_test.go new file mode 100644 index 0000000..f150b42 --- /dev/null +++ b/modules/l4postgres/messages_test.go @@ -0,0 +1,18 @@ +package l4postgres + +import "testing" + +func TestIsSSLRequest(t *testing.T) { + if !isSSLRequest(80877103) { + t.Fatalf("magic SSL number is not recognised") + } +} + +func TestIsSupported(t *testing.T) { + if isSupported(1234) { + t.Fatalf("protocol version should require > v3.0") + } + if !isSupported(196608) { // v3.0 + t.Fatalf("protocol version should require > v3.0") + } +}