From 9514d476f2fed7290497ec18fdd9bf8b46827ee6 Mon Sep 17 00:00:00 2001 From: Aaron Powell Date: Thu, 23 May 2024 16:06:46 +1000 Subject: [PATCH] Fixes #136 - filter location header properly --- .../IntegrationTests/CustomHeaderTests.cs | 40 ++++++++++++++++++- src/Teapot.Web/CustomHttpStatusCodeResult.cs | 11 ++++- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/Teapot.Web.Tests/IntegrationTests/CustomHeaderTests.cs b/src/Teapot.Web.Tests/IntegrationTests/CustomHeaderTests.cs index 48b136b..e37b02b 100644 --- a/src/Teapot.Web.Tests/IntegrationTests/CustomHeaderTests.cs +++ b/src/Teapot.Web.Tests/IntegrationTests/CustomHeaderTests.cs @@ -1,14 +1,27 @@ using Microsoft.AspNetCore.Mvc.Testing; +using System.Net.Http; namespace Teapot.Web.Tests.IntegrationTests; public class CustomHeaderTests { + [OneTimeSetUp] + public void OneTimeSetUp() + { + _httpClient = new WebApplicationFactory().CreateDefaultClient(); + } + + [OneTimeTearDown] + public void OneTimeTearDown() + { + _httpClient.Dispose(); + } + + private HttpClient _httpClient = null!; [Test] public async Task CanSetCustomHeaders() { - HttpClient httpClient = new WebApplicationFactory().CreateDefaultClient(); string uri = "/200"; string headerName = "Foo"; string headerValue = "bar"; @@ -16,7 +29,7 @@ public async Task CanSetCustomHeaders() using HttpRequestMessage request = new(HttpMethod.Get, uri); request.Headers.Add($"{StatusExtensions.CUSTOM_RESPONSE_HEADER_PREFIX}{headerName}", headerValue); - using HttpResponseMessage response = await httpClient.SendAsync(request); + using HttpResponseMessage response = await _httpClient.SendAsync(request); System.Net.Http.Headers.HttpResponseHeaders headers = response.Headers; Assert.Multiple(() => @@ -28,4 +41,27 @@ public async Task CanSetCustomHeaders() Assert.That(values!.First(), Is.EqualTo(headerValue)); }); } + + [Test] + public async Task Redirects302ToCorrectLocation() + { + string uri = "/302"; + string headerName = "Location"; + string headerValue = "example.com"; + + using HttpRequestMessage request = new(HttpMethod.Get, uri); + request.Headers.Add($"{StatusExtensions.CUSTOM_RESPONSE_HEADER_PREFIX}{headerName}", headerValue); + + using var response = await _httpClient.SendAsync(request); + + var headers = response.Headers; + Assert.Multiple(() => + { + Assert.That(headers.Contains(headerName), Is.True); + Assert.That(headers.TryGetValues(headerName, out var values), Is.True); + Assert.That(values, Is.Not.Null); + Assert.That(values!.Count(), Is.EqualTo(1)); + Assert.That(values!.First(), Is.EqualTo(headerValue)); + }); + } } diff --git a/src/Teapot.Web/CustomHttpStatusCodeResult.cs b/src/Teapot.Web/CustomHttpStatusCodeResult.cs index e131f25..7259d36 100644 --- a/src/Teapot.Web/CustomHttpStatusCodeResult.cs +++ b/src/Teapot.Web/CustomHttpStatusCodeResult.cs @@ -4,6 +4,7 @@ using Microsoft.Net.Http.Headers; using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; using Teapot.Web.Models; @@ -19,6 +20,7 @@ public class CustomHttpStatusCodeResult( { private const int SLEEP_MIN = 0; private const int SLEEP_MAX = 5 * 60 * 1000; // 5 mins in milliseconds + private static readonly string[] onlySingleHeader = ["Location"]; private static readonly MediaTypeHeaderValue jsonMimeType = new("application/json"); @@ -51,7 +53,14 @@ public async Task ExecuteAsync(HttpContext context) foreach ((string header, StringValues values) in customResponseHeaders) { - context.Response.Headers.Append(header, values); + if (onlySingleHeader.Contains(header)) + { + context.Response.Headers[header] = values; + } + else + { + context.Response.Headers.Append(header, values); + } } if (metadata.ExcludeBody || suppressBody == true)