diff --git a/sklearn2pmml/resources/sklearn2pmml-1.0-SNAPSHOT.jar b/sklearn2pmml/resources/sklearn2pmml-1.0-SNAPSHOT.jar index d602c53..d982787 100644 Binary files a/sklearn2pmml/resources/sklearn2pmml-1.0-SNAPSHOT.jar and b/sklearn2pmml/resources/sklearn2pmml-1.0-SNAPSHOT.jar differ diff --git a/src/main/java/com/sklearn2pmml/Main.java b/src/main/java/com/sklearn2pmml/Main.java index 19264f5..d4e2c27 100644 --- a/src/main/java/com/sklearn2pmml/Main.java +++ b/src/main/java/com/sklearn2pmml/Main.java @@ -19,20 +19,33 @@ package com.sklearn2pmml; import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; +import java.io.InputStream; import java.io.OutputStream; +import javax.xml.transform.TransformerFactory; +import javax.xml.transform.sax.SAXTransformerFactory; +import javax.xml.transform.sax.TransformerHandler; +import javax.xml.transform.stream.StreamResult; + import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import org.dmg.pmml.PMML; +import org.dmg.pmml.Version; import org.jpmml.converter.Application; +import org.jpmml.model.SAXUtil; +import org.jpmml.model.filters.ExportFilter; import org.jpmml.model.metro.MetroJAXBUtil; +import org.jpmml.model.visitors.VersionInspector; import org.jpmml.python.PickleUtil; import org.jpmml.python.Storage; import org.jpmml.python.StorageUtil; import org.jpmml.sklearn.Encodable; import org.jpmml.sklearn.EncodableUtil; +import org.jpmml.sklearn.SkLearnException; import org.jpmml.sklearn.SkLearnUtil; +import org.xml.sax.InputSource; public class Main extends Application { @@ -48,6 +61,12 @@ public class Main extends Application { ) private File outputFile = null; + @Parameter ( + names = {"--pmml-schema", "--schema"}, + converter = VersionConverter.class + ) + private Version version = null; + static public void main(String... args) throws Exception { @@ -79,6 +98,21 @@ private void run() throws Exception { PMML pmml = encodable.encodePMML(); + if(this.version != null && this.version.compareTo(Version.PMML_4_4) < 0){ + VersionInspector versionInspector = new VersionInspector(); + versionInspector.applyTo(pmml); + + Version minVersion = versionInspector.getMinimum(); + if(minVersion.compareTo(this.version) > 0){ + throw new SkLearnException("The generated markup requires PMML schema version " + minVersion.getVersion() + " or newer"); + } + + Version maxVersion = versionInspector.getMaximum(); + if(maxVersion.compareTo(this.version) < 0){ + throw new SkLearnException("The generated markup requires PMML schema version " + maxVersion.getVersion() + " or older"); + } + } // End if + if(!this.outputFile.exists()){ File absoluteOutputFile = this.outputFile.getAbsoluteFile(); @@ -88,8 +122,34 @@ private void run() throws Exception { } } - try(OutputStream os = new FileOutputStream(this.outputFile)){ - MetroJAXBUtil.marshalPMML(pmml, os); + if(this.version != null && this.version.compareTo(Version.PMML_4_4) < 0){ + File tempFile = File.createTempFile("sklearn2pmml-", ".pmml"); + + try(OutputStream os = new FileOutputStream(tempFile)){ + MetroJAXBUtil.marshalPMML(pmml, os); + } + + SAXTransformerFactory transformerFactory = (SAXTransformerFactory)TransformerFactory.newInstance(); + + try(OutputStream os = new FileOutputStream(this.outputFile)){ + TransformerHandler transformerHandler = transformerFactory.newTransformerHandler(); + transformerHandler.setResult(new StreamResult(os)); + + ExportFilter exportFilter = new ExportFilter(SAXUtil.createXMLReader(), this.version); + exportFilter.setContentHandler(transformerHandler); + + try(InputStream is = new FileInputStream(tempFile)){ + exportFilter.parse(new InputSource(is)); + } + } + + tempFile.delete(); + } else + + { + try(OutputStream os = new FileOutputStream(this.outputFile)){ + MetroJAXBUtil.marshalPMML(pmml, os); + } } } diff --git a/src/main/java/com/sklearn2pmml/VersionConverter.java b/src/main/java/com/sklearn2pmml/VersionConverter.java new file mode 100644 index 0000000..0ff016f --- /dev/null +++ b/src/main/java/com/sklearn2pmml/VersionConverter.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Villu Ruusmann + * + * This file is part of SkLearn2PMML + * + * SkLearn2PMML is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * SkLearn2PMML is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with SkLearn2PMML. If not, see . + */ +package com.sklearn2pmml; + +import java.util.Objects; + +import com.beust.jcommander.IStringConverter; +import org.dmg.pmml.Version; + +public class VersionConverter implements IStringConverter { + + @Override + public Version convert(String string){ + Version[] versions = Version.values(); + + for(Version version : versions){ + + if(!version.isStandard()){ + continue; + } // End if + + if(Objects.equals(version.getNamespaceURI(), string) || Objects.equals(version.getVersion(), string)){ + return version; + } + } + + throw new IllegalArgumentException(string); + } +} \ No newline at end of file