Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development: Add annotation to disable thread group check #378

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/main/java/de/tum/in/test/api/DisableThreadGroupCheckFor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package de.tum.in.test.api;

import static java.lang.annotation.ElementType.*;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

import java.lang.annotation.Documented;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import org.apiguardian.api.API;

/**
* Allows to disable the thread group check for threads which names start with
* any of the given prefixes.
*
* @author Benjamin Schmitz
* @since 1.14.0
* @version 1.0.0
*/
@API(status = API.Status.EXPERIMENTAL)
@Inherited
@Documented
@Retention(RUNTIME)
@Target({ TYPE, ANNOTATION_TYPE })
public @interface DisableThreadGroupCheckFor {
String[] value();
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public static AresSecurityConfiguration generateConfiguration(TestContext contex
config.withPackageWhitelist(generatePackageWhiteList(context));
config.withTrustedPackages(getTrustedPackages(context));
config.withThreadTrustScope(getThreadTrustScope(context));
config.withAllowedThreadsInThreadGroup(getAllowedThreadsInThreadGroup(context));
configureAllowLocalPort(config, context);
return config.build();
}
Expand Down Expand Up @@ -109,4 +110,9 @@ private static TrustScope getThreadTrustScope(TestContext context) {
return TestContextUtils.findAnnotationIn(context, TrustedThreads.class).map(TrustedThreads::value)
.orElse(TrustScope.MINIMAL);
}

public static Set<String> getAllowedThreadsInThreadGroup(TestContext context) {
return new HashSet<>(TestContextUtils.findAnnotationIn(context, DisableThreadGroupCheckFor.class)
.map(DisableThreadGroupCheckFor::value).map(Arrays::asList).orElse(Collections.emptyList()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ public final class AresSecurityConfiguration {
private final Set<PackageRule> whitelistedPackages;
private final Set<PackageRule> trustedPackages;
private final TrustScope threadTrustScope;
private final Set<String> allowedThreadsInThreadGroup;

AresSecurityConfiguration(Optional<Class<?>> testClass, Optional<Method> testMethod, Path executionPath, // NOSONAR
Collection<String> whitelistedClassNames, Optional<Collection<PathRule>> whitelistedPaths,
Collection<PathRule> blacklistedPaths, Set<Integer> allowedLocalPorts, OptionalInt allowLocalPortsAbove,
Set<Integer> excludedLocalPorts, OptionalInt allowedThreadCount, Set<PackageRule> blacklistedPackages,
Set<PackageRule> whitelistedPackages, Set<PackageRule> trustedPackages, TrustScope threadTrustScope) {
Set<PackageRule> whitelistedPackages, Set<PackageRule> trustedPackages, TrustScope threadTrustScope,
Set<String> allowedThreadsInThreadGroup) {
this.testClass = Objects.requireNonNull(testClass);
this.testMethod = Objects.requireNonNull(testMethod);
this.executionPath = executionPath.toAbsolutePath();
Expand All @@ -46,6 +48,7 @@ public final class AresSecurityConfiguration {
this.whitelistedPackages = Set.copyOf(whitelistedPackages);
this.trustedPackages = Set.copyOf(trustedPackages);
this.threadTrustScope = threadTrustScope;
this.allowedThreadsInThreadGroup = allowedThreadsInThreadGroup;
}

public Optional<Class<?>> testClass() {
Expand Down Expand Up @@ -104,6 +107,10 @@ public TrustScope threadTrustScope() {
return threadTrustScope;
}

public Set<String> getAllowedThreadsInThreadGroup() {
return allowedThreadsInThreadGroup;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
Expand All @@ -122,24 +129,27 @@ public boolean equals(Object obj) {
&& Objects.equals(blacklistedPaths, other.blacklistedPaths)
&& Objects.equals(blacklistedPackages, other.blacklistedPackages)
&& Objects.equals(whitelistedPackages, other.whitelistedPackages)
&& Objects.equals(threadTrustScope, other.threadTrustScope);
&& Objects.equals(threadTrustScope, other.threadTrustScope)
&& Objects.equals(allowedThreadsInThreadGroup, other.allowedThreadsInThreadGroup);
}

@Override
public int hashCode() {
return Objects.hash(executionPath, testClass, testMethod, whitelistedClassNames, allowedThreadCount,
whitelistedPaths, blacklistedPaths, blacklistedPackages, whitelistedPackages, threadTrustScope);
whitelistedPaths, blacklistedPaths, blacklistedPackages, whitelistedPackages, threadTrustScope,
allowedThreadsInThreadGroup);
}

@Override
public String toString() {
return String.format("AresSecurityConfiguration [whitelistedClassNames=%s, executionPath=%s," //$NON-NLS-1$
+ " testClass=%s, testMethod=%s, whitelistedPaths=%s, blacklistedPaths=%s, allowedLocalPorts=%s," //$NON-NLS-1$
+ " allowLocalPortsAbove=%s, excludedLocalPorts=%s, allowedThreadCount=%s," //$NON-NLS-1$
+ " blacklistedPackages=%s, whitelistedPackages=%s, trustedPackages=%s, threadTrustScope=%s]", //$NON-NLS-1$
+ " blacklistedPackages=%s, whitelistedPackages=%s, trustedPackages=%s, threadTrustScope=%s," //$NON-NLS-1$
+ " allowedThreadsInThreadGroup=%b]", //$NON-NLS-1$
whitelistedClassNames, executionPath, testClass, testMethod, whitelistedPaths, blacklistedPaths,
allowedLocalPorts, allowLocalPortsAbove, excludedLocalPorts, allowedThreadCount, blacklistedPackages,
whitelistedPackages, trustedPackages, threadTrustScope);
whitelistedPackages, trustedPackages, threadTrustScope, allowedThreadsInThreadGroup);
}

public String shortDesc() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public final class AresSecurityConfigurationBuilder {
private OptionalInt allowedThreadCount;
private Set<PackageRule> trustedPackages;
private TrustScope threadTrustScope;
private Set<String> allowedThreadsInThreadGroup;

private AresSecurityConfigurationBuilder() {
testClass = Optional.empty();
Expand All @@ -74,6 +75,7 @@ private AresSecurityConfigurationBuilder() {
allowedThreadCount = OptionalInt.empty();
trustedPackages = Set.of();
threadTrustScope = TrustScope.MINIMAL;
allowedThreadsInThreadGroup = Set.of();
}

public AresSecurityConfigurationBuilder withPath(Path executionPath) {
Expand Down Expand Up @@ -142,12 +144,17 @@ public AresSecurityConfigurationBuilder withThreadTrustScope(TrustScope threadTr
return this;
}

public AresSecurityConfigurationBuilder withAllowedThreadsInThreadGroup(Set allowedThreadsInThreadGroup) {
this.allowedThreadsInThreadGroup = allowedThreadsInThreadGroup;
return this;
}

public AresSecurityConfiguration build() {
validate();
return new AresSecurityConfiguration(testClass, testMethod, executionPath, whitelistedClassNames,
Optional.ofNullable(whitelistedPaths), blacklistedPaths, allowedLocalPorts, allowLocalPortsAbove,
excludedLocalPorts, allowedThreadCount, blacklistedPackages, whitelistedPackages, trustedPackages,
threadTrustScope);
threadTrustScope, allowedThreadsInThreadGroup);
}

private void validate() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ private Thread[] checkThreadGroup() {
for (Thread thread : threads) {
if (thread == null)
continue;
if (checkIfThreadNameStartsWithAny(thread, configuration.getAllowedThreadsInThreadGroup()))
continue;
try {
thread.interrupt();
thread.join(500 / originalCount + 1L);
Expand All @@ -612,6 +614,13 @@ private Thread[] checkThreadGroup() {
}
if (testThreadGroup.activeCount() == 0)
return new Thread[0];

if (Arrays.stream(threads).filter(Thread::isAlive)
.allMatch(t -> checkIfThreadNameStartsWithAny(t, configuration.getAllowedThreadsInThreadGroup()))) {
LOG.debug("All threads in the test thread group are allowed to run."); //$NON-NLS-1$
return new Thread[0];
}

// try forceful shutdown
var securityException = new SecurityException(
localized("security.error_threads_not_stoppable", Arrays.toString(threads))); //$NON-NLS-1$
Expand Down Expand Up @@ -657,6 +666,10 @@ private Thread[] checkThreadGroup() {
return threads;
}

private boolean checkIfThreadNameStartsWithAny(Thread thread, Set<String> allowedThreadStarts) {
return allowedThreadStarts.stream().anyMatch(thread.getName()::startsWith);
}

private void checkCommonThreadPool() {
var commonPool = ForkJoinPool.commonPool();
if (commonPool.isQuiescent())
Expand Down
Loading