From b7d1c7a7f2ed6b4bb818ca6e72edac56f368e324 Mon Sep 17 00:00:00 2001 From: Sander Bruens Date: Wed, 29 Jan 2025 20:28:32 -0500 Subject: [PATCH] feat: get client IP from the request headers (#233) * feat: get client IP from the request headers * Fix doc. * More review comments. --- cmd/outline-ss-server/main.go | 9 ++- net/http.go | 72 ++++++++++++++++++++++++ net/http_test.go | 101 ++++++++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 net/http.go create mode 100644 net/http_test.go diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 143d0f2e..0c00fe47 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -346,14 +346,13 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { defer wsConn.Close() ctx, contextCancel := context.WithCancel(context.Background()) defer contextCancel() - // TODO: Get the forwarded client address. - raddr, err := transport.MakeNetAddr("tcp", r.RemoteAddr) + clientIP, err := onet.GetClientIPFromRequest(r) if err != nil { slog.Error("failed to determine client address", "err", err) w.WriteHeader(http.StatusBadGateway) return } - conn := &streamConn{&replaceAddrConn{Conn: wsConn, raddr: raddr}} + conn := &streamConn{&replaceAddrConn{Conn: wsConn, raddr: &net.TCPAddr{IP: clientIP}}} streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn)) } websocket.Handler(handler).ServeHTTP(w, r) @@ -370,13 +369,13 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { defer wsConn.Close() ctx, contextCancel := context.WithCancel(context.Background()) defer contextCancel() - raddr, err := transport.MakeNetAddr("udp", r.RemoteAddr) + clientIP, err := onet.GetClientIPFromRequest(r) if err != nil { slog.Error("failed to determine client address", "err", err) w.WriteHeader(http.StatusBadGateway) return } - conn := &replaceAddrConn{Conn: wsConn, raddr: raddr} + conn := &replaceAddrConn{Conn: wsConn, raddr: &net.UDPAddr{IP: clientIP}} associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) } websocket.Handler(handler).ServeHTTP(w, r) diff --git a/net/http.go b/net/http.go new file mode 100644 index 00000000..a809e2ed --- /dev/null +++ b/net/http.go @@ -0,0 +1,72 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "errors" + "net" + "net/http" + "strings" +) + +// GetClientIPFromRequest retrieves the client's IP address from the request. +// This checks common headers that forward the client IP, falling back to the +// request's `RemoteAddr`. +func GetClientIPFromRequest(r *http.Request) (net.IP, error) { + clientIP, err := func() (string, error) { + // `Forwarded` (RFC 7239). + forwardedHeader := r.Header.Get("Forwarded") + if forwardedHeader != "" { + parts := strings.Split(forwardedHeader, ",") + firstPart := strings.TrimSpace(parts[0]) + subParts := strings.Split(firstPart, ";") + for _, part := range subParts { + normalisedPart := strings.ToLower(strings.TrimSpace(part)) + if strings.HasPrefix(normalisedPart, "for=") { + return normalisedPart[4:], nil + } + } + } + + // `X-Forwarded-For`` is potentially a list of addresses separated with ",". + // The first item represents the original client. + xForwardedForHeader := r.Header.Get("X-Forwarded-For") + if xForwardedForHeader != "" { + parts := strings.Split(xForwardedForHeader, ",") + firstIP := strings.TrimSpace(parts[0]) + return firstIP, nil + } + + // `X-Real-IP`. + xRealIpHeader := r.Header.Get("X-Real-IP") + if xRealIpHeader != "" { + return xRealIpHeader, nil + } + + // Fallback to the request's `RemoteAddr`, but be aware this is the last + // proxy's IP, not the client's. + ip, _, err := net.SplitHostPort(r.RemoteAddr) + return ip, err + }() + if err != nil { + return nil, err + } + + parsedIP := net.ParseIP(clientIP) + if parsedIP != nil { + return parsedIP, nil + } + return nil, errors.New("no client IP found") +} diff --git a/net/http_test.go b/net/http_test.go new file mode 100644 index 00000000..dfe1caa6 --- /dev/null +++ b/net/http_test.go @@ -0,0 +1,101 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "net" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetClientIPFromRequest(t *testing.T) { + tests := []struct { + name string + headers map[string]string + remoteAddr string + wantIP string + wantErr bool + }{ + { + name: "X-Forwarded-For (Single IP)", + headers: map[string]string{"X-Forwarded-For": "10.0.0.1"}, + wantIP: "10.0.0.1", + }, + { + name: "X-Forwarded-For (Multiple IPs)", + headers: map[string]string{"X-Forwarded-For": "10.0.0.1, 172.16.0.1"}, + wantIP: "10.0.0.1", + }, + { + name: "X-Real-IP", + headers: map[string]string{"X-Real-IP": "192.168.2.200"}, + wantIP: "192.168.2.200", + }, + { + name: "Forwarded", + headers: map[string]string{"Forwarded": "for=192.168.3.100"}, + wantIP: "192.168.3.100", + }, + { + name: "RemoteAddr (host:port)", + remoteAddr: "172.17.0.1:12345", + wantIP: "172.17.0.1", + }, + { + name: "RemoteAddr (IP only)", + remoteAddr: "172.17.0.1", + wantErr: true, + }, + { + name: "No Headers, No RemoteAddr", + wantErr: true, + }, + { + name: "Invalid IP in header", + headers: map[string]string{"X-Forwarded-For": "invalid-ip"}, + wantErr: true, + }, + { + name: "Invalid RemoteAddr", + remoteAddr: "invalid-ip:port", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &http.Request{ + Header: make(http.Header), + RemoteAddr: tt.remoteAddr, + } + for h, v := range tt.headers { + r.Header.Set(h, v) + } + + gotIP, err := GetClientIPFromRequest(r) + if !tt.wantErr { + require.NoError(t, err) + return + } + + wantIP := net.ParseIP(tt.wantIP) + if !gotIP.Equal(wantIP) { + t.Errorf("err = %v, want %v", gotIP, wantIP) + } + }) + } +}