diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java b/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java index 510e74b39..f39ba0a8f 100644 --- a/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java +++ b/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java @@ -101,14 +101,20 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { JavaType expressionType = visitedTypeCast.getExpression().getType(); JavaType castType = visitedTypeCast.getType(); - if (targetType == null || - (targetType instanceof JavaType.Primitive || castType instanceof JavaType.Primitive) && castType != expressionType || - (typeCast.getExpression() instanceof J.Lambda || typeCast.getExpression() instanceof J.MemberReference) && castType instanceof JavaType.Parameterized) { + if (targetType == null) { + return visitedTypeCast; + } + if ((targetType instanceof JavaType.Primitive || castType instanceof JavaType.Primitive) && castType != expressionType) { + return visitedTypeCast; + } + if (typeCast.getExpression() instanceof J.Lambda || typeCast.getExpression() instanceof J.MemberReference) { // Not currently supported, this will be more accurate with dataflow analysis. return visitedTypeCast; - } else if (!(targetType instanceof JavaType.Array) && TypeUtils.isOfClassType(targetType, "java.lang.Object") || - TypeUtils.isOfType(targetType, expressionType) || - TypeUtils.isAssignableTo(targetType, expressionType)) { + } + + if (!(targetType instanceof JavaType.Array) && TypeUtils.isOfClassType(targetType, "java.lang.Object") || + TypeUtils.isOfType(targetType, expressionType) || + TypeUtils.isAssignableTo(targetType, expressionType)) { JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(castType); if (fullyQualified != null) { maybeRemoveImport(fullyQualified.getFullyQualifiedName()); diff --git a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java index 652079347..3a518a6fc 100644 --- a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.Issue; +import org.openrewrite.java.JavaParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; @@ -394,6 +395,7 @@ class ExtendTest extends Test { @Test void lambdaWithComplexTypeInference() { rewriteRun( + //language=java java( """ import java.util.LinkedHashMap; @@ -425,6 +427,7 @@ public MapDropdownChoice(Supplier> choiceMap) { @Test void returnPrimitiveIntToWrapperLong() { rewriteRun( + //language=java java( """ class Test { @@ -440,6 +443,7 @@ Long method() { @Test void castWildcard() { rewriteRun( + //language=java java( """ import java.util.ArrayList; @@ -459,6 +463,7 @@ void method() { @Test void removeImport() { rewriteRun( + //language=java java( """ import java.util.ArrayList; @@ -482,4 +487,45 @@ List method(List list) { ) ); } + + @Test + void retainCastInMarshaller() { + rewriteRun( + spec -> spec.parser(JavaParser.fromJavaVersion() + //language=java + .dependsOn( + """ + package org.glassfish.jaxb.core.marshaller; + import java.io.IOException; + import java.io.Writer; + + public interface CharacterEscapeHandler { + void escape( char[] ch, int start, int length, boolean isAttVal, Writer out ) throws IOException;\s + } + """, + """ + package javax.xml.bind; + + public interface Marshaller { + void setProperty(String var1, Object var2); + } + """ + ) + ), + //language=java + java( + """ + import javax.xml.bind.Marshaller; + import org.glassfish.jaxb.core.marshaller.CharacterEscapeHandler; + + class Foo { + void bar(Marshaller marshaller) { + marshaller.setProperty("org.glassfish.jaxb.characterEscapeHandler", (CharacterEscapeHandler) (ch, start, length, isAttVal, out) -> { + }); + } + } + """ + ) + ); + } }