Skip to content

Commit

Permalink
feature/ml_linearregression initial work for supporting linear regres…
Browse files Browse the repository at this point in the history
…sion ml model as a node
  • Loading branch information
greg-higgins committed Feb 4, 2024
1 parent 723e811 commit 51859f2
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 0 deletions.
58 changes: 58 additions & 0 deletions compiler/src/test/java/com/fluxtion/runtime/ml/RegressionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.fluxtion.runtime.ml;

import com.fluxtion.compiler.generation.util.CompiledAndInterpretedSepTest;
import com.fluxtion.compiler.generation.util.MultipleSepTargetInProcessTest;
import com.fluxtion.runtime.annotations.ExportService;
import com.fluxtion.runtime.annotations.OnEventHandler;
import lombok.Value;
import org.junit.Assert;
import org.junit.Test;

import java.util.Arrays;

public class RegressionTest extends MultipleSepTargetInProcessTest {
public RegressionTest(CompiledAndInterpretedSepTest.SepTestConfig testConfig) {
super(testConfig);
}

@Test
public void simpleTest() {
sep(c -> c.addNode(new PredictiveLinearRegressionModel(new AreaFeature()), "predictiveModel"));

//initial prediction is NaN
PredictiveModel predictiveModel = getField("predictiveModel");
Assert.assertTrue(Double.isNaN(predictiveModel.predictedValue()));

//set calibration prediction is 0
sep.getExportedService(CalibrationProcessor.class).setCalibration(
Arrays.asList(
Calibration.builder()
.featureClass(AreaFeature.class)
.weight(2)
.co_efficient(1.5)
.featureVersion(0)
.build()));
Assert.assertEquals(0, predictiveModel.predictedValue(), 0.000_1);

//send record to generate a prediction
onEvent(new HouseDetails(12, 3));
Assert.assertEquals(36, predictiveModel.predictedValue(), 0.000_1);

}

public static class AreaFeature extends AbstractFeature implements @ExportService CalibrationProcessor {

@OnEventHandler
public boolean processRecord(HouseDetails houseDetails) {
value = houseDetails.area * co_efficient * weight;
return true;
}

}

@Value
public static class HouseDetails {
double area;
double distance;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.fluxtion.runtime.annotations.feature;

/**
* Marks a class or method as a experimental feature. Mirrors the use of jdk experimental features:
* <p/>
* Experimental features represent early versions of (mostly) VM-level features, which can be risky, incomplete, or even
* unstable. In most cases, they need to be enabled using dedicated flags. For the purpose of comparison, if an
* experimental feature is considered 25% “done”, then a preview feature should be at least 95% “done”
*/
public @interface Experimental {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.fluxtion.runtime.annotations.feature;

/**
* Marks a class or method as a preview feature. Mirrors the use of jdk preview features:
* <p/>
* A preview feature is a new feature of the Java language, Java Virtual Machine, or Java SE API that is fully specified,
* fully implemented, and yet impermanent. It is available in a JDK feature release to provoke developer feedback based
* on real world use; this may lead to it becoming permanent in a future Java SE Platform.
*/
public @interface Preview {
}
40 changes: 40 additions & 0 deletions runtime/src/main/java/com/fluxtion/runtime/ml/AbstractFeature.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.fluxtion.runtime.ml;

import com.fluxtion.runtime.annotations.Initialise;
import com.fluxtion.runtime.annotations.feature.Experimental;

import java.util.List;

@Experimental
public abstract class AbstractFeature implements Feature, CalibrationProcessor {

protected double co_efficient;
protected double weight;
protected double value;

@Initialise
public void init() {
co_efficient = 0;
weight = 0;
value = 0;
}

@Override
public boolean setCalibration(List<Calibration> calibrations) {
for (int i = 0, calibrationsSize = calibrations.size(); i < calibrationsSize; i++) {
Calibration calibration = calibrations.get(i);
if (calibration.getFeatureIdentifier().equals(identifier())) {
co_efficient = calibration.getCo_efficient();
weight = calibration.getWeight();
return true;
}
}
return false;
}

@Override
public double value() {
return value;
}

}
20 changes: 20 additions & 0 deletions runtime/src/main/java/com/fluxtion/runtime/ml/Calibration.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.fluxtion.runtime.ml;

import com.fluxtion.runtime.annotations.feature.Experimental;
import lombok.Builder;
import lombok.Data;

@Data
@Builder
@Experimental
public class Calibration {
private String featureIdentifier;
private Class<? extends Feature> featureClass;
private int featureVersion;
private double co_efficient;
private double weight;

public String getFeatureIdentifier() {
return featureIdentifier == null ? featureClass.getSimpleName() : featureIdentifier;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.fluxtion.runtime.ml;

import java.util.List;

public interface CalibrationProcessor {

boolean setCalibration(List<Calibration> calibration);
}
25 changes: 25 additions & 0 deletions runtime/src/main/java/com/fluxtion/runtime/ml/Feature.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.fluxtion.runtime.ml;

import com.fluxtion.runtime.annotations.ExportService;
import com.fluxtion.runtime.annotations.feature.Experimental;
import com.fluxtion.runtime.node.NamedNode;

@Experimental
public interface Feature extends NamedNode, @ExportService CalibrationProcessor {

default String identifier() {
return getClass().getSimpleName();
}

default int version() {
return 0;
}

@Override
default String getName() {
return identifier() + "_" + version();
}

double value();

}
20 changes: 20 additions & 0 deletions runtime/src/main/java/com/fluxtion/runtime/ml/MutableDouble.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.fluxtion.runtime.ml;

import com.fluxtion.runtime.annotations.feature.Experimental;

@Experimental
public class MutableDouble {
double value;

public MutableDouble(double value) {
this.value = value;
}

public MutableDouble() {
this(Double.NaN);
}

void reset() {
value = Double.NaN;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.fluxtion.runtime.ml;

import com.fluxtion.runtime.annotations.*;
import com.fluxtion.runtime.annotations.feature.Experimental;

import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;

@Experimental
public class PredictiveLinearRegressionModel implements PredictiveModel, @ExportService CalibrationProcessor {

private transient final Map<Feature, MutableDouble> valueMap;
private final Feature[] features;
private double prediction = Double.NaN;

public PredictiveLinearRegressionModel(Feature... features) {
this.features = features;
this.valueMap = new IdentityHashMap<>(features.length);
for (Feature feature : features) {
valueMap.put(feature, new MutableDouble(0));
}
}

@Initialise
public void init() {
prediction = Double.NaN;
}

@Override
@NoPropagateFunction
public boolean setCalibration(List<Calibration> calibrations) {
double previousValue = prediction;
prediction = 0;
for (Feature feature : features) {
prediction += feature.value();
}
return previousValue != prediction | Double.isNaN(previousValue) != Double.isNaN(prediction);
}

@OnParentUpdate
public void featureUpdated(Feature featureUpdated) {
MutableDouble previousValue = valueMap.get(featureUpdated);
double newValue = featureUpdated.value();
prediction += newValue - previousValue.value;
previousValue.value = newValue;
}

@OnTrigger
public boolean calculateInference() {
return true;
}


@Override
public double predictedValue() {
return prediction;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.fluxtion.runtime.ml;

import com.fluxtion.runtime.annotations.feature.Experimental;

@Experimental
public interface PredictiveModel {

double predictedValue();
}

0 comments on commit 51859f2

Please sign in to comment.