Skip to content

Commit

Permalink
feat: get client IP from the request headers (#233)
Browse files Browse the repository at this point in the history
* feat: get client IP from the request headers

* Fix doc.

* More review comments.
  • Loading branch information
sbruens authored Jan 30, 2025
1 parent 13bdb23 commit b7d1c7a
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 5 deletions.
9 changes: 4 additions & 5 deletions cmd/outline-ss-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,13 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
defer wsConn.Close()
ctx, contextCancel := context.WithCancel(context.Background())
defer contextCancel()
// TODO: Get the forwarded client address.
raddr, err := transport.MakeNetAddr("tcp", r.RemoteAddr)
clientIP, err := onet.GetClientIPFromRequest(r)
if err != nil {
slog.Error("failed to determine client address", "err", err)
w.WriteHeader(http.StatusBadGateway)
return
}
conn := &streamConn{&replaceAddrConn{Conn: wsConn, raddr: raddr}}
conn := &streamConn{&replaceAddrConn{Conn: wsConn, raddr: &net.TCPAddr{IP: clientIP}}}
streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn))
}
websocket.Handler(handler).ServeHTTP(w, r)
Expand All @@ -370,13 +369,13 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
defer wsConn.Close()
ctx, contextCancel := context.WithCancel(context.Background())
defer contextCancel()
raddr, err := transport.MakeNetAddr("udp", r.RemoteAddr)
clientIP, err := onet.GetClientIPFromRequest(r)
if err != nil {
slog.Error("failed to determine client address", "err", err)
w.WriteHeader(http.StatusBadGateway)
return
}
conn := &replaceAddrConn{Conn: wsConn, raddr: raddr}
conn := &replaceAddrConn{Conn: wsConn, raddr: &net.UDPAddr{IP: clientIP}}
associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn))
}
websocket.Handler(handler).ServeHTTP(w, r)
Expand Down
72 changes: 72 additions & 0 deletions net/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2025 The Outline Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package net

import (
"errors"
"net"
"net/http"
"strings"
)

// GetClientIPFromRequest retrieves the client's IP address from the request.
// This checks common headers that forward the client IP, falling back to the
// request's `RemoteAddr`.
func GetClientIPFromRequest(r *http.Request) (net.IP, error) {
clientIP, err := func() (string, error) {
// `Forwarded` (RFC 7239).
forwardedHeader := r.Header.Get("Forwarded")
if forwardedHeader != "" {
parts := strings.Split(forwardedHeader, ",")
firstPart := strings.TrimSpace(parts[0])
subParts := strings.Split(firstPart, ";")
for _, part := range subParts {
normalisedPart := strings.ToLower(strings.TrimSpace(part))
if strings.HasPrefix(normalisedPart, "for=") {
return normalisedPart[4:], nil
}
}
}

// `X-Forwarded-For`` is potentially a list of addresses separated with ",".
// The first item represents the original client.
xForwardedForHeader := r.Header.Get("X-Forwarded-For")
if xForwardedForHeader != "" {
parts := strings.Split(xForwardedForHeader, ",")
firstIP := strings.TrimSpace(parts[0])
return firstIP, nil
}

// `X-Real-IP`.
xRealIpHeader := r.Header.Get("X-Real-IP")
if xRealIpHeader != "" {
return xRealIpHeader, nil
}

// Fallback to the request's `RemoteAddr`, but be aware this is the last
// proxy's IP, not the client's.
ip, _, err := net.SplitHostPort(r.RemoteAddr)
return ip, err
}()
if err != nil {
return nil, err
}

parsedIP := net.ParseIP(clientIP)
if parsedIP != nil {
return parsedIP, nil
}
return nil, errors.New("no client IP found")
}
101 changes: 101 additions & 0 deletions net/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2025 The Outline Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package net

import (
"net"
"net/http"
"testing"

"github.com/stretchr/testify/require"
)

func TestGetClientIPFromRequest(t *testing.T) {
tests := []struct {
name string
headers map[string]string
remoteAddr string
wantIP string
wantErr bool
}{
{
name: "X-Forwarded-For (Single IP)",
headers: map[string]string{"X-Forwarded-For": "10.0.0.1"},
wantIP: "10.0.0.1",
},
{
name: "X-Forwarded-For (Multiple IPs)",
headers: map[string]string{"X-Forwarded-For": "10.0.0.1, 172.16.0.1"},
wantIP: "10.0.0.1",
},
{
name: "X-Real-IP",
headers: map[string]string{"X-Real-IP": "192.168.2.200"},
wantIP: "192.168.2.200",
},
{
name: "Forwarded",
headers: map[string]string{"Forwarded": "for=192.168.3.100"},
wantIP: "192.168.3.100",
},
{
name: "RemoteAddr (host:port)",
remoteAddr: "172.17.0.1:12345",
wantIP: "172.17.0.1",
},
{
name: "RemoteAddr (IP only)",
remoteAddr: "172.17.0.1",
wantErr: true,
},
{
name: "No Headers, No RemoteAddr",
wantErr: true,
},
{
name: "Invalid IP in header",
headers: map[string]string{"X-Forwarded-For": "invalid-ip"},
wantErr: true,
},
{
name: "Invalid RemoteAddr",
remoteAddr: "invalid-ip:port",
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &http.Request{
Header: make(http.Header),
RemoteAddr: tt.remoteAddr,
}
for h, v := range tt.headers {
r.Header.Set(h, v)
}

gotIP, err := GetClientIPFromRequest(r)
if !tt.wantErr {
require.NoError(t, err)
return
}

wantIP := net.ParseIP(tt.wantIP)
if !gotIP.Equal(wantIP) {
t.Errorf("err = %v, want %v", gotIP, wantIP)
}
})
}
}

0 comments on commit b7d1c7a

Please sign in to comment.