Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Streams Work In Progress #1349

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package ai.timefold.solver.core.impl.bavet;

import java.util.IdentityHashMap;
import java.util.Map;

import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;

public abstract class AbstractSession {

private final NodeNetwork nodeNetwork;
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> effectiveClassToNodeArrayMap;

protected AbstractSession(NodeNetwork nodeNetwork) {
this.nodeNetwork = nodeNetwork;
this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
}

public final void insert(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.insert(fact);
}
}

@SuppressWarnings("unchecked")
private AbstractForEachUniNode<Object>[] findNodes(Class<?> factClass) {
// Map.computeIfAbsent() would have created lambdas on the hot path, this will not.
var nodeArray = effectiveClassToNodeArrayMap.get(factClass);
if (nodeArray == null) {
nodeArray = nodeNetwork.getForEachNodes(factClass)
.filter(AbstractForEachUniNode::supportsIndividualUpdates)
.toArray(AbstractForEachUniNode[]::new);
effectiveClassToNodeArrayMap.put(factClass, nodeArray);
}
return nodeArray;
}

public final void update(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.update(fact);
}
}

public final void retract(Object fact) {
var factClass = fact.getClass();
for (var node : findNodes(factClass)) {
node.retract(fact);
}
}

protected void settle() {
nodeNetwork.settle();
}

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package ai.timefold.solver.core.impl.score.stream.bavet;
package ai.timefold.solver.core.impl.bavet;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;

import ai.timefold.solver.core.impl.bavet.common.Propagator;
import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;
Expand All @@ -17,7 +18,7 @@
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer;
* propagation needs to happen in this order.
*/
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<?>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {
public record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<?>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {

public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]);

Expand All @@ -29,23 +30,27 @@ public int layerCount() {
return layeredNodes.length;
}

@SuppressWarnings("unchecked")
public AbstractForEachUniNode<Object>[] getApplicableForEachNodes(Class<?> factClass) {
public Stream<AbstractForEachUniNode<?>> getForEachNodes() {
return declaredClassToNodeMap.values()
.stream()
.flatMap(List::stream);
}

public Stream<AbstractForEachUniNode<?>> getForEachNodes(Class<?> factClass) {
return declaredClassToNodeMap.entrySet()
.stream()
.filter(entry -> entry.getKey().isAssignableFrom(factClass))
.map(Map.Entry::getValue)
.flatMap(List::stream)
.toArray(AbstractForEachUniNode[]::new);
.flatMap(List::stream);
}

public void propagate() {
public void settle() {
for (var layerIndex = 0; layerIndex < layerCount(); layerIndex++) {
propagateInLayer(layeredNodes[layerIndex]);
settleLayer(layeredNodes[layerIndex]);
}
}

private static void propagateInLayer(Propagator[] nodesInLayer) {
private static void settleLayer(Propagator[] nodesInLayer) {
var nodeCount = nodesInLayer.length;
if (nodeCount == 1) {
nodesInLayer[0].propagateEverything();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,41 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;
import java.util.function.Function;

import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.stream.ConstraintStream;
import ai.timefold.solver.core.impl.bavet.common.tuple.AbstractTuple;
import ai.timefold.solver.core.impl.bavet.common.tuple.LeftTupleLifecycle;
import ai.timefold.solver.core.impl.bavet.common.tuple.RightTupleLifecycle;
import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle;
import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode;
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream;
import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner;

public final class NodeBuildHelper<Score_ extends Score<Score_>> {
public abstract class AbstractNodeBuildHelper<Stream_ extends BavetStream> {

private final Set<? extends ConstraintStream> activeStreamSet;
private final AbstractScoreInliner<Score_> scoreInliner;
private final Map<AbstractNode, BavetAbstractConstraintStream<?>> nodeCreatorMap;
private final Map<ConstraintStream, TupleLifecycle<? extends AbstractTuple>> tupleLifecycleMap;
private final Map<ConstraintStream, Integer> storeIndexMap;
private final Set<Stream_> activeStreamSet;
private final Map<AbstractNode, Stream_> nodeCreatorMap;
private final Map<Stream_, TupleLifecycle<? extends AbstractTuple>> tupleLifecycleMap;
private final Map<Stream_, Integer> storeIndexMap;

private List<AbstractNode> reversedNodeList;

public NodeBuildHelper(Set<? extends ConstraintStream> activeStreamSet, AbstractScoreInliner<Score_> scoreInliner) {
public AbstractNodeBuildHelper(Set<Stream_> activeStreamSet) {
this.activeStreamSet = activeStreamSet;
this.scoreInliner = scoreInliner;
int activeStreamSetSize = activeStreamSet.size();
this.nodeCreatorMap = new HashMap<>(Math.max(16, activeStreamSetSize));
this.tupleLifecycleMap = new HashMap<>(Math.max(16, activeStreamSetSize));
this.storeIndexMap = new HashMap<>(Math.max(16, activeStreamSetSize / 2));
this.reversedNodeList = new ArrayList<>(activeStreamSetSize);
}

public boolean isStreamActive(ConstraintStream stream) {
public boolean isStreamActive(Stream_ stream) {
return activeStreamSet.contains(stream);
}

public AbstractScoreInliner<Score_> getScoreInliner() {
return scoreInliner;
}

public void addNode(AbstractNode node, BavetAbstractConstraintStream<?> creator) {
public void addNode(AbstractNode node, Stream_ creator) {
addNode(node, creator, creator);
}

public void addNode(AbstractNode node, BavetAbstractConstraintStream<?> creator, BavetAbstractConstraintStream<?> parent) {
public void addNode(AbstractNode node, Stream_ creator, Stream_ parent) {
reversedNodeList.add(node);
nodeCreatorMap.put(node, creator);
if (!(node instanceof AbstractForEachUniNode<?>)) {
Expand All @@ -59,51 +51,50 @@ public void addNode(AbstractNode node, BavetAbstractConstraintStream<?> creator,
}
}

public <Solution_, LeftTuple_ extends AbstractTuple, RightTuple_ extends AbstractTuple> void addNode(
AbstractTwoInputNode<LeftTuple_, RightTuple_> node, BavetAbstractConstraintStream<?> creator,
BavetAbstractConstraintStream<Solution_> leftParent, BavetAbstractConstraintStream<Solution_> rightParent) {
public void addNode(AbstractNode node, Stream_ creator, Stream_ leftParent, Stream_ rightParent) {
reversedNodeList.add(node);
nodeCreatorMap.put(node, creator);
putInsertUpdateRetract(leftParent, TupleLifecycle.ofLeft(node));
putInsertUpdateRetract(rightParent, TupleLifecycle.ofRight(node));
putInsertUpdateRetract(leftParent, TupleLifecycle.ofLeft((LeftTupleLifecycle<? extends AbstractTuple>) node));
putInsertUpdateRetract(rightParent, TupleLifecycle.ofRight((RightTupleLifecycle<? extends AbstractTuple>) node));
}

public <Tuple_ extends AbstractTuple> void putInsertUpdateRetract(ConstraintStream stream,
public <Tuple_ extends AbstractTuple> void putInsertUpdateRetract(Stream_ stream,
TupleLifecycle<Tuple_> tupleLifecycle) {
tupleLifecycleMap.put(stream, tupleLifecycle);
}

public <Tuple_ extends AbstractTuple> void putInsertUpdateRetract(ConstraintStream stream,
List<? extends AbstractConstraintStream<?>> childStreamList,
UnaryOperator<TupleLifecycle<Tuple_>> tupleLifecycleFunction) {
public <Tuple_ extends AbstractTuple> void putInsertUpdateRetract(Stream_ stream, List<? extends Stream_> childStreamList,
Function<TupleLifecycle<Tuple_>, TupleLifecycle<Tuple_>> tupleLifecycleFunction) {
TupleLifecycle<Tuple_> tupleLifecycle = getAggregatedTupleLifecycle(childStreamList);
putInsertUpdateRetract(stream, tupleLifecycleFunction.apply(tupleLifecycle));
}

@SuppressWarnings("unchecked")
public <Tuple_ extends AbstractTuple> TupleLifecycle<Tuple_>
getAggregatedTupleLifecycle(List<? extends ConstraintStream> streamList) {
getAggregatedTupleLifecycle(List<? extends Stream_> streamList) {
var tupleLifecycles = streamList.stream()
.filter(this::isStreamActive)
.map(s -> getTupleLifecycle(s, tupleLifecycleMap))
.toArray(TupleLifecycle[]::new);
if (tupleLifecycles.length == 0) {
throw new IllegalStateException("Impossible state: None of the streamList (%s) are active.".formatted(streamList));
}
return TupleLifecycle.aggregate(tupleLifecycles);
return switch (tupleLifecycles.length) {
case 0 ->
throw new IllegalStateException("Impossible state: None of the streamList (" + streamList + ") are active.");
case 1 -> tupleLifecycles[0];
default -> TupleLifecycle.aggregate(tupleLifecycles);
};
}

@SuppressWarnings("unchecked")
private static <Tuple_ extends AbstractTuple> TupleLifecycle<Tuple_> getTupleLifecycle(ConstraintStream stream,
Map<ConstraintStream, TupleLifecycle<? extends AbstractTuple>> tupleLifecycleMap) {
private static <Stream_, Tuple_ extends AbstractTuple> TupleLifecycle<Tuple_> getTupleLifecycle(Stream_ stream,
Map<Stream_, TupleLifecycle<? extends AbstractTuple>> tupleLifecycleMap) {
var tupleLifecycle = (TupleLifecycle<Tuple_>) tupleLifecycleMap.get(stream);
if (tupleLifecycle == null) {
throw new IllegalStateException("Impossible state: the stream (" + stream + ") hasn't built a node yet.");
}
return tupleLifecycle;
}

public int reserveTupleStoreIndex(ConstraintStream tupleSourceStream) {
public int reserveTupleStoreIndex(Stream_ tupleSourceStream) {
return storeIndexMap.compute(tupleSourceStream, (k, index) -> {
if (index == null) {
return 0;
Expand All @@ -116,7 +107,7 @@ public int reserveTupleStoreIndex(ConstraintStream tupleSourceStream) {
});
}

public int extractTupleStoreSize(ConstraintStream tupleSourceStream) {
public int extractTupleStoreSize(Stream_ tupleSourceStream) {
Integer lastIndex = storeIndexMap.put(tupleSourceStream, Integer.MIN_VALUE);
return (lastIndex == null) ? 0 : lastIndex + 1;
}
Expand All @@ -128,17 +119,17 @@ public List<AbstractNode> destroyAndGetNodeList() {
return nodeList;
}

public BavetAbstractConstraintStream<?> getNodeCreatingStream(AbstractNode node) {
public Stream_ getNodeCreatingStream(AbstractNode node) {
return nodeCreatorMap.get(node);
}

public AbstractNode findParentNode(BavetAbstractConstraintStream<?> childNodeCreator) {
public AbstractNode findParentNode(Stream_ childNodeCreator) {
if (childNodeCreator == null) { // We've recursed to the bottom without finding a parent node.
throw new IllegalStateException(
"Impossible state: node-creating stream (" + childNodeCreator + ") has no parent node.");
}
// Look the stream up among node creators and if found, the node is the parent node.
for (Map.Entry<AbstractNode, BavetAbstractConstraintStream<?>> entry : this.nodeCreatorMap.entrySet()) {
for (Map.Entry<AbstractNode, Stream_> entry : this.nodeCreatorMap.entrySet()) {
if (entry.getValue() == childNodeCreator) {
return entry.getKey();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
import ai.timefold.solver.core.api.score.stream.Constraint;
import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraint;
import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintFactory;
import ai.timefold.solver.core.impl.score.stream.bavet.common.BavetScoringConstraintStream;
import ai.timefold.solver.core.impl.score.stream.bavet.common.ConstraintNodeBuildHelper;
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStream;
import ai.timefold.solver.core.impl.score.stream.common.RetrievalSemantics;
import ai.timefold.solver.core.impl.score.stream.common.ScoreImpactType;

import org.jspecify.annotations.NonNull;

public abstract class BavetAbstractConstraintStream<Solution_> extends AbstractConstraintStream<Solution_> {
public abstract class BavetAbstractConstraintStream<Solution_>
extends AbstractConstraintStream<Solution_>
implements BavetStream {

protected final BavetConstraintFactory<Solution_> constraintFactory;
protected final BavetAbstractConstraintStream<Solution_> parent;
Expand Down Expand Up @@ -109,7 +113,7 @@ public BavetAbstractConstraintStream<Solution_> getTupleSource() {
return parent.getTupleSource();
}

public abstract <Score_ extends Score<Score_>> void buildNode(NodeBuildHelper<Score_> buildHelper);
public abstract <Score_ extends Score<Score_>> void buildNode(ConstraintNodeBuildHelper<Solution_, Score_> buildHelper);

// ************************************************************************
// Helper methods
Expand All @@ -135,6 +139,7 @@ protected void assertEmptyChildStreamList() {
* @return null for join/ifExists nodes, which have left and right parents instead;
* also null for forEach node, which has no parent.
*/
@Override
public final BavetAbstractConstraintStream<Solution_> getParent() {
return parent;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ai.timefold.solver.core.impl.bavet.common;

public interface BavetStream {

<Stream_> Stream_ getParent();

}
Loading
Loading