From c7db6d3f31cec9fc79da4f711e234aa1126c2e60 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Mon, 16 Sep 2024 10:02:23 +0300 Subject: [PATCH] Standardized MiningField element --- .../org/jpmml/model/filters/ExportFilter.java | 20 ++++++++++++++++-- .../org/jpmml/model/filters/ImportFilter.java | 21 ++++++++++++++++--- .../java/org/dmg/pmml/MiningFieldTest.java | 10 +++++---- .../test/resources/pmml/MiningFieldTest.pmml | 2 +- 4 files changed, 43 insertions(+), 10 deletions(-) diff --git a/pmml-model/src/main/java/org/jpmml/model/filters/ExportFilter.java b/pmml-model/src/main/java/org/jpmml/model/filters/ExportFilter.java index 3c31b8f7..869075eb 100644 --- a/pmml-model/src/main/java/org/jpmml/model/filters/ExportFilter.java +++ b/pmml-model/src/main/java/org/jpmml/model/filters/ExportFilter.java @@ -68,9 +68,24 @@ public Attributes filterAttributes(String localName, Attributes attributes){ if(("MiningField").equals(localName)){ if(target.compareTo(Version.PMML_4_3) <= 0){ + String missingValueTreatment = getAttribute(attributes, "missingValueTreatment"); + String invalidValueTreatment = getAttribute(attributes, "invalidValueTreatment"); + attributes = renameAttribute(attributes, "invalidValueReplacement", "x-invalidValueReplacement"); - String invalidValueTreatment = getAttribute(attributes, "invalidValueTreatment"); + if(missingValueTreatment != null){ + + switch(missingValueTreatment){ + case "returnInvalid": + { + attributes = setAttribute(attributes, "missingValueTreatment", "x-" + missingValueTreatment); + } + break; + default: + break; + } + } // End if + if(invalidValueTreatment != null){ switch(invalidValueTreatment){ @@ -98,10 +113,11 @@ public Attributes filterAttributes(String localName, Attributes attributes){ if(("Segmentation").equals(localName)){ if(target.compareTo(Version.PMML_4_3) <= 0){ + String multipleModelMethod = getAttribute(attributes, "multipleModelMethod"); + attributes = renameAttribute(attributes, "missingPredictionTreatment", "x-missingPredictionTreatment"); attributes = renameAttribute(attributes, "missingThreshold", "x-missingThreshold"); - String multipleModelMethod = getAttribute(attributes, "multipleModelMethod"); if(multipleModelMethod != null){ switch(multipleModelMethod){ diff --git a/pmml-model/src/main/java/org/jpmml/model/filters/ImportFilter.java b/pmml-model/src/main/java/org/jpmml/model/filters/ImportFilter.java index 6862668e..0428d313 100644 --- a/pmml-model/src/main/java/org/jpmml/model/filters/ImportFilter.java +++ b/pmml-model/src/main/java/org/jpmml/model/filters/ImportFilter.java @@ -83,8 +83,22 @@ public Attributes filterAttributes(String localName, Attributes attributes){ } // End if if(source.compareTo(Version.PMML_4_4) <= 0){ + String missingValueTreatment = getAttribute(attributes, "missingValueTreatment"); String invalidValueTreatment = getAttribute(attributes, "invalidValueTreatment"); + if(missingValueTreatment != null){ + + switch(missingValueTreatment){ + case "x-returnInvalid": + { + attributes = setAttribute(attributes, "missingValueTreatment", missingValueTreatment.substring("x-".length())); + } + break; + default: + break; + } + } // End if + if(invalidValueTreatment != null){ switch(invalidValueTreatment){ @@ -115,10 +129,8 @@ public Attributes filterAttributes(String localName, Attributes attributes){ if(("Segmentation").equals(localName)){ if(source.compareTo(Version.PMML_4_3) <= 0){ - attributes = renameAttribute(attributes, "x-missingPredictionTreatment", "missingPredictionTreatment"); - attributes = renameAttribute(attributes, "x-missingThreshold", "missingThreshold"); - String multipleModelMethod = getAttribute(attributes, "multipleModelMethod"); + if(multipleModelMethod != null){ switch(multipleModelMethod){ @@ -132,6 +144,9 @@ public Attributes filterAttributes(String localName, Attributes attributes){ break; } } + + attributes = renameAttribute(attributes, "x-missingPredictionTreatment", "missingPredictionTreatment"); + attributes = renameAttribute(attributes, "x-missingThreshold", "missingThreshold"); } } else diff --git a/pmml-model/src/test/java/org/dmg/pmml/MiningFieldTest.java b/pmml-model/src/test/java/org/dmg/pmml/MiningFieldTest.java index f45468c8..822597f6 100644 --- a/pmml-model/src/test/java/org/dmg/pmml/MiningFieldTest.java +++ b/pmml-model/src/test/java/org/dmg/pmml/MiningFieldTest.java @@ -24,15 +24,15 @@ public class MiningFieldTest extends SchemaUpdateTest { public void transform() throws Exception { byte[] original = ResourceUtil.getByteArray(MiningFieldTest.class); - checkMiningField(original, "asIs", new String[]{"0", null}); + checkMiningField(original, "x-returnInvalid", "asIs", new String[]{"0", null}); byte[] latest = upgradeToLatest(original); - checkMiningField(latest, "asValue", new String[]{null, "0"}); + checkMiningField(latest, "returnInvalid", "asValue", new String[]{null, "0"}); byte[] latestToOriginal = downgrade(latest, Version.PMML_4_3); - checkMiningField(latestToOriginal, "asIs", new String[]{"0", null}); + checkMiningField(latestToOriginal, "x-returnInvalid", "asIs", new String[]{"0", null}); } @Test @@ -49,9 +49,11 @@ public void unmarshal() throws Exception { } static - private void checkMiningField(byte[] bytes, String invalidValueTreatment, String[] invalidValueReplacement) throws Exception { + private void checkMiningField(byte[] bytes, String missingValueTreatment, String invalidValueTreatment, String[] invalidValueReplacement) throws Exception { Node node = DOMUtil.selectNode(bytes, "/:PMML/:RegressionModel/:MiningSchema/:MiningField"); + assertEquals(missingValueTreatment, DOMUtil.getAttributeValue(node, "missingValueTreatment")); + assertEquals(invalidValueTreatment, DOMUtil.getAttributeValue(node, "invalidValueTreatment")); assertArrayEquals(invalidValueReplacement, DOMUtil.getExtensionAttributeValues(node, "invalidValueReplacement")); } diff --git a/pmml-model/src/test/resources/pmml/MiningFieldTest.pmml b/pmml-model/src/test/resources/pmml/MiningFieldTest.pmml index 482b0e25..570e5a0e 100644 --- a/pmml-model/src/test/resources/pmml/MiningFieldTest.pmml +++ b/pmml-model/src/test/resources/pmml/MiningFieldTest.pmml @@ -6,7 +6,7 @@ - +