diff --git a/hilla-test-extension/deployment/pom.xml b/hilla-test-extension/deployment/pom.xml
index b0ff9355..241f81d5 100644
--- a/hilla-test-extension/deployment/pom.xml
+++ b/hilla-test-extension/deployment/pom.xml
@@ -43,6 +43,32 @@
quarkus-junit5-internal
test
+
+ io.rest-assured
+ rest-assured
+ test
+
+
+ io.quarkus
+ quarkus-hibernate-validator
+ test
+
+
+ io.quarkus
+ quarkus-security-deployment
+ test
+
+
+ io.quarkus
+ quarkus-security-test-utils
+ test
+ ${quarkus.version}
+
+
+ org.assertj
+ assertj-core
+ 3.24.2
+
diff --git a/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java b/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java
index 51c00097..17334167 100644
--- a/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java
+++ b/hilla-test-extension/deployment/src/main/java/org/acme/hilla/test/extension/deployment/HillaTestExtensionProcessor.java
@@ -222,13 +222,17 @@ void registerHillaSecurityPolicy(HttpBuildTimeConfig buildTimeConfig,
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void registerHillaFormAuthenticationMechanism(
+ HttpBuildTimeConfig httpBuildTimeConfig,
HillaSecurityRecorder recorder,
BuildProducer producer) {
- producer.produce(SyntheticBeanBuildItem
- .configure(HillaFormAuthenticationMechanism.class)
- .types(HttpAuthenticationMechanism.class).setRuntimeInit()
- .scope(Singleton.class).alternativePriority(1)
- .supplier(recorder.setupFormAuthenticationMechanism()).done());
+ if (httpBuildTimeConfig.auth.form.enabled) {
+ producer.produce(SyntheticBeanBuildItem
+ .configure(HillaFormAuthenticationMechanism.class)
+ .types(HttpAuthenticationMechanism.class).setRuntimeInit()
+ .scope(Singleton.class).alternativePriority(1)
+ .supplier(recorder.setupFormAuthenticationMechanism())
+ .done());
+ }
}
@BuildStep
diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointControllerTest.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointControllerTest.java
new file mode 100644
index 00000000..e82ce2f4
--- /dev/null
+++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointControllerTest.java
@@ -0,0 +1,111 @@
+package org.acme.hilla.test.extension.deployment;
+
+import dev.hilla.exception.EndpointValidationException;
+import io.quarkus.test.QuarkusUnitTest;
+import io.restassured.RestAssured;
+import io.restassured.http.ContentType;
+import org.acme.hilla.test.extension.deployment.TestUtils.Parameters;
+import org.acme.hilla.test.extension.deployment.endpoints.TestEndpoint;
+import org.hamcrest.CoreMatchers;
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import static org.acme.hilla.test.extension.deployment.TestUtils.givenEndpointRequest;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.equalTo;
+
+class EndpointControllerTest {
+
+ private static final String ENDPOINT_NAME = TestEndpoint.class
+ .getSimpleName();
+
+ @RegisterExtension
+ static final QuarkusUnitTest config = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
+ .addClasses(TestUtils.class, TestEndpoint.class));
+
+ @Test
+ void invokeEndpoint_singleSimpleParameter() {
+ String msg = "A text message";
+ givenEndpointRequest(ENDPOINT_NAME, "echo",
+ Parameters.param("message", msg)).then().assertThat()
+ .statusCode(200).and().body(equalTo("\"" + msg + "\""));
+ }
+
+ @Test
+ void invokeEndpoint_singleComplexParameter() {
+ String msg = "A text message";
+ TestEndpoint.Pojo pojo = new TestEndpoint.Pojo(10, msg);
+ givenEndpointRequest(ENDPOINT_NAME, "pojo",
+ Parameters.param("pojo", pojo)).then().assertThat()
+ .statusCode(200).and().body("number", equalTo(100)).and()
+ .body("text", equalTo(msg + msg));
+ }
+
+ @Test
+ void invokeEndpoint_multipleParameters() {
+ givenEndpointRequest(ENDPOINT_NAME, "calculate",
+ Parameters.param("operator", "+").add("a", 10).add("b", 20))
+ .then().assertThat().statusCode(200).and().body(equalTo("30"));
+ }
+
+ @Test
+ void invokeEndpoint_wrongParametersOrder_badRequest() {
+ givenEndpointRequest(ENDPOINT_NAME, "calculate",
+ Parameters.param("a", 10).add("operator", "+").add("b", 20))
+ .then().assertThat().statusCode(400).and()
+ .body("type", equalTo(EndpointValidationException.class.getName()))
+ .and()
+ .body("message",
+ CoreMatchers.allOf(containsString("Validation error"),
+ containsString("'TestEndpoint'"),
+ containsString("'calculate'")))
+ .body("validationErrorData[0].parameterName",
+ equalTo("operator"));
+ }
+
+ @Test
+ void invokeEndpoint_wrongNumberOfParameters_badRequest() {
+ givenEndpointRequest(ENDPOINT_NAME, "calculate",
+ Parameters.param("operator", "+")).then().assertThat()
+ .statusCode(400).and().body("message",
+ CoreMatchers.allOf(
+ containsString(
+ "Incorrect number of parameters"),
+ containsString("'TestEndpoint'"),
+ containsString("'calculate'"),
+ containsString("expected: 3, got: 1")));
+ }
+
+ @Test
+ void invokeEndpoint_wrongEndpointName_notFound() {
+ givenEndpointRequest("NotExistingTestEndpoint", "calculate",
+ Parameters.param("operator", "+")).then().assertThat()
+ .statusCode(404);
+ }
+
+ @Test
+ void invokeEndpoint_wrongMethodName_notFound() {
+ givenEndpointRequest(ENDPOINT_NAME, "notExistingMethod",
+ Parameters.param("operator", "+")).then().assertThat()
+ .statusCode(404);
+ }
+
+ @Test
+ void invokeEndpoint_emptyMethodName_notFound() {
+ givenEndpointRequest(ENDPOINT_NAME, "",
+ Parameters.param("operator", "+")).then().assertThat()
+ .statusCode(404);
+ }
+
+ @Test
+ void invokeEndpoint_missingMethodName_notFound() {
+ RestAssured.given().contentType(ContentType.JSON)
+ .cookie("csrfToken", "CSRF_TOKEN")
+ .header("X-CSRF-Token", "CSRF_TOKEN").basePath("/connect")
+ .when().post(ENDPOINT_NAME).then().assertThat().statusCode(404);
+ }
+
+}
diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointSecurityTest.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointSecurityTest.java
new file mode 100644
index 00000000..e4b7d57c
--- /dev/null
+++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/EndpointSecurityTest.java
@@ -0,0 +1,154 @@
+package org.acme.hilla.test.extension.deployment;
+
+import java.util.function.UnaryOperator;
+import java.util.stream.Stream;
+
+import io.quarkus.security.test.utils.TestIdentityController;
+import io.quarkus.security.test.utils.TestIdentityProvider;
+import io.quarkus.test.QuarkusUnitTest;
+import io.restassured.specification.RequestSpecification;
+import org.acme.hilla.test.extension.deployment.TestUtils.User;
+import org.acme.hilla.test.extension.deployment.endpoints.SecureEndpoint;
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.asset.StringAsset;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import static org.acme.hilla.test.extension.deployment.TestUtils.ADMIN;
+import static org.acme.hilla.test.extension.deployment.TestUtils.GUEST;
+import static org.acme.hilla.test.extension.deployment.TestUtils.USER;
+import static org.acme.hilla.test.extension.deployment.TestUtils.givenEndpointRequest;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.equalTo;
+
+class EndpointSecurityTest {
+
+ @RegisterExtension
+ static final QuarkusUnitTest config = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
+ .addClasses(TestIdentityProvider.class,
+ TestIdentityController.class, TestUtils.class,
+ SecureEndpoint.class)
+ .addAsResource(new StringAsset(
+ "quarkus.http.auth.basic=true\nquarkus.http.auth.proactive=true\n"),
+ "application.properties"));
+ public static final String SECURE_ENDPOINT = "SecureEndpoint";
+
+ @BeforeAll
+ public static void setupUsers() {
+ TestIdentityController.resetRoles()
+ .add(ADMIN.username, ADMIN.pwd, "ADMIN")
+ .add(USER.username, USER.pwd, "USER")
+ .add(GUEST.username, GUEST.pwd, "GUEST");
+ }
+
+ @Test
+ void securedEndpoint_permitAll_authenticatedUsersAllowed() {
+ Stream.of(USER, GUEST)
+ .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT,
+ "authenticated", authenticate(user)).then().assertThat()
+ .statusCode(200).and()
+ .body(equalTo("\"AUTHENTICATED\"")));
+
+ givenEndpointRequest(SECURE_ENDPOINT, "authenticated").then()
+ .assertThat().statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message", containsString("reason: 'Access denied'"));
+ }
+
+ @Test
+ void securedEndpoint_adminOnly_onlyAdminAllowed() {
+ givenEndpointRequest(SECURE_ENDPOINT, "adminOnly", authenticate(ADMIN))
+ .then().assertThat().statusCode(200).and()
+ .body(equalTo("\"ADMIN\""));
+
+ Stream.of(USER, GUEST)
+ .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT,
+ "adminOnly", authenticate(user)).then().assertThat()
+ .statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message",
+ containsString("reason: 'Access denied'")));
+
+ givenEndpointRequest(SECURE_ENDPOINT, "adminOnly").then().assertThat()
+ .statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message", containsString("reason: 'Access denied'"));
+ }
+
+ @Test
+ void securedEndpoint_userOnly_onlyUserAllowed() {
+ givenEndpointRequest(SECURE_ENDPOINT, "userOnly", authenticate(USER))
+ .then().assertThat().statusCode(200).and()
+ .body(equalTo("\"USER\""));
+
+ Stream.of(ADMIN, GUEST)
+ .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT,
+ "userOnly", authenticate(user)).then().assertThat()
+ .statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message",
+ containsString("reason: 'Access denied'")));
+
+ givenEndpointRequest(SECURE_ENDPOINT, "userOnly").then().assertThat()
+ .statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message", containsString("reason: 'Access denied'"));
+ }
+
+ @Test
+ void securedEndpoint_adminAndUserOnly_onlyAdminAndUserAllowed() {
+ Stream.of(ADMIN, USER)
+ .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT,
+ "userAndAdmin", authenticate(user)).then().assertThat()
+ .statusCode(200).and()
+ .body(equalTo("\"USER AND ADMIN\"")));
+
+ givenEndpointRequest(SECURE_ENDPOINT, "userAndAdmin",
+ authenticate(GUEST)).then().assertThat().statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message", containsString("reason: 'Access denied'"));
+
+ givenEndpointRequest(SECURE_ENDPOINT, "userAndAdmin").then()
+ .assertThat().statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message", containsString("reason: 'Access denied'"));
+ }
+
+ @Test
+ void securedEndpoint_deny_notAllowed() {
+ Stream.of(ADMIN, USER, GUEST)
+ .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT, "deny",
+ authenticate(user)).then().assertThat().statusCode(401)
+ .and().body("message", containsString(SECURE_ENDPOINT))
+ .body("message",
+ containsString("reason: 'Access denied'")));
+
+ givenEndpointRequest(SECURE_ENDPOINT, "deny").then().assertThat()
+ .statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message", containsString("reason: 'Access denied'"));
+ }
+
+ @Test
+ void securedEndpoint_notAnnotatedMethod_denyAll() {
+ Stream.of(ADMIN, USER, GUEST)
+ .forEach(user -> givenEndpointRequest(SECURE_ENDPOINT,
+ "denyByDefault", authenticate(user)).then().assertThat()
+ .statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message",
+ containsString("reason: 'Access denied'")));
+
+ givenEndpointRequest(SECURE_ENDPOINT, "denyByDefault").then()
+ .assertThat().statusCode(401).and()
+ .body("message", containsString(SECURE_ENDPOINT))
+ .body("message", containsString("reason: 'Access denied'"));
+ }
+
+ private static UnaryOperator authenticate(User user) {
+ return spec -> spec.auth().preemptive().basic(user.username, user.pwd);
+ }
+}
diff --git a/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/HillaPushClient.java b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/HillaPushClient.java
new file mode 100644
index 00000000..e8fea1b0
--- /dev/null
+++ b/hilla-test-extension/deployment/src/test/java/org/acme/hilla/test/extension/deployment/HillaPushClient.java
@@ -0,0 +1,182 @@
+package org.acme.hilla.test.extension.deployment;
+
+import javax.websocket.CloseReason;
+import javax.websocket.Endpoint;
+import javax.websocket.EndpointConfig;
+import javax.websocket.MessageHandler;
+import javax.websocket.Session;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.net.URI;
+import java.net.URLEncoder;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
+import java.util.logging.LogManager;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import org.assertj.core.api.AbstractStringAssert;
+import org.assertj.core.api.ObjectAssert;
+import org.assertj.core.api.StringAssert;
+import org.junit.jupiter.api.Assertions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class HillaPushClient extends Endpoint
+ implements MessageHandler.Whole {
+
+ private static final Logger LOGGER = LoggerFactory
+ .getLogger(HillaPushClient.class);
+ private static final AtomicInteger CLIENT_ID_GEN = new AtomicInteger();
+
+ final LinkedBlockingDeque messages = new LinkedBlockingDeque<>();
+
+ final String id;
+ private final String endpointName;
+ private final String methodName;
+ private final List