Skip to content

Commit

Permalink
fix: disallow access into ws interceptors
Browse files Browse the repository at this point in the history
  • Loading branch information
bbortt committed Oct 21, 2024
1 parent 4c1290e commit 1ac97c4
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

package org.citrusframework.ws.client;

import static java.util.Collections.singletonList;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

import org.citrusframework.endpoint.AbstractPollableEndpointConfiguration;
import org.citrusframework.endpoint.resolver.DynamicEndpointUriResolver;
import org.citrusframework.endpoint.resolver.EndpointUriResolver;
Expand Down Expand Up @@ -81,9 +82,7 @@ public class WebServiceEndpointConfiguration extends AbstractPollableEndpointCon
* Default constructor initializes with default logging interceptor.
*/
public WebServiceEndpointConfiguration() {
List<ClientInterceptor> interceptors = new ArrayList<>();
interceptors.add(new LoggingClientInterceptor());
setInterceptors(interceptors);
setInterceptors(new ArrayList<>(singletonList(new LoggingClientInterceptor())));
}

/**
Expand Down Expand Up @@ -228,7 +227,7 @@ public void setDefaultUri(String defaultUri) {
* Gets the client interceptors.
* @return
*/
public List<ClientInterceptor> getInterceptors() {
List<ClientInterceptor> getInterceptors() {
return interceptors;
}

Expand All @@ -246,7 +245,14 @@ public void setInterceptors(List<ClientInterceptor> interceptors) {
* @param interceptor
*/
public void setInterceptor(ClientInterceptor interceptor) {
List<ClientInterceptor> interceptors = new ArrayList<>();
setInterceptors(new ArrayList<>(singletonList(interceptor)));
}

/**
* Adds the client interceptor to the already existing ones.
* @param interceptor
*/
public void addInterceptor(ClientInterceptor interceptor) {
interceptors.add(interceptor);
setInterceptors(interceptors);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package org.citrusframework.ws.client;

import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.MockitoAnnotations.openMocks;

import org.citrusframework.ws.interceptor.LoggingClientInterceptor;
import org.mockito.Mock;
import org.springframework.ws.client.core.WebServiceTemplate;
import org.springframework.ws.client.support.interceptor.ClientInterceptor;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

public class WebServiceEndpointConfigurationTest {

@Mock
private WebServiceTemplate webServiceTemplateMock;

private AutoCloseable mockitoContext;

private WebServiceEndpointConfiguration fixture;

@BeforeMethod
public void setUp() {
mockitoContext = openMocks(this);

fixture = new WebServiceEndpointConfiguration();
fixture.setWebServiceTemplate(webServiceTemplateMock);
}

@AfterMethod
public void tearDown() throws Exception {
mockitoContext.close();
}

@Test
public void containsLoggingClientInterceptorByDefault() {
assertThat(fixture.getInterceptors())
.hasSize(1)
.satisfiesOnlyOnce(i -> assertThat(i).isInstanceOf(LoggingClientInterceptor.class));

verify(webServiceTemplateMock)
.setInterceptors(fixture.getInterceptors().toArray(new ClientInterceptor[0]));
}

@Test
public void setInterceptors_overridesDefaultInterceptor() {
var clientInterceptor = mock(ClientInterceptor.class);

fixture.setInterceptors(singletonList(clientInterceptor));

verifyFixtureContainsOnlyClientInterceptor(clientInterceptor);

verify(webServiceTemplateMock)
.setInterceptors(new ClientInterceptor[]{clientInterceptor});
}

@Test
public void setInterceptor_overridesDefaultInterceptor() {
var clientInterceptor = mock(ClientInterceptor.class);

fixture.setInterceptor(clientInterceptor);

verifyFixtureContainsOnlyClientInterceptor(clientInterceptor);

verify(webServiceTemplateMock)
.setInterceptors(new ClientInterceptor[]{clientInterceptor});
}

@Test
public void addInterceptorAppendsToDefaultInterceptors() {
var clientInterceptor = mock(ClientInterceptor.class);

fixture.addInterceptor(clientInterceptor);

assertThat(fixture.getInterceptors())
.hasSize(2)
.satisfiesOnlyOnce(i -> assertThat(i).isInstanceOf(LoggingClientInterceptor.class))
.satisfiesOnlyOnce(i -> assertThat(i).isEqualTo(clientInterceptor));

verify(webServiceTemplateMock)
.setInterceptors(fixture.getInterceptors().toArray(new ClientInterceptor[0]));
}

private void verifyFixtureContainsOnlyClientInterceptor(ClientInterceptor clientInterceptor) {
assertThat(fixture.getInterceptors())
.hasSize(1)
.satisfiesOnlyOnce(i -> assertThat(i).isEqualTo(clientInterceptor));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.citrusframework.ws.client;

import static org.citrusframework.util.ReflectionHelper.getField;

import java.util.List;
import org.springframework.ws.client.support.interceptor.ClientInterceptor;

public final class WsTestUtils {

@SuppressWarnings({"unchecked"})
public static List<ClientInterceptor> getInterceptors(WebServiceClient webServiceClient) throws NoSuchFieldException {
return (List<ClientInterceptor>) getField(
WebServiceEndpointConfiguration.class.getDeclaredField("interceptors"),
webServiceClient.getEndpointConfiguration());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package org.citrusframework.ws.config.annotation;

import static org.citrusframework.ws.client.WsTestUtils.getInterceptors;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Map;

import org.citrusframework.TestActor;
import org.citrusframework.annotations.CitrusAnnotations;
import org.citrusframework.annotations.CitrusEndpoint;
Expand Down Expand Up @@ -52,8 +54,6 @@
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import static org.mockito.Mockito.when;

public class WebServiceClientConfigParserTest extends AbstractTestNGUnitTest {

@CitrusEndpoint(name = "wsClient1")
Expand Down Expand Up @@ -137,15 +137,15 @@ public void setMocks() {
}

@Test
public void testWebServiceClientParser() {
public void testWebServiceClientParser() throws NoSuchFieldException {
CitrusAnnotations.injectEndpoints(this, context);

// 1st message sender
Assert.assertEquals(client1.getEndpointConfiguration().getDefaultUri(), "http://localhost:8080/test");
Assert.assertTrue(client1.getEndpointConfiguration().getMessageFactory() instanceof SoapMessageFactory);
Assert.assertEquals(client1.getEndpointConfiguration().getCorrelator().getClass(), DefaultMessageCorrelator.class);
Assert.assertEquals(client1.getEndpointConfiguration().getInterceptors().size(), 1L);
Assert.assertEquals(client1.getEndpointConfiguration().getInterceptors().get(0).getClass(), LoggingClientInterceptor.class);
Assert.assertEquals(getInterceptors(client1).size(), 1L);
Assert.assertEquals(getInterceptors(client1).get(0).getClass(), LoggingClientInterceptor.class);
Assert.assertTrue(client1.getEndpointConfiguration().getMessageConverter() instanceof SoapMessageConverter);
Assert.assertEquals(client1.getEndpointConfiguration().getErrorHandlingStrategy(), ErrorHandlingStrategy.THROWS_EXCEPTION);
Assert.assertEquals(client1.getEndpointConfiguration().getTimeout(), 5000L);
Expand Down Expand Up @@ -175,19 +175,19 @@ public void testWebServiceClientParser() {
Assert.assertEquals(client4.getEndpointConfiguration().getErrorHandlingStrategy(), ErrorHandlingStrategy.THROWS_EXCEPTION);
Assert.assertNotNull(client4.getEndpointConfiguration().getMessageSender());
Assert.assertEquals(client4.getEndpointConfiguration().getMessageSender(), messageSender);
Assert.assertEquals(client4.getEndpointConfiguration().getInterceptors().size(), 1L);
Assert.assertEquals(client4.getEndpointConfiguration().getInterceptors().get(0), clientInterceptor1);
Assert.assertEquals(getInterceptors(client4).size(), 1L);
Assert.assertEquals(getInterceptors(client4).get(0), clientInterceptor1);
Assert.assertNotNull(client4.getEndpointConfiguration().getWebServiceTemplate());
Assert.assertEquals(client4.getEndpointConfiguration().getWebServiceTemplate().getInterceptors().length, 1L);
Assert.assertTrue(client4.getEndpointConfiguration().getMessageConverter() instanceof WsAddressingMessageConverter);

// 5th message sender
Assert.assertEquals(client5.getEndpointConfiguration().getDefaultUri(), "http://localhost:8080/test");
Assert.assertEquals(client5.getEndpointConfiguration().getErrorHandlingStrategy(), ErrorHandlingStrategy.PROPAGATE);
Assert.assertNotNull(client5.getEndpointConfiguration().getInterceptors());
Assert.assertEquals(client5.getEndpointConfiguration().getInterceptors().size(), 2L);
Assert.assertEquals(client5.getEndpointConfiguration().getInterceptors().get(0), clientInterceptor1);
Assert.assertEquals(client5.getEndpointConfiguration().getInterceptors().get(1), clientInterceptor2);
Assert.assertNotNull(getInterceptors(client5));
Assert.assertEquals(getInterceptors(client5).size(), 2L);
Assert.assertEquals(getInterceptors(client5).get(0), clientInterceptor1);
Assert.assertEquals(getInterceptors(client5).get(1), clientInterceptor2);
Assert.assertEquals(client5.getEndpointConfiguration().getPollingInterval(), 250L);
Assert.assertNotNull(client5.getEndpointConfiguration().getWebServiceTemplate());
Assert.assertEquals(client5.getEndpointConfiguration().getWebServiceTemplate().getInterceptors().length, 2L);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package org.citrusframework.ws.config.xml;

import java.util.Map;
import static org.citrusframework.ws.client.WsTestUtils.getInterceptors;

import java.util.Map;
import org.citrusframework.TestActor;
import org.citrusframework.message.DefaultMessageCorrelator;
import org.citrusframework.message.ErrorHandlingStrategy;
Expand All @@ -34,7 +35,7 @@
public class WebServiceClientParserTest extends AbstractBeanDefinitionParserTest {

@Test
public void testWebServiceClientParser() {
public void testWebServiceClientParser() throws NoSuchFieldException {
Map<String, WebServiceClient> messageSenders = beanDefinitionContext.getBeansOfType(WebServiceClient.class);

Assert.assertEquals(messageSenders.size(), 6);
Expand All @@ -44,8 +45,8 @@ public void testWebServiceClientParser() {
Assert.assertEquals(client.getEndpointConfiguration().getDefaultUri(), "http://localhost:8080/test");
Assert.assertTrue(client.getEndpointConfiguration().getMessageFactory() instanceof SoapMessageFactory);
Assert.assertEquals(client.getEndpointConfiguration().getCorrelator().getClass(), DefaultMessageCorrelator.class);
Assert.assertEquals(client.getEndpointConfiguration().getInterceptors().size(), 1L);
Assert.assertEquals(client.getEndpointConfiguration().getInterceptors().get(0).getClass(), LoggingClientInterceptor.class);
Assert.assertEquals(getInterceptors(client).size(), 1L);
Assert.assertEquals(getInterceptors(client).get(0).getClass(), LoggingClientInterceptor.class);
Assert.assertTrue(client.getEndpointConfiguration().getMessageConverter() instanceof SoapMessageConverter);
Assert.assertEquals(client.getEndpointConfiguration().getErrorHandlingStrategy(), ErrorHandlingStrategy.THROWS_EXCEPTION);
Assert.assertEquals(client.getEndpointConfiguration().getTimeout(), 5000L);
Expand Down Expand Up @@ -78,8 +79,8 @@ public void testWebServiceClientParser() {
Assert.assertEquals(client.getEndpointConfiguration().getErrorHandlingStrategy(), ErrorHandlingStrategy.THROWS_EXCEPTION);
Assert.assertNotNull(client.getEndpointConfiguration().getMessageSender());
Assert.assertEquals(client.getEndpointConfiguration().getMessageSender(), beanDefinitionContext.getBean("wsMessageSender"));
Assert.assertEquals(client.getEndpointConfiguration().getInterceptors().size(), 1L);
Assert.assertEquals(client.getEndpointConfiguration().getInterceptors().get(0), beanDefinitionContext.getBean("singleInterceptor"));
Assert.assertEquals(getInterceptors(client).size(), 1L);
Assert.assertEquals(getInterceptors(client).get(0), beanDefinitionContext.getBean("singleInterceptor"));
Assert.assertNotNull(client.getEndpointConfiguration().getWebServiceTemplate());
Assert.assertEquals(client.getEndpointConfiguration().getWebServiceTemplate().getInterceptors().length, 1L);
Assert.assertTrue(client.getEndpointConfiguration().getMessageConverter() instanceof WsAddressingMessageConverter);
Expand All @@ -88,10 +89,10 @@ public void testWebServiceClientParser() {
client = messageSenders.get("soapClient5");
Assert.assertEquals(client.getEndpointConfiguration().getDefaultUri(), "http://localhost:8080/test");
Assert.assertEquals(client.getEndpointConfiguration().getErrorHandlingStrategy(), ErrorHandlingStrategy.PROPAGATE);
Assert.assertNotNull(client.getEndpointConfiguration().getInterceptors());
Assert.assertEquals(client.getEndpointConfiguration().getInterceptors().size(), 2L);
Assert.assertEquals(client.getEndpointConfiguration().getInterceptors().get(0), beanDefinitionContext.getBean("interceptor1"));
Assert.assertEquals(client.getEndpointConfiguration().getInterceptors().get(1), beanDefinitionContext.getBean("interceptor2"));
Assert.assertNotNull(getInterceptors(client));
Assert.assertEquals(getInterceptors(client).size(), 2L);
Assert.assertEquals(getInterceptors(client).get(0), beanDefinitionContext.getBean("interceptor1"));
Assert.assertEquals(getInterceptors(client).get(1), beanDefinitionContext.getBean("interceptor2"));
Assert.assertEquals(client.getEndpointConfiguration().getPollingInterval(), 250L);
Assert.assertNotNull(client.getEndpointConfiguration().getWebServiceTemplate());
Assert.assertEquals(client.getEndpointConfiguration().getWebServiceTemplate().getInterceptors().length, 2L);
Expand Down

0 comments on commit 1ac97c4

Please sign in to comment.