diff --git a/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs b/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs index 64ca375ddd..e928a33b81 100644 --- a/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs +++ b/src/Microsoft.AspNetCore.OData/Batch/ODataBatchReaderExtensions.cs @@ -14,7 +14,6 @@ using Microsoft.AspNet.OData.Interfaces; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Primitives; using Microsoft.OData; @@ -239,17 +238,9 @@ private static HttpContext CreateHttpContext(HttpContext originalContext) features[typeof(IHttpResponseFeature)] = new HttpResponseFeature(); - // Create a context from the factory or use the default context. - HttpContext context = null; - IHttpContextFactory httpContextFactory = originalContext.RequestServices.GetRequiredService(); - if (httpContextFactory != null) - { - context = httpContextFactory.Create(features); - } - else - { - context = new DefaultHttpContext(features); - } + // Create a context. + // IHttpContextFactory should not be used, because it resets IHttpContextAccessor.HttpContext; + HttpContext context = new DefaultHttpContext(features); // Clone parts of the request. All other parts of the request will be // populated during batch processing. diff --git a/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs b/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs index cc3c815b32..39194ddd02 100644 --- a/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs +++ b/test/UnitTest/Microsoft.AspNet.OData.Test.Shared/Batch/DefaultODataBatchHandlerTest.cs @@ -12,11 +12,14 @@ using Microsoft.AspNet.OData.Extensions; using Microsoft.AspNet.OData.Test.Abstraction; using Microsoft.AspNet.OData.Test.Common; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Xunit; #if !NETCORE using System.Web.Http; using System.Web.Http.Routing; #else +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Newtonsoft.Json; #endif @@ -709,6 +712,75 @@ public async Task SendAsync_CorrectlyCopiesHeadersToIndividualRequests( Assert.Contains(deleteRequest, responseContent); Assert.Contains(postRequest, responseContent); } + + [Fact] + public async Task ProcessBatchAsync_PreservesHttpContext() + { + var batchRef = $"batch_{Guid.NewGuid()}"; + var changesetRef = $"changeset_{Guid.NewGuid()}"; + var endpoint = "http://localhost"; + + Type[] controllers = new[] { typeof(BatchTestOrdersController), }; + var server = TestServerFactory.Create( + controllers, + config => + { + var builder = ODataConventionModelBuilderFactory.Create(config); + builder.EntitySet("BatchTestOrders"); + + config.MapODataServiceRoute("odata", null, builder.GetEdmModel(), new CustomODataBatchHandler()); + config.Expand(); + config.EnableDependencyInjection(); + }, + config => + { + config.TryAddSingleton(); + }); + + var client = TestServerFactory.CreateClient(server); + + var orderId = 2; + var createOrderPayload = $@"{{""@odata.type"":""Microsoft.AspNet.OData.Test.Batch.BatchTestOrder"",""Id"":{orderId},""Amount"":50}}"; + + var batchRequest = new HttpRequestMessage(HttpMethod.Post, $"{endpoint}/$batch"); + batchRequest.Headers.Accept.Add(MediaTypeWithQualityHeaderValue.Parse("text/plain")); + + var batchContent = $@" +--{batchRef} +Content-Type: multipart/mixed;boundary={changesetRef} + +--{changesetRef} +Content-Type: application/http +Content-Transfer-Encoding: binary +Content-ID: 1 + +POST {endpoint}/BatchTestOrders HTTP/1.1 +Content-Type: application/json;type=entry +Prefer: return=representation + +{createOrderPayload} +--{changesetRef}-- +--{batchRef} +Content-Type: application/http +Content-Transfer-Encoding: binary + +GET {endpoint}/BatchTestOrders({orderId}) HTTP/1.1 +Content-Type: application/json;type=entry +Prefer: return=representation + +--{batchRef}-- +"; + + var httpContent = new StringContent(batchContent); + httpContent.Headers.ContentType = MediaTypeHeaderValue.Parse($"multipart/mixed;boundary={batchRef}"); + httpContent.Headers.ContentLength = batchContent.Length; + batchRequest.Content = httpContent; + var response = await client.SendAsync(batchRequest); + + ExceptionAssert.DoesNotThrow(() => response.EnsureSuccessStatusCode()); + + // TODO: assert somehow? + } #endif } @@ -748,6 +820,13 @@ public class BatchTestOrder return new List { order01 }; }); + + [EnableQuery] + public SingleResult Get([FromODataUri]int key) + { + return SingleResult.Create(BatchTestOrder.Orders.Where(d => d.Id.Equals(key)).AsQueryable()); + } + public static IList Orders { get @@ -845,5 +924,22 @@ public class BatchTestHeadersCustomer { public int Id { get; set; } } + + public class CustomODataBatchHandler : DefaultODataBatchHandler + { + /// + public override async Task ProcessBatchAsync(HttpContext context, RequestDelegate nextHandler) + { + // Retrieve current httpcontext. + var httpContextAccessor = context.RequestServices.GetService(); + var beforeContext = httpContextAccessor?.HttpContext; + await base.ProcessBatchAsync(context, nextHandler); + var afterContext = httpContextAccessor?.HttpContext; + if (httpContextAccessor != null && beforeContext != afterContext) + { + throw new Exception($"{nameof(ProcessBatchAsync)} has lost HttpContext."); + } + } + } #endif }