Skip to content

Commit

Permalink
Add supernode for comparing a literal string (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
smarr authored Dec 3, 2023
2 parents a33ffee + 3d74296 commit 1defdd7
Show file tree
Hide file tree
Showing 10 changed files with 539 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import trufflesom.interpreter.nodes.ReturnNonLocalNode;
import trufflesom.interpreter.nodes.ReturnNonLocalNode.CatchNonLocalReturnNode;
import trufflesom.interpreter.nodes.literals.BlockNode;
import trufflesom.interpreter.nodes.specialized.IntIncrementNode;
import trufflesom.interpreter.supernodes.IntIncrementNode;
import trufflesom.primitives.Primitives;
import trufflesom.vmobjects.SClass;
import trufflesom.vmobjects.SInvokable;
Expand Down
57 changes: 54 additions & 3 deletions src/trufflesom/src/trufflesom/compiler/ParserAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
import trufflesom.bdt.basic.ProgramDefinitionError;
import trufflesom.bdt.inlining.InlinableNodes;
import trufflesom.bdt.tools.structure.StructuralProbe;
import trufflesom.interpreter.nodes.ArgumentReadNode.LocalArgumentReadNode;
import trufflesom.interpreter.nodes.ArgumentReadNode.NonLocalArgumentReadNode;
import trufflesom.interpreter.nodes.ExpressionNode;
import trufflesom.interpreter.nodes.FieldNode;
import trufflesom.interpreter.nodes.FieldNode.FieldReadNode;
import trufflesom.interpreter.nodes.GlobalNode;
import trufflesom.interpreter.nodes.MessageSendNode;
import trufflesom.interpreter.nodes.SequenceNode;
Expand All @@ -38,9 +41,14 @@
import trufflesom.interpreter.nodes.literals.GenericLiteralNode;
import trufflesom.interpreter.nodes.literals.IntegerLiteralNode;
import trufflesom.interpreter.nodes.literals.LiteralNode;
import trufflesom.interpreter.nodes.specialized.IntIncrementNodeGen;
import trufflesom.interpreter.supernodes.IntIncrementNodeGen;
import trufflesom.interpreter.supernodes.LocalFieldStringEqualsNode;
import trufflesom.interpreter.supernodes.NonLocalFieldStringEqualsNode;
import trufflesom.interpreter.supernodes.StringEqualsNodeGen;
import trufflesom.primitives.Primitives;
import trufflesom.vm.Globals;
import trufflesom.vm.NotYetImplementedException;
import trufflesom.vm.SymbolTable;
import trufflesom.vmobjects.SArray;
import trufflesom.vmobjects.SClass;
import trufflesom.vmobjects.SInvokable;
Expand Down Expand Up @@ -255,15 +263,58 @@ protected ExpressionNode binaryMessage(final MethodGenerationContext mgenc,
mgenc.getHolder().getSuperClass(), msg, args, coordWithL);
}

String binSelector = msg.getString();

if (binSelector.equals("=")) {
if (operand instanceof GenericLiteralNode) {
Object literal = operand.executeGeneric(null);
if (literal instanceof String s) {
if (receiver instanceof FieldReadNode fieldRead) {
ExpressionNode self = fieldRead.getSelf();
if (self instanceof LocalArgumentReadNode localSelf) {
return new LocalFieldStringEqualsNode(fieldRead.getFieldIndex(),
localSelf.getArg(), s).initialize(coordWithL);
} else if (self instanceof NonLocalArgumentReadNode arg) {
return new NonLocalFieldStringEqualsNode(fieldRead.getFieldIndex(), arg.getArg(),
arg.getContextLevel(), s).initialize(coordWithL);
} else {
throw new NotYetImplementedException();
}
}

return StringEqualsNodeGen.create(s, receiver).initialize(coordWithL);
}
}

if (receiver instanceof GenericLiteralNode) {
Object literal = receiver.executeGeneric(null);
if (literal instanceof String s) {
if (operand instanceof FieldReadNode fieldRead) {
ExpressionNode self = fieldRead.getSelf();
if (self instanceof LocalArgumentReadNode localSelf) {
return new LocalFieldStringEqualsNode(fieldRead.getFieldIndex(),
localSelf.getArg(), s).initialize(coordWithL);
} else if (self instanceof NonLocalArgumentReadNode arg) {
return new NonLocalFieldStringEqualsNode(fieldRead.getFieldIndex(), arg.getArg(),
arg.getContextLevel(), s).initialize(coordWithL);
} else {
throw new NotYetImplementedException();
}
}

return StringEqualsNodeGen.create(s, operand).initialize(coordWithL);
}
}
}

ExpressionNode inlined =
inlinableNodes.inline(msg, args, mgenc, coordWithL);
if (inlined != null) {
assert !isSuperSend;
return inlined;
}

if (msg.getString().equals("+") && operand instanceof IntegerLiteralNode) {
IntegerLiteralNode lit = (IntegerLiteralNode) operand;
if (msg == SymbolTable.symPlus && operand instanceof IntegerLiteralNode lit) {
if (lit.executeLong(null) == 1) {
return IntIncrementNodeGen.create(receiver);
}
Expand Down
7 changes: 0 additions & 7 deletions src/trufflesom/src/trufflesom/compiler/Variable.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package trufflesom.compiler;

import static com.oracle.truffle.api.CompilerDirectives.transferToInterpreterAndInvalidate;
import static trufflesom.compiler.bc.BytecodeGenerator.emitPOPARGUMENT;
import static trufflesom.compiler.bc.BytecodeGenerator.emitPOPLOCAL;
import static trufflesom.compiler.bc.BytecodeGenerator.emitPUSHARGUMENT;
Expand Down Expand Up @@ -116,8 +115,6 @@ public Local splitToMergeIntoOuterScope(final int newSlotIndex) {

@Override
public ExpressionNode getReadNode(final int contextLevel, final long coordinate) {
transferToInterpreterAndInvalidate();

if (contextLevel == 0) {
return new LocalArgumentReadNode(this).initialize(coordinate);
} else {
Expand All @@ -128,8 +125,6 @@ public ExpressionNode getReadNode(final int contextLevel, final long coordinate)
@Override
public ExpressionNode getWriteNode(final int contextLevel,
final ExpressionNode valueExpr, final long coordinate) {
transferToInterpreterAndInvalidate();

if (contextLevel == 0) {
return new LocalArgumentWriteNode(this, valueExpr).initialize(coordinate);
} else {
Expand Down Expand Up @@ -172,7 +167,6 @@ public void init(final FrameDescriptor desc) {

@Override
public ExpressionNode getReadNode(final int contextLevel, final long coordinate) {
transferToInterpreterAndInvalidate();
if (contextLevel > 0) {
return NonLocalVariableReadNodeGen.create(contextLevel, this).initialize(coordinate);
}
Expand All @@ -196,7 +190,6 @@ public Local splitToMergeIntoOuterScope(final int newSlotIndex) {
@Override
public ExpressionNode getWriteNode(final int contextLevel,
final ExpressionNode valueExpr, final long coordinate) {
transferToInterpreterAndInvalidate();
if (contextLevel > 0) {
return NonLocalVariableWriteNodeGen.create(contextLevel, this, valueExpr)
.initialize(coordinate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) {
inliner.updateRead(arg, this, 0);
}

public Argument getArg() {
return arg;
}

@Override
public SSymbol getInvocationIdentifier() {
return arg.name;
Expand Down Expand Up @@ -118,6 +122,10 @@ public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) {
inliner.updateRead(arg, this, contextLevel);
}

public Argument getArg() {
return arg;
}

@Override
public SSymbol getInvocationIdentifier() {
return arg.name;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package trufflesom.interpreter.nodes.specialized;
package trufflesom.interpreter.supernodes;

import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package trufflesom.interpreter.supernodes;

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.UnexpectedResultException;

import trufflesom.bdt.inlining.ScopeAdaptationVisitor;
import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement;
import trufflesom.compiler.Variable.Argument;
import trufflesom.interpreter.bc.RespecializeException;
import trufflesom.interpreter.nodes.ArgumentReadNode.LocalArgumentReadNode;
import trufflesom.interpreter.nodes.ExpressionNode;
import trufflesom.interpreter.nodes.FieldNode.FieldReadNode;
import trufflesom.interpreter.nodes.GenericMessageSendNode;
import trufflesom.interpreter.nodes.MessageSendNode;
import trufflesom.interpreter.nodes.bc.BytecodeLoopNode;
import trufflesom.interpreter.nodes.literals.GenericLiteralNode;
import trufflesom.interpreter.objectstorage.FieldAccessorNode;
import trufflesom.interpreter.objectstorage.FieldAccessorNode.AbstractReadFieldNode;
import trufflesom.interpreter.objectstorage.ObjectLayout;
import trufflesom.interpreter.objectstorage.StorageLocation;
import trufflesom.vm.SymbolTable;
import trufflesom.vm.VmSettings;
import trufflesom.vm.constants.Nil;
import trufflesom.vmobjects.SObject;


public final class LocalFieldStringEqualsNode extends ExpressionNode {

private final int fieldIdx;
private final String value;
protected final Argument arg;

@Child private AbstractReadFieldNode readFieldNode;

@CompilationFinal private int state;

public LocalFieldStringEqualsNode(final int fieldIdx, final Argument arg,
final String value) {
this.fieldIdx = fieldIdx;
this.arg = arg;
this.value = value;

this.state = 0;
}

@Override
public Object executeGeneric(final VirtualFrame frame) {
try {
SObject rcvr = (SObject) frame.getArguments()[0];
return executeEvaluated(frame, rcvr);
} catch (UnexpectedResultException e) {
return e.getResult();
}
}

@Override
public Object doPreEvaluated(final VirtualFrame frame, final Object[] args) {
try {
return executeEvaluated(frame, (SObject) args[0]);
} catch (UnexpectedResultException e) {
return e.getResult();
}
}

public boolean executeEvaluated(final VirtualFrame frame, final SObject rcvr)
throws UnexpectedResultException {
int currentState = state;

if (state == 0) {
// uninitialized
CompilerDirectives.transferToInterpreterAndInvalidate();
final ObjectLayout layout = rcvr.getObjectLayout();
StorageLocation location = layout.getStorageLocation(fieldIdx);

readFieldNode =
insert(location.getReadNode(fieldIdx, layout,
FieldAccessorNode.createRead(fieldIdx)));
}

Object result = readFieldNode.read(rcvr);

if ((state & 0b1) != 0) {
// we saw a string before
if (result instanceof String) {
return ((String) result).equals(value);
}
}

if ((state & 0b10) != 0) {
// we saw a nil before
if (result == Nil.nilObject) {
return false;
}
}

CompilerDirectives.transferToInterpreterAndInvalidate();
return specialize(frame, result, currentState);
}

@Override
public boolean executeBoolean(final VirtualFrame frame) throws UnexpectedResultException {
SObject rcvr = (SObject) frame.getArguments()[0];

return executeEvaluated(frame, rcvr);
}

private boolean specialize(final VirtualFrame frame, final Object result,
final int currentState) throws UnexpectedResultException {
if (result instanceof String) {
state = currentState | 0b1;
return value.equals(result);
}

if (result == Nil.nilObject) {
state = currentState | 0b10;
return false;
}

Object sendResult =
makeGenericSend(result).doPreEvaluated(frame, new Object[] {result, value});
if (sendResult instanceof Boolean) {
return (Boolean) sendResult;
}
throw new UnexpectedResultException(sendResult);
}

public GenericMessageSendNode makeGenericSend(
@SuppressWarnings("unused") final Object receiver) {
GenericMessageSendNode send =
MessageSendNode.createGeneric(SymbolTable.symbolFor("="),
new ExpressionNode[] {new FieldReadNode(new LocalArgumentReadNode(arg), fieldIdx),
new GenericLiteralNode(value)},
sourceCoord);

if (VmSettings.UseAstInterp) {
replace(send);
send.notifyDispatchInserted();
return send;
}

assert getParent() instanceof BytecodeLoopNode : "This node was expected to be a direct child of a `BytecodeLoopNode`.";
throw new RespecializeException(send);
}

@Override
public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) {
ScopeElement<? extends Node> se = inliner.getAdaptedVar(arg);
if (se.var != arg || se.contextLevel < 0) {
Node newNode;
if (se.contextLevel == 0) {
newNode =
new LocalFieldStringEqualsNode(fieldIdx, (Argument) se.var, value).initialize(
fieldIdx);
} else {
newNode = new NonLocalFieldStringEqualsNode(fieldIdx, (Argument) se.var,
se.contextLevel, value).initialize(fieldIdx);
}

replace(newNode);
} else {
assert 0 == se.contextLevel;
}
}
}
Loading

0 comments on commit 1defdd7

Please sign in to comment.