diff --git a/hilla-test-extension/deployment/pom.xml b/hilla-test-extension/deployment/pom.xml index b0ff9355..241f81d5 100644 --- a/hilla-test-extension/deployment/pom.xml +++ b/hilla-test-extension/deployment/pom.xml @@ -43,6 +43,32 @@ quarkus-junit5-internal test + + io.rest-assured + rest-assured + test + + + io.quarkus + quarkus-hibernate-validator + test + + + io.quarkus + quarkus-security-deployment + test + + + io.quarkus + quarkus-security-test-utils + test + ${quarkus.version} + + + org.assertj + assertj-core + 3.24.2 + diff --git a/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java b/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java index 51c00097..17334167 100644 --- a/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java +++ b/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java @@ -222,13 +222,17 @@ void registerHillaSecurityPolicy(HttpBuildTimeConfig buildTimeConfig, @BuildStep @Record(ExecutionTime.RUNTIME_INIT) void registerHillaFormAuthenticationMechanism( + HttpBuildTimeConfig httpBuildTimeConfig, HillaSecurityRecorder recorder, BuildProducer producer) { - producer.produce(SyntheticBeanBuildItem - .configure(HillaFormAuthenticationMechanism.class) - .types(HttpAuthenticationMechanism.class).setRuntimeInit() - .scope(Singleton.class).alternativePriority(1) - .supplier(recorder.setupFormAuthenticationMechanism()).done()); + if (httpBuildTimeConfig.auth.form.enabled) { + producer.produce(SyntheticBeanBuildItem + .configure(HillaFormAuthenticationMechanism.class) + .types(HttpAuthenticationMechanism.class).setRuntimeInit() + .scope(Singleton.class).alternativePriority(1) + .supplier(recorder.setupFormAuthenticationMechanism()) + .done()); + } } @BuildStep diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointControllerTest.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointControllerTest.java new file mode 100644 index 00000000..e82ce2f4 --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointControllerTest.java @@ -0,0 +1,111 @@ +package org.acme.hilla.test.extension.deployment; + +import dev.hilla.exception.EndpointValidationException; +import io.quarkus.test.QuarkusUnitTest; +import io.restassured.RestAssured; +import io.restassured.http.ContentType; +import org.acme.hilla.test.extension.deployment.TestUtils.Parameters; +import org.acme.hilla.test.extension.deployment.endpoints.TestEndpoint; +import org.hamcrest.CoreMatchers; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import static org.acme.hilla.test.extension.deployment.TestUtils.givenEndpointRequest; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; + +class EndpointControllerTest { + + private static final String ENDPOINT_NAME = TestEndpoint.class + .getSimpleName(); + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestUtils.class, TestEndpoint.class)); + + @Test + void invokeEndpoint_singleSimpleParameter() { + String msg = "A text message"; + givenEndpointRequest(ENDPOINT_NAME, "echo", + Parameters.param("message", msg)).then().assertThat() + .statusCode(200).and().body(equalTo("\"" + msg + "\"")); + } + + @Test + void invokeEndpoint_singleComplexParameter() { + String msg = "A text message"; + TestEndpoint.Pojo pojo = new TestEndpoint.Pojo(10, msg); + givenEndpointRequest(ENDPOINT_NAME, "pojo", + Parameters.param("pojo", pojo)).then().assertThat() + .statusCode(200).and().body("number", equalTo(100)).and() + .body("text", equalTo(msg + msg)); + } + + @Test + void invokeEndpoint_multipleParameters() { + givenEndpointRequest(ENDPOINT_NAME, "calculate", + Parameters.param("operator", "+").add("a", 10).add("b", 20)) + .then().assertThat().statusCode(200).and().body(equalTo("30")); + } + + @Test + void invokeEndpoint_wrongParametersOrder_badRequest() { + givenEndpointRequest(ENDPOINT_NAME, "calculate", + Parameters.param("a", 10).add("operator", "+").add("b", 20)) + .then().assertThat().statusCode(400).and() + .body("type", equalTo(EndpointValidationException.class.getName())) + .and() + .body("message", + CoreMatchers.allOf(containsString("Validation error"), + containsString("'TestEndpoint'"), + containsString("'calculate'"))) + .body("validationErrorData[0].parameterName", + equalTo("operator")); + } + + @Test + void invokeEndpoint_wrongNumberOfParameters_badRequest() { + givenEndpointRequest(ENDPOINT_NAME, "calculate", + Parameters.param("operator", "+")).then().assertThat() + .statusCode(400).and().body("message", + CoreMatchers.allOf( + containsString( + "Incorrect number of parameters"), + containsString("'TestEndpoint'"), + containsString("'calculate'"), + containsString("expected: 3, got: 1"))); + } + + @Test + void invokeEndpoint_wrongEndpointName_notFound() { + givenEndpointRequest("NotExistingTestEndpoint", "calculate", + Parameters.param("operator", "+")).then().assertThat() + .statusCode(404); + } + + @Test + void invokeEndpoint_wrongMethodName_notFound() { + givenEndpointRequest(ENDPOINT_NAME, "notExistingMethod", + Parameters.param("operator", "+")).then().assertThat() + .statusCode(404); + } + + @Test + void invokeEndpoint_emptyMethodName_notFound() { + givenEndpointRequest(ENDPOINT_NAME, "", + Parameters.param("operator", "+")).then().assertThat() + .statusCode(404); + } + + @Test + void invokeEndpoint_missingMethodName_notFound() { + RestAssured.given().contentType(ContentType.JSON) + .cookie("csrfToken", "CSRF_TOKEN") + .header("X-CSRF-Token", "CSRF_TOKEN").basePath("/connect") + .when().post(ENDPOINT_NAME).then().assertThat().statusCode(404); + } + +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointSecurityTest.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointSecurityTest.java new file mode 100644 index 00000000..e4b7d57c --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointSecurityTest.java @@ -0,0 +1,154 @@ +package org.acme.hilla.test.extension.deployment; + +import java.util.function.UnaryOperator; +import java.util.stream.Stream; + +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.restassured.specification.RequestSpecification; +import org.acme.hilla.test.extension.deployment.TestUtils.User; +import org.acme.hilla.test.extension.deployment.endpoints.SecureEndpoint; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import static org.acme.hilla.test.extension.deployment.TestUtils.ADMIN; +import static org.acme.hilla.test.extension.deployment.TestUtils.GUEST; +import static org.acme.hilla.test.extension.deployment.TestUtils.USER; +import static org.acme.hilla.test.extension.deployment.TestUtils.givenEndpointRequest; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; + +class EndpointSecurityTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestIdentityProvider.class, + TestIdentityController.class, TestUtils.class, + SecureEndpoint.class) + .addAsResource(new StringAsset( + "quarkus.http.auth.basic=true\nquarkus.http.auth.proactive=true\n"), + "application.properties")); + public static final String SECURE_ENDPOINT = "SecureEndpoint"; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add(ADMIN.username, ADMIN.pwd, "ADMIN") + .add(USER.username, USER.pwd, "USER") + .add(GUEST.username, GUEST.pwd, "GUEST"); + } + + @Test + void securedEndpoint_permitAll_authenticatedUsersAllowed() { + Stream.of(USER, GUEST) + .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT, + "authenticated", authenticate(user)).then().assertThat() + .statusCode(200).and() + .body(equalTo("\"AUTHENTICATED\""))); + + givenEndpointRequest(SECURE_ENDPOINT, "authenticated").then() + .assertThat().statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", containsString("reason: 'Access denied'")); + } + + @Test + void securedEndpoint_adminOnly_onlyAdminAllowed() { + givenEndpointRequest(SECURE_ENDPOINT, "adminOnly", authenticate(ADMIN)) + .then().assertThat().statusCode(200).and() + .body(equalTo("\"ADMIN\"")); + + Stream.of(USER, GUEST) + .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT, + "adminOnly", authenticate(user)).then().assertThat() + .statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", + containsString("reason: 'Access denied'"))); + + givenEndpointRequest(SECURE_ENDPOINT, "adminOnly").then().assertThat() + .statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", containsString("reason: 'Access denied'")); + } + + @Test + void securedEndpoint_userOnly_onlyUserAllowed() { + givenEndpointRequest(SECURE_ENDPOINT, "userOnly", authenticate(USER)) + .then().assertThat().statusCode(200).and() + .body(equalTo("\"USER\"")); + + Stream.of(ADMIN, GUEST) + .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT, + "userOnly", authenticate(user)).then().assertThat() + .statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", + containsString("reason: 'Access denied'"))); + + givenEndpointRequest(SECURE_ENDPOINT, "userOnly").then().assertThat() + .statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", containsString("reason: 'Access denied'")); + } + + @Test + void securedEndpoint_adminAndUserOnly_onlyAdminAndUserAllowed() { + Stream.of(ADMIN, USER) + .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT, + "userAndAdmin", authenticate(user)).then().assertThat() + .statusCode(200).and() + .body(equalTo("\"USER AND ADMIN\""))); + + givenEndpointRequest(SECURE_ENDPOINT, "userAndAdmin", + authenticate(GUEST)).then().assertThat().statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", containsString("reason: 'Access denied'")); + + givenEndpointRequest(SECURE_ENDPOINT, "userAndAdmin").then() + .assertThat().statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", containsString("reason: 'Access denied'")); + } + + @Test + void securedEndpoint_deny_notAllowed() { + Stream.of(ADMIN, USER, GUEST) + .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT, "deny", + authenticate(user)).then().assertThat().statusCode(401) + .and().body("message", containsString(SECURE_ENDPOINT)) + .body("message", + containsString("reason: 'Access denied'"))); + + givenEndpointRequest(SECURE_ENDPOINT, "deny").then().assertThat() + .statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", containsString("reason: 'Access denied'")); + } + + @Test + void securedEndpoint_notAnnotatedMethod_denyAll() { + Stream.of(ADMIN, USER, GUEST) + .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT, + "denyByDefault", authenticate(user)).then().assertThat() + .statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", + containsString("reason: 'Access denied'"))); + + givenEndpointRequest(SECURE_ENDPOINT, "denyByDefault").then() + .assertThat().statusCode(401).and() + .body("message", containsString(SECURE_ENDPOINT)) + .body("message", containsString("reason: 'Access denied'")); + } + + private static UnaryOperator authenticate(User user) { + return spec -> spec.auth().preemptive().basic(user.username, user.pwd); + } +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/HillaPushClient.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/HillaPushClient.java new file mode 100644 index 00000000..e8fea1b0 --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/HillaPushClient.java @@ -0,0 +1,182 @@ +package org.acme.hilla.test.extension.deployment; + +import javax.websocket.CloseReason; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.Session; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.URLEncoder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.logging.LogManager; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.assertj.core.api.AbstractStringAssert; +import org.assertj.core.api.ObjectAssert; +import org.assertj.core.api.StringAssert; +import org.junit.jupiter.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; + +public class HillaPushClient extends Endpoint + implements MessageHandler.Whole { + + private static final Logger LOGGER = LoggerFactory + .getLogger(HillaPushClient.class); + private static final AtomicInteger CLIENT_ID_GEN = new AtomicInteger(); + + final LinkedBlockingDeque messages = new LinkedBlockingDeque<>(); + + final String id; + private final String endpointName; + private final String methodName; + private final List parameters = new ArrayList<>(); + private final ObjectMapper objectMapper = new ObjectMapper(); + + private Session session; + + public HillaPushClient(String endpointName, String methodName, + Object... parameters) { + this.id = Integer.toString(CLIENT_ID_GEN.getAndIncrement()); + this.endpointName = endpointName; + this.methodName = methodName; + this.parameters.addAll(Arrays.asList(parameters)); + } + + @Override + public void onOpen(Session session, EndpointConfig config) { + LOGGER.trace("Client {} connected", id); + this.session = session; + messages.add("CONNECT"); + session.addMessageHandler(this); + session.getAsyncRemote().sendText(createSubscribeMessage()); + } + + @Override + public void onClose(Session session, CloseReason closeReason) { + LOGGER.trace("Session closed for client {} with reason {}", id, + closeReason); + messages.add("CLOSED: " + closeReason.toString()); + session.removeMessageHandler(this); + this.session = null; + } + + @Override + public void onError(Session session, Throwable throwable) { + LOGGER.trace("Got error for client {}", id, throwable); + messages.add("ERROR: " + throwable.getMessage()); + } + + public void onMessage(String msg) { + if (msg != null && !msg.isBlank()) { + LOGGER.trace("Message received for client {} :: {}", id, msg); + messages.add(msg); + } else { + LOGGER.trace("Ignored empty message for client {} :: {}", id, msg); + } + } + + public void cancel() { + LOGGER.trace("Canceling subscription for client {}", id); + if (session != null) { + try { + session.getBasicRemote().sendText(createUnsubscribeMessage()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } else { + throw new IllegalStateException("Not connected"); + } + } + + public void subscribe() { + LOGGER.trace("Subscribing client {} :: {}/{} ({})", id, endpointName, + methodName, parameters); + if (session != null) { + try { + session.getBasicRemote().sendText(createSubscribeMessage()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } else { + throw new IllegalStateException("Not connected"); + } + } + + public String pollMessage(long timeout, TimeUnit unit) + throws InterruptedException { + String message = messages.poll(timeout, unit); + if (message != null) { + // remove atmosphere internal identifier, to get only the + // application message + message = message.replaceFirst("\\d+\\|", ""); + } + return message; + } + + public void assertMessageReceived(long timeout, TimeUnit unit, + String expected) throws InterruptedException { + String msg = pollMessage(timeout, unit); + Assertions.assertEquals(expected, msg); + } + + public void assertMessageReceived(long timeout, TimeUnit unit, + Consumer> consumer) + throws InterruptedException { + String msg = pollMessage(timeout, unit); + AbstractStringAssert stringAssert = assertThat(msg).isNotNull(); + consumer.accept(stringAssert); + } + + private String createSubscribeMessage() { + LinkedHashMap params = new LinkedHashMap<>(); + params.put("@type", "subscribe"); + params.put("id", id); + params.put("endpointName", endpointName); + params.put("methodName", methodName); + params.put("params", this.parameters); + try { + return objectMapper.writeValueAsString(params); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private String createUnsubscribeMessage() { + LinkedHashMap params = new LinkedHashMap<>(); + params.put("@type", "unsubscribe"); + params.put("id", id); + try { + return objectMapper.writeValueAsString(params); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + static URI createPUSHConnectURI(URI baseURI) { + String contentType = URLEncoder + .encode("application/json; charset=UTF-8", UTF_8); + return URI.create(baseURI.toASCIIString() + // + "?X-Atmosphere-tracking-id=" + UUID.randomUUID() // + + "&X-Atmosphere-Transport=websocket" // + + "&X-Atmosphere-TrackMessageSize=true" // + + "&Content-Type=" + contentType // + + "&X-atmo-protocol=true&X-CSRF-Token=" + UUID.randomUUID()); + } + +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/ReactiveEndpointTest.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/ReactiveEndpointTest.java new file mode 100644 index 00000000..6ba721a5 --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/ReactiveEndpointTest.java @@ -0,0 +1,155 @@ +package org.acme.hilla.test.extension.deployment; + +import javax.enterprise.context.control.ActivateRequestContext; +import javax.websocket.ContainerProvider; +import javax.websocket.Session; +import java.net.URI; +import java.util.LinkedHashMap; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.restassured.RestAssured; +import io.restassured.http.ContentType; +import org.acme.hilla.test.extension.deployment.endpoints.ReactiveEndpoint; +import org.hamcrest.Matchers; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +class ReactiveEndpointTest { + private static final String ENDPOINT_NAME = ReactiveEndpoint.class + .getSimpleName(); + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(ReactiveEndpoint.class, HillaPushClient.class) + .add(new StringAsset( + "com.vaadin.experimental.hillaPush=true"), + "vaadin-featureflags.properties")); + + @TestHTTPResource("/HILLA/push") + URI uri; + + @Test + @ActivateRequestContext + void reactiveEndpoint_messagesPushedToTheClient() throws Exception { + URI connectURI = HillaPushClient.createPUSHConnectURI(uri); + String counterName = UUID.randomUUID().toString(); + HillaPushClient client = new HillaPushClient(ENDPOINT_NAME, "count", + counterName); + try (Session ignored = ContainerProvider.getWebSocketContainer() + .connectToServer(client, null, connectURI)) { + assertThatClientIsConnected(client); + for (int i = 1; i < 10; i++) { + assertThatPushUpdateHasBeenReceived(client, i); + } + } + assertThatConnectionHasBeenClosed(client); + + assertCounterValue(counterName, 9); + } + + @Test + @ActivateRequestContext + void cancelableReactiveEndpoint_clientCancel_serverUnsubscribeCallBackInvoked() + throws Exception { + URI connectURI = HillaPushClient.createPUSHConnectURI(uri); + String counterName = UUID.randomUUID().toString(); + HillaPushClient client = new HillaPushClient(ENDPOINT_NAME, + "cancelableCount", counterName); + try (Session ignored = ContainerProvider.getWebSocketContainer() + .connectToServer(client, null, connectURI)) { + assertThatClientIsConnected(client); + for (int i = 1; i < 10; i++) { + assertThatPushUpdateHasBeenReceived(client, i); + } + client.cancel(); + assertCounterValue(counterName, -1); + } + assertThatConnectionHasBeenClosed(client); + } + + @Test + @ActivateRequestContext + void cancelableReactiveEndpoint_clientDisconnectWithoutCancel_serverUnsubscribeCallBackInvoked() + throws Exception { + URI connectURI = HillaPushClient.createPUSHConnectURI(uri); + String counterName = UUID.randomUUID().toString(); + HillaPushClient client = new HillaPushClient(ENDPOINT_NAME, + "cancelableCount", counterName); + try (Session ignored = ContainerProvider.getWebSocketContainer() + .connectToServer(client, null, connectURI)) { + assertThatClientIsConnected(client); + for (int i = 1; i < 10; i++) { + assertThatPushUpdateHasBeenReceived(client, i); + } + } + assertThatConnectionHasBeenClosed(client); + + assertCounterValue(counterName, -1); + } + + @Test + @ActivateRequestContext + void cancelableReactiveEndpoint_subscribeAfterCancel_connectionNotClosedAndMessagesPushed() + throws Exception { + URI connectURI = HillaPushClient.createPUSHConnectURI(uri); + String counterName = UUID.randomUUID().toString(); + HillaPushClient client = new HillaPushClient(ENDPOINT_NAME, + "cancelableCount", counterName); + try (Session ignored = ContainerProvider.getWebSocketContainer() + .connectToServer(client, null, connectURI)) { + assertThatClientIsConnected(client); + for (int i = 1; i < 5; i++) { + assertThatPushUpdateHasBeenReceived(client, i); + } + client.cancel(); + assertCounterValue(counterName, -1); + + client.subscribe(); + for (int i = 0; i < 3; i++) { + assertThatPushUpdateHasBeenReceived(client, i); + } + client.cancel(); + assertCounterValue(counterName, -1); + + } + assertThatConnectionHasBeenClosed(client); + } + + private static void assertThatClientIsConnected(HillaPushClient client) + throws InterruptedException { + client.assertMessageReceived(10, TimeUnit.SECONDS, "CONNECT"); + } + + private static void assertThatConnectionHasBeenClosed( + HillaPushClient client) throws InterruptedException { + client.assertMessageReceived(1, TimeUnit.SECONDS, + message -> message.isNotNull().startsWith("CLOSED: ")); + } + + private static void assertThatPushUpdateHasBeenReceived( + HillaPushClient client, int i) throws InterruptedException { + client.assertMessageReceived(1, TimeUnit.SECONDS, + message -> message.as("Message %d", i).isEqualTo( + "{\"@type\":\"update\",\"id\":\"%s\",\"item\":%s}", + client.id, i)); + } + + private static void assertCounterValue(String counterName, int expected) { + LinkedHashMap orderedParams = new LinkedHashMap<>(); + orderedParams.put("counterName", counterName); + RestAssured.given().contentType(ContentType.JSON) + .cookie("csrfToken", "CSRF_TOKEN") + .header("X-CSRF-Token", "CSRF_TOKEN").body(orderedParams) + .basePath("/connect").when() + .post("/{endpointName}/counterValue", ENDPOINT_NAME).then() + .body(Matchers.equalTo(Integer.toString(expected))); + } + +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/ReactiveSecureEndpointTest.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/ReactiveSecureEndpointTest.java new file mode 100644 index 00000000..1d65cd04 --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/ReactiveSecureEndpointTest.java @@ -0,0 +1,167 @@ +package org.acme.hilla.test.extension.deployment; + +import javax.enterprise.context.control.ActivateRequestContext; +import javax.websocket.ClientEndpointConfig; +import javax.websocket.ContainerProvider; +import javax.websocket.Session; +import java.net.URI; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.vertx.core.http.HttpHeaders; +import org.acme.hilla.test.extension.deployment.endpoints.ReactiveSecureEndpoint; +import org.assertj.core.api.AbstractStringAssert; +import org.assertj.core.api.Assertions; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.acme.hilla.test.extension.deployment.TestUtils.ADMIN; +import static org.acme.hilla.test.extension.deployment.TestUtils.ANONYMOUS; +import static org.acme.hilla.test.extension.deployment.TestUtils.GUEST; +import static org.acme.hilla.test.extension.deployment.TestUtils.USER; + +class ReactiveSecureEndpointTest { + private static final String ENDPOINT_NAME = ReactiveSecureEndpoint.class + .getSimpleName(); + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestIdentityProvider.class, + TestIdentityController.class, TestUtils.class, + ReactiveSecureEndpoint.class, HillaPushClient.class) + .addAsResource(new StringAsset( + "quarkus.http.auth.basic=true\nquarkus.http.auth.proactive=true\n"), + "application.properties") + .add(new StringAsset( + "com.vaadin.experimental.hillaPush=true"), + "vaadin-featureflags.properties")); + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add(ADMIN.username, ADMIN.pwd, "ADMIN") + .add(USER.username, USER.pwd, "USER") + .add(GUEST.username, GUEST.pwd, "GUEST"); + } + + @TestHTTPResource("/HILLA/push") + URI uri; + + @Test + @ActivateRequestContext + void securedEndpoint_permitAll_authenticatedUsersAllowed() { + Stream.of(ADMIN, USER, GUEST) + .forEach(user -> pushConnection(user, "authenticated").accept( + msg -> msg.contains("\"item\":\"AUTHENTICATED\""))); + pushConnection(ANONYMOUS, "authenticated") + .accept(assertAccessDenied("authenticated")); + } + + @Test + @ActivateRequestContext + void securedEndpoint_adminOnly_onlyAdminAllowed() { + pushConnection(ADMIN, "adminOnly") + .accept(msg -> msg.contains("\"item\":\"ADMIN\"")); + Stream.of(ANONYMOUS, USER, GUEST) + .forEach(user -> pushConnection(user, "adminOnly") + .accept(assertAccessDenied("adminOnly"))); + } + + @Test + @ActivateRequestContext + void securedEndpoint_userOnly_onlyUserAllowed() { + pushConnection(USER, "userOnly") + .accept(msg -> msg.contains("\"item\":\"USER\"")); + Stream.of(ANONYMOUS, ADMIN, GUEST) + .forEach(user -> pushConnection(user, "userOnly") + .accept(assertAccessDenied("userOnly"))); + } + + @Test + @ActivateRequestContext + void securedEndpoint_adminAndUserOnly_onlyAdminAndUserAllowed() { + Stream.of(ADMIN, USER) + .forEach(user -> pushConnection(user, "userAndAdmin").accept( + msg -> msg.contains("\"item\":\"USER AND ADMIN\""))); + Stream.of(ANONYMOUS, GUEST) + .forEach(user -> pushConnection(user, "userAndAdmin") + .accept(assertAccessDenied("userAndAdmin"))); + } + + @Test + @ActivateRequestContext + void securedEndpoint_deny_notAllowed() { + Stream.of(ANONYMOUS, ADMIN, USER, GUEST) + .forEach(user -> pushConnection(user, "deny") + .accept(assertAccessDenied("deny"))); + } + + @Test + @ActivateRequestContext + void securedEndpoint_notAnnotatedMethod_denyAll() { + Stream.of(ANONYMOUS, ADMIN, USER, GUEST) + .forEach(user -> pushConnection(user, "denyByDefault") + .accept(assertAccessDenied("denyByDefault"))); + } + + private Consumer>> pushConnection( + TestUtils.User user, String methodName) { + return asserter -> { + URI connectURI = HillaPushClient.createPUSHConnectURI(uri); + HillaPushClient client = new HillaPushClient(ENDPOINT_NAME, + methodName); + ClientEndpointConfig cec = ClientEndpointConfig.Builder.create() + .configurator(new BasicAuthConfigurator(user)).build(); + try (Session ignored = ContainerProvider.getWebSocketContainer() + .connectToServer(client, cec, connectURI)) { + client.assertMessageReceived(10, TimeUnit.SECONDS, "CONNECT"); + client.assertMessageReceived(1, TimeUnit.SECONDS, asserter); + } catch (Exception e) { + Assertions.fail("PUSH communication failed", e); + } + }; + } + + private Consumer> assertAccessDenied( + String method) { + return msg -> msg.contains("Access denied") + .containsSequence("Endpoint '", ENDPOINT_NAME, "'") + .contains(String.format("method '%s'", method)); + } + + private static class BasicAuthConfigurator + extends ClientEndpointConfig.Configurator { + + private final TestUtils.User user; + + public BasicAuthConfigurator(TestUtils.User user) { + this.user = user; + } + + @Override + public void beforeRequest(Map> headers) { + if (user.username != null && user.pwd != null) { + String credentials = user.username + ":" + user.pwd; + String authHeader = "Basic " + Base64.getEncoder() + .encodeToString(credentials.getBytes(UTF_8)); + headers.put(HttpHeaders.AUTHORIZATION.toString(), + List.of(authHeader)); + } + } + } + +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/SpringReplacementsTest.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/SpringReplacementsTest.java new file mode 100644 index 00000000..253f5d0c --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/SpringReplacementsTest.java @@ -0,0 +1,112 @@ +package org.acme.hilla.test.extension.deployment; + +import javax.enterprise.context.ApplicationScoped; +import javax.inject.Inject; +import java.security.Principal; +import java.util.Set; +import java.util.function.Function; + +import io.quarkus.security.test.utils.AuthData; +import io.quarkus.security.test.utils.IdentityMock; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import org.acme.hilla.test.extension.SpringReplacements; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import static org.acme.hilla.test.extension.deployment.TestUtils.ADMIN; +import static org.acme.hilla.test.extension.deployment.TestUtils.USER; +import static org.assertj.core.api.Assertions.assertThat; + +class SpringReplacementsTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(TestIdentityProvider.class, IdentityMock.class, + TestIdentityController.class, TestUtils.class) + .addAsResource(new StringAsset( + "quarkus.http.auth.basic=true\nquarkus.http.auth.proactive=true\n"), + "application.properties") + .add(new StringAsset( + "com.vaadin.experimental.hillaPush=true"), + "vaadin-featureflags.properties")); + + @Inject + MyBean bean; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add(ADMIN.username, ADMIN.pwd, "ADMIN") + .add(USER.username, USER.pwd, "USER"); + } + + @Test + void authenticationUtil_getSecurityHolderAuthentication_anonymous_returnsNull() { + IdentityMock.setUpAuth(IdentityMock.ANONYMOUS); + Principal principal = SpringReplacements + .authenticationUtil_getSecurityHolderAuthentication(); + assertThat(principal).isNull(); + } + + @Test + void authenticationUtil_getSecurityHolderAuthentication_authenticated_returnsPrincipal() { + IdentityMock.setUpAuth(IdentityMock.ADMIN); + Principal principal = SpringReplacements + .authenticationUtil_getSecurityHolderAuthentication(); + assertThat(principal).isNotNull().extracting(Principal::getName) + .isEqualTo("admin"); + } + + @Test + void authenticationUtil_getSecurityHolderRoleChecker_authenticated_checksRoles() { + IdentityMock.setUpAuth( + new AuthData(Set.of("ADMIN", "SUPERUSER"), false, "admin")); + Function checker = SpringReplacements + .authenticationUtil_getSecurityHolderRoleChecker(); + assertThat(checker).isNotNull(); + assertThat(checker.apply("ADMIN")).as("Check for ADMIN role").isTrue(); + assertThat(checker.apply("SUPERUSER")).as("Check for SUPERUSER role") + .isTrue(); + assertThat(checker.apply("GUEST")).as("Check for GUEST role").isFalse(); + assertThat(checker.apply("")).as("Check for blank role").isFalse(); + assertThat(checker.apply(null)).as("Check for null role").isFalse(); + } + + @Test + void authenticationUtil_getSecurityHolderRoleChecker_anonymous_checkIsAlwaysFalse() { + IdentityMock.setUpAuth(IdentityMock.ANONYMOUS); + Function checker = SpringReplacements + .authenticationUtil_getSecurityHolderRoleChecker(); + assertThat(checker).isNotNull(); + assertThat(checker.apply("ADMIN")).as("Check for ADMIN role").isFalse(); + assertThat(checker.apply("SUPERUSER")).as("Check for SUPERUSER role") + .isFalse(); + assertThat(checker.apply("GUEST")).as("Check for GUEST role").isFalse(); + assertThat(checker.apply("ANONYMOUS")).as("Check for ANONYMOUS role") + .isFalse(); + assertThat(checker.apply("")).as("Check for blank role").isFalse(); + assertThat(checker.apply(null)).as("Check for null role").isFalse(); + } + + @Test + void classUtils_getUserClass_proxiedObject_returnRawClass() { + Class userClass = SpringReplacements + .classUtils_getUserClass(bean.getClass()); + assertThat(userClass).isEqualTo(MyBean.class); + + userClass = SpringReplacements.classUtils_getUserClass(new MyBean()); + assertThat(userClass).isEqualTo(MyBean.class); + } + + @ApplicationScoped + public static class MyBean { + + } +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/TestUtils.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/TestUtils.java new file mode 100644 index 00000000..a1c507e3 --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/TestUtils.java @@ -0,0 +1,74 @@ +package org.acme.hilla.test.extension.deployment; + +import java.util.LinkedHashMap; +import java.util.function.UnaryOperator; + +import io.restassured.RestAssured; +import io.restassured.http.ContentType; +import io.restassured.response.Response; +import io.restassured.specification.RequestSpecification; + +final class TestUtils { + + static final User ANONYMOUS = new User(null, null); + static final User ADMIN = new User("admin", "admin"); + static final User USER = new User("user", "user"); + static final User GUEST = new User("guest", "guest"); + + static Response givenEndpointRequest(String endpointName, + String methodName) { + return givenEndpointRequest(endpointName, methodName, new Parameters(), + UnaryOperator.identity()); + } + + static Response givenEndpointRequest(String endpointName, String methodName, + UnaryOperator customizer) { + return givenEndpointRequest(endpointName, methodName, new Parameters(), + customizer); + } + + static Response givenEndpointRequest(String endpointName, String methodName, + Parameters parameters) { + return givenEndpointRequest(endpointName, methodName, parameters, + UnaryOperator.identity()); + } + + static Response givenEndpointRequest(String endpointName, String methodName, + Parameters parameters, + UnaryOperator customizer) { + RequestSpecification specs = RestAssured.given() + .contentType(ContentType.JSON).cookie("csrfToken", "CSRF_TOKEN") + .header("X-CSRF-Token", "CSRF_TOKEN").body(parameters.params) + .basePath("/connect"); + specs = customizer.apply(specs); + return specs.when().post("{endpointName}/{methodName}", endpointName, + methodName); + } + + public static class Parameters { + + private final LinkedHashMap params = new LinkedHashMap<>(); + + public Parameters add(String name, Object value) { + params.put(name, value); + return this; + } + + public static Parameters param(String name, Object value) { + Parameters parameters = new Parameters(); + parameters.add(name, value); + return parameters; + } + } + + static final class User { + final String username; + final String pwd; + + User(String username, String pwd) { + this.username = username; + this.pwd = pwd; + } + } + +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/ReactiveEndpoint.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/ReactiveEndpoint.java new file mode 100644 index 00000000..696e5eed --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/ReactiveEndpoint.java @@ -0,0 +1,39 @@ +package org.acme.hilla.test.extension.deployment.endpoints; + +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import dev.hilla.Endpoint; +import dev.hilla.EndpointSubscription; +import reactor.core.publisher.Flux; + +import com.vaadin.flow.server.auth.AnonymousAllowed; + +@Endpoint +@AnonymousAllowed +public class ReactiveEndpoint { + + private final ConcurrentHashMap counters = new ConcurrentHashMap<>(); + + public Flux count(String counterName) { + return Flux.interval(Duration.ofMillis(200)).onBackpressureDrop() + .map(_interval -> counters + .computeIfAbsent(counterName, + unused -> new AtomicInteger()) + .incrementAndGet()); + } + + public EndpointSubscription cancelableCount(String counterName) { + return EndpointSubscription.of(count(counterName), () -> { + counters.get(counterName).set(-1); + }); + } + + public Integer counterValue(String counterName) { + if (counters.containsKey(counterName)) { + return counters.get(counterName).get(); + } + return null; + } +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/ReactiveSecureEndpoint.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/ReactiveSecureEndpoint.java new file mode 100644 index 00000000..b6b76b2a --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/ReactiveSecureEndpoint.java @@ -0,0 +1,43 @@ +package org.acme.hilla.test.extension.deployment.endpoints; + +import javax.annotation.security.DenyAll; +import javax.annotation.security.PermitAll; +import javax.annotation.security.RolesAllowed; + +import dev.hilla.Endpoint; +import reactor.core.publisher.Flux; + +@Endpoint +public class ReactiveSecureEndpoint { + + @RolesAllowed("ADMIN") + public Flux adminOnly() { + return Flux.just("ADMIN"); + } + + @RolesAllowed("USER") + public Flux userOnly() { + return Flux.just("USER"); + } + + @RolesAllowed({ "USER", "ADMIN" }) + public Flux userAndAdmin() { + return Flux.just("USER AND ADMIN"); + } + + @PermitAll + public Flux authenticated() { + return Flux.just("AUTHENTICATED"); + } + + public Flux denyByDefault() { + throw new IllegalArgumentException( + "Method should be denied by default"); + } + + @DenyAll + public Flux deny() { + throw new IllegalArgumentException("Method denied"); + } + +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/SecureEndpoint.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/SecureEndpoint.java new file mode 100644 index 00000000..a2cd9d22 --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/SecureEndpoint.java @@ -0,0 +1,42 @@ +package org.acme.hilla.test.extension.deployment.endpoints; + +import javax.annotation.security.DenyAll; +import javax.annotation.security.PermitAll; +import javax.annotation.security.RolesAllowed; + +import dev.hilla.Endpoint; + +@Endpoint +public class SecureEndpoint { + + @RolesAllowed("ADMIN") + public String adminOnly() { + return "ADMIN"; + } + + @RolesAllowed("USER") + public String userOnly() { + return "USER"; + } + + @RolesAllowed({ "USER", "ADMIN" }) + public String userAndAdmin() { + return "USER AND ADMIN"; + } + + @PermitAll + public String authenticated() { + return "AUTHENTICATED"; + } + + public String denyByDefault() { + throw new IllegalArgumentException( + "Method should be denied by default"); + } + + @DenyAll + public String deny() { + throw new IllegalArgumentException("Method denied"); + } + +} diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/TestEndpoint.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/TestEndpoint.java new file mode 100644 index 00000000..908d0366 --- /dev/null +++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/endpoints/TestEndpoint.java @@ -0,0 +1,67 @@ +package org.acme.hilla.test.extension.deployment.endpoints; + +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import dev.hilla.Endpoint; + +import com.vaadin.flow.server.auth.AnonymousAllowed; + +@Endpoint +@AnonymousAllowed +public class TestEndpoint { + + public String echo(String message) { + return message; + } + + public int calculate(String operator, int a, int b) { + int result; + switch (operator) { + case "+": + result = a + b; + break; + case "*": + result = a * b; + break; + default: + throw new IllegalArgumentException("Invalid operation"); + } + return result; + } + + public Pojo pojo(Pojo pojo) { + return new Pojo(pojo.number * 10, pojo.text + pojo.text); + } + + public static class Pojo { + @JsonProperty + private final int number; + + @JsonProperty + private final String text; + + @JsonCreator + public Pojo(@JsonProperty int number, @JsonProperty String text) { + this.number = number; + this.text = text; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Pojo pojo = (Pojo) o; + return number == pojo.number && Objects.equals(text, pojo.text); + } + + @Override + public int hashCode() { + return Objects.hash(number, text); + } + } + +} diff --git a/hilla-test-extension/runtime/src/main/java/org/acme/hilla/test/extension/SpringReplacements.java b/hilla-test-extension/runtime/src/main/java/org/acme/hilla/test/extension/SpringReplacements.java index fac66196..ecce9f77 100644 --- a/hilla-test-extension/runtime/src/main/java/org/acme/hilla/test/extension/SpringReplacements.java +++ b/hilla-test-extension/runtime/src/main/java/org/acme/hilla/test/extension/SpringReplacements.java @@ -23,17 +23,18 @@ public static Class classUtils_getUserClass(Class clazz) { } public static Principal authenticationUtil_getSecurityHolderAuthentication() { - System.out - .println("authenticationUtil_getSecurityHolderAuthentication"); - return CurrentIdentityAssociation.current().getPrincipal(); + SecurityIdentity identity = CurrentIdentityAssociation.current(); + if (identity != null && !identity.isAnonymous()) { + return identity.getPrincipal(); + } + return null; } public static Function authenticationUtil_getSecurityHolderRoleChecker() { - System.out.println("authenticationUtil_getSecurityHolderRoleChecker"); SecurityIdentity identity = CurrentIdentityAssociation.current(); - if (identity == null) { + if (identity == null || identity.isAnonymous()) { return role -> false; } - return identity::hasRole; + return role -> role != null && identity.hasRole(role); } }