Skip to content

Commit

Permalink
Merge pull request #58 from ponder-lab/24-__call__-not-supported
Browse files Browse the repository at this point in the history
24   call   not supported
  • Loading branch information
khatchad authored Nov 30, 2023
2 parents 791d054 + 8976e5f commit 0aae467
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.ibm.wala.cast.python.client;

import com.ibm.wala.cast.python.ipa.callgraph.PythonSSAPropagationCallGraphBuilder;
import com.ibm.wala.cast.python.loader.PytestLoader;
import com.ibm.wala.cast.python.loader.PytestLoaderFactory;
import com.ibm.wala.classLoader.CallSiteReference;
Expand All @@ -10,7 +9,6 @@
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.IAnalysisCacheView;
import com.ibm.wala.ipa.callgraph.MethodTargetSelector;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
Expand All @@ -23,8 +21,6 @@

public class PytestAnalysisEngine<T> extends PythonAnalysisEngine<T> {

private PythonSSAPropagationCallGraphBuilder builder;

private class PytestTargetSelector implements MethodTargetSelector {
private final MethodTargetSelector base;

Expand Down Expand Up @@ -88,12 +84,6 @@ protected void addBypassLogic(IClassHierarchy cha, AnalysisOptions options) {
addSummaryBypassLogic(options, "pytest.xml");
}

@Override
protected PythonSSAPropagationCallGraphBuilder getCallGraphBuilder(
IClassHierarchy cha, AnalysisOptions options, IAnalysisCacheView cache) {
return builder = super.getCallGraphBuilder(cha, options, cache);
}

@Override
public T performAnalysis(PropagationCallGraphBuilder arg0) throws CancelException {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,13 @@ public void testTf2()
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
testTf2(
"tf2_test_model_call.py",
"SequentialModel.__call__",
0,
2); // NOTE: Change to testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4,
// 2) once
// https://github.com/wala/ML/issues/24 is fixed.
testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 2);
testTf2(
"tf2_test_model_call2.py",
"SequentialModel.call",
0,
2); // NOTE: Change to testTf2("tf2_test_model_call2.py", "SequentialModel.call", 1, 4, 2)
// once
// https://github.com/wala/ML/issues/106 is fixed.
// once https://github.com/wala/ML/issues/106 is fixed.
testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 2);
testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 2);
}
Expand Down
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C:

def __call__(self, x):
return x * x


c = C()
a = c.__call__(5)
assert a == 25
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C:

def __call__(self, x):
return x * x


c = C()
a = c(5)
assert a == 25
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C(object):

def __call__(self, x):
return x * x


c = C()
a = c.__call__(5)
assert a == 25
9 changes: 9 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class C(object):

def __call__(self, x):
return x * x


c = C()
a = c(5)
assert a == 25
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class D:
pass


class C(D):

def __call__(self, x):
return x * x


c = C()
a = c.__call__(5)
assert a == 25
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/callables6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class D:
pass


class C(D):

def __call__(self, x):
return x * x


c = C()
a = c(5)
assert a == 25
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.ibm.wala.cast.python.test;

import static org.junit.Assert.assertTrue;

import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.cha.ClassHierarchyException;
import com.ibm.wala.util.CancelException;
import java.io.IOException;
import java.util.Iterator;
import java.util.logging.Logger;
import org.junit.Test;

public class TestCallables extends TestPythonCallGraphShape {

private static Logger logger = Logger.getLogger(TestCallables.class.getName());

@Test
public void testCallables()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
final String[] testFileNames = {
"callables.py",
"callables2.py",
"callables3.py",
"callables4.py",
"callables5.py",
"callables6.py"
};

for (String fileName : testFileNames) {
CallGraph CG = process(fileName);
boolean found = false;

for (CGNode node : CG) {
if (node.getMethod()
.getDeclaringClass()
.getName()
.toString()
.equals("Lscript " + fileName)) {

for (Iterator<CGNode> it = CG.getSuccNodes(node); it.hasNext(); ) {
CGNode callee = it.next();

logger.info("Found callee: " + callee.getMethod().getSignature());

if (callee
.getMethod()
.getDeclaringClass()
.getName()
.toString()
.equals("L$script " + fileName + "/C/__call__")) found = true;
}
}
}

assertTrue("Expecting to find __call__ method trampoline in: " + fileName + ".", found);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ public abstract class PythonAnalysisEngine<T>

private static final Logger logger = Logger.getLogger(PythonAnalysisEngine.class.getName());

protected PythonSSAPropagationCallGraphBuilder builder;

static {
try {
Class<?> j3 = Class.forName("com.ibm.wala.cast.python.loader.Python3LoaderFactory");
Expand Down Expand Up @@ -286,9 +288,10 @@ public boolean isReferenceType() {

protected void addBypassLogic(IClassHierarchy cha, AnalysisOptions options) {
options.setSelector(
new PythonTrampolineTargetSelector(
new PythonTrampolineTargetSelector<T>(
new PythonConstructorTargetSelector(
new PythonComprehensionTrampolines(options.getMethodTargetSelector()))));
new PythonComprehensionTrampolines(options.getMethodTargetSelector())),
this));

BuiltinFunctions builtins = new BuiltinFunctions(cha);
options.setSelector(builtins.builtinClassTargetSelector(options.getClassTargetSelector()));
Expand Down Expand Up @@ -349,7 +352,11 @@ public int getDefaultValue(SymbolTable symtab, int valueNumber) {

new PythonSuper(cha).handleSuperCalls(builder, options);

return builder;
return this.builder = builder;
}

public PythonSSAPropagationCallGraphBuilder getCachedCallGraphBuilder() {
return this.builder;
}

protected PythonSSAPropagationCallGraphBuilder makeBuilder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IField;
import com.ibm.wala.classLoader.IMethod;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.fixpoint.AbstractOperator;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.IAnalysisCacheView;
import com.ibm.wala.ipa.callgraph.propagation.AbstractFieldPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerKeyFactory;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
Expand All @@ -37,11 +40,15 @@
import com.ibm.wala.ssa.SSAGetInstruction;
import com.ibm.wala.ssa.SymbolTable;
import com.ibm.wala.types.FieldReference;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.types.TypeReference;
import com.ibm.wala.util.collections.HashMapFactory;
import com.ibm.wala.util.collections.Pair;
import com.ibm.wala.util.intset.IntIterator;
import com.ibm.wala.util.intset.IntSet;
import com.ibm.wala.util.intset.IntSetUtil;
import com.ibm.wala.util.intset.MutableIntSet;
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
Expand Down Expand Up @@ -211,7 +218,19 @@ protected void processCallingConstraints(
}
} else {
PointerKey rval = getPointerKeyForLocal(caller, call.getUse(i));
getSystem().newConstraint(lval, assignOperator, rval);

// If we are looking at the implicit parameter of a callable.
if (call.getCallSite().isDispatch() && i == 0 && isCallable(rval)) {
// Ensure that lval's variable refers to the callable method instead of callable object.
IClass callable = getCallable(target);
IntSet instanceKeysForCallable = this.getSystem().getInstanceKeysForClass(callable);

for (IntIterator it = instanceKeysForCallable.intIterator(); it.hasNext(); ) {
int instanceKeyIndex = it.next();
InstanceKey instanceKey = this.getSystem().getInstanceKey(instanceKeyIndex);
this.getSystem().newConstraint(lval, instanceKey);
}
} else getSystem().newConstraint(lval, assignOperator, rval);
}
}

Expand Down Expand Up @@ -271,6 +290,33 @@ protected void processCallingConstraints(
}
}

private IClass getCallable(CGNode target) {
IMethod method = target.getMethod();
IClass declaringClass = method.getDeclaringClass();
TypeName declaringClassName = declaringClass.getName();

TypeReference typeReference =
TypeReference.findOrCreate(
declaringClass.getClassLoader().getReference(), declaringClassName);

return this.getClassHierarchy().lookupClass(typeReference);
}

protected boolean isCallable(PointerKey rval) {
PointerAnalysis<InstanceKey> pointerAnalysis = this.getPointerAnalysis();
OrdinalSet<InstanceKey> pointsToSet = pointerAnalysis.getPointsToSet(rval);

for (InstanceKey instanceKey : pointsToSet) {
IClass concreteType = instanceKey.getConcreteType();
TypeReference reference = concreteType.getReference();

// If it's an "object" method.
if (reference.equals(PythonTypes.object)) return true;
}

return false;
}

@Override
public PythonConstraintVisitor makeVisitor(CGNode node) {
return new PythonConstraintVisitor(this, node);
Expand Down
Loading

0 comments on commit 0aae467

Please sign in to comment.