From ce4ebe0aaa67ea059fc23915b51ec564e66f4e2b Mon Sep 17 00:00:00 2001 From: Oleksandr Savchuk Date: Tue, 14 Nov 2023 20:08:04 +0200 Subject: [PATCH] allow to use custom func for get ip from request --- tgb/webhook.go | 14 +++++++++++++- tgb/webhook_test.go | 5 +++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tgb/webhook.go b/tgb/webhook.go index 613e1eb..3ae59de 100644 --- a/tgb/webhook.go +++ b/tgb/webhook.go @@ -33,6 +33,8 @@ type Webhook struct { securitySubnets []netip.Prefix securityToken string + ipFromRequestFunc func(r *http.Request) string + isSetup bool } @@ -66,6 +68,14 @@ func WithWebhookIP(ip string) WebhookOption { } } +// WithWebhookRequestIP sets function to get the IP address from the request. +// By default the IP address is resolved through the X-Real-Ip and X-Forwarded-For headers. +func WithWebhookRequestIP(ip func(r *http.Request) string) WebhookOption { + return func(webhook *Webhook) { + webhook.ipFromRequestFunc = ip + } +} + // WithWebhookSecuritySubnets sets list of subnets which are allowed to send webhook requests. func WithWebhookSecuritySubnets(subnets ...netip.Prefix) WebhookOption { return func(webhook *Webhook) { @@ -117,6 +127,8 @@ func NewWebhook(handler Handler, client *tg.Client, url string, options ...Webho allowedUpdates: []tg.UpdateType{}, securitySubnets: defaultSubnets, securityToken: token, + + ipFromRequestFunc: realip.FromRequest, } for _, option := range options { @@ -322,7 +334,7 @@ func (webhook *Webhook) ServeRequest(ctx context.Context, r *WebhookRequest) *We // ServeHTTP is the HTTP handler for webhook requests. // Implementation of http.Handler. func (webhook *Webhook) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ip, err := netip.ParseAddr(realip.FromRequest(r)) + ip, err := netip.ParseAddr(webhook.ipFromRequestFunc(r)) if err != nil { webhook.log("failed to parse ip: %s", err) http.Error(w, "failed to parse ip", http.StatusBadRequest) diff --git a/tgb/webhook_test.go b/tgb/webhook_test.go index 0b4b046..45d6c46 100644 --- a/tgb/webhook_test.go +++ b/tgb/webhook_test.go @@ -27,6 +27,7 @@ func TestNewWebhook(t *testing.T) { assert.Equal(t, "https://example.com/webhook", webhook.url) assert.NotNil(t, webhook.handler) + assert.NotNil(t, webhook.ipFromRequestFunc) assert.NotNil(t, webhook.securityToken) assert.Len(t, webhook.securitySubnets, 2) }) @@ -39,10 +40,14 @@ func TestNewWebhook(t *testing.T) { WithWebhookSecuritySubnets(netip.MustParsePrefix("1.1.1.1/24")), WithWebhookSecurityToken("12345"), WithWebhookMaxConnections(10), + WithWebhookRequestIP(func(r *http.Request) string { + return "" + }), ) assert.Equal(t, "https://example.com/webhook", webhook.url) assert.NotNil(t, webhook.handler) + assert.NotNil(t, webhook.ipFromRequestFunc) assert.Equal(t, "12345", webhook.securityToken) assert.Len(t, webhook.securitySubnets, 1) assert.Equal(t, 10, webhook.maxConnections)