diff --git a/matching-eval/src/main/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetric.java b/matching-eval/src/main/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetric.java index 9247593727..ac80617d9d 100644 --- a/matching-eval/src/main/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetric.java +++ b/matching-eval/src/main/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetric.java @@ -245,27 +245,27 @@ public ConfusionMatrix getMacroAveragesForResults(Iterable resu Alignment truePositive = new Alignment(); Alignment falsePositive = new Alignment(); Alignment falseNegative = new Alignment(); - - double precision = 0.0; // dummy init - double recall = 0.0; // dummy init - - // for aggregation: + int numberOfCorrespondences = 0; + double aggregatedPrecision = 0.0; + double aggregatedRecall = 0.0; + double aggregatedF1 = 0.0; + for (ConfusionMatrix individualConfusionMatrix : confusionMatrices) { truePositive.addAll(individualConfusionMatrix.getTruePositive()); falsePositive.addAll(individualConfusionMatrix.getFalsePositive()); falseNegative.addAll(individualConfusionMatrix.getFalseNegative()); - } - - double aggregatedPrecision = 0.0; - double aggregatedRecall = 0.0; - for (ConfusionMatrix individualConfusionMatrix : confusionMatrices) { + + numberOfCorrespondences += individualConfusionMatrix.getNumberOfCorrespondences(); + aggregatedPrecision = aggregatedPrecision + individualConfusionMatrix.getPrecision(); aggregatedRecall = aggregatedRecall + individualConfusionMatrix.getRecall(); + aggregatedF1 = aggregatedF1 + individualConfusionMatrix.getF1measure(); } - precision = aggregatedPrecision / numberOfTestCases; - recall = aggregatedRecall / numberOfTestCases; - - return new ConfusionMatrix(truePositive, falsePositive, falseNegative, precision, recall); + double precision = aggregatedPrecision / numberOfTestCases; + double recall = aggregatedRecall / numberOfTestCases; + double f1 = aggregatedF1 / numberOfTestCases; + + return new ConfusionMatrixMacroAveraged(truePositive, falsePositive, falseNegative, numberOfCorrespondences, precision, recall, f1); } diff --git a/matching-eval/src/test/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetricTest.java b/matching-eval/src/test/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetricTest.java index 281622156e..e740e32e8d 100644 --- a/matching-eval/src/test/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetricTest.java +++ b/matching-eval/src/test/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/ConfusionMatrixMetricTest.java @@ -1,5 +1,7 @@ package de.uni_mannheim.informatik.dws.melt.matching_eval.evaluator.metric.cm; +import de.uni_mannheim.informatik.dws.melt.matching_data.GoldStandardCompleteness; +import de.uni_mannheim.informatik.dws.melt.matching_data.LocalTrack; import de.uni_mannheim.informatik.dws.melt.matching_data.TestCase; import de.uni_mannheim.informatik.dws.melt.matching_data.TrackRepository; import de.uni_mannheim.informatik.dws.melt.matching_eval.ExecutionResultSet; @@ -9,6 +11,14 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Predicate; +import org.apache.jena.ontology.OntModel; +import org.apache.jena.ontology.OntModelSpec; +import org.apache.jena.rdf.model.ModelFactory; +import org.apache.jena.vocabulary.OWL; import static org.junit.jupiter.api.Assertions.*; @@ -242,4 +252,172 @@ void realTest() { assertEquals(0.615, confusionMatrix1Dome.getRecall(), 0.001); } + + @Test + void micromacroTest() { + ConfusionMatrixMetric metric = new ConfusionMatrixMetric(); + double delta = 0.01; + //https://iamirmasoud.com/2022/06/19/understanding-micro-macro-and-weighted-averages-for-scikit-learn-metrics-in-multi-class-classification-with-example/ + ExecutionResult resultFirst = createResultWith("testCaseA", GoldStandardCompleteness.COMPLETE, + 2,1,1,0,0,0,0,0,0); + ExecutionResult resultSecond = createResultWith("testCaseB", GoldStandardCompleteness.COMPLETE, + 1,3,0,0,0,0,0,0,0); + ExecutionResult resultThird = createResultWith("testCaseC", GoldStandardCompleteness.COMPLETE, + 3,0,3,0,0,0,0,0,0); + ExecutionResultSet all = new ExecutionResultSet(); + all.add(resultFirst); + all.add(resultSecond); + all.add(resultThird); + + ConfusionMatrix confusionMatrixFirst = metric.compute(resultFirst); + ConfusionMatrix confusionMatrixSecond = metric.compute(resultSecond); + ConfusionMatrix confusionMatrixThird = metric.compute(resultThird); + + assertEquals(2, confusionMatrixFirst.getTruePositiveSize()); + assertEquals(1, confusionMatrixFirst.getFalsePositiveSize()); + assertEquals(1, confusionMatrixFirst.getFalseNegativeSize()); + assertEquals(0.67, confusionMatrixFirst.getPrecision(), delta); + assertEquals(0.67, confusionMatrixFirst.getRecall(), delta); + assertEquals(0.67, confusionMatrixFirst.getF1measure(), delta); + + + assertEquals(1, confusionMatrixSecond.getTruePositiveSize()); + assertEquals(3, confusionMatrixSecond.getFalsePositiveSize()); + assertEquals(0, confusionMatrixSecond.getFalseNegativeSize()); + assertEquals(0.25, confusionMatrixSecond.getPrecision(), delta); + assertEquals(1.0, confusionMatrixSecond.getRecall(), delta); + assertEquals(0.4, confusionMatrixSecond.getF1measure(), delta); + + assertEquals(3, confusionMatrixThird.getTruePositiveSize()); + assertEquals(0, confusionMatrixThird.getFalsePositiveSize()); + assertEquals(3, confusionMatrixThird.getFalseNegativeSize()); + assertEquals(1.0, confusionMatrixThird.getPrecision(), delta); + assertEquals(0.5, confusionMatrixThird.getRecall(), delta); + assertEquals(0.67, confusionMatrixThird.getF1measure(), delta); + + + + ConfusionMatrix microAll = metric.getMicroAveragesForResults(all); + + assertEquals(6, microAll.getTruePositiveSize()); + assertEquals(4, microAll.getFalsePositiveSize()); + assertEquals(4, microAll.getFalseNegativeSize()); + assertEquals(0.6, microAll.getPrecision(), delta); + assertEquals(0.6, microAll.getRecall(), delta); + assertEquals(0.6, microAll.getF1measure(), delta); + + + ConfusionMatrix macroAll = metric.getMacroAveragesForResults(all); + assertEquals(6, macroAll.getTruePositiveSize()); + assertEquals(4, macroAll.getFalsePositiveSize()); + assertEquals(4, macroAll.getFalseNegativeSize()); + assertEquals(0.64, macroAll.getPrecision(), delta); + assertEquals(0.72, macroAll.getRecall(), delta); + assertEquals(0.58, macroAll.getF1measure(), delta); + + ConfusionMatrix macroSpecifiedNumber = metric.getMacroAveragesForResults(all, 3); + assertEquals(6, macroSpecifiedNumber.getTruePositiveSize()); + assertEquals(4, macroSpecifiedNumber.getFalsePositiveSize()); + assertEquals(4, macroSpecifiedNumber.getFalseNegativeSize()); + assertEquals(0.64, macroSpecifiedNumber.getPrecision(), delta); + assertEquals(0.72, macroSpecifiedNumber.getRecall(), delta); + assertEquals(0.58, macroSpecifiedNumber.getF1measure(), delta); + } + + + + + + private static ExecutionResult createResultWith(String testCase, GoldStandardCompleteness goldStandardCompleteness, + int classTP, int classFP, int classFN, + int propTP, int propFP, int propFN, + int instTP, int instFP, int instFN){ + int counter = 0; + String sourceBase = "http://source.com/" + testCase + "/"; + String targetBase = "http://target.com/" + testCase + "/"; + + Alignment systemAlignment = new Alignment(); + Alignment refAlignment = new Alignment(); + + OntModel src = ModelFactory.createOntologyModel(OntModelSpec.OWL_MEM); + OntModel tgt = ModelFactory.createOntologyModel(OntModelSpec.OWL_MEM); + + //TP + for(int i = 0; i < classTP; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createClass(sourceURI); + tgt.createClass(targetURI); + systemAlignment.add(sourceURI, targetURI); + refAlignment.add(sourceURI, targetURI); + } + for(int i = 0; i < propTP; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createProperty(sourceURI); + tgt.createProperty(targetURI); + systemAlignment.add(sourceURI, targetURI); + refAlignment.add(sourceURI, targetURI); + } + for(int i = 0; i < instTP; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createIndividual(sourceURI, OWL.Thing); + tgt.createIndividual(targetURI, OWL.Thing); + systemAlignment.add(sourceURI, targetURI); + refAlignment.add(sourceURI, targetURI); + } + + //FP + for(int i = 0; i < classFP; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createClass(sourceURI); + tgt.createClass(targetURI); + systemAlignment.add(sourceURI, targetURI); + } + for(int i = 0; i < propFP; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createProperty(sourceURI); + tgt.createProperty(targetURI); + systemAlignment.add(sourceURI, targetURI); + } + for(int i = 0; i < instFP; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createIndividual(sourceURI, OWL.Thing); + tgt.createIndividual(targetURI, OWL.Thing); + systemAlignment.add(sourceURI, targetURI); + } + + //FN + for(int i = 0; i < classFN; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createClass(sourceURI); + tgt.createClass(targetURI); + refAlignment.add(sourceURI, targetURI); + } + for(int i = 0; i < propFN; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createProperty(sourceURI); + tgt.createProperty(targetURI); + refAlignment.add(sourceURI, targetURI); + } + for(int i = 0; i < instFN; i++){ + String sourceURI = sourceBase + counter++; + String targetURI = targetBase + counter++; + src.createIndividual(sourceURI, OWL.Thing); + tgt.createIndividual(targetURI, OWL.Thing); + refAlignment.add(sourceURI, targetURI); + } + + + LocalTrack track = new LocalTrack("testtrack", "1.0"); + TestCase tc = new TestCaseWithModel(testCase, src, tgt, refAlignment, track, goldStandardCompleteness); + + return new ExecutionResult(tc, "myTestMatcher", systemAlignment, refAlignment); + } } \ No newline at end of file diff --git a/matching-eval/src/test/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/TestCaseWithModel.java b/matching-eval/src/test/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/TestCaseWithModel.java new file mode 100644 index 0000000000..516ed2b318 --- /dev/null +++ b/matching-eval/src/test/java/de/uni_mannheim/informatik/dws/melt/matching_eval/evaluator/metric/cm/TestCaseWithModel.java @@ -0,0 +1,55 @@ +package de.uni_mannheim.informatik.dws.melt.matching_eval.evaluator.metric.cm; + +import de.uni_mannheim.informatik.dws.melt.matching_data.GoldStandardCompleteness; +import de.uni_mannheim.informatik.dws.melt.matching_data.TestCase; +import de.uni_mannheim.informatik.dws.melt.matching_data.Track; +import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment; +import java.util.Properties; +import org.apache.jena.ontology.OntModel; + +public class TestCaseWithModel extends TestCase { + private OntModel sourceModel; + private OntModel targetModel; + private Alignment referenceAlignment; + public TestCaseWithModel(String name, OntModel source, OntModel target, Alignment reference, Track track, GoldStandardCompleteness goldStandardCompleteness) { + super(name, null, null, null, track, null, goldStandardCompleteness, null, null); + this.sourceModel = source; + this.targetModel = target; + this.referenceAlignment = reference; + } + + @Override + public T getSourceOntology(Class clazz){ + return getSourceOntology(clazz, null); + } + @Override + @SuppressWarnings("unchecked") + public T getSourceOntology(Class clazz, Properties parameters){ + if(clazz.equals(OntModel.class)){ + return (T) sourceModel; + }else{ + throw new IllegalArgumentException("Wrong ontology type"); + } + } + + + @Override + public T getTargetOntology(Class clazz){ + return getTargetOntology(clazz, null); + } + @Override + @SuppressWarnings("unchecked") + public T getTargetOntology(Class clazz, Properties parameters){ + if(clazz.equals(OntModel.class)){ + return (T) targetModel; + }else{ + throw new IllegalArgumentException("Wrong ontology type"); + } + } + + @Override + public Alignment getParsedReferenceAlignment() { + return referenceAlignment; + } + +}