diff --git a/com.ibm.wala.cast.python.jython3/source/com/ibm/wala/cast/python/client/PytestAnalysisEngine.java b/com.ibm.wala.cast.python.jython3/source/com/ibm/wala/cast/python/client/PytestAnalysisEngine.java index 765e0b4d0..49bdea156 100644 --- a/com.ibm.wala.cast.python.jython3/source/com/ibm/wala/cast/python/client/PytestAnalysisEngine.java +++ b/com.ibm.wala.cast.python.jython3/source/com/ibm/wala/cast/python/client/PytestAnalysisEngine.java @@ -1,6 +1,5 @@ package com.ibm.wala.cast.python.client; -import com.ibm.wala.cast.python.ipa.callgraph.PythonSSAPropagationCallGraphBuilder; import com.ibm.wala.cast.python.loader.PytestLoader; import com.ibm.wala.cast.python.loader.PytestLoaderFactory; import com.ibm.wala.classLoader.CallSiteReference; @@ -10,7 +9,6 @@ import com.ibm.wala.core.util.strings.Atom; import com.ibm.wala.ipa.callgraph.AnalysisOptions; import com.ibm.wala.ipa.callgraph.CGNode; -import com.ibm.wala.ipa.callgraph.IAnalysisCacheView; import com.ibm.wala.ipa.callgraph.MethodTargetSelector; import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; import com.ibm.wala.ipa.callgraph.propagation.PointerKey; @@ -23,8 +21,6 @@ public class PytestAnalysisEngine extends PythonAnalysisEngine { - private PythonSSAPropagationCallGraphBuilder builder; - private class PytestTargetSelector implements MethodTargetSelector { private final MethodTargetSelector base; @@ -88,12 +84,6 @@ protected void addBypassLogic(IClassHierarchy cha, AnalysisOptions options) { addSummaryBypassLogic(options, "pytest.xml"); } - @Override - protected PythonSSAPropagationCallGraphBuilder getCallGraphBuilder( - IClassHierarchy cha, AnalysisOptions options, IAnalysisCacheView cache) { - return builder = super.getCallGraphBuilder(cha, options, cache); - } - @Override public T performAnalysis(PropagationCallGraphBuilder arg0) throws CancelException { return null; diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index bef0ac9bb..0efaa55ff 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -207,20 +207,13 @@ public void testTf2() testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3); testTf2("tf2_test_tensor_list2.py", "add", 0, 2); testTf2("tf2_test_tensor_list3.py", "add", 0, 2); - testTf2( - "tf2_test_model_call.py", - "SequentialModel.__call__", - 0, - 2); // NOTE: Change to testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, - // 2) once - // https://github.com/wala/ML/issues/24 is fixed. + testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 2); testTf2( "tf2_test_model_call2.py", "SequentialModel.call", 0, 2); // NOTE: Change to testTf2("tf2_test_model_call2.py", "SequentialModel.call", 1, 4, 2) - // once - // https://github.com/wala/ML/issues/106 is fixed. + // once https://github.com/wala/ML/issues/106 is fixed. testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 2); testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 2); } diff --git a/com.ibm.wala.cast.python.test/data/callables.py b/com.ibm.wala.cast.python.test/data/callables.py new file mode 100644 index 000000000..de9c08e1b --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/callables.py @@ -0,0 +1,9 @@ +class C: + + def __call__(self, x): + return x * x + + +c = C() +a = c.__call__(5) +assert a == 25 diff --git a/com.ibm.wala.cast.python.test/data/callables2.py b/com.ibm.wala.cast.python.test/data/callables2.py new file mode 100644 index 000000000..56b017789 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/callables2.py @@ -0,0 +1,9 @@ +class C: + + def __call__(self, x): + return x * x + + +c = C() +a = c(5) +assert a == 25 diff --git a/com.ibm.wala.cast.python.test/data/callables3.py b/com.ibm.wala.cast.python.test/data/callables3.py new file mode 100644 index 000000000..c64a8d661 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/callables3.py @@ -0,0 +1,9 @@ +class C(object): + + def __call__(self, x): + return x * x + + +c = C() +a = c.__call__(5) +assert a == 25 diff --git a/com.ibm.wala.cast.python.test/data/callables4.py b/com.ibm.wala.cast.python.test/data/callables4.py new file mode 100644 index 000000000..a37c8548b --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/callables4.py @@ -0,0 +1,9 @@ +class C(object): + + def __call__(self, x): + return x * x + + +c = C() +a = c(5) +assert a == 25 diff --git a/com.ibm.wala.cast.python.test/data/callables5.py b/com.ibm.wala.cast.python.test/data/callables5.py new file mode 100644 index 000000000..e29d03592 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/callables5.py @@ -0,0 +1,13 @@ +class D: + pass + + +class C(D): + + def __call__(self, x): + return x * x + + +c = C() +a = c.__call__(5) +assert a == 25 diff --git a/com.ibm.wala.cast.python.test/data/callables6.py b/com.ibm.wala.cast.python.test/data/callables6.py new file mode 100644 index 000000000..0040d95ec --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/callables6.py @@ -0,0 +1,13 @@ +class D: + pass + + +class C(D): + + def __call__(self, x): + return x * x + + +c = C() +a = c(5) +assert a == 25 diff --git a/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestCallables.java b/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestCallables.java new file mode 100644 index 000000000..17ac6c2c1 --- /dev/null +++ b/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestCallables.java @@ -0,0 +1,59 @@ +package com.ibm.wala.cast.python.test; + +import static org.junit.Assert.assertTrue; + +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.CallGraph; +import com.ibm.wala.ipa.cha.ClassHierarchyException; +import com.ibm.wala.util.CancelException; +import java.io.IOException; +import java.util.Iterator; +import java.util.logging.Logger; +import org.junit.Test; + +public class TestCallables extends TestPythonCallGraphShape { + + private static Logger logger = Logger.getLogger(TestCallables.class.getName()); + + @Test + public void testCallables() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + final String[] testFileNames = { + "callables.py", + "callables2.py", + "callables3.py", + "callables4.py", + "callables5.py", + "callables6.py" + }; + + for (String fileName : testFileNames) { + CallGraph CG = process(fileName); + boolean found = false; + + for (CGNode node : CG) { + if (node.getMethod() + .getDeclaringClass() + .getName() + .toString() + .equals("Lscript " + fileName)) { + + for (Iterator it = CG.getSuccNodes(node); it.hasNext(); ) { + CGNode callee = it.next(); + + logger.info("Found callee: " + callee.getMethod().getSignature()); + + if (callee + .getMethod() + .getDeclaringClass() + .getName() + .toString() + .equals("L$script " + fileName + "/C/__call__")) found = true; + } + } + } + + assertTrue("Expecting to find __call__ method trampoline in: " + fileName + ".", found); + } + } +} diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/client/PythonAnalysisEngine.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/client/PythonAnalysisEngine.java index 574951c41..5da40bfb5 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/client/PythonAnalysisEngine.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/client/PythonAnalysisEngine.java @@ -73,6 +73,8 @@ public abstract class PythonAnalysisEngine private static final Logger logger = Logger.getLogger(PythonAnalysisEngine.class.getName()); + protected PythonSSAPropagationCallGraphBuilder builder; + static { try { Class j3 = Class.forName("com.ibm.wala.cast.python.loader.Python3LoaderFactory"); @@ -286,9 +288,10 @@ public boolean isReferenceType() { protected void addBypassLogic(IClassHierarchy cha, AnalysisOptions options) { options.setSelector( - new PythonTrampolineTargetSelector( + new PythonTrampolineTargetSelector( new PythonConstructorTargetSelector( - new PythonComprehensionTrampolines(options.getMethodTargetSelector())))); + new PythonComprehensionTrampolines(options.getMethodTargetSelector())), + this)); BuiltinFunctions builtins = new BuiltinFunctions(cha); options.setSelector(builtins.builtinClassTargetSelector(options.getClassTargetSelector())); @@ -349,7 +352,11 @@ public int getDefaultValue(SymbolTable symtab, int valueNumber) { new PythonSuper(cha).handleSuperCalls(builder, options); - return builder; + return this.builder = builder; + } + + public PythonSSAPropagationCallGraphBuilder getCachedCallGraphBuilder() { + return this.builder; } protected PythonSSAPropagationCallGraphBuilder makeBuilder( diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java index 7629a9879..5e80fcd22 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java @@ -17,7 +17,9 @@ import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor; import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction; import com.ibm.wala.cast.python.types.PythonTypes; +import com.ibm.wala.classLoader.IClass; import com.ibm.wala.classLoader.IField; +import com.ibm.wala.classLoader.IMethod; import com.ibm.wala.core.util.strings.Atom; import com.ibm.wala.fixpoint.AbstractOperator; import com.ibm.wala.ipa.callgraph.AnalysisOptions; @@ -25,6 +27,7 @@ import com.ibm.wala.ipa.callgraph.IAnalysisCacheView; import com.ibm.wala.ipa.callgraph.propagation.AbstractFieldPointerKey; import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; import com.ibm.wala.ipa.callgraph.propagation.PointerKey; import com.ibm.wala.ipa.callgraph.propagation.PointerKeyFactory; import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; @@ -37,11 +40,15 @@ import com.ibm.wala.ssa.SSAGetInstruction; import com.ibm.wala.ssa.SymbolTable; import com.ibm.wala.types.FieldReference; +import com.ibm.wala.types.TypeName; import com.ibm.wala.types.TypeReference; import com.ibm.wala.util.collections.HashMapFactory; import com.ibm.wala.util.collections.Pair; +import com.ibm.wala.util.intset.IntIterator; +import com.ibm.wala.util.intset.IntSet; import com.ibm.wala.util.intset.IntSetUtil; import com.ibm.wala.util.intset.MutableIntSet; +import com.ibm.wala.util.intset.OrdinalSet; import java.util.Arrays; import java.util.Collection; import java.util.Map; @@ -211,7 +218,19 @@ protected void processCallingConstraints( } } else { PointerKey rval = getPointerKeyForLocal(caller, call.getUse(i)); - getSystem().newConstraint(lval, assignOperator, rval); + + // If we are looking at the implicit parameter of a callable. + if (call.getCallSite().isDispatch() && i == 0 && isCallable(rval)) { + // Ensure that lval's variable refers to the callable method instead of callable object. + IClass callable = getCallable(target); + IntSet instanceKeysForCallable = this.getSystem().getInstanceKeysForClass(callable); + + for (IntIterator it = instanceKeysForCallable.intIterator(); it.hasNext(); ) { + int instanceKeyIndex = it.next(); + InstanceKey instanceKey = this.getSystem().getInstanceKey(instanceKeyIndex); + this.getSystem().newConstraint(lval, instanceKey); + } + } else getSystem().newConstraint(lval, assignOperator, rval); } } @@ -271,6 +290,33 @@ protected void processCallingConstraints( } } + private IClass getCallable(CGNode target) { + IMethod method = target.getMethod(); + IClass declaringClass = method.getDeclaringClass(); + TypeName declaringClassName = declaringClass.getName(); + + TypeReference typeReference = + TypeReference.findOrCreate( + declaringClass.getClassLoader().getReference(), declaringClassName); + + return this.getClassHierarchy().lookupClass(typeReference); + } + + protected boolean isCallable(PointerKey rval) { + PointerAnalysis pointerAnalysis = this.getPointerAnalysis(); + OrdinalSet pointsToSet = pointerAnalysis.getPointsToSet(rval); + + for (InstanceKey instanceKey : pointsToSet) { + IClass concreteType = instanceKey.getConcreteType(); + TypeReference reference = concreteType.getReference(); + + // If it's an "object" method. + if (reference.equals(PythonTypes.object)) return true; + } + + return false; + } + @Override public PythonConstraintVisitor makeVisitor(CGNode node) { return new PythonConstraintVisitor(this, node); diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonTrampolineTargetSelector.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonTrampolineTargetSelector.java index 2cec10ac2..20108293d 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonTrampolineTargetSelector.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonTrampolineTargetSelector.java @@ -11,6 +11,7 @@ package com.ibm.wala.cast.python.ipa.callgraph; import com.ibm.wala.cast.loader.DynamicCallSiteReference; +import com.ibm.wala.cast.python.client.PythonAnalysisEngine; import com.ibm.wala.cast.python.ipa.summaries.PythonInstanceMethodTrampoline; import com.ibm.wala.cast.python.ipa.summaries.PythonSummarizedFunction; import com.ibm.wala.cast.python.ipa.summaries.PythonSummary; @@ -24,19 +25,32 @@ import com.ibm.wala.core.util.strings.Atom; import com.ibm.wala.ipa.callgraph.CGNode; import com.ibm.wala.ipa.callgraph.MethodTargetSelector; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; +import com.ibm.wala.ipa.callgraph.propagation.NormalAllocationInNode; +import com.ibm.wala.ipa.callgraph.propagation.PointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerKeyFactory; import com.ibm.wala.ipa.cha.IClassHierarchy; import com.ibm.wala.ssa.SSAReturnInstruction; import com.ibm.wala.types.FieldReference; import com.ibm.wala.types.MethodReference; +import com.ibm.wala.types.TypeName; +import com.ibm.wala.types.TypeReference; import com.ibm.wala.util.collections.HashMapFactory; import com.ibm.wala.util.collections.Pair; +import com.ibm.wala.util.intset.OrdinalSet; import java.util.Map; -public class PythonTrampolineTargetSelector implements MethodTargetSelector { +public class PythonTrampolineTargetSelector implements MethodTargetSelector { + private static final String CALL = "__call__"; + private final MethodTargetSelector base; - public PythonTrampolineTargetSelector(MethodTargetSelector base) { + private PythonAnalysisEngine engine; + + public PythonTrampolineTargetSelector( + MethodTargetSelector base, PythonAnalysisEngine pythonAnalysisEngine) { this.base = base; + this.engine = pythonAnalysisEngine; } private final Map, IMethod> codeBodies = HashMapFactory.make(); @@ -46,8 +60,18 @@ public PythonTrampolineTargetSelector(MethodTargetSelector base) { public IMethod getCalleeTarget(CGNode caller, CallSiteReference site, IClass receiver) { if (receiver != null) { IClassHierarchy cha = receiver.getClassHierarchy(); - if (cha.isSubclassOf(receiver, cha.lookupClass(PythonTypes.trampoline))) { + final boolean callable = receiver.getReference().equals(PythonTypes.object); + + if (cha.isSubclassOf(receiver, cha.lookupClass(PythonTypes.trampoline)) || callable) { PythonInvokeInstruction call = (PythonInvokeInstruction) caller.getIR().getCalls(site)[0]; + + if (callable) { + // It's a callable. Change the receiver. + receiver = getCallable(caller, cha, call); + + if (receiver == null) return null; // not found. + } + Pair key = Pair.make(receiver, call.getNumberOfTotalParameters()); if (!codeBodies.containsKey(key)) { Map names = HashMapFactory.make(); @@ -122,4 +146,34 @@ public IMethod getCalleeTarget(CGNode caller, CallSiteReference site, IClass rec return base.getCalleeTarget(caller, site, receiver); } + + private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstruction call) { + PythonSSAPropagationCallGraphBuilder builder = this.getEngine().getCachedCallGraphBuilder(); + + // Lookup the __call__ method. + PointerKeyFactory pkf = builder.getPointerKeyFactory(); + PointerKey receiver = pkf.getPointerKeyForLocal(caller, call.getUse(0)); + OrdinalSet objs = builder.getPointerAnalysis().getPointsToSet(receiver); + + for (InstanceKey o : objs) { + NormalAllocationInNode instanceKey = (NormalAllocationInNode) o; + CGNode node = instanceKey.getNode(); + IMethod method = node.getMethod(); + IClass declaringClass = method.getDeclaringClass(); + TypeName declaringClassName = declaringClass.getName(); + final String packageName = "$" + declaringClassName.toString().substring(1); + TypeReference typeReference = + TypeReference.findOrCreateClass( + declaringClass.getClassLoader().getReference(), packageName, CALL); + IClass lookupClass = cha.lookupClass(typeReference); + + if (lookupClass != null) return lookupClass; + } + + return null; + } + + public PythonAnalysisEngine getEngine() { + return engine; + } }