Skip to content

Commit

Permalink
Added support for custom term frequencies
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jun 24, 2021
1 parent df153c1 commit cc71a88
Show file tree
Hide file tree
Showing 8 changed files with 6,595 additions and 1,522 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.jpmml.translator.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -462,15 +463,31 @@ public Number apply(FunctionInvocationPredictor tfTerm){

JVar coefficientVar = context.declare(context.ref(Number.class), "coefficient", coefficientsVar.invoke("get").arg(indexVar));

List<JVar> factorVars = new ArrayList<>();
factorVars.add(coefficientVar);

if(weightsVar != null){
JVar weightVar = context.declare(context.ref(Number.class), "weight", weightsVar.invoke("get").arg(indexVar));

valueBuilder.update("add", coefficientVar, weightVar, frequencyVar);
} else
factorVars.add(weightVar);
}

{
valueBuilder.update("add", coefficientVar, frequencyVar);
TextIndex.LocalTermWeights localTermWeights = textIndex.getLocalTermWeights();
switch(localTermWeights){
case BINARY:
break;
case TERM_FREQUENCY:
factorVars.add(frequencyVar);
break;
case LOGARITHMIC:
JVar logFrequencyVar = context.declare(context.ref(Double.class), "logFrequency", context.staticInvoke(Math.class, "log10", JExpr.lit(1).plus(frequencyVar)));
factorVars.add(logFrequencyVar);
break;
default:
throw new UnsupportedAttributeException(localTextIndex, localTermWeights);
}

valueBuilder.update("add", factorVars.toArray(new JVar[factorVars.size()]));
} finally {
context.popScope();
}
Expand Down
1 change: 1 addition & 0 deletions src/test/java/org/jpmml/transpiler/Algorithms.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public interface Algorithms {
String GRADIENT_BOOSTING = "GradientBoosting";
String ISOLATION_FOREST = "IsolationForest";
String LIGHT_GBM = "LightGBM";
String LINEAR_DISCRIMINANT_ANALYSIS = "LinearDiscriminantAnalysis";
String LINEAR_REGRESSION = "LinearRegression";
String LINEAR_SVC = "LinearSVC";
String LOGISTIC_REGRESSION = "LogisticRegression";
Expand Down
5 changes: 5 additions & 0 deletions src/test/java/org/jpmml/transpiler/ClassificationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public void evaluateXGBoostAuditNA() throws Exception {
evaluate(XGBOOST, AUDIT_NA, excludeFields(AUDIT_PROBABILITY_FALSE), new FloatEquivalence(8));
}

@Test
public void evaluateLinearDiscriminantAnalysisSentiment() throws Exception {
evaluate(LINEAR_DISCRIMINANT_ANALYSIS, SENTIMENT);
}

@Test
public void evaluateLinearSVCSentiment() throws Exception {
evaluate(LINEAR_SVC, SENTIMENT);
Expand Down
Loading

0 comments on commit cc71a88

Please sign in to comment.