Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(common) #62: Message routing headers #70

Merged
merged 8 commits into from
Jan 19, 2024
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package com.izivia.ocpi.toolkit.common

import com.fasterxml.jackson.core.type.TypeReference
import com.izivia.ocpi.toolkit.common.Header.OCPI_FROM_COUNTRY_CODE
import com.izivia.ocpi.toolkit.common.Header.OCPI_FROM_PARTY_ID
import com.izivia.ocpi.toolkit.common.Header.OCPI_TO_COUNTRY_CODE
import com.izivia.ocpi.toolkit.common.Header.OCPI_TO_PARTY_ID
import com.izivia.ocpi.toolkit.common.context.*
import com.izivia.ocpi.toolkit.common.validation.validate
import com.izivia.ocpi.toolkit.common.validation.validateLength
import com.izivia.ocpi.toolkit.modules.credentials.repositories.PartnerRepository
import com.izivia.ocpi.toolkit.modules.versions.domain.ModuleID
import com.izivia.ocpi.toolkit.transport.TransportClient
Expand All @@ -18,12 +25,25 @@ object Header {
const val X_LIMIT = "X-Limit"
const val LINK = "Link"
const val CONTENT_TYPE = "Content-Type"
const val OCPI_TO_PARTY_ID = "OCPI-to-party-id"
const val OCPI_TO_COUNTRY_CODE = "OCPI-to-country-code"
const val OCPI_FROM_PARTY_ID = "OCPI-from-party-id"
const val OCPI_FROM_COUNTRY_CODE = "OCPI-from-country-code"
}

object ContentType {
const val APPLICATION_JSON = "application/json"
}

fun Map<String, String>.validateMessageRoutingHeaders() {
validate {
validateLength(OCPI_TO_PARTY_ID, getByNormalizedKey(OCPI_TO_PARTY_ID).orEmpty(), 3)
validateLength(OCPI_TO_COUNTRY_CODE, getByNormalizedKey(OCPI_TO_COUNTRY_CODE).orEmpty(), 2)
validateLength(OCPI_FROM_PARTY_ID, getByNormalizedKey(OCPI_FROM_PARTY_ID).orEmpty(), 3)
validateLength(OCPI_FROM_COUNTRY_CODE, getByNormalizedKey(OCPI_FROM_COUNTRY_CODE).orEmpty(), 2)
}
}

/**
* Parse body of a paginated request. The result will be stored in a SearchResult which contains all pagination
* information.
Expand Down Expand Up @@ -130,14 +150,59 @@ fun HttpRequest.authenticate(token: String): AuthenticatedHttpRequest =
* It adds Content-Type header as "application/json" if the body is not null.
*/
private fun HttpRequest.withContentTypeHeaderIfNeeded(): HttpRequest =
withHeaders(
headers = if (body != null) {
headers.plus(Header.CONTENT_TYPE to ContentType.APPLICATION_JSON)
} else {
headers
}
if (body != null) {
withHeaders(headers = headers.plus(Header.CONTENT_TYPE to ContentType.APPLICATION_JSON))
} else {
this
}

/**
* It adds message routing header if they are set in the current coroutine context
*/
private suspend fun HttpRequest.withRequestMessageRoutingHeadersIfPresent(): HttpRequest {
val requestMessageRoutingHeaders = currentRequestMessageRoutingHeadersOrNull()

return if (requestMessageRoutingHeaders != null) {
withHeaders(headers = headers.plus(requestMessageRoutingHeaders.httpHeaders()))
} else {
this
}
}

/**
* It builds MessageRoutingHeaders from the headers of the request.
*/
fun HttpRequest.messageRoutingHeaders(): RequestMessageRoutingHeaders =
RequestMessageRoutingHeaders(
toPartyId = headers.getByNormalizedKey(OCPI_TO_PARTY_ID),
toCountryCode = headers.getByNormalizedKey(OCPI_TO_COUNTRY_CODE),
fromPartyId = headers.getByNormalizedKey(OCPI_FROM_PARTY_ID),
fromCountryCode = headers.getByNormalizedKey(OCPI_FROM_COUNTRY_CODE)
)

/**
* It builds headers from a ResponseMessageRoutingHeaders
*/
private fun RequestMessageRoutingHeaders.httpHeaders(): Map<String, String> =
mapOf(
OCPI_TO_PARTY_ID to toPartyId,
OCPI_TO_COUNTRY_CODE to toCountryCode,
OCPI_FROM_PARTY_ID to fromPartyId,
OCPI_FROM_COUNTRY_CODE to fromCountryCode
)
.filter { it.value != null }
.mapValues { it.value!! }

fun ResponseMessageRoutingHeaders.httpHeaders(): Map<String, String> =
mapOf(
OCPI_TO_PARTY_ID to toPartyId,
OCPI_TO_COUNTRY_CODE to toCountryCode,
OCPI_FROM_PARTY_ID to fromPartyId,
OCPI_FROM_COUNTRY_CODE to fromCountryCode
)
.filter { it.value != null }
.mapValues { it.value!! }

/**
* For debugging issues, OCPI implementations are required to include unique IDs via HTTP headers in every
* request/response.
Expand All @@ -156,15 +221,17 @@ private fun HttpRequest.withContentTypeHeaderIfNeeded(): HttpRequest =
* Dev note: When the server does a request (not a response), it must keep the same X-Correlation-ID but generate a new
* X-Request-ID. So don't call this method in that case.
*/
fun HttpRequest.withRequiredHeaders(
suspend fun HttpRequest.withRequiredHeaders(
requestId: String,
correlationId: String
): HttpRequest =
withHeaders(
headers = headers
.plus(Header.X_REQUEST_ID to requestId)
.plus(Header.X_CORRELATION_ID to correlationId)
).withContentTypeHeaderIfNeeded()
)
.withContentTypeHeaderIfNeeded()
.withRequestMessageRoutingHeadersIfPresent()

/**
* For debugging issues, OCPI implementations are required to include unique IDs via HTTP headers in every
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.izivia.ocpi.toolkit.common

import com.fasterxml.jackson.core.JsonProcessingException
import com.izivia.ocpi.toolkit.common.context.currentResponseMessageRoutingHeadersOrNull
import com.izivia.ocpi.toolkit.common.validation.toReadableString
import com.izivia.ocpi.toolkit.transport.domain.HttpException
import com.izivia.ocpi.toolkit.transport.domain.HttpRequest
Expand Down Expand Up @@ -139,10 +140,11 @@ suspend fun <T> HttpRequest.httpResponse(fn: suspend () -> OcpiResponseBody<T>):
),
headers = getDebugHeaders()
.plus(Header.CONTENT_TYPE to ContentType.APPLICATION_JSON)
.plus(currentResponseMessageRoutingHeadersOrNull()?.httpHeaders().orEmpty())
).let {
if (isPaginated) {
it.copy(
headers = (ocpiResponseBody as OcpiResponseBody<SearchResult<*>>)
headers = it.headers + (ocpiResponseBody as OcpiResponseBody<SearchResult<*>>)
.getPaginatedHeaders(request = this)
)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.izivia.ocpi.toolkit.common.context

import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext

/**
* Contains context about the current MessageRoutingHeaders
*/
data class RequestMessageRoutingHeaders(
val toPartyId: String? = null,
val toCountryCode: String? = null,
val fromPartyId: String? = null,
val fromCountryCode: String? = null
) : AbstractCoroutineContextElement(RequestMessageRoutingHeaders) {
companion object Key : CoroutineContext.Key<RequestMessageRoutingHeaders>
}

/**
* Retrieves MessageRoutingHeaders in the current coroutine if it is found.
*/
suspend fun currentRequestMessageRoutingHeadersOrNull(): RequestMessageRoutingHeaders? =
coroutineContext[RequestMessageRoutingHeaders]

/**
* Retrieves MessageRoutingHeaders in the current coroutine, and throws IllegalStateException
* if it could not be found.
*/
suspend fun currentRequestMessageRoutingHeaders(): RequestMessageRoutingHeaders =
coroutineContext[RequestMessageRoutingHeaders]
?: throw IllegalStateException("No MessageRoutingHeaders object in current coroutine context")
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.izivia.ocpi.toolkit.common.context

import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext

/**
* Contains context about the current MessageRoutingHeaders
*/
data class ResponseMessageRoutingHeaders(
var toPartyId: String? = null,
var toCountryCode: String? = null,
var fromPartyId: String? = null,
var fromCountryCode: String? = null
) : AbstractCoroutineContextElement(ResponseMessageRoutingHeaders) {
companion object Key : CoroutineContext.Key<ResponseMessageRoutingHeaders> {
/**
* Creates [ResponseMessageRoutingHeaders] by inverting "from" and "to" headers of
* the [RequestMessageRoutingHeaders].
*/
fun invertFromRequest(requestMessageRoutingHeaders: RequestMessageRoutingHeaders) =
ResponseMessageRoutingHeaders(
toPartyId = requestMessageRoutingHeaders.fromPartyId,
toCountryCode = requestMessageRoutingHeaders.fromCountryCode,
fromPartyId = requestMessageRoutingHeaders.toPartyId,
fromCountryCode = requestMessageRoutingHeaders.toCountryCode
)
}
}

/**
* Retrieves MessageRoutingHeaders in the current coroutine if it is found.
*/
suspend fun currentResponseMessageRoutingHeadersOrNull(): ResponseMessageRoutingHeaders? =
coroutineContext[ResponseMessageRoutingHeaders]

/**
* Retrieves MessageRoutingHeaders in the current coroutine, and throws IllegalStateException
* if it could not be found.
*/
suspend fun currentResponseMessageRoutingHeaders(): ResponseMessageRoutingHeaders =
coroutineContext[ResponseMessageRoutingHeaders]
?: throw IllegalStateException("No MessageRoutingHeaders object in current coroutine context")
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package com.izivia.ocpi.toolkit.samples.common

import com.izivia.ocpi.toolkit.common.OcpiException
import com.izivia.ocpi.toolkit.common.toHttpResponse
import com.izivia.ocpi.toolkit.common.*
import com.izivia.ocpi.toolkit.common.context.ResponseMessageRoutingHeaders
import com.izivia.ocpi.toolkit.common.validation.toReadableString
import com.izivia.ocpi.toolkit.transport.TransportServer
import com.izivia.ocpi.toolkit.transport.domain.*
import kotlinx.coroutines.runBlocking
import org.http4k.core.*
import org.http4k.filter.DebuggingFilters
import org.http4k.routing.RoutingHttpHandler
import org.http4k.routing.bind
import org.http4k.routing.path
import org.http4k.routing.routes
import org.http4k.routing.*
import org.http4k.server.Http4kServer
import org.http4k.server.Netty
import org.http4k.server.asServer
import org.valiktor.ConstraintViolationException

class Http4kTransportServer(
val baseUrl: String,
Expand All @@ -34,7 +33,7 @@ class Http4kTransportServer(
callback: suspend (request: HttpRequest) -> HttpResponse
) {
val pathParams = path
.filterIsInstance(VariablePathSegment::class.java)
.filterIsInstance<VariablePathSegment>()
.map { it.path }

val route = path.joinToString("/") { segment ->
Expand All @@ -58,12 +57,25 @@ class Http4kTransportServer(
.associate { (key, value) -> key to value!! },
body = req.bodyString()
)
.also { httpRequest ->
try {
httpRequest.headers.validateMessageRoutingHeaders()
} catch (e: ConstraintViolationException) {
throw OcpiClientInvalidParametersException(
message = "invalid message routing headers: " + e.toReadableString()
)
}
}
.also { httpRequest ->
runBlocking { secureFilter(httpRequest) }
}
.also { httpRequest -> filters.forEach { filter -> filter(httpRequest) } }
.let { httpRequest ->
httpRequest to runBlocking {
val requestMessageRoutingHeaders = httpRequest.messageRoutingHeaders()
val responseMessageRoutingHeaders = ResponseMessageRoutingHeaders
.invertFromRequest(requestMessageRoutingHeaders)

httpRequest to runBlocking(requestMessageRoutingHeaders + responseMessageRoutingHeaders) {
callback(httpRequest)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.izivia.ocpi.toolkit.tests.integration

import com.izivia.ocpi.toolkit.common.Header
import com.izivia.ocpi.toolkit.common.OcpiStatus
import com.izivia.ocpi.toolkit.common.context.RequestMessageRoutingHeaders
import com.izivia.ocpi.toolkit.modules.locations.LocationsCpoServer
import com.izivia.ocpi.toolkit.modules.locations.LocationsEmspClient
import com.izivia.ocpi.toolkit.modules.locations.domain.Location
Expand All @@ -10,6 +12,7 @@ import com.izivia.ocpi.toolkit.modules.versions.repositories.InMemoryVersionsRep
import com.izivia.ocpi.toolkit.samples.common.*
import com.izivia.ocpi.toolkit.tests.integration.common.BaseServerIntegrationTest
import com.izivia.ocpi.toolkit.tests.integration.mock.LocationsCpoMongoRepository
import com.izivia.ocpi.toolkit.transport.domain.HttpMethod
import com.mongodb.client.MongoDatabase
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -83,8 +86,15 @@ class LocationsIntegrationTest : BaseServerIntegrationTest() {
var dateFrom: Instant? = null
var dateTo: Instant? = null

val requestMessageRoutingHeaders = RequestMessageRoutingHeaders(
toPartyId = "AAA",
toCountryCode = "AA",
fromPartyId = "BBB",
fromCountryCode = "BB"
)

expectThat(
runBlocking {
runBlocking(requestMessageRoutingHeaders) {
locationsEmspClient.getLocations(
dateFrom = dateFrom,
dateTo = dateTo,
Expand Down Expand Up @@ -125,6 +135,36 @@ class LocationsIntegrationTest : BaseServerIntegrationTest() {
}
}

expectThat(cpoServer.requestHistory)
.hasSize(1)[0]
.and {
get { first }.and {
// request
get { method }.isEqualTo(HttpMethod.GET)
get { path }.isEqualTo("/2.2.1/locations")
get { headers[Header.OCPI_FROM_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromPartyId)
get { headers[Header.OCPI_FROM_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromCountryCode)
get { headers[Header.OCPI_TO_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toPartyId)
get { headers[Header.OCPI_TO_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toCountryCode)
}

get { second }.and {
// response
get { headers[Header.OCPI_FROM_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toPartyId)
get { headers[Header.OCPI_FROM_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.toCountryCode)
get { headers[Header.OCPI_TO_PARTY_ID] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromPartyId)
get { headers[Header.OCPI_TO_COUNTRY_CODE] }.isNotNull()
.isEqualTo(requestMessageRoutingHeaders.fromCountryCode)
}
}

limit = 100
offset = 100
dateFrom = null
Expand Down