From cde4b9c86bb64f1e0a22efcb7c37d13cfaee0e4c Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 8 Feb 2023 11:33:50 +0800 Subject: [PATCH 01/47] Add preprocess for GBT algorithms --- .../apache/flink/ml/common/gbt/DataUtils.java | 38 +++ .../flink/ml/common/gbt/Preprocess.java | 253 ++++++++++++++++++ .../flink/ml/common/gbt/defs/FeatureMeta.java | 148 ++++++++++ .../flink/ml/common/gbt/defs/GbtParams.java | 30 +++ .../flink/ml/common/gbt/defs/TaskType.java | 25 ++ .../flink/ml/common/gbt/DataUtilsTest.java | 39 +++ .../flink/ml/common/gbt/PreprocessTest.java | 218 +++++++++++++++ 7 files changed, 751 insertions(+) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/FeatureMeta.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TaskType.java create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/DataUtilsTest.java create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java new file mode 100644 index 000000000..3e13e7604 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; + +import java.util.Arrays; + +/** Some data utilities. */ +public class DataUtils { + /** The mapping computation is from {@link KBinsDiscretizerModel}. */ + public static int findBin(double[] binEdges, double v) { + int index = Arrays.binarySearch(binEdges, v); + if (index < 0) { + // Computes the index to insert. + index = -index - 1; + // Puts it in the left bin. + index--; + } + return Math.max(Math.min(index, (binEdges.length - 2)), 0); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java new file mode 100644 index 000000000..dfc5faedc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData; +import org.apache.flink.ml.feature.stringindexer.StringIndexer; +import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; +import org.apache.flink.ml.feature.stringindexer.StringIndexerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.ApiExpression; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Arrays; + +import static org.apache.flink.table.api.Expressions.$; + +/** + * Preprocesses input data table for gradient boosting trees algorithms. + * + *

Multiple non-vector columns or a single vector column can be specified for preprocessing. + * Values of these column(s) are mapped to integers inplace through discretizer or string indexer, + * and the meta information of column(s) are obtained. + */ +class Preprocess { + + /** + * Maps continuous and categorical columns to integers inplace using quantile discretizer and + * string indexer respectively, and obtains meta information for all columns. + */ + static Tuple2> preprocessCols(Table dataTable, GbtParams p) { + + final String[] relatedCols = ArrayUtils.add(p.featureCols, p.labelCol); + dataTable = + dataTable.select( + Arrays.stream(relatedCols) + .map(Expressions::$) + .toArray(ApiExpression[]::new)); + + // Maps continuous columns to integers, and obtain corresponding discretizer model. + String[] continuousCols = ArrayUtils.removeElements(p.featureCols, p.categoricalCols); + Tuple2> continuousMappedDataAndModelData = + discretizeContinuousCols(dataTable, continuousCols, p.maxBins); + dataTable = continuousMappedDataAndModelData.f0; + DataStream continuousFeatureMeta = + buildContinuousFeatureMeta(continuousMappedDataAndModelData.f1, continuousCols); + + // Maps categorical columns to integers, and obtain string indexer model. + DataStream categoricalFeatureMeta; + if (p.categoricalCols.length > 0) { + String[] mappedCategoricalCols = + Arrays.stream(p.categoricalCols).map(d -> d + "_output").toArray(String[]::new); + StringIndexer stringIndexer = + new StringIndexer() + .setInputCols(p.categoricalCols) + .setOutputCols(mappedCategoricalCols) + .setHandleInvalid("keep"); + StringIndexerModel stringIndexerModel = stringIndexer.fit(dataTable); + dataTable = stringIndexerModel.transform(dataTable)[0]; + + categoricalFeatureMeta = + buildCategoricalFeatureMeta( + StringIndexerModelData.getModelDataStream( + stringIndexerModel.getModelData()[0]), + p.categoricalCols); + } else { + categoricalFeatureMeta = + continuousFeatureMeta + .flatMap((value, out) -> {}) + .returns(TypeInformation.of(FeatureMeta.class)); + } + + // Rename results columns. + ApiExpression[] dropColumnExprs = + Arrays.stream(p.categoricalCols).map(Expressions::$).toArray(ApiExpression[]::new); + ApiExpression[] renameColumnExprs = + Arrays.stream(p.categoricalCols) + .map(d -> $(d + "_output").as(d)) + .toArray(ApiExpression[]::new); + dataTable = dataTable.dropColumns(dropColumnExprs).renameColumns(renameColumnExprs); + + return Tuple2.of(dataTable, continuousFeatureMeta.union(categoricalFeatureMeta)); + } + + /** + * Maps features values in vectors to integers using quantile discretizer, and obtains meta + * information for all features. + */ + static Tuple2> preprocessVecCol(Table dataTable, GbtParams p) { + dataTable = dataTable.select($(p.vectorCol), $(p.labelCol)); + Tuple2> mappedDataAndModelData = + discretizeVectorCol(dataTable, p.vectorCol, p.maxBins); + dataTable = mappedDataAndModelData.f0; + DataStream featureMeta = + buildContinuousFeatureMeta(mappedDataAndModelData.f1, null); + return Tuple2.of(dataTable, featureMeta); + } + + /** Builds {@link FeatureMeta} from {@link StringIndexerModelData}. */ + private static DataStream buildCategoricalFeatureMeta( + DataStream stringIndexerModelData, String[] cols) { + return stringIndexerModelData + .flatMap( + (d, out) -> { + Preconditions.checkArgument(d.stringArrays.length == cols.length); + for (int i = 0; i < cols.length; i += 1) { + out.collect( + FeatureMeta.categorical( + cols[i], + d.stringArrays[i].length, + d.stringArrays[i])); + } + }) + .returns(TypeInformation.of(FeatureMeta.class)); + } + + /** Builds {@link FeatureMeta} from {@link KBinsDiscretizerModelData}. */ + private static DataStream buildContinuousFeatureMeta( + DataStream discretizerModelData, String[] cols) { + return discretizerModelData + .flatMap( + (d, out) -> { + double[][] binEdges = d.binEdges; + for (int i = 0; i < binEdges.length; i += 1) { + String name = (null != cols) ? cols[i] : "_vec_f" + i; + out.collect( + FeatureMeta.continuous( + name, binEdges[i].length - 1, binEdges[i])); + } + }) + .returns(TypeInformation.of(FeatureMeta.class)); + } + + /** Discretizes continuous columns inplace, and obtains quantile discretizer model data. */ + @SuppressWarnings("checkstyle:RegexpSingleline") + private static Tuple2> discretizeContinuousCols( + Table dataTable, String[] continuousCols, int numBins) { + final StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); + final int nCols = continuousCols.length; + + // Merges all continuous columns into a vector columns. + final String vectorCol = "_vec"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(dataTable.getResolvedSchema()); + DataStream data = tEnv.toDataStream(dataTable, Row.class); + DataStream dataWithVectors = + data.map( + (row) -> { + double[] values = new double[nCols]; + for (int i = 0; i < nCols; i += 1) { + Number number = row.getFieldAs(continuousCols[i]); + // Null values are represented using `Double.NaN` in `DenseVector`. + values[i] = (null == number) ? Double.NaN : number.doubleValue(); + } + return Row.join(row, Row.of(Vectors.dense(values))); + }, + new RowTypeInfo( + ArrayUtils.add( + inputTypeInfo.getFieldTypes(), + DenseVectorTypeInfo.INSTANCE), + ArrayUtils.add(inputTypeInfo.getFieldNames(), vectorCol))); + + Tuple2> mappedDataAndModelData = + discretizeVectorCol(tEnv.fromDataStream(dataWithVectors), vectorCol, numBins); + DataStream discretized = tEnv.toDataStream(mappedDataAndModelData.f0); + + // Maps the result vector back to multiple continuous columns. + final String[] otherCols = + ArrayUtils.removeElements(inputTypeInfo.getFieldNames(), continuousCols); + final TypeInformation[] otherColTypes = + Arrays.stream(otherCols) + .map(inputTypeInfo::getTypeAt) + .toArray(TypeInformation[]::new); + final TypeInformation[] mappedColTypes = + Arrays.stream(continuousCols).map(d -> Types.INT).toArray(TypeInformation[]::new); + + DataStream mapped = + discretized.map( + (row) -> { + DenseVector vec = row.getFieldAs(vectorCol); + Integer[] ints = + Arrays.stream(vec.values) + .mapToObj(d -> (Integer) ((int) d)) + .toArray(Integer[]::new); + Row result = Row.project(row, otherCols); + for (int i = 0; i < ints.length; i += 1) { + result.setField(continuousCols[i], ints[i]); + } + return result; + }, + new RowTypeInfo( + ArrayUtils.addAll(otherColTypes, mappedColTypes), + ArrayUtils.addAll(otherCols, continuousCols))); + + return Tuple2.of(tEnv.fromDataStream(mapped), mappedDataAndModelData.f1); + } + + /** + * Discretize the vector column inplace using quantile discretizer, and obtains quantile + * discretizer model data.. + */ + private static Tuple2> discretizeVectorCol( + Table dataTable, String vectorCol, int numBins) { + final String outputCol = "_output_col"; + KBinsDiscretizer kBinsDiscretizer = + new KBinsDiscretizer() + .setInputCol(vectorCol) + .setOutputCol(outputCol) + .setStrategy("quantile") + .setNumBins(numBins); + KBinsDiscretizerModel model = kBinsDiscretizer.fit(dataTable); + Table discretizedDataTable = model.transform(dataTable)[0]; + return Tuple2.of( + discretizedDataTable + .dropColumns($(vectorCol)) + .renameColumns($(outputCol).as(vectorCol)), + KBinsDiscretizerModelData.getModelDataStream(model.getModelData()[0])); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/FeatureMeta.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/FeatureMeta.java new file mode 100644 index 000000000..f14caed61 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/FeatureMeta.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.ml.common.gbt.DataUtils; + +import java.io.Serializable; +import java.util.Arrays; + +/** Stores meta information of a feature. */ +public abstract class FeatureMeta { + + public String name; + public Type type; + // The bin index representing the missing values. + public int missingBin; + + public FeatureMeta() {} + + public FeatureMeta(String name, int missingBin, Type type) { + this.name = name; + this.missingBin = missingBin; + this.type = type; + } + + public static CategoricalFeatureMeta categorical( + String name, int missingBin, String[] categories) { + return new CategoricalFeatureMeta(name, missingBin, categories); + } + + public static ContinuousFeatureMeta continuous(String name, int missingBin, double[] binEdges) { + return new ContinuousFeatureMeta(name, missingBin, binEdges); + } + + /** + * Calculate number of bins used for this feature. + * + * @param useMissing Whether to assign an addition bin for missing values. + * @return The number of bins. + */ + public abstract int numBins(boolean useMissing); + + @Override + public String toString() { + return String.format( + "FeatureMeta{name='%s', type=%s, missingBin=%d}", name, type, missingBin); + } + + /** Indicates the feature type. */ + public enum Type implements Serializable { + CATEGORICAL, + CONTINUOUS + } + + /** Stores meta information for a categorical feature. */ + public static class CategoricalFeatureMeta extends FeatureMeta { + // Stores ordered categorical values. + public String[] categories; + + public CategoricalFeatureMeta() {} + + public CategoricalFeatureMeta(String name, int missingBin, String[] categories) { + super(name, missingBin, Type.CATEGORICAL); + this.categories = categories; + } + + @Override + public int numBins(boolean useMissing) { + return useMissing ? categories.length + 1 : categories.length; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + return obj instanceof CategoricalFeatureMeta + && this.type.equals(((CategoricalFeatureMeta) obj).type) + && (this.name.equals(((CategoricalFeatureMeta) obj).name)) + && (this.missingBin == ((CategoricalFeatureMeta) obj).missingBin) + && (Arrays.equals(this.categories, ((CategoricalFeatureMeta) obj).categories)); + } + + @Override + public String toString() { + return String.format( + "CategoricalFeatureMeta{categories=%s} %s", + Arrays.toString(categories), super.toString()); + } + } + + /** Stores meta information for a continuous feature. */ + public static class ContinuousFeatureMeta extends FeatureMeta { + // Stores the edges of bins. + public double[] binEdges; + // The bin index for value 0. + public int zeroBin; + + public ContinuousFeatureMeta() {} + + public ContinuousFeatureMeta(String name, int missingBin, double[] binEdges) { + super(name, missingBin, Type.CONTINUOUS); + this.binEdges = binEdges; + this.zeroBin = DataUtils.findBin(binEdges, 0.); + } + + @Override + public int numBins(boolean useMissing) { + return useMissing ? binEdges.length : binEdges.length - 1; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + return obj instanceof ContinuousFeatureMeta + && this.type.equals(((ContinuousFeatureMeta) obj).type) + && (this.name.equals(((ContinuousFeatureMeta) obj).name)) + && (this.missingBin == ((ContinuousFeatureMeta) obj).missingBin) + && (Arrays.equals(this.binEdges, ((ContinuousFeatureMeta) obj).binEdges)) + && (this.zeroBin == ((ContinuousFeatureMeta) obj).zeroBin); + } + + @Override + public String toString() { + return String.format( + "ContinuousFeatureMeta{binEdges=%s, zeroBin=%d} %s", + Arrays.toString(binEdges), zeroBin, super.toString()); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java new file mode 100644 index 000000000..d5bf5d489 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** Internal parameters of a gradient boosting trees algorithm. */ +public class GbtParams implements Serializable { + public String[] featureCols; + public String vectorCol; + public String labelCol; + public String[] categoricalCols; + public int maxBins; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TaskType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TaskType.java new file mode 100644 index 000000000..3d375823e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TaskType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** Indicates the type of task. */ +public enum TaskType { + CLASSIFICATION, + REGRESSION, +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/DataUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/DataUtilsTest.java new file mode 100644 index 000000000..567a68651 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/DataUtilsTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.junit.Assert; +import org.junit.Test; + +/** Test {@link DataUtils}. */ +public class DataUtilsTest { + @Test + public void testFindBin() { + double[] binEdges = new double[] {1., 2., 3., 4.}; + for (int i = 0; i < binEdges.length; i += 1) { + Assert.assertEquals( + Math.min(binEdges.length - 2, i), DataUtils.findBin(binEdges, binEdges[i])); + } + double[] values = new double[] {.5, 1.5, 2.5, 3.5, 4.5}; + int[] bins = new int[] {0, 0, 1, 2, 2}; + for (int i = 0; i < values.length; i += 1) { + Assert.assertEquals(bins[i], DataUtils.findBin(binEdges, values[i])); + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java new file mode 100644 index 000000000..a9e3d6665 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.test.util.TestBaseUtils; +import org.apache.flink.testutils.junit.SharedObjects; +import org.apache.flink.testutils.junit.SharedReference; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; + +/** Tests {@link Preprocess}. */ +public class PreprocessTest extends AbstractTestBase { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + @Rule public final SharedObjects sharedObjects = SharedObjects.create(); + + private static final List inputDataRows = + Arrays.asList( + Row.of(1.2, 2, null, 40., Vectors.dense(1.2, 2, Double.NaN)), + Row.of(2.3, 3, "b", 40., Vectors.dense(2.3, 3, 2.)), + Row.of(3.4, 4, "c", 40., Vectors.dense(3.4, 4, 3.)), + Row.of(4.5, 5, "a", 40., Vectors.dense(4.5, 5, 1.)), + Row.of(5.6, 2, "b", 40., Vectors.dense(5.6, 2, 2.)), + Row.of(null, 3, "c", 41., Vectors.dense(Double.NaN, 3, 3.)), + Row.of(12.8, 4, "e", 41., Vectors.dense(12.8, 4, 5.)), + Row.of(13.9, 2, "b", 41., Vectors.dense(13.9, 2, 2.)), + Row.of(14.1, 4, "a", 41., Vectors.dense(14.1, 4, 1.)), + Row.of(15.3, 1, "d", 41., Vectors.dense(15.3, 1, 4.))); + + private StreamTableEnvironment tEnv; + private Table inputTable; + private SharedReference> actualMeta; + + // private static void verifyPredictionResult(Table output, List expected) throws + // Exception { + // StreamTableEnvironment tEnv = + // (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + // //noinspection unchecked + // List results = + // IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + // final double delta = 1e-3; + // final Comparator denseVectorComparator = + // new TestUtils.DenseVectorComparatorWithDelta(delta); + // final Comparator comparator = + // Comparator.comparing(d -> d.getFieldAs(0)) + // .thenComparing(d -> d.getFieldAs(1), denseVectorComparator) + // .thenComparing(d -> d.getFieldAs(2), denseVectorComparator); + // TestUtils.compareResultCollectionsWithComparator(expected, results, comparator); + // } + + @Before + public void before() { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputDataRows, + new RowTypeInfo( + new TypeInformation[] { + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.DOUBLE, + DenseVectorTypeInfo.INSTANCE + }, + new String[] {"f0", "f1", "f2", "label", "vec"}))); + + actualMeta = sharedObjects.add(new ArrayBlockingQueue<>(8)); + } + + private static class CollectSink implements SinkFunction { + private final SharedReference> q; + + public CollectSink(SharedReference> q) { + this.q = q; + } + + @Override + public void invoke(T value, Context context) { + q.get().add(value); + } + } + + @Test + public void testPreprocessCols() throws Exception { + GbtParams p = new GbtParams(); + p.featureCols = new String[] {"f0", "f1", "f2"}; + p.categoricalCols = new String[] {"f2"}; + p.labelCol = "label"; + p.maxBins = 3; + Tuple2> results = Preprocess.preprocessCols(inputTable, p); + + actualMeta.get().clear(); + results.f1.addSink(new CollectSink<>(actualMeta)); + //noinspection unchecked + List preprocessedRows = + IteratorUtils.toList(tEnv.toDataStream(results.f0).executeAndCollect()); + + // TODO: correct `binEdges` of feature `f0` after FLINK-30734 resolved. + List expectedMeta = + Arrays.asList( + FeatureMeta.continuous("f0", 3, new double[] {1.2, 4.5, 13.9, Double.NaN}), + FeatureMeta.continuous("f1", 3, new double[] {1.0, 2.0, 4.0, 5.0}), + FeatureMeta.categorical("f2", 5, new String[] {"a", "b", "c", "d", "e"})); + + List expectedPreprocessedRows = + Arrays.asList( + Row.of(40.0, 0, 1, 5.0), + Row.of(40.0, 0, 1, 1.0), + Row.of(40.0, 0, 2, 2.0), + Row.of(40.0, 1, 2, 0.0), + Row.of(40.0, 1, 1, 1.0), + Row.of(41.0, 2, 1, 2.0), + Row.of(41.0, 1, 2, 4.0), + Row.of(41.0, 2, 1, 1.0), + Row.of(41.0, 2, 2, 0.0), + Row.of(41.0, 2, 0, 3.0)); + Comparator preprocessedRowComparator = + Comparator.comparing(d -> d.getFieldAs(0)) + .thenComparing(d -> d.getFieldAs(1)) + .thenComparing(d -> d.getFieldAs(2)) + .thenComparing(d -> d.getFieldAs(3)); + + TestBaseUtils.compareResultCollections( + expectedPreprocessedRows, preprocessedRows, preprocessedRowComparator); + TestBaseUtils.compareResultCollections( + expectedMeta, new ArrayList<>(actualMeta.get()), Comparator.comparing(d -> d.name)); + } + + @Test + public void testPreprocessVectorCol() throws Exception { + GbtParams p = new GbtParams(); + p.vectorCol = "vec"; + p.labelCol = "label"; + p.maxBins = 3; + Tuple2> results = Preprocess.preprocessVecCol(inputTable, p); + + actualMeta.get().clear(); + results.f1.addSink(new CollectSink<>(actualMeta)); + //noinspection unchecked + List preprocessedRows = + IteratorUtils.toList(tEnv.toDataStream(results.f0).executeAndCollect()); + + // TODO: correct `binEdges` of feature `_vec_f0` and `_vec_f2` after FLINK-30734 resolved. + List expectedMeta = + Arrays.asList( + FeatureMeta.continuous( + "_vec_f0", 3, new double[] {1.2, 4.5, 13.9, Double.NaN}), + FeatureMeta.continuous("_vec_f1", 3, new double[] {1.0, 2.0, 4.0, 5.0}), + FeatureMeta.continuous( + "_vec_f2", 3, new double[] {1.0, 2.0, 3.0, Double.NaN})); + List expectedPreprocessedRows = + Arrays.asList( + Row.of(40.0, Vectors.dense(0, 1, 2.0)), + Row.of(40.0, Vectors.dense(0, 1, 1.0)), + Row.of(40.0, Vectors.dense(0, 2, 2.0)), + Row.of(40.0, Vectors.dense(1, 2, 0.0)), + Row.of(40.0, Vectors.dense(1, 1, 1.0)), + Row.of(41.0, Vectors.dense(2, 1, 2.0)), + Row.of(41.0, Vectors.dense(1, 2, 2.0)), + Row.of(41.0, Vectors.dense(2, 1, 1.0)), + Row.of(41.0, Vectors.dense(2, 2, 0.0)), + Row.of(41.0, Vectors.dense(2, 0, 2.0))); + + Comparator preprocessedRowComparator = + Comparator.comparing(d -> d.getFieldAs(0)) + .thenComparing(d -> d.getFieldAs(1), TestUtils::compare); + + TestBaseUtils.compareResultCollections( + expectedPreprocessedRows, preprocessedRows, preprocessedRowComparator); + TestBaseUtils.compareResultCollections( + expectedMeta, new ArrayList<>(actualMeta.get()), Comparator.comparing(d -> d.name)); + } +} From 7d561c37899b523fdd28ddc7d407a14a1955bf36 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 8 Feb 2023 11:33:50 +0800 Subject: [PATCH 02/47] Add training and prediction --- flink-ml-lib/pom.xml | 10 + .../ml/common/gbt/BoostIterationBody.java | 210 +++++++++++++++++ .../apache/flink/ml/common/gbt/DataUtils.java | 30 +++ .../flink/ml/common/gbt/GBTModelData.java | 176 ++++++++++++++ .../apache/flink/ml/common/gbt/GBTRunner.java | 173 ++++++++++++++ .../datastorage/IterationSharedStorage.java | 136 +++++++++++ .../ml/common/gbt/defs/BinnedInstance.java | 56 +++++ .../flink/ml/common/gbt/defs/Distributor.java | 77 ++++++ .../flink/ml/common/gbt/defs/GbtParams.java | 25 ++ .../ml/common/gbt/defs/HessianImpurity.java | 107 +++++++++ .../flink/ml/common/gbt/defs/Histogram.java | 78 +++++++ .../flink/ml/common/gbt/defs/Impurity.java | 101 ++++++++ .../ml/common/gbt/defs/LearningNode.java | 45 ++++ .../flink/ml/common/gbt/defs/LocalState.java | 85 +++++++ .../apache/flink/ml/common/gbt/defs/Node.java | 30 +++ .../ml/common/gbt/defs/PredGradHess.java | 40 ++++ .../flink/ml/common/gbt/defs/Slice.java | 40 ++++ .../flink/ml/common/gbt/defs/Split.java | 167 +++++++++++++ .../flink/ml/common/gbt/defs/Splits.java | 81 +++++++ .../ml/common/gbt/loss/AbsoluteError.java | 46 ++++ .../flink/ml/common/gbt/loss/LogLoss.java | 51 ++++ .../apache/flink/ml/common/gbt/loss/Loss.java | 52 +++++ .../ml/common/gbt/loss/SquaredError.java | 46 ++++ .../CacheDataCalcLocalHistsOperator.java | 219 ++++++++++++++++++ .../operators/CalcLocalSplitsOperator.java | 109 +++++++++ .../ml/common/gbt/operators/HistBuilder.java | 186 +++++++++++++++ .../operators/HistogramAggregateFunction.java | 55 +++++ .../common/gbt/operators/InstanceUpdater.java | 91 ++++++++ .../gbt/operators/LocalStateInitializer.java | 166 +++++++++++++ .../ml/common/gbt/operators/NodeSplitter.java | 136 +++++++++++ .../gbt/operators/PostSplitsOperator.java | 156 +++++++++++++ .../ml/common/gbt/operators/SplitFinder.java | 101 ++++++++ .../operators/SplitsAggregateFunction.java | 54 +++++ .../common/gbt/operators/TreeInitializer.java | 73 ++++++ .../splitter/CategoricalFeatureSplitter.java | 107 +++++++++ .../splitter/ContinuousFeatureSplitter.java | 70 ++++++ .../common/gbt/splitter/FeatureSplitter.java | 52 +++++ .../splitter/HistogramFeatureSplitter.java | 140 +++++++++++ .../flink/ml/common/gbt/GBTRunnerTest.java | 136 +++++++++++ 39 files changed, 3713 insertions(+) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Distributor.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/HessianImpurity.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Impurity.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/AbsoluteError.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/LogLoss.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/SquaredError.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/LocalStateInitializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 1773fc7d4..2c5d86c75 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -138,6 +138,16 @@ under the License. test test-jar + + org.eclipse.collections + eclipse-collections-api + 11.1.0 + + + org.eclipse.collections + eclipse-collections + 11.1.0 + diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java new file mode 100644 index 000000000..d3cc62355 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationID; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.operators.CacheDataCalcLocalHistsOperator; +import org.apache.flink.ml.common.gbt.operators.CalcLocalSplitsOperator; +import org.apache.flink.ml.common.gbt.operators.HistogramAggregateFunction; +import org.apache.flink.ml.common.gbt.operators.PostSplitsOperator; +import org.apache.flink.ml.common.gbt.operators.SplitsAggregateFunction; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +import org.apache.commons.lang3.ArrayUtils; + +/** + * Implements iteration body for boosting algorithms. This implementation uses horizontal partition + * of data and row-store storage of instances. + */ +class BoostIterationBody implements IterationBody { + private final IterationID iterationID; + private final GbtParams gbtParams; + + public BoostIterationBody(IterationID iterationID, GbtParams gbtParams) { + this.iterationID = iterationID; + this.gbtParams = gbtParams; + } + + @Override + public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream data = dataStreams.get(0); + DataStream localState = variableStreams.get(0); + + final OutputTag stateOutputTag = + new OutputTag<>("state", TypeInformation.of(LocalState.class)); + + final OutputTag finalStateOutputTag = + new OutputTag<>("final_state", TypeInformation.of(LocalState.class)); + + /** + * In the iteration, some data needs to be shared between subtasks of different operators + * within one machine. We use {@link IterationSharedStorage} with co-location mechanism to + * achieve such purpose. The data is stored in JVM static region, and is accessed through + * string keys from different operator subtasks. Note the first operator to put the data is + * the owner of the data, and only the owner can update or delete the data. + * + *

To be specified, in gradient boosting trees algorithm, there three types of shared + * data: + * + *

+ */ + final String sharedInstancesKey = "instances"; + final String sharedPredGradHessKey = "preds_grads_hessians"; + final String sharedShuffledIndicesKey = "shuffled_indices"; + final String sharedSwappedIndicesKey = "swapped_indices"; + + final String coLocationKey = "boosting"; + + // In 1st round, cache all data. For all rounds calculate local histogram based on + // current tree layer. + SingleOutputStreamOperator localHists = + data.connect(localState) + .transform( + "CacheDataCalcLocalHists", + TypeInformation.of(Histogram.class), + new CacheDataCalcLocalHistsOperator( + gbtParams, + iterationID, + sharedInstancesKey, + sharedPredGradHessKey, + sharedShuffledIndicesKey, + sharedSwappedIndicesKey, + stateOutputTag)); + localHists.getTransformation().setCoLocationGroupKey("coLocationKey"); + DataStream modelData = localHists.getSideOutput(stateOutputTag); + + DataStream globalHists = scatterReduceHistograms(localHists); + + SingleOutputStreamOperator localSplits = + modelData + .connect(globalHists) + .transform( + "CalcLocalSplits", + TypeInformation.of(Splits.class), + new CalcLocalSplitsOperator(stateOutputTag)); + localHists.getTransformation().setCoLocationGroupKey(coLocationKey); + DataStream globalSplits = + localSplits.broadcast().flatMap(new SplitsAggregateFunction()); + + SingleOutputStreamOperator updatedModelData = + modelData + .connect(globalSplits.broadcast()) + .transform( + "PostSplits", + TypeInformation.of(LocalState.class), + new PostSplitsOperator( + iterationID, + sharedInstancesKey, + sharedPredGradHessKey, + sharedShuffledIndicesKey, + sharedSwappedIndicesKey, + finalStateOutputTag)); + updatedModelData.getTransformation().setCoLocationGroupKey(coLocationKey); + + DataStream termination = + updatedModelData.flatMap( + new FlatMapFunction() { + @Override + public void flatMap(LocalState value, Collector out) { + LocalState.Dynamics dynamics = value.dynamics; + boolean terminated = + !dynamics.inWeakLearner + && dynamics.roots.size() + == value.statics.params.maxIter; + // TODO: add validation error rate + if (!terminated) { + out.collect(0); + } + } + }); + termination.getTransformation().setCoLocationGroupKey(coLocationKey); + + return new IterationBodyResult( + DataStreamList.of(updatedModelData), + DataStreamList.of(updatedModelData.getSideOutput(finalStateOutputTag)), + termination); + } + + public DataStream scatterReduceHistograms(DataStream localHists) { + return localHists + .flatMap( + (FlatMapFunction>) + (value, out) -> { + double[] hists = value.hists; + int[] recvcnts = value.recvcnts; + int p = 0; + for (int i = 0; i < recvcnts.length; i += 1) { + out.collect( + Tuple2.of( + i, + new Histogram( + value.subtaskId, + ArrayUtils.subarray( + hists, p, p + recvcnts[i]), + recvcnts))); + p += recvcnts[i]; + } + }) + .returns(new TypeHint>() {}) + .partitionCustom( + new Partitioner() { + @Override + public int partition(Integer key, int numPartitions) { + return key; + } + }, + new KeySelector, Integer>() { + @Override + public Integer getKey(Tuple2 value) + throws Exception { + return value.f0; + } + }) + .map(d -> d.f1) + .flatMap(new HistogramAggregateFunction()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java index 3e13e7604..32331cd31 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/DataUtils.java @@ -21,9 +21,39 @@ import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import java.util.Arrays; +import java.util.Random; /** Some data utilities. */ public class DataUtils { + + // Stores 4 values for one histogram bin, i.e., gradient, hessian, weight, and count. + public static final int BIN_SIZE = 4; + + public static void shuffle(int[] array, Random random) { + int n = array.length; + for (int i = 0; i < n; i += 1) { + int index = i + random.nextInt(n - i); + int tmp = array[index]; + array[index] = array[i]; + array[i] = tmp; + } + } + + public static int[] sample(int[] values, int numSamples, Random random) { + int n = values.length; + int[] sampled = new int[numSamples]; + + for (int i = 0; i < numSamples; i += 1) { + int index = i + random.nextInt(n - i); + sampled[i] = values[index]; + + int temp = values[i]; + values[i] = values[index]; + values[index] = temp; + } + return sampled; + } + /** The mapping computation is from {@link KBinsDiscretizerModel}. */ public static int findBin(double[] binEdges, double v) { int index = Arrays.binarySearch(binEdges, v); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java new file mode 100644 index 000000000..910588b97 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; +import org.eclipse.collections.impl.map.mutable.primitive.IntObjectHashMap; +import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap; + +import java.util.BitSet; +import java.util.List; + +/** + * Model data of gradient boosting trees. + * + *

This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class GBTModelData { + + public String type; + public boolean isInputVector; + + public double prior; + public double stepSize; + + public List roots; + public IntObjectHashMap> categoryToIdMaps; + public IntObjectHashMap featureIdToBinEdges; + public BitSet isCategorical; + + public GBTModelData() {} + + public GBTModelData( + String type, + boolean isInputVector, + double prior, + double stepSize, + List roots, + IntObjectHashMap> categoryToIdMaps, + IntObjectHashMap featureIdToBinEdges, + BitSet isCategorical) { + this.type = type; + this.isInputVector = isInputVector; + this.prior = prior; + this.stepSize = stepSize; + this.roots = roots; + this.categoryToIdMaps = categoryToIdMaps; + this.featureIdToBinEdges = featureIdToBinEdges; + this.isCategorical = isCategorical; + } + + public static GBTModelData fromLocalState(LocalState state) { + IntObjectHashMap> categoryToIdMaps = new IntObjectHashMap<>(); + IntObjectHashMap featureIdToBinEdges = new IntObjectHashMap<>(); + BitSet isCategorical = new BitSet(); + + FeatureMeta[] featureMetas = state.statics.featureMetas; + for (int k = 0; k < featureMetas.length; k += 1) { + FeatureMeta featureMeta = featureMetas[k]; + if (featureMeta instanceof FeatureMeta.CategoricalFeatureMeta) { + String[] categories = ((FeatureMeta.CategoricalFeatureMeta) featureMeta).categories; + ObjectIntHashMap categoryToId = new ObjectIntHashMap<>(); + for (int i = 0; i < categories.length; i += 1) { + categoryToId.put(categories[i], i); + } + categoryToIdMaps.put(k, categoryToId); + isCategorical.set(k); + } else if (featureMeta instanceof FeatureMeta.ContinuousFeatureMeta) { + featureIdToBinEdges.put( + k, ((FeatureMeta.ContinuousFeatureMeta) featureMeta).binEdges); + } + } + return new GBTModelData( + state.statics.params.taskType.name(), + state.statics.params.isInputVector, + state.statics.prior, + state.statics.params.stepSize, + state.dynamics.roots, + categoryToIdMaps, + featureIdToBinEdges, + isCategorical); + } + + public static DataStream getModelDataStream(Table modelDataTable) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); + return tEnv.toDataStream(modelDataTable).map(x -> x.getFieldAs(0)); + } + + /** The mapping computation is from {@link StringIndexerModel}. */ + private static int mapCategoricalFeature(ObjectIntHashMap categoryToId, Object v) { + String s; + if (v instanceof String) { + s = (String) v; + } else if (v instanceof Number) { + s = String.valueOf(v); + } else if (null == v) { + s = null; + } else { + throw new RuntimeException("Categorical column only supports string and numeric type."); + } + return categoryToId.getIfAbsent(s, categoryToId.size()); + } + + public IntDoubleHashMap rowToFeatures(Row row, String[] featureCols, String vectorCol) { + IntDoubleHashMap features = new IntDoubleHashMap(); + if (isInputVector) { + Vector vec = row.getFieldAs(vectorCol); + SparseVector sv = vec.toSparse(); + for (int i = 0; i < sv.indices.length; i += 1) { + features.put(sv.indices[i], sv.values[i]); + } + } else { + for (int i = 0; i < featureCols.length; i += 1) { + Object obj = row.getField(featureCols[i]); + double v; + if (isCategorical.get(i)) { + v = mapCategoricalFeature(categoryToIdMaps.get(i), obj); + } else { + Number number = (Number) obj; + v = (null == number) ? Double.NaN : number.doubleValue(); + } + features.put(i, v); + } + } + return features; + } + + public double predictRaw(IntDoubleHashMap rawFeatures) { + double v = prior; + for (Node root : roots) { + Node node = root; + while (!node.isLeaf) { + boolean goLeft = node.split.shouldGoLeft(rawFeatures); + node = goLeft ? node.left : node.right; + } + v += stepSize * node.split.prediction; + } + return v; + } + + @Override + public String toString() { + return String.format( + "GBTModelData{type=%s, prior=%s, roots=%s, categoryToIdMaps=%s, featureIdToBinEdges=%s, isCategorical=%s}", + type, prior, roots, categoryToIdMaps, featureIdToBinEdges, isCategorical); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java new file mode 100644 index 000000000..217f836a2 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationID; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; + +/** Runs a gradient boosting trees implementation. */ +public class GBTRunner { + + /** Trains a gradient boosting tree model with given data and parameters. */ + static DataStream train(Table dataTable, GbtParams p) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); + Tuple2> preprocessResult = + p.isInputVector + ? Preprocess.preprocessVecCol(dataTable, p) + : Preprocess.preprocessCols(dataTable, p); + dataTable = preprocessResult.f0; + DataStream featureMeta = preprocessResult.f1; + + DataStream data = tEnv.toDataStream(dataTable); + DataStream> labelSumCount = + DataStreamUtils.aggregate(data, new LabelSumCountFunction(p.labelCol)); + return boost(dataTable, p, featureMeta, labelSumCount); + } + + private static DataStream boost( + Table dataTable, + GbtParams p, + DataStream featureMeta, + DataStream> labelSumCount) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); + + final String featureMetaBcName = "featureMeta"; + final String labelSumCountBcName = "labelSumCount"; + Map> bcMap = new HashMap<>(); + bcMap.put(featureMetaBcName, featureMeta); + bcMap.put(labelSumCountBcName, labelSumCount); + + DataStream initStates = + BroadcastUtils.withBroadcastStream( + Collections.singletonList( + tEnv.toDataStream(tEnv.fromValues(0), Integer.class)), + bcMap, + (inputs) -> { + //noinspection unchecked + DataStream input = (DataStream) (inputs.get(0)); + return input.map( + new InitLocalStateFunction( + featureMetaBcName, labelSumCountBcName, p)); + }); + + DataStream data = tEnv.toDataStream(dataTable); + final IterationID iterationID = new IterationID(); + DataStreamList dataStreamList = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(initStates.broadcast()), + ReplayableDataStreamList.notReplay(data, featureMeta), + IterationConfig.newBuilder() + .setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND) + .build(), + new BoostIterationBody(iterationID, p)); + DataStream state = dataStreamList.get(0); + return state.map(GBTModelData::fromLocalState); + } + + private static class InitLocalStateFunction extends RichMapFunction { + private final String featureMetaBcName; + private final String labelSumCountBcName; + private final GbtParams p; + + private InitLocalStateFunction( + String featureMetaBcName, String labelSumCountBcName, GbtParams p) { + this.featureMetaBcName = featureMetaBcName; + this.labelSumCountBcName = labelSumCountBcName; + this.p = p; + } + + @Override + public LocalState map(Integer value) { + LocalState.Statics statics = new LocalState.Statics(); + statics.params = p; + statics.featureMetas = + getRuntimeContext() + .getBroadcastVariable(featureMetaBcName) + .toArray(new FeatureMeta[0]); + if (!statics.params.isInputVector) { + Arrays.sort( + statics.featureMetas, + Comparator.comparing(d -> ArrayUtils.indexOf(p.featureCols, d.name))); + } + statics.numFeatures = statics.featureMetas.length; + statics.labelSumCount = + getRuntimeContext() + .>getBroadcastVariable(labelSumCountBcName) + .get(0); + return new LocalState(statics, new LocalState.Dynamics()); + } + } + + private static class LabelSumCountFunction + implements AggregateFunction, Tuple2> { + + private final String labelCol; + + private LabelSumCountFunction(String labelCol) { + this.labelCol = labelCol; + } + + @Override + public Tuple2 createAccumulator() { + return Tuple2.of(0., 0L); + } + + @Override + public Tuple2 add(Row value, Tuple2 accumulator) { + double label = ((Number) value.getFieldAs(labelCol)).doubleValue(); + return Tuple2.of(accumulator.f0 + label, accumulator.f1 + 1); + } + + @Override + public Tuple2 getResult(Tuple2 accumulator) { + return accumulator; + } + + @Override + public Tuple2 merge(Tuple2 a, Tuple2 b) { + return Tuple2.of(a.f0 + b.f0, a.f1 + b.f1); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java new file mode 100644 index 000000000..d933de447 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.datastorage; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.iteration.IterationID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.util.Preconditions; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** A shared storage across subtasks of different operators. */ +public class IterationSharedStorage { + private static final Map, Object> m = + new ConcurrentHashMap<>(); + + private static final Map, OperatorID> owners = + new ConcurrentHashMap<>(); + + /** + * Gets a {@link Reader} of shared data identified by (iterationId, subtaskId, key). + * + * @param iterationID The iteration ID. + * @param subtaskId The subtask ID. + * @param key The string key. + * @return A {@link Reader} of shared data. + * @param The type of shared ata. + */ + public static Reader getReader(IterationID iterationID, int subtaskId, String key) { + return new Reader<>(Tuple3.of(iterationID, subtaskId, key)); + } + + /** + * Gets a {@link Writer} of shared data identified by (iterationId, subtaskId, key). + * + * @param iterationID The iteration ID. + * @param subtaskId The subtask ID. + * @param key The string key. + * @param operatorID The owner operator. + * @param initVal Initialize value of the data. + * @return A {@link Writer} of shared data. + * @param The type of shared ata. + */ + public static Writer getWriter( + IterationID iterationID, int subtaskId, String key, OperatorID operatorID, T initVal) { + Tuple3 t = Tuple3.of(iterationID, subtaskId, key); + OperatorID lastOwner = owners.putIfAbsent(t, operatorID); + if (null != lastOwner) { + throw new IllegalStateException( + String.format( + "The shared data (%s, %s, %s) already has a writer %s.", + iterationID, subtaskId, key, operatorID)); + } + Writer writer = new Writer<>(t, operatorID); + writer.set(initVal); + return writer; + } + + /** + * A reader of shared data identified by key (IterationID, subtaskID, key). + * + * @param The type of shared ata. + */ + public static class Reader { + protected final Tuple3 t; + + public Reader(Tuple3 t) { + this.t = t; + } + + /** + * Get the value. + * + * @return The value. + */ + public T get() { + //noinspection unchecked + return (T) m.get(t); + } + } + + /** + * A writer of shared data identified by key (IterationID, subtaskID, key). A writer is + * responsible for the checkpointing of data. + * + * @param The type of shared ata. + */ + public static class Writer extends Reader { + private final OperatorID operatorID; + + public Writer(Tuple3 t, OperatorID operatorID) { + super(t); + this.operatorID = operatorID; + } + + private void ensureOwner() { + // Double-checks the owner, because a writer may call this method after the key removed + // and re-added by other operators. + Preconditions.checkState(owners.get(t).equals(operatorID)); + } + + /** + * Set new value. + * + * @param value The new value. + */ + public void set(T value) { + ensureOwner(); + m.put(t, value); + } + + /** Remove this data entry. */ + public void remove() { + ensureOwner(); + m.remove(t); + owners.remove(t); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java new file mode 100644 index 000000000..b7509fe36 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; +import org.apache.flink.ml.feature.stringindexer.StringIndexer; +import org.apache.flink.ml.linalg.SparseVector; + +import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; + +/** + * Represents an instance including binned values of all features, weight, and label. + * + *

Categorical and continuous features are mapped to integers by {@link StringIndexer} and {@link + * KBinsDiscretizer}, respectively. Null values (`null` or `Double.NaN`) are also mapped to certain + * integers. + * + *

NOTE: When the input features are sparse, i.e., from {@link SparseVector}s, unseen indices are + * not stored in `features`. They should be handled separately. + */ +public class BinnedInstance { + + public IntIntHashMap features; + public double weight; + public double label; + + public BinnedInstance() {} + + public BinnedInstance(IntIntHashMap features, double weight, double label) { + this.weight = weight; + this.label = label; + this.features = features; + } + + @Override + public String toString() { + return String.format( + "BinnedInstance{features=%s, weight=%s, label=%s}", features, weight, label); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Distributor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Distributor.java new file mode 100644 index 000000000..0061dfc61 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Distributor.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** + * A utility class which helps data partitioning. + * + *

Given an indexable linear structures, like an array, of n elements and m tasks, the goal is to + * partition the linear structure into m consecutive segments and assign them to tasks accordingly. + * This class calculates the segment assigned to each task, including the start position and element + * count of the segment. + */ +public abstract class Distributor implements Serializable { + protected final long numTasks; + protected final long total; + + public Distributor(long total, long numTasks) { + this.numTasks = numTasks; + this.total = total; + } + + /** + * Calculates the start position of the segment assigned to the task. + * + * @param taskId The task index. + * @return The start position. + */ + public abstract long start(long taskId); + + /** + * Calculates the count of elements of the segment assigned to the task. + * + * @param taskId The task index. + * @return The count of elements. + */ + public abstract long count(long taskId); + + /** An implementation of {@link Distributor} which evenly partitioned the elements. */ + public static class EvenDistributor extends Distributor { + + public EvenDistributor(long parallelism, long totalCnt) { + super(totalCnt, parallelism); + } + + @Override + public long start(long taskId) { + long div = total / numTasks; + long mod = total % numTasks; + return taskId < mod ? div * taskId + taskId : div * taskId + mod; + } + + @Override + public long count(long taskId) { + long div = total / numTasks; + long mod = total % numTasks; + return taskId < mod ? div + 1 : div; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java index d5bf5d489..ea6d84bbc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java @@ -22,9 +22,34 @@ /** Internal parameters of a gradient boosting trees algorithm. */ public class GbtParams implements Serializable { + public TaskType taskType; + + // Parameters related with input data. public String[] featureCols; public String vectorCol; + public boolean isInputVector; public String labelCol; + public String weightCol; public String[] categoricalCols; + + // Parameters related with algorithms. + public int maxDepth; public int maxBins; + public int minInstancesPerNode; + public double minWeightFractionPerNode; + public double minInfoGain; + public int maxIter; + public double stepSize; + public long seed; + public double subsamplingRate; + public String featureSubsetStrategy; + public double validationTol; + public double lambda; + public double gamma; + + // Derived parameters. + public String lossType; + public int maxNumLeaves; + // useMissing is always true right now. + public boolean useMissing; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/HessianImpurity.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/HessianImpurity.java new file mode 100644 index 000000000..063572625 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/HessianImpurity.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** + * The impurity introduced in XGBoost. + * + *

See: Introduction to + * Boosted Trees. + */ +public class HessianImpurity extends Impurity { + + // Regularization of the leaf number. + protected final double lambda; + // Regularization of leaf scores. + protected final double gamma; + // Total of instance gradients. + protected double totalGradients; + // Total of instance hessians. + protected double totalHessians; + + public HessianImpurity( + double lambda, + double gamma, + int numInstances, + double totalWeights, + double totalGradients, + double totalHessians) { + super(numInstances, totalWeights); + this.lambda = lambda; + this.gamma = gamma; + this.totalGradients = totalGradients; + this.totalHessians = totalHessians; + } + + @Override + public double prediction() { + return -totalGradients / (totalHessians + gamma); + } + + @Override + public double impurity() { + if (totalHessians + lambda == 0) { + return 0.; + } + return totalGradients * totalGradients / (totalHessians + lambda); + } + + @Override + public double gain(Impurity... others) { + double sum = 0.; + for (Impurity other : others) { + sum += other.impurity(); + } + return .5 * (sum - impurity()) - gamma; + } + + @Override + public HessianImpurity add(Impurity other) { + HessianImpurity impurity = (HessianImpurity) other; + this.numInstances += impurity.numInstances; + this.totalWeights += impurity.totalWeights; + this.totalGradients += impurity.totalGradients; + this.totalHessians += impurity.totalHessians; + return this; + } + + @Override + public HessianImpurity subtract(Impurity other) { + HessianImpurity impurity = (HessianImpurity) other; + this.numInstances -= impurity.numInstances; + this.totalWeights -= impurity.totalWeights; + this.totalGradients -= impurity.totalGradients; + this.totalHessians -= impurity.totalHessians; + return this; + } + + public void add(int numInstances, double weights, double gradients, double hessians) { + this.numInstances += numInstances; + this.totalWeights += weights; + this.totalGradients += gradients; + this.totalHessians += hessians; + } + + public void subtract(int numInstances, double weights, double gradients, double hessians) { + this.numInstances -= numInstances; + this.totalWeights -= weights; + this.totalGradients -= gradients; + this.totalHessians -= hessians; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java new file mode 100644 index 000000000..bfc9a2641 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.util.Preconditions; + +import java.io.Serializable; + +/** + * This class stores values of histogram bins, and necessary information of reducing and scattering. + */ +public class Histogram implements Serializable { + + // Stores source subtask ID when reducing or target subtask ID when scattering. + public int subtaskId; + // Stores values of histogram bins. + public double[] hists; + // Stores the number of elements received by subtasks in scattering. + public int[] recvcnts; + + public Histogram(int subtaskId, double[] hists, int[] recvcnts) { + this.subtaskId = subtaskId; + this.hists = hists; + this.recvcnts = recvcnts; + } + + private Histogram accumulate(Histogram other) { + Preconditions.checkArgument(hists.length == other.hists.length); + for (int i = 0; i < hists.length; i += 1) { + hists[i] += other.hists[i]; + } + return this; + } + + /** Aggregator for Histogram. */ + public static class Aggregator + implements AggregateFunction, Serializable { + @Override + public Histogram createAccumulator() { + return null; + } + + @Override + public Histogram add(Histogram value, Histogram accumulator) { + if (null == accumulator) { + return value; + } + return accumulator.accumulate(value); + } + + @Override + public Histogram getResult(Histogram accumulator) { + return accumulator; + } + + @Override + public Histogram merge(Histogram a, Histogram b) { + return a.accumulate(b); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Impurity.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Impurity.java new file mode 100644 index 000000000..4e30c57ab --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Impurity.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** The base class for calculating information gain from statistics. */ +public abstract class Impurity implements Cloneable, Serializable { + + // Number of instances. + protected int numInstances; + // Total of instance weights. + protected double totalWeights; + + public Impurity(int numInstances, double totalWeights) { + this.numInstances = numInstances; + this.totalWeights = totalWeights; + } + + /** + * Calculates the prediction. + * + * @return The prediction. + */ + public abstract double prediction(); + + /** + * Calculates the impurity. + * + * @return The impurity score. + */ + public abstract double impurity(); + + /** + * Calculate the information gain over other impurity instances, usually coming from splitting + * nodes. + * + * @param others Other impurity instances. + * @return The value of information gain. + */ + public abstract double gain(Impurity... others); + + /** + * Add statistics from other impurity instance. + * + * @param other The other impurity instance. + * @return The result after adding. + */ + public abstract Impurity add(Impurity other); + + /** + * Subtract statistics from other impurity instance. + * + * @param other The other impurity instance. + * @return The result after subtracting. + */ + public abstract Impurity subtract(Impurity other); + + /** + * Get the total of instance weights. + * + * @return The total of instance weights. + */ + public double getTotalWeights() { + return totalWeights; + } + + /** + * Get the number of instances. + * + * @return The number of instances. + */ + public int getNumInstances() { + return numInstances; + } + + @Override + public Impurity clone() { + try { + return (Impurity) super.clone(); + } catch (CloneNotSupportedException e) { + throw new IllegalStateException("Can not clone the impurity instance.", e); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java new file mode 100644 index 000000000..e208269fb --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** A node used in learning procedure. */ +public class LearningNode { + + // The corresponding tree node. + public Node node; + // Slice of indices of bagging instances. + public Slice slice; + // Slice of indices of non-bagging instances. + public Slice oob; + // Depth of corresponding tree node. + public int depth; + + public LearningNode(Node node, Slice slice, Slice oob, int depth) { + this.node = node; + this.slice = slice; + this.oob = oob; + this.depth = depth; + } + + @Override + public String toString() { + return String.format( + "LearningNode{node=%s, slice=%s, oob=%s, depth=%d}", node, slice, oob, depth); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java new file mode 100644 index 000000000..c13df2d43 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.gbt.loss.Loss; + +import org.eclipse.collections.api.tuple.primitive.IntIntPair; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Stores training state, including static parts and dynamic parts. Static parts won't change across + * the iteration rounds (except initialization), while dynamic parts are updated on every round. + * + *

An instance of training states is bound to a subtask id, so the operators accepting training + * states should be co-located. + */ +public class LocalState implements Serializable { + + public Statics statics; + public Dynamics dynamics; + + public LocalState(Statics statics, Dynamics dynamics) { + this.statics = statics; + this.dynamics = dynamics; + } + + /** Static part of local state. */ + public static class Statics { + + public int subtaskId; + public int numSubtasks; + public GbtParams params; + + public int numInstances; + public int numBaggingInstances; + public Random instanceRandomizer; + + public int numFeatures; + public int numBaggingFeatures; + public Random featureRandomizer; + + public FeatureMeta[] featureMetas; + public int[] numFeatureBins; + + public Tuple2 labelSumCount; + public double prior; + public Loss loss; + } + + /** Dynamic part of local state. */ + public static class Dynamics { + // Root nodes of every tree. + public List roots = new ArrayList<>(); + // Initializes a new tree when false, otherwise splits nodes in current layer. + public boolean inWeakLearner; + + // Nodes to be split in the current layer. + public List layer = new ArrayList<>(); + // Node ID and feature ID pairs to be considered in current layer. + public List nodeFeaturePairs = new ArrayList<>(); + // Leaf nodes in the current tree. + public List leaves = new ArrayList<>(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java new file mode 100644 index 000000000..691d1c0ef --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** Tree node in the decision tree that will be serialized to json and deserialized from json. */ +public class Node implements Serializable { + + public Split split; + public boolean isLeaf = false; + public Node left; + public Node right; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java new file mode 100644 index 000000000..36272af57 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** Stores prediction, gradient, and hessian of an instance. */ +public class PredGradHess { + public double pred; + public double gradient; + public double hessian; + + public PredGradHess() {} + + public PredGradHess(double pred, double gradient, double hessian) { + this.pred = pred; + this.gradient = gradient; + this.hessian = hessian; + } + + @Override + public String toString() { + return String.format( + "PredGradHess{pred=%s, gradient=%s, hessian=%s}", pred, gradient, hessian); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java new file mode 100644 index 000000000..44ece2453 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +/** Represents a slice of an indexable linear structure, like an array. */ +public final class Slice { + + public int start; + public int end; + + public Slice(int start, int end) { + this.start = start; + this.end = end; + } + + public int size() { + return end - start; + } + + @Override + public String toString() { + return String.format("Slice{start=%d, end=%d}", start, end); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java new file mode 100644 index 000000000..78052bbe7 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; +import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; + +import java.util.BitSet; + +/** Stores a split on a feature. */ +public abstract class Split { + public static final double INVALID_GAIN = 0.0; + + // Stores the feature index of this split. + public final int featureId; + + // Stores impurity gain. A value of `INVALID_GAIN` indicates this split is invalid. + public final double gain; + + // Bin index for missing values of this feature. + public final int missingBin; + // Whether the missing values should go left. + public final boolean missingGoLeft; + + // The prediction value if this split is invalid. + public final double prediction; + + public Split( + int featureId, double gain, int missingBin, boolean missingGoLeft, double prediction) { + this.featureId = featureId; + this.gain = gain; + this.missingBin = missingBin; + this.missingGoLeft = missingGoLeft; + this.prediction = prediction; + } + + /** + * Test the binned instance should go to the left child or the right child. + * + * @param binnedInstance The instance after binned. + * @return True if the instance should go to the left child. + */ + public abstract boolean shouldGoLeft(BinnedInstance binnedInstance); + + /** + * Test the raw features should go to the left child or the right child. In the raw features, + * the categorical values are mapped to integers, while the continuous values are kept unmapped. + * + * @param rawFeatures The feature map from feature indices to values. + * @return True if the raw features should go to the left child. + */ + public abstract boolean shouldGoLeft(IntDoubleHashMap rawFeatures); + + public boolean isValid() { + return gain != INVALID_GAIN; + } + + /** Stores a split on a continuous feature. */ + public static class ContinuousSplit extends Split { + + /** + * Stores the threshold that one continuous feature should go the left or right. Before + * splitting the node, the threshold is the bin index. After that, the threshold is replaced + * with the actual value of the bin edge. + */ + public double threshold; + + // True if treat unseen values as missing values, otherwise treat them as 0s. + public boolean isUnseenMissing; + + // Bin index for 0 values. + public int zeroBin; + + public ContinuousSplit( + int featureIndex, + double gain, + int missingBin, + boolean missingGoLeft, + double prediction, + double threshold, + boolean isUnseenMissing, + int zeroBin) { + super(featureIndex, gain, missingBin, missingGoLeft, prediction); + this.threshold = threshold; + this.isUnseenMissing = isUnseenMissing; + this.zeroBin = zeroBin; + } + + public static ContinuousSplit invalid(double prediction) { + return new ContinuousSplit(0, INVALID_GAIN, 0, false, prediction, 0., false, 0); + } + + @Override + public boolean shouldGoLeft(BinnedInstance binnedInstance) { + IntIntHashMap features = binnedInstance.features; + if (!features.containsKey(featureId) && isUnseenMissing) { + return missingGoLeft; + } + int binId = features.getIfAbsent(featureId, zeroBin); + return binId == missingBin ? missingGoLeft : binId <= threshold; + } + + @Override + public boolean shouldGoLeft(IntDoubleHashMap rawFeatures) { + if (!rawFeatures.containsKey(featureId) && isUnseenMissing) { + return missingGoLeft; + } + double v = rawFeatures.getIfAbsent(featureId, 0.); + return Double.isNaN(v) ? missingGoLeft : v < threshold; + } + } + + /** Stores a split on a categorical feature. */ + public static class CategoricalSplit extends Split { + // Stores the indices of categorical values that should go to the left child. + public final BitSet categoriesGoLeft; + + public CategoricalSplit( + int featureId, + double gain, + int missingBin, + boolean missingGoLeft, + double prediction, + BitSet categoriesGoLeft) { + super(featureId, gain, missingBin, missingGoLeft, prediction); + this.categoriesGoLeft = categoriesGoLeft; + } + + public static CategoricalSplit invalid(double prediction) { + return new CategoricalSplit(0, INVALID_GAIN, 0, false, prediction, new BitSet()); + } + + @Override + public boolean shouldGoLeft(BinnedInstance binnedInstance) { + IntIntHashMap features = binnedInstance.features; + if (!features.containsKey(featureId)) { + return missingGoLeft; + } + int binId = features.get(featureId); + return binId == missingBin ? missingGoLeft : categoriesGoLeft.get(binId); + } + + @Override + public boolean shouldGoLeft(IntDoubleHashMap rawFeatures) { + if (!rawFeatures.containsKey(featureId)) { + return missingGoLeft; + } + return categoriesGoLeft.get((int) rawFeatures.get(featureId)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java new file mode 100644 index 000000000..c80598383 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.common.functions.AggregateFunction; + +/** + * This class stores splits of nodes in the current layer, and necessary information of + * all-reducing.. + */ +public class Splits { + + // Stores source subtask ID when reducing or target subtask ID when scattering. + public int subtaskId; + // Stores splits of nodes in the current layer. + public Split[] splits; + + public Splits(int subtaskId, Split[] splits) { + this.subtaskId = subtaskId; + this.splits = splits; + } + + private Splits accumulate(Splits other) { + for (int i = 0; i < splits.length; ++i) { + if (splits[i] == null && other.splits[i] != null) { + splits[i] = other.splits[i]; + } else if (splits[i] != null && other.splits[i] != null) { + if (splits[i].gain < other.splits[i].gain) { + splits[i] = other.splits[i]; + } else if (splits[i].gain == other.splits[i].gain) { + if (splits[i].featureId < other.splits[i].featureId) { + splits[i] = other.splits[i]; + } + } + } + } + return this; + } + + /** Aggregator for Splits. */ + public static class Aggregator implements AggregateFunction { + @Override + public Splits createAccumulator() { + return null; + } + + @Override + public Splits add(Splits value, Splits accumulator) { + if (null == accumulator) { + return value; + } + return accumulator.accumulate(value); + } + + @Override + public Splits getResult(Splits accumulator) { + return accumulator; + } + + @Override + public Splits merge(Splits a, Splits b) { + return a.accumulate(b); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/AbsoluteError.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/AbsoluteError.java new file mode 100644 index 000000000..b09240205 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/AbsoluteError.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.loss; + +/** + * Squared error loss function defined as |y - pred| where y and pred are label and predictions for + * the instance respectively. + */ +public class AbsoluteError implements Loss { + + public static final AbsoluteError INSTANCE = new AbsoluteError(); + + private AbsoluteError() {} + + @Override + public double loss(double pred, double y) { + double error = y - pred; + return Math.abs(error); + } + + @Override + public double gradient(double pred, double y) { + return y > pred ? -1. : 1; + } + + @Override + public double hessian(double pred, double y) { + return 0.; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/LogLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/LogLoss.java new file mode 100644 index 000000000..b2efe8c6c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/LogLoss.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.loss; + +import org.apache.commons.math3.analysis.function.Sigmoid; + +/** + * The loss function for binary log loss. + * + *

The binary log loss defined as -y * pred + log(1 + exp(pred)) where y is a label in {0, 1} and + * pred is the predicted logit for the sample point. + */ +public class LogLoss implements Loss { + + public static final LogLoss INSTANCE = new LogLoss(); + private final Sigmoid sigmoid = new Sigmoid(); + + private LogLoss() {} + + @Override + public double loss(double pred, double y) { + return -y * pred + Math.log(1 + Math.exp(pred)); + } + + @Override + public double gradient(double pred, double y) { + return sigmoid.value(pred) - y; + } + + @Override + public double hessian(double pred, double y) { + double sig = sigmoid.value(pred); + return sig * (1 - sig); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java new file mode 100644 index 000000000..fa6fadf7f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.loss; + +import java.io.Serializable; + +/** Loss functions for gradient boosting algorithms. */ +public interface Loss extends Serializable { + + /** + * Calculates loss given pred and y. + * + * @param pred prediction value. + * @param y label value. + * @return loss value. + */ + double loss(double pred, double y); + + /** + * Calculates value of gradient given prediction and label. + * + * @param pred prediction value. + * @param y label value. + * @return the value of gradient. + */ + double gradient(double pred, double y); + + /** + * Calculates value of second derivative, i.e. hessian, given prediction and label. + * + * @param pred prediction value. + * @param y label value. + * @return the value of second derivative, i.e. hessian. + */ + double hessian(double pred, double y); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/SquaredError.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/SquaredError.java new file mode 100644 index 000000000..14321c024 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/SquaredError.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.loss; + +/** + * Squared error loss function defined as (y - pred)^2 where y and pred are label and predictions + * for the instance respectively. + */ +public class SquaredError implements Loss { + + public static final SquaredError INSTANCE = new SquaredError(); + + private SquaredError() {} + + @Override + public double loss(double pred, double y) { + double error = y - pred; + return error * error; + } + + @Override + public double gradient(double pred, double y) { + return -2. * (y - pred); + } + + @Override + public double hessian(double pred, double y) { + return 2.; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java new file mode 100644 index 000000000..50ba95a20 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.iteration.IterationID; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.PredGradHess; +import org.apache.flink.ml.common.gbt.loss.Loss; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +/** + * Calculates local histograms for local data partition. Specifically in the first round, this + * operator caches all data instances to JVM static region. + */ +public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator + implements TwoInputStreamOperator, + IterationListener { + private static final Logger LOG = + LoggerFactory.getLogger(CacheDataCalcLocalHistsOperator.class); + + private final GbtParams gbtParams; + private final IterationID iterationID; + private final String sharedInstancesKey; + private final String sharedPredGradHessKey; + private final String sharedShuffledIndicesKey; + private final String sharedSwappedIndicesKey; + private final OutputTag stateOutputTag; + + // States of local data. + private transient List instancesCollecting; + private transient LocalState localState; + private transient TreeInitializer treeInitializer; + private transient HistBuilder histBuilder; + + // Readers/writers of shared data. + private transient IterationSharedStorage.Writer instancesWriter; + private transient IterationSharedStorage.Reader pghReader; + private transient IterationSharedStorage.Writer shuffledIndicesWriter; + private transient IterationSharedStorage.Reader swappedIndicesReader; + + public CacheDataCalcLocalHistsOperator( + GbtParams gbtParams, + IterationID iterationID, + String sharedInstancesKey, + String sharedPredGradHessKey, + String sharedShuffledIndicesKey, + String sharedSwappedIndicesKey, + OutputTag stateOutputTag) { + super(); + this.gbtParams = gbtParams; + this.iterationID = iterationID; + this.sharedInstancesKey = sharedInstancesKey; + this.sharedPredGradHessKey = sharedPredGradHessKey; + this.sharedShuffledIndicesKey = sharedShuffledIndicesKey; + this.sharedSwappedIndicesKey = sharedSwappedIndicesKey; + this.stateOutputTag = stateOutputTag; + } + + @Override + public void open() throws Exception { + instancesCollecting = new ArrayList<>(); + + int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); + instancesWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + sharedInstancesKey, + getOperatorID(), + new BinnedInstance[0]); + + shuffledIndicesWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + sharedShuffledIndicesKey, + getOperatorID(), + new int[0]); + + this.pghReader = + IterationSharedStorage.getReader(iterationID, subtaskId, sharedPredGradHessKey); + this.swappedIndicesReader = + IterationSharedStorage.getReader(iterationID, subtaskId, sharedSwappedIndicesKey); + } + + @Override + public void processElement1(StreamRecord streamRecord) throws Exception { + Row row = streamRecord.getValue(); + IntIntHashMap features = new IntIntHashMap(); + if (gbtParams.isInputVector) { + Vector vec = row.getFieldAs(gbtParams.vectorCol); + SparseVector sv = vec.toSparse(); + for (int i = 0; i < sv.indices.length; i += 1) { + features.put(sv.indices[i], (int) sv.values[i]); + } + } else { + for (int i = 0; i < gbtParams.featureCols.length; i += 1) { + // Values from StringIndexModel#transform are double. + features.put(i, ((Number) row.getFieldAs(gbtParams.featureCols[i])).intValue()); + } + } + double label = row.getFieldAs(gbtParams.labelCol); + instancesCollecting.add(new BinnedInstance(features, 1., label)); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + localState = streamRecord.getValue(); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector out) throws Exception { + if (0 == epochWatermark) { + // Initializes local state in first round. + instancesWriter.set(instancesCollecting.toArray(new BinnedInstance[0])); + instancesCollecting.clear(); + new LocalStateInitializer(gbtParams) + .init( + localState, + getRuntimeContext().getIndexOfThisSubtask(), + getRuntimeContext().getNumberOfParallelSubtasks(), + instancesWriter.get()); + + treeInitializer = new TreeInitializer(localState.statics); + histBuilder = new HistBuilder(localState.statics); + } + + BinnedInstance[] instances = instancesWriter.get(); + Preconditions.checkArgument( + getRuntimeContext().getIndexOfThisSubtask() == localState.statics.subtaskId); + PredGradHess[] pgh = pghReader.get(); + + // In the first round, use prior as the predictions. + if (0 == pgh.length) { + pgh = new PredGradHess[instances.length]; + double prior = localState.statics.prior; + Loss loss = localState.statics.loss; + for (int i = 0; i < instances.length; i += 1) { + double label = instances[i].label; + pgh[i] = + new PredGradHess( + prior, loss.gradient(prior, label), loss.hessian(prior, label)); + } + } + + int[] indices; + if (!localState.dynamics.inWeakLearner) { + // When last tree is finished, initializes a new tree, and shuffle instance indices. + treeInitializer.init(localState.dynamics, shuffledIndicesWriter::set); + localState.dynamics.inWeakLearner = true; + indices = shuffledIndicesWriter.get(); + } else { + // Otherwise, uses the swapped instance indices. + shuffledIndicesWriter.set(new int[0]); + indices = swappedIndicesReader.get(); + } + + Histogram localHists = + histBuilder.build( + localState.dynamics.layer, + localState.dynamics.nodeFeaturePairs, + indices, + instances, + pgh); + out.collect(localHists); + context.output(stateOutputTag, localState); + } + + @Override + public void onIterationTerminated(Context context, Collector collector) { + instancesCollecting.clear(); + instancesWriter.set(new BinnedInstance[0]); + shuffledIndicesWriter.set(new int[0]); + } + + @Override + public void close() throws Exception { + instancesWriter.remove(); + shuffledIndicesWriter.remove(); + super.close(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java new file mode 100644 index 000000000..3cf86557c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +import java.util.Collections; + +/** Calculates local splits for assigned (nodeId, featureId) pairs. */ +public class CalcLocalSplitsOperator extends AbstractStreamOperator + implements TwoInputStreamOperator, + IterationListener { + + private static final String LOCAL_STATE_STATE_NAME = "local_state"; + private static final String CALC_BEST_SPLIT_STATE_NAME = "split_finder"; + private static final String HISTOGRAM_STATE_NAME = "histogram"; + + private final OutputTag stateOutputTag; + + private transient ListState localState; + private transient ListState splitFinder; + private transient ListState histogram; + + public CalcLocalSplitsOperator(OutputTag stateOutputTag) { + this.stateOutputTag = stateOutputTag; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + localState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + LOCAL_STATE_STATE_NAME, LocalState.class)); + splitFinder = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + CALC_BEST_SPLIT_STATE_NAME, SplitFinder.class)); + histogram = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>(HISTOGRAM_STATE_NAME, Histogram.class)); + } + + @SuppressWarnings("OptionalGetWithoutIsPresent") + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + LocalState localStateValue = + OperatorStateUtils.getUniqueElement(localState, LOCAL_STATE_STATE_NAME).get(); + if (0 == epochWatermark) { + splitFinder.update(Collections.singletonList(new SplitFinder(localStateValue.statics))); + } + Splits splits = + OperatorStateUtils.getUniqueElement(splitFinder, CALC_BEST_SPLIT_STATE_NAME) + .get() + .calc( + localStateValue.dynamics.layer, + localStateValue.dynamics.nodeFeaturePairs, + localStateValue.dynamics.leaves, + OperatorStateUtils.getUniqueElement(histogram, HISTOGRAM_STATE_NAME) + .get()); + collector.collect(splits); + context.output(stateOutputTag, localStateValue); + } + + @Override + public void onIterationTerminated(Context context, Collector collector) {} + + @Override + public void processElement1(StreamRecord element) throws Exception { + localState.update(Collections.singletonList(element.getValue())); + } + + @Override + public void processElement2(StreamRecord element) throws Exception { + histogram.update(Collections.singletonList(element.getValue())); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java new file mode 100644 index 000000000..a9428b5b4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.DataUtils; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.Distributor; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.PredGradHess; + +import org.eclipse.collections.api.tuple.primitive.IntIntPair; +import org.eclipse.collections.impl.tuple.primitive.PrimitiveTuples; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.stream.IntStream; + +class HistBuilder { + private static final Logger LOG = LoggerFactory.getLogger(HistBuilder.class); + + private final int subtaskId; + private final int numSubtasks; + + private final int[] numFeatureBins; + private final FeatureMeta[] featureMetas; + + private final int numBaggingFeatures; + private final Random featureRandomizer; + private final int[] featureIndicesPool; + + private final boolean isInputVector; + + private final double[] hists; + + public HistBuilder(LocalState.Statics statics) { + subtaskId = statics.subtaskId; + numSubtasks = statics.numSubtasks; + + numFeatureBins = statics.numFeatureBins; + featureMetas = statics.featureMetas; + + numBaggingFeatures = statics.numBaggingFeatures; + featureRandomizer = statics.featureRandomizer; + featureIndicesPool = IntStream.range(0, statics.numFeatures).toArray(); + + isInputVector = statics.params.isInputVector; + + int maxNumNodes = + Math.min( + ((int) Math.pow(2, statics.params.maxDepth - 1)), + statics.params.maxNumLeaves); + + int maxFeatureBins = Arrays.stream(numFeatureBins).max().orElse(0); + int totalNumFeatureBins = Arrays.stream(numFeatureBins).sum(); + int maxNumBins = + maxNumNodes * Math.min(maxFeatureBins * numBaggingFeatures, totalNumFeatureBins); + hists = new double[maxNumBins * DataUtils.BIN_SIZE]; + } + + /** + * Calculate histograms for all (nodeId, featureId) pairs. The results are written to `hists`, + * so `hists` must be large enough to store values. + */ + private static void calcNodeFeaturePairHists( + List layer, + List nodeFeaturePairs, + FeatureMeta[] featureMetas, + boolean isInputVector, + int[] numFeatureBins, + int[] indices, + BinnedInstance[] instances, + PredGradHess[] pgh, + double[] hists) { + Arrays.fill(hists, 0.); + int binOffset = 0; + for (IntIntPair nodeFeaturePair : nodeFeaturePairs) { + int nodeId = nodeFeaturePair.getOne(); + int featureId = nodeFeaturePair.getTwo(); + FeatureMeta featureMeta = featureMetas[featureId]; + + int defaultValue = featureMeta.missingBin; + // When isInputVector is true, values of unseen features are treated as 0s. + if (isInputVector && featureMeta instanceof FeatureMeta.ContinuousFeatureMeta) { + defaultValue = ((FeatureMeta.ContinuousFeatureMeta) featureMeta).zeroBin; + } + + LearningNode node = layer.get(nodeId); + for (int i = node.slice.start; i < node.slice.end; i += 1) { + int instanceId = indices[i]; + BinnedInstance binnedInstance = instances[instanceId]; + double gradient = pgh[instanceId].gradient; + double hessian = pgh[instanceId].hessian; + + int val = binnedInstance.features.getIfAbsent(featureId, defaultValue); + int startIndex = (binOffset + val) * DataUtils.BIN_SIZE; + hists[startIndex] += gradient; + hists[startIndex + 1] += hessian; + hists[startIndex + 2] += binnedInstance.weight; + hists[startIndex + 3] += 1.; + } + binOffset += numFeatureBins[featureId]; + } + } + + /** + * Calculates elements counts of histogram distributed to each downstream subtask. The elements + * counts is bin counts multiplied by STEP. The minimum unit to be distributed is (nodeId, + * featureId), i.e., all bins belonging to the same (nodeId, featureId) pair must go to one + * subtask. + */ + private static int[] calcRecvCounts( + int numSubtasks, List nodeFeaturePairs, int[] numFeatureBins) { + int[] recvcnts = new int[numSubtasks]; + Distributor.EvenDistributor distributor = + new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.size()); + for (int k = 0; k < numSubtasks; k += 1) { + int pairStart = (int) distributor.start(k); + int pairCnt = (int) distributor.count(k); + for (int i = pairStart; i < pairStart + pairCnt; i += 1) { + int featureId = nodeFeaturePairs.get(i).getTwo(); + recvcnts[k] += numFeatureBins[featureId] * DataUtils.BIN_SIZE; + } + } + return recvcnts; + } + + /** Calculate local histograms for nodes in current layer of tree. */ + public Histogram build( + List layer, + List nodeFeaturePairs, + int[] indices, + BinnedInstance[] instances, + PredGradHess[] pgh) { + LOG.info("subtaskId: {}, {} start", subtaskId, HistBuilder.class.getSimpleName()); + + // Generates (nodeId, featureId) pairs that are required to build histograms. + nodeFeaturePairs.clear(); + for (int k = 0; k < layer.size(); k += 1) { + int[] sampledFeatures = + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + for (int featureId : sampledFeatures) { + nodeFeaturePairs.add(PrimitiveTuples.pair(k, featureId)); + } + } + + // Calculates histograms for (nodeId, featureId) pairs. + calcNodeFeaturePairHists( + layer, + nodeFeaturePairs, + featureMetas, + isInputVector, + numFeatureBins, + indices, + instances, + pgh, + hists); + + // Calculates number of elements received by each downstream subtask. + int[] recvcnts = calcRecvCounts(numSubtasks, nodeFeaturePairs, numFeatureBins); + + LOG.info("subtaskId: {}, {} end", this.subtaskId, HistBuilder.class.getSimpleName()); + return new Histogram(this.subtaskId, hists, recvcnts); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java new file mode 100644 index 000000000..8fbb869f4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.util.BitSet; + +/** Aggregation function for merging histograms. */ +public class HistogramAggregateFunction extends RichFlatMapFunction { + + private final AggregateFunction aggregator = + new Histogram.Aggregator(); + private int numSubtasks; + private BitSet accepted; + private Histogram acc = null; + + @Override + public void flatMap(Histogram value, Collector out) throws Exception { + if (null == accepted) { + numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); + accepted = new BitSet(numSubtasks); + } + int receivedSubtaskId = value.subtaskId; + Preconditions.checkState(!accepted.get(receivedSubtaskId)); + accepted.set(receivedSubtaskId); + acc = aggregator.add(value, acc); + if (numSubtasks == accepted.cardinality()) { + acc.subtaskId = getRuntimeContext().getIndexOfThisSubtask(); + out.collect(acc); + accepted = null; + acc = null; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java new file mode 100644 index 000000000..cb5482014 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.PredGradHess; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.loss.Loss; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.function.Consumer; + +class InstanceUpdater { + private static final Logger LOG = LoggerFactory.getLogger(InstanceUpdater.class); + + private final int subtaskId; + private final Loss loss; + private final double stepSize; + private final PredGradHess[] pgh; + private final double prior; + + private boolean initialized; + + public InstanceUpdater(LocalState.Statics statics) { + subtaskId = statics.subtaskId; + loss = statics.loss; + stepSize = statics.params.stepSize; + prior = statics.prior; + pgh = new PredGradHess[statics.numInstances]; + initialized = false; + } + + public void update( + List leaves, + int[] indices, + BinnedInstance[] instances, + Consumer pghSetter) { + LOG.info("subtaskId: {}, {} start", subtaskId, InstanceUpdater.class.getSimpleName()); + if (!initialized) { + for (int i = 0; i < instances.length; i += 1) { + double label = instances[i].label; + pgh[i] = + new PredGradHess( + prior, loss.gradient(prior, label), loss.hessian(prior, label)); + } + initialized = true; + } + + for (LearningNode nodeInfo : leaves) { + Split split = nodeInfo.node.split; + double pred = split.prediction * stepSize; + for (int i = nodeInfo.slice.start; i < nodeInfo.slice.end; ++i) { + int instanceId = indices[i]; + updatePgh(pred, instances[instanceId].label, pgh[instanceId]); + } + for (int i = nodeInfo.oob.start; i < nodeInfo.oob.end; ++i) { + int instanceId = indices[i]; + updatePgh(pred, instances[instanceId].label, pgh[instanceId]); + } + } + pghSetter.accept(pgh); + LOG.info("subtaskId: {}, {} end", subtaskId, InstanceUpdater.class.getSimpleName()); + } + + private void updatePgh(double pred, double label, PredGradHess pgh) { + pgh.pred += pred; + pgh.gradient = loss.gradient(pgh.pred, label); + pgh.hessian = loss.hessian(pgh.pred, label); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/LocalStateInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/LocalStateInitializer.java new file mode 100644 index 000000000..baf5548fc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/LocalStateInitializer.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.common.gbt.loss.AbsoluteError; +import org.apache.flink.ml.common.gbt.loss.LogLoss; +import org.apache.flink.ml.common.gbt.loss.Loss; +import org.apache.flink.ml.common.gbt.loss.SquaredError; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.function.Function; + +import static java.util.Arrays.stream; + +class LocalStateInitializer { + private static final Logger LOG = LoggerFactory.getLogger(LocalStateInitializer.class); + private final GbtParams params; + + public LocalStateInitializer(GbtParams params) { + this.params = params; + } + + /** + * Initializes local state. + * + *

Note that local state already has some properties set in advance, see GBTRunner#boost. + */ + public LocalState init( + LocalState localState, int subtaskId, int numSubtasks, BinnedInstance[] instances) { + LOG.info("subtaskId: {}, {} start", subtaskId, LocalStateInitializer.class.getSimpleName()); + + LocalState.Statics statics = localState.statics; + statics.subtaskId = subtaskId; + statics.numSubtasks = numSubtasks; + + int numInstances = instances.length; + int numFeatures = statics.featureMetas.length; + + LOG.info( + "subtaskId: {}, #samples: {}, #features: {}", subtaskId, numInstances, numFeatures); + + statics.params = params; + statics.numInstances = numInstances; + statics.numFeatures = numFeatures; + + statics.numBaggingInstances = getNumBaggingSamples(numInstances); + statics.numBaggingFeatures = getNumBaggingFeatures(numFeatures); + + statics.instanceRandomizer = new Random(subtaskId + params.seed); + statics.featureRandomizer = new Random(params.seed); + + statics.loss = getLoss(); + statics.prior = calcPrior(statics.labelSumCount); + + statics.numFeatureBins = + stream(statics.featureMetas) + .mapToInt(d -> d.numBins(statics.params.useMissing)) + .toArray(); + + LocalState.Dynamics dynamics = localState.dynamics; + dynamics.inWeakLearner = false; + + LOG.info("subtaskId: {}, {} end", subtaskId, LocalStateInitializer.class.getSimpleName()); + return new LocalState(statics, dynamics); + } + + private int getNumBaggingSamples(int numSamples) { + return (int) Math.min(numSamples, Math.ceil(numSamples * params.subsamplingRate)); + } + + private int getNumBaggingFeatures(int numFeatures) { + final List supported = Arrays.asList("auto", "all", "onethird", "sqrt", "log2"); + final String errorMsg = + String.format( + "Parameter `featureSubsetStrategy` supports %s, (0.0 - 1.0], [1 - n].", + String.join(", ", supported)); + final Function clamp = + d -> Math.max(1, Math.min(d.intValue(), numFeatures)); + String strategy = params.featureSubsetStrategy; + try { + int numBaggingFeatures = Integer.parseInt(strategy); + Preconditions.checkArgument( + numBaggingFeatures >= 1 && numBaggingFeatures <= numFeatures, errorMsg); + } catch (NumberFormatException ignored) { + } + try { + double baggingRatio = Double.parseDouble(strategy); + Preconditions.checkArgument(baggingRatio > 0. && baggingRatio <= 1., errorMsg); + return clamp.apply(baggingRatio * numFeatures); + } catch (NumberFormatException ignored) { + } + + Preconditions.checkArgument(supported.contains(strategy), errorMsg); + switch (strategy) { + case "auto": + return TaskType.CLASSIFICATION.equals(params.taskType) + ? clamp.apply(Math.sqrt(numFeatures)) + : clamp.apply(numFeatures / 3.); + case "all": + return numFeatures; + case "onethird": + return clamp.apply(numFeatures / 3.); + case "sqrt": + return clamp.apply(Math.sqrt(numFeatures)); + case "log2": + return clamp.apply(Math.log(numFeatures) / Math.log(2)); + default: + throw new IllegalArgumentException(errorMsg); + } + } + + private Loss getLoss() { + String lossType = params.lossType; + switch (lossType) { + case "logistic": + return LogLoss.INSTANCE; + case "squared": + return SquaredError.INSTANCE; + case "absolute": + return AbsoluteError.INSTANCE; + default: + throw new UnsupportedOperationException("Unsupported loss."); + } + } + + private double calcPrior(Tuple2 labelStat) { + String lossType = params.lossType; + switch (lossType) { + case "logistic": + return Math.log(labelStat.f0 / (labelStat.f1 - labelStat.f0)); + case "squared": + return labelStat.f0 / labelStat.f1; + case "absolute": + throw new UnsupportedOperationException("absolute error is not supported yet."); + default: + throw new UnsupportedOperationException("Unsupported loss."); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java new file mode 100644 index 000000000..72e96352d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +class NodeSplitter { + private static final Logger LOG = LoggerFactory.getLogger(NodeSplitter.class); + + private final int subtaskId; + private final FeatureMeta[] featureMetas; + private final int maxLeaves; + private final int maxDepth; + + public NodeSplitter(LocalState.Statics statics) { + subtaskId = statics.subtaskId; + featureMetas = statics.featureMetas; + maxLeaves = statics.params.maxNumLeaves; + maxDepth = statics.params.maxDepth; + } + + private int partitionInstances( + Split split, Slice slice, int[] indices, BinnedInstance[] instances) { + int lstart = slice.start; + int lend = slice.end - 1; + while (lstart <= lend) { + while (lstart <= lend && split.shouldGoLeft(instances[indices[lstart]])) { + lstart += 1; + } + while (lstart <= lend && !split.shouldGoLeft(instances[indices[lend]])) { + lend -= 1; + } + if (lstart < lend) { + int temp = indices[lstart]; + indices[lstart] = indices[lend]; + indices[lend] = temp; + } + } + return lstart; + } + + private void splitNode( + LearningNode nodeInfo, + int[] indices, + BinnedInstance[] instances, + List nextLayer) { + int mid = partitionInstances(nodeInfo.node.split, nodeInfo.slice, indices, instances); + int oobMid = partitionInstances(nodeInfo.node.split, nodeInfo.oob, indices, instances); + nodeInfo.node.left = new Node(); + nodeInfo.node.right = new Node(); + nextLayer.add( + new LearningNode( + nodeInfo.node.left, + new Slice(nodeInfo.slice.start, mid), + new Slice(nodeInfo.oob.start, oobMid), + nodeInfo.depth + 1)); + nextLayer.add( + new LearningNode( + nodeInfo.node.right, + new Slice(mid, nodeInfo.slice.end), + new Slice(oobMid, nodeInfo.oob.end), + nodeInfo.depth + 1)); + } + + public void split( + List layer, + List leaves, + Split[] splits, + int[] indices, + BinnedInstance[] instances) { + LOG.info("subtaskId: {}, {} start", subtaskId, NodeSplitter.class.getSimpleName()); + Preconditions.checkState(splits.length == layer.size()); + + List nextLayer = new ArrayList<>(); + + // nodes in current layer or next layer are expected to generate at least 1 leaf. + int numQueued = layer.size(); + for (int i = 0; i < layer.size(); i += 1) { + LearningNode node = layer.get(i); + Split split = splits[i]; + numQueued -= 1; + node.node.split = split; + if (!split.isValid() + || node.node.isLeaf + || (leaves.size() + numQueued + 2) > maxLeaves + || node.depth + 1 > maxDepth) { + node.node.isLeaf = true; + leaves.add(node); + } else { + splitNode(node, indices, instances, nextLayer); + // Converts splits point from bin id to real feature value after splitting node. + if (split instanceof Split.ContinuousSplit) { + Split.ContinuousSplit cs = (Split.ContinuousSplit) split; + FeatureMeta.ContinuousFeatureMeta featureMeta = + (FeatureMeta.ContinuousFeatureMeta) featureMetas[cs.featureId]; + cs.threshold = featureMeta.binEdges[(int) cs.threshold + 1]; + } + numQueued += 2; + } + } + + layer.clear(); + layer.addAll(nextLayer); + + LOG.info("subtaskId: {}, {} end", subtaskId, NodeSplitter.class.getSimpleName()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java new file mode 100644 index 000000000..1baf89669 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.iteration.IterationID; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.PredGradHess; +import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +/** + * Post-process after global splits obtained, including split instances to left or child nodes, and + * update instances scores after a tree is complete. + */ +public class PostSplitsOperator extends AbstractStreamOperator + implements TwoInputStreamOperator, + IterationListener { + + private final IterationID iterationID; + private final String sharedInstancesKey; + private final String sharedPredGradHessKey; + private final String sharedShuffledIndicesKey; + private final String sharedSwappedIndicesKey; + private final OutputTag finalStateOutputTag; + + private IterationSharedStorage.Reader instancesReader; + private IterationSharedStorage.Writer pghWriter; + private IterationSharedStorage.Reader shuffledIndicesReader; + private IterationSharedStorage.Writer swappedIndicesWriter; + + private transient LocalState localState; + private transient Splits splits; + private transient NodeSplitter nodeSplitter; + private transient InstanceUpdater instanceUpdater; + + public PostSplitsOperator( + IterationID iterationID, + String sharedInstancesKey, + String sharedPredGradHessKey, + String sharedShuffledIndicesKey, + String sharedSwappedIndicesKey, + OutputTag finalStateOutputTag) { + this.iterationID = iterationID; + this.sharedInstancesKey = sharedInstancesKey; + this.sharedPredGradHessKey = sharedPredGradHessKey; + this.sharedShuffledIndicesKey = sharedShuffledIndicesKey; + this.sharedSwappedIndicesKey = sharedSwappedIndicesKey; + this.finalStateOutputTag = finalStateOutputTag; + } + + @Override + public void open() throws Exception { + int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); + pghWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + sharedPredGradHessKey, + getOperatorID(), + new PredGradHess[0]); + swappedIndicesWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + sharedSwappedIndicesKey, + getOperatorID(), + new int[0]); + + this.instancesReader = + IterationSharedStorage.getReader(iterationID, subtaskId, sharedInstancesKey); + this.shuffledIndicesReader = + IterationSharedStorage.getReader(iterationID, subtaskId, sharedShuffledIndicesKey); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + LocalState localStateValue = localState; + if (0 == epochWatermark) { + nodeSplitter = new NodeSplitter(localStateValue.statics); + instanceUpdater = new InstanceUpdater(localStateValue.statics); + } + + int[] indices = swappedIndicesWriter.get(); + if (0 == indices.length) { + indices = shuffledIndicesReader.get().clone(); + } + + BinnedInstance[] instances = instancesReader.get(); + nodeSplitter.split( + localStateValue.dynamics.layer, + localStateValue.dynamics.leaves, + splits.splits, + indices, + instances); + + if (localStateValue.dynamics.layer.isEmpty()) { + localStateValue.dynamics.inWeakLearner = false; + instanceUpdater.update( + localStateValue.dynamics.leaves, indices, instances, pghWriter::set); + swappedIndicesWriter.set(new int[0]); + } else { + swappedIndicesWriter.set(indices); + } + collector.collect(localStateValue); + } + + @Override + public void onIterationTerminated(Context context, Collector collector) { + pghWriter.set(new PredGradHess[0]); + swappedIndicesWriter.set(new int[0]); + if (0 == getRuntimeContext().getIndexOfThisSubtask()) { + context.output(finalStateOutputTag, localState); + } + } + + @Override + public void processElement1(StreamRecord element) throws Exception { + localState = element.getValue(); + } + + @Override + public void processElement2(StreamRecord element) throws Exception { + splits = element.getValue(); + } + + @Override + public void close() throws Exception { + pghWriter.remove(); + swappedIndicesWriter.remove(); + super.close(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java new file mode 100644 index 000000000..ea4cf302b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.defs.Distributor; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.splitter.CategoricalFeatureSplitter; +import org.apache.flink.ml.common.gbt.splitter.ContinuousFeatureSplitter; +import org.apache.flink.ml.common.gbt.splitter.HistogramFeatureSplitter; +import org.apache.flink.util.Preconditions; + +import org.eclipse.collections.api.tuple.primitive.IntIntPair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +class SplitFinder { + private static final Logger LOG = LoggerFactory.getLogger(SplitFinder.class); + + private final int subtaskId; + private final int numSubtasks; + private final int[] numFeatureBins; + private final HistogramFeatureSplitter[] splitters; + private final int maxDepth; + private final int maxNumLeaves; + + public SplitFinder(LocalState.Statics statics) { + subtaskId = statics.subtaskId; + numSubtasks = statics.numSubtasks; + + numFeatureBins = statics.numFeatureBins; + FeatureMeta[] featureMetas = statics.featureMetas; + splitters = new HistogramFeatureSplitter[statics.numFeatures]; + for (int i = 0; i < statics.numFeatures; ++i) { + splitters[i] = + FeatureMeta.Type.CATEGORICAL == featureMetas[i].type + ? new CategoricalFeatureSplitter(i, featureMetas[i], statics.params) + : new ContinuousFeatureSplitter(i, featureMetas[i], statics.params); + } + maxDepth = statics.params.maxDepth; + maxNumLeaves = statics.params.maxNumLeaves; + } + + public Splits calc( + List layer, + List nodeFeaturePairs, + List leaves, + Histogram histogram) { + LOG.info("subtaskId: {}, {} start", subtaskId, SplitFinder.class.getSimpleName()); + + Distributor distributor = + new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.size()); + long start = distributor.start(subtaskId); + long cnt = distributor.count(subtaskId); + + Split[] nodesBestSplits = new Split[layer.size()]; + int binOffset = 0; + for (long i = start; i < start + cnt; i += 1) { + IntIntPair nodeFeaturePair = nodeFeaturePairs.get((int) i); + int nodeId = nodeFeaturePair.getOne(); + int featureId = nodeFeaturePair.getTwo(); + LearningNode node = layer.get(nodeId); + + Preconditions.checkState(node.depth < maxDepth || leaves.size() + 2 <= maxNumLeaves); + splitters[featureId].reset( + histogram.hists, new Slice(binOffset, binOffset + numFeatureBins[featureId])); + Split bestSplit = splitters[featureId].bestSplit(); + if (null == nodesBestSplits[nodeId] + || (bestSplit.gain > nodesBestSplits[nodeId].gain)) { + nodesBestSplits[nodeId] = bestSplit; + } + binOffset += numFeatureBins[featureId]; + } + + LOG.info("subtaskId: {}, {} end", subtaskId, SplitFinder.class.getSimpleName()); + return new Splits(subtaskId, nodesBestSplits); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java new file mode 100644 index 000000000..8b1c0cee4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.util.BitSet; + +/** Aggregation function for merging splits. */ +public class SplitsAggregateFunction extends RichFlatMapFunction { + + private final AggregateFunction aggregator = new Splits.Aggregator(); + private int numSubtasks; + private BitSet accepted; + private Splits acc = null; + + @Override + public void flatMap(Splits value, Collector out) throws Exception { + if (null == accepted) { + numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); + accepted = new BitSet(numSubtasks); + } + int receivedSubtaskId = value.subtaskId; + Preconditions.checkState(!accepted.get(receivedSubtaskId)); + accepted.set(receivedSubtaskId); + acc = aggregator.add(value, acc); + if (numSubtasks == accepted.cardinality()) { + acc.subtaskId = getRuntimeContext().getIndexOfThisSubtask(); + out.collect(acc); + accepted = null; + acc = null; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java new file mode 100644 index 000000000..42240fc54 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.DataUtils; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Random; +import java.util.function.Consumer; +import java.util.stream.IntStream; + +class TreeInitializer { + private static final Logger LOG = LoggerFactory.getLogger(TreeInitializer.class); + + private final int subtaskId; + private final int numInstances; + private final int numBaggingInstances; + private final int[] shuffledIndices; + private final Random instanceRandomizer; + + public TreeInitializer(LocalState.Statics statics) { + subtaskId = statics.subtaskId; + numInstances = statics.numInstances; + numBaggingInstances = statics.numBaggingInstances; + instanceRandomizer = statics.instanceRandomizer; + shuffledIndices = IntStream.range(0, numInstances).toArray(); + } + + /** Calculate local histograms for nodes in current layer of tree. */ + public void init(LocalState.Dynamics dynamics, Consumer shuffledIndicesSetter) { + LOG.info("subtaskId: {}, {} start", subtaskId, TreeInitializer.class.getSimpleName()); + Preconditions.checkState(!dynamics.inWeakLearner); + Preconditions.checkState(dynamics.layer.isEmpty()); + + // Initializes the root node of a new tree when last tree is finalized. + DataUtils.shuffle(shuffledIndices, instanceRandomizer); + Node root = new Node(); + dynamics.layer.add( + new LearningNode( + root, + new Slice(0, numBaggingInstances), + new Slice(numBaggingInstances, numInstances), + 1)); + dynamics.roots.add(root); + dynamics.leaves.clear(); + shuffledIndicesSetter.accept(shuffledIndices); + + LOG.info("subtaskId: {}, {} end", this.subtaskId, TreeInitializer.class.getSimpleName()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java new file mode 100644 index 000000000..eaac29c47 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.HessianImpurity; +import org.apache.flink.ml.common.gbt.defs.Split; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.Comparator; + +import static org.apache.flink.ml.common.gbt.DataUtils.BIN_SIZE; + +/** Splitter for a categorical feature using LightGBM many-vs-many split. */ +public class CategoricalFeatureSplitter extends HistogramFeatureSplitter { + + public CategoricalFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { + super(featureId, featureMeta, params); + } + + @Override + public Split.CategoricalSplit bestSplit() { + Tuple2 totalMissing = countTotalMissing(); + HessianImpurity total = totalMissing.f0; + HessianImpurity missing = totalMissing.f1; + + if (total.getNumInstances() <= minSamplesPerLeaf) { + return Split.CategoricalSplit.invalid(total.prediction()); + } + + int numBins = slice.size(); + // Sorts categories based on grads / hessians, i.e., LightGBM many-vs-many approach. + Integer[] sortedCategories = new Integer[numBins]; + { + double[] scores = new double[numBins]; + for (int i = 0; i < numBins; ++i) { + sortedCategories[i] = i; + int startIndex = (slice.start + i) * BIN_SIZE; + scores[i] = hists[startIndex] / hists[startIndex + 1]; + } + Arrays.sort(sortedCategories, Comparator.comparing(d -> scores[d])); + } + + Tuple3 bestSplit = + findBestSplit(ArrayUtils.toPrimitive(sortedCategories), total, missing); + double bestGain = bestSplit.f0; + int bestSplitBinId = bestSplit.f1; + boolean missingGoLeft = bestSplit.f2; + + if (bestGain <= Split.INVALID_GAIN || bestGain <= minInfoGain) { + return Split.CategoricalSplit.invalid(total.prediction()); + } + + // Indicates which bins should go left. + BitSet binsGoLeft = new BitSet(numBins); + if (useMissing) { + for (int i = 0; i < numBins; ++i) { + int binId = sortedCategories[i]; + if (i <= bestSplitBinId) { + if (binId < featureMeta.missingBin) { + binsGoLeft.set(binId); + } else if (binId > featureMeta.missingBin) { + binsGoLeft.set(binId - 1); + } + } + } + } else { + int numCategories = + ((FeatureMeta.CategoricalFeatureMeta) featureMeta).categories.length; + for (int i = 0; i < numCategories; i += 1) { + int binId = sortedCategories[i]; + if (i <= bestSplitBinId) { + binsGoLeft.set(binId); + } + } + } + return new Split.CategoricalSplit( + featureId, + bestGain, + featureMeta.missingBin, + missingGoLeft, + total.prediction(), + binsGoLeft); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java new file mode 100644 index 000000000..ce2656cd8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.HessianImpurity; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.util.stream.IntStream; + +/** Splitter for a continuous feature. */ +public final class ContinuousFeatureSplitter extends HistogramFeatureSplitter { + + public ContinuousFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { + super(featureId, featureMeta, params); + } + + @Override + public Split.ContinuousSplit bestSplit() { + Tuple2 totalMissing = countTotalMissing(); + HessianImpurity total = totalMissing.f0; + HessianImpurity missing = totalMissing.f1; + + if (total.getNumInstances() <= minSamplesPerLeaf) { + return Split.ContinuousSplit.invalid(total.prediction()); + } + + int[] sortedBinIds = IntStream.range(0, slice.size()).toArray(); + Tuple3 bestSplit = findBestSplit(sortedBinIds, total, missing); + double bestGain = bestSplit.f0; + int bestSplitBinId = bestSplit.f1; + boolean missingGoLeft = bestSplit.f2; + + if (bestGain <= Split.INVALID_GAIN || bestGain <= minInfoGain) { + return Split.ContinuousSplit.invalid(total.prediction()); + } + int splitPoint = + useMissing && bestSplitBinId > featureMeta.missingBin + ? bestSplitBinId - 1 + : bestSplitBinId; + return new Split.ContinuousSplit( + featureId, + bestGain, + featureMeta.missingBin, + missingGoLeft, + total.prediction(), + splitPoint, + !params.isInputVector, + ((FeatureMeta.ContinuousFeatureMeta) featureMeta).zeroBin); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java new file mode 100644 index 000000000..b9fccf037 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.Split; + +/** + * Tests if the node can be split on a given feature and obtains best split. + * + *

When testing the node, we only check internal criteria, such as minimum info gain, minium + * samples per leaf, etc. The external criteria, like maximum depth or maximum number of leaves are + * not checked. + */ +public abstract class FeatureSplitter { + protected final int featureId; + protected final FeatureMeta featureMeta; + protected final GbtParams params; + + protected final int minSamplesPerLeaf; + protected final double minSampleRatioPerChild; + protected final double minInfoGain; + + public FeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { + this.params = params; + this.featureId = featureId; + this.featureMeta = featureMeta; + + this.minSamplesPerLeaf = params.minInstancesPerNode; + this.minSampleRatioPerChild = params.minWeightFractionPerNode; // TODO: not exactly the same + this.minInfoGain = params.minInfoGain; + } + + public abstract Split bestSplit(); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java new file mode 100644 index 000000000..0d22a7466 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.splitter; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.FeatureMeta; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.HessianImpurity; +import org.apache.flink.ml.common.gbt.defs.Impurity; +import org.apache.flink.ml.common.gbt.defs.Slice; +import org.apache.flink.ml.common.gbt.defs.Split; + +import static org.apache.flink.ml.common.gbt.DataUtils.BIN_SIZE; + +/** Histogram based feature splitter. */ +public abstract class HistogramFeatureSplitter extends FeatureSplitter { + protected final boolean useMissing; + protected double[] hists; + protected Slice slice; + + public HistogramFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { + super(featureId, featureMeta, params); + this.useMissing = params.useMissing; + } + + protected boolean isSplitIllegal(Impurity total, Impurity left, Impurity right) { + return (minSamplesPerLeaf > left.getTotalWeights() + || minSamplesPerLeaf > right.getTotalWeights()) + || minSampleRatioPerChild > 1. * left.getNumInstances() / total.getNumInstances() + || minSampleRatioPerChild > 1. * right.getNumInstances() / total.getNumInstances(); + } + + protected double gain(Impurity total, Impurity left, Impurity right) { + return isSplitIllegal(total, left, right) ? Split.INVALID_GAIN : total.gain(left, right); + } + + protected void addBinToLeft(int binId, HessianImpurity left, HessianImpurity right) { + int index = (slice.start + binId) * BIN_SIZE; + left.add((int) hists[index + 3], hists[index + 2], hists[index], hists[index + 1]); + if (null != right) { + right.subtract( + (int) hists[index + 3], hists[index + 2], hists[index], hists[index + 1]); + } + } + + protected Tuple2 findBestSplitWithInitial( + int[] sortedBinIds, + HessianImpurity total, + HessianImpurity left, + HessianImpurity right) { + // Bins [0, bestSplitBinId] go left. + int bestSplitBinId = 0; + double bestGain = Split.INVALID_GAIN; + for (int i = 0; i < sortedBinIds.length; i += 1) { + int binId = sortedBinIds[i]; + if (useMissing && binId == featureMeta.missingBin) { + continue; + } + addBinToLeft(binId, left, right); + double gain = gain(total, left, right); + if (gain > bestGain && gain >= minInfoGain) { + bestGain = gain; + bestSplitBinId = i; + } + } + return Tuple2.of(bestGain, bestSplitBinId); + } + + protected Tuple3 findBestSplit( + int[] sortedBinIds, HessianImpurity total, HessianImpurity missing) { + double bestGain = Split.INVALID_GAIN; + int bestSplitBinId = 0; + boolean missingGoLeft = false; + + { + // The cases where the missing values go right, or missing values are not allowed. + HessianImpurity left = emptyImpurity(); + HessianImpurity right = (HessianImpurity) total.clone(); + Tuple2 bestSplit = + findBestSplitWithInitial(sortedBinIds, total, left, right); + if (bestSplit.f0 > bestGain) { + bestGain = bestSplit.f0; + bestSplitBinId = bestSplit.f1; + } + } + + if (useMissing) { + // The cases where the missing values go left. + HessianImpurity leftWithMissing = emptyImpurity().add(missing); + HessianImpurity rightWithoutMissing = (HessianImpurity) total.clone().subtract(missing); + Tuple2 bestSplitMissingGoLeft = + findBestSplitWithInitial( + sortedBinIds, total, leftWithMissing, rightWithoutMissing); + if (bestSplitMissingGoLeft.f0 > bestGain) { + bestGain = bestSplitMissingGoLeft.f0; + bestSplitBinId = bestSplitMissingGoLeft.f1; + missingGoLeft = true; + } + } + return Tuple3.of(bestGain, bestSplitBinId, missingGoLeft); + } + + public void reset(double[] hists, Slice slice) { + this.hists = hists; + this.slice = slice; + } + + protected Tuple2 countTotalMissing() { + HessianImpurity total = emptyImpurity(); + HessianImpurity missing = emptyImpurity(); + for (int i = 0; i < slice.size(); ++i) { + addBinToLeft(i, total, null); + } + if (useMissing) { + addBinToLeft(featureMeta.missingBin, missing, null); + } + return Tuple2.of(total, missing); + } + + protected HessianImpurity emptyImpurity() { + return new HessianImpurity(params.lambda, params.gamma, 0, 0, 0, 0); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java new file mode 100644 index 000000000..16a770189 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +public class GBTRunnerTest extends AbstractTestBase { + private static final List inputDataRows = + Arrays.asList( + Row.of(1.2, 2, null, 40., 1., 0., Vectors.dense(1.2, 2, Double.NaN)), + Row.of(2.3, 3, "b", 40., 2., 0., Vectors.dense(2.3, 3, 2.)), + Row.of(3.4, 4, "c", 40., 3., 0., Vectors.dense(3.4, 4, 3.)), + Row.of(4.5, 5, "a", 40., 4., 0., Vectors.dense(4.5, 5, 1.)), + Row.of(5.6, 2, "b", 40., 5., 0., Vectors.dense(5.6, 2, 2.)), + Row.of(null, 3, "c", 41., 1., 1., Vectors.dense(Double.NaN, 3, 3.)), + Row.of(12.8, 4, "e", 41., 2., 1., Vectors.dense(12.8, 4, 5.)), + Row.of(13.9, 2, "b", 41., 3., 1., Vectors.dense(13.9, 2, 2.)), + Row.of(14.1, 4, "a", 41., 4., 1., Vectors.dense(14.1, 4, 1.)), + Row.of(15.3, 1, "d", 41., 5., 1., Vectors.dense(15.3, 1, 4.))); + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private Table inputTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputDataRows, + new RowTypeInfo( + new TypeInformation[] { + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.DOUBLE, + Types.DOUBLE, + Types.DOUBLE, + VectorTypeInfo.INSTANCE + }, + new String[] { + "f0", "f1", "f2", "label", "weight", "cls_label", "vec" + }))); + } + + private GbtParams getCommonGbtParams() { + GbtParams p = new GbtParams(); + p.featureCols = new String[] {"f0", "f1", "f2"}; + p.categoricalCols = new String[] {"f2"}; + p.isInputVector = false; + p.gamma = 0.; + p.maxBins = 3; + p.seed = 123; + p.featureSubsetStrategy = "all"; + p.maxDepth = 3; + p.maxNumLeaves = 1 << (p.maxDepth - 1); + p.maxIter = 20; + p.stepSize = 0.1; + return p; + } + + private void verifyModelData(GBTModelData modelData, GbtParams p) { + Assert.assertEquals(p.taskType, TaskType.valueOf(modelData.type)); + Assert.assertEquals(p.stepSize, modelData.stepSize, 1e-12); + Assert.assertEquals(p.maxIter, modelData.roots.size()); + } + + @Test + public void testTrainClassifier() throws Exception { + GbtParams p = getCommonGbtParams(); + p.taskType = TaskType.CLASSIFICATION; + p.labelCol = "cls_label"; + p.lossType = "logistic"; + + GBTModelData modelData = GBTRunner.train(inputTable, p).executeAndCollect().next(); + verifyModelData(modelData, p); + } + + @Test + public void testTrainRegressor() throws Exception { + GbtParams p = getCommonGbtParams(); + p.taskType = TaskType.REGRESSION; + p.labelCol = "label"; + p.lossType = "squared"; + + GBTModelData modelData = GBTRunner.train(inputTable, p).executeAndCollect().next(); + verifyModelData(modelData, p); + } +} From ee76f338b014e154ea1730a9a18e69b5d9629722 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 10 Feb 2023 16:32:12 +0800 Subject: [PATCH 03/47] Support checkpoint for operator states --- .../datastorage/IterationSharedStorage.java | 57 +++++++- .../CacheDataCalcLocalHistsOperator.java | 116 +++++++++++++---- .../gbt/operators/PostSplitsOperator.java | 99 +++++++++++--- .../typeinfo/BinnedInstanceSerializer.java | 123 ++++++++++++++++++ .../gbt/typeinfo/PredGradHessSerializer.java | 115 ++++++++++++++++ 5 files changed, 461 insertions(+), 49 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java index d933de447..e9922bf72 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java @@ -18,11 +18,20 @@ package org.apache.flink.ml.common.gbt.datastorage; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.iteration.IterationID; import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StatePartitionStreamProvider; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.util.Preconditions; +import org.apache.commons.collections.IteratorUtils; + +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -54,12 +63,18 @@ public static Reader getReader(IterationID iterationID, int subtaskId, St * @param subtaskId The subtask ID. * @param key The string key. * @param operatorID The owner operator. + * @param serializer Serializer of the data. * @param initVal Initialize value of the data. * @return A {@link Writer} of shared data. * @param The type of shared ata. */ public static Writer getWriter( - IterationID iterationID, int subtaskId, String key, OperatorID operatorID, T initVal) { + IterationID iterationID, + int subtaskId, + String key, + OperatorID operatorID, + TypeSerializer serializer, + T initVal) { Tuple3 t = Tuple3.of(iterationID, subtaskId, key); OperatorID lastOwner = owners.putIfAbsent(t, operatorID); if (null != lastOwner) { @@ -68,7 +83,7 @@ public static Writer getWriter( "The shared data (%s, %s, %s) already has a writer %s.", iterationID, subtaskId, key, operatorID)); } - Writer writer = new Writer<>(t, operatorID); + Writer writer = new Writer<>(t, operatorID, serializer); writer.set(initVal); return writer; } @@ -104,10 +119,15 @@ public T get() { */ public static class Writer extends Reader { private final OperatorID operatorID; + private final TypeSerializer serializer; - public Writer(Tuple3 t, OperatorID operatorID) { + public Writer( + Tuple3 t, + OperatorID operatorID, + TypeSerializer serializer) { super(t); this.operatorID = operatorID; + this.serializer = serializer; } private void ensureOwner() { @@ -132,5 +152,36 @@ public void remove() { m.remove(t); owners.remove(t); } + + /** + * Initialize the state. + * + * @param context The state initialization context. + * @throws Exception + */ + public void initializeState(StateInitializationContext context) throws Exception { + //noinspection unchecked + List inputs = + IteratorUtils.toList(context.getRawOperatorStateInputs().iterator()); + Preconditions.checkState( + inputs.size() < 2, "The input from raw operator state should be one or zero."); + if (inputs.size() > 0) { + T value = + serializer.deserialize( + new DataInputViewStreamWrapper(inputs.get(0).getStream())); + set(value); + } + } + + /** + * Snapshot the state. + * + * @param context The state snapshot context. + * @throws Exception + */ + public void snapshotState(StateSnapshotContext context) throws Exception { + serializer.serialize( + get(), new DataOutputViewStreamWrapper(context.getRawOperatorStateOutput())); + } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 50ba95a20..1b7d41770 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -18,8 +18,14 @@ package org.apache.flink.ml.common.gbt.operators; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.GbtParams; @@ -27,8 +33,11 @@ import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.loss.Loss; +import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -37,12 +46,12 @@ import org.apache.flink.util.OutputTag; import org.apache.flink.util.Preconditions; +import org.apache.commons.collections.IteratorUtils; import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; -import java.util.List; +import java.util.Collections; /** * Calculates local histograms for local data partition. Specifically in the first round, this @@ -54,6 +63,10 @@ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator stateOutputTag; // States of local data. - private transient List instancesCollecting; - private transient LocalState localState; - private transient TreeInitializer treeInitializer; - private transient HistBuilder histBuilder; + private transient ListStateWithCache instancesCollecting; + private transient ListState localState; + private transient ListState treeInitializer; + private transient ListState histBuilder; // Readers/writers of shared data. private transient IterationSharedStorage.Writer instancesWriter; @@ -93,8 +106,31 @@ public CacheDataCalcLocalHistsOperator( } @Override - public void open() throws Exception { - instancesCollecting = new ArrayList<>(); + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + instancesCollecting = + new ListStateWithCache<>( + BinnedInstanceSerializer.INSTANCE, + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + localState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + LOCAL_STATE_STATE_NAME, LocalState.class)); + treeInitializer = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + TREE_INITIALIZER_STATE_NAME, TreeInitializer.class)); + histBuilder = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + HIST_BUILDER_STATE_NAME, HistBuilder.class)); int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); instancesWriter = @@ -103,7 +139,10 @@ public void open() throws Exception { subtaskId, sharedInstancesKey, getOperatorID(), + new GenericArraySerializer<>( + BinnedInstance.class, BinnedInstanceSerializer.INSTANCE), new BinnedInstance[0]); + instancesWriter.initializeState(context); shuffledIndicesWriter = IterationSharedStorage.getWriter( @@ -111,7 +150,9 @@ public void open() throws Exception { subtaskId, sharedShuffledIndicesKey, getOperatorID(), + IntPrimitiveArraySerializer.INSTANCE, new int[0]); + shuffledIndicesWriter.initializeState(context); this.pghReader = IterationSharedStorage.getReader(iterationID, subtaskId, sharedPredGradHessKey); @@ -119,6 +160,14 @@ public void open() throws Exception { IterationSharedStorage.getReader(iterationID, subtaskId, sharedSwappedIndicesKey); } + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + instancesCollecting.snapshotState(context); + instancesWriter.snapshotState(context); + shuffledIndicesWriter.snapshotState(context); + } + @Override public void processElement1(StreamRecord streamRecord) throws Exception { Row row = streamRecord.getValue(); @@ -141,37 +190,44 @@ public void processElement1(StreamRecord streamRecord) throws Exception { @Override public void processElement2(StreamRecord streamRecord) throws Exception { - localState = streamRecord.getValue(); + localState.update(Collections.singletonList(streamRecord.getValue())); } + @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector out) throws Exception { + LocalState localStateValue = + OperatorStateUtils.getUniqueElement(localState, "local_state").get(); if (0 == epochWatermark) { // Initializes local state in first round. - instancesWriter.set(instancesCollecting.toArray(new BinnedInstance[0])); + instancesWriter.set( + (BinnedInstance[]) + IteratorUtils.toArray( + instancesCollecting.get().iterator(), BinnedInstance.class)); instancesCollecting.clear(); new LocalStateInitializer(gbtParams) .init( - localState, + localStateValue, getRuntimeContext().getIndexOfThisSubtask(), getRuntimeContext().getNumberOfParallelSubtasks(), instancesWriter.get()); - treeInitializer = new TreeInitializer(localState.statics); - histBuilder = new HistBuilder(localState.statics); + treeInitializer.update( + Collections.singletonList(new TreeInitializer(localStateValue.statics))); + histBuilder.update(Collections.singletonList(new HistBuilder(localStateValue.statics))); } BinnedInstance[] instances = instancesWriter.get(); Preconditions.checkArgument( - getRuntimeContext().getIndexOfThisSubtask() == localState.statics.subtaskId); + getRuntimeContext().getIndexOfThisSubtask() == localStateValue.statics.subtaskId); PredGradHess[] pgh = pghReader.get(); // In the first round, use prior as the predictions. if (0 == pgh.length) { pgh = new PredGradHess[instances.length]; - double prior = localState.statics.prior; - Loss loss = localState.statics.loss; + double prior = localStateValue.statics.prior; + Loss loss = localStateValue.statics.loss; for (int i = 0; i < instances.length; i += 1) { double label = instances[i].label; pgh[i] = @@ -181,10 +237,12 @@ public void onEpochWatermarkIncremented( } int[] indices; - if (!localState.dynamics.inWeakLearner) { + if (!localStateValue.dynamics.inWeakLearner) { // When last tree is finished, initializes a new tree, and shuffle instance indices. - treeInitializer.init(localState.dynamics, shuffledIndicesWriter::set); - localState.dynamics.inWeakLearner = true; + OperatorStateUtils.getUniqueElement(treeInitializer, TREE_INITIALIZER_STATE_NAME) + .get() + .init(localStateValue.dynamics, shuffledIndicesWriter::set); + localStateValue.dynamics.inWeakLearner = true; indices = shuffledIndicesWriter.get(); } else { // Otherwise, uses the swapped instance indices. @@ -193,19 +251,25 @@ public void onEpochWatermarkIncremented( } Histogram localHists = - histBuilder.build( - localState.dynamics.layer, - localState.dynamics.nodeFeaturePairs, - indices, - instances, - pgh); + OperatorStateUtils.getUniqueElement(histBuilder, HIST_BUILDER_STATE_NAME) + .get() + .build( + localStateValue.dynamics.layer, + localStateValue.dynamics.nodeFeaturePairs, + indices, + instances, + pgh); out.collect(localHists); - context.output(stateOutputTag, localState); + context.output(stateOutputTag, localStateValue); } @Override public void onIterationTerminated(Context context, Collector collector) { instancesCollecting.clear(); + localState.clear(); + treeInitializer.clear(); + histBuilder.clear(); + instancesWriter.set(new BinnedInstance[0]); shuffledIndicesWriter.set(new int[0]); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 1baf89669..52a1ad1b3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -18,19 +18,29 @@ package org.apache.flink.ml.common.gbt.operators; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.typeinfo.PredGradHessSerializer; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; import org.apache.flink.util.OutputTag; +import java.util.Collections; + /** * Post-process after global splits obtained, including split instances to left or child nodes, and * update instances scores after a tree is complete. @@ -39,6 +49,11 @@ public class PostSplitsOperator extends AbstractStreamOperator implements TwoInputStreamOperator, IterationListener { + private static final String LOCAL_STATE_STATE_NAME = "local_state"; + private static final String SPLITS_STATE_NAME = "splits"; + private static final String NODE_SPLITTER_STATE_NAME = "node_splitter"; + private static final String INSTANCE_UPDATER_STATE_NAME = "instance_updater"; + private final IterationID iterationID; private final String sharedInstancesKey; private final String sharedPredGradHessKey; @@ -51,10 +66,10 @@ public class PostSplitsOperator extends AbstractStreamOperator private IterationSharedStorage.Reader shuffledIndicesReader; private IterationSharedStorage.Writer swappedIndicesWriter; - private transient LocalState localState; - private transient Splits splits; - private transient NodeSplitter nodeSplitter; - private transient InstanceUpdater instanceUpdater; + private transient ListState localState; + private transient ListState splits; + private transient ListState nodeSplitter; + private transient ListState instanceUpdater; public PostSplitsOperator( IterationID iterationID, @@ -72,7 +87,28 @@ public PostSplitsOperator( } @Override - public void open() throws Exception { + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + localState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + LOCAL_STATE_STATE_NAME, LocalState.class)); + splits = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>(SPLITS_STATE_NAME, Splits.class)); + nodeSplitter = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + NODE_SPLITTER_STATE_NAME, NodeSplitter.class)); + instanceUpdater = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + INSTANCE_UPDATER_STATE_NAME, InstanceUpdater.class)); + int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); pghWriter = IterationSharedStorage.getWriter( @@ -80,14 +116,19 @@ public void open() throws Exception { subtaskId, sharedPredGradHessKey, getOperatorID(), + new GenericArraySerializer<>( + PredGradHess.class, PredGradHessSerializer.INSTANCE), new PredGradHess[0]); + pghWriter.initializeState(context); swappedIndicesWriter = IterationSharedStorage.getWriter( iterationID, subtaskId, sharedSwappedIndicesKey, getOperatorID(), + IntPrimitiveArraySerializer.INSTANCE, new int[0]); + swappedIndicesWriter.initializeState(context); this.instancesReader = IterationSharedStorage.getReader(iterationID, subtaskId, sharedInstancesKey); @@ -95,13 +136,24 @@ public void open() throws Exception { IterationSharedStorage.getReader(iterationID, subtaskId, sharedShuffledIndicesKey); } + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + pghWriter.snapshotState(context); + swappedIndicesWriter.snapshotState(context); + } + + @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { - LocalState localStateValue = localState; + LocalState localStateValue = + OperatorStateUtils.getUniqueElement(localState, LOCAL_STATE_STATE_NAME).get(); if (0 == epochWatermark) { - nodeSplitter = new NodeSplitter(localStateValue.statics); - instanceUpdater = new InstanceUpdater(localStateValue.statics); + nodeSplitter.update( + Collections.singletonList(new NodeSplitter(localStateValue.statics))); + instanceUpdater.update( + Collections.singletonList(new InstanceUpdater(localStateValue.statics))); } int[] indices = swappedIndicesWriter.get(); @@ -110,17 +162,20 @@ public void onEpochWatermarkIncremented( } BinnedInstance[] instances = instancesReader.get(); - nodeSplitter.split( - localStateValue.dynamics.layer, - localStateValue.dynamics.leaves, - splits.splits, - indices, - instances); + OperatorStateUtils.getUniqueElement(nodeSplitter, NODE_SPLITTER_STATE_NAME) + .get() + .split( + localStateValue.dynamics.layer, + localStateValue.dynamics.leaves, + OperatorStateUtils.getUniqueElement(splits, SPLITS_STATE_NAME).get().splits, + indices, + instances); if (localStateValue.dynamics.layer.isEmpty()) { localStateValue.dynamics.inWeakLearner = false; - instanceUpdater.update( - localStateValue.dynamics.leaves, indices, instances, pghWriter::set); + OperatorStateUtils.getUniqueElement(instanceUpdater, INSTANCE_UPDATER_STATE_NAME) + .get() + .update(localStateValue.dynamics.leaves, indices, instances, pghWriter::set); swappedIndicesWriter.set(new int[0]); } else { swappedIndicesWriter.set(indices); @@ -129,22 +184,26 @@ public void onEpochWatermarkIncremented( } @Override - public void onIterationTerminated(Context context, Collector collector) { + public void onIterationTerminated(Context context, Collector collector) + throws Exception { pghWriter.set(new PredGradHess[0]); swappedIndicesWriter.set(new int[0]); if (0 == getRuntimeContext().getIndexOfThisSubtask()) { - context.output(finalStateOutputTag, localState); + //noinspection OptionalGetWithoutIsPresent + context.output( + finalStateOutputTag, + OperatorStateUtils.getUniqueElement(localState, LOCAL_STATE_STATE_NAME).get()); } } @Override public void processElement1(StreamRecord element) throws Exception { - localState = element.getValue(); + localState.update(Collections.singletonList(element.getValue())); } @Override public void processElement2(StreamRecord element) throws Exception { - splits = element.getValue(); + splits.update(Collections.singletonList(element.getValue())); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java new file mode 100644 index 000000000..af195af16 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; + +import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; + +import java.io.IOException; + +/** Serializer for {@link BinnedInstance}. */ +public final class BinnedInstanceSerializer extends TypeSerializerSingleton { + + public static final BinnedInstanceSerializer INSTANCE = new BinnedInstanceSerializer(); + private static final long serialVersionUID = 1L; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public BinnedInstance createInstance() { + return new BinnedInstance(); + } + + @Override + public BinnedInstance copy(BinnedInstance from) { + BinnedInstance instance = new BinnedInstance(); + instance.features = new IntIntHashMap(from.features); + instance.label = from.label; + instance.weight = from.weight; + return instance; + } + + @Override + public BinnedInstance copy(BinnedInstance from, BinnedInstance reuse) { + assert from.getClass() == reuse.getClass(); + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(BinnedInstance record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.features.size(), target); + for (int k : record.features.keysView().toArray()) { + IntSerializer.INSTANCE.serialize(k, target); + IntSerializer.INSTANCE.serialize(record.features.get(k), target); + } + DoubleSerializer.INSTANCE.serialize(record.label, target); + DoubleSerializer.INSTANCE.serialize(record.weight, target); + } + + @Override + public BinnedInstance deserialize(DataInputView source) throws IOException { + BinnedInstance instance = new BinnedInstance(); + int numFeatures = IntSerializer.INSTANCE.deserialize(source); + instance.features = new IntIntHashMap(); + for (int i = 0; i < numFeatures; i += 1) { + int k = IntSerializer.INSTANCE.deserialize(source); + int v = IntSerializer.INSTANCE.deserialize(source); + instance.features.put(k, v); + } + instance.label = DoubleSerializer.INSTANCE.deserialize(source); + instance.weight = DoubleSerializer.INSTANCE.deserialize(source); + return instance; + } + + @Override + public BinnedInstance deserialize(BinnedInstance reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new BinnedInstanceSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class BinnedInstanceSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public BinnedInstanceSerializerSnapshot() { + super(BinnedInstanceSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java new file mode 100644 index 000000000..d206a1427 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.PredGradHess; + +import java.io.IOException; + +/** Serializer for {@link PredGradHess}. */ +public final class PredGradHessSerializer extends TypeSerializerSingleton { + + public static final PredGradHessSerializer INSTANCE = new PredGradHessSerializer(); + private static final long serialVersionUID = 1L; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public PredGradHess createInstance() { + return new PredGradHess(); + } + + @Override + public PredGradHess copy(PredGradHess from) { + PredGradHess instance = new PredGradHess(); + instance.pred = from.pred; + instance.gradient = from.gradient; + instance.hessian = from.hessian; + return instance; + } + + @Override + public PredGradHess copy(PredGradHess from, PredGradHess reuse) { + assert from.getClass() == reuse.getClass(); + reuse.pred = from.pred; + reuse.gradient = from.gradient; + reuse.hessian = from.hessian; + return reuse; + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(PredGradHess record, DataOutputView target) throws IOException { + DoubleSerializer.INSTANCE.serialize(record.pred, target); + DoubleSerializer.INSTANCE.serialize(record.gradient, target); + DoubleSerializer.INSTANCE.serialize(record.hessian, target); + } + + @Override + public PredGradHess deserialize(DataInputView source) throws IOException { + PredGradHess instance = new PredGradHess(); + instance.pred = DoubleSerializer.INSTANCE.deserialize(source); + instance.gradient = DoubleSerializer.INSTANCE.deserialize(source); + instance.hessian = DoubleSerializer.INSTANCE.deserialize(source); + return instance; + } + + @Override + public PredGradHess deserialize(PredGradHess reuse, DataInputView source) throws IOException { + reuse.pred = DoubleSerializer.INSTANCE.deserialize(source); + reuse.gradient = DoubleSerializer.INSTANCE.deserialize(source); + reuse.hessian = DoubleSerializer.INSTANCE.deserialize(source); + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new PredGradHessSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class PredGradHessSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public PredGradHessSerializerSnapshot() { + super(PredGradHessSerializer::new); + } + } +} From 69e0774553931e12efd16a82e2b3b631c5169412 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 10 Feb 2023 15:54:15 +0800 Subject: [PATCH 04/47] Add GBTClassifier --- .../apache/flink/ml/util/ReadWriteUtils.java | 25 + .../org/apache/flink/ml/util/TestUtils.java | 52 ++ .../gbtclassifier/GBTClassifier.java | 75 +++ .../gbtclassifier/GBTClassifierModel.java | 134 +++++ .../gbtclassifier/GBTClassifierParams.java | 47 ++ .../flink/ml/common/gbt/BaseGBTModel.java | 68 +++ .../flink/ml/common/gbt/BaseGBTParams.java | 90 ++++ .../flink/ml/common/gbt/GBTModelData.java | 57 ++ .../flink/ml/common/gbt/GBTModelParams.java | 55 ++ .../apache/flink/ml/common/gbt/GBTRunner.java | 75 +++ .../typeinfo/CategoricalSplitSerializer.java | 124 +++++ .../typeinfo/ContinuousSplitSerializer.java | 127 +++++ .../gbt/typeinfo/GBTModelDataSerializer.java | 189 +++++++ .../gbt/typeinfo/GBTModelDataTypeInfo.java | 88 +++ .../typeinfo/GBTModelDataTypeInfoFactory.java | 39 ++ .../common/gbt/typeinfo/NodeSerializer.java | 131 +++++ .../common/gbt/typeinfo/SplitSerializer.java | 132 +++++ .../param/HasFeatureSubsetStrategy.java | 42 ++ .../flink/ml/common/param/HasLeafCol.java | 37 ++ .../flink/ml/common/param/HasLossType.java | 44 ++ .../flink/ml/common/param/HasMaxBins.java | 42 ++ .../flink/ml/common/param/HasMaxDepth.java | 38 ++ .../flink/ml/common/param/HasMinInfoGain.java | 42 ++ .../common/param/HasMinInstancesPerNode.java | 42 ++ .../param/HasMinWeightFractionPerNode.java | 42 ++ .../ml/common/param/HasProbabilityCol.java | 42 ++ .../flink/ml/common/param/HasStepSize.java | 42 ++ .../ml/common/param/HasSubsamplingRate.java | 42 ++ .../param/HasValidationIndicatorCol.java | 40 ++ .../ml/common/param/HasValidationTol.java | 43 ++ .../ml/classification/GBTClassifierTest.java | 500 ++++++++++++++++++ 31 files changed, 2546 insertions(+) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/CategoricalSplitSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfo.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfoFactory.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLeafCol.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasProbabilityCol.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java index b284bbb3d..300f753eb 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.eventtime.WatermarkStrategy; import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.connector.source.Source; import org.apache.flink.connector.file.sink.FileSink; import org.apache.flink.connector.file.src.FileSource; @@ -323,4 +324,28 @@ public static Table loadModelData( env.fromSource(source, WatermarkStrategy.noWatermarks(), "modelData"); return tEnv.fromDataStream(modelDataStream); } + + /** + * Loads the model data from the given path using the model decoder. This overloaded version + * returns a table with only 1 column whose type is the class of the model data. + * + * @param tEnv A StreamTableEnvironment instance. + * @param path The parent directory of the model data file. + * @param modelDecoder The decoder used to decode the model data. + * @param typeInfo The type information of model data. + * @param The class type of the model data. + * @return The loaded model data. + */ + public static Table loadModelData( + StreamTableEnvironment tEnv, + String path, + SimpleStreamFormat modelDecoder, + TypeInformation typeInfo) { + StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); + Source source = + FileSource.forRecordStreamFormat(modelDecoder, new Path(getDataPath(path))).build(); + DataStream modelDataStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "modelData", typeInfo); + return tEnv.fromDataStream(modelDataStream); + } } diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java index ec97b48c6..59d1b119a 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -51,6 +51,7 @@ import org.apache.commons.collections.IteratorUtils; import org.apache.commons.lang3.ArrayUtils; +import org.junit.Assert; import java.io.DataInputStream; import java.io.DataOutputStream; @@ -324,4 +325,55 @@ public static DataFrame constructDataFrame( } return new DataFrame(columnNames, dataTypes, rowList); } + + /** + * Compare two lists of elements with the given comparator. Different from {@link + * org.apache.flink.test.util.TestBaseUtils#compareResultCollections}, the comparator is also + * used when comparing elements. + */ + public static void compareResultCollectionsWithComparator( + List expected, List actual, Comparator comparator) { + Assert.assertEquals(expected.size(), actual.size()); + expected.sort(comparator); + actual.sort(comparator); + for (int i = 0; i < expected.size(); i++) { + Assert.assertEquals(0, comparator.compare(expected.get(i), actual.get(i))); + } + } + + public static class DoubleComparatorWithDelta implements Comparator { + private final double delta; + + public DoubleComparatorWithDelta(double delta) { + this.delta = delta; + } + + @Override + public int compare(Double o1, Double o2) { + return Math.abs(o1 - o2) <= delta ? 0 : Double.compare(o1, o2); + } + } + + public static class DenseVectorComparatorWithDelta implements Comparator { + private final DoubleComparatorWithDelta doubleComparatorWithDelta; + + public DenseVectorComparatorWithDelta(double delta) { + doubleComparatorWithDelta = new DoubleComparatorWithDelta(delta); + } + + @Override + public int compare(DenseVector o1, DenseVector o2) { + if (o1.size() != o2.size()) { + return Integer.compare(o1.size(), o2.size()); + } else { + for (int i = 0; i < o1.size(); i++) { + int cmp = doubleComparatorWithDelta.compare(o1.get(i), o2.get(i)); + if (cmp != 0) { + return cmp; + } + } + } + return 0; + } + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java new file mode 100644 index 000000000..2ff827ea9 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.GBTRunner; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +/** An Estimator which implements the gradient boosting trees classification algorithm. */ +public class GBTClassifier + implements Estimator, + GBTClassifierParams { + + private final Map, Object> paramMap = new HashMap<>(); + + public GBTClassifier() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + public static GBTClassifier load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + @Override + public GBTClassifierModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream modelData = GBTRunner.trainClassifier(inputs[0], this); + GBTClassifierModel model = new GBTClassifierModel(); + model.setModelData(tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData"))); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java new file mode 100644 index 000000000..8e3110079 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.BaseGBTModel; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.math3.analysis.function.Sigmoid; +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; + +import java.io.IOException; +import java.util.Collections; + +/** A Model computed by {@link GBTClassifier}. */ +public class GBTClassifierModel extends BaseGBTModel + implements GBTClassifierParams { + + /** + * Loads model data from path. + * + * @param tEnv A StreamTableEnvironment instance. + * @param path Model path. + * @return GBT classification model. + */ + public static GBTClassifierModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + GBTClassifierModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData(tEnv, path, new GBTModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream inputStream = tEnv.toDataStream(inputs[0]); + final String broadcastModelKey = "broadcastModelKey"; + DataStream modelDataStream = GBTModelData.getModelDataStream(modelDataTable); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + Types.DOUBLE, + DenseVectorTypeInfo.INSTANCE, + DenseVectorTypeInfo.INSTANCE), + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), + getPredictionCol(), + getRawPredictionCol(), + getProbabilityCol())); + DataStream predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + //noinspection unchecked + DataStream inputData = (DataStream) inputList.get(0); + return inputData.map( + new PredictLabelFunction( + broadcastModelKey, getInputCols(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + private static class PredictLabelFunction extends RichMapFunction { + + private static final Sigmoid sigmoid = new Sigmoid(); + + private final String broadcastModelKey; + private final String[] inputCols; + private final String featuresCol; + private GBTModelData modelData; + + public PredictLabelFunction( + String broadcastModelKey, String[] inputCols, String featuresCol) { + this.broadcastModelKey = broadcastModelKey; + this.inputCols = inputCols; + this.featuresCol = featuresCol; + } + + @Override + public Row map(Row value) throws Exception { + if (null == modelData) { + modelData = + (GBTModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + } + IntDoubleHashMap features = modelData.rowToFeatures(value, inputCols, featuresCol); + double logits = modelData.predictRaw(features); + double prob = sigmoid.value(logits); + return Row.join( + value, + Row.of( + logits >= 0. ? 1. : 0., + Vectors.dense(-logits, logits), + Vectors.dense(1 - prob, prob))); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java new file mode 100644 index 000000000..20ee450ee --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.ml.common.gbt.BaseGBTParams; +import org.apache.flink.ml.common.param.HasProbabilityCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Parameters for {@link GBTClassifier}. + * + * @param The class type of this instance. + */ +public interface GBTClassifierParams + extends BaseGBTParams, HasRawPredictionCol, HasProbabilityCol { + + Param LOSS_TYPE = + new StringParam( + "lossType", "Loss type.", "logistic", ParamValidators.inArray("logistic")); + + default String getLossType() { + return get(LOSS_TYPE); + } + + default T setLossType(String value) { + return set(LOSS_TYPE, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java new file mode 100644 index 000000000..5ba8b7a9d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.table.api.Table; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** Base model computed by {@link GBTClassifier}. */ +public abstract class BaseGBTModel> + implements Model, GBTModelParams { + + protected final Map, Object> paramMap = new HashMap<>(); + protected Table modelDataTable; + + public BaseGBTModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public T setModelData(Table... inputs) { + modelDataTable = inputs[0]; + //noinspection unchecked + return (T) this; + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + GBTModelData.getModelDataStream(modelDataTable), + path, + new GBTModelData.ModelDataEncoder()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java new file mode 100644 index 000000000..de65c77f5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.common.param.HasFeatureSubsetStrategy; +import org.apache.flink.ml.common.param.HasLeafCol; +import org.apache.flink.ml.common.param.HasMaxBins; +import org.apache.flink.ml.common.param.HasMaxDepth; +import org.apache.flink.ml.common.param.HasMaxIter; +import org.apache.flink.ml.common.param.HasMinInfoGain; +import org.apache.flink.ml.common.param.HasMinInstancesPerNode; +import org.apache.flink.ml.common.param.HasMinWeightFractionPerNode; +import org.apache.flink.ml.common.param.HasSeed; +import org.apache.flink.ml.common.param.HasStepSize; +import org.apache.flink.ml.common.param.HasSubsamplingRate; +import org.apache.flink.ml.common.param.HasValidationIndicatorCol; +import org.apache.flink.ml.common.param.HasValidationTol; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Common parameters for GBT classifier and regressor. + * + *

TODO: support param thresholds, impurity (actually meaningless) + * + * @param The class type of this instance. + */ +public interface BaseGBTParams + extends GBTModelParams, + HasLeafCol, + HasWeightCol, + HasMaxDepth, + HasMaxBins, + HasMinInstancesPerNode, + HasMinWeightFractionPerNode, + HasMinInfoGain, + HasMaxIter, + HasStepSize, + HasSeed, + HasSubsamplingRate, + HasFeatureSubsetStrategy, + HasValidationIndicatorCol, + HasValidationTol { + Param REG_LAMBDA = + new DoubleParam( + "regLambda", + "Regularization term for the number of leaves.", + 0., + ParamValidators.gtEq(0.)); + Param REG_GAMMA = + new DoubleParam( + "regGamma", + "L2 regularization term for the weights of leaves.", + 1., + ParamValidators.gtEq(0)); + + default double getRegLambda() { + return get(REG_LAMBDA); + } + + default T setRegLambda(Double value) { + return set(REG_LAMBDA, value); + } + + default double getRegGamma() { + return get(REG_GAMMA); + } + + default T setRegGamma(Double value) { + return set(REG_GAMMA, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index 910588b97..542caf7e2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -18,9 +18,22 @@ package org.apache.flink.ml.common.gbt; +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.typeinfo.GBTModelDataSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.GBTModelDataTypeInfoFactory; import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; @@ -34,6 +47,8 @@ import org.eclipse.collections.impl.map.mutable.primitive.IntObjectHashMap; import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap; +import java.io.IOException; +import java.io.OutputStream; import java.util.BitSet; import java.util.List; @@ -43,6 +58,7 @@ *

This class also provides methods to convert model data from Table to Datastream, and classes * to save/load model data. */ +@TypeInfo(GBTModelDataTypeInfoFactory.class) public class GBTModelData { public String type; @@ -173,4 +189,45 @@ public String toString() { "GBTModelData{type=%s, prior=%s, roots=%s, categoryToIdMaps=%s, featureIdToBinEdges=%s, isCategorical=%s}", type, prior, roots, categoryToIdMaps, featureIdToBinEdges, isCategorical); } + + /** Encoder for {@link GBTModelData}. */ + public static class ModelDataEncoder implements Encoder { + @Override + public void encode(GBTModelData modelData, OutputStream outputStream) throws IOException { + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); + final GBTModelDataSerializer serializer = GBTModelDataSerializer.INSTANCE; + serializer.serialize(modelData, dataOutputView); + } + } + + /** Decoder for {@link GBTModelData}. */ + public static class ModelDataDecoder extends SimpleStreamFormat { + @Override + public Reader createReader(Configuration config, FSDataInputStream stream) { + return new Reader() { + + private final GBTModelDataSerializer serializer = GBTModelDataSerializer.INSTANCE; + + @Override + public GBTModelData read() { + DataInputView source = new DataInputViewStreamWrapper(stream); + try { + return serializer.deserialize(source); + } catch (IOException e) { + return null; + } + } + + @Override + public void close() throws IOException { + stream.close(); + } + }; + } + + @Override + public TypeInformation getProducedType() { + return TypeInformation.of(GBTModelData.class); + } + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java new file mode 100644 index 000000000..c0997c4cf --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt; + +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; +import org.apache.flink.ml.common.param.HasCategoricalCols; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.StringArrayParam; + +/** + * Params of {@link GBTClassifierModel}. + * + *

If the input features come from 1 column of vector type, `featuresCol` should be used, and all + * features are treated as continuous features. Otherwise, `inputCols` should be used for multiple + * columns. Columns whose names specified in `categoricalCols` are treated as categorical features, + * while others are continuous features. + * + *

NOTE: `inputCols` and `featuresCol` are in conflict with each other, so they should not be set + * at the same time. In addition, `inputCols` has a higher precedence than `featuresCol`, that is, + * `featuresCol` is ignored when `inputCols` is not `null`. + * + * @param The class type of this instance. + */ +public interface GBTModelParams + extends HasFeaturesCol, HasLabelCol, HasCategoricalCols, HasPredictionCol { + + Param INPUT_COLS = new StringArrayParam("inputCols", "Input column names.", null); + + default String[] getInputCols() { + return get(INPUT_COLS); + } + + default T setInputCols(String... value) { + return set(INPUT_COLS, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index 217f836a2..c748325ee 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -26,11 +26,14 @@ import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierParams; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.param.Param; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -43,11 +46,28 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; /** Runs a gradient boosting trees implementation. */ public class GBTRunner { + public static DataStream trainClassifier(Table data, BaseGBTParams estimator) { + return train(data, estimator, TaskType.CLASSIFICATION); + } + + public static DataStream trainRegressor(Table data, BaseGBTParams estimator) { + return train(data, estimator, TaskType.REGRESSION); + } + + static DataStream train( + Table data, BaseGBTParams estimator, TaskType taskType) { + return train(data, fromEstimator(estimator, taskType)); + } + /** Trains a gradient boosting tree model with given data and parameters. */ static DataStream train(Table dataTable, GbtParams p) { StreamTableEnvironment tEnv = @@ -106,6 +126,61 @@ private static DataStream boost( return state.map(GBTModelData::fromLocalState); } + public static GbtParams fromEstimator(BaseGBTParams estimator, TaskType taskType) { + final Map, Object> paramMap = estimator.getParamMap(); + final Set> unsupported = + new HashSet<>( + Arrays.asList( + BaseGBTParams.WEIGHT_COL, + BaseGBTParams.LEAF_COL, + BaseGBTParams.VALIDATION_INDICATOR_COL)); + List> unsupportedButSet = + unsupported.stream() + .filter(d -> null != paramMap.get(d)) + .collect(Collectors.toList()); + if (!unsupportedButSet.isEmpty()) { + throw new UnsupportedOperationException( + String.format( + "Parameters %s are not supported yet right now.", + unsupportedButSet.stream() + .map(d -> d.name) + .collect(Collectors.joining(", ")))); + } + + GbtParams p = new GbtParams(); + p.taskType = taskType; + p.featureCols = estimator.getInputCols(); + p.vectorCol = estimator.getFeaturesCol(); + p.isInputVector = (null == p.featureCols); + p.labelCol = estimator.getLabelCol(); + p.weightCol = estimator.getWeightCol(); + p.categoricalCols = estimator.getCategoricalCols(); + + p.maxDepth = estimator.getMaxDepth(); + p.maxBins = estimator.getMaxBins(); + p.minInstancesPerNode = estimator.getMinInstancesPerNode(); + p.minWeightFractionPerNode = estimator.getMinWeightFractionPerNode(); + p.minInfoGain = estimator.getMinInfoGain(); + p.maxIter = estimator.getMaxIter(); + p.stepSize = estimator.getStepSize(); + p.seed = estimator.getSeed(); + p.subsamplingRate = estimator.getSubsamplingRate(); + p.featureSubsetStrategy = estimator.getFeatureSubsetStrategy(); + p.validationTol = estimator.getValidationTol(); + p.gamma = estimator.getRegGamma(); + p.lambda = estimator.getRegLambda(); + + if (TaskType.CLASSIFICATION.equals(p.taskType)) { + p.lossType = estimator.get(GBTClassifierParams.LOSS_TYPE); + } else { + // TODO: add GBTRegressorParams.LOSS_TYPE in next PR. + p.lossType = estimator.get(GBTClassifierParams.LOSS_TYPE); + } + p.maxNumLeaves = 1 << p.maxDepth - 1; + p.useMissing = true; + return p; + } + private static class InitLocalStateFunction extends RichMapFunction { private final String featureMetaBcName; private final String labelSumCountBcName; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/CategoricalSplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/CategoricalSplitSerializer.java new file mode 100644 index 000000000..13889a186 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/CategoricalSplitSerializer.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.io.IOException; +import java.util.BitSet; + +/** Specialized serializer for {@link Split.CategoricalSplit}. */ +public final class CategoricalSplitSerializer + extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + public static final CategoricalSplitSerializer INSTANCE = new CategoricalSplitSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Split.CategoricalSplit createInstance() { + return new Split.CategoricalSplit(-1, Split.INVALID_GAIN, 0, false, 0., new BitSet()); + } + + @Override + public Split.CategoricalSplit copy(Split.CategoricalSplit from) { + return new Split.CategoricalSplit( + from.featureId, + from.gain, + from.missingBin, + from.missingGoLeft, + from.prediction, + from.categoriesGoLeft); + } + + @Override + public Split.CategoricalSplit copy(Split.CategoricalSplit from, Split.CategoricalSplit reuse) { + assert from.getClass() == reuse.getClass(); + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Split.CategoricalSplit record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.featureId, target); + DoubleSerializer.INSTANCE.serialize(record.gain, target); + IntSerializer.INSTANCE.serialize(record.missingBin, target); + BooleanSerializer.INSTANCE.serialize(record.missingGoLeft, target); + DoubleSerializer.INSTANCE.serialize(record.prediction, target); + BytePrimitiveArraySerializer.INSTANCE.serialize( + record.categoriesGoLeft.toByteArray(), target); + } + + @Override + public Split.CategoricalSplit deserialize(DataInputView source) throws IOException { + return new Split.CategoricalSplit( + IntSerializer.INSTANCE.deserialize(source), + DoubleSerializer.INSTANCE.deserialize(source), + IntSerializer.INSTANCE.deserialize(source), + BooleanSerializer.INSTANCE.deserialize(source), + DoubleSerializer.INSTANCE.deserialize(source), + BitSet.valueOf(BytePrimitiveArraySerializer.INSTANCE.deserialize(source))); + } + + @Override + public Split.CategoricalSplit deserialize(Split.CategoricalSplit reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new CategoricalSplitSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class CategoricalSplitSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public CategoricalSplitSerializerSnapshot() { + super(CategoricalSplitSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java new file mode 100644 index 000000000..b57e2c94d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.io.IOException; + +/** Specialized serializer for {@link Split.ContinuousSplit}. */ +public final class ContinuousSplitSerializer + extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + public static final ContinuousSplitSerializer INSTANCE = new ContinuousSplitSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Split.ContinuousSplit createInstance() { + return new Split.ContinuousSplit(-1, Split.INVALID_GAIN, 0, false, 0., 0., false, 0); + } + + @Override + public Split.ContinuousSplit copy(Split.ContinuousSplit from) { + return new Split.ContinuousSplit( + from.featureId, + from.gain, + from.missingBin, + from.missingGoLeft, + from.prediction, + from.threshold, + from.isUnseenMissing, + from.zeroBin); + } + + @Override + public Split.ContinuousSplit copy(Split.ContinuousSplit from, Split.ContinuousSplit reuse) { + assert from.getClass() == reuse.getClass(); + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Split.ContinuousSplit record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.featureId, target); + DoubleSerializer.INSTANCE.serialize(record.gain, target); + IntSerializer.INSTANCE.serialize(record.missingBin, target); + BooleanSerializer.INSTANCE.serialize(record.missingGoLeft, target); + DoubleSerializer.INSTANCE.serialize(record.prediction, target); + DoubleSerializer.INSTANCE.serialize(record.threshold, target); + BooleanSerializer.INSTANCE.serialize(record.isUnseenMissing, target); + IntSerializer.INSTANCE.serialize(record.zeroBin, target); + } + + @Override + public Split.ContinuousSplit deserialize(DataInputView source) throws IOException { + return new Split.ContinuousSplit( + IntSerializer.INSTANCE.deserialize(source), + DoubleSerializer.INSTANCE.deserialize(source), + IntSerializer.INSTANCE.deserialize(source), + BooleanSerializer.INSTANCE.deserialize(source), + DoubleSerializer.INSTANCE.deserialize(source), + DoubleSerializer.INSTANCE.deserialize(source), + BooleanSerializer.INSTANCE.deserialize(source), + IntSerializer.INSTANCE.deserialize(source)); + } + + @Override + public Split.ContinuousSplit deserialize(Split.ContinuousSplit reuse, DataInputView source) + throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new ContinuousSplitSplitSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class ContinuousSplitSplitSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public ContinuousSplitSplitSerializerSnapshot() { + super(ContinuousSplitSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java new file mode 100644 index 000000000..419130146 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; +import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArraySerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.defs.Node; + +import org.eclipse.collections.impl.map.mutable.primitive.IntObjectHashMap; +import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.BitSet; + +/** Specialized serializer for {@link GBTModelData}. */ +public final class GBTModelDataSerializer extends TypeSerializerSingleton { + + public static final GBTModelDataSerializer INSTANCE = new GBTModelDataSerializer(); + private static final long serialVersionUID = 1L; + private static final NodeSerializer NODE_SERIALIZER = NodeSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public GBTModelData createInstance() { + return new GBTModelData(); + } + + @Override + public GBTModelData copy(GBTModelData from) { + GBTModelData record = new GBTModelData(); + record.type = from.type; + record.isInputVector = from.isInputVector; + + record.prior = from.prior; + record.stepSize = from.stepSize; + + record.roots = new ArrayList<>(from.roots); + record.categoryToIdMaps = new IntObjectHashMap<>(from.categoryToIdMaps); + record.featureIdToBinEdges = new IntObjectHashMap<>(from.featureIdToBinEdges); + record.isCategorical = BitSet.valueOf(from.isCategorical.toByteArray()); + return record; + } + + @Override + public GBTModelData copy(GBTModelData from, GBTModelData reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(GBTModelData record, DataOutputView target) throws IOException { + StringSerializer.INSTANCE.serialize(record.type, target); + BooleanSerializer.INSTANCE.serialize(record.isInputVector, target); + + DoubleSerializer.INSTANCE.serialize(record.prior, target); + DoubleSerializer.INSTANCE.serialize(record.stepSize, target); + + IntSerializer.INSTANCE.serialize(record.roots.size(), target); + for (Node root : record.roots) { + NODE_SERIALIZER.serialize(root, target); + } + + IntSerializer.INSTANCE.serialize(record.categoryToIdMaps.size(), target); + for (int featureId : record.categoryToIdMaps.keysView().toArray()) { + ObjectIntHashMap categoryToIdMap = record.categoryToIdMaps.get(featureId); + IntSerializer.INSTANCE.serialize(featureId, target); + IntSerializer.INSTANCE.serialize(categoryToIdMap.size(), target); + for (String category : categoryToIdMap.keysView()) { + StringSerializer.INSTANCE.serialize(category, target); + IntSerializer.INSTANCE.serialize(categoryToIdMap.get(category), target); + } + } + + IntSerializer.INSTANCE.serialize(record.featureIdToBinEdges.size(), target); + for (int featureId : record.featureIdToBinEdges.keysView().toArray()) { + double[] binEdges = record.featureIdToBinEdges.get(featureId); + IntSerializer.INSTANCE.serialize(featureId, target); + DoublePrimitiveArraySerializer.INSTANCE.serialize(binEdges, target); + } + + BytePrimitiveArraySerializer.INSTANCE.serialize(record.isCategorical.toByteArray(), target); + } + + @Override + public GBTModelData deserialize(DataInputView source) throws IOException { + GBTModelData record = new GBTModelData(); + + record.type = StringSerializer.INSTANCE.deserialize(source); + record.isInputVector = BooleanSerializer.INSTANCE.deserialize(source); + + record.prior = DoubleSerializer.INSTANCE.deserialize(source); + record.stepSize = DoubleSerializer.INSTANCE.deserialize(source); + + int numRoots = IntSerializer.INSTANCE.deserialize(source); + record.roots = new ArrayList<>(); + for (int i = 0; i < numRoots; i += 1) { + record.roots.add(NODE_SERIALIZER.deserialize(source)); + } + + int numCategoricalFeatures = IntSerializer.INSTANCE.deserialize(source); + record.categoryToIdMaps = IntObjectHashMap.newMap(); + for (int k = 0; k < numCategoricalFeatures; k += 1) { + int featureId = IntSerializer.INSTANCE.deserialize(source); + int categoryToIdMapSize = IntSerializer.INSTANCE.deserialize(source); + ObjectIntHashMap categoryToIdMap = ObjectIntHashMap.newMap(); + for (int i = 0; i < categoryToIdMapSize; i += 1) { + categoryToIdMap.put( + StringSerializer.INSTANCE.deserialize(source), + IntSerializer.INSTANCE.deserialize(source)); + } + record.categoryToIdMaps.put(featureId, categoryToIdMap); + } + + int numContinuousFeatures = IntSerializer.INSTANCE.deserialize(source); + record.featureIdToBinEdges = IntObjectHashMap.newMap(); + for (int i = 0; i < numContinuousFeatures; i += 1) { + int featureId = IntSerializer.INSTANCE.deserialize(source); + double[] binEdges = DoublePrimitiveArraySerializer.INSTANCE.deserialize(source); + record.featureIdToBinEdges.put(featureId, binEdges); + } + + record.isCategorical = + BitSet.valueOf(BytePrimitiveArraySerializer.INSTANCE.deserialize(source)); + return record; + } + + @Override + public GBTModelData deserialize(GBTModelData reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new GBTModelDataSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class GBTModelDataSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public GBTModelDataSerializerSnapshot() { + super(GBTModelDataSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfo.java new file mode 100644 index 000000000..42fc88718 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.common.gbt.GBTModelData; + +/** A {@link TypeInformation} for the {@link GBTModelData} type. */ +public class GBTModelDataTypeInfo extends TypeInformation { + + public static final GBTModelDataTypeInfo INSTANCE = new GBTModelDataTypeInfo(); + + private GBTModelDataTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 2; + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public Class getTypeClass() { + return GBTModelData.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new GBTModelDataSerializer(); + } + + @Override + public String toString() { + return "SplitTypeInfo"; + } + + @Override + public boolean equals(Object o) { + return o instanceof GBTModelDataTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof GBTModelDataTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfoFactory.java new file mode 100644 index 000000000..f32e1176e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataTypeInfoFactory.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.GBTModelData; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * GBTModelData}. + */ +public class GBTModelDataTypeInfoFactory extends TypeInfoFactory { + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return GBTModelDataTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java new file mode 100644 index 000000000..6fbe73e87 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Node; + +import java.io.IOException; + +/** Serializer for {@link Node}. */ +public final class NodeSerializer extends TypeSerializerSingleton { + + public static final NodeSerializer INSTANCE = new NodeSerializer(); + private static final long serialVersionUID = 1L; + + private static final SplitSerializer SPLIT_SERIALIZER = SplitSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Node createInstance() { + return new Node(); + } + + @Override + public Node copy(Node from) { + Node node = new Node(); + node.split = SPLIT_SERIALIZER.copy(from.split); + node.isLeaf = from.isLeaf; + if (!node.isLeaf) { + node.left = copy(from.left); + node.right = copy(from.right); + } + return node; + } + + @Override + public Node copy(Node from, Node reuse) { + assert from.getClass() == reuse.getClass(); + SPLIT_SERIALIZER.copy(from.split, reuse.split); + reuse.isLeaf = from.isLeaf; + if (!reuse.isLeaf) { + copy(from.left, reuse.left); + copy(from.right, reuse.right); + } + return reuse; + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Node record, DataOutputView target) throws IOException { + SPLIT_SERIALIZER.serialize(record.split, target); + BooleanSerializer.INSTANCE.serialize(record.isLeaf, target); + if (!record.isLeaf) { + serialize(record.left, target); + serialize(record.right, target); + } + } + + @Override + public Node deserialize(DataInputView source) throws IOException { + Node node = new Node(); + node.split = SPLIT_SERIALIZER.deserialize(source); + node.isLeaf = BooleanSerializer.INSTANCE.deserialize(source); + if (!node.isLeaf) { + node.left = deserialize(source); + node.right = deserialize(source); + } + return node; + } + + @Override + public Node deserialize(Node reuse, DataInputView source) throws IOException { + reuse.split = SPLIT_SERIALIZER.deserialize(source); + reuse.isLeaf = BooleanSerializer.INSTANCE.deserialize(source); + if (!reuse.isLeaf) { + reuse.left = deserialize(source); + reuse.right = deserialize(source); + } + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new NodeSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class NodeSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public NodeSerializerSnapshot() { + super(NodeSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java new file mode 100644 index 000000000..c8d44df7f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.io.IOException; + +/** Specialized serializer for {@link Split}. */ +public final class SplitSerializer extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + private static final CategoricalSplitSerializer CATEGORICAL_SPLIT_SERIALIZER = + CategoricalSplitSerializer.INSTANCE; + + private static final ContinuousSplitSerializer CONTINUOUS_SPLIT_SERIALIZER = + ContinuousSplitSerializer.INSTANCE; + + public static final SplitSerializer INSTANCE = new SplitSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Split createInstance() { + return CATEGORICAL_SPLIT_SERIALIZER.createInstance(); + } + + @Override + public Split copy(Split from) { + if (from instanceof Split.CategoricalSplit) { + return CATEGORICAL_SPLIT_SERIALIZER.copy((Split.CategoricalSplit) from); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.copy((Split.ContinuousSplit) from); + } + } + + @Override + public Split copy(Split from, Split reuse) { + assert from.getClass() == reuse.getClass(); + if (from instanceof Split.CategoricalSplit) { + return CATEGORICAL_SPLIT_SERIALIZER.copy( + (Split.CategoricalSplit) from, (Split.CategoricalSplit) reuse); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.copy( + (Split.ContinuousSplit) from, (Split.ContinuousSplit) reuse); + } + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Split record, DataOutputView target) throws IOException { + if (record instanceof Split.CategoricalSplit) { + target.writeByte(0); + CATEGORICAL_SPLIT_SERIALIZER.serialize((Split.CategoricalSplit) record, target); + } else { + target.writeByte(1); + CONTINUOUS_SPLIT_SERIALIZER.serialize((Split.ContinuousSplit) record, target); + } + } + + @Override + public Split deserialize(DataInputView source) throws IOException { + byte type = source.readByte(); + if (type == 0) { + return CATEGORICAL_SPLIT_SERIALIZER.deserialize(source); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.deserialize(source); + } + } + + @Override + public Split deserialize(Split reuse, DataInputView source) throws IOException { + byte type = source.readByte(); + assert type == 0 && reuse instanceof Split.CategoricalSplit + || type == 1 && reuse instanceof Split.ContinuousSplit; + if (type == 0) { + return CATEGORICAL_SPLIT_SERIALIZER.deserialize((Split.CategoricalSplit) reuse, source); + } else { + return CONTINUOUS_SPLIT_SERIALIZER.deserialize((Split.ContinuousSplit) reuse, source); + } + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new SplitSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class SplitSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public SplitSerializerSnapshot() { + super(SplitSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java new file mode 100644 index 000000000..77bee2c3a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param feature subset strategy. */ +public interface HasFeatureSubsetStrategy extends WithParams { + Param FEATURE_SUBSET_STRATEGY = + new StringParam( + "featureSubsetStrategy.", + "Fraction of the training data used for learning one tree. Supports \"auto\", \"all\", \"onethird\", \"sqrt\", \"log2\", (0.0 - 1.0], and [1 - n].", + "auto", + ParamValidators.notNull()); + + default String getFeatureSubsetStrategy() { + return get(FEATURE_SUBSET_STRATEGY); + } + + default T setFeatureSubsetStrategy(String value) { + return set(FEATURE_SUBSET_STRATEGY, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLeafCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLeafCol.java new file mode 100644 index 000000000..52dd29e73 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLeafCol.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param leaf column. */ +public interface HasLeafCol extends WithParams { + Param LEAF_COL = + new StringParam("leafCol", "Predicted leaf index of each instance in each tree.", null); + + default String getLeafCol() { + return get(LEAF_COL); + } + + default T setLeafCol(String value) { + return set(LEAF_COL, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java new file mode 100644 index 000000000..daed708ac --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared maxBins param. */ +public interface HasLossType extends WithParams { + + Param LOSS_TYPE = + new StringParam( + "lossType", + "Loss type.", + "squared", + ParamValidators.inArray("squared", "absolute", "logistic")); + + default String getLossType() { + return get(LOSS_TYPE); + } + + default T setLossType(String value) { + set(LOSS_TYPE, value); + return (T) this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java new file mode 100644 index 000000000..45042c903 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared maxBins param. */ +public interface HasMaxBins extends WithParams { + Param MAX_BINS = + new IntParam( + "maxBins", + "Maximum number of bins used for discretizing continuous features.", + 32, + ParamValidators.gtEq(2)); + + default int getMaxBins() { + return get(MAX_BINS); + } + + default T setMaxBins(int value) { + return set(MAX_BINS, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java new file mode 100644 index 000000000..68a746f4e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared maxDepth param. */ +public interface HasMaxDepth extends WithParams { + Param MAX_DEPTH = + new IntParam("maxDepth", "Maximum depth of the tree.", 5, ParamValidators.gtEq(1)); + + default int getMaxDepth() { + return get(MAX_DEPTH); + } + + default T setMaxDepth(int value) { + return set(MAX_DEPTH, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java new file mode 100644 index 000000000..cbb5c4c08 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param minInfoGain. */ +public interface HasMinInfoGain extends WithParams { + Param MIN_INFO_GAIN = + new DoubleParam( + "minInfoGain", + "Minimum information gain for a split to be considered valid.", + 0., + ParamValidators.gtEq(0.)); + + default double getMinInfoGain() { + return get(MIN_INFO_GAIN); + } + + default T setMinInfoGain(Double value) { + return set(MIN_INFO_GAIN, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java new file mode 100644 index 000000000..91cf8ab8d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared minInstancesPerNode param. */ +public interface HasMinInstancesPerNode extends WithParams { + Param MIN_INSTANCES_PER_NODE = + new IntParam( + "minInstancesPerNode", + "Minimum number of instances each node must have. If a split causes the left or right child to have fewer instances than minInstancesPerNode, the split is invalid.", + 1, + ParamValidators.gtEq(1)); + + default int getMinInstancesPerNode() { + return get(MIN_INSTANCES_PER_NODE); + } + + default T setMinInstancesPerNode(int value) { + return set(MIN_INSTANCES_PER_NODE, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java new file mode 100644 index 000000000..c8fbaa3ae --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param minWeightFractionPerNode. */ +public interface HasMinWeightFractionPerNode extends WithParams { + Param MIN_WEIGHT_FRACTION_PER_NODE = + new DoubleParam( + "minWeightFractionPerNode", + "Minimum fraction of the weighted sample count that each node must have. If a split causes the left or right child to have a smaller fraction of the total weight than minWeightFractionPerNode, the split is invalid.", + 0., + ParamValidators.gtEq(0.)); + + default double getMinWeightFractionPerNode() { + return get(MIN_WEIGHT_FRACTION_PER_NODE); + } + + default T setMinWeightFractionPerNode(Double value) { + return set(MIN_WEIGHT_FRACTION_PER_NODE, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasProbabilityCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasProbabilityCol.java new file mode 100644 index 000000000..3eba2b72f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasProbabilityCol.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param probability column. */ +public interface HasProbabilityCol extends WithParams { + Param PROBABILITY_COL = + new StringParam( + "probabilityCol", + "Column name for predicted class conditional probabilities.", + "probability", + ParamValidators.notNull()); + + default String getProbabilityCol() { + return get(PROBABILITY_COL); + } + + default T setProbabilityCol(String value) { + return set(PROBABILITY_COL, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java new file mode 100644 index 000000000..f0faa2edd --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared step size param. */ +public interface HasStepSize extends WithParams { + Param STEP_SIZE = + new DoubleParam( + "stepSize", + "Step size for shrinking the contribution of each estimator.", + 0.1, + ParamValidators.inRange(0., 1.)); + + default double getStepSize() { + return get(STEP_SIZE); + } + + default T setStepSize(Double value) { + return set(STEP_SIZE, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java new file mode 100644 index 000000000..3f04d6282 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param subsampling rate. */ +public interface HasSubsamplingRate extends WithParams { + Param SUBSAMPLING_RATE = + new DoubleParam( + "subsamplingRate", + "Fraction of the training data used for learning one tree.", + 1., + ParamValidators.inRange(0., 1.)); + + default double getSubsamplingRate() { + return get(SUBSAMPLING_RATE); + } + + default T setSubsamplingRate(Double value) { + return set(SUBSAMPLING_RATE, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java new file mode 100644 index 000000000..5e076474e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared validation indicate column param. */ +public interface HasValidationIndicatorCol extends WithParams { + Param VALIDATION_INDICATOR_COL = + new StringParam( + "validationIndicatorCol", + "The name of the column that indicates whether each row is for training or for validation.", + null); + + default String getValidationIndicatorCol() { + return get(VALIDATION_INDICATOR_COL); + } + + default T setValidationIndicatorCol(String value) { + return set(VALIDATION_INDICATOR_COL, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java new file mode 100644 index 000000000..d50d958d6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared tolerance param. */ +public interface HasValidationTol extends WithParams { + + Param VALIDATION_TOL = + new DoubleParam( + "validationTol", + "Threshold for early stopping when fitting with validation is used.", + .01, + ParamValidators.gtEq(0)); + + default double getValidationTol() { + return get(VALIDATION_TOL); + } + + default T setValidationTol(Double value) { + return set(VALIDATION_TOL, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java new file mode 100644 index 000000000..c7de6d290 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -0,0 +1,500 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.ArrayUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +import static org.apache.flink.table.api.Expressions.$; + +/** Tests {@link GBTClassifier} and {@link GBTClassifierModel}. */ +public class GBTClassifierTest extends AbstractTestBase { + private static final List inputDataRows = + Arrays.asList( + Row.of(1.2, 2, null, 40., 1., 0., Vectors.dense(1.2, 2, Double.NaN)), + Row.of(2.3, 3, "b", 40., 2., 0., Vectors.dense(2.3, 3, 2.)), + Row.of(3.4, 4, "c", 40., 3., 0., Vectors.dense(3.4, 4, 3.)), + Row.of(4.5, 5, "a", 40., 4., 0., Vectors.dense(4.5, 5, 1.)), + Row.of(5.6, 2, "b", 40., 5., 0., Vectors.dense(5.6, 2, 2.)), + Row.of(null, 3, "c", 41., 1., 1., Vectors.dense(Double.NaN, 3, 3.)), + Row.of(12.8, 4, "e", 41., 2., 1., Vectors.dense(12.8, 4, 5.)), + Row.of(13.9, 2, "b", 41., 3., 1., Vectors.dense(13.9, 2, 2.)), + Row.of(14.1, 4, "a", 41., 4., 1., Vectors.dense(14.1, 4, 1.)), + Row.of(15.3, 1, "d", 41., 5., 1., Vectors.dense(15.3, 1, 4.))); + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + List outputRows = + Arrays.asList( + Row.of( + 0.0, + Vectors.dense(2.376078066514637, -2.376078066514637), + Vectors.dense(0.914984852695779, 0.08501514730422102)), + Row.of( + 1.0, + Vectors.dense(-2.5493892913102703, 2.5493892913102703), + Vectors.dense(0.07246752402942669, 0.9275324759705733)), + Row.of( + 1.0, + Vectors.dense(-2.658830586839206, 2.658830586839206), + Vectors.dense(0.06544682253255263, 0.9345531774674474)), + Row.of( + 0.0, + Vectors.dense(2.3309355512336296, -2.3309355512336296), + Vectors.dense(0.9114069063091061, 0.08859309369089385)), + Row.of( + 1.0, + Vectors.dense(-2.6577392865785714, 2.6577392865785714), + Vectors.dense(0.06551360197733425, 0.9344863980226658)), + Row.of( + 0.0, + Vectors.dense(2.5532653631402114, -2.5532653631402114), + Vectors.dense(0.9277925785910718, 0.07220742140892823)), + Row.of( + 0.0, + Vectors.dense(2.3773197509703996, -2.3773197509703996), + Vectors.dense(0.9150813905583675, 0.0849186094416325)), + Row.of( + 1.0, + Vectors.dense(-2.132645378098387, 2.132645378098387), + Vectors.dense(0.10596411850817689, 0.8940358814918231)), + Row.of( + 0.0, + Vectors.dense(2.3105035625447106, -2.3105035625447106), + Vectors.dense(0.9097432116019103, 0.09025678839808973)), + Row.of( + 1.0, + Vectors.dense(-2.0541952729346695, 2.0541952729346695), + Vectors.dense(0.11362915817869357, 0.8863708418213064))); + private StreamTableEnvironment tEnv; + private Table inputTable; + + private static void verifyPredictionResult(Table output, List expected) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + //noinspection unchecked + List results = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + final double delta = 1e-3; + final Comparator denseVectorComparator = + new TestUtils.DenseVectorComparatorWithDelta(delta); + final Comparator comparator = + Comparator.comparing(d -> d.getFieldAs(0)) + .thenComparing(d -> d.getFieldAs(1), denseVectorComparator) + .thenComparing(d -> d.getFieldAs(2), denseVectorComparator); + TestUtils.compareResultCollectionsWithComparator(expected, results, comparator); + } + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputDataRows, + new RowTypeInfo( + new TypeInformation[] { + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.DOUBLE, + Types.DOUBLE, + Types.DOUBLE, + VectorTypeInfo.INSTANCE + }, + new String[] { + "f0", "f1", "f2", "label", "weight", "cls_label", "vec" + }))); + } + + @Test + public void testParam() { + GBTClassifier gbtc = new GBTClassifier(); + Assert.assertEquals("features", gbtc.getFeaturesCol()); + Assert.assertNull(gbtc.getInputCols()); + Assert.assertEquals("label", gbtc.getLabelCol()); + Assert.assertArrayEquals(new String[] {}, gbtc.getCategoricalCols()); + Assert.assertEquals("prediction", gbtc.getPredictionCol()); + + Assert.assertNull(gbtc.getLeafCol()); + Assert.assertNull(gbtc.getWeightCol()); + Assert.assertEquals(5, gbtc.getMaxDepth()); + Assert.assertEquals(32, gbtc.getMaxBins()); + Assert.assertEquals(1, gbtc.getMinInstancesPerNode()); + Assert.assertEquals(0., gbtc.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(0., gbtc.getMinInfoGain(), 1e-12); + Assert.assertEquals(20, gbtc.getMaxIter()); + Assert.assertEquals(.1, gbtc.getStepSize(), 1e-12); + Assert.assertEquals(GBTClassifier.class.getName().hashCode(), gbtc.getSeed()); + Assert.assertEquals(1., gbtc.getSubsamplingRate(), 1e-12); + Assert.assertEquals("auto", gbtc.getFeatureSubsetStrategy()); + Assert.assertNull(gbtc.getValidationIndicatorCol()); + Assert.assertEquals(.01, gbtc.getValidationTol(), 1e-12); + Assert.assertEquals(0., gbtc.getRegLambda(), 1e-12); + Assert.assertEquals(1., gbtc.getRegGamma(), 1e-12); + + Assert.assertEquals("logistic", gbtc.getLossType()); + Assert.assertEquals("rawPrediction", gbtc.getRawPredictionCol()); + Assert.assertEquals("probability", gbtc.getProbabilityCol()); + + gbtc.setFeaturesCol("vec") + .setInputCols("f0", "f1", "f2") + .setLabelCol("cls_label") + .setCategoricalCols("f0", "f1") + .setPredictionCol("pred") + .setLeafCol("leaf") + .setWeightCol("weight") + .setMaxDepth(6) + .setMaxBins(64) + .setMinInstancesPerNode(2) + .setMinWeightFractionPerNode(.1) + .setMinInfoGain(.1) + .setMaxIter(10) + .setStepSize(.2) + .setSeed(123) + .setSubsamplingRate(.8) + .setFeatureSubsetStrategy("0.5") + .setValidationIndicatorCol("val") + .setValidationTol(.1) + .setRegLambda(.1) + .setRegGamma(.1) + .setRawPredictionCol("raw_pred") + .setProbabilityCol("prob"); + + Assert.assertEquals("vec", gbtc.getFeaturesCol()); + Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtc.getInputCols()); + Assert.assertEquals("cls_label", gbtc.getLabelCol()); + Assert.assertArrayEquals(new String[] {"f0", "f1"}, gbtc.getCategoricalCols()); + Assert.assertEquals("pred", gbtc.getPredictionCol()); + + Assert.assertEquals("leaf", gbtc.getLeafCol()); + Assert.assertEquals("weight", gbtc.getWeightCol()); + Assert.assertEquals(6, gbtc.getMaxDepth()); + Assert.assertEquals(64, gbtc.getMaxBins()); + Assert.assertEquals(2, gbtc.getMinInstancesPerNode()); + Assert.assertEquals(.1, gbtc.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(.1, gbtc.getMinInfoGain(), 1e-12); + Assert.assertEquals(10, gbtc.getMaxIter()); + Assert.assertEquals(.2, gbtc.getStepSize(), 1e-12); + Assert.assertEquals(123, gbtc.getSeed()); + Assert.assertEquals(.8, gbtc.getSubsamplingRate(), 1e-12); + Assert.assertEquals("0.5", gbtc.getFeatureSubsetStrategy()); + Assert.assertEquals("val", gbtc.getValidationIndicatorCol()); + Assert.assertEquals(.1, gbtc.getValidationTol(), 1e-12); + Assert.assertEquals(.1, gbtc.getRegLambda(), 1e-12); + Assert.assertEquals(.1, gbtc.getRegGamma(), 1e-12); + + Assert.assertEquals("raw_pred", gbtc.getRawPredictionCol()); + Assert.assertEquals("prob", gbtc.getProbabilityCol()); + } + + @Test + public void testOutputSchema() throws Exception { + GBTClassifier gbtc = + new GBTClassifier().setInputCols("f0", "f1", "f2").setCategoricalCols("f2"); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = model.transform(inputTable)[0]; + Assert.assertArrayEquals( + ArrayUtils.addAll( + inputTable.getResolvedSchema().getColumnNames().toArray(new String[0]), + gbtc.getPredictionCol(), + gbtc.getRawPredictionCol(), + gbtc.getProbabilityCol()), + output.getResolvedSchema().getColumnNames().toArray(new String[0])); + } + + @Test + public void testFitAndPredict() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testFitAndPredictWithVectorCol() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setFeaturesCol("vec") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + List outputRowsUsingVectorCol = + Arrays.asList( + Row.of( + 0.0, + Vectors.dense(1.9834935486026828, -1.9834935486026828), + Vectors.dense(0.8790530839977041, 0.12094691600229594)), + Row.of( + 1.0, + Vectors.dense(-1.9962334686995544, 1.9962334686995544), + Vectors.dense(0.11959895119804398, 0.880401048801956)), + Row.of( + 0.0, + Vectors.dense(2.2596958412285053, -2.2596958412285053), + Vectors.dense(0.9054836034255209, 0.0945163965744791)), + Row.of( + 1.0, + Vectors.dense(-2.23023965816558, 2.23023965816558), + Vectors.dense(0.09706763399626683, 0.9029323660037332)), + Row.of( + 1.0, + Vectors.dense(-2.520667396406638, 2.520667396406638), + Vectors.dense(0.0744219596185437, 0.9255780403814563)), + Row.of( + 0.0, + Vectors.dense(2.5005544570205114, -2.5005544570205114), + Vectors.dense(0.9241806803368346, 0.07581931966316532)), + Row.of( + 0.0, + Vectors.dense(2.155310746068554, -2.155310746068554), + Vectors.dense(0.8961640042377698, 0.10383599576223027)), + Row.of( + 1.0, + Vectors.dense(-2.2386996519306424, 2.2386996519306424), + Vectors.dense(0.09632867690962832, 0.9036713230903717)), + Row.of( + 0.0, + Vectors.dense(2.0375281995821273, -2.0375281995821273), + Vectors.dense(0.8846813338862343, 0.11531866611376576)), + Row.of( + 1.0, + Vectors.dense(-1.9751553623558855, 1.9751553623558855), + Vectors.dense(0.12183622723878906, 0.8781637727612109))); + verifyPredictionResult(output, outputRowsUsingVectorCol); + } + + @Test + public void testFitAndPredictWithNoCategoricalCols() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setInputCols("f0", "f1") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(5) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + List outputRowsUsingNoCategoricalCols = + Arrays.asList( + Row.of( + 0.0, + Vectors.dense(2.4386858360079877, -2.4386858360079877), + Vectors.dense(0.9197301210345855, 0.08026987896541447)), + Row.of( + 0.0, + Vectors.dense(2.079593609142336, -2.079593609142336), + Vectors.dense(0.8889039070093702, 0.11109609299062985)), + Row.of( + 1.0, + Vectors.dense(-2.4477766607449594, 2.4477766607449594), + Vectors.dense(0.07960128978764613, 0.9203987102123539)), + Row.of( + 0.0, + Vectors.dense(2.3680506847981113, -2.3680506847981113), + Vectors.dense(0.9143583384561507, 0.0856416615438493)), + Row.of( + 1.0, + Vectors.dense(-2.0115161495245792, 2.0115161495245792), + Vectors.dense(0.11799909267017583, 0.8820009073298242)), + Row.of( + 0.0, + Vectors.dense(2.3680506847981113, -2.3680506847981113), + Vectors.dense(0.9143583384561507, 0.0856416615438493)), + Row.of( + 1.0, + Vectors.dense(-2.1774376078697983, 2.1774376078697983), + Vectors.dense(0.10179497553813543, 0.8982050244618646)), + Row.of( + 0.0, + Vectors.dense(2.434832949283468, -2.434832949283468), + Vectors.dense(0.9194452150195366, 0.08055478498046341)), + Row.of( + 1.0, + Vectors.dense(-2.441225164856452, 2.441225164856452), + Vectors.dense(0.08008260858505134, 0.9199173914149487)), + Row.of( + 1.0, + Vectors.dense(-2.672457199454413, 2.672457199454413), + Vectors.dense(0.06461828968951666, 0.9353817103104833))); + verifyPredictionResult(output, outputRowsUsingNoCategoricalCols); + } + + @Test + public void testEstimatorSaveLoadAndPredict() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifier loadedGbtc = + TestUtils.saveAndReload(tEnv, gbtc, tempFolder.newFolder().getAbsolutePath()); + GBTClassifierModel model = loadedGbtc.fit(inputTable); + Assert.assertEquals( + Collections.singletonList("modelData"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = + model.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testModelSaveLoadAndPredict() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + GBTClassifierModel loadedModel = + TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); + Table output = + loadedModel.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testGetModelData() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel model = gbtc.fit(inputTable); + Table modelDataTable = model.getModelData()[0]; + List modelDataColumnNames = modelDataTable.getResolvedSchema().getColumnNames(); + DataStream output = tEnv.toDataStream(modelDataTable); + Assert.assertArrayEquals( + new String[] {"modelData"}, modelDataColumnNames.toArray(new String[0])); + + Row modelDataRow = (Row) IteratorUtils.toList(output.executeAndCollect()).get(0); + GBTModelData modelData = modelDataRow.getFieldAs(0); + Assert.assertNotNull(modelData); + + Assert.assertEquals(TaskType.CLASSIFICATION, TaskType.valueOf(modelData.type)); + Assert.assertFalse(modelData.isInputVector); + Assert.assertEquals(0., modelData.prior, 1e-12); + Assert.assertEquals(gbtc.getStepSize(), modelData.stepSize, 1e-12); + Assert.assertEquals(gbtc.getMaxIter(), modelData.roots.size()); + Assert.assertEquals(gbtc.getCategoricalCols().length, modelData.categoryToIdMaps.size()); + Assert.assertEquals( + gbtc.getInputCols().length - gbtc.getCategoricalCols().length, + modelData.featureIdToBinEdges.size()); + Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); + } + + @Test + public void testSetModelData() throws Exception { + GBTClassifier gbtc = + new GBTClassifier() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("cls_label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTClassifierModel modelA = gbtc.fit(inputTable); + Table modelDataTable = modelA.getModelData()[0]; + GBTClassifierModel modelB = new GBTClassifierModel().setModelData(modelDataTable); + ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); + Table output = + modelA.transform(inputTable)[0].select( + $(gbtc.getPredictionCol()), + $(gbtc.getRawPredictionCol()), + $(gbtc.getProbabilityCol())); + verifyPredictionResult(output, outputRows); + } +} From 2c43fa65850d5c2724285aa50cf62a65fe4689e7 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 7 Feb 2023 12:02:49 +0800 Subject: [PATCH 05/47] Add GBTRegressor --- .../flink/ml/common/gbt/BaseGBTModel.java | 3 +- .../flink/ml/common/gbt/GBTModelData.java | 3 +- .../flink/ml/common/gbt/GBTModelParams.java | 3 +- .../apache/flink/ml/common/gbt/GBTRunner.java | 4 +- .../regression/gbtregressor/GBTRegressor.java | 74 ++++ .../gbtregressor/GBTRegressorModel.java | 114 ++++++ .../gbtregressor/GBTRegressorParams.java | 46 +++ .../flink/ml/regression/GBTRegressorTest.java | 374 ++++++++++++++++++ 8 files changed, 616 insertions(+), 5 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java index 5ba8b7a9d..9b4e81001 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java @@ -21,6 +21,7 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.table.api.Table; @@ -29,7 +30,7 @@ import java.util.HashMap; import java.util.Map; -/** Base model computed by {@link GBTClassifier}. */ +/** Base model computed by {@link GBTClassifier} or {@link GBTRegressor}. */ public abstract class BaseGBTModel> implements Model, GBTModelParams { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index 542caf7e2..bc1f12fc6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -37,6 +37,7 @@ import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -53,7 +54,7 @@ import java.util.List; /** - * Model data of gradient boosting trees. + * Model data of {@link GBTClassifierModel}. * *

This class also provides methods to convert model data from Table to Datastream, and classes * to save/load model data. diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java index c0997c4cf..50c078303 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java @@ -25,9 +25,10 @@ import org.apache.flink.ml.common.param.HasPredictionCol; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; /** - * Params of {@link GBTClassifierModel}. + * Params of {@link GBTClassifierModel} and {@link GBTRegressorModel}. * *

If the input features come from 1 column of vector type, `featuresCol` should be used, and all * features are treated as continuous features. Otherwise, `inputCols` should be used for multiple diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index c748325ee..f1c73473b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -34,6 +34,7 @@ import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.TaskType; import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorParams; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -173,8 +174,7 @@ public static GbtParams fromEstimator(BaseGBTParams estimator, TaskType taskT if (TaskType.CLASSIFICATION.equals(p.taskType)) { p.lossType = estimator.get(GBTClassifierParams.LOSS_TYPE); } else { - // TODO: add GBTRegressorParams.LOSS_TYPE in next PR. - p.lossType = estimator.get(GBTClassifierParams.LOSS_TYPE); + p.lossType = estimator.get(GBTRegressorParams.LOSS_TYPE); } p.maxNumLeaves = 1 << p.maxDepth - 1; p.useMissing = true; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java new file mode 100644 index 000000000..d9435bdaf --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.GBTRunner; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +/** An Estimator which implements the gradient boosting trees regression algorithm. */ +public class GBTRegressor + implements Estimator, GBTRegressorParams { + + private final Map, Object> paramMap = new HashMap<>(); + + public GBTRegressor() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + public static GBTRegressor load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + @Override + public GBTRegressorModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream modelData = GBTRunner.trainRegressor(inputs[0], this); + GBTRegressorModel model = new GBTRegressorModel(); + model.setModelData(tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData"))); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java new file mode 100644 index 000000000..acd6e65c1 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.BaseGBTModel; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; +import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; + +import java.io.IOException; +import java.util.Collections; + +/** A Model computed by {@link GBTRegressor}. */ +public class GBTRegressorModel extends BaseGBTModel { + + /** + * Loads model data from path. + * + * @param tEnv A StreamTableEnvironment instance. + * @param path Model path. + * @return GBT regression model. + */ + public static GBTRegressorModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + GBTRegressorModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData(tEnv, path, new GBTModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream inputStream = tEnv.toDataStream(inputs[0]); + final String broadcastModelKey = "broadcastModelKey"; + DataStream modelDataStream = GBTModelData.getModelDataStream(modelDataTable); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + //noinspection unchecked + DataStream inputData = (DataStream) inputList.get(0); + return inputData.map( + new PredictLabelFunction( + broadcastModelKey, getInputCols(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + private static class PredictLabelFunction extends RichMapFunction { + + private final String broadcastModelKey; + private final String[] inputCols; + private final String featuresCol; + private GBTModelData modelData; + + public PredictLabelFunction( + String broadcastModelKey, String[] inputCols, String featuresCol) { + this.broadcastModelKey = broadcastModelKey; + this.inputCols = inputCols; + this.featuresCol = featuresCol; + } + + @Override + public Row map(Row value) throws Exception { + if (null == modelData) { + modelData = + (GBTModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + } + IntDoubleHashMap features = modelData.rowToFeatures(value, inputCols, featuresCol); + double pred = modelData.predictRaw(features); + return Row.join(value, Row.of(pred)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java new file mode 100644 index 000000000..0f9ed5d27 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.ml.common.gbt.BaseGBTParams; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Parameters for {@link GBTRegressor}. + * + * @param The class type of this instance. + */ +public interface GBTRegressorParams extends BaseGBTParams { + Param LOSS_TYPE = + new StringParam( + "lossType", + "Loss type.", + "squared", + ParamValidators.inArray("squared", "absolute")); + + default String getLossType() { + return get(LOSS_TYPE); + } + + default T setLossType(String value) { + return set(LOSS_TYPE, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java new file mode 100644 index 000000000..af0810e6e --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.ArrayUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +import static org.apache.flink.table.api.Expressions.$; + +/** Tests {@link GBTRegressor} and {@link GBTRegressorModel}. */ +public class GBTRegressorTest extends AbstractTestBase { + private static final List inputDataRows = + Arrays.asList( + Row.of(1.2, 2, null, 40., 1., 0., Vectors.dense(1.2, 2, Double.NaN)), + Row.of(2.3, 3, "b", 40., 2., 0., Vectors.dense(2.3, 3, 2.)), + Row.of(3.4, 4, "c", 40., 3., 0., Vectors.dense(3.4, 4, 3.)), + Row.of(4.5, 5, "a", 40., 4., 0., Vectors.dense(4.5, 5, 1.)), + Row.of(5.6, 2, "b", 40., 5., 0., Vectors.dense(5.6, 2, 2.)), + Row.of(null, 3, "c", 41., 1., 1., Vectors.dense(Double.NaN, 3, 3.)), + Row.of(12.8, 4, "e", 41., 2., 1., Vectors.dense(12.8, 4, 5.)), + Row.of(13.9, 2, "b", 41., 3., 1., Vectors.dense(13.9, 2, 2.)), + Row.of(14.1, 4, "a", 41., 4., 1., Vectors.dense(14.1, 4, 1.)), + Row.of(15.3, 1, "d", 41., 5., 1., Vectors.dense(15.3, 1, 4.))); + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + List outputRows = + Arrays.asList( + Row.of(40.06841194119824), + Row.of(40.94100994144195), + Row.of(40.93898887207972), + Row.of(40.14918141164082), + Row.of(40.90620397010659), + Row.of(40.06041865505043), + Row.of(40.1049148535624), + Row.of(40.88096567879293), + Row.of(40.08071914298763), + Row.of(40.86772065751431)); + + private StreamTableEnvironment tEnv; + private Table inputTable; + + private static void verifyPredictionResult(Table output, List expected) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + //noinspection unchecked + List results = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + final double delta = 1e-9; + final Comparator comparator = + Comparator.comparing( + d -> d.getFieldAs(0), new TestUtils.DoubleComparatorWithDelta(delta)); + TestUtils.compareResultCollectionsWithComparator(expected, results, comparator); + } + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputDataRows, + new RowTypeInfo( + new TypeInformation[] { + Types.DOUBLE, + Types.INT, + Types.STRING, + Types.DOUBLE, + Types.DOUBLE, + Types.DOUBLE, + VectorTypeInfo.INSTANCE + }, + new String[] { + "f0", "f1", "f2", "label", "weight", "cls_label", "vec" + }))); + } + + @Test + public void testParam() { + GBTRegressor gbtr = new GBTRegressor(); + Assert.assertEquals("features", gbtr.getFeaturesCol()); + Assert.assertNull(gbtr.getInputCols()); + Assert.assertEquals("label", gbtr.getLabelCol()); + Assert.assertArrayEquals(new String[] {}, gbtr.getCategoricalCols()); + Assert.assertEquals("prediction", gbtr.getPredictionCol()); + + Assert.assertNull(gbtr.getLeafCol()); + Assert.assertNull(gbtr.getWeightCol()); + Assert.assertEquals(5, gbtr.getMaxDepth()); + Assert.assertEquals(32, gbtr.getMaxBins()); + Assert.assertEquals(1, gbtr.getMinInstancesPerNode()); + Assert.assertEquals(0., gbtr.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(0., gbtr.getMinInfoGain(), 1e-12); + Assert.assertEquals(20, gbtr.getMaxIter()); + Assert.assertEquals(.1, gbtr.getStepSize(), 1e-12); + Assert.assertEquals(GBTRegressor.class.getName().hashCode(), gbtr.getSeed()); + Assert.assertEquals(1., gbtr.getSubsamplingRate(), 1e-12); + Assert.assertEquals("auto", gbtr.getFeatureSubsetStrategy()); + Assert.assertNull(gbtr.getValidationIndicatorCol()); + Assert.assertEquals(.01, gbtr.getValidationTol(), 1e-12); + Assert.assertEquals(0., gbtr.getRegLambda(), 1e-12); + Assert.assertEquals(1., gbtr.getRegGamma(), 1e-12); + + Assert.assertEquals("squared", gbtr.getLossType()); + + gbtr.setFeaturesCol("vec") + .setInputCols("f0", "f1", "f2") + .setLabelCol("label") + .setCategoricalCols("f0", "f1") + .setPredictionCol("pred") + .setLeafCol("leaf") + .setWeightCol("weight") + .setMaxDepth(6) + .setMaxBins(64) + .setMinInstancesPerNode(2) + .setMinWeightFractionPerNode(.1) + .setMinInfoGain(.1) + .setMaxIter(10) + .setStepSize(.2) + .setSeed(123) + .setSubsamplingRate(.8) + .setFeatureSubsetStrategy("0.5") + .setValidationIndicatorCol("val") + .setValidationTol(.1) + .setRegLambda(.1) + .setRegGamma(.1); + + Assert.assertEquals("vec", gbtr.getFeaturesCol()); + Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtr.getInputCols()); + Assert.assertEquals("label", gbtr.getLabelCol()); + Assert.assertArrayEquals(new String[] {"f0", "f1"}, gbtr.getCategoricalCols()); + Assert.assertEquals("pred", gbtr.getPredictionCol()); + + Assert.assertEquals("leaf", gbtr.getLeafCol()); + Assert.assertEquals("weight", gbtr.getWeightCol()); + Assert.assertEquals(6, gbtr.getMaxDepth()); + Assert.assertEquals(64, gbtr.getMaxBins()); + Assert.assertEquals(2, gbtr.getMinInstancesPerNode()); + Assert.assertEquals(.1, gbtr.getMinWeightFractionPerNode(), 1e-12); + Assert.assertEquals(.1, gbtr.getMinInfoGain(), 1e-12); + Assert.assertEquals(10, gbtr.getMaxIter()); + Assert.assertEquals(.2, gbtr.getStepSize(), 1e-12); + Assert.assertEquals(123, gbtr.getSeed()); + Assert.assertEquals(.8, gbtr.getSubsamplingRate(), 1e-12); + Assert.assertEquals("0.5", gbtr.getFeatureSubsetStrategy()); + Assert.assertEquals("val", gbtr.getValidationIndicatorCol()); + Assert.assertEquals(.1, gbtr.getValidationTol(), 1e-12); + Assert.assertEquals(.1, gbtr.getRegLambda(), 1e-12); + Assert.assertEquals(.1, gbtr.getRegGamma(), 1e-12); + } + + @Test + public void testOutputSchema() throws Exception { + GBTRegressor gbtr = + new GBTRegressor().setInputCols("f0", "f1", "f2").setCategoricalCols("f2"); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0]; + Assert.assertArrayEquals( + ArrayUtils.addAll( + inputTable.getResolvedSchema().getColumnNames().toArray(new String[0]), + gbtr.getPredictionCol()), + output.getResolvedSchema().getColumnNames().toArray(new String[0])); + } + + @Test + public void testFitAndPredict() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testFitAndPredictWithVectorCol() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setFeaturesCol("vec") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + List outputRowsUsingVectorCol = + Arrays.asList( + Row.of(40.11011764668384), + Row.of(40.8838231947867), + Row.of(40.064839102170275), + Row.of(40.10374937485196), + Row.of(40.909914467915144), + Row.of(40.11472131282394), + Row.of(40.88106076252836), + Row.of(40.089859516616336), + Row.of(40.90833852360301), + Row.of(40.94920075468803)); + verifyPredictionResult(output, outputRowsUsingVectorCol); + } + + @Test + public void testFitAndPredictWithNoCategoricalCols() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setInputCols("f0", "f1") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(5) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + List outputRowsUsingNoCategoricalCols = + Arrays.asList( + Row.of(40.07663214615239), + Row.of(40.92462268161843), + Row.of(40.941626445241624), + Row.of(40.06608854749729), + Row.of(40.12272436518743), + Row.of(40.92737873124178), + Row.of(40.08092204935494), + Row.of(40.898529570430696), + Row.of(40.08092204935494), + Row.of(40.88296818645738)); + verifyPredictionResult(output, outputRowsUsingNoCategoricalCols); + } + + @Test + public void testEstimatorSaveLoadAndPredict() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressor loadedgbtr = + TestUtils.saveAndReload(tEnv, gbtr, tempFolder.newFolder().getAbsolutePath()); + GBTRegressorModel model = loadedgbtr.fit(inputTable); + Assert.assertEquals( + Collections.singletonList("modelData"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testModelSaveLoadAndPredict() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + GBTRegressorModel loadedModel = + TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); + Table output = loadedModel.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } + + @Test + public void testGetModelData() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel model = gbtr.fit(inputTable); + Table modelDataTable = model.getModelData()[0]; + List modelDataColumnNames = modelDataTable.getResolvedSchema().getColumnNames(); + DataStream output = tEnv.toDataStream(modelDataTable); + Assert.assertArrayEquals( + new String[] {"modelData"}, modelDataColumnNames.toArray(new String[0])); + + Row modelDataRow = (Row) IteratorUtils.toList(output.executeAndCollect()).get(0); + GBTModelData modelData = modelDataRow.getFieldAs(0); + Assert.assertNotNull(modelData); + + Assert.assertEquals(TaskType.REGRESSION, TaskType.valueOf(modelData.type)); + Assert.assertFalse(modelData.isInputVector); + Assert.assertEquals(40.5, modelData.prior, .5); + Assert.assertEquals(gbtr.getStepSize(), modelData.stepSize, 1e-12); + Assert.assertEquals(gbtr.getMaxIter(), modelData.roots.size()); + Assert.assertEquals(gbtr.getCategoricalCols().length, modelData.categoryToIdMaps.size()); + Assert.assertEquals( + gbtr.getInputCols().length - gbtr.getCategoricalCols().length, + modelData.featureIdToBinEdges.size()); + Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); + } + + @Test + public void testSetModelData() throws Exception { + GBTRegressor gbtr = + new GBTRegressor() + .setInputCols("f0", "f1", "f2") + .setCategoricalCols("f2") + .setLabelCol("label") + .setRegGamma(0.) + .setMaxBins(3) + .setSeed(123); + GBTRegressorModel modelA = gbtr.fit(inputTable); + Table modelDataTable = modelA.getModelData()[0]; + GBTRegressorModel modelB = new GBTRegressorModel().setModelData(modelDataTable); + ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); + Table output = modelA.transform(inputTable)[0].select($(gbtr.getPredictionCol())); + verifyPredictionResult(output, outputRows); + } +} From 2fe1f723011f08a266b915f8411c7a532f4dc265 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 13 Feb 2023 19:23:51 +0800 Subject: [PATCH 06/47] Fixing some missing Javadoc comment. --- .../test/java/org/apache/flink/ml/util/TestUtils.java | 9 +++++++++ .../org/apache/flink/ml/common/gbt/GBTModelData.java | 2 +- .../org/apache/flink/ml/common/gbt/GBTRunnerTest.java | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java index 59d1b119a..33fc59cd4 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -341,6 +341,11 @@ public static void compareResultCollectionsWithComparator( } } + /** + * Compare two doubles with specified delta. If the differences between the two doubles are + * equal or less than delta, they are considered equal. Otherwise, they are compared with + * default comparison. + */ public static class DoubleComparatorWithDelta implements Comparator { private final double delta; @@ -354,6 +359,10 @@ public int compare(Double o1, Double o2) { } } + /** + * Compare two dense vectors with specified delta. When comparing their values, {@link + * DoubleComparatorWithDelta} is used. + */ public static class DenseVectorComparatorWithDelta implements Comparator { private final DoubleComparatorWithDelta doubleComparatorWithDelta; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index bc1f12fc6..3904819ba 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -54,7 +54,7 @@ import java.util.List; /** - * Model data of {@link GBTClassifierModel}. + * Model data of {@link GBTClassifierModel} and {@link GBTRegressorModel}. * *

This class also provides methods to convert model data from Table to Datastream, and classes * to save/load model data. diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java index 16a770189..2811d83cd 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java @@ -43,6 +43,7 @@ import java.util.Arrays; import java.util.List; +/** Tests {@link GBTRunner}. */ public class GBTRunnerTest extends AbstractTestBase { private static final List inputDataRows = Arrays.asList( From dbe0009d7e3f2f5072d7a2691a965a04b38f227f Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 7 Feb 2023 12:01:02 +0800 Subject: [PATCH 07/47] [NO MERGE] Ad-hoc fix of KBinsDiscretizer --- .../ml/feature/kbinsdiscretizer/KBinsDiscretizer.java | 9 +++++++-- .../feature/kbinsdiscretizer/KBinsDiscretizerModel.java | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java index 763a0df22..ad3132cf6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java @@ -217,8 +217,13 @@ private static double[][] findBinEdgesWithQuantileStrategy( features[i] = input.get(i).get(columnId); } Arrays.sort(features); + int n = numData; - if (features[0] == features[numData - 1]) { + while (n > 0 && Double.isNaN(features[n - 1])) { + n -= 1; + } + + if (features[0] == features[n - 1]) { LOG.warn("Feature " + columnId + " is constant and the output will all be zero."); binEdges[columnId] = new double[] {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}; @@ -231,7 +236,7 @@ private static double[][] findBinEdgesWithQuantileStrategy( for (int binEdgeId = 0; binEdgeId < numBins; binEdgeId++) { tempBinEdges[binEdgeId] = features[(int) (binEdgeId * width)]; } - tempBinEdges[numBins] = features[numData - 1]; + tempBinEdges[numBins] = features[n - 1]; } else { tempBinEdges = features; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java index 03f2fc394..3bd429cf9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java @@ -154,6 +154,10 @@ public Row map(Row row) { DenseVector outputVec = inputVec.clone(); for (int i = 0; i < inputVec.size(); i++) { double targetFeature = inputVec.get(i); + if (Double.isNaN(targetFeature)) { + outputVec.set(i, binEdges[i].length - 1); + continue; + } int index = Arrays.binarySearch(binEdges[i], targetFeature); if (index < 0) { // Computes the index to insert. From 638ebee6ffdbcf3552f877f4fdd1f726cc2860c6 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 13 Feb 2023 19:38:13 +0800 Subject: [PATCH 08/47] Fix tests according to ad-hoc fix of KBinsDiscretizer. --- .../apache/flink/ml/common/gbt/PreprocessTest.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java index a9e3d6665..0ad602a74 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java @@ -143,7 +143,7 @@ public void testPreprocessCols() throws Exception { // TODO: correct `binEdges` of feature `f0` after FLINK-30734 resolved. List expectedMeta = Arrays.asList( - FeatureMeta.continuous("f0", 3, new double[] {1.2, 4.5, 13.9, Double.NaN}), + FeatureMeta.continuous("f0", 3, new double[] {1.2, 4.5, 13.9, 15.3}), FeatureMeta.continuous("f1", 3, new double[] {1.0, 2.0, 4.0, 5.0}), FeatureMeta.categorical("f2", 5, new String[] {"a", "b", "c", "d", "e"})); @@ -154,7 +154,7 @@ public void testPreprocessCols() throws Exception { Row.of(40.0, 0, 2, 2.0), Row.of(40.0, 1, 2, 0.0), Row.of(40.0, 1, 1, 1.0), - Row.of(41.0, 2, 1, 2.0), + Row.of(41.0, 3, 1, 2.0), Row.of(41.0, 1, 2, 4.0), Row.of(41.0, 2, 1, 1.0), Row.of(41.0, 2, 2, 0.0), @@ -188,19 +188,17 @@ public void testPreprocessVectorCol() throws Exception { // TODO: correct `binEdges` of feature `_vec_f0` and `_vec_f2` after FLINK-30734 resolved. List expectedMeta = Arrays.asList( - FeatureMeta.continuous( - "_vec_f0", 3, new double[] {1.2, 4.5, 13.9, Double.NaN}), + FeatureMeta.continuous("_vec_f0", 3, new double[] {1.2, 4.5, 13.9, 15.3}), FeatureMeta.continuous("_vec_f1", 3, new double[] {1.0, 2.0, 4.0, 5.0}), - FeatureMeta.continuous( - "_vec_f2", 3, new double[] {1.0, 2.0, 3.0, Double.NaN})); + FeatureMeta.continuous("_vec_f2", 3, new double[] {1.0, 2.0, 3.0, 5.0})); List expectedPreprocessedRows = Arrays.asList( - Row.of(40.0, Vectors.dense(0, 1, 2.0)), + Row.of(40.0, Vectors.dense(0, 1, 3.0)), Row.of(40.0, Vectors.dense(0, 1, 1.0)), Row.of(40.0, Vectors.dense(0, 2, 2.0)), Row.of(40.0, Vectors.dense(1, 2, 0.0)), Row.of(40.0, Vectors.dense(1, 1, 1.0)), - Row.of(41.0, Vectors.dense(2, 1, 2.0)), + Row.of(41.0, Vectors.dense(3, 1, 2.0)), Row.of(41.0, Vectors.dense(1, 2, 2.0)), Row.of(41.0, Vectors.dense(2, 1, 1.0)), Row.of(41.0, Vectors.dense(2, 2, 0.0)), From 34797f0bedf8071369b483f804d47d5ce1e6a5a2 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 13 Feb 2023 12:04:13 +0800 Subject: [PATCH 09/47] Change LocalState datastream to JVM static memory --- .../ml/common/gbt/BoostIterationBody.java | 108 +++------- .../flink/ml/common/gbt/GBTModelData.java | 34 +-- .../apache/flink/ml/common/gbt/GBTRunner.java | 33 ++- .../flink/ml/common/gbt/defs/GbtParams.java | 2 + .../flink/ml/common/gbt/defs/Histogram.java | 2 + .../ml/common/gbt/defs/LearningNode.java | 13 +- .../flink/ml/common/gbt/defs/LocalState.java | 85 -------- .../apache/flink/ml/common/gbt/defs/Node.java | 10 +- .../flink/ml/common/gbt/defs/Slice.java | 2 + .../flink/ml/common/gbt/defs/Split.java | 4 + .../flink/ml/common/gbt/defs/Splits.java | 2 + .../ml/common/gbt/defs/TrainContext.java | 46 ++++ .../CacheDataCalcLocalHistsOperator.java | 181 ++++++++++------ .../operators/CalcLocalSplitsOperator.java | 67 +++--- .../ml/common/gbt/operators/HistBuilder.java | 69 +++--- .../common/gbt/operators/InstanceUpdater.java | 20 +- .../ml/common/gbt/operators/NodeSplitter.java | 49 +++-- .../gbt/operators/PostSplitsOperator.java | 202 ++++++++++++------ .../ml/common/gbt/operators/SharedKeys.java | 63 ++++++ .../ml/common/gbt/operators/SplitFinder.java | 47 ++-- .../gbt/operators/TerminationOperator.java | 78 +++++++ ...izer.java => TrainContextInitializer.java} | 55 +++-- .../common/gbt/operators/TreeInitializer.java | 37 ++-- .../typeinfo/BinnedInstanceSerializer.java | 2 +- .../typeinfo/ContinuousSplitSerializer.java | 36 ++-- .../gbt/typeinfo/GBTModelDataSerializer.java | 28 ++- .../gbt/typeinfo/HistogramSerializer.java | 119 +++++++++++ .../gbt/typeinfo/IntIntPairSerializer.java | 102 +++++++++ .../gbt/typeinfo/LearningNodeSerializer.java | 122 +++++++++++ .../common/gbt/typeinfo/NodeSerializer.java | 31 +-- .../ml/common/gbt/typeinfo/NodeTypeInfo.java | 88 ++++++++ .../gbt/typeinfo/NodeTypeInfoFactory.java | 40 ++++ .../common/gbt/typeinfo/SliceSerializer.java | 110 ++++++++++ .../common/gbt/typeinfo/SplitSerializer.java | 15 +- .../ml/common/gbt/typeinfo/SplitTypeInfo.java | 88 ++++++++ .../gbt/typeinfo/SplitTypeInfoFactory.java | 47 ++++ .../ml/classification/GBTClassifierTest.java | 2 +- .../flink/ml/common/gbt/GBTRunnerTest.java | 2 +- .../flink/ml/regression/GBTRegressorTest.java | 2 +- 39 files changed, 1484 insertions(+), 559 deletions(-) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/{LocalStateInitializer.java => TrainContextInitializer.java} (76%) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/LearningNodeSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfo.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfoFactory.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SliceSerializer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfo.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfoFactory.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index d3cc62355..faa4acf7a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.iteration.DataStreamList; @@ -30,17 +31,17 @@ import org.apache.flink.iteration.IterationID; import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.Histogram; -import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.operators.CacheDataCalcLocalHistsOperator; import org.apache.flink.ml.common.gbt.operators.CalcLocalSplitsOperator; import org.apache.flink.ml.common.gbt.operators.HistogramAggregateFunction; import org.apache.flink.ml.common.gbt.operators.PostSplitsOperator; import org.apache.flink.ml.common.gbt.operators.SplitsAggregateFunction; +import org.apache.flink.ml.common.gbt.operators.TerminationOperator; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; import org.apache.flink.util.OutputTag; import org.apache.commons.lang3.ArrayUtils; @@ -61,111 +62,54 @@ public BoostIterationBody(IterationID iterationID, GbtParams gbtParams) { @Override public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) { DataStream data = dataStreams.get(0); - DataStream localState = variableStreams.get(0); - - final OutputTag stateOutputTag = - new OutputTag<>("state", TypeInformation.of(LocalState.class)); - - final OutputTag finalStateOutputTag = - new OutputTag<>("final_state", TypeInformation.of(LocalState.class)); - - /** - * In the iteration, some data needs to be shared between subtasks of different operators - * within one machine. We use {@link IterationSharedStorage} with co-location mechanism to - * achieve such purpose. The data is stored in JVM static region, and is accessed through - * string keys from different operator subtasks. Note the first operator to put the data is - * the owner of the data, and only the owner can update or delete the data. - * - *

To be specified, in gradient boosting trees algorithm, there three types of shared - * data: - * - *

    - *
  • Instances (after binned) and their corresponding predictions, gradients, and - * hessians are shared to avoid being stored multiple times or communication. - *
  • When initializing every new tree, instances need to be shuffled and split to - * bagging instances and non-bagging ones. To reduce the cost, we shuffle instance - * indices other than instances. Therefore, the shuffle indices need to be shared to - * access actual instances. - *
  • After splitting nodes of each layer, instance indices need to be swapped to - * maintain {@link LearningNode#slice} and {@link LearningNode#oob}. However, we - * cannot directly update the data of shuffle indices above, as it already has an - * owner. So we use another key to store instance indices after swapping. - *
- */ - final String sharedInstancesKey = "instances"; - final String sharedPredGradHessKey = "preds_grads_hessians"; - final String sharedShuffledIndicesKey = "shuffled_indices"; - final String sharedSwappedIndicesKey = "swapped_indices"; + DataStream trainContext = variableStreams.get(0); final String coLocationKey = "boosting"; // In 1st round, cache all data. For all rounds calculate local histogram based on // current tree layer. SingleOutputStreamOperator localHists = - data.connect(localState) + data.connect(trainContext) .transform( "CacheDataCalcLocalHists", TypeInformation.of(Histogram.class), - new CacheDataCalcLocalHistsOperator( - gbtParams, - iterationID, - sharedInstancesKey, - sharedPredGradHessKey, - sharedShuffledIndicesKey, - sharedSwappedIndicesKey, - stateOutputTag)); + new CacheDataCalcLocalHistsOperator(gbtParams, iterationID)); localHists.getTransformation().setCoLocationGroupKey("coLocationKey"); - DataStream modelData = localHists.getSideOutput(stateOutputTag); DataStream globalHists = scatterReduceHistograms(localHists); SingleOutputStreamOperator localSplits = - modelData - .connect(globalHists) - .transform( - "CalcLocalSplits", - TypeInformation.of(Splits.class), - new CalcLocalSplitsOperator(stateOutputTag)); + globalHists.transform( + "CalcLocalSplits", + TypeInformation.of(Splits.class), + new CalcLocalSplitsOperator(iterationID)); localHists.getTransformation().setCoLocationGroupKey(coLocationKey); DataStream globalSplits = localSplits.broadcast().flatMap(new SplitsAggregateFunction()); - SingleOutputStreamOperator updatedModelData = - modelData - .connect(globalSplits.broadcast()) + SingleOutputStreamOperator updatedModelData = + globalSplits + .broadcast() .transform( "PostSplits", - TypeInformation.of(LocalState.class), - new PostSplitsOperator( - iterationID, - sharedInstancesKey, - sharedPredGradHessKey, - sharedShuffledIndicesKey, - sharedSwappedIndicesKey, - finalStateOutputTag)); + TypeInformation.of(Integer.class), + new PostSplitsOperator(iterationID)); updatedModelData.getTransformation().setCoLocationGroupKey(coLocationKey); - DataStream termination = - updatedModelData.flatMap( - new FlatMapFunction() { - @Override - public void flatMap(LocalState value, Collector out) { - LocalState.Dynamics dynamics = value.dynamics; - boolean terminated = - !dynamics.inWeakLearner - && dynamics.roots.size() - == value.statics.params.maxIter; - // TODO: add validation error rate - if (!terminated) { - out.collect(0); - } - } - }); + final OutputTag modelDataOutputTag = + new OutputTag<>("model_data", TypeInformation.of(GBTModelData.class)); + SingleOutputStreamOperator termination = + updatedModelData.transform( + "check_termination", + Types.INT, + new TerminationOperator(iterationID, modelDataOutputTag)); termination.getTransformation().setCoLocationGroupKey(coLocationKey); return new IterationBodyResult( - DataStreamList.of(updatedModelData), - DataStreamList.of(updatedModelData.getSideOutput(finalStateOutputTag)), + DataStreamList.of( + updatedModelData.flatMap( + (d, out) -> {}, TypeInformation.of(TrainContext.class))), + DataStreamList.of(termination.getSideOutput(modelDataOutputTag)), termination); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index 3904819ba..d34f44211 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -30,14 +30,13 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.typeinfo.GBTModelDataSerializer; import org.apache.flink.ml.common.gbt.typeinfo.GBTModelDataTypeInfoFactory; import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -68,19 +67,20 @@ public class GBTModelData { public double prior; public double stepSize; - public List roots; + public List> allTrees; public IntObjectHashMap> categoryToIdMaps; public IntObjectHashMap featureIdToBinEdges; public BitSet isCategorical; public GBTModelData() {} + // TODO: !!! public GBTModelData( String type, boolean isInputVector, double prior, double stepSize, - List roots, + List> allTrees, IntObjectHashMap> categoryToIdMaps, IntObjectHashMap featureIdToBinEdges, BitSet isCategorical) { @@ -88,18 +88,18 @@ public GBTModelData( this.isInputVector = isInputVector; this.prior = prior; this.stepSize = stepSize; - this.roots = roots; + this.allTrees = allTrees; this.categoryToIdMaps = categoryToIdMaps; this.featureIdToBinEdges = featureIdToBinEdges; this.isCategorical = isCategorical; } - public static GBTModelData fromLocalState(LocalState state) { + public static GBTModelData from(TrainContext trainContext, List> allTrees) { IntObjectHashMap> categoryToIdMaps = new IntObjectHashMap<>(); IntObjectHashMap featureIdToBinEdges = new IntObjectHashMap<>(); BitSet isCategorical = new BitSet(); - FeatureMeta[] featureMetas = state.statics.featureMetas; + FeatureMeta[] featureMetas = trainContext.featureMetas; for (int k = 0; k < featureMetas.length; k += 1) { FeatureMeta featureMeta = featureMetas[k]; if (featureMeta instanceof FeatureMeta.CategoricalFeatureMeta) { @@ -116,11 +116,11 @@ public static GBTModelData fromLocalState(LocalState state) { } } return new GBTModelData( - state.statics.params.taskType.name(), - state.statics.params.isInputVector, - state.statics.prior, - state.statics.params.stepSize, - state.dynamics.roots, + trainContext.params.taskType.name(), + trainContext.params.isInputVector, + trainContext.prior, + trainContext.params.stepSize, + allTrees, categoryToIdMaps, featureIdToBinEdges, isCategorical); @@ -173,11 +173,11 @@ public IntDoubleHashMap rowToFeatures(Row row, String[] featureCols, String vect public double predictRaw(IntDoubleHashMap rawFeatures) { double v = prior; - for (Node root : roots) { - Node node = root; + for (List treeNodes : allTrees) { + Node node = treeNodes.get(0); while (!node.isLeaf) { boolean goLeft = node.split.shouldGoLeft(rawFeatures); - node = goLeft ? node.left : node.right; + node = goLeft ? treeNodes.get(node.left) : treeNodes.get(node.right); } v += stepSize * node.split.prediction; } @@ -187,8 +187,8 @@ public double predictRaw(IntDoubleHashMap rawFeatures) { @Override public String toString() { return String.format( - "GBTModelData{type=%s, prior=%s, roots=%s, categoryToIdMaps=%s, featureIdToBinEdges=%s, isCategorical=%s}", - type, prior, roots, categoryToIdMaps, featureIdToBinEdges, isCategorical); + "GBTModelData{type=%s, prior=%s, allTrees=%s, categoryToIdMaps=%s, featureIdToBinEdges=%s, isCategorical=%s}", + type, prior, allTrees, categoryToIdMaps, featureIdToBinEdges, isCategorical); } /** Encoder for {@link GBTModelData}. */ diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index f1c73473b..2630111a1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -31,8 +31,8 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.GbtParams; -import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.regression.gbtregressor.GBTRegressorParams; import org.apache.flink.streaming.api.datastream.DataStream; @@ -100,7 +100,7 @@ private static DataStream boost( bcMap.put(featureMetaBcName, featureMeta); bcMap.put(labelSumCountBcName, labelSumCount); - DataStream initStates = + DataStream initTrainContext = BroadcastUtils.withBroadcastStream( Collections.singletonList( tEnv.toDataStream(tEnv.fromValues(0), Integer.class)), @@ -109,7 +109,7 @@ private static DataStream boost( //noinspection unchecked DataStream input = (DataStream) (inputs.get(0)); return input.map( - new InitLocalStateFunction( + new InitTrainContextFunction( featureMetaBcName, labelSumCountBcName, p)); }); @@ -117,14 +117,13 @@ private static DataStream boost( final IterationID iterationID = new IterationID(); DataStreamList dataStreamList = Iterations.iterateBoundedStreamsUntilTermination( - DataStreamList.of(initStates.broadcast()), + DataStreamList.of(initTrainContext.broadcast()), ReplayableDataStreamList.notReplay(data, featureMeta), IterationConfig.newBuilder() .setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND) .build(), new BoostIterationBody(iterationID, p)); - DataStream state = dataStreamList.get(0); - return state.map(GBTModelData::fromLocalState); + return dataStreamList.get(0); } public static GbtParams fromEstimator(BaseGBTParams estimator, TaskType taskType) { @@ -181,12 +180,12 @@ public static GbtParams fromEstimator(BaseGBTParams estimator, TaskType taskT return p; } - private static class InitLocalStateFunction extends RichMapFunction { + private static class InitTrainContextFunction extends RichMapFunction { private final String featureMetaBcName; private final String labelSumCountBcName; private final GbtParams p; - private InitLocalStateFunction( + private InitTrainContextFunction( String featureMetaBcName, String labelSumCountBcName, GbtParams p) { this.featureMetaBcName = featureMetaBcName; this.labelSumCountBcName = labelSumCountBcName; @@ -194,24 +193,24 @@ private InitLocalStateFunction( } @Override - public LocalState map(Integer value) { - LocalState.Statics statics = new LocalState.Statics(); - statics.params = p; - statics.featureMetas = + public TrainContext map(Integer value) { + TrainContext trainContext = new TrainContext(); + trainContext.params = p; + trainContext.featureMetas = getRuntimeContext() .getBroadcastVariable(featureMetaBcName) .toArray(new FeatureMeta[0]); - if (!statics.params.isInputVector) { + if (!trainContext.params.isInputVector) { Arrays.sort( - statics.featureMetas, + trainContext.featureMetas, Comparator.comparing(d -> ArrayUtils.indexOf(p.featureCols, d.name))); } - statics.numFeatures = statics.featureMetas.length; - statics.labelSumCount = + trainContext.numFeatures = trainContext.featureMetas.length; + trainContext.labelSumCount = getRuntimeContext() .>getBroadcastVariable(labelSumCountBcName) .get(0); - return new LocalState(statics, new LocalState.Dynamics()); + return trainContext; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java index ea6d84bbc..f55d47773 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java @@ -52,4 +52,6 @@ public class GbtParams implements Serializable { public int maxNumLeaves; // useMissing is always true right now. public boolean useMissing; + + public GbtParams() {} } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java index bfc9a2641..497460327 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java @@ -35,6 +35,8 @@ public class Histogram implements Serializable { // Stores the number of elements received by subtasks in scattering. public int[] recvcnts; + public Histogram() {} + public Histogram(int subtaskId, double[] hists, int[] recvcnts) { this.subtaskId = subtaskId; this.hists = hists; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java index e208269fb..a97c6d435 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java @@ -21,8 +21,8 @@ /** A node used in learning procedure. */ public class LearningNode { - // The corresponding tree node. - public Node node; + // The node index in `currentTreeNodes` used in `PostSplitsOperator`. + public int nodeIndex; // Slice of indices of bagging instances. public Slice slice; // Slice of indices of non-bagging instances. @@ -30,8 +30,10 @@ public class LearningNode { // Depth of corresponding tree node. public int depth; - public LearningNode(Node node, Slice slice, Slice oob, int depth) { - this.node = node; + public LearningNode() {} + + public LearningNode(int nodeIndex, Slice slice, Slice oob, int depth) { + this.nodeIndex = nodeIndex; this.slice = slice; this.oob = oob; this.depth = depth; @@ -40,6 +42,7 @@ public LearningNode(Node node, Slice slice, Slice oob, int depth) { @Override public String toString() { return String.format( - "LearningNode{node=%s, slice=%s, oob=%s, depth=%d}", node, slice, oob, depth); + "LearningNode{nodeIndex=%s, slice=%s, oob=%s, depth=%d}", + nodeIndex, slice, oob, depth); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java deleted file mode 100644 index c13df2d43..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LocalState.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.defs; - -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.common.gbt.loss.Loss; - -import org.eclipse.collections.api.tuple.primitive.IntIntPair; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; - -/** - * Stores training state, including static parts and dynamic parts. Static parts won't change across - * the iteration rounds (except initialization), while dynamic parts are updated on every round. - * - *

An instance of training states is bound to a subtask id, so the operators accepting training - * states should be co-located. - */ -public class LocalState implements Serializable { - - public Statics statics; - public Dynamics dynamics; - - public LocalState(Statics statics, Dynamics dynamics) { - this.statics = statics; - this.dynamics = dynamics; - } - - /** Static part of local state. */ - public static class Statics { - - public int subtaskId; - public int numSubtasks; - public GbtParams params; - - public int numInstances; - public int numBaggingInstances; - public Random instanceRandomizer; - - public int numFeatures; - public int numBaggingFeatures; - public Random featureRandomizer; - - public FeatureMeta[] featureMetas; - public int[] numFeatureBins; - - public Tuple2 labelSumCount; - public double prior; - public Loss loss; - } - - /** Dynamic part of local state. */ - public static class Dynamics { - // Root nodes of every tree. - public List roots = new ArrayList<>(); - // Initializes a new tree when false, otherwise splits nodes in current layer. - public boolean inWeakLearner; - - // Nodes to be split in the current layer. - public List layer = new ArrayList<>(); - // Node ID and feature ID pairs to be considered in current layer. - public List nodeFeaturePairs = new ArrayList<>(); - // Leaf nodes in the current tree. - public List leaves = new ArrayList<>(); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java index 691d1c0ef..82121fdf6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java @@ -18,13 +18,19 @@ package org.apache.flink.ml.common.gbt.defs; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.common.gbt.typeinfo.NodeTypeInfoFactory; + import java.io.Serializable; /** Tree node in the decision tree that will be serialized to json and deserialized from json. */ +@TypeInfo(NodeTypeInfoFactory.class) public class Node implements Serializable { public Split split; public boolean isLeaf = false; - public Node left; - public Node right; + public int left; + public int right; + + public Node() {} } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java index 44ece2453..e29a7dc39 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java @@ -24,6 +24,8 @@ public final class Slice { public int start; public int end; + public Slice() {} + public Slice(int start, int end) { this.start = start; this.end = end; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java index 78052bbe7..226577262 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java @@ -18,12 +18,16 @@ package org.apache.flink.ml.common.gbt.defs; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.common.gbt.typeinfo.SplitTypeInfoFactory; + import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; import java.util.BitSet; /** Stores a split on a feature. */ +@TypeInfo(SplitTypeInfoFactory.class) public abstract class Split { public static final double INVALID_GAIN = 0.0; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java index c80598383..72c69e4ca 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java @@ -31,6 +31,8 @@ public class Splits { // Stores splits of nodes in the current layer. public Split[] splits; + public Splits() {} + public Splits(int subtaskId, Split[] splits) { this.subtaskId = subtaskId; this.splits = splits; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java new file mode 100644 index 000000000..2364ab203 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.gbt.loss.Loss; + +import java.util.Random; + +/** Stores the training context. */ +public class TrainContext { + public int subtaskId; + public int numSubtasks; + public GbtParams params; + + public int numInstances; + public int numBaggingInstances; + public Random instanceRandomizer; + + public int numFeatures; + public int numBaggingFeatures; + public Random featureRandomizer; + + public FeatureMeta[] featureMetas; + public int[] numFeatureBins; + + public Tuple2 labelSumCount; + public double prior; + public Loss loss; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 1b7d41770..2f5fd4e8e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -20,8 +20,10 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; @@ -30,10 +32,12 @@ import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.Histogram; -import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.PredGradHess; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.loss.Loss; import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.runtime.state.StateInitializationContext; @@ -43,7 +47,6 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; -import org.apache.flink.util.OutputTag; import org.apache.flink.util.Preconditions; import org.apache.commons.collections.IteratorUtils; @@ -52,32 +55,26 @@ import org.slf4j.LoggerFactory; import java.util.Collections; +import java.util.List; /** * Calculates local histograms for local data partition. Specifically in the first round, this * operator caches all data instances to JVM static region. */ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator - implements TwoInputStreamOperator, + implements TwoInputStreamOperator, IterationListener { private static final Logger LOG = LoggerFactory.getLogger(CacheDataCalcLocalHistsOperator.class); - private static final String LOCAL_STATE_STATE_NAME = "local_state"; private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; private static final String HIST_BUILDER_STATE_NAME = "hist_builder"; private final GbtParams gbtParams; private final IterationID iterationID; - private final String sharedInstancesKey; - private final String sharedPredGradHessKey; - private final String sharedShuffledIndicesKey; - private final String sharedSwappedIndicesKey; - private final OutputTag stateOutputTag; // States of local data. private transient ListStateWithCache instancesCollecting; - private transient ListState localState; private transient ListState treeInitializer; private transient ListState histBuilder; @@ -86,23 +83,17 @@ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator pghReader; private transient IterationSharedStorage.Writer shuffledIndicesWriter; private transient IterationSharedStorage.Reader swappedIndicesReader; + private IterationSharedStorage.Writer nodeFeaturePairsWriter; + private IterationSharedStorage.Reader> layerReader; + private IterationSharedStorage.Writer rootLearningNodeWriter; + private IterationSharedStorage.Reader needInitTreeReader; + private IterationSharedStorage.Writer hasInitedTreeWriter; + private IterationSharedStorage.Writer trainContextWriter; - public CacheDataCalcLocalHistsOperator( - GbtParams gbtParams, - IterationID iterationID, - String sharedInstancesKey, - String sharedPredGradHessKey, - String sharedShuffledIndicesKey, - String sharedSwappedIndicesKey, - OutputTag stateOutputTag) { + public CacheDataCalcLocalHistsOperator(GbtParams gbtParams, IterationID iterationID) { super(); this.gbtParams = gbtParams; this.iterationID = iterationID; - this.sharedInstancesKey = sharedInstancesKey; - this.sharedPredGradHessKey = sharedPredGradHessKey; - this.sharedShuffledIndicesKey = sharedShuffledIndicesKey; - this.sharedSwappedIndicesKey = sharedSwappedIndicesKey; - this.stateOutputTag = stateOutputTag; } @Override @@ -116,11 +107,6 @@ public void initializeState(StateInitializationContext context) throws Exception getRuntimeContext(), context, getOperatorID()); - localState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - LOCAL_STATE_STATE_NAME, LocalState.class)); treeInitializer = context.getOperatorStateStore() .getListState( @@ -137,7 +123,7 @@ public void initializeState(StateInitializationContext context) throws Exception IterationSharedStorage.getWriter( iterationID, subtaskId, - sharedInstancesKey, + SharedKeys.INSTANCES, getOperatorID(), new GenericArraySerializer<>( BinnedInstance.class, BinnedInstanceSerializer.INSTANCE), @@ -148,16 +134,62 @@ public void initializeState(StateInitializationContext context) throws Exception IterationSharedStorage.getWriter( iterationID, subtaskId, - sharedShuffledIndicesKey, + SharedKeys.SHUFFLED_INDICES, getOperatorID(), IntPrimitiveArraySerializer.INSTANCE, new int[0]); shuffledIndicesWriter.initializeState(context); + nodeFeaturePairsWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.NODE_FEATURE_PAIRS, + getOperatorID(), + IntPrimitiveArraySerializer.INSTANCE, + new int[0]); + nodeFeaturePairsWriter.initializeState(context); + + rootLearningNodeWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.ROOT_LEARNING_NODE, + getOperatorID(), + LearningNodeSerializer.INSTANCE, + new LearningNode()); + rootLearningNodeWriter.initializeState(context); + + hasInitedTreeWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.HAS_INITED_TREE, + getOperatorID(), + BooleanSerializer.INSTANCE, + false); + hasInitedTreeWriter.initializeState(context); + + trainContextWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.TRAIN_CONTEXT, + getOperatorID(), + new KryoSerializer<>(TrainContext.class, getExecutionConfig()), + new TrainContext()); + trainContextWriter.initializeState(context); + this.pghReader = - IterationSharedStorage.getReader(iterationID, subtaskId, sharedPredGradHessKey); + IterationSharedStorage.getReader( + iterationID, subtaskId, SharedKeys.PREDS_GRADS_HESSIANS); this.swappedIndicesReader = - IterationSharedStorage.getReader(iterationID, subtaskId, sharedSwappedIndicesKey); + IterationSharedStorage.getReader( + iterationID, subtaskId, SharedKeys.SWAPPED_INDICES); + this.layerReader = + IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.LAYER); + this.needInitTreeReader = + IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.NEED_INIT_TREE); } @Override @@ -166,19 +198,22 @@ public void snapshotState(StateSnapshotContext context) throws Exception { instancesCollecting.snapshotState(context); instancesWriter.snapshotState(context); shuffledIndicesWriter.snapshotState(context); + hasInitedTreeWriter.snapshotState(context); } @Override public void processElement1(StreamRecord streamRecord) throws Exception { Row row = streamRecord.getValue(); - IntIntHashMap features = new IntIntHashMap(); + IntIntHashMap features; if (gbtParams.isInputVector) { Vector vec = row.getFieldAs(gbtParams.vectorCol); SparseVector sv = vec.toSparse(); + features = new IntIntHashMap(sv.indices.length); for (int i = 0; i < sv.indices.length; i += 1) { features.put(sv.indices[i], (int) sv.values[i]); } } else { + features = new IntIntHashMap(gbtParams.featureCols.length); for (int i = 0; i < gbtParams.featureCols.length; i += 1) { // Values from StringIndexModel#transform are double. features.put(i, ((Number) row.getFieldAs(gbtParams.featureCols[i])).intValue()); @@ -189,16 +224,18 @@ public void processElement1(StreamRecord streamRecord) throws Exception { } @Override - public void processElement2(StreamRecord streamRecord) throws Exception { - localState.update(Collections.singletonList(streamRecord.getValue())); + public void processElement2(StreamRecord streamRecord) throws Exception { + TrainContext trainContext = streamRecord.getValue(); + if (null != trainContext) { + // Not null only in first round. + trainContextWriter.set(trainContext); + } } @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector out) throws Exception { - LocalState localStateValue = - OperatorStateUtils.getUniqueElement(localState, "local_state").get(); if (0 == epochWatermark) { // Initializes local state in first round. instancesWriter.set( @@ -206,28 +243,30 @@ public void onEpochWatermarkIncremented( IteratorUtils.toArray( instancesCollecting.get().iterator(), BinnedInstance.class)); instancesCollecting.clear(); - new LocalStateInitializer(gbtParams) - .init( - localStateValue, - getRuntimeContext().getIndexOfThisSubtask(), - getRuntimeContext().getNumberOfParallelSubtasks(), - instancesWriter.get()); - - treeInitializer.update( - Collections.singletonList(new TreeInitializer(localStateValue.statics))); - histBuilder.update(Collections.singletonList(new HistBuilder(localStateValue.statics))); + TrainContext trainContext = + new TrainContextInitializer(gbtParams) + .init( + trainContextWriter.get(), + getRuntimeContext().getIndexOfThisSubtask(), + getRuntimeContext().getNumberOfParallelSubtasks(), + instancesWriter.get()); + trainContextWriter.set(trainContext); + + treeInitializer.update(Collections.singletonList(new TreeInitializer(trainContext))); + histBuilder.update(Collections.singletonList(new HistBuilder(trainContext))); } + TrainContext trainContext = trainContextWriter.get(); BinnedInstance[] instances = instancesWriter.get(); Preconditions.checkArgument( - getRuntimeContext().getIndexOfThisSubtask() == localStateValue.statics.subtaskId); + getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); PredGradHess[] pgh = pghReader.get(); // In the first round, use prior as the predictions. if (0 == pgh.length) { pgh = new PredGradHess[instances.length]; - double prior = localStateValue.statics.prior; - Loss loss = localStateValue.statics.loss; + double prior = trainContext.prior; + Loss loss = trainContext.loss; for (int i = 0; i < instances.length; i += 1) { double label = instances[i].label; pgh[i] = @@ -237,47 +276,63 @@ public void onEpochWatermarkIncremented( } int[] indices; - if (!localStateValue.dynamics.inWeakLearner) { + if (needInitTreeReader.get()) { + TreeInitializer treeInit = + OperatorStateUtils.getUniqueElement( + treeInitializer, TREE_INITIALIZER_STATE_NAME) + .get(); + // When last tree is finished, initializes a new tree, and shuffle instance indices. - OperatorStateUtils.getUniqueElement(treeInitializer, TREE_INITIALIZER_STATE_NAME) - .get() - .init(localStateValue.dynamics, shuffledIndicesWriter::set); - localStateValue.dynamics.inWeakLearner = true; + treeInit.init(shuffledIndicesWriter::set); + + LearningNode rootLearningNode = treeInit.getRootLearningNode(); indices = shuffledIndicesWriter.get(); + rootLearningNodeWriter.set(rootLearningNode); + hasInitedTreeWriter.set(true); } else { // Otherwise, uses the swapped instance indices. shuffledIndicesWriter.set(new int[0]); indices = swappedIndicesReader.get(); + hasInitedTreeWriter.set(false); + } + + List layer = layerReader.get(); + if (layer.size() == 0) { + layer = Collections.singletonList(rootLearningNodeWriter.get()); } + int[] nodeFeaturePairs = + OperatorStateUtils.getUniqueElement(histBuilder, HIST_BUILDER_STATE_NAME) + .get() + .getNodeFeaturePairs(layer.size()); + nodeFeaturePairsWriter.set(nodeFeaturePairs); + Histogram localHists = OperatorStateUtils.getUniqueElement(histBuilder, HIST_BUILDER_STATE_NAME) .get() - .build( - localStateValue.dynamics.layer, - localStateValue.dynamics.nodeFeaturePairs, - indices, - instances, - pgh); + .build(layer, nodeFeaturePairs, indices, instances, pgh); out.collect(localHists); - context.output(stateOutputTag, localStateValue); } @Override public void onIterationTerminated(Context context, Collector collector) { instancesCollecting.clear(); - localState.clear(); treeInitializer.clear(); histBuilder.clear(); instancesWriter.set(new BinnedInstance[0]); shuffledIndicesWriter.set(new int[0]); + nodeFeaturePairsWriter.set(new int[0]); } @Override public void close() throws Exception { instancesWriter.remove(); shuffledIndicesWriter.remove(); + nodeFeaturePairsWriter.remove(); + rootLearningNodeWriter.remove(); + hasInitedTreeWriter.remove(); + trainContextWriter.remove(); super.close(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java index 3cf86557c..d32015679 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -20,47 +20,47 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.Histogram; -import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; -import org.apache.flink.util.OutputTag; import java.util.Collections; +import java.util.List; /** Calculates local splits for assigned (nodeId, featureId) pairs. */ public class CalcLocalSplitsOperator extends AbstractStreamOperator - implements TwoInputStreamOperator, - IterationListener { + implements OneInputStreamOperator, IterationListener { - private static final String LOCAL_STATE_STATE_NAME = "local_state"; private static final String CALC_BEST_SPLIT_STATE_NAME = "split_finder"; private static final String HISTOGRAM_STATE_NAME = "histogram"; - private final OutputTag stateOutputTag; + private final IterationID iterationID; - private transient ListState localState; private transient ListState splitFinder; private transient ListState histogram; + private IterationSharedStorage.Reader nodeFeaturePairsReader; + private IterationSharedStorage.Reader> leavesReader; + private IterationSharedStorage.Reader> layerReader; + private IterationSharedStorage.Reader rootLearningNodeReader; + private IterationSharedStorage.Reader trainContextReader; - public CalcLocalSplitsOperator(OutputTag stateOutputTag) { - this.stateOutputTag = stateOutputTag; + public CalcLocalSplitsOperator(IterationID iterationID) { + this.iterationID = iterationID; } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - localState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - LOCAL_STATE_STATE_NAME, LocalState.class)); splitFinder = context.getOperatorStateStore() .getListState( @@ -70,40 +70,51 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>(HISTOGRAM_STATE_NAME, Histogram.class)); + + int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); + nodeFeaturePairsReader = + IterationSharedStorage.getReader( + iterationID, subtaskId, SharedKeys.NODE_FEATURE_PAIRS); + leavesReader = IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.LEAVES); + layerReader = IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.LAYER); + rootLearningNodeReader = + IterationSharedStorage.getReader( + iterationID, subtaskId, SharedKeys.ROOT_LEARNING_NODE); + trainContextReader = + IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.TRAIN_CONTEXT); } @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { - LocalState localStateValue = - OperatorStateUtils.getUniqueElement(localState, LOCAL_STATE_STATE_NAME).get(); if (0 == epochWatermark) { - splitFinder.update(Collections.singletonList(new SplitFinder(localStateValue.statics))); + splitFinder.update( + Collections.singletonList(new SplitFinder(trainContextReader.get()))); + } + + List layer = layerReader.get(); + if (layer.size() == 0) { + layer = Collections.singletonList(rootLearningNodeReader.get()); } + Splits splits = OperatorStateUtils.getUniqueElement(splitFinder, CALC_BEST_SPLIT_STATE_NAME) .get() .calc( - localStateValue.dynamics.layer, - localStateValue.dynamics.nodeFeaturePairs, - localStateValue.dynamics.leaves, + layer, + nodeFeaturePairsReader.get(), + leavesReader.get().size(), OperatorStateUtils.getUniqueElement(histogram, HISTOGRAM_STATE_NAME) .get()); collector.collect(splits); - context.output(stateOutputTag, localStateValue); } @Override public void onIterationTerminated(Context context, Collector collector) {} @Override - public void processElement1(StreamRecord element) throws Exception { - localState.update(Collections.singletonList(element.getValue())); - } - - @Override - public void processElement2(StreamRecord element) throws Exception { + public void processElement(StreamRecord element) throws Exception { histogram.update(Collections.singletonList(element.getValue())); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index a9428b5b4..89c18ef5e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -24,11 +24,9 @@ import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.PredGradHess; +import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.eclipse.collections.api.tuple.primitive.IntIntPair; -import org.eclipse.collections.impl.tuple.primitive.PrimitiveTuples; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,23 +52,23 @@ class HistBuilder { private final double[] hists; - public HistBuilder(LocalState.Statics statics) { - subtaskId = statics.subtaskId; - numSubtasks = statics.numSubtasks; + public HistBuilder(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + numSubtasks = trainContext.numSubtasks; - numFeatureBins = statics.numFeatureBins; - featureMetas = statics.featureMetas; + numFeatureBins = trainContext.numFeatureBins; + featureMetas = trainContext.featureMetas; - numBaggingFeatures = statics.numBaggingFeatures; - featureRandomizer = statics.featureRandomizer; - featureIndicesPool = IntStream.range(0, statics.numFeatures).toArray(); + numBaggingFeatures = trainContext.numBaggingFeatures; + featureRandomizer = trainContext.featureRandomizer; + featureIndicesPool = IntStream.range(0, trainContext.numFeatures).toArray(); - isInputVector = statics.params.isInputVector; + isInputVector = trainContext.params.isInputVector; int maxNumNodes = Math.min( - ((int) Math.pow(2, statics.params.maxDepth - 1)), - statics.params.maxNumLeaves); + ((int) Math.pow(2, trainContext.params.maxDepth - 1)), + trainContext.params.maxNumLeaves); int maxFeatureBins = Arrays.stream(numFeatureBins).max().orElse(0); int totalNumFeatureBins = Arrays.stream(numFeatureBins).sum(); @@ -85,7 +83,7 @@ public HistBuilder(LocalState.Statics statics) { */ private static void calcNodeFeaturePairHists( List layer, - List nodeFeaturePairs, + int[] nodeFeaturePairs, FeatureMeta[] featureMetas, boolean isInputVector, int[] numFeatureBins, @@ -95,9 +93,9 @@ private static void calcNodeFeaturePairHists( double[] hists) { Arrays.fill(hists, 0.); int binOffset = 0; - for (IntIntPair nodeFeaturePair : nodeFeaturePairs) { - int nodeId = nodeFeaturePair.getOne(); - int featureId = nodeFeaturePair.getTwo(); + for (int k = 0; k < nodeFeaturePairs.length; k += 2) { + int nodeId = nodeFeaturePairs[k]; + int featureId = nodeFeaturePairs[k + 1]; FeatureMeta featureMeta = featureMetas[featureId]; int defaultValue = featureMeta.missingBin; @@ -131,40 +129,45 @@ private static void calcNodeFeaturePairHists( * subtask. */ private static int[] calcRecvCounts( - int numSubtasks, List nodeFeaturePairs, int[] numFeatureBins) { + int numSubtasks, int[] nodeFeaturePairs, int[] numFeatureBins) { int[] recvcnts = new int[numSubtasks]; Distributor.EvenDistributor distributor = - new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.size()); + new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.length / 2); for (int k = 0; k < numSubtasks; k += 1) { int pairStart = (int) distributor.start(k); int pairCnt = (int) distributor.count(k); for (int i = pairStart; i < pairStart + pairCnt; i += 1) { - int featureId = nodeFeaturePairs.get(i).getTwo(); + int featureId = nodeFeaturePairs[2 * i + 1]; recvcnts[k] += numFeatureBins[featureId] * DataUtils.BIN_SIZE; } } return recvcnts; } + /** Generates (nodeId, featureId) pairs that are required to build histograms. */ + int[] getNodeFeaturePairs(int numLayerNodes) { + int[] nodeFeaturePairs = new int[numLayerNodes * numBaggingFeatures * 2]; + int p = 0; + for (int k = 0; k < numLayerNodes; k += 1) { + int[] sampledFeatures = + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + for (int featureId : sampledFeatures) { + nodeFeaturePairs[p++] = k; + nodeFeaturePairs[p++] = featureId; + } + } + return nodeFeaturePairs; + } + /** Calculate local histograms for nodes in current layer of tree. */ - public Histogram build( + Histogram build( List layer, - List nodeFeaturePairs, + int[] nodeFeaturePairs, int[] indices, BinnedInstance[] instances, PredGradHess[] pgh) { LOG.info("subtaskId: {}, {} start", subtaskId, HistBuilder.class.getSimpleName()); - // Generates (nodeId, featureId) pairs that are required to build histograms. - nodeFeaturePairs.clear(); - for (int k = 0; k < layer.size(); k += 1) { - int[] sampledFeatures = - DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); - for (int featureId : sampledFeatures) { - nodeFeaturePairs.add(PrimitiveTuples.pair(k, featureId)); - } - } - // Calculates histograms for (nodeId, featureId) pairs. calcNodeFeaturePairHists( layer, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java index cb5482014..1ddb01369 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -20,9 +20,10 @@ import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.loss.Loss; import org.slf4j.Logger; @@ -42,12 +43,12 @@ class InstanceUpdater { private boolean initialized; - public InstanceUpdater(LocalState.Statics statics) { - subtaskId = statics.subtaskId; - loss = statics.loss; - stepSize = statics.params.stepSize; - prior = statics.prior; - pgh = new PredGradHess[statics.numInstances]; + public InstanceUpdater(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + loss = trainContext.loss; + stepSize = trainContext.params.stepSize; + prior = trainContext.prior; + pgh = new PredGradHess[trainContext.numInstances]; initialized = false; } @@ -55,7 +56,8 @@ public void update( List leaves, int[] indices, BinnedInstance[] instances, - Consumer pghSetter) { + Consumer pghSetter, + List treeNodes) { LOG.info("subtaskId: {}, {} start", subtaskId, InstanceUpdater.class.getSimpleName()); if (!initialized) { for (int i = 0; i < instances.length; i += 1) { @@ -68,7 +70,7 @@ public void update( } for (LearningNode nodeInfo : leaves) { - Split split = nodeInfo.node.split; + Split split = treeNodes.get(nodeInfo.nodeIndex).split; double pred = split.prediction * stepSize; for (int i = nodeInfo.slice.start; i < nodeInfo.slice.end; ++i) { int instanceId = indices[i]; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java index 72e96352d..15f7012b9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java @@ -21,10 +21,10 @@ import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.Slice; import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -41,11 +41,11 @@ class NodeSplitter { private final int maxLeaves; private final int maxDepth; - public NodeSplitter(LocalState.Statics statics) { - subtaskId = statics.subtaskId; - featureMetas = statics.featureMetas; - maxLeaves = statics.params.maxNumLeaves; - maxDepth = statics.params.maxDepth; + public NodeSplitter(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + featureMetas = trainContext.featureMetas; + maxLeaves = trainContext.params.maxNumLeaves; + maxDepth = trainContext.params.maxDepth; } private int partitionInstances( @@ -69,29 +69,36 @@ private int partitionInstances( } private void splitNode( + Node treeNode, LearningNode nodeInfo, int[] indices, BinnedInstance[] instances, - List nextLayer) { - int mid = partitionInstances(nodeInfo.node.split, nodeInfo.slice, indices, instances); - int oobMid = partitionInstances(nodeInfo.node.split, nodeInfo.oob, indices, instances); - nodeInfo.node.left = new Node(); - nodeInfo.node.right = new Node(); + List nextLayer, + List treeNodes) { + int mid = partitionInstances(treeNode.split, nodeInfo.slice, indices, instances); + int oobMid = partitionInstances(treeNode.split, nodeInfo.oob, indices, instances); + + treeNode.left = treeNodes.size(); + treeNodes.add(new Node()); + treeNode.right = treeNodes.size(); + treeNodes.add(new Node()); + nextLayer.add( new LearningNode( - nodeInfo.node.left, + treeNode.left, new Slice(nodeInfo.slice.start, mid), new Slice(nodeInfo.oob.start, oobMid), nodeInfo.depth + 1)); nextLayer.add( new LearningNode( - nodeInfo.node.right, + treeNode.right, new Slice(mid, nodeInfo.slice.end), new Slice(oobMid, nodeInfo.oob.end), nodeInfo.depth + 1)); } - public void split( + public List split( + List treeNodes, List layer, List leaves, Split[] splits, @@ -108,15 +115,16 @@ public void split( LearningNode node = layer.get(i); Split split = splits[i]; numQueued -= 1; - node.node.split = split; + Node treeNode = treeNodes.get(node.nodeIndex); + treeNode.split = split; if (!split.isValid() - || node.node.isLeaf + || treeNode.isLeaf || (leaves.size() + numQueued + 2) > maxLeaves || node.depth + 1 > maxDepth) { - node.node.isLeaf = true; + treeNode.isLeaf = true; leaves.add(node); } else { - splitNode(node, indices, instances, nextLayer); + splitNode(treeNode, node, indices, instances, nextLayer, treeNodes); // Converts splits point from bin id to real feature value after splitting node. if (split instanceof Split.ContinuousSplit) { Split.ContinuousSplit cs = (Split.ContinuousSplit) split; @@ -127,10 +135,7 @@ public void split( numQueued += 2; } } - - layer.clear(); - layer.addAll(nextLayer); - LOG.info("subtaskId: {}, {} end", subtaskId, NodeSplitter.class.getSimpleName()); + return nextLayer; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 52a1ad1b3..2b526ba1d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -20,81 +20,72 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; -import org.apache.flink.ml.common.gbt.defs.LocalState; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.NodeSerializer; import org.apache.flink.ml.common.gbt.typeinfo.PredGradHessSerializer; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; -import org.apache.flink.util.OutputTag; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; /** * Post-process after global splits obtained, including split instances to left or child nodes, and * update instances scores after a tree is complete. */ -public class PostSplitsOperator extends AbstractStreamOperator - implements TwoInputStreamOperator, - IterationListener { +public class PostSplitsOperator extends AbstractStreamOperator + implements OneInputStreamOperator, IterationListener { - private static final String LOCAL_STATE_STATE_NAME = "local_state"; private static final String SPLITS_STATE_NAME = "splits"; private static final String NODE_SPLITTER_STATE_NAME = "node_splitter"; private static final String INSTANCE_UPDATER_STATE_NAME = "instance_updater"; + private static final String CURRENT_TREE_NODES_STATE_NAME = "current_tree_nodes"; private final IterationID iterationID; - private final String sharedInstancesKey; - private final String sharedPredGradHessKey; - private final String sharedShuffledIndicesKey; - private final String sharedSwappedIndicesKey; - private final OutputTag finalStateOutputTag; private IterationSharedStorage.Reader instancesReader; private IterationSharedStorage.Writer pghWriter; private IterationSharedStorage.Reader shuffledIndicesReader; private IterationSharedStorage.Writer swappedIndicesWriter; - private transient ListState localState; private transient ListState splits; private transient ListState nodeSplitter; private transient ListState instanceUpdater; + private IterationSharedStorage.Writer> leavesWriter; + private IterationSharedStorage.Writer> layerWriter; + private IterationSharedStorage.Reader rootLearningNodeReader; + private IterationSharedStorage.Writer>> allTreesWriter; + private IterationSharedStorage.Writer> currentTreeNodesWriter; + private IterationSharedStorage.Writer needInitTreeWriter; + private IterationSharedStorage.Reader trainContextReader; - public PostSplitsOperator( - IterationID iterationID, - String sharedInstancesKey, - String sharedPredGradHessKey, - String sharedShuffledIndicesKey, - String sharedSwappedIndicesKey, - OutputTag finalStateOutputTag) { + public PostSplitsOperator(IterationID iterationID) { this.iterationID = iterationID; - this.sharedInstancesKey = sharedInstancesKey; - this.sharedPredGradHessKey = sharedPredGradHessKey; - this.sharedShuffledIndicesKey = sharedShuffledIndicesKey; - this.sharedSwappedIndicesKey = sharedSwappedIndicesKey; - this.finalStateOutputTag = finalStateOutputTag; } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - localState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - LOCAL_STATE_STATE_NAME, LocalState.class)); splits = context.getOperatorStateStore() .getListState(new ListStateDescriptor<>(SPLITS_STATE_NAME, Splits.class)); @@ -114,7 +105,7 @@ public void initializeState(StateInitializationContext context) throws Exception IterationSharedStorage.getWriter( iterationID, subtaskId, - sharedPredGradHessKey, + SharedKeys.PREDS_GRADS_HESSIANS, getOperatorID(), new GenericArraySerializer<>( PredGradHess.class, PredGradHessSerializer.INSTANCE), @@ -124,16 +115,71 @@ public void initializeState(StateInitializationContext context) throws Exception IterationSharedStorage.getWriter( iterationID, subtaskId, - sharedSwappedIndicesKey, + SharedKeys.SWAPPED_INDICES, getOperatorID(), IntPrimitiveArraySerializer.INSTANCE, new int[0]); swappedIndicesWriter.initializeState(context); + leavesWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.LEAVES, + getOperatorID(), + new ListSerializer<>(LearningNodeSerializer.INSTANCE), + new ArrayList<>()); + leavesWriter.initializeState(context); + + layerWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.LAYER, + getOperatorID(), + new ListSerializer<>(LearningNodeSerializer.INSTANCE), + new ArrayList<>()); + layerWriter.initializeState(context); + + allTreesWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.ALL_TREES, + getOperatorID(), + new ListSerializer<>(new ListSerializer<>(NodeSerializer.INSTANCE)), + new ArrayList<>()); + allTreesWriter.initializeState(context); + + needInitTreeWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + SharedKeys.NEED_INIT_TREE, + getOperatorID(), + BooleanSerializer.INSTANCE, + true); + needInitTreeWriter.initializeState(context); + + currentTreeNodesWriter = + IterationSharedStorage.getWriter( + iterationID, + subtaskId, + CURRENT_TREE_NODES_STATE_NAME, + getOperatorID(), + new ListSerializer<>(NodeSerializer.INSTANCE), + new ArrayList<>()); + currentTreeNodesWriter.initializeState(context); - this.instancesReader = - IterationSharedStorage.getReader(iterationID, subtaskId, sharedInstancesKey); - this.shuffledIndicesReader = - IterationSharedStorage.getReader(iterationID, subtaskId, sharedShuffledIndicesKey); + instancesReader = + IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.INSTANCES); + shuffledIndicesReader = + IterationSharedStorage.getReader( + iterationID, subtaskId, SharedKeys.SHUFFLED_INDICES); + rootLearningNodeReader = + IterationSharedStorage.getReader( + iterationID, subtaskId, SharedKeys.ROOT_LEARNING_NODE); + trainContextReader = + IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.TRAIN_CONTEXT); } @Override @@ -141,19 +187,19 @@ public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); pghWriter.snapshotState(context); swappedIndicesWriter.snapshotState(context); + leavesWriter.snapshotState(context); + needInitTreeWriter.snapshotState(context); } @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector collector) throws Exception { - LocalState localStateValue = - OperatorStateUtils.getUniqueElement(localState, LOCAL_STATE_STATE_NAME).get(); + int epochWatermark, Context context, Collector collector) throws Exception { if (0 == epochWatermark) { nodeSplitter.update( - Collections.singletonList(new NodeSplitter(localStateValue.statics))); + Collections.singletonList(new NodeSplitter(trainContextReader.get()))); instanceUpdater.update( - Collections.singletonList(new InstanceUpdater(localStateValue.statics))); + Collections.singletonList(new InstanceUpdater(trainContextReader.get()))); } int[] indices = swappedIndicesWriter.get(); @@ -162,47 +208,62 @@ public void onEpochWatermarkIncremented( } BinnedInstance[] instances = instancesReader.get(); - OperatorStateUtils.getUniqueElement(nodeSplitter, NODE_SPLITTER_STATE_NAME) - .get() - .split( - localStateValue.dynamics.layer, - localStateValue.dynamics.leaves, - OperatorStateUtils.getUniqueElement(splits, SPLITS_STATE_NAME).get().splits, - indices, - instances); - - if (localStateValue.dynamics.layer.isEmpty()) { - localStateValue.dynamics.inWeakLearner = false; + List leaves = leavesWriter.get(); + List layer = layerWriter.get(); + List currentTreeNodes; + if (layer.size() == 0) { + layer = Collections.singletonList(rootLearningNodeReader.get()); + currentTreeNodes = new ArrayList<>(); + currentTreeNodes.add(new Node()); + } else { + currentTreeNodes = currentTreeNodesWriter.get(); + } + + List nextLayer = + OperatorStateUtils.getUniqueElement(nodeSplitter, NODE_SPLITTER_STATE_NAME) + .get() + .split( + currentTreeNodes, + layer, + leaves, + OperatorStateUtils.getUniqueElement(splits, SPLITS_STATE_NAME) + .get() + .splits, + indices, + instances); + leavesWriter.set(leaves); + layerWriter.set(nextLayer); + currentTreeNodesWriter.set(currentTreeNodes); + + if (nextLayer.isEmpty()) { + needInitTreeWriter.set(true); OperatorStateUtils.getUniqueElement(instanceUpdater, INSTANCE_UPDATER_STATE_NAME) .get() - .update(localStateValue.dynamics.leaves, indices, instances, pghWriter::set); + .update(leaves, indices, instances, pghWriter::set, currentTreeNodes); + leaves.clear(); + List> allTrees = allTreesWriter.get(); + allTrees.add(currentTreeNodes); + + leavesWriter.set(new ArrayList<>()); swappedIndicesWriter.set(new int[0]); + allTreesWriter.set(allTrees); } else { swappedIndicesWriter.set(indices); + needInitTreeWriter.set(false); } - collector.collect(localStateValue); } @Override - public void onIterationTerminated(Context context, Collector collector) - throws Exception { + public void onIterationTerminated(Context context, Collector collector) { pghWriter.set(new PredGradHess[0]); swappedIndicesWriter.set(new int[0]); - if (0 == getRuntimeContext().getIndexOfThisSubtask()) { - //noinspection OptionalGetWithoutIsPresent - context.output( - finalStateOutputTag, - OperatorStateUtils.getUniqueElement(localState, LOCAL_STATE_STATE_NAME).get()); - } - } - - @Override - public void processElement1(StreamRecord element) throws Exception { - localState.update(Collections.singletonList(element.getValue())); + leavesWriter.set(Collections.emptyList()); + layerWriter.set(Collections.emptyList()); + currentTreeNodesWriter.set(Collections.emptyList()); } @Override - public void processElement2(StreamRecord element) throws Exception { + public void processElement(StreamRecord element) throws Exception { splits.update(Collections.singletonList(element.getValue())); } @@ -210,6 +271,11 @@ public void processElement2(StreamRecord element) throws Exception { public void close() throws Exception { pghWriter.remove(); swappedIndicesWriter.remove(); + leavesWriter.remove(); + layerWriter.remove(); + allTreesWriter.remove(); + currentTreeNodesWriter.remove(); + needInitTreeWriter.remove(); super.close(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java new file mode 100644 index 000000000..4e0278969 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; + +/** Stores keys for shared data stored in {@link IterationSharedStorage}. */ +class SharedKeys { + /** + * In the iteration, some data needs to be shared between subtasks of different operators within + * one machine. We use {@link IterationSharedStorage} with co-location mechanism to achieve such + * purpose. The data is stored in JVM static region, and is accessed through string keys from + * different operator subtasks. Note the first operator to put the data is the owner of the + * data, and only the owner can update or delete the data. + * + *

To be specified, in gradient boosting trees algorithm, there three types of shared data: + * + *

    + *
  • Instances (after binned) and their corresponding predictions, gradients, and hessians + * are shared to avoid being stored multiple times or communication. + *
  • When initializing every new tree, instances need to be shuffled and split to bagging + * instances and non-bagging ones. To reduce the cost, we shuffle instance indices other + * than instances. Therefore, the shuffle indices need to be shared to access actual + * instances. + *
  • After splitting nodes of each layer, instance indices need to be swapped to maintain + * {@link LearningNode#slice} and {@link LearningNode#oob}. However, we cannot directly + * update the data of shuffle indices above, as it already has an owner. So we use another + * key to store instance indices after swapping. + *
+ */ + static final String INSTANCES = "instances"; + + static final String PREDS_GRADS_HESSIANS = "preds_grads_hessians"; + static final String SHUFFLED_INDICES = "shuffled_indices"; + static final String SWAPPED_INDICES = "swapped_indices"; + + static final String NODE_FEATURE_PAIRS = "node_feature_pairs"; + static final String LEAVES = "leaves"; + static final String LAYER = "layer"; + + static final String ROOT_LEARNING_NODE = "root_learning_node"; + static final String ALL_TREES = "all_trees"; + static final String NEED_INIT_TREE = "need_init_tree"; + static final String HAS_INITED_TREE = "has_inited_tree"; + + static final String TRAIN_CONTEXT = "train_context"; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java index ea4cf302b..da5ef22b2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java @@ -22,16 +22,15 @@ import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.Slice; import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.splitter.CategoricalFeatureSplitter; import org.apache.flink.ml.common.gbt.splitter.ContinuousFeatureSplitter; import org.apache.flink.ml.common.gbt.splitter.HistogramFeatureSplitter; import org.apache.flink.util.Preconditions; -import org.eclipse.collections.api.tuple.primitive.IntIntPair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,44 +46,42 @@ class SplitFinder { private final int maxDepth; private final int maxNumLeaves; - public SplitFinder(LocalState.Statics statics) { - subtaskId = statics.subtaskId; - numSubtasks = statics.numSubtasks; + public SplitFinder(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + numSubtasks = trainContext.numSubtasks; - numFeatureBins = statics.numFeatureBins; - FeatureMeta[] featureMetas = statics.featureMetas; - splitters = new HistogramFeatureSplitter[statics.numFeatures]; - for (int i = 0; i < statics.numFeatures; ++i) { + numFeatureBins = trainContext.numFeatureBins; + FeatureMeta[] featureMetas = trainContext.featureMetas; + splitters = new HistogramFeatureSplitter[trainContext.numFeatures]; + for (int i = 0; i < trainContext.numFeatures; ++i) { splitters[i] = FeatureMeta.Type.CATEGORICAL == featureMetas[i].type - ? new CategoricalFeatureSplitter(i, featureMetas[i], statics.params) - : new ContinuousFeatureSplitter(i, featureMetas[i], statics.params); + ? new CategoricalFeatureSplitter( + i, featureMetas[i], trainContext.params) + : new ContinuousFeatureSplitter( + i, featureMetas[i], trainContext.params); } - maxDepth = statics.params.maxDepth; - maxNumLeaves = statics.params.maxNumLeaves; + maxDepth = trainContext.params.maxDepth; + maxNumLeaves = trainContext.params.maxNumLeaves; } public Splits calc( - List layer, - List nodeFeaturePairs, - List leaves, - Histogram histogram) { + List layer, int[] nodeFeaturePairs, int numLeaves, Histogram histogram) { LOG.info("subtaskId: {}, {} start", subtaskId, SplitFinder.class.getSimpleName()); Distributor distributor = - new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.size()); - long start = distributor.start(subtaskId); - long cnt = distributor.count(subtaskId); + new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.length / 2); + int start = (int) distributor.start(subtaskId); + int cnt = (int) distributor.count(subtaskId); Split[] nodesBestSplits = new Split[layer.size()]; int binOffset = 0; - for (long i = start; i < start + cnt; i += 1) { - IntIntPair nodeFeaturePair = nodeFeaturePairs.get((int) i); - int nodeId = nodeFeaturePair.getOne(); - int featureId = nodeFeaturePair.getTwo(); + for (int i = start; i < start + cnt; i += 1) { + int nodeId = nodeFeaturePairs[2 * i]; + int featureId = nodeFeaturePairs[2 * i + 1]; LearningNode node = layer.get(nodeId); - Preconditions.checkState(node.depth < maxDepth || leaves.size() + 2 <= maxNumLeaves); + Preconditions.checkState(node.depth < maxDepth || numLeaves + 2 <= maxNumLeaves); splitters[featureId].reset( histogram.hists, new Slice(binOffset, binOffset + numFeatureBins[featureId])); Split bestSplit = splitters[featureId].bestSplit(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java new file mode 100644 index 000000000..7caf81078 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.iteration.IterationID; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +import java.util.List; + +/** Determines whether to terminated training. */ +public class TerminationOperator extends AbstractStreamOperator + implements OneInputStreamOperator, IterationListener { + + private final IterationID iterationID; + private final OutputTag modelDataOutputTag; + private IterationSharedStorage.Reader>> allTreesReader; + private IterationSharedStorage.Reader trainContextReader; + + public TerminationOperator( + IterationID iterationID, OutputTag modelDataOutputTag) { + this.iterationID = iterationID; + this.modelDataOutputTag = modelDataOutputTag; + } + + @Override + public void open() throws Exception { + int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); + allTreesReader = + IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.ALL_TREES); + trainContextReader = + IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.TRAIN_CONTEXT); + } + + @Override + public void processElement(StreamRecord element) throws Exception {} + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) { + boolean terminated = allTreesReader.get().size() == trainContextReader.get().params.maxIter; + // TODO: add validation error rate + if (!terminated) { + output.collect(new StreamRecord<>(0)); + } + } + + @Override + public void onIterationTerminated(Context context, Collector collector) { + context.output( + modelDataOutputTag, + GBTModelData.from(trainContextReader.get(), allTreesReader.get())); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/LocalStateInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java similarity index 76% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/LocalStateInitializer.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java index baf5548fc..3bf7f3c41 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/LocalStateInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java @@ -21,8 +21,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.GbtParams; -import org.apache.flink.ml.common.gbt.defs.LocalState; import org.apache.flink.ml.common.gbt.defs.TaskType; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.loss.AbsoluteError; import org.apache.flink.ml.common.gbt.loss.LogLoss; import org.apache.flink.ml.common.gbt.loss.Loss; @@ -39,11 +39,11 @@ import static java.util.Arrays.stream; -class LocalStateInitializer { - private static final Logger LOG = LoggerFactory.getLogger(LocalStateInitializer.class); +class TrainContextInitializer { + private static final Logger LOG = LoggerFactory.getLogger(TrainContextInitializer.class); private final GbtParams params; - public LocalStateInitializer(GbtParams params) { + public TrainContextInitializer(GbtParams params) { this.params = params; } @@ -52,43 +52,42 @@ public LocalStateInitializer(GbtParams params) { * *

Note that local state already has some properties set in advance, see GBTRunner#boost. */ - public LocalState init( - LocalState localState, int subtaskId, int numSubtasks, BinnedInstance[] instances) { - LOG.info("subtaskId: {}, {} start", subtaskId, LocalStateInitializer.class.getSimpleName()); + public TrainContext init( + TrainContext trainContext, int subtaskId, int numSubtasks, BinnedInstance[] instances) { + LOG.info( + "subtaskId: {}, {} start", + subtaskId, + TrainContextInitializer.class.getSimpleName()); - LocalState.Statics statics = localState.statics; - statics.subtaskId = subtaskId; - statics.numSubtasks = numSubtasks; + trainContext.subtaskId = subtaskId; + trainContext.numSubtasks = numSubtasks; int numInstances = instances.length; - int numFeatures = statics.featureMetas.length; + int numFeatures = trainContext.featureMetas.length; LOG.info( "subtaskId: {}, #samples: {}, #features: {}", subtaskId, numInstances, numFeatures); - statics.params = params; - statics.numInstances = numInstances; - statics.numFeatures = numFeatures; + trainContext.params = params; + trainContext.numInstances = numInstances; + trainContext.numFeatures = numFeatures; - statics.numBaggingInstances = getNumBaggingSamples(numInstances); - statics.numBaggingFeatures = getNumBaggingFeatures(numFeatures); + trainContext.numBaggingInstances = getNumBaggingSamples(numInstances); + trainContext.numBaggingFeatures = getNumBaggingFeatures(numFeatures); - statics.instanceRandomizer = new Random(subtaskId + params.seed); - statics.featureRandomizer = new Random(params.seed); + trainContext.instanceRandomizer = new Random(subtaskId + params.seed); + trainContext.featureRandomizer = new Random(params.seed); - statics.loss = getLoss(); - statics.prior = calcPrior(statics.labelSumCount); + trainContext.loss = getLoss(); + trainContext.prior = calcPrior(trainContext.labelSumCount); - statics.numFeatureBins = - stream(statics.featureMetas) - .mapToInt(d -> d.numBins(statics.params.useMissing)) + trainContext.numFeatureBins = + stream(trainContext.featureMetas) + .mapToInt(d -> d.numBins(trainContext.params.useMissing)) .toArray(); - LocalState.Dynamics dynamics = localState.dynamics; - dynamics.inWeakLearner = false; - - LOG.info("subtaskId: {}, {} end", subtaskId, LocalStateInitializer.class.getSimpleName()); - return new LocalState(statics, dynamics); + LOG.info("subtaskId: {}, {} end", subtaskId, TrainContextInitializer.class.getSimpleName()); + return trainContext; } private int getNumBaggingSamples(int numSamples) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java index 42240fc54..e63e1da9e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java @@ -20,10 +20,8 @@ import org.apache.flink.ml.common.gbt.DataUtils; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.LocalState; -import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.Slice; -import org.apache.flink.util.Preconditions; +import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,33 +39,28 @@ class TreeInitializer { private final int[] shuffledIndices; private final Random instanceRandomizer; - public TreeInitializer(LocalState.Statics statics) { - subtaskId = statics.subtaskId; - numInstances = statics.numInstances; - numBaggingInstances = statics.numBaggingInstances; - instanceRandomizer = statics.instanceRandomizer; + public TreeInitializer(TrainContext trainContext) { + subtaskId = trainContext.subtaskId; + numInstances = trainContext.numInstances; + numBaggingInstances = trainContext.numBaggingInstances; + instanceRandomizer = trainContext.instanceRandomizer; shuffledIndices = IntStream.range(0, numInstances).toArray(); } /** Calculate local histograms for nodes in current layer of tree. */ - public void init(LocalState.Dynamics dynamics, Consumer shuffledIndicesSetter) { + public void init(Consumer shuffledIndicesSetter) { LOG.info("subtaskId: {}, {} start", subtaskId, TreeInitializer.class.getSimpleName()); - Preconditions.checkState(!dynamics.inWeakLearner); - Preconditions.checkState(dynamics.layer.isEmpty()); - // Initializes the root node of a new tree when last tree is finalized. DataUtils.shuffle(shuffledIndices, instanceRandomizer); - Node root = new Node(); - dynamics.layer.add( - new LearningNode( - root, - new Slice(0, numBaggingInstances), - new Slice(numBaggingInstances, numInstances), - 1)); - dynamics.roots.add(root); - dynamics.leaves.clear(); shuffledIndicesSetter.accept(shuffledIndices); - LOG.info("subtaskId: {}, {} end", this.subtaskId, TreeInitializer.class.getSimpleName()); } + + public LearningNode getRootLearningNode() { + return new LearningNode( + 0, + new Slice(0, numBaggingInstances), + new Slice(numBaggingInstances, numInstances), + 1); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java index af195af16..e02f316f6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java @@ -82,7 +82,7 @@ public void serialize(BinnedInstance record, DataOutputView target) throws IOExc public BinnedInstance deserialize(DataInputView source) throws IOException { BinnedInstance instance = new BinnedInstance(); int numFeatures = IntSerializer.INSTANCE.deserialize(source); - instance.features = new IntIntHashMap(); + instance.features = new IntIntHashMap(numFeatures); for (int i = 0; i < numFeatures; i += 1) { int k = IntSerializer.INSTANCE.deserialize(source); int v = IntSerializer.INSTANCE.deserialize(source); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java index b57e2c94d..3e85a9354 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/ContinuousSplitSerializer.java @@ -69,32 +69,34 @@ public Split.ContinuousSplit copy(Split.ContinuousSplit from, Split.ContinuousSp @Override public int getLength() { - return -1; + return 3 * IntSerializer.INSTANCE.getLength() + + 3 * DoubleSerializer.INSTANCE.getLength() + + 2 * BooleanSerializer.INSTANCE.getLength(); } @Override public void serialize(Split.ContinuousSplit record, DataOutputView target) throws IOException { - IntSerializer.INSTANCE.serialize(record.featureId, target); - DoubleSerializer.INSTANCE.serialize(record.gain, target); - IntSerializer.INSTANCE.serialize(record.missingBin, target); - BooleanSerializer.INSTANCE.serialize(record.missingGoLeft, target); - DoubleSerializer.INSTANCE.serialize(record.prediction, target); - DoubleSerializer.INSTANCE.serialize(record.threshold, target); - BooleanSerializer.INSTANCE.serialize(record.isUnseenMissing, target); - IntSerializer.INSTANCE.serialize(record.zeroBin, target); + target.writeInt(record.featureId); + target.writeDouble(record.gain); + target.writeInt(record.missingBin); + target.writeBoolean(record.missingGoLeft); + target.writeDouble(record.prediction); + target.writeDouble(record.threshold); + target.writeBoolean(record.isUnseenMissing); + target.writeInt(record.zeroBin); } @Override public Split.ContinuousSplit deserialize(DataInputView source) throws IOException { return new Split.ContinuousSplit( - IntSerializer.INSTANCE.deserialize(source), - DoubleSerializer.INSTANCE.deserialize(source), - IntSerializer.INSTANCE.deserialize(source), - BooleanSerializer.INSTANCE.deserialize(source), - DoubleSerializer.INSTANCE.deserialize(source), - DoubleSerializer.INSTANCE.deserialize(source), - BooleanSerializer.INSTANCE.deserialize(source), - IntSerializer.INSTANCE.deserialize(source)); + source.readInt(), + source.readDouble(), + source.readInt(), + source.readBoolean(), + source.readDouble(), + source.readDouble(), + source.readBoolean(), + source.readInt()); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java index 419130146..53c188a48 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java @@ -39,6 +39,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.BitSet; +import java.util.List; /** Specialized serializer for {@link GBTModelData}. */ public final class GBTModelDataSerializer extends TypeSerializerSingleton { @@ -66,7 +67,10 @@ public GBTModelData copy(GBTModelData from) { record.prior = from.prior; record.stepSize = from.stepSize; - record.roots = new ArrayList<>(from.roots); + record.allTrees = new ArrayList<>(from.allTrees.size()); + for (int i = 0; i < from.allTrees.size(); i += 1) { + record.allTrees.add(new ArrayList<>(from.allTrees.get(i))); + } record.categoryToIdMaps = new IntObjectHashMap<>(from.categoryToIdMaps); record.featureIdToBinEdges = new IntObjectHashMap<>(from.featureIdToBinEdges); record.isCategorical = BitSet.valueOf(from.isCategorical.toByteArray()); @@ -91,9 +95,12 @@ public void serialize(GBTModelData record, DataOutputView target) throws IOExcep DoubleSerializer.INSTANCE.serialize(record.prior, target); DoubleSerializer.INSTANCE.serialize(record.stepSize, target); - IntSerializer.INSTANCE.serialize(record.roots.size(), target); - for (Node root : record.roots) { - NODE_SERIALIZER.serialize(root, target); + IntSerializer.INSTANCE.serialize(record.allTrees.size(), target); + for (List treeNodes : record.allTrees) { + IntSerializer.INSTANCE.serialize(treeNodes.size(), target); + for (Node treeNode : treeNodes) { + NodeSerializer.INSTANCE.serialize(treeNode, target); + } } IntSerializer.INSTANCE.serialize(record.categoryToIdMaps.size(), target); @@ -127,10 +134,15 @@ public GBTModelData deserialize(DataInputView source) throws IOException { record.prior = DoubleSerializer.INSTANCE.deserialize(source); record.stepSize = DoubleSerializer.INSTANCE.deserialize(source); - int numRoots = IntSerializer.INSTANCE.deserialize(source); - record.roots = new ArrayList<>(); - for (int i = 0; i < numRoots; i += 1) { - record.roots.add(NODE_SERIALIZER.deserialize(source)); + int numTrees = IntSerializer.INSTANCE.deserialize(source); + record.allTrees = new ArrayList<>(numTrees); + for (int k = 0; k < numTrees; k += 1) { + int numTreeNodes = IntSerializer.INSTANCE.deserialize(source); + List treeNodes = new ArrayList<>(numTreeNodes); + for (int i = 0; i < numTreeNodes; i += 1) { + treeNodes.add(NODE_SERIALIZER.deserialize(source)); + } + record.allTrees.add(treeNodes); } int numCategoricalFeatures = IntSerializer.INSTANCE.deserialize(source); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java new file mode 100644 index 000000000..9c4648399 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArraySerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Histogram; + +import java.io.IOException; + +/** Serializer for {@link Histogram}. */ +public final class HistogramSerializer extends TypeSerializerSingleton { + + public static final HistogramSerializer INSTANCE = new HistogramSerializer(); + private static final long serialVersionUID = 1L; + + private static final SplitSerializer SPLIT_SERIALIZER = SplitSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Histogram createInstance() { + return new Histogram(); + } + + @Override + public Histogram copy(Histogram from) { + Histogram histogram = new Histogram(); + histogram.subtaskId = from.subtaskId; + histogram.hists = from.hists.clone(); + histogram.recvcnts = from.recvcnts.clone(); + return histogram; + } + + @Override + public Histogram copy(Histogram from, Histogram reuse) { + assert from.getClass() == reuse.getClass(); + reuse.subtaskId = from.subtaskId; + reuse.hists = from.hists.clone(); + reuse.recvcnts = from.recvcnts.clone(); + return reuse; + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(Histogram record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.subtaskId, target); + DoublePrimitiveArraySerializer.INSTANCE.serialize(record.hists, target); + IntPrimitiveArraySerializer.INSTANCE.serialize(record.recvcnts, target); + } + + @Override + public Histogram deserialize(DataInputView source) throws IOException { + Histogram histogram = new Histogram(); + histogram.subtaskId = IntSerializer.INSTANCE.deserialize(source); + histogram.hists = DoublePrimitiveArraySerializer.INSTANCE.deserialize(source); + histogram.recvcnts = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + return histogram; + } + + @Override + public Histogram deserialize(Histogram reuse, DataInputView source) throws IOException { + reuse.subtaskId = IntSerializer.INSTANCE.deserialize(source); + reuse.hists = DoublePrimitiveArraySerializer.INSTANCE.deserialize(source); + reuse.recvcnts = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new HistogramSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class HistogramSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public HistogramSerializerSnapshot() { + super(HistogramSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java new file mode 100644 index 000000000..27975e07c --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import org.eclipse.collections.api.tuple.primitive.IntIntPair; +import org.eclipse.collections.impl.tuple.primitive.PrimitiveTuples; + +import java.io.IOException; + +/** Serializer for {@link IntIntPair}. */ +public class IntIntPairSerializer extends TypeSerializerSingleton { + + public static final IntIntPairSerializer INSTANCE = new IntIntPairSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public IntIntPair createInstance() { + return PrimitiveTuples.pair(0, 0); + } + + @Override + public IntIntPair copy(IntIntPair from) { + return PrimitiveTuples.pair(from.getOne(), from.getTwo()); + } + + @Override + public IntIntPair copy(IntIntPair from, IntIntPair reuse) { + return copy(from); + } + + @Override + public int getLength() { + return 2 * IntSerializer.INSTANCE.getLength(); + } + + @Override + public void serialize(IntIntPair record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.getOne(), target); + IntSerializer.INSTANCE.serialize(record.getTwo(), target); + } + + @Override + public IntIntPair deserialize(DataInputView source) throws IOException { + return PrimitiveTuples.pair( + (int) IntSerializer.INSTANCE.deserialize(source), + (int) IntSerializer.INSTANCE.deserialize(source)); + } + + @Override + public IntIntPair deserialize(IntIntPair reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new IntIntPairSerializer.IntIntPairSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class IntIntPairSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public IntIntPairSerializerSnapshot() { + super(IntIntPairSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/LearningNodeSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/LearningNodeSerializer.java new file mode 100644 index 000000000..69136cba4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/LearningNodeSerializer.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.LearningNode; + +import java.io.IOException; + +/** Serializer for {@link LearningNode}. */ +public final class LearningNodeSerializer extends TypeSerializerSingleton { + + public static final LearningNodeSerializer INSTANCE = new LearningNodeSerializer(); + private static final long serialVersionUID = 1L; + + private static final SliceSerializer SLICE_SERIALIZER = SliceSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public LearningNode createInstance() { + return new LearningNode(); + } + + @Override + public LearningNode copy(LearningNode from) { + LearningNode learningNode = new LearningNode(); + learningNode.nodeIndex = from.nodeIndex; + SLICE_SERIALIZER.copy(from.slice, learningNode.slice); + SLICE_SERIALIZER.copy(from.oob, learningNode.oob); + learningNode.slice = from.slice; + return learningNode; + } + + @Override + public LearningNode copy(LearningNode from, LearningNode reuse) { + assert from.getClass() == reuse.getClass(); + reuse.nodeIndex = from.nodeIndex; + SLICE_SERIALIZER.copy(from.slice, reuse.slice); + SLICE_SERIALIZER.copy(from.oob, reuse.oob); + reuse.depth = from.depth; + return reuse; + } + + @Override + public int getLength() { + return SLICE_SERIALIZER.getLength() + 2 * IntSerializer.INSTANCE.getLength(); + } + + @Override + public void serialize(LearningNode record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.nodeIndex, target); + SLICE_SERIALIZER.serialize(record.slice, target); + SLICE_SERIALIZER.serialize(record.oob, target); + IntSerializer.INSTANCE.serialize(record.depth, target); + } + + @Override + public LearningNode deserialize(DataInputView source) throws IOException { + LearningNode learningNode = new LearningNode(); + learningNode.nodeIndex = IntSerializer.INSTANCE.deserialize(source); + learningNode.slice = SLICE_SERIALIZER.deserialize(source); + learningNode.oob = SLICE_SERIALIZER.deserialize(source); + learningNode.depth = IntSerializer.INSTANCE.deserialize(source); + return learningNode; + } + + @Override + public LearningNode deserialize(LearningNode reuse, DataInputView source) throws IOException { + reuse.nodeIndex = IntSerializer.INSTANCE.deserialize(source); + reuse.slice = SLICE_SERIALIZER.deserialize(source); + reuse.oob = SLICE_SERIALIZER.deserialize(source); + reuse.depth = IntSerializer.INSTANCE.deserialize(source); + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new LearningNodeSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class LearningNodeSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public LearningNodeSerializerSnapshot() { + super(LearningNodeSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java index 6fbe73e87..c6087c6d5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeSerializer.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; @@ -51,10 +52,8 @@ public Node copy(Node from) { Node node = new Node(); node.split = SPLIT_SERIALIZER.copy(from.split); node.isLeaf = from.isLeaf; - if (!node.isLeaf) { - node.left = copy(from.left); - node.right = copy(from.right); - } + node.left = from.left; + node.right = from.right; return node; } @@ -63,10 +62,8 @@ public Node copy(Node from, Node reuse) { assert from.getClass() == reuse.getClass(); SPLIT_SERIALIZER.copy(from.split, reuse.split); reuse.isLeaf = from.isLeaf; - if (!reuse.isLeaf) { - copy(from.left, reuse.left); - copy(from.right, reuse.right); - } + reuse.left = from.left; + reuse.right = from.right; return reuse; } @@ -79,10 +76,8 @@ public int getLength() { public void serialize(Node record, DataOutputView target) throws IOException { SPLIT_SERIALIZER.serialize(record.split, target); BooleanSerializer.INSTANCE.serialize(record.isLeaf, target); - if (!record.isLeaf) { - serialize(record.left, target); - serialize(record.right, target); - } + IntSerializer.INSTANCE.serialize(record.left, target); + IntSerializer.INSTANCE.serialize(record.right, target); } @Override @@ -90,10 +85,8 @@ public Node deserialize(DataInputView source) throws IOException { Node node = new Node(); node.split = SPLIT_SERIALIZER.deserialize(source); node.isLeaf = BooleanSerializer.INSTANCE.deserialize(source); - if (!node.isLeaf) { - node.left = deserialize(source); - node.right = deserialize(source); - } + node.left = IntSerializer.INSTANCE.deserialize(source); + node.right = IntSerializer.INSTANCE.deserialize(source); return node; } @@ -101,10 +94,8 @@ public Node deserialize(DataInputView source) throws IOException { public Node deserialize(Node reuse, DataInputView source) throws IOException { reuse.split = SPLIT_SERIALIZER.deserialize(source); reuse.isLeaf = BooleanSerializer.INSTANCE.deserialize(source); - if (!reuse.isLeaf) { - reuse.left = deserialize(source); - reuse.right = deserialize(source); - } + reuse.left = IntSerializer.INSTANCE.deserialize(source); + reuse.right = IntSerializer.INSTANCE.deserialize(source); return reuse; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfo.java new file mode 100644 index 000000000..48e8342c5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.common.gbt.defs.Node; + +/** A {@link TypeInformation} for the {@link Node} type. */ +public class NodeTypeInfo extends TypeInformation { + + public static final NodeTypeInfo INSTANCE = new NodeTypeInfo(); + + private NodeTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 4; + } + + @Override + public int getTotalFields() { + return 4; + } + + @Override + public Class getTypeClass() { + return Node.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new NodeSerializer(); + } + + @Override + public String toString() { + return "Node"; + } + + @Override + public boolean equals(Object o) { + return o instanceof NodeTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof NodeTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfoFactory.java new file mode 100644 index 000000000..de6267ed3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/NodeTypeInfoFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.defs.Node; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * Node}. + */ +public class NodeTypeInfoFactory extends TypeInfoFactory { + + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return NodeTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SliceSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SliceSerializer.java new file mode 100644 index 000000000..4d5e9583a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SliceSerializer.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.common.gbt.defs.Slice; + +import java.io.IOException; + +/** Serializer for {@link Slice}. */ +public final class SliceSerializer extends TypeSerializerSingleton { + + public static final SliceSerializer INSTANCE = new SliceSerializer(); + private static final long serialVersionUID = 1L; + + private static final SplitSerializer SPLIT_SERIALIZER = SplitSerializer.INSTANCE; + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public Slice createInstance() { + return new Slice(); + } + + @Override + public Slice copy(Slice from) { + Slice slice = new Slice(); + slice.start = from.start; + slice.end = from.end; + return slice; + } + + @Override + public Slice copy(Slice from, Slice reuse) { + reuse.start = from.start; + reuse.end = from.end; + return reuse; + } + + @Override + public int getLength() { + return 2 * IntSerializer.INSTANCE.getLength(); + } + + @Override + public void serialize(Slice record, DataOutputView target) throws IOException { + IntSerializer.INSTANCE.serialize(record.start, target); + IntSerializer.INSTANCE.serialize(record.end, target); + } + + @Override + public Slice deserialize(DataInputView source) throws IOException { + Slice slice = new Slice(); + slice.start = IntSerializer.INSTANCE.deserialize(source); + slice.end = IntSerializer.INSTANCE.deserialize(source); + return slice; + } + + @Override + public Slice deserialize(Slice reuse, DataInputView source) throws IOException { + reuse.start = IntSerializer.INSTANCE.deserialize(source); + reuse.end = IntSerializer.INSTANCE.deserialize(source); + return reuse; + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + // ------------------------------------------------------------------------ + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new SliceSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class SliceSerializerSnapshot extends SimpleTypeSerializerSnapshot { + + public SliceSerializerSnapshot() { + super(SliceSerializer::new); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java index c8d44df7f..5c9efadfd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitSerializer.java @@ -78,11 +78,13 @@ public int getLength() { @Override public void serialize(Split record, DataOutputView target) throws IOException { - if (record instanceof Split.CategoricalSplit) { + if (null == record) { target.writeByte(0); + } else if (record instanceof Split.CategoricalSplit) { + target.writeByte(1); CATEGORICAL_SPLIT_SERIALIZER.serialize((Split.CategoricalSplit) record, target); } else { - target.writeByte(1); + target.writeByte(2); CONTINUOUS_SPLIT_SERIALIZER.serialize((Split.ContinuousSplit) record, target); } } @@ -91,6 +93,8 @@ public void serialize(Split record, DataOutputView target) throws IOException { public Split deserialize(DataInputView source) throws IOException { byte type = source.readByte(); if (type == 0) { + return null; + } else if (type == 1) { return CATEGORICAL_SPLIT_SERIALIZER.deserialize(source); } else { return CONTINUOUS_SPLIT_SERIALIZER.deserialize(source); @@ -100,9 +104,12 @@ public Split deserialize(DataInputView source) throws IOException { @Override public Split deserialize(Split reuse, DataInputView source) throws IOException { byte type = source.readByte(); - assert type == 0 && reuse instanceof Split.CategoricalSplit - || type == 1 && reuse instanceof Split.ContinuousSplit; if (type == 0) { + return null; + } + assert type == 1 && reuse instanceof Split.CategoricalSplit + || type == 2 && reuse instanceof Split.ContinuousSplit; + if (type == 1) { return CATEGORICAL_SPLIT_SERIALIZER.deserialize((Split.CategoricalSplit) reuse, source); } else { return CONTINUOUS_SPLIT_SERIALIZER.deserialize((Split.ContinuousSplit) reuse, source); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfo.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfo.java new file mode 100644 index 000000000..ca26987ee --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.common.gbt.defs.Split; + +/** A {@link TypeInformation} for the {@link Split} type. */ +public class SplitTypeInfo extends TypeInformation { + + public static final SplitTypeInfo INSTANCE = new SplitTypeInfo(); + + private SplitTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 8; + } + + @Override + public int getTotalFields() { + return 8; + } + + @Override + public Class getTypeClass() { + return Split.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new SplitSerializer(); + } + + @Override + public String toString() { + return "Split"; + } + + @Override + public boolean equals(Object o) { + return o instanceof SplitTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof SplitTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfoFactory.java new file mode 100644 index 000000000..68c47f4c3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/SplitTypeInfoFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.defs.Split; + +import java.lang.reflect.Type; +import java.util.HashMap; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * Split}. + */ +public class SplitTypeInfoFactory extends TypeInfoFactory { + + private static final Map> fields; + + static { + fields = new HashMap<>(); + } + + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return SplitTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java index c7de6d290..4462daf36 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -468,7 +468,7 @@ public void testGetModelData() throws Exception { Assert.assertFalse(modelData.isInputVector); Assert.assertEquals(0., modelData.prior, 1e-12); Assert.assertEquals(gbtc.getStepSize(), modelData.stepSize, 1e-12); - Assert.assertEquals(gbtc.getMaxIter(), modelData.roots.size()); + Assert.assertEquals(gbtc.getMaxIter(), modelData.allTrees.size()); Assert.assertEquals(gbtc.getCategoricalCols().length, modelData.categoryToIdMaps.size()); Assert.assertEquals( gbtc.getInputCols().length - gbtc.getCategoricalCols().length, diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java index 2811d83cd..f609c2de3 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java @@ -110,7 +110,7 @@ private GbtParams getCommonGbtParams() { private void verifyModelData(GBTModelData modelData, GbtParams p) { Assert.assertEquals(p.taskType, TaskType.valueOf(modelData.type)); Assert.assertEquals(p.stepSize, modelData.stepSize, 1e-12); - Assert.assertEquals(p.maxIter, modelData.roots.size()); + Assert.assertEquals(p.maxIter, modelData.allTrees.size()); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java index af0810e6e..f6dc213ce 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java @@ -346,7 +346,7 @@ public void testGetModelData() throws Exception { Assert.assertFalse(modelData.isInputVector); Assert.assertEquals(40.5, modelData.prior, .5); Assert.assertEquals(gbtr.getStepSize(), modelData.stepSize, 1e-12); - Assert.assertEquals(gbtr.getMaxIter(), modelData.roots.size()); + Assert.assertEquals(gbtr.getMaxIter(), modelData.allTrees.size()); Assert.assertEquals(gbtr.getCategoricalCols().length, modelData.categoryToIdMaps.size()); Assert.assertEquals( gbtr.getInputCols().length - gbtr.getCategoricalCols().length, From 4a76eda484d22938a169dc5c957fa9e13fd72fc5 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Thu, 16 Feb 2023 10:53:18 +0800 Subject: [PATCH 10/47] Change features storage in BinnedInstance --- .../ml/common/gbt/defs/BinnedInstance.java | 23 ++- .../flink/ml/common/gbt/defs/Split.java | 13 +- .../CacheDataCalcLocalHistsOperator.java | 33 ++-- .../ml/common/gbt/operators/HistBuilder.java | 157 +++++++++++++----- .../typeinfo/BinnedInstanceSerializer.java | 28 ++-- 5 files changed, 165 insertions(+), 89 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java index b7509fe36..1e03e736a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BinnedInstance.java @@ -22,7 +22,9 @@ import org.apache.flink.ml.feature.stringindexer.StringIndexer; import org.apache.flink.ml.linalg.SparseVector; -import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; +import javax.annotation.Nullable; + +import java.util.Arrays; /** * Represents an instance including binned values of all features, weight, and label. @@ -36,21 +38,28 @@ */ public class BinnedInstance { - public IntIntHashMap features; + @Nullable public int[] featureIds; + public int[] featureValues; public double weight; public double label; public BinnedInstance() {} - public BinnedInstance(IntIntHashMap features, double weight, double label) { - this.weight = weight; - this.label = label; - this.features = features; + /** + * Get the index of `featureId` in `featureValues`. + * + * @param featureId The feature ID. + * @return The index in `featureValues`. If the index is negative, the corresponding feature is + * not stored in `featureValues`. + */ + public int getFeatureIndex(int featureId) { + return null == featureIds ? featureId : Arrays.binarySearch(featureIds, featureId); } @Override public String toString() { return String.format( - "BinnedInstance{features=%s, weight=%s, label=%s}", features, weight, label); + "BinnedInstance{featureIds=%s, featureValues=%s, weight=%s, label=%s}", + Arrays.toString(featureIds), Arrays.toString(featureValues), weight, label); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java index 226577262..a88f5e668 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java @@ -22,7 +22,6 @@ import org.apache.flink.ml.common.gbt.typeinfo.SplitTypeInfoFactory; import org.eclipse.collections.impl.map.mutable.primitive.IntDoubleHashMap; -import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; import java.util.BitSet; @@ -112,11 +111,11 @@ public static ContinuousSplit invalid(double prediction) { @Override public boolean shouldGoLeft(BinnedInstance binnedInstance) { - IntIntHashMap features = binnedInstance.features; - if (!features.containsKey(featureId) && isUnseenMissing) { + int index = binnedInstance.getFeatureIndex(featureId); + if (index < 0 && isUnseenMissing) { return missingGoLeft; } - int binId = features.getIfAbsent(featureId, zeroBin); + int binId = index >= 0 ? binnedInstance.featureValues[index] : zeroBin; return binId == missingBin ? missingGoLeft : binId <= threshold; } @@ -152,11 +151,11 @@ public static CategoricalSplit invalid(double prediction) { @Override public boolean shouldGoLeft(BinnedInstance binnedInstance) { - IntIntHashMap features = binnedInstance.features; - if (!features.containsKey(featureId)) { + int index = binnedInstance.getFeatureIndex(featureId); + if (index < 0) { return missingGoLeft; } - int binId = features.get(featureId); + int binId = binnedInstance.featureValues[index]; return binId == missingBin ? missingGoLeft : categoriesGoLeft.get(binId); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 2f5fd4e8e..741a8ccfc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -50,10 +50,10 @@ import org.apache.flink.util.Preconditions; import org.apache.commons.collections.IteratorUtils; -import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -204,23 +204,22 @@ public void snapshotState(StateSnapshotContext context) throws Exception { @Override public void processElement1(StreamRecord streamRecord) throws Exception { Row row = streamRecord.getValue(); - IntIntHashMap features; + BinnedInstance instance = new BinnedInstance(); + instance.weight = 1.; + instance.label = row.getFieldAs(gbtParams.labelCol); + if (gbtParams.isInputVector) { Vector vec = row.getFieldAs(gbtParams.vectorCol); SparseVector sv = vec.toSparse(); - features = new IntIntHashMap(sv.indices.length); - for (int i = 0; i < sv.indices.length; i += 1) { - features.put(sv.indices[i], (int) sv.values[i]); - } + instance.featureIds = sv.indices.length == sv.size() ? null : sv.indices; + instance.featureValues = Arrays.stream(sv.values).mapToInt(d -> (int) d).toArray(); } else { - features = new IntIntHashMap(gbtParams.featureCols.length); - for (int i = 0; i < gbtParams.featureCols.length; i += 1) { - // Values from StringIndexModel#transform are double. - features.put(i, ((Number) row.getFieldAs(gbtParams.featureCols[i])).intValue()); - } + instance.featureValues = + Arrays.stream(gbtParams.featureCols) + .mapToInt(col -> ((Number) row.getFieldAs(col)).intValue()) + .toArray(); } - double label = row.getFieldAs(gbtParams.labelCol); - instancesCollecting.add(new BinnedInstance(features, 1., label)); + instancesCollecting.add(instance); } @Override @@ -301,16 +300,10 @@ public void onEpochWatermarkIncremented( layer = Collections.singletonList(rootLearningNodeWriter.get()); } - int[] nodeFeaturePairs = - OperatorStateUtils.getUniqueElement(histBuilder, HIST_BUILDER_STATE_NAME) - .get() - .getNodeFeaturePairs(layer.size()); - nodeFeaturePairsWriter.set(nodeFeaturePairs); - Histogram localHists = OperatorStateUtils.getUniqueElement(histBuilder, HIST_BUILDER_STATE_NAME) .get() - .build(layer, nodeFeaturePairs, indices, instances, pgh); + .build(layer, indices, instances, pgh, nodeFeaturePairsWriter::set); out.collect(localHists); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index 89c18ef5e..757f66369 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -31,10 +31,14 @@ import org.slf4j.LoggerFactory; import java.util.Arrays; +import java.util.BitSet; import java.util.List; import java.util.Random; +import java.util.function.Consumer; import java.util.stream.IntStream; +import static org.apache.flink.ml.common.gbt.DataUtils.BIN_SIZE; + class HistBuilder { private static final Logger LOG = LoggerFactory.getLogger(HistBuilder.class); @@ -74,7 +78,7 @@ public HistBuilder(TrainContext trainContext) { int totalNumFeatureBins = Arrays.stream(numFeatureBins).sum(); int maxNumBins = maxNumNodes * Math.min(maxFeatureBins * numBaggingFeatures, totalNumFeatureBins); - hists = new double[maxNumBins * DataUtils.BIN_SIZE]; + hists = new double[maxNumBins * BIN_SIZE]; } /** @@ -83,42 +87,111 @@ public HistBuilder(TrainContext trainContext) { */ private static void calcNodeFeaturePairHists( List layer, - int[] nodeFeaturePairs, + int[][] nodeToFeatures, FeatureMeta[] featureMetas, - boolean isInputVector, int[] numFeatureBins, + boolean isInputVector, int[] indices, BinnedInstance[] instances, PredGradHess[] pgh, double[] hists) { - Arrays.fill(hists, 0.); + + int numNodes = layer.size(); + int numFeatures = featureMetas.length; + + int[][] nodeToBinOffsets = new int[numNodes][]; int binOffset = 0; - for (int k = 0; k < nodeFeaturePairs.length; k += 2) { - int nodeId = nodeFeaturePairs[k]; - int featureId = nodeFeaturePairs[k + 1]; - FeatureMeta featureMeta = featureMetas[featureId]; - - int defaultValue = featureMeta.missingBin; - // When isInputVector is true, values of unseen features are treated as 0s. - if (isInputVector && featureMeta instanceof FeatureMeta.ContinuousFeatureMeta) { - defaultValue = ((FeatureMeta.ContinuousFeatureMeta) featureMeta).zeroBin; + for (int k = 0; k < numNodes; k += 1) { + int[] features = nodeToFeatures[k]; + nodeToBinOffsets[k] = new int[features.length]; + for (int i = 0; i < features.length; i += 1) { + nodeToBinOffsets[k][i] = binOffset; + binOffset += numFeatureBins[features[i]]; } + } + + int[] featureDefaultVal = new int[numFeatures]; + for (int i = 0; i < numFeatures; i += 1) { + FeatureMeta d = featureMetas[i]; + featureDefaultVal[i] = + isInputVector && d instanceof FeatureMeta.ContinuousFeatureMeta + ? ((FeatureMeta.ContinuousFeatureMeta) d).zeroBin + : d.missingBin; + } + + int[] featureOffset = new int[numFeatures]; + for (int k = 0; k < numNodes; k += 1) { + int[] features = nodeToFeatures[k]; + int[] binOffsets = nodeToBinOffsets[k]; + LearningNode node = layer.get(k); - LearningNode node = layer.get(nodeId); + BitSet featureValid = new BitSet(numFeatures); + for (int i = 0; i < features.length; i += 1) { + featureValid.set(features[i]); + featureOffset[features[i]] = binOffsets[i]; + } + + double[] totalHists = new double[4]; for (int i = node.slice.start; i < node.slice.end; i += 1) { int instanceId = indices[i]; BinnedInstance binnedInstance = instances[instanceId]; + double weight = binnedInstance.weight; double gradient = pgh[instanceId].gradient; double hessian = pgh[instanceId].hessian; - int val = binnedInstance.features.getIfAbsent(featureId, defaultValue); - int startIndex = (binOffset + val) * DataUtils.BIN_SIZE; - hists[startIndex] += gradient; - hists[startIndex + 1] += hessian; - hists[startIndex + 2] += binnedInstance.weight; - hists[startIndex + 3] += 1.; + totalHists[0] += gradient; + totalHists[1] += hessian; + totalHists[2] += weight; + totalHists[3] += 1.; + + if (null == binnedInstance.featureIds) { + for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { + if (!featureValid.get(j)) { + continue; + } + int val = binnedInstance.featureValues[j]; + int offset = featureOffset[j]; + int index = (offset + val) * BIN_SIZE; + hists[index] += gradient; + hists[index + 1] += hessian; + hists[index + 2] += weight; + hists[index + 3] += 1.; + } + } else { + for (int j = 0; j < binnedInstance.featureIds.length; j += 1) { + int featureId = binnedInstance.featureIds[j]; + if (!featureValid.get(featureId)) { + continue; + } + int val = binnedInstance.featureValues[j]; + int offset = featureOffset[featureId]; + int index = (offset + val) * BIN_SIZE; + hists[index] += gradient; + hists[index + 1] += hessian; + hists[index + 2] += weight; + hists[index + 3] += 1.; + } + } + } + + for (int featureId : features) { + int defaultVal = featureDefaultVal[featureId]; + int defaultValIndex = (featureOffset[featureId] + defaultVal) * BIN_SIZE; + hists[defaultValIndex] = totalHists[0]; + hists[defaultValIndex + 1] = totalHists[1]; + hists[defaultValIndex + 2] = totalHists[2]; + hists[defaultValIndex + 3] = totalHists[3]; + + for (int i = 0; i < numFeatureBins[featureId]; i += 1) { + if (i != defaultVal) { + int index = (featureOffset[featureId] + i) * BIN_SIZE; + hists[defaultValIndex] -= hists[index]; + hists[defaultValIndex + 1] -= hists[index + 1]; + hists[defaultValIndex + 2] -= hists[index + 2]; + hists[defaultValIndex + 3] -= hists[index + 3]; + } + } } - binOffset += numFeatureBins[featureId]; } } @@ -138,43 +211,45 @@ private static int[] calcRecvCounts( int pairCnt = (int) distributor.count(k); for (int i = pairStart; i < pairStart + pairCnt; i += 1) { int featureId = nodeFeaturePairs[2 * i + 1]; - recvcnts[k] += numFeatureBins[featureId] * DataUtils.BIN_SIZE; + recvcnts[k] += numFeatureBins[featureId] * BIN_SIZE; } } return recvcnts; } - /** Generates (nodeId, featureId) pairs that are required to build histograms. */ - int[] getNodeFeaturePairs(int numLayerNodes) { - int[] nodeFeaturePairs = new int[numLayerNodes * numBaggingFeatures * 2]; - int p = 0; - for (int k = 0; k < numLayerNodes; k += 1) { - int[] sampledFeatures = - DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); - for (int featureId : sampledFeatures) { - nodeFeaturePairs[p++] = k; - nodeFeaturePairs[p++] = featureId; - } - } - return nodeFeaturePairs; - } - /** Calculate local histograms for nodes in current layer of tree. */ Histogram build( List layer, - int[] nodeFeaturePairs, int[] indices, BinnedInstance[] instances, - PredGradHess[] pgh) { + PredGradHess[] pgh, + Consumer nodeFeaturePairsSetter) { LOG.info("subtaskId: {}, {} start", subtaskId, HistBuilder.class.getSimpleName()); + int numNodes = layer.size(); + + // Generates (nodeId, featureId) pairs that are required to build histograms. + int[][] nodeToFeatures = new int[numNodes][]; + int[] nodeFeaturePairs = new int[numNodes * numBaggingFeatures * 2]; + int p = 0; + for (int k = 0; k < numNodes; k += 1) { + nodeToFeatures[k] = + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + Arrays.sort(nodeToFeatures[k]); + for (int featureId : nodeToFeatures[k]) { + nodeFeaturePairs[p++] = k; + nodeFeaturePairs[p++] = featureId; + } + } + nodeFeaturePairsSetter.accept(nodeFeaturePairs); + Arrays.fill(hists, 0); // Calculates histograms for (nodeId, featureId) pairs. calcNodeFeaturePairHists( layer, - nodeFeaturePairs, + nodeToFeatures, featureMetas, - isInputVector, numFeatureBins, + isInputVector, indices, instances, pgh, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java index e02f316f6..f3e038143 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/BinnedInstanceSerializer.java @@ -21,14 +21,12 @@ import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.base.DoubleSerializer; -import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; -import org.eclipse.collections.impl.map.mutable.primitive.IntIntHashMap; - import java.io.IOException; /** Serializer for {@link BinnedInstance}. */ @@ -50,7 +48,8 @@ public BinnedInstance createInstance() { @Override public BinnedInstance copy(BinnedInstance from) { BinnedInstance instance = new BinnedInstance(); - instance.features = new IntIntHashMap(from.features); + instance.featureIds = null == from.featureIds ? null : from.featureIds.clone(); + instance.featureValues = from.featureValues.clone(); instance.label = from.label; instance.weight = from.weight; return instance; @@ -69,11 +68,13 @@ public int getLength() { @Override public void serialize(BinnedInstance record, DataOutputView target) throws IOException { - IntSerializer.INSTANCE.serialize(record.features.size(), target); - for (int k : record.features.keysView().toArray()) { - IntSerializer.INSTANCE.serialize(k, target); - IntSerializer.INSTANCE.serialize(record.features.get(k), target); + if (null == record.featureIds) { + target.writeBoolean(true); + } else { + target.writeBoolean(false); + IntPrimitiveArraySerializer.INSTANCE.serialize(record.featureIds, target); } + IntPrimitiveArraySerializer.INSTANCE.serialize(record.featureValues, target); DoubleSerializer.INSTANCE.serialize(record.label, target); DoubleSerializer.INSTANCE.serialize(record.weight, target); } @@ -81,13 +82,12 @@ public void serialize(BinnedInstance record, DataOutputView target) throws IOExc @Override public BinnedInstance deserialize(DataInputView source) throws IOException { BinnedInstance instance = new BinnedInstance(); - int numFeatures = IntSerializer.INSTANCE.deserialize(source); - instance.features = new IntIntHashMap(numFeatures); - for (int i = 0; i < numFeatures; i += 1) { - int k = IntSerializer.INSTANCE.deserialize(source); - int v = IntSerializer.INSTANCE.deserialize(source); - instance.features.put(k, v); + if (source.readBoolean()) { + instance.featureIds = null; + } else { + instance.featureIds = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); } + instance.featureValues = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); instance.label = DoubleSerializer.INSTANCE.deserialize(source); instance.weight = DoubleSerializer.INSTANCE.deserialize(source); return instance; From 4b970644cf32d28c74411244ffdc4b10a0fca005 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 21 Feb 2023 11:25:38 +0800 Subject: [PATCH 11/47] [NO MERGE] Ignore GBT operators to pass Python completeness tests. --- flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py index db59df0b7..ad5e6210b 100644 --- a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py +++ b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py @@ -92,7 +92,8 @@ def module(self): pass def exclude_java_stage(self): - return [] + return ['gbtclassifier.GBTClassifier', 'gbtclassifier.GBTClassifierModel', + 'gbtregressor.GBTRegressor', 'gbtregressor.GBTRegressorModel'] class ClassificationCompletenessTest(CompletenessTest, MLLibTest): From c7156a4ad6d5fbb5f2286246e489958b5e3952b2 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 17 Feb 2023 10:45:41 +0800 Subject: [PATCH 12/47] Improve feature splitter. --- .../splitter/CategoricalFeatureSplitter.java | 7 +-- .../splitter/ContinuousFeatureSplitter.java | 12 ++-- .../splitter/HistogramFeatureSplitter.java | 59 +++++++++++++++++-- 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java index eaac29c47..db2b07d55 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java @@ -18,7 +18,6 @@ package org.apache.flink.ml.common.gbt.splitter; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.GbtParams; @@ -42,9 +41,9 @@ public CategoricalFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtPar @Override public Split.CategoricalSplit bestSplit() { - Tuple2 totalMissing = countTotalMissing(); - HessianImpurity total = totalMissing.f0; - HessianImpurity missing = totalMissing.f1; + HessianImpurity total = emptyImpurity(); + HessianImpurity missing = emptyImpurity(); + countTotalMissing(total, missing); if (total.getNumInstances() <= minSamplesPerLeaf) { return Split.CategoricalSplit.invalid(total.prediction()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java index ce2656cd8..924c3ad4f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java @@ -18,15 +18,12 @@ package org.apache.flink.ml.common.gbt.splitter; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.HessianImpurity; import org.apache.flink.ml.common.gbt.defs.Split; -import java.util.stream.IntStream; - /** Splitter for a continuous feature. */ public final class ContinuousFeatureSplitter extends HistogramFeatureSplitter { @@ -36,16 +33,15 @@ public ContinuousFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtPara @Override public Split.ContinuousSplit bestSplit() { - Tuple2 totalMissing = countTotalMissing(); - HessianImpurity total = totalMissing.f0; - HessianImpurity missing = totalMissing.f1; + HessianImpurity total = emptyImpurity(); + HessianImpurity missing = emptyImpurity(); + countTotalMissing(total, missing); if (total.getNumInstances() <= minSamplesPerLeaf) { return Split.ContinuousSplit.invalid(total.prediction()); } - int[] sortedBinIds = IntStream.range(0, slice.size()).toArray(); - Tuple3 bestSplit = findBestSplit(sortedBinIds, total, missing); + Tuple3 bestSplit = findBestSplit(slice.size(), total, missing); double bestGain = bestSplit.f0; int bestSplitBinId = bestSplit.f1; boolean missingGoLeft = bestSplit.f2; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java index 0d22a7466..3774d4b86 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java @@ -83,6 +83,25 @@ protected Tuple2 findBestSplitWithInitial( return Tuple2.of(bestGain, bestSplitBinId); } + protected Tuple2 findBestSplitWithInitial( + int numBins, HessianImpurity total, HessianImpurity left, HessianImpurity right) { + // Bins [0, bestSplitBinId] go left. + int bestSplitBinId = 0; + double bestGain = Split.INVALID_GAIN; + for (int binId = 0; binId < numBins; binId += 1) { + if (useMissing && binId == featureMeta.missingBin) { + continue; + } + addBinToLeft(binId, left, right); + double gain = gain(total, left, right); + if (gain > bestGain && gain >= minInfoGain) { + bestGain = gain; + bestSplitBinId = binId; + } + } + return Tuple2.of(bestGain, bestSplitBinId); + } + protected Tuple3 findBestSplit( int[] sortedBinIds, HessianImpurity total, HessianImpurity missing) { double bestGain = Split.INVALID_GAIN; @@ -101,7 +120,7 @@ protected Tuple3 findBestSplit( } } - if (useMissing) { + if (useMissing && missing.getNumInstances() > 0) { // The cases where the missing values go left. HessianImpurity leftWithMissing = emptyImpurity().add(missing); HessianImpurity rightWithoutMissing = (HessianImpurity) total.clone().subtract(missing); @@ -117,21 +136,51 @@ protected Tuple3 findBestSplit( return Tuple3.of(bestGain, bestSplitBinId, missingGoLeft); } + protected Tuple3 findBestSplit( + int numBins, HessianImpurity total, HessianImpurity missing) { + double bestGain = Split.INVALID_GAIN; + int bestSplitBinId = 0; + boolean missingGoLeft = false; + + { + // The cases where the missing values go right, or missing values are not allowed. + HessianImpurity left = emptyImpurity(); + HessianImpurity right = (HessianImpurity) total.clone(); + Tuple2 bestSplit = + findBestSplitWithInitial(numBins, total, left, right); + if (bestSplit.f0 > bestGain) { + bestGain = bestSplit.f0; + bestSplitBinId = bestSplit.f1; + } + } + + if (useMissing) { + // The cases where the missing values go left. + HessianImpurity leftWithMissing = emptyImpurity().add(missing); + HessianImpurity rightWithoutMissing = (HessianImpurity) total.clone().subtract(missing); + Tuple2 bestSplitMissingGoLeft = + findBestSplitWithInitial(numBins, total, leftWithMissing, rightWithoutMissing); + if (bestSplitMissingGoLeft.f0 > bestGain) { + bestGain = bestSplitMissingGoLeft.f0; + bestSplitBinId = bestSplitMissingGoLeft.f1; + missingGoLeft = true; + } + } + return Tuple3.of(bestGain, bestSplitBinId, missingGoLeft); + } + public void reset(double[] hists, Slice slice) { this.hists = hists; this.slice = slice; } - protected Tuple2 countTotalMissing() { - HessianImpurity total = emptyImpurity(); - HessianImpurity missing = emptyImpurity(); + protected void countTotalMissing(HessianImpurity total, HessianImpurity missing) { for (int i = 0; i < slice.size(); ++i) { addBinToLeft(i, total, null); } if (useMissing) { addBinToLeft(featureMeta.missingBin, missing, null); } - return Tuple2.of(total, missing); } protected HessianImpurity emptyImpurity() { From 5525b858630b11526f1c3e2527885050cf8ba8d0 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 17 Feb 2023 11:40:53 +0800 Subject: [PATCH 13/47] Improve hist builder when no feature subsampling. --- .../ml/common/gbt/operators/HistBuilder.java | 84 +++++++++++++------ 1 file changed, 57 insertions(+), 27 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index 757f66369..d7fcd093e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -120,15 +120,25 @@ private static void calcNodeFeaturePairHists( } int[] featureOffset = new int[numFeatures]; + BitSet featureValid = null; + boolean allFeatureValid; for (int k = 0; k < numNodes; k += 1) { int[] features = nodeToFeatures[k]; int[] binOffsets = nodeToBinOffsets[k]; LearningNode node = layer.get(k); - BitSet featureValid = new BitSet(numFeatures); - for (int i = 0; i < features.length; i += 1) { - featureValid.set(features[i]); - featureOffset[features[i]] = binOffsets[i]; + if (numFeatures != features.length) { + allFeatureValid = false; + featureValid = new BitSet(numFeatures); + for (int feature : features) { + featureValid.set(feature); + } + for (int i = 0; i < features.length; i += 1) { + featureOffset[features[i]] = binOffsets[i]; + } + } else { + allFeatureValid = true; + System.arraycopy(binOffsets, 0, featureOffset, 0, numFeatures); } double[] totalHists = new double[4]; @@ -143,33 +153,41 @@ private static void calcNodeFeaturePairHists( totalHists[1] += hessian; totalHists[2] += weight; totalHists[3] += 1.; + } + + for (int i = node.slice.start; i < node.slice.end; i += 1) { + int instanceId = indices[i]; + BinnedInstance binnedInstance = instances[instanceId]; + double weight = binnedInstance.weight; + double gradient = pgh[instanceId].gradient; + double hessian = pgh[instanceId].hessian; if (null == binnedInstance.featureIds) { for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { - if (!featureValid.get(j)) { - continue; + if (allFeatureValid || featureValid.get(j)) { + add( + hists, + featureOffset[j], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); } - int val = binnedInstance.featureValues[j]; - int offset = featureOffset[j]; - int index = (offset + val) * BIN_SIZE; - hists[index] += gradient; - hists[index + 1] += hessian; - hists[index + 2] += weight; - hists[index + 3] += 1.; } } else { for (int j = 0; j < binnedInstance.featureIds.length; j += 1) { int featureId = binnedInstance.featureIds[j]; - if (!featureValid.get(featureId)) { - continue; + if (allFeatureValid || featureValid.get(featureId)) { + add( + hists, + featureOffset[featureId], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); } - int val = binnedInstance.featureValues[j]; - int offset = featureOffset[featureId]; - int index = (offset + val) * BIN_SIZE; - hists[index] += gradient; - hists[index + 1] += hessian; - hists[index + 2] += weight; - hists[index + 3] += 1.; } } } @@ -181,20 +199,32 @@ private static void calcNodeFeaturePairHists( hists[defaultValIndex + 1] = totalHists[1]; hists[defaultValIndex + 2] = totalHists[2]; hists[defaultValIndex + 3] = totalHists[3]; - for (int i = 0; i < numFeatureBins[featureId]; i += 1) { if (i != defaultVal) { int index = (featureOffset[featureId] + i) * BIN_SIZE; - hists[defaultValIndex] -= hists[index]; - hists[defaultValIndex + 1] -= hists[index + 1]; - hists[defaultValIndex + 2] -= hists[index + 2]; - hists[defaultValIndex + 3] -= hists[index + 3]; + add( + hists, + featureOffset[featureId], + defaultVal, + -hists[index], + -hists[index + 1], + -hists[index + 2], + -hists[index + 3]); } } } } } + private static void add( + double[] hists, int offset, int val, double d0, double d1, double d2, double d3) { + int index = (offset + val) * BIN_SIZE; + hists[index] += d0; + hists[index + 1] += d1; + hists[index + 2] += d2; + hists[index + 3] += d3; + } + /** * Calculates elements counts of histogram distributed to each downstream subtask. The elements * counts is bin counts multiplied by STEP. The minimum unit to be distributed is (nodeId, From 89323927e7f3f32111e1ef9ea38e38561d82c387 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 17 Feb 2023 19:09:54 +0800 Subject: [PATCH 14/47] Add optimized serializer for double arrays. --- .../typeinfo/DenseVectorSerializerTest.java | 57 +++++++ ...zedDoublePrimitiveArraySerializerTest.java | 59 +++++++ .../flink/ml/common/gbt/defs/Histogram.java | 3 + .../gbt/typeinfo/HistogramSerializer.java | 11 +- .../gbt/typeinfo/HistogramTypeInfo.java | 88 ++++++++++ .../typeinfo/HistogramTypeInfoFactory.java | 40 +++++ .../typeinfo/DenseVectorSerializer.java | 44 +---- ...timizedDoublePrimitiveArraySerializer.java | 158 ++++++++++++++++++ 8 files changed, 417 insertions(+), 43 deletions(-) create mode 100644 flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializerTest.java create mode 100644 flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializerTest.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfo.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfoFactory.java create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializerTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializerTest.java new file mode 100644 index 000000000..84dba96a3 --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializerTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Random; + +/** Tests the serialization and deserialization from {@link DenseVectorSerializer}. */ +public class DenseVectorSerializerTest { + @Test + public void testSerializationDeserialization() throws IOException { + Random random = new Random(0); + int[] lens = new int[] {0, 100, 128, 500, 1024, 4096}; + + DenseVectorSerializer serializer = new DenseVectorSerializer(); + for (int len : lens) { + double[] arr = new double[len]; + for (int i = 0; i < len; i += 1) { + arr[i] = random.nextDouble(); + } + DenseVector expected = new DenseVector(arr); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + serializer.serialize(expected, new DataOutputViewStreamWrapper(baos)); + DenseVector actual = + serializer.deserialize( + new DataInputViewStreamWrapper( + new ByteArrayInputStream(baos.toByteArray()))); + Assert.assertEquals(expected, actual); + } + } +} diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializerTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializerTest.java new file mode 100644 index 000000000..24225c069 --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializerTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Random; + +/** + * Tests the serialization and deserialization for the double array from {@link + * OptimizedDoublePrimitiveArraySerializer}. + */ +public class OptimizedDoublePrimitiveArraySerializerTest { + @Test + public void testSerializationDeserialization() throws IOException { + Random random = new Random(0); + int[] lens = new int[] {0, 100, 128, 500, 1024, 4096}; + + OptimizedDoublePrimitiveArraySerializer serializer = + new OptimizedDoublePrimitiveArraySerializer(); + for (int len : lens) { + double[] arr = new double[len]; + for (int i = 0; i < len; i += 1) { + arr[i] = random.nextDouble(); + } + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + serializer.serialize(arr, new DataOutputViewStreamWrapper(baos)); + double[] actual = + serializer.deserialize( + new DataInputViewStreamWrapper( + new ByteArrayInputStream(baos.toByteArray()))); + Assert.assertArrayEquals(arr, actual, 0.); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java index 497460327..b939bd7fc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java @@ -19,6 +19,8 @@ package org.apache.flink.ml.common.gbt.defs; import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.common.gbt.typeinfo.HistogramTypeInfoFactory; import org.apache.flink.util.Preconditions; import java.io.Serializable; @@ -26,6 +28,7 @@ /** * This class stores values of histogram bins, and necessary information of reducing and scattering. */ +@TypeInfo(HistogramTypeInfoFactory.class) public class Histogram implements Serializable { // Stores source subtask ID when reducing or target subtask ID when scattering. diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java index 9c4648399..02b0628f4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java @@ -22,11 +22,11 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; -import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArraySerializer; import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.linalg.typeinfo.OptimizedDoublePrimitiveArraySerializer; import java.io.IOException; @@ -36,7 +36,8 @@ public final class HistogramSerializer extends TypeSerializerSingleton { + + public static final HistogramTypeInfo INSTANCE = new HistogramTypeInfo(); + + private HistogramTypeInfo() {} + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 3; + } + + @Override + public int getTotalFields() { + return 3; + } + + @Override + public Class getTypeClass() { + return Histogram.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new HistogramSerializer(); + } + + @Override + public String toString() { + return "Histogram"; + } + + @Override + public boolean equals(Object o) { + return o instanceof HistogramTypeInfo; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object o) { + return o instanceof HistogramTypeInfo; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfoFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfoFactory.java new file mode 100644 index 000000000..e956e20a4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramTypeInfoFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.gbt.defs.Histogram; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * Histogram}. + */ +public class HistogramTypeInfoFactory extends TypeInfoFactory { + + @Override + public TypeInformation createTypeInfo( + Type t, Map> genericParameters) { + return HistogramTypeInfo.INSTANCE; + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java index 5b6f984aa..de2a882d9 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java @@ -25,7 +25,6 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.util.Bits; import java.io.IOException; import java.util.Arrays; @@ -38,7 +37,8 @@ public final class DenseVectorSerializer extends TypeSerializer { private static final double[] EMPTY = new double[0]; - private final byte[] buf = new byte[1024]; + private final OptimizedDoublePrimitiveArraySerializer valuesSerializer = + new OptimizedDoublePrimitiveArraySerializer(); @Override public boolean isImmutableType() { @@ -79,53 +79,21 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept if (vector == null) { throw new IllegalArgumentException("The vector must not be null."); } - - final int len = vector.values.length; - target.writeInt(len); - - for (int i = 0; i < len; i++) { - Bits.putDouble(buf, (i & 127) << 3, vector.values[i]); - if ((i & 127) == 127) { - target.write(buf); - } - } - target.write(buf, 0, (len & 127) << 3); + valuesSerializer.serialize(vector.values, target); } @Override public DenseVector deserialize(DataInputView source) throws IOException { - int len = source.readInt(); - double[] values = new double[len]; - readDoubleArray(values, source, len); - return new DenseVector(values); - } - - // Reads `len` double values from `source` into `dst`. - private void readDoubleArray(double[] dst, DataInputView source, int len) throws IOException { - int index = 0; - for (int i = 0; i < (len >> 7); i++) { - source.readFully(buf, 0, 1024); - for (int j = 0; j < 128; j++) { - dst[index++] = Bits.getDouble(buf, j << 3); - } - } - source.readFully(buf, 0, (len << 3) & 1023); - for (int j = 0; j < (len & 127); j++) { - dst[index++] = Bits.getDouble(buf, j << 3); - } + return new DenseVector(valuesSerializer.deserialize(source)); } @Override public DenseVector deserialize(DenseVector reuse, DataInputView source) throws IOException { int len = source.readInt(); if (len == reuse.values.length) { - readDoubleArray(reuse.values, source, len); - return reuse; + valuesSerializer.deserialize(reuse.values, source); } - - double[] values = new double[len]; - readDoubleArray(values, source, len); - return new DenseVector(values); + return new DenseVector(valuesSerializer.deserialize(source)); } @Override diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java new file mode 100644 index 000000000..9264be2ce --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.util.Bits; + +import java.io.IOException; +import java.util.Objects; + +/** A serializer for double arrays. */ +@Internal +public final class OptimizedDoublePrimitiveArraySerializer extends TypeSerializer { + + private static final long serialVersionUID = 1L; + + private static final double[] EMPTY = new double[0]; + + private static final int BUFFER_SIZE = 1024; + private final byte[] buf = new byte[BUFFER_SIZE]; + + public OptimizedDoublePrimitiveArraySerializer() {} + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return new OptimizedDoublePrimitiveArraySerializer(); + } + + @Override + public double[] createInstance() { + return EMPTY; + } + + @Override + public double[] copy(double[] from) { + double[] copy = new double[from.length]; + System.arraycopy(from, 0, copy, 0, from.length); + return copy; + } + + @Override + public double[] copy(double[] from, double[] reuse) { + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(double[] record, DataOutputView target) throws IOException { + if (record == null) { + throw new IllegalArgumentException("The record must not be null."); + } + final int len = record.length; + target.writeInt(len); + for (int i = 0; i < len; i++) { + Bits.putDouble(buf, (i & 127) << 3, record[i]); + if ((i & 127) == 127) { + target.write(buf); + } + } + target.write(buf, 0, (len & 127) << 3); + } + + @Override + public double[] deserialize(DataInputView source) throws IOException { + final int len = source.readInt(); + double[] result = new double[len]; + readDoubleArray(len, result, source); + return result; + } + + public void readDoubleArray(int len, double[] result, DataInputView source) throws IOException { + int index = 0; + for (int i = 0; i < (len >> 7); i++) { + source.readFully(buf, 0, 1024); + for (int j = 0; j < 128; j++) { + result[index++] = Bits.getDouble(buf, j << 3); + } + } + source.readFully(buf, 0, (len & 127) << 3); + for (int j = 0; j < (len & 127); j++) { + result[index++] = Bits.getDouble(buf, j << 3); + } + } + + @Override + public double[] deserialize(double[] reuse, DataInputView source) throws IOException { + int len = source.readInt(); + if (len == reuse.length) { + readDoubleArray(len, reuse, source); + return reuse; + } + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + final int len = source.readInt(); + target.writeInt(len); + target.write(source, len * Double.BYTES); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof OptimizedDoublePrimitiveArraySerializer; + } + + @Override + public int hashCode() { + return Objects.hashCode(OptimizedDoublePrimitiveArraySerializer.class); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new DoublePrimitiveArraySerializerSnapshot(); + } + + // ------------------------------------------------------------------------ + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class DoublePrimitiveArraySerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public DoublePrimitiveArraySerializerSnapshot() { + super(OptimizedDoublePrimitiveArraySerializer::new); + } + } +} From 74fb657078bc802ae91b79ebc6d2127ce06eda16 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 20 Feb 2023 17:36:48 +0800 Subject: [PATCH 15/47] Fixed checkpoint problem. --- .../ml/common/gbt/defs/LearningNode.java | 4 +- .../CacheDataCalcLocalHistsOperator.java | 85 +++++++------ .../operators/CalcLocalSplitsOperator.java | 84 ++++++++----- .../ml/common/gbt/operators/HistBuilder.java | 20 ++- .../common/gbt/operators/InstanceUpdater.java | 10 +- .../gbt/operators/PostSplitsOperator.java | 115 +++++++++++------- .../gbt/operators/TerminationOperator.java | 4 +- 7 files changed, 185 insertions(+), 137 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java index a97c6d435..3afb8b05c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java @@ -24,9 +24,9 @@ public class LearningNode { // The node index in `currentTreeNodes` used in `PostSplitsOperator`. public int nodeIndex; // Slice of indices of bagging instances. - public Slice slice; + public Slice slice = new Slice(); // Slice of indices of non-bagging instances. - public Slice oob; + public Slice oob = new Slice(); // Depth of corresponding tree node. public int depth; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 741a8ccfc..1e6ad9ea1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -18,8 +18,6 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; @@ -50,8 +48,6 @@ import org.apache.flink.util.Preconditions; import org.apache.commons.collections.IteratorUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.util.Arrays; import java.util.Collections; @@ -64,8 +60,6 @@ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator implements TwoInputStreamOperator, IterationListener { - private static final Logger LOG = - LoggerFactory.getLogger(CacheDataCalcLocalHistsOperator.class); private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; private static final String HIST_BUILDER_STATE_NAME = "hist_builder"; @@ -75,20 +69,22 @@ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator instancesCollecting; - private transient ListState treeInitializer; - private transient ListState histBuilder; + private transient ListStateWithCache treeInitializerState; + private transient TreeInitializer treeInitializer; + private transient ListStateWithCache histBuilderState; + private transient HistBuilder histBuilder; // Readers/writers of shared data. private transient IterationSharedStorage.Writer instancesWriter; private transient IterationSharedStorage.Reader pghReader; private transient IterationSharedStorage.Writer shuffledIndicesWriter; private transient IterationSharedStorage.Reader swappedIndicesReader; - private IterationSharedStorage.Writer nodeFeaturePairsWriter; - private IterationSharedStorage.Reader> layerReader; - private IterationSharedStorage.Writer rootLearningNodeWriter; - private IterationSharedStorage.Reader needInitTreeReader; - private IterationSharedStorage.Writer hasInitedTreeWriter; - private IterationSharedStorage.Writer trainContextWriter; + private transient IterationSharedStorage.Writer nodeFeaturePairsWriter; + private transient IterationSharedStorage.Reader> layerReader; + private transient IterationSharedStorage.Writer rootLearningNodeWriter; + private transient IterationSharedStorage.Reader needInitTreeReader; + private transient IterationSharedStorage.Writer hasInitedTreeWriter; + private transient IterationSharedStorage.Writer trainContextWriter; public CacheDataCalcLocalHistsOperator(GbtParams gbtParams, IterationID iterationID) { super(); @@ -107,16 +103,27 @@ public void initializeState(StateInitializationContext context) throws Exception getRuntimeContext(), context, getOperatorID()); + treeInitializerState = + new ListStateWithCache<>( + new KryoSerializer<>(TreeInitializer.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); treeInitializer = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - TREE_INITIALIZER_STATE_NAME, TreeInitializer.class)); + OperatorStateUtils.getUniqueElement( + treeInitializerState, TREE_INITIALIZER_STATE_NAME) + .orElse(null); + histBuilderState = + new ListStateWithCache<>( + new KryoSerializer<>(HistBuilder.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); histBuilder = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - HIST_BUILDER_STATE_NAME, HistBuilder.class)); + OperatorStateUtils.getUniqueElement(histBuilderState, HIST_BUILDER_STATE_NAME) + .orElse(null); int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); instancesWriter = @@ -196,9 +203,15 @@ public void initializeState(StateInitializationContext context) throws Exception public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); instancesCollecting.snapshotState(context); + treeInitializerState.snapshotState(context); + histBuilderState.snapshotState(context); + instancesWriter.snapshotState(context); shuffledIndicesWriter.snapshotState(context); + nodeFeaturePairsWriter.snapshotState(context); + rootLearningNodeWriter.snapshotState(context); hasInitedTreeWriter.snapshotState(context); + trainContextWriter.snapshotState(context); } @Override @@ -251,8 +264,10 @@ public void onEpochWatermarkIncremented( instancesWriter.get()); trainContextWriter.set(trainContext); - treeInitializer.update(Collections.singletonList(new TreeInitializer(trainContext))); - histBuilder.update(Collections.singletonList(new HistBuilder(trainContext))); + treeInitializer = new TreeInitializer(trainContext); + treeInitializerState.update(Collections.singletonList(treeInitializer)); + histBuilder = new HistBuilder(trainContext); + histBuilderState.update(Collections.singletonList(histBuilder)); } TrainContext trainContext = trainContextWriter.get(); @@ -276,15 +291,9 @@ public void onEpochWatermarkIncremented( int[] indices; if (needInitTreeReader.get()) { - TreeInitializer treeInit = - OperatorStateUtils.getUniqueElement( - treeInitializer, TREE_INITIALIZER_STATE_NAME) - .get(); - // When last tree is finished, initializes a new tree, and shuffle instance indices. - treeInit.init(shuffledIndicesWriter::set); - - LearningNode rootLearningNode = treeInit.getRootLearningNode(); + treeInitializer.init(shuffledIndicesWriter::set); + LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); indices = shuffledIndicesWriter.get(); rootLearningNodeWriter.set(rootLearningNode); hasInitedTreeWriter.set(true); @@ -301,17 +310,15 @@ public void onEpochWatermarkIncremented( } Histogram localHists = - OperatorStateUtils.getUniqueElement(histBuilder, HIST_BUILDER_STATE_NAME) - .get() - .build(layer, indices, instances, pgh, nodeFeaturePairsWriter::set); + histBuilder.build(layer, indices, instances, pgh, nodeFeaturePairsWriter::set); out.collect(localHists); } @Override public void onIterationTerminated(Context context, Collector collector) { instancesCollecting.clear(); - treeInitializer.clear(); - histBuilder.clear(); + treeInitializerState.clear(); + histBuilderState.clear(); instancesWriter.set(new BinnedInstance[0]); shuffledIndicesWriter.set(new int[0]); @@ -320,6 +327,10 @@ public void onIterationTerminated(Context context, Collector collecto @Override public void close() throws Exception { + instancesCollecting.clear(); + treeInitializerState.clear(); + histBuilderState.clear(); + instancesWriter.remove(); shuffledIndicesWriter.remove(); nodeFeaturePairsWriter.remove(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java index d32015679..dcc68ee3b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -18,17 +18,19 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Splits; import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.typeinfo.HistogramSerializer; import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -41,18 +43,23 @@ public class CalcLocalSplitsOperator extends AbstractStreamOperator implements OneInputStreamOperator, IterationListener { - private static final String CALC_BEST_SPLIT_STATE_NAME = "split_finder"; + private static final String SPLIT_FINDER_STATE_NAME = "split_finder"; private static final String HISTOGRAM_STATE_NAME = "histogram"; private final IterationID iterationID; - private transient ListState splitFinder; - private transient ListState histogram; - private IterationSharedStorage.Reader nodeFeaturePairsReader; - private IterationSharedStorage.Reader> leavesReader; - private IterationSharedStorage.Reader> layerReader; - private IterationSharedStorage.Reader rootLearningNodeReader; - private IterationSharedStorage.Reader trainContextReader; + // States of local data. + private transient ListStateWithCache splitFinderState; + private transient SplitFinder splitFinder; + private transient ListStateWithCache histogramState; + private transient Histogram histogram; + + // Readers/writers of shared data. + private transient IterationSharedStorage.Reader nodeFeaturePairsReader; + private transient IterationSharedStorage.Reader> leavesReader; + private transient IterationSharedStorage.Reader> layerReader; + private transient IterationSharedStorage.Reader rootLearningNodeReader; + private transient IterationSharedStorage.Reader trainContextReader; public CalcLocalSplitsOperator(IterationID iterationID) { this.iterationID = iterationID; @@ -61,15 +68,27 @@ public CalcLocalSplitsOperator(IterationID iterationID) { @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); + splitFinderState = + new ListStateWithCache<>( + new KryoSerializer<>(SplitFinder.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); splitFinder = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - CALC_BEST_SPLIT_STATE_NAME, SplitFinder.class)); + OperatorStateUtils.getUniqueElement(splitFinderState, SPLIT_FINDER_STATE_NAME) + .orElse(null); + + histogramState = + new ListStateWithCache<>( + new HistogramSerializer(), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); histogram = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>(HISTOGRAM_STATE_NAME, Histogram.class)); + OperatorStateUtils.getUniqueElement(histogramState, HISTOGRAM_STATE_NAME) + .orElse(null); int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); nodeFeaturePairsReader = @@ -84,13 +103,20 @@ public void initializeState(StateInitializationContext context) throws Exception IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.TRAIN_CONTEXT); } + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + splitFinderState.snapshotState(context); + histogramState.snapshotState(context); + } + @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { if (0 == epochWatermark) { - splitFinder.update( - Collections.singletonList(new SplitFinder(trainContextReader.get()))); + splitFinder = new SplitFinder(trainContextReader.get()); + splitFinderState.update(Collections.singletonList(splitFinder)); } List layer = layerReader.get(); @@ -99,22 +125,20 @@ public void onEpochWatermarkIncremented( } Splits splits = - OperatorStateUtils.getUniqueElement(splitFinder, CALC_BEST_SPLIT_STATE_NAME) - .get() - .calc( - layer, - nodeFeaturePairsReader.get(), - leavesReader.get().size(), - OperatorStateUtils.getUniqueElement(histogram, HISTOGRAM_STATE_NAME) - .get()); + splitFinder.calc( + layer, nodeFeaturePairsReader.get(), leavesReader.get().size(), histogram); collector.collect(splits); } @Override - public void onIterationTerminated(Context context, Collector collector) {} + public void processElement(StreamRecord element) throws Exception { + histogram = element.getValue(); + histogramState.update(Collections.singletonList(histogram)); + } @Override - public void processElement(StreamRecord element) throws Exception { - histogram.update(Collections.singletonList(element.getValue())); + public void onIterationTerminated(Context context, Collector collector) { + splitFinderState.clear(); + histogramState.clear(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index d7fcd093e..ecaa2a1c9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -54,7 +54,8 @@ class HistBuilder { private final boolean isInputVector; - private final double[] hists; + private final int maxFeatureBins; + private final int totalNumFeatureBins; public HistBuilder(TrainContext trainContext) { subtaskId = trainContext.subtaskId; @@ -69,16 +70,8 @@ public HistBuilder(TrainContext trainContext) { isInputVector = trainContext.params.isInputVector; - int maxNumNodes = - Math.min( - ((int) Math.pow(2, trainContext.params.maxDepth - 1)), - trainContext.params.maxNumLeaves); - - int maxFeatureBins = Arrays.stream(numFeatureBins).max().orElse(0); - int totalNumFeatureBins = Arrays.stream(numFeatureBins).sum(); - int maxNumBins = - maxNumNodes * Math.min(maxFeatureBins * numBaggingFeatures, totalNumFeatureBins); - hists = new double[maxNumBins * BIN_SIZE]; + maxFeatureBins = Arrays.stream(numFeatureBins).max().orElse(0); + totalNumFeatureBins = Arrays.stream(numFeatureBins).sum(); } /** @@ -95,7 +88,6 @@ private static void calcNodeFeaturePairHists( BinnedInstance[] instances, PredGradHess[] pgh, double[] hists) { - int numNodes = layer.size(); int numFeatures = featureMetas.length; @@ -272,7 +264,9 @@ Histogram build( } nodeFeaturePairsSetter.accept(nodeFeaturePairs); - Arrays.fill(hists, 0); + int maxNumBins = + numNodes * Math.min(maxFeatureBins * numBaggingFeatures, totalNumFeatureBins); + double[] hists = new double[maxNumBins * BIN_SIZE]; // Calculates histograms for (nodeId, featureId) pairs. calcNodeFeaturePairHists( layer, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java index 1ddb01369..9b00ac57b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -38,35 +38,31 @@ class InstanceUpdater { private final int subtaskId; private final Loss loss; private final double stepSize; - private final PredGradHess[] pgh; private final double prior; - private boolean initialized; - public InstanceUpdater(TrainContext trainContext) { subtaskId = trainContext.subtaskId; loss = trainContext.loss; stepSize = trainContext.params.stepSize; prior = trainContext.prior; - pgh = new PredGradHess[trainContext.numInstances]; - initialized = false; } public void update( + PredGradHess[] pgh, List leaves, int[] indices, BinnedInstance[] instances, Consumer pghSetter, List treeNodes) { LOG.info("subtaskId: {}, {} start", subtaskId, InstanceUpdater.class.getSimpleName()); - if (!initialized) { + if (pgh.length == 0) { + pgh = new PredGradHess[instances.length]; for (int i = 0; i < instances.length; i += 1) { double label = instances[i].label; pgh[i] = new PredGradHess( prior, loss.gradient(prior, label), loss.hessian(prior, label)); } - initialized = true; } for (LearningNode nodeInfo : leaves) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 2b526ba1d..72e9b11cc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -18,14 +18,14 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; import org.apache.flink.api.common.typeutils.base.ListSerializer; import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; @@ -62,21 +62,26 @@ public class PostSplitsOperator extends AbstractStreamOperator private final IterationID iterationID; - private IterationSharedStorage.Reader instancesReader; - private IterationSharedStorage.Writer pghWriter; - private IterationSharedStorage.Reader shuffledIndicesReader; - private IterationSharedStorage.Writer swappedIndicesWriter; + // States of local data. + private transient ListStateWithCache splitsState; + private transient Splits splits; + private transient ListStateWithCache nodeSplitterState; + private transient NodeSplitter nodeSplitter; + private transient ListStateWithCache instanceUpdaterState; + private transient InstanceUpdater instanceUpdater; - private transient ListState splits; - private transient ListState nodeSplitter; - private transient ListState instanceUpdater; - private IterationSharedStorage.Writer> leavesWriter; - private IterationSharedStorage.Writer> layerWriter; - private IterationSharedStorage.Reader rootLearningNodeReader; - private IterationSharedStorage.Writer>> allTreesWriter; - private IterationSharedStorage.Writer> currentTreeNodesWriter; - private IterationSharedStorage.Writer needInitTreeWriter; - private IterationSharedStorage.Reader trainContextReader; + // Readers/writers of shared data. + private transient IterationSharedStorage.Reader instancesReader; + private transient IterationSharedStorage.Writer pghWriter; + private transient IterationSharedStorage.Reader shuffledIndicesReader; + private transient IterationSharedStorage.Writer swappedIndicesWriter; + private transient IterationSharedStorage.Writer> leavesWriter; + private transient IterationSharedStorage.Writer> layerWriter; + private transient IterationSharedStorage.Reader rootLearningNodeReader; + private transient IterationSharedStorage.Writer>> allTreesWriter; + private transient IterationSharedStorage.Writer> currentTreeNodesWriter; + private transient IterationSharedStorage.Writer needInitTreeWriter; + private transient IterationSharedStorage.Reader trainContextReader; public PostSplitsOperator(IterationID iterationID) { this.iterationID = iterationID; @@ -86,19 +91,35 @@ public PostSplitsOperator(IterationID iterationID) { public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - splits = - context.getOperatorStateStore() - .getListState(new ListStateDescriptor<>(SPLITS_STATE_NAME, Splits.class)); + splitsState = + new ListStateWithCache<>( + new KryoSerializer<>(Splits.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); + splits = OperatorStateUtils.getUniqueElement(splitsState, SPLITS_STATE_NAME).orElse(null); + nodeSplitterState = + new ListStateWithCache<>( + new KryoSerializer<>(NodeSplitter.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); nodeSplitter = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - NODE_SPLITTER_STATE_NAME, NodeSplitter.class)); + OperatorStateUtils.getUniqueElement(nodeSplitterState, NODE_SPLITTER_STATE_NAME) + .orElse(null); + instanceUpdaterState = + new ListStateWithCache<>( + new KryoSerializer<>(InstanceUpdater.class, getExecutionConfig()), + getContainingTask(), + getRuntimeContext(), + context, + getOperatorID()); instanceUpdater = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - INSTANCE_UPDATER_STATE_NAME, InstanceUpdater.class)); + OperatorStateUtils.getUniqueElement( + instanceUpdaterState, INSTANCE_UPDATER_STATE_NAME) + .orElse(null); int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); pghWriter = @@ -185,9 +206,16 @@ public void initializeState(StateInitializationContext context) throws Exception @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); + splitsState.snapshotState(context); + nodeSplitterState.snapshotState(context); + instanceUpdaterState.snapshotState(context); + pghWriter.snapshotState(context); swappedIndicesWriter.snapshotState(context); leavesWriter.snapshotState(context); + layerWriter.snapshotState(context); + allTreesWriter.snapshotState(context); + currentTreeNodesWriter.snapshotState(context); needInitTreeWriter.snapshotState(context); } @@ -196,10 +224,10 @@ public void snapshotState(StateSnapshotContext context) throws Exception { public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { if (0 == epochWatermark) { - nodeSplitter.update( - Collections.singletonList(new NodeSplitter(trainContextReader.get()))); - instanceUpdater.update( - Collections.singletonList(new InstanceUpdater(trainContextReader.get()))); + nodeSplitter = new NodeSplitter(trainContextReader.get()); + nodeSplitterState.update(Collections.singletonList(nodeSplitter)); + instanceUpdater = new InstanceUpdater(trainContextReader.get()); + instanceUpdaterState.update(Collections.singletonList(instanceUpdater)); } int[] indices = swappedIndicesWriter.get(); @@ -220,26 +248,16 @@ public void onEpochWatermarkIncremented( } List nextLayer = - OperatorStateUtils.getUniqueElement(nodeSplitter, NODE_SPLITTER_STATE_NAME) - .get() - .split( - currentTreeNodes, - layer, - leaves, - OperatorStateUtils.getUniqueElement(splits, SPLITS_STATE_NAME) - .get() - .splits, - indices, - instances); + nodeSplitter.split( + currentTreeNodes, layer, leaves, splits.splits, indices, instances); leavesWriter.set(leaves); layerWriter.set(nextLayer); currentTreeNodesWriter.set(currentTreeNodes); if (nextLayer.isEmpty()) { needInitTreeWriter.set(true); - OperatorStateUtils.getUniqueElement(instanceUpdater, INSTANCE_UPDATER_STATE_NAME) - .get() - .update(leaves, indices, instances, pghWriter::set, currentTreeNodes); + instanceUpdater.update( + pghWriter.get(), leaves, indices, instances, pghWriter::set, currentTreeNodes); leaves.clear(); List> allTrees = allTreesWriter.get(); allTrees.add(currentTreeNodes); @@ -264,11 +282,16 @@ public void onIterationTerminated(Context context, Collector collector) @Override public void processElement(StreamRecord element) throws Exception { - splits.update(Collections.singletonList(element.getValue())); + splits = element.getValue(); + splitsState.update(Collections.singletonList(splits)); } @Override public void close() throws Exception { + splitsState.clear(); + nodeSplitterState.clear(); + instanceUpdaterState.clear(); + pghWriter.remove(); swappedIndicesWriter.remove(); leavesWriter.remove(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index 7caf81078..dffa53b48 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -38,8 +38,8 @@ public class TerminationOperator extends AbstractStreamOperator private final IterationID iterationID; private final OutputTag modelDataOutputTag; - private IterationSharedStorage.Reader>> allTreesReader; - private IterationSharedStorage.Reader trainContextReader; + private transient IterationSharedStorage.Reader>> allTreesReader; + private transient IterationSharedStorage.Reader trainContextReader; public TerminationOperator( IterationID iterationID, OutputTag modelDataOutputTag) { From 99061bc6d5cc06e47aeab203043ee7fe343d5a88 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 21 Feb 2023 11:13:59 +0800 Subject: [PATCH 16/47] Rewrite shared storage. --- .../common/sharedstorage/ItemDescriptor.java | 75 +++++ .../common/sharedstorage/SharedStorage.java | 154 ++++++++++ .../sharedstorage/SharedStorageBody.java | 92 ++++++ .../sharedstorage/SharedStorageContext.java | 80 +++++ .../SharedStorageContextImpl.java | 117 +++++++ .../SharedStorageStreamOperator.java | 38 +++ .../sharedstorage/SharedStorageUtils.java | 86 ++++++ .../ml/common/sharedstorage/StorageID.java | 32 ++ .../AbstractSharedStorageWrapperOperator.java | 284 +++++++++++++++++ .../OneInputSharedStorageWrapperOperator.java | 74 +++++ .../operator/SharedStorageWrapper.java | 113 +++++++ .../TwoInputSharedStorageWrapperOperator.java | 91 ++++++ .../ml/common/gbt/BoostIterationBody.java | 79 +++-- .../apache/flink/ml/common/gbt/GBTRunner.java | 4 +- .../datastorage/IterationSharedStorage.java | 187 ----------- .../ml/common/gbt/defs/LearningNode.java | 4 +- .../flink/ml/common/gbt/defs/Slice.java | 4 +- .../ml/common/gbt/defs/TrainContext.java | 3 +- .../CacheDataCalcLocalHistsOperator.java | 290 +++++++----------- .../operators/CalcLocalSplitsOperator.java | 109 +++---- .../gbt/operators/PostSplitsOperator.java | 270 ++++++---------- .../ml/common/gbt/operators/SharedKeys.java | 63 ---- .../gbt/operators/SharedStorageConstants.java | 161 ++++++++++ .../gbt/operators/TerminationOperator.java | 75 +++-- .../flink/ml/common/gbt/GBTRunnerTest.java | 2 + 25 files changed, 1765 insertions(+), 722 deletions(-) create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/ItemDescriptor.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageStreamOperator.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/StorageID.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/AbstractSharedStorageWrapperOperator.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/OneInputSharedStorageWrapperOperator.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/TwoInputSharedStorageWrapperOperator.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/ItemDescriptor.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/ItemDescriptor.java new file mode 100644 index 000000000..8848d2ba4 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/ItemDescriptor.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import java.io.Serializable; + +/** + * Descriptor for a shared item. + * + * @param The type of the shared item. + */ +@Experimental +public class ItemDescriptor implements Serializable { + + /** Name of the item. */ + public String key; + + /** Type serializer. */ + public TypeSerializer serializer; + + /** Initialize value. */ + public T initVal; + + private ItemDescriptor(String key, TypeSerializer serializer, T initVal) { + this.key = key; + this.serializer = serializer; + this.initVal = initVal; + } + + public static ItemDescriptor of(String key, TypeSerializer serializer, T initVal) { + return new ItemDescriptor<>(key, serializer, initVal); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ItemDescriptor that = (ItemDescriptor) o; + return key.equals(that.key); + } + + @Override + public String toString() { + return String.format( + "ItemDescriptor{key='%s', serializer=%s, initVal=%s}", key, serializer, initVal); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java new file mode 100644 index 000000000..591bb9481 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.Preconditions; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** A shared storage to support access through subtasks of different operators. */ +class SharedStorage { + private static final Map, Object> m = + new ConcurrentHashMap<>(); + + private static final Map, String> owners = + new ConcurrentHashMap<>(); + + /** Gets a {@link Reader} of shared item identified by (storageID, subtaskId, descriptor). */ + static Reader getReader( + StorageID storageID, int subtaskId, ItemDescriptor descriptor) { + return new Reader<>(Tuple3.of(storageID, subtaskId, descriptor.key)); + } + + /** Gets a {@link Writer} of shared item identified by (storageID, subtaskId, key). */ + static Writer getWriter( + StorageID storageID, + int subtaskId, + ItemDescriptor descriptor, + String ownerId, + OperatorID operatorID, + StreamTask containingTask, + StreamingRuntimeContext runtimeContext, + StateInitializationContext stateInitializationContext) { + Tuple3 t = Tuple3.of(storageID, subtaskId, descriptor.key); + String lastOwner = owners.putIfAbsent(t, ownerId); + if (null != lastOwner) { + throw new IllegalStateException( + String.format( + "The shared item (%s, %s, %s) already has a writer %s.", + storageID, subtaskId, descriptor.key, ownerId)); + } + Writer writer = + new Writer<>( + t, + ownerId, + descriptor.serializer, + containingTask, + runtimeContext, + stateInitializationContext, + operatorID); + writer.set(descriptor.initVal); + return writer; + } + + static class Reader { + protected final Tuple3 t; + + Reader(Tuple3 t) { + this.t = t; + } + + T get() { + //noinspection unchecked + return (T) m.get(t); + } + } + + static class Writer extends Reader { + private final String ownerId; + private final ListStateWithCache cache; + + Writer( + Tuple3 t, + String ownerId, + TypeSerializer serializer, + StreamTask containingTask, + StreamingRuntimeContext runtimeContext, + StateInitializationContext stateInitializationContext, + OperatorID operatorID) { + super(t); + this.ownerId = ownerId; + try { + cache = + new ListStateWithCache<>( + serializer, + containingTask, + runtimeContext, + stateInitializationContext, + operatorID); + Iterator iterator = cache.get().iterator(); + if (iterator.hasNext()) { + T value = iterator.next(); + ensureOwner(); + m.put(t, value); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void ensureOwner() { + // Double-checks the owner, because a writer may call this method after the key removed + // and re-added by other operators. + Preconditions.checkState(owners.get(t).equals(ownerId)); + } + + void set(T value) { + ensureOwner(); + m.put(t, value); + try { + cache.update(Collections.singletonList(value)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + void remove() { + ensureOwner(); + m.remove(t); + owners.remove(t); + cache.clear(); + } + + void snapshotState(StateSnapshotContext context) throws Exception { + cache.snapshotState(context); + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java new file mode 100644 index 000000000..a7a59b28c --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.streaming.api.datastream.DataStream; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; + +/** + * The builder of the subgraph that will be executed with a common shared storage. Users can only + * create data streams from {@code inputs}. Users can not refer to data streams outside, and can not + * add sources/sinks. + * + *

The shared storage body requires all streams accessing the shared storage, i.e., {@link + * SharedStorageBodyResult#accessors} have same parallelism and can be co-located. + */ +@Experimental +@FunctionalInterface +public interface SharedStorageBody extends Serializable { + + /** + * This method creates the subgraph for the shared storage body. + * + * @param inputs Input data streams. + * @return Result of the subgraph, including output data streams, data streams with access to + * the shared storage, and a mapping from share items to their owners. + */ + SharedStorageBodyResult process(List> inputs); + + /** + * The result of a {@link SharedStorageBody}, including output data streams, data streams with + * access to the shared storage, and a mapping from descriptors of share items to their owners. + */ + @Experimental + class SharedStorageBodyResult { + /** A list of output streams. */ + private final List> outputs; + + /** + * A list of data streams which access to the shared storage. All data streams in the list + * should implement {@link SharedStorageStreamOperator}. + */ + private final List> accessors; + + /** + * A mapping from descriptors of shared items to their owners. The owner is specified by + * {@link SharedStorageStreamOperator#getSharedStorageAccessorID()}, which must be kept + * unchanged for an instance of {@link SharedStorageStreamOperator}. + */ + private final Map, String> ownerMap; + + public SharedStorageBodyResult( + List> outputs, + List> accessors, + Map, String> ownerMap) { + this.outputs = outputs; + this.accessors = accessors; + this.ownerMap = ownerMap; + } + + public List> getOutputs() { + return outputs; + } + + public List> getAccessors() { + return accessors; + } + + public Map, String> getOwnerMap() { + return ownerMap; + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java new file mode 100644 index 000000000..534051332 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.util.function.BiConsumerWithException; + +/** + * Context for shared storage. Every operator implementing {@link SharedStorageStreamOperator} will + * have an instance of this context set by {@link + * SharedStorageStreamOperator#onSharedStorageContextSet} in runtime. User defined logic can be + * invoked through {@link #invoke} with the access to shared items. + * + *

NOTE: The corresponding operator must explicitly invoke + * + *

    + *
  • {@link #initializeState} to initialize this context and possibly restore data items owned + * by itself in {@link StreamOperatorStateHandler.CheckpointedStreamOperator#initializeState}; + *
  • {@link #snapshotState} in order to save data items owned by itself in {@link + * StreamOperatorStateHandler.CheckpointedStreamOperator#snapshotState}; + *
  • {@link #clear()} in order to clear all data items owned by itself in {@link + * StreamOperator#close}. + *
+ */ +@Experimental +public interface SharedStorageContext { + + /** + * Invoke user defined function with provided getters/setters of the shared storage. + * + * @param func User defined function where share items can be accessed through getters/setters. + * @throws Exception Possible exception. + */ + void invoke(BiConsumerWithException func) + throws Exception; + + /** Initializes shared storage context and restores of shared items owned by this operator. */ + & SharedStorageStreamOperator> void initializeState( + T operator, StreamingRuntimeContext runtimeContext, StateInitializationContext context); + + /** Save shared items owned by this operator. */ + void snapshotState(StateSnapshotContext context) throws Exception; + + /** Clear all internal states. */ + void clear(); + + /** Interface of shared item getter. */ + @FunctionalInterface + interface SharedItemGetter { + T get(ItemDescriptor key); + } + + /** Interface of shared item writer. */ + @FunctionalInterface + interface SharedItemSetter { + void set(ItemDescriptor key, T value); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java new file mode 100644 index 000000000..434f78ad3 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.BiConsumerWithException; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** Default implementation of {@link SharedStorageContext} using {@link SharedStorage}. */ +@SuppressWarnings("rawtypes") +class SharedStorageContextImpl implements SharedStorageContext, Serializable { + private final StorageID storageID; + private final Map writers = new HashMap<>(); + private final Map readers = new HashMap<>(); + private Map, String> ownerMap; + + public SharedStorageContextImpl() { + this.storageID = new StorageID(); + } + + public void setOwnerMap(Map, String> ownerMap) { + this.ownerMap = ownerMap; + } + + @Override + public void invoke(BiConsumerWithException func) + throws Exception { + func.accept(this::getSharedItem, this::setSharedItem); + } + + private T getSharedItem(ItemDescriptor key) { + //noinspection unchecked + SharedStorage.Reader reader = readers.get(key); + Preconditions.checkState( + null != reader, + String.format( + "The operator requested to read a shared item %s not owned by itself.", + key)); + return reader.get(); + } + + private void setSharedItem(ItemDescriptor key, T value) { + //noinspection unchecked + SharedStorage.Writer writer = writers.get(key); + Preconditions.checkState( + null != writer, + String.format( + "The operator requested to read a shared item %s not owned by itself.", + key)); + writer.set(value); + } + + @Override + public & SharedStorageStreamOperator> void initializeState( + T operator, + StreamingRuntimeContext runtimeContext, + StateInitializationContext context) { + String ownerId = operator.getSharedStorageAccessorID(); + int subtaskId = runtimeContext.getIndexOfThisSubtask(); + for (Map.Entry, String> entry : ownerMap.entrySet()) { + ItemDescriptor descriptor = entry.getKey(); + if (ownerId.equals(entry.getValue())) { + writers.put( + descriptor, + SharedStorage.getWriter( + storageID, + subtaskId, + descriptor, + ownerId, + operator.getOperatorID(), + operator.getContainingTask(), + runtimeContext, + context)); + } + readers.put(descriptor, SharedStorage.getReader(storageID, subtaskId, descriptor)); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + for (SharedStorage.Writer writer : writers.values()) { + writer.snapshotState(context); + } + } + + @Override + public void clear() { + for (SharedStorage.Writer writer : writers.values()) { + writer.remove(); + } + writers.clear(); + readers.clear(); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageStreamOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageStreamOperator.java new file mode 100644 index 000000000..81d964d11 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageStreamOperator.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +/** Interface for all operators that need to access the shared storage. */ +public interface SharedStorageStreamOperator { + + /** + * Set the shared storage context in runtime. + * + * @param context The shared storage context. + */ + void onSharedStorageContextSet(SharedStorageContext context); + + /** + * Get a unique ID to represent the operator instance. The ID must be kept unchanged through its + * lifetime. + * + * @return A unique ID. + */ + String getSharedStorageAccessorID(); +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java new file mode 100644 index 000000000..9fe21d978 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.iteration.compile.DraftExecutionEnvironment; +import org.apache.flink.ml.common.sharedstorage.operator.SharedStorageWrapper; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +/** Utility class to support {@link SharedStorage} in DataStream. */ +@Experimental +public class SharedStorageUtils { + + /** + * Support read/write access of data in the shared storage from operators which implements + * {@link SharedStorageStreamOperator}. + * + *

In the shared storage `body`, users build the subgraph with data streams only from + * `inputs`, return streams that have access to the shared storage, and return the mapping from + * shared items to their owners. + * + * @param inputs Input data streams. + * @param body User defined logic to build subgraph and to specify owners of every shared data + * item. + * @return The output data streams. + */ + public static List> withSharedStorage( + List> inputs, SharedStorageBody body) { + Preconditions.checkArgument(inputs.size() > 0); + StreamExecutionEnvironment env = inputs.get(0).getExecutionEnvironment(); + String coLocationID = "shared-storage-" + UUID.randomUUID(); + SharedStorageContextImpl context = new SharedStorageContextImpl(); + + DraftExecutionEnvironment draftEnv = + new DraftExecutionEnvironment(env, new SharedStorageWrapper<>(context)); + List> draftSources = + inputs.stream() + .map( + dataStream -> + draftEnv.addDraftSource(dataStream, dataStream.getType())) + .collect(Collectors.toList()); + SharedStorageBody.SharedStorageBodyResult result = body.process(draftSources); + + List> draftOutputs = result.getOutputs(); + context.setOwnerMap(result.getOwnerMap()); + + for (DataStream draftOutput : draftOutputs) { + draftEnv.addOperator(draftOutput.getTransformation()); + } + draftEnv.copyToActualEnvironment(); + + for (DataStream accessor : result.getAccessors()) { + DataStream ds = draftEnv.getActualStream(accessor.getTransformation().getId()); + ds.getTransformation().setCoLocationGroupKey(coLocationID); + } + + List> outputs = new ArrayList<>(); + for (DataStream draftOutput : draftOutputs) { + outputs.add(draftEnv.getActualStream(draftOutput.getId())); + } + return outputs; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/StorageID.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/StorageID.java new file mode 100644 index 000000000..123edcb64 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/StorageID.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.util.AbstractID; + +/** ID of a shared storage. */ +class StorageID extends AbstractID { + private static final long serialVersionUID = 1L; + + public StorageID(byte[] bytes) { + super(bytes); + } + + public StorageID() {} +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/AbstractSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/AbstractSharedStorageWrapperOperator.java new file mode 100644 index 000000000..274fd9a16 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/AbstractSharedStorageWrapperOperator.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage.operator; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.ManagedMemoryUseCase; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext; +import org.apache.flink.metrics.groups.OperatorMetricGroup; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.metrics.groups.InternalOperatorIOMetricGroup; +import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; +import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.StreamOperatorStateContext; +import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler; +import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler.CheckpointedStreamOperator; +import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; +import java.util.Optional; + +/** Base class for the shared storage wrapper operators. */ +abstract class AbstractSharedStorageWrapperOperator> + implements StreamOperator, IterationListener, CheckpointedStreamOperator { + + private static final Logger LOG = + LoggerFactory.getLogger(AbstractSharedStorageWrapperOperator.class); + + protected final StreamOperatorParameters parameters; + + protected final StreamConfig streamConfig; + + protected final StreamTask containingTask; + + protected final Output> output; + + protected final StreamOperatorFactory operatorFactory; + protected final OperatorMetricGroup metrics; + protected final S wrappedOperator; + protected transient StreamOperatorStateHandler stateHandler; + + protected transient InternalTimeServiceManager timeServiceManager; + + @SuppressWarnings({"unchecked", "rawtypes"}) + AbstractSharedStorageWrapperOperator( + StreamOperatorParameters parameters, + StreamOperatorFactory operatorFactory, + SharedStorageContext context) { + this.parameters = Objects.requireNonNull(parameters); + this.streamConfig = Objects.requireNonNull(parameters.getStreamConfig()); + this.containingTask = Objects.requireNonNull(parameters.getContainingTask()); + this.output = Objects.requireNonNull(parameters.getOutput()); + this.operatorFactory = Objects.requireNonNull(operatorFactory); + this.metrics = createOperatorMetricGroup(containingTask.getEnvironment(), streamConfig); + this.wrappedOperator = + (S) + StreamOperatorFactoryUtil.createOperator( + operatorFactory, + (StreamTask) containingTask, + streamConfig, + output, + parameters.getOperatorEventDispatcher()) + .f0; + Preconditions.checkArgument( + wrappedOperator instanceof SharedStorageStreamOperator, + String.format( + "The wrapped operator is not an instance of %s.", + SharedStorageStreamOperator.class.getSimpleName())); + ((SharedStorageStreamOperator) wrappedOperator).onSharedStorageContextSet(context); + } + + private OperatorMetricGroup createOperatorMetricGroup( + Environment environment, StreamConfig streamConfig) { + try { + OperatorMetricGroup operatorMetricGroup = + environment + .getMetricGroup() + .getOrAddOperator( + streamConfig.getOperatorID(), streamConfig.getOperatorName()); + if (streamConfig.isChainEnd()) { + ((InternalOperatorIOMetricGroup) operatorMetricGroup.getIOMetricGroup()) + .reuseOutputMetricsForTask(); + } + return operatorMetricGroup; + } catch (Exception e) { + LOG.warn("An error occurred while instantiating task metrics.", e); + return UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup(); + } + } + + @Override + public void open() throws Exception { + wrappedOperator.open(); + } + + @Override + public void close() throws Exception { + wrappedOperator.close(); + } + + @Override + public void finish() throws Exception { + wrappedOperator.finish(); + } + + @Override + public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { + wrappedOperator.prepareSnapshotPreBarrier(checkpointId); + } + + @Override + public void initializeState(StateInitializationContext stateInitializationContext) + throws Exception {} + + @Override + public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception { + if (wrappedOperator instanceof StreamOperatorStateHandler.CheckpointedStreamOperator) { + ((CheckpointedStreamOperator) wrappedOperator).snapshotState(stateSnapshotContext); + } + } + + @Override + public OperatorSnapshotFutures snapshotState( + long checkpointId, + long timestamp, + CheckpointOptions checkpointOptions, + CheckpointStreamFactory storageLocation) + throws Exception { + return stateHandler.snapshotState( + this, + Optional.ofNullable(timeServiceManager), + streamConfig.getOperatorName(), + checkpointId, + timestamp, + checkpointOptions, + storageLocation, + false); + } + + @Override + public void initializeState(StreamTaskStateInitializer streamTaskStateManager) + throws Exception { + final TypeSerializer keySerializer = + streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader()); + + StreamOperatorStateContext streamOperatorStateContext = + streamTaskStateManager.streamOperatorStateContext( + getOperatorID(), + getClass().getSimpleName(), + parameters.getProcessingTimeService(), + this, + keySerializer, + containingTask.getCancelables(), + metrics, + streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot( + ManagedMemoryUseCase.STATE_BACKEND, + containingTask + .getEnvironment() + .getTaskManagerInfo() + .getConfiguration(), + containingTask.getUserCodeClassLoader()), + false); + stateHandler = + new StreamOperatorStateHandler( + streamOperatorStateContext, + containingTask.getExecutionConfig(), + containingTask.getCancelables()); + stateHandler.initializeOperatorState(this); + + timeServiceManager = streamOperatorStateContext.internalTimerServiceManager(); + + wrappedOperator.initializeState( + (operatorID, + operatorClassName, + processingTimeService, + keyContext, + keySerializerX, + streamTaskCloseableRegistry, + metricGroup, + managedMemoryFraction, + isUsingCustomRawKeyedState) -> + new ProxyStreamOperatorStateContext( + streamOperatorStateContext, + "wrapped-", + CloseableIterator.empty(), + 0)); + } + + @Override + public void setKeyContextElement1(StreamRecord record) throws Exception { + wrappedOperator.setKeyContextElement1(record); + } + + @Override + public void setKeyContextElement2(StreamRecord record) throws Exception { + wrappedOperator.setKeyContextElement2(record); + } + + @Override + public OperatorMetricGroup getMetricGroup() { + return wrappedOperator.getMetricGroup(); + } + + @Override + public OperatorID getOperatorID() { + return wrappedOperator.getOperatorID(); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + wrappedOperator.notifyCheckpointComplete(checkpointId); + } + + @Override + public void notifyCheckpointAborted(long checkpointId) throws Exception { + wrappedOperator.notifyCheckpointAborted(checkpointId); + } + + @Override + public Object getCurrentKey() { + return wrappedOperator.getCurrentKey(); + } + + @Override + public void setCurrentKey(Object key) { + wrappedOperator.setCurrentKey(key); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector collector) throws Exception { + if (wrappedOperator instanceof IterationListener) { + //noinspection unchecked + ((IterationListener) wrappedOperator) + .onEpochWatermarkIncremented(epochWatermark, context, collector); + } + } + + @Override + public void onIterationTerminated(Context context, Collector collector) throws Exception { + if (wrappedOperator instanceof IterationListener) { + //noinspection unchecked + ((IterationListener) wrappedOperator).onIterationTerminated(context, collector); + } + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/OneInputSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/OneInputSharedStorageWrapperOperator.java new file mode 100644 index 000000000..6e4bc0cd4 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/OneInputSharedStorageWrapperOperator.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage.operator; + +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; + +/** Wrapper for {@link OneInputStreamOperator}. */ +class OneInputSharedStorageWrapperOperator + extends AbstractSharedStorageWrapperOperator> + implements OneInputStreamOperator, BoundedOneInput { + + OneInputSharedStorageWrapperOperator( + StreamOperatorParameters parameters, + StreamOperatorFactory operatorFactory, + SharedStorageContext context) { + super(parameters, operatorFactory, context); + } + + @Override + public void processElement(StreamRecord streamRecord) throws Exception { + wrappedOperator.processElement(streamRecord); + } + + @Override + public void endInput() throws Exception { + OperatorUtils.processOperatorOrUdfIfSatisfy( + wrappedOperator, BoundedOneInput.class, BoundedOneInput::endInput); + } + + @Override + public void processWatermark(Watermark watermark) throws Exception { + wrappedOperator.processWatermark(watermark); + } + + @Override + public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws Exception { + wrappedOperator.processWatermarkStatus(watermarkStatus); + } + + @Override + public void processLatencyMarker(LatencyMarker latencyMarker) throws Exception { + wrappedOperator.processLatencyMarker(latencyMarker); + } + + @Override + public void setKeyContextElement(StreamRecord streamRecord) throws Exception { + wrappedOperator.setKeyContextElement(streamRecord); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java new file mode 100644 index 000000000..153ba8980 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage.operator; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.iteration.operator.OperatorWrapper; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.OutputTag; + +/** The operator wrapper for {@link AbstractSharedStorageWrapperOperator}. */ +public class SharedStorageWrapper implements OperatorWrapper { + + /** Shared storage context. */ + private final SharedStorageContext context; + + public SharedStorageWrapper(SharedStorageContext context) { + this.context = context; + } + + @Override + public StreamOperator wrap( + StreamOperatorParameters operatorParameters, + StreamOperatorFactory operatorFactory) { + Class operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (SharedStorageStreamOperator.class.isAssignableFrom(operatorClass)) { + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return new OneInputSharedStorageWrapperOperator<>( + operatorParameters, operatorFactory, context); + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return new TwoInputSharedStorageWrapperOperator<>( + operatorParameters, operatorFactory, context); + } else { + return nowrap(operatorParameters, operatorFactory); + } + } else { + return nowrap(operatorParameters, operatorFactory); + } + } + + public StreamOperator nowrap( + StreamOperatorParameters parameters, StreamOperatorFactory operatorFactory) { + return StreamOperatorFactoryUtil.createOperator( + operatorFactory, + (StreamTask) parameters.getContainingTask(), + OperatorUtils.createWrappedOperatorConfig(parameters.getStreamConfig()), + parameters.getOutput(), + parameters.getOperatorEventDispatcher()) + .f0; + } + + @Override + public Class getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory operatorFactory) { + Class operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return OneInputSharedStorageWrapperOperator.class; + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return TwoInputSharedStorageWrapperOperator.class; + } else { + throw new UnsupportedOperationException( + "Unsupported operator class for shared storage wrapper: " + operatorClass); + } + } + + @Override + public KeySelector wrapKeySelector(KeySelector keySelector) { + return keySelector; + } + + @Override + public StreamPartitioner wrapStreamPartitioner(StreamPartitioner streamPartitioner) { + return streamPartitioner; + } + + @Override + public OutputTag wrapOutputTag(OutputTag outputTag) { + return outputTag; + } + + @Override + public TypeInformation getWrappedTypeInfo(TypeInformation typeInfo) { + return typeInfo; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/TwoInputSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/TwoInputSharedStorageWrapperOperator.java new file mode 100644 index 000000000..03824a48f --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/TwoInputSharedStorageWrapperOperator.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage.operator; + +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; + +/** Wrapper for {@link TwoInputStreamOperator}. */ +class TwoInputSharedStorageWrapperOperator + extends AbstractSharedStorageWrapperOperator> + implements TwoInputStreamOperator, BoundedMultiInput { + + TwoInputSharedStorageWrapperOperator( + StreamOperatorParameters parameters, + StreamOperatorFactory operatorFactory, + SharedStorageContext context) { + super(parameters, operatorFactory, context); + } + + @Override + public void processElement1(StreamRecord streamRecord) throws Exception { + wrappedOperator.processElement1(streamRecord); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + wrappedOperator.processElement2(streamRecord); + } + + @Override + public void endInput(int inputId) throws Exception { + OperatorUtils.processOperatorOrUdfIfSatisfy( + wrappedOperator, + BoundedMultiInput.class, + boundedMultipleInput -> boundedMultipleInput.endInput(inputId)); + } + + @Override + public void processWatermark1(Watermark watermark) throws Exception { + wrappedOperator.processWatermark1(watermark); + } + + @Override + public void processWatermark2(Watermark watermark) throws Exception { + wrappedOperator.processWatermark2(watermark); + } + + @Override + public void processLatencyMarker1(LatencyMarker latencyMarker) throws Exception { + wrappedOperator.processLatencyMarker1(latencyMarker); + } + + @Override + public void processLatencyMarker2(LatencyMarker latencyMarker) throws Exception { + wrappedOperator.processLatencyMarker2(latencyMarker); + } + + @Override + public void processWatermarkStatus1(WatermarkStatus watermarkStatus) throws Exception { + wrappedOperator.processWatermarkStatus1(watermarkStatus); + } + + @Override + public void processWatermarkStatus2(WatermarkStatus watermarkStatus) throws Exception { + wrappedOperator.processWatermarkStatus2(watermarkStatus); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index faa4acf7a..549deb292 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -28,7 +28,6 @@ import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationBody; import org.apache.flink.iteration.IterationBodyResult; -import org.apache.flink.iteration.IterationID; import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.Splits; @@ -37,8 +36,12 @@ import org.apache.flink.ml.common.gbt.operators.CalcLocalSplitsOperator; import org.apache.flink.ml.common.gbt.operators.HistogramAggregateFunction; import org.apache.flink.ml.common.gbt.operators.PostSplitsOperator; +import org.apache.flink.ml.common.gbt.operators.SharedStorageConstants; import org.apache.flink.ml.common.gbt.operators.SplitsAggregateFunction; import org.apache.flink.ml.common.gbt.operators.TerminationOperator; +import org.apache.flink.ml.common.sharedstorage.ItemDescriptor; +import org.apache.flink.ml.common.sharedstorage.SharedStorageBody; +import org.apache.flink.ml.common.sharedstorage.SharedStorageUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.types.Row; @@ -46,35 +49,44 @@ import org.apache.commons.lang3.ArrayUtils; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + /** * Implements iteration body for boosting algorithms. This implementation uses horizontal partition * of data and row-store storage of instances. */ class BoostIterationBody implements IterationBody { - private final IterationID iterationID; private final GbtParams gbtParams; - public BoostIterationBody(IterationID iterationID, GbtParams gbtParams) { - this.iterationID = iterationID; + public BoostIterationBody(GbtParams gbtParams) { this.gbtParams = gbtParams; } - @Override - public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) { - DataStream data = dataStreams.get(0); - DataStream trainContext = variableStreams.get(0); + private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( + List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + //noinspection unchecked + DataStream trainContext = (DataStream) inputs.get(1); - final String coLocationKey = "boosting"; + Map, String> ownerMap = new HashMap<>(); // In 1st round, cache all data. For all rounds calculate local histogram based on // current tree layer. + CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = + new CacheDataCalcLocalHistsOperator(gbtParams); SingleOutputStreamOperator localHists = data.connect(trainContext) .transform( "CacheDataCalcLocalHists", TypeInformation.of(Histogram.class), - new CacheDataCalcLocalHistsOperator(gbtParams, iterationID)); - localHists.getTransformation().setCoLocationGroupKey("coLocationKey"); + cacheDataCalcLocalHistsOp); + for (ItemDescriptor s : SharedStorageConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { + ownerMap.put(s, cacheDataCalcLocalHistsOp.getSharedStorageAccessorID()); + } DataStream globalHists = scatterReduceHistograms(localHists); @@ -82,34 +94,52 @@ public IterationBodyResult process(DataStreamList variableStreams, DataStreamLis globalHists.transform( "CalcLocalSplits", TypeInformation.of(Splits.class), - new CalcLocalSplitsOperator(iterationID)); - localHists.getTransformation().setCoLocationGroupKey(coLocationKey); + new CalcLocalSplitsOperator()); DataStream globalSplits = localSplits.broadcast().flatMap(new SplitsAggregateFunction()); + PostSplitsOperator postSplitsOp = new PostSplitsOperator(); SingleOutputStreamOperator updatedModelData = globalSplits .broadcast() - .transform( - "PostSplits", - TypeInformation.of(Integer.class), - new PostSplitsOperator(iterationID)); - updatedModelData.getTransformation().setCoLocationGroupKey(coLocationKey); + .transform("PostSplits", TypeInformation.of(Integer.class), postSplitsOp); + for (ItemDescriptor descriptor : SharedStorageConstants.OWNED_BY_POST_SPLITS_OP) { + ownerMap.put(descriptor, postSplitsOp.getSharedStorageAccessorID()); + } - final OutputTag modelDataOutputTag = + final OutputTag finalModelDataOutputTag = new OutputTag<>("model_data", TypeInformation.of(GBTModelData.class)); SingleOutputStreamOperator termination = updatedModelData.transform( - "check_termination", + "CheckTermination", Types.INT, - new TerminationOperator(iterationID, modelDataOutputTag)); - termination.getTransformation().setCoLocationGroupKey(coLocationKey); + new TerminationOperator(finalModelDataOutputTag)); + DataStream finalModelData = + termination.getSideOutput(finalModelDataOutputTag); + + return new SharedStorageBody.SharedStorageBodyResult( + Arrays.asList(updatedModelData, finalModelData, termination), + Arrays.asList(localHists, localSplits, updatedModelData, termination), + ownerMap); + } + + @Override + public IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream data = dataStreams.get(0); + DataStream trainContext = variableStreams.get(0); + + List> outputs = + SharedStorageUtils.withSharedStorage( + Arrays.asList(data, trainContext), this::sharedStorageBody); + DataStream updatedModelData = outputs.get(0); + DataStream finalModelData = outputs.get(1); + DataStream termination = outputs.get(2); return new IterationBodyResult( DataStreamList.of( updatedModelData.flatMap( (d, out) -> {}, TypeInformation.of(TrainContext.class))), - DataStreamList.of(termination.getSideOutput(modelDataOutputTag)), + DataStreamList.of(finalModelData), termination); } @@ -143,8 +173,7 @@ public int partition(Integer key, int numPartitions) { }, new KeySelector, Integer>() { @Override - public Integer getKey(Tuple2 value) - throws Exception { + public Integer getKey(Tuple2 value) { return value.f0; } }) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index 2630111a1..4d600f94f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -23,7 +23,6 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationConfig; -import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.ReplayableDataStreamList; import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierParams; @@ -114,7 +113,6 @@ private static DataStream boost( }); DataStream data = tEnv.toDataStream(dataTable); - final IterationID iterationID = new IterationID(); DataStreamList dataStreamList = Iterations.iterateBoundedStreamsUntilTermination( DataStreamList.of(initTrainContext.broadcast()), @@ -122,7 +120,7 @@ private static DataStream boost( IterationConfig.newBuilder() .setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND) .build(), - new BoostIterationBody(iterationID, p)); + new BoostIterationBody(p)); return dataStreamList.get(0); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java deleted file mode 100644 index e9922bf72..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/datastorage/IterationSharedStorage.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.datastorage; - -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.iteration.IterationID; -import org.apache.flink.runtime.jobgraph.OperatorID; -import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.runtime.state.StatePartitionStreamProvider; -import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.util.Preconditions; - -import org.apache.commons.collections.IteratorUtils; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** A shared storage across subtasks of different operators. */ -public class IterationSharedStorage { - private static final Map, Object> m = - new ConcurrentHashMap<>(); - - private static final Map, OperatorID> owners = - new ConcurrentHashMap<>(); - - /** - * Gets a {@link Reader} of shared data identified by (iterationId, subtaskId, key). - * - * @param iterationID The iteration ID. - * @param subtaskId The subtask ID. - * @param key The string key. - * @return A {@link Reader} of shared data. - * @param The type of shared ata. - */ - public static Reader getReader(IterationID iterationID, int subtaskId, String key) { - return new Reader<>(Tuple3.of(iterationID, subtaskId, key)); - } - - /** - * Gets a {@link Writer} of shared data identified by (iterationId, subtaskId, key). - * - * @param iterationID The iteration ID. - * @param subtaskId The subtask ID. - * @param key The string key. - * @param operatorID The owner operator. - * @param serializer Serializer of the data. - * @param initVal Initialize value of the data. - * @return A {@link Writer} of shared data. - * @param The type of shared ata. - */ - public static Writer getWriter( - IterationID iterationID, - int subtaskId, - String key, - OperatorID operatorID, - TypeSerializer serializer, - T initVal) { - Tuple3 t = Tuple3.of(iterationID, subtaskId, key); - OperatorID lastOwner = owners.putIfAbsent(t, operatorID); - if (null != lastOwner) { - throw new IllegalStateException( - String.format( - "The shared data (%s, %s, %s) already has a writer %s.", - iterationID, subtaskId, key, operatorID)); - } - Writer writer = new Writer<>(t, operatorID, serializer); - writer.set(initVal); - return writer; - } - - /** - * A reader of shared data identified by key (IterationID, subtaskID, key). - * - * @param The type of shared ata. - */ - public static class Reader { - protected final Tuple3 t; - - public Reader(Tuple3 t) { - this.t = t; - } - - /** - * Get the value. - * - * @return The value. - */ - public T get() { - //noinspection unchecked - return (T) m.get(t); - } - } - - /** - * A writer of shared data identified by key (IterationID, subtaskID, key). A writer is - * responsible for the checkpointing of data. - * - * @param The type of shared ata. - */ - public static class Writer extends Reader { - private final OperatorID operatorID; - private final TypeSerializer serializer; - - public Writer( - Tuple3 t, - OperatorID operatorID, - TypeSerializer serializer) { - super(t); - this.operatorID = operatorID; - this.serializer = serializer; - } - - private void ensureOwner() { - // Double-checks the owner, because a writer may call this method after the key removed - // and re-added by other operators. - Preconditions.checkState(owners.get(t).equals(operatorID)); - } - - /** - * Set new value. - * - * @param value The new value. - */ - public void set(T value) { - ensureOwner(); - m.put(t, value); - } - - /** Remove this data entry. */ - public void remove() { - ensureOwner(); - m.remove(t); - owners.remove(t); - } - - /** - * Initialize the state. - * - * @param context The state initialization context. - * @throws Exception - */ - public void initializeState(StateInitializationContext context) throws Exception { - //noinspection unchecked - List inputs = - IteratorUtils.toList(context.getRawOperatorStateInputs().iterator()); - Preconditions.checkState( - inputs.size() < 2, "The input from raw operator state should be one or zero."); - if (inputs.size() > 0) { - T value = - serializer.deserialize( - new DataInputViewStreamWrapper(inputs.get(0).getStream())); - set(value); - } - } - - /** - * Snapshot the state. - * - * @param context The state snapshot context. - * @throws Exception - */ - public void snapshotState(StateSnapshotContext context) throws Exception { - serializer.serialize( - get(), new DataOutputViewStreamWrapper(context.getRawOperatorStateOutput())); - } - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java index 3afb8b05c..71a54c333 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LearningNode.java @@ -18,8 +18,10 @@ package org.apache.flink.ml.common.gbt.defs; +import java.io.Serializable; + /** A node used in learning procedure. */ -public class LearningNode { +public class LearningNode implements Serializable { // The node index in `currentTreeNodes` used in `PostSplitsOperator`. public int nodeIndex; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java index e29a7dc39..4c2a1acef 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Slice.java @@ -18,8 +18,10 @@ package org.apache.flink.ml.common.gbt.defs; +import java.io.Serializable; + /** Represents a slice of an indexable linear structure, like an array. */ -public final class Slice { +public final class Slice implements Serializable { public int start; public int end; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java index 2364ab203..32ed86bc3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java @@ -21,10 +21,11 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.common.gbt.loss.Loss; +import java.io.Serializable; import java.util.Random; /** Stores the training context. */ -public class TrainContext { +public class TrainContext implements Serializable { public int subtaskId; public int numSubtasks; public GbtParams params; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 1e6ad9ea1..db004df4e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -18,15 +18,10 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.api.common.typeutils.base.BooleanSerializer; -import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; -import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; -import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; -import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.Histogram; @@ -35,7 +30,8 @@ import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.loss.Loss; import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; -import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.runtime.state.StateInitializationContext; @@ -52,6 +48,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.UUID; /** * Calculates local histograms for local data partition. Specifically in the first round, this @@ -59,13 +56,14 @@ */ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator implements TwoInputStreamOperator, - IterationListener { + IterationListener, + SharedStorageStreamOperator { private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; private static final String HIST_BUILDER_STATE_NAME = "hist_builder"; private final GbtParams gbtParams; - private final IterationID iterationID; + private final String sharedStorageAccessorID; // States of local data. private transient ListStateWithCache instancesCollecting; @@ -73,23 +71,12 @@ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator histBuilderState; private transient HistBuilder histBuilder; + private transient SharedStorageContext sharedStorageContext; - // Readers/writers of shared data. - private transient IterationSharedStorage.Writer instancesWriter; - private transient IterationSharedStorage.Reader pghReader; - private transient IterationSharedStorage.Writer shuffledIndicesWriter; - private transient IterationSharedStorage.Reader swappedIndicesReader; - private transient IterationSharedStorage.Writer nodeFeaturePairsWriter; - private transient IterationSharedStorage.Reader> layerReader; - private transient IterationSharedStorage.Writer rootLearningNodeWriter; - private transient IterationSharedStorage.Reader needInitTreeReader; - private transient IterationSharedStorage.Writer hasInitedTreeWriter; - private transient IterationSharedStorage.Writer trainContextWriter; - - public CacheDataCalcLocalHistsOperator(GbtParams gbtParams, IterationID iterationID) { + public CacheDataCalcLocalHistsOperator(GbtParams gbtParams) { super(); this.gbtParams = gbtParams; - this.iterationID = iterationID; + sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -125,78 +112,7 @@ public void initializeState(StateInitializationContext context) throws Exception OperatorStateUtils.getUniqueElement(histBuilderState, HIST_BUILDER_STATE_NAME) .orElse(null); - int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); - instancesWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.INSTANCES, - getOperatorID(), - new GenericArraySerializer<>( - BinnedInstance.class, BinnedInstanceSerializer.INSTANCE), - new BinnedInstance[0]); - instancesWriter.initializeState(context); - - shuffledIndicesWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.SHUFFLED_INDICES, - getOperatorID(), - IntPrimitiveArraySerializer.INSTANCE, - new int[0]); - shuffledIndicesWriter.initializeState(context); - - nodeFeaturePairsWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.NODE_FEATURE_PAIRS, - getOperatorID(), - IntPrimitiveArraySerializer.INSTANCE, - new int[0]); - nodeFeaturePairsWriter.initializeState(context); - - rootLearningNodeWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.ROOT_LEARNING_NODE, - getOperatorID(), - LearningNodeSerializer.INSTANCE, - new LearningNode()); - rootLearningNodeWriter.initializeState(context); - - hasInitedTreeWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.HAS_INITED_TREE, - getOperatorID(), - BooleanSerializer.INSTANCE, - false); - hasInitedTreeWriter.initializeState(context); - - trainContextWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.TRAIN_CONTEXT, - getOperatorID(), - new KryoSerializer<>(TrainContext.class, getExecutionConfig()), - new TrainContext()); - trainContextWriter.initializeState(context); - - this.pghReader = - IterationSharedStorage.getReader( - iterationID, subtaskId, SharedKeys.PREDS_GRADS_HESSIANS); - this.swappedIndicesReader = - IterationSharedStorage.getReader( - iterationID, subtaskId, SharedKeys.SWAPPED_INDICES); - this.layerReader = - IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.LAYER); - this.needInitTreeReader = - IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.NEED_INIT_TREE); + sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override @@ -205,13 +121,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception { instancesCollecting.snapshotState(context); treeInitializerState.snapshotState(context); histBuilderState.snapshotState(context); - - instancesWriter.snapshotState(context); - shuffledIndicesWriter.snapshotState(context); - nodeFeaturePairsWriter.snapshotState(context); - rootLearningNodeWriter.snapshotState(context); - hasInitedTreeWriter.snapshotState(context); - trainContextWriter.snapshotState(context); + sharedStorageContext.snapshotState(context); } @Override @@ -237,92 +147,112 @@ public void processElement1(StreamRecord streamRecord) throws Exception { @Override public void processElement2(StreamRecord streamRecord) throws Exception { - TrainContext trainContext = streamRecord.getValue(); - if (null != trainContext) { - // Not null only in first round. - trainContextWriter.set(trainContext); - } + TrainContext rawTrainContext = streamRecord.getValue(); + sharedStorageContext.invoke( + (getter, setter) -> + setter.set(SharedStorageConstants.TRAIN_CONTEXT, rawTrainContext)); } - @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector out) throws Exception { if (0 == epochWatermark) { // Initializes local state in first round. - instancesWriter.set( - (BinnedInstance[]) - IteratorUtils.toArray( - instancesCollecting.get().iterator(), BinnedInstance.class)); - instancesCollecting.clear(); - TrainContext trainContext = - new TrainContextInitializer(gbtParams) - .init( - trainContextWriter.get(), - getRuntimeContext().getIndexOfThisSubtask(), - getRuntimeContext().getNumberOfParallelSubtasks(), - instancesWriter.get()); - trainContextWriter.set(trainContext); - - treeInitializer = new TreeInitializer(trainContext); - treeInitializerState.update(Collections.singletonList(treeInitializer)); - histBuilder = new HistBuilder(trainContext); - histBuilderState.update(Collections.singletonList(histBuilder)); - } - - TrainContext trainContext = trainContextWriter.get(); - BinnedInstance[] instances = instancesWriter.get(); - Preconditions.checkArgument( - getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); - PredGradHess[] pgh = pghReader.get(); - - // In the first round, use prior as the predictions. - if (0 == pgh.length) { - pgh = new PredGradHess[instances.length]; - double prior = trainContext.prior; - Loss loss = trainContext.loss; - for (int i = 0; i < instances.length; i += 1) { - double label = instances[i].label; - pgh[i] = - new PredGradHess( - prior, loss.gradient(prior, label), loss.hessian(prior, label)); - } + sharedStorageContext.invoke( + (getter, setter) -> { + BinnedInstance[] instances = + (BinnedInstance[]) + IteratorUtils.toArray( + instancesCollecting.get().iterator(), + BinnedInstance.class); + setter.set(SharedStorageConstants.INSTANCES, instances); + instancesCollecting.clear(); + + TrainContext rawTrainContext = + getter.get(SharedStorageConstants.TRAIN_CONTEXT); + TrainContext trainContext = + new TrainContextInitializer(gbtParams) + .init( + rawTrainContext, + getRuntimeContext().getIndexOfThisSubtask(), + getRuntimeContext().getNumberOfParallelSubtasks(), + instances); + setter.set(SharedStorageConstants.TRAIN_CONTEXT, trainContext); + + treeInitializer = new TreeInitializer(trainContext); + treeInitializerState.update(Collections.singletonList(treeInitializer)); + histBuilder = new HistBuilder(trainContext); + histBuilderState.update(Collections.singletonList(histBuilder)); + }); } - int[] indices; - if (needInitTreeReader.get()) { - // When last tree is finished, initializes a new tree, and shuffle instance indices. - treeInitializer.init(shuffledIndicesWriter::set); - LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); - indices = shuffledIndicesWriter.get(); - rootLearningNodeWriter.set(rootLearningNode); - hasInitedTreeWriter.set(true); - } else { - // Otherwise, uses the swapped instance indices. - shuffledIndicesWriter.set(new int[0]); - indices = swappedIndicesReader.get(); - hasInitedTreeWriter.set(false); - } - - List layer = layerReader.get(); - if (layer.size() == 0) { - layer = Collections.singletonList(rootLearningNodeWriter.get()); - } - - Histogram localHists = - histBuilder.build(layer, indices, instances, pgh, nodeFeaturePairsWriter::set); - out.collect(localHists); + sharedStorageContext.invoke( + (getter, setter) -> { + TrainContext trainContext = getter.get(SharedStorageConstants.TRAIN_CONTEXT); + Preconditions.checkArgument( + getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); + BinnedInstance[] instances = getter.get(SharedStorageConstants.INSTANCES); + PredGradHess[] pgh = getter.get(SharedStorageConstants.PREDS_GRADS_HESSIANS); + // In the first round, use prior as the predictions. + if (0 == pgh.length) { + pgh = new PredGradHess[instances.length]; + double prior = trainContext.prior; + Loss loss = trainContext.loss; + for (int i = 0; i < instances.length; i += 1) { + double label = instances[i].label; + pgh[i] = + new PredGradHess( + prior, + loss.gradient(prior, label), + loss.hessian(prior, label)); + } + } + + boolean needInitTree = getter.get(SharedStorageConstants.NEED_INIT_TREE); + int[] indices; + List layer; + if (needInitTree) { + // When last tree is finished, initializes a new tree, and shuffle instance + // indices. + treeInitializer.init( + d -> setter.set(SharedStorageConstants.SHUFFLED_INDICES, d)); + LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); + indices = getter.get(SharedStorageConstants.SHUFFLED_INDICES); + layer = Collections.singletonList(rootLearningNode); + setter.set(SharedStorageConstants.ROOT_LEARNING_NODE, rootLearningNode); + setter.set(SharedStorageConstants.HAS_INITED_TREE, true); + } else { + // Otherwise, uses the swapped instance indices. + indices = getter.get(SharedStorageConstants.SWAPPED_INDICES); + layer = getter.get(SharedStorageConstants.LAYER); + setter.set(SharedStorageConstants.SHUFFLED_INDICES, new int[0]); + setter.set(SharedStorageConstants.HAS_INITED_TREE, false); + } + + Histogram localHists = + histBuilder.build( + layer, + indices, + instances, + pgh, + d -> setter.set(SharedStorageConstants.NODE_FEATURE_PAIRS, d)); + out.collect(localHists); + }); } @Override - public void onIterationTerminated(Context context, Collector collector) { + public void onIterationTerminated(Context context, Collector collector) + throws Exception { instancesCollecting.clear(); treeInitializerState.clear(); histBuilderState.clear(); - instancesWriter.set(new BinnedInstance[0]); - shuffledIndicesWriter.set(new int[0]); - nodeFeaturePairsWriter.set(new int[0]); + sharedStorageContext.invoke( + (getter, setter) -> { + setter.set(SharedStorageConstants.INSTANCES, new BinnedInstance[0]); + setter.set(SharedStorageConstants.SHUFFLED_INDICES, new int[0]); + setter.set(SharedStorageConstants.NODE_FEATURE_PAIRS, new int[0]); + }); } @Override @@ -330,13 +260,17 @@ public void close() throws Exception { instancesCollecting.clear(); treeInitializerState.clear(); histBuilderState.clear(); - - instancesWriter.remove(); - shuffledIndicesWriter.remove(); - nodeFeaturePairsWriter.remove(); - rootLearningNodeWriter.remove(); - hasInitedTreeWriter.remove(); - trainContextWriter.remove(); + sharedStorageContext.clear(); super.close(); } + + @Override + public void onSharedStorageContextSet(SharedStorageContext context) { + this.sharedStorageContext = context; + } + + @Override + public String getSharedStorageAccessorID() { + return sharedStorageAccessorID; + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java index dcc68ee3b..7750c5fa6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -19,16 +19,14 @@ package org.apache.flink.ml.common.gbt.operators; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; -import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; -import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Splits; -import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.gbt.typeinfo.HistogramSerializer; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -38,31 +36,23 @@ import java.util.Collections; import java.util.List; +import java.util.UUID; /** Calculates local splits for assigned (nodeId, featureId) pairs. */ public class CalcLocalSplitsOperator extends AbstractStreamOperator - implements OneInputStreamOperator, IterationListener { + implements OneInputStreamOperator, + IterationListener, + SharedStorageStreamOperator { private static final String SPLIT_FINDER_STATE_NAME = "split_finder"; - private static final String HISTOGRAM_STATE_NAME = "histogram"; - - private final IterationID iterationID; - + private final String sharedStorageAccessorID; // States of local data. private transient ListStateWithCache splitFinderState; private transient SplitFinder splitFinder; - private transient ListStateWithCache histogramState; - private transient Histogram histogram; - - // Readers/writers of shared data. - private transient IterationSharedStorage.Reader nodeFeaturePairsReader; - private transient IterationSharedStorage.Reader> leavesReader; - private transient IterationSharedStorage.Reader> layerReader; - private transient IterationSharedStorage.Reader rootLearningNodeReader; - private transient IterationSharedStorage.Reader trainContextReader; + private transient SharedStorageContext sharedStorageContext; - public CalcLocalSplitsOperator(IterationID iterationID) { - this.iterationID = iterationID; + public CalcLocalSplitsOperator() { + sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -79,66 +69,61 @@ public void initializeState(StateInitializationContext context) throws Exception OperatorStateUtils.getUniqueElement(splitFinderState, SPLIT_FINDER_STATE_NAME) .orElse(null); - histogramState = - new ListStateWithCache<>( - new HistogramSerializer(), - getContainingTask(), - getRuntimeContext(), - context, - getOperatorID()); - histogram = - OperatorStateUtils.getUniqueElement(histogramState, HISTOGRAM_STATE_NAME) - .orElse(null); - - int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); - nodeFeaturePairsReader = - IterationSharedStorage.getReader( - iterationID, subtaskId, SharedKeys.NODE_FEATURE_PAIRS); - leavesReader = IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.LEAVES); - layerReader = IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.LAYER); - rootLearningNodeReader = - IterationSharedStorage.getReader( - iterationID, subtaskId, SharedKeys.ROOT_LEARNING_NODE); - trainContextReader = - IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.TRAIN_CONTEXT); + sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); splitFinderState.snapshotState(context); - histogramState.snapshotState(context); } - @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector collector) throws Exception { - if (0 == epochWatermark) { - splitFinder = new SplitFinder(trainContextReader.get()); - splitFinderState.update(Collections.singletonList(splitFinder)); - } + int epochWatermark, Context context, Collector collector) {} - List layer = layerReader.get(); - if (layer.size() == 0) { - layer = Collections.singletonList(rootLearningNodeReader.get()); + @Override + public void processElement(StreamRecord element) throws Exception { + if (null == splitFinder) { + sharedStorageContext.invoke( + (getter, setter) -> { + splitFinder = + new SplitFinder(getter.get(SharedStorageConstants.TRAIN_CONTEXT)); + splitFinderState.update(Collections.singletonList(splitFinder)); + }); } - Splits splits = - splitFinder.calc( - layer, nodeFeaturePairsReader.get(), leavesReader.get().size(), histogram); - collector.collect(splits); + Histogram histogram = element.getValue(); + sharedStorageContext.invoke( + (getter, setter) -> { + List layer = getter.get(SharedStorageConstants.LAYER); + if (layer.size() == 0) { + layer = + Collections.singletonList( + getter.get(SharedStorageConstants.ROOT_LEARNING_NODE)); + } + Splits splits = + splitFinder.calc( + layer, + getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS), + getter.get(SharedStorageConstants.LEAVES).size(), + histogram); + output.collect(new StreamRecord<>(splits)); + }); } @Override - public void processElement(StreamRecord element) throws Exception { - histogram = element.getValue(); - histogramState.update(Collections.singletonList(histogram)); + public void onIterationTerminated(Context context, Collector collector) { + splitFinderState.clear(); } @Override - public void onIterationTerminated(Context context, Collector collector) { - splitFinderState.clear(); - histogramState.clear(); + public void onSharedStorageContextSet(SharedStorageContext context) { + this.sharedStorageContext = context; + } + + @Override + public String getSharedStorageAccessorID() { + return sharedStorageAccessorID; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 72e9b11cc..1668a87a2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -18,25 +18,18 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.api.common.typeutils.base.BooleanSerializer; -import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; -import org.apache.flink.api.common.typeutils.base.ListSerializer; -import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; -import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; -import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.Splits; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; -import org.apache.flink.ml.common.gbt.typeinfo.NodeSerializer; -import org.apache.flink.ml.common.gbt.typeinfo.PredGradHessSerializer; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -47,20 +40,22 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.UUID; /** * Post-process after global splits obtained, including split instances to left or child nodes, and * update instances scores after a tree is complete. */ public class PostSplitsOperator extends AbstractStreamOperator - implements OneInputStreamOperator, IterationListener { + implements OneInputStreamOperator, + IterationListener, + SharedStorageStreamOperator { private static final String SPLITS_STATE_NAME = "splits"; private static final String NODE_SPLITTER_STATE_NAME = "node_splitter"; private static final String INSTANCE_UPDATER_STATE_NAME = "instance_updater"; - private static final String CURRENT_TREE_NODES_STATE_NAME = "current_tree_nodes"; - private final IterationID iterationID; + private final String sharedStorageAccessorID; // States of local data. private transient ListStateWithCache splitsState; @@ -69,22 +64,10 @@ public class PostSplitsOperator extends AbstractStreamOperator private transient NodeSplitter nodeSplitter; private transient ListStateWithCache instanceUpdaterState; private transient InstanceUpdater instanceUpdater; + private transient SharedStorageContext sharedStorageContext; - // Readers/writers of shared data. - private transient IterationSharedStorage.Reader instancesReader; - private transient IterationSharedStorage.Writer pghWriter; - private transient IterationSharedStorage.Reader shuffledIndicesReader; - private transient IterationSharedStorage.Writer swappedIndicesWriter; - private transient IterationSharedStorage.Writer> leavesWriter; - private transient IterationSharedStorage.Writer> layerWriter; - private transient IterationSharedStorage.Reader rootLearningNodeReader; - private transient IterationSharedStorage.Writer>> allTreesWriter; - private transient IterationSharedStorage.Writer> currentTreeNodesWriter; - private transient IterationSharedStorage.Writer needInitTreeWriter; - private transient IterationSharedStorage.Reader trainContextReader; - - public PostSplitsOperator(IterationID iterationID) { - this.iterationID = iterationID; + public PostSplitsOperator() { + sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -121,86 +104,7 @@ public void initializeState(StateInitializationContext context) throws Exception instanceUpdaterState, INSTANCE_UPDATER_STATE_NAME) .orElse(null); - int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); - pghWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.PREDS_GRADS_HESSIANS, - getOperatorID(), - new GenericArraySerializer<>( - PredGradHess.class, PredGradHessSerializer.INSTANCE), - new PredGradHess[0]); - pghWriter.initializeState(context); - swappedIndicesWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.SWAPPED_INDICES, - getOperatorID(), - IntPrimitiveArraySerializer.INSTANCE, - new int[0]); - swappedIndicesWriter.initializeState(context); - leavesWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.LEAVES, - getOperatorID(), - new ListSerializer<>(LearningNodeSerializer.INSTANCE), - new ArrayList<>()); - leavesWriter.initializeState(context); - - layerWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.LAYER, - getOperatorID(), - new ListSerializer<>(LearningNodeSerializer.INSTANCE), - new ArrayList<>()); - layerWriter.initializeState(context); - - allTreesWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.ALL_TREES, - getOperatorID(), - new ListSerializer<>(new ListSerializer<>(NodeSerializer.INSTANCE)), - new ArrayList<>()); - allTreesWriter.initializeState(context); - - needInitTreeWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - SharedKeys.NEED_INIT_TREE, - getOperatorID(), - BooleanSerializer.INSTANCE, - true); - needInitTreeWriter.initializeState(context); - - currentTreeNodesWriter = - IterationSharedStorage.getWriter( - iterationID, - subtaskId, - CURRENT_TREE_NODES_STATE_NAME, - getOperatorID(), - new ListSerializer<>(NodeSerializer.INSTANCE), - new ArrayList<>()); - currentTreeNodesWriter.initializeState(context); - - instancesReader = - IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.INSTANCES); - shuffledIndicesReader = - IterationSharedStorage.getReader( - iterationID, subtaskId, SharedKeys.SHUFFLED_INDICES); - rootLearningNodeReader = - IterationSharedStorage.getReader( - iterationID, subtaskId, SharedKeys.ROOT_LEARNING_NODE); - trainContextReader = - IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.TRAIN_CONTEXT); + sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override @@ -209,75 +113,92 @@ public void snapshotState(StateSnapshotContext context) throws Exception { splitsState.snapshotState(context); nodeSplitterState.snapshotState(context); instanceUpdaterState.snapshotState(context); - - pghWriter.snapshotState(context); - swappedIndicesWriter.snapshotState(context); - leavesWriter.snapshotState(context); - layerWriter.snapshotState(context); - allTreesWriter.snapshotState(context); - currentTreeNodesWriter.snapshotState(context); - needInitTreeWriter.snapshotState(context); + sharedStorageContext.snapshotState(context); } - @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { if (0 == epochWatermark) { - nodeSplitter = new NodeSplitter(trainContextReader.get()); - nodeSplitterState.update(Collections.singletonList(nodeSplitter)); - instanceUpdater = new InstanceUpdater(trainContextReader.get()); - instanceUpdaterState.update(Collections.singletonList(instanceUpdater)); - } - - int[] indices = swappedIndicesWriter.get(); - if (0 == indices.length) { - indices = shuffledIndicesReader.get().clone(); - } - - BinnedInstance[] instances = instancesReader.get(); - List leaves = leavesWriter.get(); - List layer = layerWriter.get(); - List currentTreeNodes; - if (layer.size() == 0) { - layer = Collections.singletonList(rootLearningNodeReader.get()); - currentTreeNodes = new ArrayList<>(); - currentTreeNodes.add(new Node()); - } else { - currentTreeNodes = currentTreeNodesWriter.get(); + sharedStorageContext.invoke( + (getter, setter) -> { + TrainContext trainContext = + getter.get(SharedStorageConstants.TRAIN_CONTEXT); + nodeSplitter = new NodeSplitter(trainContext); + nodeSplitterState.update(Collections.singletonList(nodeSplitter)); + instanceUpdater = new InstanceUpdater(trainContext); + instanceUpdaterState.update(Collections.singletonList(instanceUpdater)); + }); } - List nextLayer = - nodeSplitter.split( - currentTreeNodes, layer, leaves, splits.splits, indices, instances); - leavesWriter.set(leaves); - layerWriter.set(nextLayer); - currentTreeNodesWriter.set(currentTreeNodes); - - if (nextLayer.isEmpty()) { - needInitTreeWriter.set(true); - instanceUpdater.update( - pghWriter.get(), leaves, indices, instances, pghWriter::set, currentTreeNodes); - leaves.clear(); - List> allTrees = allTreesWriter.get(); - allTrees.add(currentTreeNodes); - - leavesWriter.set(new ArrayList<>()); - swappedIndicesWriter.set(new int[0]); - allTreesWriter.set(allTrees); - } else { - swappedIndicesWriter.set(indices); - needInitTreeWriter.set(false); - } + sharedStorageContext.invoke( + (getter, setter) -> { + int[] indices = getter.get(SharedStorageConstants.SWAPPED_INDICES); + if (0 == indices.length) { + indices = getter.get(SharedStorageConstants.SHUFFLED_INDICES).clone(); + } + + BinnedInstance[] instances = getter.get(SharedStorageConstants.INSTANCES); + List leaves = getter.get(SharedStorageConstants.LEAVES); + List layer = getter.get(SharedStorageConstants.LAYER); + List currentTreeNodes; + if (layer.size() == 0) { + layer = + Collections.singletonList( + getter.get(SharedStorageConstants.ROOT_LEARNING_NODE)); + currentTreeNodes = new ArrayList<>(); + currentTreeNodes.add(new Node()); + } else { + currentTreeNodes = getter.get(SharedStorageConstants.CURRENT_TREE_NODES); + } + + List nextLayer = + nodeSplitter.split( + currentTreeNodes, + layer, + leaves, + splits.splits, + indices, + instances); + setter.set(SharedStorageConstants.LEAVES, leaves); + setter.set(SharedStorageConstants.LAYER, nextLayer); + setter.set(SharedStorageConstants.CURRENT_TREE_NODES, currentTreeNodes); + + if (nextLayer.isEmpty()) { + // Current tree is finished. + setter.set(SharedStorageConstants.NEED_INIT_TREE, true); + instanceUpdater.update( + getter.get(SharedStorageConstants.PREDS_GRADS_HESSIANS), + leaves, + indices, + instances, + d -> setter.set(SharedStorageConstants.PREDS_GRADS_HESSIANS, d), + currentTreeNodes); + leaves.clear(); + List> allTrees = getter.get(SharedStorageConstants.ALL_TREES); + allTrees.add(currentTreeNodes); + + setter.set(SharedStorageConstants.LEAVES, new ArrayList<>()); + setter.set(SharedStorageConstants.SWAPPED_INDICES, new int[0]); + setter.set(SharedStorageConstants.ALL_TREES, allTrees); + } else { + setter.set(SharedStorageConstants.SWAPPED_INDICES, indices); + setter.set(SharedStorageConstants.NEED_INIT_TREE, false); + } + }); } @Override - public void onIterationTerminated(Context context, Collector collector) { - pghWriter.set(new PredGradHess[0]); - swappedIndicesWriter.set(new int[0]); - leavesWriter.set(Collections.emptyList()); - layerWriter.set(Collections.emptyList()); - currentTreeNodesWriter.set(Collections.emptyList()); + public void onIterationTerminated(Context context, Collector collector) + throws Exception { + sharedStorageContext.invoke( + (getter, setter) -> { + setter.set(SharedStorageConstants.PREDS_GRADS_HESSIANS, new PredGradHess[0]); + setter.set(SharedStorageConstants.SWAPPED_INDICES, new int[0]); + setter.set(SharedStorageConstants.LEAVES, Collections.emptyList()); + setter.set(SharedStorageConstants.LAYER, Collections.emptyList()); + setter.set(SharedStorageConstants.CURRENT_TREE_NODES, Collections.emptyList()); + }); } @Override @@ -291,14 +212,17 @@ public void close() throws Exception { splitsState.clear(); nodeSplitterState.clear(); instanceUpdaterState.clear(); - - pghWriter.remove(); - swappedIndicesWriter.remove(); - leavesWriter.remove(); - layerWriter.remove(); - allTreesWriter.remove(); - currentTreeNodesWriter.remove(); - needInitTreeWriter.remove(); + sharedStorageContext.clear(); super.close(); } + + @Override + public void onSharedStorageContextSet(SharedStorageContext context) { + sharedStorageContext = context; + } + + @Override + public String getSharedStorageAccessorID() { + return sharedStorageAccessorID; + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java deleted file mode 100644 index 4e0278969..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedKeys.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.operators; - -import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; - -/** Stores keys for shared data stored in {@link IterationSharedStorage}. */ -class SharedKeys { - /** - * In the iteration, some data needs to be shared between subtasks of different operators within - * one machine. We use {@link IterationSharedStorage} with co-location mechanism to achieve such - * purpose. The data is stored in JVM static region, and is accessed through string keys from - * different operator subtasks. Note the first operator to put the data is the owner of the - * data, and only the owner can update or delete the data. - * - *

To be specified, in gradient boosting trees algorithm, there three types of shared data: - * - *

    - *
  • Instances (after binned) and their corresponding predictions, gradients, and hessians - * are shared to avoid being stored multiple times or communication. - *
  • When initializing every new tree, instances need to be shuffled and split to bagging - * instances and non-bagging ones. To reduce the cost, we shuffle instance indices other - * than instances. Therefore, the shuffle indices need to be shared to access actual - * instances. - *
  • After splitting nodes of each layer, instance indices need to be swapped to maintain - * {@link LearningNode#slice} and {@link LearningNode#oob}. However, we cannot directly - * update the data of shuffle indices above, as it already has an owner. So we use another - * key to store instance indices after swapping. - *
- */ - static final String INSTANCES = "instances"; - - static final String PREDS_GRADS_HESSIANS = "preds_grads_hessians"; - static final String SHUFFLED_INDICES = "shuffled_indices"; - static final String SWAPPED_INDICES = "swapped_indices"; - - static final String NODE_FEATURE_PAIRS = "node_feature_pairs"; - static final String LEAVES = "leaves"; - static final String LAYER = "layer"; - - static final String ROOT_LEARNING_NODE = "root_learning_node"; - static final String ALL_TREES = "all_trees"; - static final String NEED_INIT_TREE = "need_init_tree"; - static final String HAS_INITED_TREE = "has_inited_tree"; - - static final String TRAIN_CONTEXT = "train_context"; -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java new file mode 100644 index 000000000..4d09f18d9 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.GenericArraySerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; +import org.apache.flink.ml.common.gbt.GBTRunner; +import org.apache.flink.ml.common.gbt.defs.BinnedInstance; +import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Node; +import org.apache.flink.ml.common.gbt.defs.PredGradHess; +import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.NodeSerializer; +import org.apache.flink.ml.common.gbt.typeinfo.PredGradHessSerializer; +import org.apache.flink.ml.common.sharedstorage.ItemDescriptor; +import org.apache.flink.ml.common.sharedstorage.SharedStorageUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Stores constants used for {@link SharedStorageUtils} in {@link GBTRunner}. + * + *

In the iteration, some data needs to be shared and accessed between subtasks of different + * operators within one JVM to reduce memory footprint and communication cost. We use {@link + * SharedStorageUtils} with co-location mechanism to achieve such purpose. + * + *

All shared data items have corresponding {@link ItemDescriptor}s, and can be read/written + * through {@link ItemDescriptor}s from different operator subtasks. Note that every shared item has + * an owner, and the owner can set new values and snapshot the item. + * + *

This class records all {@link ItemDescriptor}s used in {@link GBTRunner} and their owners. + */ +@Internal +public class SharedStorageConstants { + + /** Instances (after binned). */ + static final ItemDescriptor INSTANCES = + ItemDescriptor.of( + "instances", + new GenericArraySerializer<>( + BinnedInstance.class, BinnedInstanceSerializer.INSTANCE), + new BinnedInstance[0]); + + /** + * Predictions, gradients, and hessians of instances, sharing same instances with {@link + * #INSTANCES}. + */ + static final ItemDescriptor PREDS_GRADS_HESSIANS = + ItemDescriptor.of( + "preds_grads_hessians", + new GenericArraySerializer<>( + PredGradHess.class, PredGradHessSerializer.INSTANCE), + new PredGradHess[0]); + + /** Shuffle indices of instances used after every new tree just initialized. */ + static final ItemDescriptor SHUFFLED_INDICES = + ItemDescriptor.of("shuffled_indices", IntPrimitiveArraySerializer.INSTANCE, new int[0]); + + /** Swapped indices of instances used when {@link #SHUFFLED_INDICES} not applicable. */ + static final ItemDescriptor SWAPPED_INDICES = + ItemDescriptor.of("swapped_indices", IntPrimitiveArraySerializer.INSTANCE, new int[0]); + + /** (nodeId, featureId) pairs used to calculate histograms. */ + static final ItemDescriptor NODE_FEATURE_PAIRS = + ItemDescriptor.of( + "node_feature_pairs", IntPrimitiveArraySerializer.INSTANCE, new int[0]); + + /** Leaves nodes of current working tree. */ + static final ItemDescriptor> LEAVES = + ItemDescriptor.of( + "leaves", + new ListSerializer<>(LearningNodeSerializer.INSTANCE), + new ArrayList<>()); + + /** Nodes in current layer of current working tree. */ + static final ItemDescriptor> LAYER = + ItemDescriptor.of( + "layer", + new ListSerializer<>(LearningNodeSerializer.INSTANCE), + new ArrayList<>()); + + /** The root node when initializing a new tree. */ + static final ItemDescriptor ROOT_LEARNING_NODE = + ItemDescriptor.of( + "root_learning_node", LearningNodeSerializer.INSTANCE, new LearningNode()); + + /** All finished trees. */ + static final ItemDescriptor>> ALL_TREES = + ItemDescriptor.of( + "all_trees", + new ListSerializer<>(new ListSerializer<>(NodeSerializer.INSTANCE)), + new ArrayList<>()); + + /** Nodes in current working tree. */ + static final ItemDescriptor> CURRENT_TREE_NODES = + ItemDescriptor.of( + "current_tree_nodes", + new ListSerializer<>(NodeSerializer.INSTANCE), + new ArrayList<>()); + + /** Indicates the necessity of initializing a new tree. */ + static final ItemDescriptor NEED_INIT_TREE = + ItemDescriptor.of("need_init_tree", BooleanSerializer.INSTANCE, true); + + /** Data items owned by the `PostSplits` operator. */ + public static final List> OWNED_BY_POST_SPLITS_OP = + Arrays.asList( + PREDS_GRADS_HESSIANS, + SWAPPED_INDICES, + LEAVES, + LAYER, + ALL_TREES, + CURRENT_TREE_NODES, + NEED_INIT_TREE); + + /** Indicate a new tree has been initialized. */ + static final ItemDescriptor HAS_INITED_TREE = + ItemDescriptor.of("has_inited_tree", BooleanSerializer.INSTANCE, false); + + /** Training context. */ + static final ItemDescriptor TRAIN_CONTEXT = + ItemDescriptor.of( + "train_context", + new KryoSerializer<>(TrainContext.class, new ExecutionConfig()), + new TrainContext()); + + /** Data items owned by the `CacheDataCalcLocalHists` operator. */ + public static final List> OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP = + Arrays.asList( + INSTANCES, + SHUFFLED_INDICES, + NODE_FEATURE_PAIRS, + ROOT_LEARNING_NODE, + HAS_INITED_TREE, + TRAIN_CONTEXT); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index dffa53b48..c724301dd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -18,42 +18,38 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.IterationListener; import org.apache.flink.ml.common.gbt.GBTModelData; -import org.apache.flink.ml.common.gbt.datastorage.IterationSharedStorage; -import org.apache.flink.ml.common.gbt.defs.Node; -import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; import org.apache.flink.util.OutputTag; -import java.util.List; +import java.util.UUID; /** Determines whether to terminated training. */ public class TerminationOperator extends AbstractStreamOperator - implements OneInputStreamOperator, IterationListener { + implements OneInputStreamOperator, + IterationListener, + SharedStorageStreamOperator { - private final IterationID iterationID; private final OutputTag modelDataOutputTag; - private transient IterationSharedStorage.Reader>> allTreesReader; - private transient IterationSharedStorage.Reader trainContextReader; + private final String sharedStorageAccessorID; + private transient SharedStorageContext sharedStorageContext; - public TerminationOperator( - IterationID iterationID, OutputTag modelDataOutputTag) { - this.iterationID = iterationID; + public TerminationOperator(OutputTag modelDataOutputTag) { this.modelDataOutputTag = modelDataOutputTag; + sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override - public void open() throws Exception { - int subtaskId = getRuntimeContext().getIndexOfThisSubtask(); - allTreesReader = - IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.ALL_TREES); - trainContextReader = - IterationSharedStorage.getReader(iterationID, subtaskId, SharedKeys.TRAIN_CONTEXT); + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override @@ -61,18 +57,41 @@ public void processElement(StreamRecord element) throws Exception {} @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector collector) { - boolean terminated = allTreesReader.get().size() == trainContextReader.get().params.maxIter; - // TODO: add validation error rate - if (!terminated) { - output.collect(new StreamRecord<>(0)); - } + int epochWatermark, Context context, Collector collector) + throws Exception { + sharedStorageContext.invoke( + (getter, setter) -> { + boolean terminated = + getter.get(SharedStorageConstants.ALL_TREES).size() + == getter.get(SharedStorageConstants.TRAIN_CONTEXT) + .params + .maxIter; + // TODO: add validation error rate + if (!terminated) { + output.collect(new StreamRecord<>(0)); + } + }); } @Override - public void onIterationTerminated(Context context, Collector collector) { - context.output( - modelDataOutputTag, - GBTModelData.from(trainContextReader.get(), allTreesReader.get())); + public void onIterationTerminated(Context context, Collector collector) + throws Exception { + sharedStorageContext.invoke( + (getter, setter) -> + context.output( + modelDataOutputTag, + GBTModelData.from( + getter.get(SharedStorageConstants.TRAIN_CONTEXT), + getter.get(SharedStorageConstants.ALL_TREES)))); + } + + @Override + public void onSharedStorageContextSet(SharedStorageContext context) { + sharedStorageContext = context; + } + + @Override + public String getSharedStorageAccessorID() { + return sharedStorageAccessorID; } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java index f609c2de3..b63494ca5 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java @@ -119,6 +119,7 @@ public void testTrainClassifier() throws Exception { p.taskType = TaskType.CLASSIFICATION; p.labelCol = "cls_label"; p.lossType = "logistic"; + p.useMissing = true; GBTModelData modelData = GBTRunner.train(inputTable, p).executeAndCollect().next(); verifyModelData(modelData, p); @@ -130,6 +131,7 @@ public void testTrainRegressor() throws Exception { p.taskType = TaskType.REGRESSION; p.labelCol = "label"; p.lossType = "squared"; + p.useMissing = true; GBTModelData modelData = GBTRunner.train(inputTable, p).executeAndCollect().next(); verifyModelData(modelData, p); From ad8d4a779d4ca54d747d2c93e090eec7c96c5ae3 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 28 Feb 2023 20:08:42 +0800 Subject: [PATCH 17/47] Replace inputCols and featuresCol with featuresCols. --- .../gbtclassifier/GBTClassifierModel.java | 14 ++---- .../flink/ml/common/gbt/GBTModelData.java | 8 ++-- .../flink/ml/common/gbt/GBTModelParams.java | 28 ++--------- .../apache/flink/ml/common/gbt/GBTRunner.java | 37 +++++++++++--- .../flink/ml/common/gbt/Preprocess.java | 8 ++-- .../flink/ml/common/gbt/defs/GbtParams.java | 3 +- .../CacheDataCalcLocalHistsOperator.java | 4 +- .../ml/common/param/HasFeaturesCols.java | 48 +++++++++++++++++++ .../gbtregressor/GBTRegressorModel.java | 14 ++---- .../ml/classification/GBTClassifierTest.java | 27 +++++------ .../flink/ml/common/gbt/GBTRunnerTest.java | 2 +- .../flink/ml/common/gbt/PreprocessTest.java | 6 ++- .../flink/ml/regression/GBTRegressorTest.java | 27 +++++------ 13 files changed, 134 insertions(+), 92 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCols.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java index 8e3110079..24116bac4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java @@ -90,8 +90,7 @@ public Table[] transform(Table... inputs) { //noinspection unchecked DataStream inputData = (DataStream) inputList.get(0); return inputData.map( - new PredictLabelFunction( - broadcastModelKey, getInputCols(), getFeaturesCol()), + new PredictLabelFunction(broadcastModelKey, getFeaturesCols()), outputTypeInfo); }); return new Table[] {tEnv.fromDataStream(predictionResult)}; @@ -102,15 +101,12 @@ private static class PredictLabelFunction extends RichMapFunction { private static final Sigmoid sigmoid = new Sigmoid(); private final String broadcastModelKey; - private final String[] inputCols; - private final String featuresCol; + private final String[] featuresCols; private GBTModelData modelData; - public PredictLabelFunction( - String broadcastModelKey, String[] inputCols, String featuresCol) { + public PredictLabelFunction(String broadcastModelKey, String[] featuresCols) { this.broadcastModelKey = broadcastModelKey; - this.inputCols = inputCols; - this.featuresCol = featuresCol; + this.featuresCols = featuresCols; } @Override @@ -120,7 +116,7 @@ public Row map(Row value) throws Exception { (GBTModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); } - IntDoubleHashMap features = modelData.rowToFeatures(value, inputCols, featuresCol); + IntDoubleHashMap features = modelData.rowToFeatures(value, featuresCols); double logits = modelData.predictRaw(features); double prob = sigmoid.value(logits); return Row.join( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index d34f44211..555d9b6c3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -147,17 +147,17 @@ private static int mapCategoricalFeature(ObjectIntHashMap categoryToId, return categoryToId.getIfAbsent(s, categoryToId.size()); } - public IntDoubleHashMap rowToFeatures(Row row, String[] featureCols, String vectorCol) { + public IntDoubleHashMap rowToFeatures(Row row, String[] featuresCols) { IntDoubleHashMap features = new IntDoubleHashMap(); if (isInputVector) { - Vector vec = row.getFieldAs(vectorCol); + Vector vec = row.getFieldAs(featuresCols[0]); SparseVector sv = vec.toSparse(); for (int i = 0; i < sv.indices.length; i += 1) { features.put(sv.indices[i], sv.values[i]); } } else { - for (int i = 0; i < featureCols.length; i += 1) { - Object obj = row.getField(featureCols[i]); + for (int i = 0; i < featuresCols.length; i += 1) { + Object obj = row.getField(featuresCols[i]); double v; if (isCategorical.get(i)) { v = mapCategoricalFeature(categoryToIdMaps.get(i), obj); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java index 50c078303..d6de81591 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java @@ -20,37 +20,19 @@ import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierModel; import org.apache.flink.ml.common.param.HasCategoricalCols; -import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasFeaturesCols; import org.apache.flink.ml.common.param.HasLabelCol; import org.apache.flink.ml.common.param.HasPredictionCol; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.StringArrayParam; import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; /** * Params of {@link GBTClassifierModel} and {@link GBTRegressorModel}. * - *

If the input features come from 1 column of vector type, `featuresCol` should be used, and all - * features are treated as continuous features. Otherwise, `inputCols` should be used for multiple - * columns. Columns whose names specified in `categoricalCols` are treated as categorical features, - * while others are continuous features. - * - *

NOTE: `inputCols` and `featuresCol` are in conflict with each other, so they should not be set - * at the same time. In addition, `inputCols` has a higher precedence than `featuresCol`, that is, - * `featuresCol` is ignored when `inputCols` is not `null`. + *

The value `featureCols` can be either one column name of vector type, or multiple columns + * names of non-vector types. For the latter case, `categoricalCols` can be further set to + * specifying columns that need to be treated as categorical features. * * @param The class type of this instance. */ public interface GBTModelParams - extends HasFeaturesCol, HasLabelCol, HasCategoricalCols, HasPredictionCol { - - Param INPUT_COLS = new StringArrayParam("inputCols", "Input column names.", null); - - default String[] getInputCols() { - return get(INPUT_COLS); - } - - default T setInputCols(String... value) { - return set(INPUT_COLS, value); - } -} + extends HasFeaturesCols, HasLabelCol, HasCategoricalCols, HasPredictionCol {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index 4d600f94f..5f3ab5cc7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationConfig; @@ -28,10 +29,14 @@ import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierParams; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.TaskType; import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.regression.gbtregressor.GBTRegressorParams; import org.apache.flink.streaming.api.datastream.DataStream; @@ -39,6 +44,7 @@ import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.api.internal.TableImpl; import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; import org.apache.commons.lang3.ArrayUtils; @@ -63,9 +69,28 @@ public static DataStream trainRegressor(Table data, BaseGBTParams< return train(data, estimator, TaskType.REGRESSION); } + private static boolean isVectorType(TypeInformation typeInfo) { + return typeInfo instanceof DenseVectorTypeInfo + || typeInfo instanceof SparseVectorTypeInfo + || typeInfo instanceof VectorTypeInfo; + } + static DataStream train( Table data, BaseGBTParams estimator, TaskType taskType) { - return train(data, fromEstimator(estimator, taskType)); + String[] featuresCols = estimator.getFeaturesCols(); + TypeInformation[] featuresTypes = + Arrays.stream(featuresCols) + .map(d -> TableUtils.getTypeInfoByName(data.getResolvedSchema(), d)) + .toArray(TypeInformation[]::new); + for (int i = 0; i < featuresCols.length; i += 1) { + Preconditions.checkArgument( + null != featuresTypes[i], + String.format( + "Column name %s not existed in the input data.", featuresCols[i])); + } + + boolean isInputVector = featuresCols.length == 1 && isVectorType(featuresTypes[0]); + return train(data, fromEstimator(estimator, isInputVector, taskType)); } /** Trains a gradient boosting tree model with given data and parameters. */ @@ -124,7 +149,8 @@ private static DataStream boost( return dataStreamList.get(0); } - public static GbtParams fromEstimator(BaseGBTParams estimator, TaskType taskType) { + public static GbtParams fromEstimator( + BaseGBTParams estimator, boolean isInputVector, TaskType taskType) { final Map, Object> paramMap = estimator.getParamMap(); final Set> unsupported = new HashSet<>( @@ -147,9 +173,8 @@ public static GbtParams fromEstimator(BaseGBTParams estimator, TaskType taskT GbtParams p = new GbtParams(); p.taskType = taskType; - p.featureCols = estimator.getInputCols(); - p.vectorCol = estimator.getFeaturesCol(); - p.isInputVector = (null == p.featureCols); + p.featuresCols = estimator.getFeaturesCols(); + p.isInputVector = isInputVector; p.labelCol = estimator.getLabelCol(); p.weightCol = estimator.getWeightCol(); p.categoricalCols = estimator.getCategoricalCols(); @@ -201,7 +226,7 @@ public TrainContext map(Integer value) { if (!trainContext.params.isInputVector) { Arrays.sort( trainContext.featureMetas, - Comparator.comparing(d -> ArrayUtils.indexOf(p.featureCols, d.name))); + Comparator.comparing(d -> ArrayUtils.indexOf(p.featuresCols, d.name))); } trainContext.numFeatures = trainContext.featureMetas.length; trainContext.labelSumCount = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java index dfc5faedc..890804a11 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java @@ -64,7 +64,7 @@ class Preprocess { */ static Tuple2> preprocessCols(Table dataTable, GbtParams p) { - final String[] relatedCols = ArrayUtils.add(p.featureCols, p.labelCol); + final String[] relatedCols = ArrayUtils.add(p.featuresCols, p.labelCol); dataTable = dataTable.select( Arrays.stream(relatedCols) @@ -72,7 +72,7 @@ static Tuple2> preprocessCols(Table dataTable, Gb .toArray(ApiExpression[]::new)); // Maps continuous columns to integers, and obtain corresponding discretizer model. - String[] continuousCols = ArrayUtils.removeElements(p.featureCols, p.categoricalCols); + String[] continuousCols = ArrayUtils.removeElements(p.featuresCols, p.categoricalCols); Tuple2> continuousMappedDataAndModelData = discretizeContinuousCols(dataTable, continuousCols, p.maxBins); dataTable = continuousMappedDataAndModelData.f0; @@ -121,9 +121,9 @@ static Tuple2> preprocessCols(Table dataTable, Gb * information for all features. */ static Tuple2> preprocessVecCol(Table dataTable, GbtParams p) { - dataTable = dataTable.select($(p.vectorCol), $(p.labelCol)); + dataTable = dataTable.select($(p.featuresCols[0]), $(p.labelCol)); Tuple2> mappedDataAndModelData = - discretizeVectorCol(dataTable, p.vectorCol, p.maxBins); + discretizeVectorCol(dataTable, p.featuresCols[0], p.maxBins); dataTable = mappedDataAndModelData.f0; DataStream featureMeta = buildContinuousFeatureMeta(mappedDataAndModelData.f1, null); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java index f55d47773..82d536bcc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java @@ -25,8 +25,7 @@ public class GbtParams implements Serializable { public TaskType taskType; // Parameters related with input data. - public String[] featureCols; - public String vectorCol; + public String[] featuresCols; public boolean isInputVector; public String labelCol; public String weightCol; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index db004df4e..2f0f9e79a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -132,13 +132,13 @@ public void processElement1(StreamRecord streamRecord) throws Exception { instance.label = row.getFieldAs(gbtParams.labelCol); if (gbtParams.isInputVector) { - Vector vec = row.getFieldAs(gbtParams.vectorCol); + Vector vec = row.getFieldAs(gbtParams.featuresCols[0]); SparseVector sv = vec.toSparse(); instance.featureIds = sv.indices.length == sv.size() ? null : sv.indices; instance.featureValues = Arrays.stream(sv.values).mapToInt(d -> (int) d).toArray(); } else { instance.featureValues = - Arrays.stream(gbtParams.featureCols) + Arrays.stream(gbtParams.featuresCols) .mapToInt(col -> ((Number) row.getFieldAs(col)).intValue()) .toArray(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCols.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCols.java new file mode 100644 index 000000000..842deb6ff --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCols.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.api.Stage; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared featuresCols param. + * + *

{@link HasFeaturesCols} is typically used for {@link Stage}s that implement {@link + * HasLabelCol}. It is preferred to use {@link HasInputCol} for other cases. + */ +public interface HasFeaturesCols extends WithParams { + Param FEATURES_COLS = + new StringArrayParam( + "featuresCols", + "Feature column names.", + new String[] {"features"}, + ParamValidators.nonEmptyArray()); + + default String[] getFeaturesCols() { + return get(FEATURES_COLS); + } + + default T setFeaturesCols(String... value) { + return set(FEATURES_COLS, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java index acd6e65c1..05c127525 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java @@ -78,8 +78,7 @@ public Table[] transform(Table... inputs) { //noinspection unchecked DataStream inputData = (DataStream) inputList.get(0); return inputData.map( - new PredictLabelFunction( - broadcastModelKey, getInputCols(), getFeaturesCol()), + new PredictLabelFunction(broadcastModelKey, getFeaturesCols()), outputTypeInfo); }); return new Table[] {tEnv.fromDataStream(predictionResult)}; @@ -88,15 +87,12 @@ broadcastModelKey, getInputCols(), getFeaturesCol()), private static class PredictLabelFunction extends RichMapFunction { private final String broadcastModelKey; - private final String[] inputCols; - private final String featuresCol; + private final String[] featuresCols; private GBTModelData modelData; - public PredictLabelFunction( - String broadcastModelKey, String[] inputCols, String featuresCol) { + public PredictLabelFunction(String broadcastModelKey, String[] featuresCols) { this.broadcastModelKey = broadcastModelKey; - this.inputCols = inputCols; - this.featuresCol = featuresCol; + this.featuresCols = featuresCols; } @Override @@ -106,7 +102,7 @@ public Row map(Row value) throws Exception { (GBTModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); } - IntDoubleHashMap features = modelData.rowToFeatures(value, inputCols, featuresCol); + IntDoubleHashMap features = modelData.rowToFeatures(value, featuresCols); double pred = modelData.predictRaw(features); return Row.join(value, Row.of(pred)); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java index 4462daf36..edbb719fe 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -166,8 +166,7 @@ public void before() { @Test public void testParam() { GBTClassifier gbtc = new GBTClassifier(); - Assert.assertEquals("features", gbtc.getFeaturesCol()); - Assert.assertNull(gbtc.getInputCols()); + Assert.assertArrayEquals(new String[] {"features"}, gbtc.getFeaturesCols()); Assert.assertEquals("label", gbtc.getLabelCol()); Assert.assertArrayEquals(new String[] {}, gbtc.getCategoricalCols()); Assert.assertEquals("prediction", gbtc.getPredictionCol()); @@ -193,8 +192,7 @@ public void testParam() { Assert.assertEquals("rawPrediction", gbtc.getRawPredictionCol()); Assert.assertEquals("probability", gbtc.getProbabilityCol()); - gbtc.setFeaturesCol("vec") - .setInputCols("f0", "f1", "f2") + gbtc.setFeaturesCols("f0", "f1", "f2") .setLabelCol("cls_label") .setCategoricalCols("f0", "f1") .setPredictionCol("pred") @@ -217,8 +215,7 @@ public void testParam() { .setRawPredictionCol("raw_pred") .setProbabilityCol("prob"); - Assert.assertEquals("vec", gbtc.getFeaturesCol()); - Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtc.getInputCols()); + Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtc.getFeaturesCols()); Assert.assertEquals("cls_label", gbtc.getLabelCol()); Assert.assertArrayEquals(new String[] {"f0", "f1"}, gbtc.getCategoricalCols()); Assert.assertEquals("pred", gbtc.getPredictionCol()); @@ -247,7 +244,7 @@ public void testParam() { @Test public void testOutputSchema() throws Exception { GBTClassifier gbtc = - new GBTClassifier().setInputCols("f0", "f1", "f2").setCategoricalCols("f2"); + new GBTClassifier().setFeaturesCols("f0", "f1", "f2").setCategoricalCols("f2"); GBTClassifierModel model = gbtc.fit(inputTable); Table output = model.transform(inputTable)[0]; Assert.assertArrayEquals( @@ -263,7 +260,7 @@ public void testOutputSchema() throws Exception { public void testFitAndPredict() throws Exception { GBTClassifier gbtc = new GBTClassifier() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("cls_label") .setRegGamma(0.) @@ -282,7 +279,7 @@ public void testFitAndPredict() throws Exception { public void testFitAndPredictWithVectorCol() throws Exception { GBTClassifier gbtc = new GBTClassifier() - .setFeaturesCol("vec") + .setFeaturesCols("vec") .setLabelCol("cls_label") .setRegGamma(0.) .setMaxBins(3) @@ -342,7 +339,7 @@ public void testFitAndPredictWithVectorCol() throws Exception { public void testFitAndPredictWithNoCategoricalCols() throws Exception { GBTClassifier gbtc = new GBTClassifier() - .setInputCols("f0", "f1") + .setFeaturesCols("f0", "f1") .setLabelCol("cls_label") .setRegGamma(0.) .setMaxBins(5) @@ -402,7 +399,7 @@ public void testFitAndPredictWithNoCategoricalCols() throws Exception { public void testEstimatorSaveLoadAndPredict() throws Exception { GBTClassifier gbtc = new GBTClassifier() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("cls_label") .setRegGamma(0.) @@ -426,7 +423,7 @@ public void testEstimatorSaveLoadAndPredict() throws Exception { public void testModelSaveLoadAndPredict() throws Exception { GBTClassifier gbtc = new GBTClassifier() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("cls_label") .setRegGamma(0.) @@ -447,7 +444,7 @@ public void testModelSaveLoadAndPredict() throws Exception { public void testGetModelData() throws Exception { GBTClassifier gbtc = new GBTClassifier() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("cls_label") .setRegGamma(0.) @@ -471,7 +468,7 @@ public void testGetModelData() throws Exception { Assert.assertEquals(gbtc.getMaxIter(), modelData.allTrees.size()); Assert.assertEquals(gbtc.getCategoricalCols().length, modelData.categoryToIdMaps.size()); Assert.assertEquals( - gbtc.getInputCols().length - gbtc.getCategoricalCols().length, + gbtc.getFeaturesCols().length - gbtc.getCategoricalCols().length, modelData.featureIdToBinEdges.size()); Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); } @@ -480,7 +477,7 @@ public void testGetModelData() throws Exception { public void testSetModelData() throws Exception { GBTClassifier gbtc = new GBTClassifier() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("cls_label") .setRegGamma(0.) diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java index b63494ca5..0df47c61d 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java @@ -93,7 +93,7 @@ public void before() { private GbtParams getCommonGbtParams() { GbtParams p = new GbtParams(); - p.featureCols = new String[] {"f0", "f1", "f2"}; + p.featuresCols = new String[] {"f0", "f1", "f2"}; p.categoricalCols = new String[] {"f2"}; p.isInputVector = false; p.gamma = 0.; diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java index 0ad602a74..cc99b9627 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java @@ -128,7 +128,8 @@ public void invoke(T value, Context context) { @Test public void testPreprocessCols() throws Exception { GbtParams p = new GbtParams(); - p.featureCols = new String[] {"f0", "f1", "f2"}; + p.isInputVector = false; + p.featuresCols = new String[] {"f0", "f1", "f2"}; p.categoricalCols = new String[] {"f2"}; p.labelCol = "label"; p.maxBins = 3; @@ -174,7 +175,8 @@ public void testPreprocessCols() throws Exception { @Test public void testPreprocessVectorCol() throws Exception { GbtParams p = new GbtParams(); - p.vectorCol = "vec"; + p.isInputVector = true; + p.featuresCols = new String[] {"vec"}; p.labelCol = "label"; p.maxBins = 3; Tuple2> results = Preprocess.preprocessVecCol(inputTable, p); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java index f6dc213ce..852e7969f 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java @@ -133,8 +133,7 @@ public void before() { @Test public void testParam() { GBTRegressor gbtr = new GBTRegressor(); - Assert.assertEquals("features", gbtr.getFeaturesCol()); - Assert.assertNull(gbtr.getInputCols()); + Assert.assertArrayEquals(new String[] {"features"}, gbtr.getFeaturesCols()); Assert.assertEquals("label", gbtr.getLabelCol()); Assert.assertArrayEquals(new String[] {}, gbtr.getCategoricalCols()); Assert.assertEquals("prediction", gbtr.getPredictionCol()); @@ -158,8 +157,7 @@ public void testParam() { Assert.assertEquals("squared", gbtr.getLossType()); - gbtr.setFeaturesCol("vec") - .setInputCols("f0", "f1", "f2") + gbtr.setFeaturesCols("f0", "f1", "f2") .setLabelCol("label") .setCategoricalCols("f0", "f1") .setPredictionCol("pred") @@ -180,8 +178,7 @@ public void testParam() { .setRegLambda(.1) .setRegGamma(.1); - Assert.assertEquals("vec", gbtr.getFeaturesCol()); - Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtr.getInputCols()); + Assert.assertArrayEquals(new String[] {"f0", "f1", "f2"}, gbtr.getFeaturesCols()); Assert.assertEquals("label", gbtr.getLabelCol()); Assert.assertArrayEquals(new String[] {"f0", "f1"}, gbtr.getCategoricalCols()); Assert.assertEquals("pred", gbtr.getPredictionCol()); @@ -207,7 +204,7 @@ public void testParam() { @Test public void testOutputSchema() throws Exception { GBTRegressor gbtr = - new GBTRegressor().setInputCols("f0", "f1", "f2").setCategoricalCols("f2"); + new GBTRegressor().setFeaturesCols("f0", "f1", "f2").setCategoricalCols("f2"); GBTRegressorModel model = gbtr.fit(inputTable); Table output = model.transform(inputTable)[0]; Assert.assertArrayEquals( @@ -221,7 +218,7 @@ public void testOutputSchema() throws Exception { public void testFitAndPredict() throws Exception { GBTRegressor gbtr = new GBTRegressor() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("label") .setRegGamma(0.) @@ -236,7 +233,7 @@ public void testFitAndPredict() throws Exception { public void testFitAndPredictWithVectorCol() throws Exception { GBTRegressor gbtr = new GBTRegressor() - .setFeaturesCol("vec") + .setFeaturesCols("vec") .setLabelCol("label") .setRegGamma(0.) .setMaxBins(3) @@ -262,7 +259,7 @@ public void testFitAndPredictWithVectorCol() throws Exception { public void testFitAndPredictWithNoCategoricalCols() throws Exception { GBTRegressor gbtr = new GBTRegressor() - .setInputCols("f0", "f1") + .setFeaturesCols("f0", "f1") .setLabelCol("label") .setRegGamma(0.) .setMaxBins(5) @@ -288,7 +285,7 @@ public void testFitAndPredictWithNoCategoricalCols() throws Exception { public void testEstimatorSaveLoadAndPredict() throws Exception { GBTRegressor gbtr = new GBTRegressor() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("label") .setRegGamma(0.) @@ -308,7 +305,7 @@ public void testEstimatorSaveLoadAndPredict() throws Exception { public void testModelSaveLoadAndPredict() throws Exception { GBTRegressor gbtr = new GBTRegressor() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("label") .setRegGamma(0.) @@ -325,7 +322,7 @@ public void testModelSaveLoadAndPredict() throws Exception { public void testGetModelData() throws Exception { GBTRegressor gbtr = new GBTRegressor() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("label") .setRegGamma(0.) @@ -349,7 +346,7 @@ public void testGetModelData() throws Exception { Assert.assertEquals(gbtr.getMaxIter(), modelData.allTrees.size()); Assert.assertEquals(gbtr.getCategoricalCols().length, modelData.categoryToIdMaps.size()); Assert.assertEquals( - gbtr.getInputCols().length - gbtr.getCategoricalCols().length, + gbtr.getFeaturesCols().length - gbtr.getCategoricalCols().length, modelData.featureIdToBinEdges.size()); Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); } @@ -358,7 +355,7 @@ public void testGetModelData() throws Exception { public void testSetModelData() throws Exception { GBTRegressor gbtr = new GBTRegressor() - .setInputCols("f0", "f1", "f2") + .setFeaturesCols("f0", "f1", "f2") .setCategoricalCols("f2") .setLabelCol("label") .setRegGamma(0.) From b0309e49ffce24cc8e98e0e63c9e48c9db1efa55 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 28 Feb 2023 20:39:09 +0800 Subject: [PATCH 18/47] Fix cases when reading a shared item earlier than its initialization. --- .../common/sharedstorage/SharedStorage.java | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java index 591bb9481..f056034bb 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java @@ -86,8 +86,24 @@ static class Reader { } T get() { - //noinspection unchecked - return (T) m.get(t); + // It is possible that the `get` request of an item is triggered earlier than its + // initialization. In this case, we wait for a while. + long waitTime = 10; + do { + //noinspection unchecked + T value = (T) m.get(t); + if (null != value) { + return value; + } + try { + Thread.sleep(waitTime); + } catch (InterruptedException e) { + break; + } + waitTime *= 2; + } while (waitTime < 10 * 1000); + throw new IllegalStateException( + String.format("Failed to get value of %s after waiting %d ms.", t, waitTime)); } } From 9d3f10f77328c0adc57b3af75572dd0912193a47 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 28 Feb 2023 20:45:24 +0800 Subject: [PATCH 19/47] Improve GBT params. --- .../gbtclassifier/GBTClassifierModel.java | 2 +- .../GBTClassifierModelParams.java | 31 +++++++++++++++++++ .../gbtclassifier/GBTClassifierParams.java | 5 +-- .../flink/ml/common/gbt/BaseGBTModel.java | 3 +- ...delParams.java => BaseGBTModelParams.java} | 2 +- .../flink/ml/common/gbt/BaseGBTParams.java | 2 +- .../gbtregressor/GBTRegressorModel.java | 3 +- .../gbtregressor/GBTRegressorModelParams.java | 28 +++++++++++++++++ .../gbtregressor/GBTRegressorParams.java | 2 +- 9 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModelParams.java rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/{GBTModelParams.java => BaseGBTModelParams.java} (97%) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModelParams.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java index 24116bac4..7bf0763fe 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java @@ -44,7 +44,7 @@ /** A Model computed by {@link GBTClassifier}. */ public class GBTClassifierModel extends BaseGBTModel - implements GBTClassifierParams { + implements GBTClassifierModelParams { /** * Loads model data from path. diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModelParams.java new file mode 100644 index 000000000..e4625e9e4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModelParams.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.gbtclassifier; + +import org.apache.flink.ml.common.gbt.BaseGBTModelParams; +import org.apache.flink.ml.common.param.HasProbabilityCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; + +/** + * Parameters for {@link GBTClassifierModel}. + * + * @param The class type of this instance. + */ +public interface GBTClassifierModelParams + extends BaseGBTModelParams, HasRawPredictionCol, HasProbabilityCol {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java index 20ee450ee..0640f56c5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierParams.java @@ -19,8 +19,6 @@ package org.apache.flink.ml.classification.gbtclassifier; import org.apache.flink.ml.common.gbt.BaseGBTParams; -import org.apache.flink.ml.common.param.HasProbabilityCol; -import org.apache.flink.ml.common.param.HasRawPredictionCol; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; import org.apache.flink.ml.param.StringParam; @@ -30,8 +28,7 @@ * * @param The class type of this instance. */ -public interface GBTClassifierParams - extends BaseGBTParams, HasRawPredictionCol, HasProbabilityCol { +public interface GBTClassifierParams extends BaseGBTParams, GBTClassifierModelParams { Param LOSS_TYPE = new StringParam( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java index 9b4e81001..315a49489 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java @@ -31,8 +31,7 @@ import java.util.Map; /** Base model computed by {@link GBTClassifier} or {@link GBTRegressor}. */ -public abstract class BaseGBTModel> - implements Model, GBTModelParams { +public abstract class BaseGBTModel> implements Model { protected final Map, Object> paramMap = new HashMap<>(); protected Table modelDataTable; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModelParams.java similarity index 97% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModelParams.java index d6de81591..ec221472d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModelParams.java @@ -34,5 +34,5 @@ * * @param The class type of this instance. */ -public interface GBTModelParams +public interface BaseGBTModelParams extends HasFeaturesCols, HasLabelCol, HasCategoricalCols, HasPredictionCol {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java index de65c77f5..fc75cba8b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java @@ -44,7 +44,7 @@ * @param The class type of this instance. */ public interface BaseGBTParams - extends GBTModelParams, + extends BaseGBTModelParams, HasLeafCol, HasWeightCol, HasMaxDepth, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java index 05c127525..6cadaf696 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java @@ -40,7 +40,8 @@ import java.util.Collections; /** A Model computed by {@link GBTRegressor}. */ -public class GBTRegressorModel extends BaseGBTModel { +public class GBTRegressorModel extends BaseGBTModel + implements GBTRegressorModelParams { /** * Loads model data from path. diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModelParams.java new file mode 100644 index 000000000..84fe9c4f8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModelParams.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.gbtregressor; + +import org.apache.flink.ml.common.gbt.BaseGBTModelParams; + +/** + * Parameters for {@link GBTRegressorModel}. + * + * @param The class type of this instance. + */ +public interface GBTRegressorModelParams extends BaseGBTModelParams {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java index 0f9ed5d27..184cf158a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java @@ -28,7 +28,7 @@ * * @param The class type of this instance. */ -public interface GBTRegressorParams extends BaseGBTParams { +public interface GBTRegressorParams extends BaseGBTParams, GBTRegressorModelParams { Param LOSS_TYPE = new StringParam( "lossType", From 1f592eb05f596439977032ab2be40c7da1f5308a Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 28 Feb 2023 20:49:04 +0800 Subject: [PATCH 20/47] Remove unused HasLossType. --- .../flink/ml/common/param/HasLossType.java | 44 ------------------- 1 file changed, 44 deletions(-) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java deleted file mode 100644 index daed708ac..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLossType.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.StringParam; -import org.apache.flink.ml.param.WithParams; - -/** Interface for the shared maxBins param. */ -public interface HasLossType extends WithParams { - - Param LOSS_TYPE = - new StringParam( - "lossType", - "Loss type.", - "squared", - ParamValidators.inArray("squared", "absolute", "logistic")); - - default String getLossType() { - return get(LOSS_TYPE); - } - - default T setLossType(String value) { - set(LOSS_TYPE, value); - return (T) this; - } -} From bd7917663b58693c0699203e425ce5067015b7e4 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 28 Feb 2023 21:02:37 +0800 Subject: [PATCH 21/47] Improve loss func. --- .../ml/common/gbt/defs/TrainContext.java | 4 +- .../apache/flink/ml/common/gbt/loss/Loss.java | 52 ------------------- .../CacheDataCalcLocalHistsOperator.java | 4 +- .../common/gbt/operators/InstanceUpdater.java | 4 +- .../operators/TrainContextInitializer.java | 14 ++--- .../AbsoluteErrorLoss.java} | 32 ++++++++---- .../{gbt/loss => lossfunc}/LogLoss.java | 28 +++++++--- .../flink/ml/common/lossfunc/LossFunc.java | 33 ++++++++++++ .../SquaredErrorLoss.java} | 32 ++++++++---- 9 files changed, 113 insertions(+), 90 deletions(-) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/{gbt/loss/AbsoluteError.java => lossfunc/AbsoluteErrorLoss.java} (50%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/{gbt/loss => lossfunc}/LogLoss.java (61%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/{gbt/loss/SquaredError.java => lossfunc/SquaredErrorLoss.java} (53%) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java index 32ed86bc3..9dc1b627a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.common.gbt.defs; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.common.gbt.loss.Loss; +import org.apache.flink.ml.common.lossfunc.LossFunc; import java.io.Serializable; import java.util.Random; @@ -43,5 +43,5 @@ public class TrainContext implements Serializable { public Tuple2 labelSumCount; public double prior; - public Loss loss; + public LossFunc loss; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java deleted file mode 100644 index fa6fadf7f..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/Loss.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.loss; - -import java.io.Serializable; - -/** Loss functions for gradient boosting algorithms. */ -public interface Loss extends Serializable { - - /** - * Calculates loss given pred and y. - * - * @param pred prediction value. - * @param y label value. - * @return loss value. - */ - double loss(double pred, double y); - - /** - * Calculates value of gradient given prediction and label. - * - * @param pred prediction value. - * @param y label value. - * @return the value of gradient. - */ - double gradient(double pred, double y); - - /** - * Calculates value of second derivative, i.e. hessian, given prediction and label. - * - * @param pred prediction value. - * @param y label value. - * @return the value of second derivative, i.e. hessian. - */ - double hessian(double pred, double y); -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 2f0f9e79a..5c46f7015 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -28,8 +28,8 @@ import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.gbt.loss.Loss; import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; +import org.apache.flink.ml.common.lossfunc.LossFunc; import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.ml.linalg.SparseVector; @@ -197,7 +197,7 @@ public void onEpochWatermarkIncremented( if (0 == pgh.length) { pgh = new PredGradHess[instances.length]; double prior = trainContext.prior; - Loss loss = trainContext.loss; + LossFunc loss = trainContext.loss; for (int i = 0; i < instances.length; i += 1) { double label = instances[i].label; pgh[i] = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java index 9b00ac57b..2c157ad10 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -24,7 +24,7 @@ import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.gbt.loss.Loss; +import org.apache.flink.ml.common.lossfunc.LossFunc; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,7 +36,7 @@ class InstanceUpdater { private static final Logger LOG = LoggerFactory.getLogger(InstanceUpdater.class); private final int subtaskId; - private final Loss loss; + private final LossFunc loss; private final double stepSize; private final double prior; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java index 3bf7f3c41..e657eb64a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java @@ -23,10 +23,10 @@ import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.TaskType; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.gbt.loss.AbsoluteError; -import org.apache.flink.ml.common.gbt.loss.LogLoss; -import org.apache.flink.ml.common.gbt.loss.Loss; -import org.apache.flink.ml.common.gbt.loss.SquaredError; +import org.apache.flink.ml.common.lossfunc.AbsoluteErrorLoss; +import org.apache.flink.ml.common.lossfunc.LogLoss; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.lossfunc.SquaredErrorLoss; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -135,15 +135,15 @@ private int getNumBaggingFeatures(int numFeatures) { } } - private Loss getLoss() { + private LossFunc getLoss() { String lossType = params.lossType; switch (lossType) { case "logistic": return LogLoss.INSTANCE; case "squared": - return SquaredError.INSTANCE; + return SquaredErrorLoss.INSTANCE; case "absolute": - return AbsoluteError.INSTANCE; + return AbsoluteErrorLoss.INSTANCE; default: throw new UnsupportedOperationException("Unsupported loss."); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/AbsoluteError.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/AbsoluteErrorLoss.java similarity index 50% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/AbsoluteError.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/AbsoluteErrorLoss.java index b09240205..a79c3e575 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/AbsoluteError.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/AbsoluteErrorLoss.java @@ -16,31 +16,45 @@ * limitations under the License. */ -package org.apache.flink.ml.common.gbt.loss; +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; /** - * Squared error loss function defined as |y - pred| where y and pred are label and predictions for + * Absolute error loss function defined as |y - pred| where y and pred are label and predictions for * the instance respectively. */ -public class AbsoluteError implements Loss { +public class AbsoluteErrorLoss implements LossFunc { - public static final AbsoluteError INSTANCE = new AbsoluteError(); + public static final AbsoluteErrorLoss INSTANCE = new AbsoluteErrorLoss(); - private AbsoluteError() {} + private AbsoluteErrorLoss() {} @Override - public double loss(double pred, double y) { - double error = y - pred; + public double loss(double pred, double label) { + double error = label - pred; return Math.abs(error); } @Override - public double gradient(double pred, double y) { - return y > pred ? -1. : 1; + public double gradient(double pred, double label) { + return label > pred ? -1. : 1; } @Override public double hessian(double pred, double y) { return 0.; } + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + throw new UnsupportedOperationException(); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + throw new UnsupportedOperationException(); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/LogLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LogLoss.java similarity index 61% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/LogLoss.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LogLoss.java index b2efe8c6c..507d21e19 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/LogLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LogLoss.java @@ -16,7 +16,10 @@ * limitations under the License. */ -package org.apache.flink.ml.common.gbt.loss; +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; import org.apache.commons.math3.analysis.function.Sigmoid; @@ -26,7 +29,7 @@ *

The binary log loss defined as -y * pred + log(1 + exp(pred)) where y is a label in {0, 1} and * pred is the predicted logit for the sample point. */ -public class LogLoss implements Loss { +public class LogLoss implements LossFunc { public static final LogLoss INSTANCE = new LogLoss(); private final Sigmoid sigmoid = new Sigmoid(); @@ -34,18 +37,29 @@ public class LogLoss implements Loss { private LogLoss() {} @Override - public double loss(double pred, double y) { - return -y * pred + Math.log(1 + Math.exp(pred)); + public double loss(double pred, double label) { + return -label * pred + Math.log(1 + Math.exp(pred)); } @Override - public double gradient(double pred, double y) { - return sigmoid.value(pred) - y; + public double gradient(double pred, double label) { + return sigmoid.value(pred) - label; } @Override - public double hessian(double pred, double y) { + public double hessian(double pred, double label) { double sig = sigmoid.value(pred); return sig * (1 - sig); } + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + throw new UnsupportedOperationException(); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + throw new UnsupportedOperationException(); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java index a90967a73..e1326f88e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java @@ -48,4 +48,37 @@ public interface LossFunc extends Serializable { */ void computeGradient( LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient); + + /** + * Calculates loss given pred and label. + * + * @param pred prediction value. + * @param label label value. + * @return loss value. + */ + default double loss(double pred, double label) { + throw new UnsupportedOperationException(); + } + + /** + * Calculates value of gradient given prediction and label. + * + * @param pred prediction value. + * @param label label value. + * @return the value of gradient. + */ + default double gradient(double pred, double label) { + throw new UnsupportedOperationException(); + } + + /** + * Calculates value of second derivative, i.e. hessian, given prediction and label. + * + * @param pred prediction value. + * @param label label value. + * @return the value of second derivative, i.e. hessian. + */ + default double hessian(double pred, double label) { + throw new UnsupportedOperationException(); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/SquaredError.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/SquaredErrorLoss.java similarity index 53% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/SquaredError.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/SquaredErrorLoss.java index 14321c024..d9c3edf69 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/loss/SquaredError.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/SquaredErrorLoss.java @@ -16,31 +16,45 @@ * limitations under the License. */ -package org.apache.flink.ml.common.gbt.loss; +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; /** * Squared error loss function defined as (y - pred)^2 where y and pred are label and predictions * for the instance respectively. */ -public class SquaredError implements Loss { +public class SquaredErrorLoss implements LossFunc { - public static final SquaredError INSTANCE = new SquaredError(); + public static final SquaredErrorLoss INSTANCE = new SquaredErrorLoss(); - private SquaredError() {} + private SquaredErrorLoss() {} @Override - public double loss(double pred, double y) { - double error = y - pred; + public double loss(double pred, double label) { + double error = label - pred; return error * error; } @Override - public double gradient(double pred, double y) { - return -2. * (y - pred); + public double gradient(double pred, double label) { + return -2. * (label - pred); } @Override - public double hessian(double pred, double y) { + public double hessian(double pred, double label) { return 2.; } + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + throw new UnsupportedOperationException(); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + throw new UnsupportedOperationException(); + } } From fdaa732ad6d33420e81cbe9219cc2259ea70a2e0 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 1 Mar 2023 10:53:48 +0800 Subject: [PATCH 22/47] Refactor params, BoostingStrategy, and Distributor. --- .../apache/flink/ml/util}/Distributor.java | 2 +- .../gbtclassifier/GBTClassifier.java | 2 +- .../flink/ml/common/gbt/BaseGBTParams.java | 177 +++++++++++++++--- .../ml/common/gbt/BoostIterationBody.java | 10 +- .../flink/ml/common/gbt/GBTModelData.java | 6 +- .../apache/flink/ml/common/gbt/GBTRunner.java | 112 ++++++----- .../flink/ml/common/gbt/Preprocess.java | 35 ++-- .../ml/common/gbt/defs/BoostingStrategy.java | 111 +++++++++++ .../flink/ml/common/gbt/defs/GbtParams.java | 56 ------ .../defs/LossType.java} | 22 +-- .../apache/flink/ml/common/gbt/defs/Node.java | 8 +- .../ml/common/gbt/defs/TrainContext.java | 30 ++- .../CacheDataCalcLocalHistsOperator.java | 18 +- .../ml/common/gbt/operators/HistBuilder.java | 4 +- .../common/gbt/operators/InstanceUpdater.java | 2 +- .../ml/common/gbt/operators/NodeSplitter.java | 4 +- .../ml/common/gbt/operators/SplitFinder.java | 10 +- .../gbt/operators/TerminationOperator.java | 2 +- .../operators/TrainContextInitializer.java | 49 +++-- .../splitter/CategoricalFeatureSplitter.java | 7 +- .../splitter/ContinuousFeatureSplitter.java | 9 +- .../common/gbt/splitter/FeatureSplitter.java | 15 +- .../splitter/HistogramFeatureSplitter.java | 11 +- .../ml/common/lossfunc/AbsoluteErrorLoss.java | 60 ------ .../param/HasFeatureSubsetStrategy.java | 42 ----- .../flink/ml/common/param/HasMaxBins.java | 42 ----- .../flink/ml/common/param/HasMaxDepth.java | 38 ---- .../flink/ml/common/param/HasMinInfoGain.java | 42 ----- .../common/param/HasMinInstancesPerNode.java | 42 ----- .../param/HasMinWeightFractionPerNode.java | 42 ----- .../flink/ml/common/param/HasStepSize.java | 42 ----- .../ml/common/param/HasSubsamplingRate.java | 42 ----- .../param/HasValidationIndicatorCol.java | 40 ---- .../ml/common/param/HasValidationTol.java | 43 ----- .../regression/gbtregressor/GBTRegressor.java | 2 +- .../gbtregressor/GBTRegressorParams.java | 5 +- .../flink/ml/common/gbt/GBTRunnerTest.java | 61 +++--- .../flink/ml/common/gbt/PreprocessTest.java | 30 +-- 38 files changed, 500 insertions(+), 775 deletions(-) rename {flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs => flink-ml-core/src/main/java/org/apache/flink/ml/util}/Distributor.java (98%) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BoostingStrategy.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/{param/HasLeafCol.java => gbt/defs/LossType.java} (59%) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/AbsoluteErrorLoss.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Distributor.java b/flink-ml-core/src/main/java/org/apache/flink/ml/util/Distributor.java similarity index 98% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Distributor.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/util/Distributor.java index 0061dfc61..b985adfb8 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Distributor.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/Distributor.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.gbt.defs; +package org.apache.flink.ml.util; import java.io.Serializable; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java index 2ff827ea9..980af398d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java @@ -61,7 +61,7 @@ public GBTClassifierModel fit(Table... inputs) { Preconditions.checkArgument(inputs.length == 1); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream modelData = GBTRunner.trainClassifier(inputs[0], this); + DataStream modelData = GBTRunner.train(inputs[0], this); GBTClassifierModel model = new GBTClassifierModel(); model.setModelData(tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData"))); ReadWriteUtils.updateExistingParams(model, getParamMap()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java index fc75cba8b..d2ae5b446 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java @@ -18,47 +18,25 @@ package org.apache.flink.ml.common.gbt; -import org.apache.flink.ml.common.param.HasFeatureSubsetStrategy; -import org.apache.flink.ml.common.param.HasLeafCol; -import org.apache.flink.ml.common.param.HasMaxBins; -import org.apache.flink.ml.common.param.HasMaxDepth; import org.apache.flink.ml.common.param.HasMaxIter; -import org.apache.flink.ml.common.param.HasMinInfoGain; -import org.apache.flink.ml.common.param.HasMinInstancesPerNode; -import org.apache.flink.ml.common.param.HasMinWeightFractionPerNode; import org.apache.flink.ml.common.param.HasSeed; -import org.apache.flink.ml.common.param.HasStepSize; -import org.apache.flink.ml.common.param.HasSubsamplingRate; -import org.apache.flink.ml.common.param.HasValidationIndicatorCol; -import org.apache.flink.ml.common.param.HasValidationTol; import org.apache.flink.ml.common.param.HasWeightCol; import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; /** * Common parameters for GBT classifier and regressor. * - *

TODO: support param thresholds, impurity (actually meaningless) + *

NOTE: Features related with {@link #WEIGHT_COL}, {@link #LEAF_COL}, and {@link + * #VALIDATION_INDICATOR_COL} are not implemented yet. * * @param The class type of this instance. */ public interface BaseGBTParams - extends BaseGBTModelParams, - HasLeafCol, - HasWeightCol, - HasMaxDepth, - HasMaxBins, - HasMinInstancesPerNode, - HasMinWeightFractionPerNode, - HasMinInfoGain, - HasMaxIter, - HasStepSize, - HasSeed, - HasSubsamplingRate, - HasFeatureSubsetStrategy, - HasValidationIndicatorCol, - HasValidationTol { + extends BaseGBTModelParams, HasWeightCol, HasMaxIter, HasSeed { Param REG_LAMBDA = new DoubleParam( "regLambda", @@ -71,6 +49,63 @@ public interface BaseGBTParams "L2 regularization term for the weights of leaves.", 1., ParamValidators.gtEq(0)); + Param LEAF_COL = + new StringParam("leafCol", "Predicted leaf index of each instance in each tree.", null); + Param MAX_DEPTH = + new IntParam("maxDepth", "Maximum depth of the tree.", 5, ParamValidators.gtEq(1)); + Param MAX_BINS = + new IntParam( + "maxBins", + "Maximum number of bins used for discretizing continuous features.", + 32, + ParamValidators.gtEq(2)); + Param MIN_INSTANCES_PER_NODE = + new IntParam( + "minInstancesPerNode", + "Minimum number of instances each node must have. If a split causes the left or right child to have fewer instances than minInstancesPerNode, the split is invalid.", + 1, + ParamValidators.gtEq(1)); + Param MIN_WEIGHT_FRACTION_PER_NODE = + new DoubleParam( + "minWeightFractionPerNode", + "Minimum fraction of the weighted sample count that each node must have. If a split causes the left or right child to have a smaller fraction of the total weight than minWeightFractionPerNode, the split is invalid.", + 0., + ParamValidators.gtEq(0.)); + Param MIN_INFO_GAIN = + new DoubleParam( + "minInfoGain", + "Minimum information gain for a split to be considered valid.", + 0., + ParamValidators.gtEq(0.)); + Param STEP_SIZE = + new DoubleParam( + "stepSize", + "Step size for shrinking the contribution of each estimator.", + 0.1, + ParamValidators.inRange(0., 1.)); + Param SUBSAMPLING_RATE = + new DoubleParam( + "subsamplingRate", + "Fraction of the training data used for learning one tree.", + 1., + ParamValidators.inRange(0., 1.)); + Param FEATURE_SUBSET_STRATEGY = + new StringParam( + "featureSubsetStrategy.", + "Fraction of the training data used for learning one tree. Supports \"auto\", \"all\", \"onethird\", \"sqrt\", \"log2\", (0.0 - 1.0], and [1 - n].", + "auto", + ParamValidators.notNull()); + Param VALIDATION_INDICATOR_COL = + new StringParam( + "validationIndicatorCol", + "The name of the column that indicates whether each row is for training or for validation.", + null); + Param VALIDATION_TOL = + new DoubleParam( + "validationTol", + "Threshold for early stopping when fitting with validation is used.", + .01, + ParamValidators.gtEq(0)); default double getRegLambda() { return get(REG_LAMBDA); @@ -87,4 +122,92 @@ default double getRegGamma() { default T setRegGamma(Double value) { return set(REG_GAMMA, value); } + + default String getLeafCol() { + return get(LEAF_COL); + } + + default T setLeafCol(String value) { + return set(LEAF_COL, value); + } + + default int getMaxDepth() { + return get(MAX_DEPTH); + } + + default T setMaxDepth(int value) { + return set(MAX_DEPTH, value); + } + + default int getMaxBins() { + return get(MAX_BINS); + } + + default T setMaxBins(int value) { + return set(MAX_BINS, value); + } + + default int getMinInstancesPerNode() { + return get(MIN_INSTANCES_PER_NODE); + } + + default T setMinInstancesPerNode(int value) { + return set(MIN_INSTANCES_PER_NODE, value); + } + + default double getMinWeightFractionPerNode() { + return get(MIN_WEIGHT_FRACTION_PER_NODE); + } + + default T setMinWeightFractionPerNode(Double value) { + return set(MIN_WEIGHT_FRACTION_PER_NODE, value); + } + + default double getMinInfoGain() { + return get(MIN_INFO_GAIN); + } + + default T setMinInfoGain(Double value) { + return set(MIN_INFO_GAIN, value); + } + + default double getStepSize() { + return get(STEP_SIZE); + } + + default T setStepSize(Double value) { + return set(STEP_SIZE, value); + } + + default double getSubsamplingRate() { + return get(SUBSAMPLING_RATE); + } + + default T setSubsamplingRate(Double value) { + return set(SUBSAMPLING_RATE, value); + } + + default String getFeatureSubsetStrategy() { + return get(FEATURE_SUBSET_STRATEGY); + } + + default T setFeatureSubsetStrategy(String value) { + return set(FEATURE_SUBSET_STRATEGY, value); + } + + default String getValidationIndicatorCol() { + return get(VALIDATION_INDICATOR_COL); + } + + default T setValidationIndicatorCol(String value) { + return set(VALIDATION_INDICATOR_COL, value); + } + + default double getValidationTol() { + return get(VALIDATION_TOL); + } + + default T setValidationTol(Double value) { + return set(VALIDATION_TOL, value); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index 549deb292..b54cb9b44 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -28,7 +28,7 @@ import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationBody; import org.apache.flink.iteration.IterationBodyResult; -import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.Splits; import org.apache.flink.ml.common.gbt.defs.TrainContext; @@ -59,10 +59,10 @@ * of data and row-store storage of instances. */ class BoostIterationBody implements IterationBody { - private final GbtParams gbtParams; + private final BoostingStrategy strategy; - public BoostIterationBody(GbtParams gbtParams) { - this.gbtParams = gbtParams; + public BoostIterationBody(BoostingStrategy strategy) { + this.strategy = strategy; } private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( @@ -77,7 +77,7 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( // In 1st round, cache all data. For all rounds calculate local histogram based on // current tree layer. CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = - new CacheDataCalcLocalHistsOperator(gbtParams); + new CacheDataCalcLocalHistsOperator(strategy); SingleOutputStreamOperator localHists = data.connect(trainContext) .transform( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index 555d9b6c3..8b93b0272 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -116,10 +116,10 @@ public static GBTModelData from(TrainContext trainContext, List> allT } } return new GBTModelData( - trainContext.params.taskType.name(), - trainContext.params.isInputVector, + trainContext.strategy.taskType.name(), + trainContext.strategy.isInputVector, trainContext.prior, - trainContext.params.stepSize, + trainContext.strategy.stepSize, allTrees, categoryToIdMaps, featureIdToBinEdges, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index 5f3ab5cc7..0709dd37c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -26,19 +26,20 @@ import org.apache.flink.iteration.IterationConfig; import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.ReplayableDataStreamList; -import org.apache.flink.ml.classification.gbtclassifier.GBTClassifierParams; +import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.LossType; import org.apache.flink.ml.common.gbt.defs.TaskType; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.regression.gbtregressor.GBTRegressorParams; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -61,22 +62,13 @@ /** Runs a gradient boosting trees implementation. */ public class GBTRunner { - public static DataStream trainClassifier(Table data, BaseGBTParams estimator) { - return train(data, estimator, TaskType.CLASSIFICATION); - } - - public static DataStream trainRegressor(Table data, BaseGBTParams estimator) { - return train(data, estimator, TaskType.REGRESSION); - } - private static boolean isVectorType(TypeInformation typeInfo) { return typeInfo instanceof DenseVectorTypeInfo || typeInfo instanceof SparseVectorTypeInfo || typeInfo instanceof VectorTypeInfo; } - static DataStream train( - Table data, BaseGBTParams estimator, TaskType taskType) { + public static DataStream train(Table data, BaseGBTParams estimator) { String[] featuresCols = estimator.getFeaturesCols(); TypeInformation[] featuresTypes = Arrays.stream(featuresCols) @@ -90,29 +82,29 @@ static DataStream train( } boolean isInputVector = featuresCols.length == 1 && isVectorType(featuresTypes[0]); - return train(data, fromEstimator(estimator, isInputVector, taskType)); + return train(data, getStrategy(estimator, isInputVector)); } /** Trains a gradient boosting tree model with given data and parameters. */ - static DataStream train(Table dataTable, GbtParams p) { + static DataStream train(Table dataTable, BoostingStrategy strategy) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); Tuple2> preprocessResult = - p.isInputVector - ? Preprocess.preprocessVecCol(dataTable, p) - : Preprocess.preprocessCols(dataTable, p); + strategy.isInputVector + ? Preprocess.preprocessVecCol(dataTable, strategy) + : Preprocess.preprocessCols(dataTable, strategy); dataTable = preprocessResult.f0; DataStream featureMeta = preprocessResult.f1; DataStream data = tEnv.toDataStream(dataTable); DataStream> labelSumCount = - DataStreamUtils.aggregate(data, new LabelSumCountFunction(p.labelCol)); - return boost(dataTable, p, featureMeta, labelSumCount); + DataStreamUtils.aggregate(data, new LabelSumCountFunction(strategy.labelCol)); + return boost(dataTable, strategy, featureMeta, labelSumCount); } private static DataStream boost( Table dataTable, - GbtParams p, + BoostingStrategy strategy, DataStream featureMeta, DataStream> labelSumCount) { StreamTableEnvironment tEnv = @@ -134,7 +126,7 @@ private static DataStream boost( DataStream input = (DataStream) (inputs.get(0)); return input.map( new InitTrainContextFunction( - featureMetaBcName, labelSumCountBcName, p)); + featureMetaBcName, labelSumCountBcName, strategy)); }); DataStream data = tEnv.toDataStream(dataTable); @@ -145,12 +137,11 @@ private static DataStream boost( IterationConfig.newBuilder() .setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND) .build(), - new BoostIterationBody(p)); + new BoostIterationBody(strategy)); return dataStreamList.get(0); } - public static GbtParams fromEstimator( - BaseGBTParams estimator, boolean isInputVector, TaskType taskType) { + public static BoostingStrategy getStrategy(BaseGBTParams estimator, boolean isInputVector) { final Map, Object> paramMap = estimator.getParamMap(); final Set> unsupported = new HashSet<>( @@ -171,62 +162,69 @@ public static GbtParams fromEstimator( .collect(Collectors.joining(", ")))); } - GbtParams p = new GbtParams(); - p.taskType = taskType; - p.featuresCols = estimator.getFeaturesCols(); - p.isInputVector = isInputVector; - p.labelCol = estimator.getLabelCol(); - p.weightCol = estimator.getWeightCol(); - p.categoricalCols = estimator.getCategoricalCols(); + BoostingStrategy strategy = new BoostingStrategy(); + strategy.featuresCols = estimator.getFeaturesCols(); + strategy.isInputVector = isInputVector; + strategy.labelCol = estimator.getLabelCol(); + strategy.categoricalCols = estimator.getCategoricalCols(); - p.maxDepth = estimator.getMaxDepth(); - p.maxBins = estimator.getMaxBins(); - p.minInstancesPerNode = estimator.getMinInstancesPerNode(); - p.minWeightFractionPerNode = estimator.getMinWeightFractionPerNode(); - p.minInfoGain = estimator.getMinInfoGain(); - p.maxIter = estimator.getMaxIter(); - p.stepSize = estimator.getStepSize(); - p.seed = estimator.getSeed(); - p.subsamplingRate = estimator.getSubsamplingRate(); - p.featureSubsetStrategy = estimator.getFeatureSubsetStrategy(); - p.validationTol = estimator.getValidationTol(); - p.gamma = estimator.getRegGamma(); - p.lambda = estimator.getRegLambda(); + strategy.maxDepth = estimator.getMaxDepth(); + strategy.maxBins = estimator.getMaxBins(); + strategy.minInstancesPerNode = estimator.getMinInstancesPerNode(); + strategy.minWeightFractionPerNode = estimator.getMinWeightFractionPerNode(); + strategy.minInfoGain = estimator.getMinInfoGain(); + strategy.maxIter = estimator.getMaxIter(); + strategy.stepSize = estimator.getStepSize(); + strategy.seed = estimator.getSeed(); + strategy.subsamplingRate = estimator.getSubsamplingRate(); + strategy.featureSubsetStrategy = estimator.getFeatureSubsetStrategy(); + strategy.regGamma = estimator.getRegGamma(); + strategy.regLambda = estimator.getRegLambda(); - if (TaskType.CLASSIFICATION.equals(p.taskType)) { - p.lossType = estimator.get(GBTClassifierParams.LOSS_TYPE); + String lossTypeStr; + if (estimator instanceof GBTClassifier) { + strategy.taskType = TaskType.CLASSIFICATION; + lossTypeStr = ((GBTClassifier) estimator).getLossType(); + } else if (estimator instanceof GBTRegressor) { + strategy.taskType = TaskType.REGRESSION; + lossTypeStr = ((GBTRegressor) estimator).getLossType(); } else { - p.lossType = estimator.get(GBTRegressorParams.LOSS_TYPE); + throw new IllegalArgumentException( + String.format( + "Unexpected type of estimator: %s.", + estimator.getClass().getSimpleName())); } - p.maxNumLeaves = 1 << p.maxDepth - 1; - p.useMissing = true; - return p; + strategy.lossType = LossType.valueOf(lossTypeStr.toUpperCase()); + strategy.maxNumLeaves = 1 << strategy.maxDepth - 1; + strategy.useMissing = true; + return strategy; } private static class InitTrainContextFunction extends RichMapFunction { private final String featureMetaBcName; private final String labelSumCountBcName; - private final GbtParams p; + private final BoostingStrategy strategy; private InitTrainContextFunction( - String featureMetaBcName, String labelSumCountBcName, GbtParams p) { + String featureMetaBcName, String labelSumCountBcName, BoostingStrategy strategy) { this.featureMetaBcName = featureMetaBcName; this.labelSumCountBcName = labelSumCountBcName; - this.p = p; + this.strategy = strategy; } @Override public TrainContext map(Integer value) { TrainContext trainContext = new TrainContext(); - trainContext.params = p; + trainContext.strategy = strategy; trainContext.featureMetas = getRuntimeContext() .getBroadcastVariable(featureMetaBcName) .toArray(new FeatureMeta[0]); - if (!trainContext.params.isInputVector) { + if (!trainContext.strategy.isInputVector) { Arrays.sort( trainContext.featureMetas, - Comparator.comparing(d -> ArrayUtils.indexOf(p.featuresCols, d.name))); + Comparator.comparing( + d -> ArrayUtils.indexOf(strategy.featuresCols, d.name))); } trainContext.numFeatures = trainContext.featureMetas.length; trainContext.labelSumCount = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java index 890804a11..5e43cf846 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/Preprocess.java @@ -23,8 +23,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData; @@ -62,9 +62,10 @@ class Preprocess { * Maps continuous and categorical columns to integers inplace using quantile discretizer and * string indexer respectively, and obtains meta information for all columns. */ - static Tuple2> preprocessCols(Table dataTable, GbtParams p) { + static Tuple2> preprocessCols( + Table dataTable, BoostingStrategy strategy) { - final String[] relatedCols = ArrayUtils.add(p.featuresCols, p.labelCol); + final String[] relatedCols = ArrayUtils.add(strategy.featuresCols, strategy.labelCol); dataTable = dataTable.select( Arrays.stream(relatedCols) @@ -72,21 +73,24 @@ static Tuple2> preprocessCols(Table dataTable, Gb .toArray(ApiExpression[]::new)); // Maps continuous columns to integers, and obtain corresponding discretizer model. - String[] continuousCols = ArrayUtils.removeElements(p.featuresCols, p.categoricalCols); + String[] continuousCols = + ArrayUtils.removeElements(strategy.featuresCols, strategy.categoricalCols); Tuple2> continuousMappedDataAndModelData = - discretizeContinuousCols(dataTable, continuousCols, p.maxBins); + discretizeContinuousCols(dataTable, continuousCols, strategy.maxBins); dataTable = continuousMappedDataAndModelData.f0; DataStream continuousFeatureMeta = buildContinuousFeatureMeta(continuousMappedDataAndModelData.f1, continuousCols); // Maps categorical columns to integers, and obtain string indexer model. DataStream categoricalFeatureMeta; - if (p.categoricalCols.length > 0) { + if (strategy.categoricalCols.length > 0) { String[] mappedCategoricalCols = - Arrays.stream(p.categoricalCols).map(d -> d + "_output").toArray(String[]::new); + Arrays.stream(strategy.categoricalCols) + .map(d -> d + "_output") + .toArray(String[]::new); StringIndexer stringIndexer = new StringIndexer() - .setInputCols(p.categoricalCols) + .setInputCols(strategy.categoricalCols) .setOutputCols(mappedCategoricalCols) .setHandleInvalid("keep"); StringIndexerModel stringIndexerModel = stringIndexer.fit(dataTable); @@ -96,7 +100,7 @@ static Tuple2> preprocessCols(Table dataTable, Gb buildCategoricalFeatureMeta( StringIndexerModelData.getModelDataStream( stringIndexerModel.getModelData()[0]), - p.categoricalCols); + strategy.categoricalCols); } else { categoricalFeatureMeta = continuousFeatureMeta @@ -106,9 +110,11 @@ static Tuple2> preprocessCols(Table dataTable, Gb // Rename results columns. ApiExpression[] dropColumnExprs = - Arrays.stream(p.categoricalCols).map(Expressions::$).toArray(ApiExpression[]::new); + Arrays.stream(strategy.categoricalCols) + .map(Expressions::$) + .toArray(ApiExpression[]::new); ApiExpression[] renameColumnExprs = - Arrays.stream(p.categoricalCols) + Arrays.stream(strategy.categoricalCols) .map(d -> $(d + "_output").as(d)) .toArray(ApiExpression[]::new); dataTable = dataTable.dropColumns(dropColumnExprs).renameColumns(renameColumnExprs); @@ -120,10 +126,11 @@ static Tuple2> preprocessCols(Table dataTable, Gb * Maps features values in vectors to integers using quantile discretizer, and obtains meta * information for all features. */ - static Tuple2> preprocessVecCol(Table dataTable, GbtParams p) { - dataTable = dataTable.select($(p.featuresCols[0]), $(p.labelCol)); + static Tuple2> preprocessVecCol( + Table dataTable, BoostingStrategy strategy) { + dataTable = dataTable.select($(strategy.featuresCols[0]), $(strategy.labelCol)); Tuple2> mappedDataAndModelData = - discretizeVectorCol(dataTable, p.featuresCols[0], p.maxBins); + discretizeVectorCol(dataTable, strategy.featuresCols[0], strategy.maxBins); dataTable = mappedDataAndModelData.f0; DataStream featureMeta = buildContinuousFeatureMeta(mappedDataAndModelData.f1, null); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BoostingStrategy.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BoostingStrategy.java new file mode 100644 index 000000000..c64908241 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/BoostingStrategy.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.defs; + +import java.io.Serializable; + +/** Configurations for {@link org.apache.flink.ml.common.gbt.GBTRunner}. */ +public class BoostingStrategy implements Serializable { + + /** Indicates the task is classification or regression. */ + public TaskType taskType; + + /** + * Indicates whether the features are in one column of vector type or multiple columns of + * non-vector types. + */ + public boolean isInputVector; + + /** + * Names of features columns used for training. Contains only 1 column name when `isInputVector` + * is `true`. + */ + public String[] featuresCols; + + /** Name of label column. */ + public String labelCol; + + /** + * Names of columns which should be treated as categorical features, when `isInputVector` is + * `false`. + */ + public String[] categoricalCols; + + /** + * Max depth of the tree (root node is the 1st level). Depth 1 means 1 leaf node, i.e., the root + * node; Depth 2 means 1 internal node + 2 leaf nodes; etc. + */ + public int maxDepth; + + /** Maximum number of bins used for discretizing continuous features. */ + public int maxBins; + + /** + * Minimum number of instances each node must have. If a split causes the left or right child to + * have fewer instances than minInstancesPerNode, the split is invalid. + */ + public int minInstancesPerNode; + + /** + * Minimum fraction of the weighted sample count that each node must have. If a split causes the + * left or right child to have a smaller fraction of the total weight than + * minWeightFractionPerNode, the split is invalid. + * + *

NOTE: Weight column is not supported right now, so all samples have equal weights. + */ + public double minWeightFractionPerNode; + + /** Minimum information gain for a split to be considered valid. */ + public double minInfoGain; + + /** Maximum number of iterations of boosting, i.e. the number of trees in the final model. */ + public int maxIter; + + /** Step size for shrinking the contribution of each estimator. */ + public double stepSize; + + /** The random seed used in samples/features subsampling. */ + public long seed; + + /** Fraction of the training data used for learning one tree. */ + public double subsamplingRate; + + /** + * Fraction of the training data used for learning one tree. Supports "auto", "all", "onethird", + * "sqrt", "log2", (0.0 - 1.0], and [1 - n]. + */ + public String featureSubsetStrategy; + + /** Regularization term for the number of leaves. */ + public double regLambda; + + /** L2 regularization term for the weights of leaves. */ + public double regGamma; + + /** The type of loss used in boosting. */ + public LossType lossType; + + // Derived parameters. + /** Maximum number leaves. */ + public int maxNumLeaves; + /** Whether to consider missing values in the model. Always `true` right now. */ + public boolean useMissing; + + public BoostingStrategy() {} +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java deleted file mode 100644 index 82d536bcc..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/GbtParams.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.defs; - -import java.io.Serializable; - -/** Internal parameters of a gradient boosting trees algorithm. */ -public class GbtParams implements Serializable { - public TaskType taskType; - - // Parameters related with input data. - public String[] featuresCols; - public boolean isInputVector; - public String labelCol; - public String weightCol; - public String[] categoricalCols; - - // Parameters related with algorithms. - public int maxDepth; - public int maxBins; - public int minInstancesPerNode; - public double minWeightFractionPerNode; - public double minInfoGain; - public int maxIter; - public double stepSize; - public long seed; - public double subsamplingRate; - public String featureSubsetStrategy; - public double validationTol; - public double lambda; - public double gamma; - - // Derived parameters. - public String lossType; - public int maxNumLeaves; - // useMissing is always true right now. - public boolean useMissing; - - public GbtParams() {} -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLeafCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LossType.java similarity index 59% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLeafCol.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LossType.java index 52dd29e73..58f047b57 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLeafCol.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/LossType.java @@ -16,22 +16,10 @@ * limitations under the License. */ -package org.apache.flink.ml.common.param; +package org.apache.flink.ml.common.gbt.defs; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.StringParam; -import org.apache.flink.ml.param.WithParams; - -/** Interface for shared param leaf column. */ -public interface HasLeafCol extends WithParams { - Param LEAF_COL = - new StringParam("leafCol", "Predicted leaf index of each instance in each tree.", null); - - default String getLeafCol() { - return get(LEAF_COL); - } - - default T setLeafCol(String value) { - return set(LEAF_COL, value); - } +/** Indicates the type of loss. */ +public enum LossType { + SQUARED, + LOGISTIC, } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java index 82121fdf6..c83ab2a07 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Node.java @@ -23,7 +23,13 @@ import java.io.Serializable; -/** Tree node in the decision tree that will be serialized to json and deserialized from json. */ +/** + * Represents a tree node in a decision tree. + * + *

NOTE: This class should be used together with a linear indexable structure, e.g., a list or an + * array, which stores all tree nodes, because {@link #left} and {@link #right} are indices of nodes + * in the linear structure. + */ @TypeInfo(NodeTypeInfoFactory.class) public class Node implements Serializable { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java index 9dc1b627a..b66a78b7c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/TrainContext.java @@ -24,24 +24,50 @@ import java.io.Serializable; import java.util.Random; -/** Stores the training context. */ +/** + * Stores necessary static context information for training. Subtasks of co-located operators + * scheduled in a same TaskManager share a same context. + */ public class TrainContext implements Serializable { + /** Subtask ID of co-located operators. */ public int subtaskId; + + /** Number of subtasks of co-located operators. */ public int numSubtasks; - public GbtParams params; + /** Configurations for the boosting. */ + public BoostingStrategy strategy; + + /** Number of instances. */ public int numInstances; + + /** Number of bagging instances used for training one tree. */ public int numBaggingInstances; + + /** Randomizer for sampling instances. */ public Random instanceRandomizer; + /** Number of features. */ public int numFeatures; + + /** Number of bagging features tested for splitting one node. */ public int numBaggingFeatures; + + /** Randomizer for sampling features. */ public Random featureRandomizer; + /** Meta information of every feature. */ public FeatureMeta[] featureMetas; + + /** Number of bins for every feature. */ public int[] numFeatureBins; + /** Sum and count of labels of all samples. */ public Tuple2 labelSumCount; + + /** The prior value for prediction. */ public double prior; + + /** The loss function. */ public LossFunc loss; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 5c46f7015..ddb546ba2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -23,7 +23,7 @@ import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; -import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.PredGradHess; @@ -62,7 +62,7 @@ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator streamRecord) throws Exception { Row row = streamRecord.getValue(); BinnedInstance instance = new BinnedInstance(); instance.weight = 1.; - instance.label = row.getFieldAs(gbtParams.labelCol); + instance.label = row.getFieldAs(strategy.labelCol); - if (gbtParams.isInputVector) { - Vector vec = row.getFieldAs(gbtParams.featuresCols[0]); + if (strategy.isInputVector) { + Vector vec = row.getFieldAs(strategy.featuresCols[0]); SparseVector sv = vec.toSparse(); instance.featureIds = sv.indices.length == sv.size() ? null : sv.indices; instance.featureValues = Arrays.stream(sv.values).mapToInt(d -> (int) d).toArray(); } else { instance.featureValues = - Arrays.stream(gbtParams.featuresCols) + Arrays.stream(strategy.featuresCols) .mapToInt(col -> ((Number) row.getFieldAs(col)).intValue()) .toArray(); } @@ -171,7 +171,7 @@ public void onEpochWatermarkIncremented( TrainContext rawTrainContext = getter.get(SharedStorageConstants.TRAIN_CONTEXT); TrainContext trainContext = - new TrainContextInitializer(gbtParams) + new TrainContextInitializer(strategy) .init( rawTrainContext, getRuntimeContext().getIndexOfThisSubtask(), diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index ecaa2a1c9..255825b09 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -20,12 +20,12 @@ import org.apache.flink.ml.common.gbt.DataUtils; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; -import org.apache.flink.ml.common.gbt.defs.Distributor; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.TrainContext; +import org.apache.flink.ml.util.Distributor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,7 +68,7 @@ public HistBuilder(TrainContext trainContext) { featureRandomizer = trainContext.featureRandomizer; featureIndicesPool = IntStream.range(0, trainContext.numFeatures).toArray(); - isInputVector = trainContext.params.isInputVector; + isInputVector = trainContext.strategy.isInputVector; maxFeatureBins = Arrays.stream(numFeatureBins).max().orElse(0); totalNumFeatureBins = Arrays.stream(numFeatureBins).sum(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java index 2c157ad10..c8516575c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -43,7 +43,7 @@ class InstanceUpdater { public InstanceUpdater(TrainContext trainContext) { subtaskId = trainContext.subtaskId; loss = trainContext.loss; - stepSize = trainContext.params.stepSize; + stepSize = trainContext.strategy.stepSize; prior = trainContext.prior; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java index 15f7012b9..66e3b7d51 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java @@ -44,8 +44,8 @@ class NodeSplitter { public NodeSplitter(TrainContext trainContext) { subtaskId = trainContext.subtaskId; featureMetas = trainContext.featureMetas; - maxLeaves = trainContext.params.maxNumLeaves; - maxDepth = trainContext.params.maxDepth; + maxLeaves = trainContext.strategy.maxNumLeaves; + maxDepth = trainContext.strategy.maxDepth; } private int partitionInstances( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java index da5ef22b2..9e850bc82 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java @@ -18,7 +18,6 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.ml.common.gbt.defs.Distributor; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; @@ -29,6 +28,7 @@ import org.apache.flink.ml.common.gbt.splitter.CategoricalFeatureSplitter; import org.apache.flink.ml.common.gbt.splitter.ContinuousFeatureSplitter; import org.apache.flink.ml.common.gbt.splitter.HistogramFeatureSplitter; +import org.apache.flink.ml.util.Distributor; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -57,12 +57,12 @@ public SplitFinder(TrainContext trainContext) { splitters[i] = FeatureMeta.Type.CATEGORICAL == featureMetas[i].type ? new CategoricalFeatureSplitter( - i, featureMetas[i], trainContext.params) + i, featureMetas[i], trainContext.strategy) : new ContinuousFeatureSplitter( - i, featureMetas[i], trainContext.params); + i, featureMetas[i], trainContext.strategy); } - maxDepth = trainContext.params.maxDepth; - maxNumLeaves = trainContext.params.maxNumLeaves; + maxDepth = trainContext.strategy.maxDepth; + maxNumLeaves = trainContext.strategy.maxNumLeaves; } public Splits calc( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index c724301dd..eed26b1f6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -64,7 +64,7 @@ public void onEpochWatermarkIncremented( boolean terminated = getter.get(SharedStorageConstants.ALL_TREES).size() == getter.get(SharedStorageConstants.TRAIN_CONTEXT) - .params + .strategy .maxIter; // TODO: add validation error rate if (!terminated) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java index e657eb64a..fff2c6c1b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java @@ -20,10 +20,9 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; -import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.TaskType; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.lossfunc.AbsoluteErrorLoss; import org.apache.flink.ml.common.lossfunc.LogLoss; import org.apache.flink.ml.common.lossfunc.LossFunc; import org.apache.flink.ml.common.lossfunc.SquaredErrorLoss; @@ -41,10 +40,10 @@ class TrainContextInitializer { private static final Logger LOG = LoggerFactory.getLogger(TrainContextInitializer.class); - private final GbtParams params; + private final BoostingStrategy strategy; - public TrainContextInitializer(GbtParams params) { - this.params = params; + public TrainContextInitializer(BoostingStrategy strategy) { + this.strategy = strategy; } /** @@ -68,22 +67,22 @@ public TrainContext init( LOG.info( "subtaskId: {}, #samples: {}, #features: {}", subtaskId, numInstances, numFeatures); - trainContext.params = params; + trainContext.strategy = strategy; trainContext.numInstances = numInstances; trainContext.numFeatures = numFeatures; trainContext.numBaggingInstances = getNumBaggingSamples(numInstances); trainContext.numBaggingFeatures = getNumBaggingFeatures(numFeatures); - trainContext.instanceRandomizer = new Random(subtaskId + params.seed); - trainContext.featureRandomizer = new Random(params.seed); + trainContext.instanceRandomizer = new Random(subtaskId + strategy.seed); + trainContext.featureRandomizer = new Random(strategy.seed); trainContext.loss = getLoss(); trainContext.prior = calcPrior(trainContext.labelSumCount); trainContext.numFeatureBins = stream(trainContext.featureMetas) - .mapToInt(d -> d.numBins(trainContext.params.useMissing)) + .mapToInt(d -> d.numBins(trainContext.strategy.useMissing)) .toArray(); LOG.info("subtaskId: {}, {} end", subtaskId, TrainContextInitializer.class.getSimpleName()); @@ -91,7 +90,7 @@ public TrainContext init( } private int getNumBaggingSamples(int numSamples) { - return (int) Math.min(numSamples, Math.ceil(numSamples * params.subsamplingRate)); + return (int) Math.min(numSamples, Math.ceil(numSamples * strategy.subsamplingRate)); } private int getNumBaggingFeatures(int numFeatures) { @@ -102,24 +101,24 @@ private int getNumBaggingFeatures(int numFeatures) { String.join(", ", supported)); final Function clamp = d -> Math.max(1, Math.min(d.intValue(), numFeatures)); - String strategy = params.featureSubsetStrategy; + String featureSubsetStrategy = strategy.featureSubsetStrategy; try { - int numBaggingFeatures = Integer.parseInt(strategy); + int numBaggingFeatures = Integer.parseInt(featureSubsetStrategy); Preconditions.checkArgument( numBaggingFeatures >= 1 && numBaggingFeatures <= numFeatures, errorMsg); } catch (NumberFormatException ignored) { } try { - double baggingRatio = Double.parseDouble(strategy); + double baggingRatio = Double.parseDouble(featureSubsetStrategy); Preconditions.checkArgument(baggingRatio > 0. && baggingRatio <= 1., errorMsg); return clamp.apply(baggingRatio * numFeatures); } catch (NumberFormatException ignored) { } - Preconditions.checkArgument(supported.contains(strategy), errorMsg); - switch (strategy) { + Preconditions.checkArgument(supported.contains(featureSubsetStrategy), errorMsg); + switch (featureSubsetStrategy) { case "auto": - return TaskType.CLASSIFICATION.equals(params.taskType) + return TaskType.CLASSIFICATION.equals(strategy.taskType) ? clamp.apply(Math.sqrt(numFeatures)) : clamp.apply(numFeatures / 3.); case "all": @@ -136,28 +135,22 @@ private int getNumBaggingFeatures(int numFeatures) { } private LossFunc getLoss() { - String lossType = params.lossType; - switch (lossType) { - case "logistic": + switch (strategy.lossType) { + case LOGISTIC: return LogLoss.INSTANCE; - case "squared": + case SQUARED: return SquaredErrorLoss.INSTANCE; - case "absolute": - return AbsoluteErrorLoss.INSTANCE; default: throw new UnsupportedOperationException("Unsupported loss."); } } private double calcPrior(Tuple2 labelStat) { - String lossType = params.lossType; - switch (lossType) { - case "logistic": + switch (strategy.lossType) { + case LOGISTIC: return Math.log(labelStat.f0 / (labelStat.f1 - labelStat.f0)); - case "squared": + case SQUARED: return labelStat.f0 / labelStat.f1; - case "absolute": - throw new UnsupportedOperationException("absolute error is not supported yet."); default: throw new UnsupportedOperationException("Unsupported loss."); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java index db2b07d55..0006b4e0d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java @@ -19,8 +19,8 @@ package org.apache.flink.ml.common.gbt.splitter; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.HessianImpurity; import org.apache.flink.ml.common.gbt.defs.Split; @@ -35,8 +35,9 @@ /** Splitter for a categorical feature using LightGBM many-vs-many split. */ public class CategoricalFeatureSplitter extends HistogramFeatureSplitter { - public CategoricalFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { - super(featureId, featureMeta, params); + public CategoricalFeatureSplitter( + int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + super(featureId, featureMeta, strategy); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java index 924c3ad4f..bee536691 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/ContinuousFeatureSplitter.java @@ -19,16 +19,17 @@ package org.apache.flink.ml.common.gbt.splitter; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.HessianImpurity; import org.apache.flink.ml.common.gbt.defs.Split; /** Splitter for a continuous feature. */ public final class ContinuousFeatureSplitter extends HistogramFeatureSplitter { - public ContinuousFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { - super(featureId, featureMeta, params); + public ContinuousFeatureSplitter( + int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + super(featureId, featureMeta, strategy); } @Override @@ -60,7 +61,7 @@ public Split.ContinuousSplit bestSplit() { missingGoLeft, total.prediction(), splitPoint, - !params.isInputVector, + !strategy.isInputVector, ((FeatureMeta.ContinuousFeatureMeta) featureMeta).zeroBin); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java index b9fccf037..26936e0a6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java @@ -18,8 +18,8 @@ package org.apache.flink.ml.common.gbt.splitter; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.Split; /** @@ -32,20 +32,21 @@ public abstract class FeatureSplitter { protected final int featureId; protected final FeatureMeta featureMeta; - protected final GbtParams params; + protected final BoostingStrategy strategy; protected final int minSamplesPerLeaf; protected final double minSampleRatioPerChild; protected final double minInfoGain; - public FeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { - this.params = params; + public FeatureSplitter(int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + this.strategy = strategy; this.featureId = featureId; this.featureMeta = featureMeta; - this.minSamplesPerLeaf = params.minInstancesPerNode; - this.minSampleRatioPerChild = params.minWeightFractionPerNode; // TODO: not exactly the same - this.minInfoGain = params.minInfoGain; + this.minSamplesPerLeaf = strategy.minInstancesPerNode; + this.minSampleRatioPerChild = + strategy.minWeightFractionPerNode; // TODO: not exactly the same + this.minInfoGain = strategy.minInfoGain; } public abstract Split bestSplit(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java index 3774d4b86..8aa3a5f69 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/HistogramFeatureSplitter.java @@ -20,8 +20,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.common.gbt.defs.HessianImpurity; import org.apache.flink.ml.common.gbt.defs.Impurity; import org.apache.flink.ml.common.gbt.defs.Slice; @@ -35,9 +35,10 @@ public abstract class HistogramFeatureSplitter extends FeatureSplitter { protected double[] hists; protected Slice slice; - public HistogramFeatureSplitter(int featureId, FeatureMeta featureMeta, GbtParams params) { - super(featureId, featureMeta, params); - this.useMissing = params.useMissing; + public HistogramFeatureSplitter( + int featureId, FeatureMeta featureMeta, BoostingStrategy strategy) { + super(featureId, featureMeta, strategy); + this.useMissing = strategy.useMissing; } protected boolean isSplitIllegal(Impurity total, Impurity left, Impurity right) { @@ -184,6 +185,6 @@ protected void countTotalMissing(HessianImpurity total, HessianImpurity missing) } protected HessianImpurity emptyImpurity() { - return new HessianImpurity(params.lambda, params.gamma, 0, 0, 0, 0); + return new HessianImpurity(strategy.regLambda, strategy.regGamma, 0, 0, 0, 0); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/AbsoluteErrorLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/AbsoluteErrorLoss.java deleted file mode 100644 index a79c3e575..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/AbsoluteErrorLoss.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.lossfunc; - -import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; - -/** - * Absolute error loss function defined as |y - pred| where y and pred are label and predictions for - * the instance respectively. - */ -public class AbsoluteErrorLoss implements LossFunc { - - public static final AbsoluteErrorLoss INSTANCE = new AbsoluteErrorLoss(); - - private AbsoluteErrorLoss() {} - - @Override - public double loss(double pred, double label) { - double error = label - pred; - return Math.abs(error); - } - - @Override - public double gradient(double pred, double label) { - return label > pred ? -1. : 1; - } - - @Override - public double hessian(double pred, double y) { - return 0.; - } - - @Override - public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { - throw new UnsupportedOperationException(); - } - - @Override - public void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { - throw new UnsupportedOperationException(); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java deleted file mode 100644 index 77bee2c3a..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSubsetStrategy.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.StringParam; -import org.apache.flink.ml.param.WithParams; - -/** Interface for shared param feature subset strategy. */ -public interface HasFeatureSubsetStrategy extends WithParams { - Param FEATURE_SUBSET_STRATEGY = - new StringParam( - "featureSubsetStrategy.", - "Fraction of the training data used for learning one tree. Supports \"auto\", \"all\", \"onethird\", \"sqrt\", \"log2\", (0.0 - 1.0], and [1 - n].", - "auto", - ParamValidators.notNull()); - - default String getFeatureSubsetStrategy() { - return get(FEATURE_SUBSET_STRATEGY); - } - - default T setFeatureSubsetStrategy(String value) { - return set(FEATURE_SUBSET_STRATEGY, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java deleted file mode 100644 index 45042c903..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxBins.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.IntParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for the shared maxBins param. */ -public interface HasMaxBins extends WithParams { - Param MAX_BINS = - new IntParam( - "maxBins", - "Maximum number of bins used for discretizing continuous features.", - 32, - ParamValidators.gtEq(2)); - - default int getMaxBins() { - return get(MAX_BINS); - } - - default T setMaxBins(int value) { - return set(MAX_BINS, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java deleted file mode 100644 index 68a746f4e..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxDepth.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.IntParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for the shared maxDepth param. */ -public interface HasMaxDepth extends WithParams { - Param MAX_DEPTH = - new IntParam("maxDepth", "Maximum depth of the tree.", 5, ParamValidators.gtEq(1)); - - default int getMaxDepth() { - return get(MAX_DEPTH); - } - - default T setMaxDepth(int value) { - return set(MAX_DEPTH, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java deleted file mode 100644 index cbb5c4c08..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInfoGain.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.DoubleParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for shared param minInfoGain. */ -public interface HasMinInfoGain extends WithParams { - Param MIN_INFO_GAIN = - new DoubleParam( - "minInfoGain", - "Minimum information gain for a split to be considered valid.", - 0., - ParamValidators.gtEq(0.)); - - default double getMinInfoGain() { - return get(MIN_INFO_GAIN); - } - - default T setMinInfoGain(Double value) { - return set(MIN_INFO_GAIN, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java deleted file mode 100644 index 91cf8ab8d..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinInstancesPerNode.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.IntParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for the shared minInstancesPerNode param. */ -public interface HasMinInstancesPerNode extends WithParams { - Param MIN_INSTANCES_PER_NODE = - new IntParam( - "minInstancesPerNode", - "Minimum number of instances each node must have. If a split causes the left or right child to have fewer instances than minInstancesPerNode, the split is invalid.", - 1, - ParamValidators.gtEq(1)); - - default int getMinInstancesPerNode() { - return get(MIN_INSTANCES_PER_NODE); - } - - default T setMinInstancesPerNode(int value) { - return set(MIN_INSTANCES_PER_NODE, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java deleted file mode 100644 index c8fbaa3ae..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMinWeightFractionPerNode.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.DoubleParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for shared param minWeightFractionPerNode. */ -public interface HasMinWeightFractionPerNode extends WithParams { - Param MIN_WEIGHT_FRACTION_PER_NODE = - new DoubleParam( - "minWeightFractionPerNode", - "Minimum fraction of the weighted sample count that each node must have. If a split causes the left or right child to have a smaller fraction of the total weight than minWeightFractionPerNode, the split is invalid.", - 0., - ParamValidators.gtEq(0.)); - - default double getMinWeightFractionPerNode() { - return get(MIN_WEIGHT_FRACTION_PER_NODE); - } - - default T setMinWeightFractionPerNode(Double value) { - return set(MIN_WEIGHT_FRACTION_PER_NODE, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java deleted file mode 100644 index f0faa2edd..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasStepSize.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.DoubleParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for the shared step size param. */ -public interface HasStepSize extends WithParams { - Param STEP_SIZE = - new DoubleParam( - "stepSize", - "Step size for shrinking the contribution of each estimator.", - 0.1, - ParamValidators.inRange(0., 1.)); - - default double getStepSize() { - return get(STEP_SIZE); - } - - default T setStepSize(Double value) { - return set(STEP_SIZE, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java deleted file mode 100644 index 3f04d6282..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasSubsamplingRate.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.DoubleParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for shared param subsampling rate. */ -public interface HasSubsamplingRate extends WithParams { - Param SUBSAMPLING_RATE = - new DoubleParam( - "subsamplingRate", - "Fraction of the training data used for learning one tree.", - 1., - ParamValidators.inRange(0., 1.)); - - default double getSubsamplingRate() { - return get(SUBSAMPLING_RATE); - } - - default T setSubsamplingRate(Double value) { - return set(SUBSAMPLING_RATE, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java deleted file mode 100644 index 5e076474e..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationIndicatorCol.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.StringParam; -import org.apache.flink.ml.param.WithParams; - -/** Interface for the shared validation indicate column param. */ -public interface HasValidationIndicatorCol extends WithParams { - Param VALIDATION_INDICATOR_COL = - new StringParam( - "validationIndicatorCol", - "The name of the column that indicates whether each row is for training or for validation.", - null); - - default String getValidationIndicatorCol() { - return get(VALIDATION_INDICATOR_COL); - } - - default T setValidationIndicatorCol(String value) { - return set(VALIDATION_INDICATOR_COL, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java deleted file mode 100644 index d50d958d6..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasValidationTol.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.param; - -import org.apache.flink.ml.param.DoubleParam; -import org.apache.flink.ml.param.Param; -import org.apache.flink.ml.param.ParamValidators; -import org.apache.flink.ml.param.WithParams; - -/** Interface for the shared tolerance param. */ -public interface HasValidationTol extends WithParams { - - Param VALIDATION_TOL = - new DoubleParam( - "validationTol", - "Threshold for early stopping when fitting with validation is used.", - .01, - ParamValidators.gtEq(0)); - - default double getValidationTol() { - return get(VALIDATION_TOL); - } - - default T setValidationTol(Double value) { - return set(VALIDATION_TOL, value); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java index d9435bdaf..92b7f783b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java @@ -60,7 +60,7 @@ public GBTRegressorModel fit(Table... inputs) { Preconditions.checkArgument(inputs.length == 1); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream modelData = GBTRunner.trainRegressor(inputs[0], this); + DataStream modelData = GBTRunner.train(inputs[0], this); GBTRegressorModel model = new GBTRegressorModel(); model.setModelData(tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData"))); ReadWriteUtils.updateExistingParams(model, getParamMap()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java index 184cf158a..ab21aee8b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorParams.java @@ -31,10 +31,7 @@ public interface GBTRegressorParams extends BaseGBTParams, GBTRegressorModelParams { Param LOSS_TYPE = new StringParam( - "lossType", - "Loss type.", - "squared", - ParamValidators.inArray("squared", "absolute")); + "lossType", "Loss type.", "squared", ParamValidators.inArray("squared")); default String getLossType() { return get(LOSS_TYPE); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java index 0df47c61d..d58357e7c 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java @@ -23,7 +23,8 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.configuration.Configuration; -import org.apache.flink.ml.common.gbt.defs.GbtParams; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; +import org.apache.flink.ml.common.gbt.defs.LossType; import org.apache.flink.ml.common.gbt.defs.TaskType; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -91,23 +92,23 @@ public void before() { }))); } - private GbtParams getCommonGbtParams() { - GbtParams p = new GbtParams(); - p.featuresCols = new String[] {"f0", "f1", "f2"}; - p.categoricalCols = new String[] {"f2"}; - p.isInputVector = false; - p.gamma = 0.; - p.maxBins = 3; - p.seed = 123; - p.featureSubsetStrategy = "all"; - p.maxDepth = 3; - p.maxNumLeaves = 1 << (p.maxDepth - 1); - p.maxIter = 20; - p.stepSize = 0.1; - return p; + private BoostingStrategy getCommonStrategy() { + BoostingStrategy strategy = new BoostingStrategy(); + strategy.featuresCols = new String[] {"f0", "f1", "f2"}; + strategy.categoricalCols = new String[] {"f2"}; + strategy.isInputVector = false; + strategy.regGamma = 0.; + strategy.maxBins = 3; + strategy.seed = 123; + strategy.featureSubsetStrategy = "all"; + strategy.maxDepth = 3; + strategy.maxNumLeaves = 1 << (strategy.maxDepth - 1); + strategy.maxIter = 20; + strategy.stepSize = 0.1; + return strategy; } - private void verifyModelData(GBTModelData modelData, GbtParams p) { + private void verifyModelData(GBTModelData modelData, BoostingStrategy p) { Assert.assertEquals(p.taskType, TaskType.valueOf(modelData.type)); Assert.assertEquals(p.stepSize, modelData.stepSize, 1e-12); Assert.assertEquals(p.maxIter, modelData.allTrees.size()); @@ -115,25 +116,25 @@ private void verifyModelData(GBTModelData modelData, GbtParams p) { @Test public void testTrainClassifier() throws Exception { - GbtParams p = getCommonGbtParams(); - p.taskType = TaskType.CLASSIFICATION; - p.labelCol = "cls_label"; - p.lossType = "logistic"; - p.useMissing = true; + BoostingStrategy strategy = getCommonStrategy(); + strategy.taskType = TaskType.CLASSIFICATION; + strategy.labelCol = "cls_label"; + strategy.lossType = LossType.LOGISTIC; + strategy.useMissing = true; - GBTModelData modelData = GBTRunner.train(inputTable, p).executeAndCollect().next(); - verifyModelData(modelData, p); + GBTModelData modelData = GBTRunner.train(inputTable, strategy).executeAndCollect().next(); + verifyModelData(modelData, strategy); } @Test public void testTrainRegressor() throws Exception { - GbtParams p = getCommonGbtParams(); - p.taskType = TaskType.REGRESSION; - p.labelCol = "label"; - p.lossType = "squared"; - p.useMissing = true; + BoostingStrategy strategy = getCommonStrategy(); + strategy.taskType = TaskType.REGRESSION; + strategy.labelCol = "label"; + strategy.lossType = LossType.SQUARED; + strategy.useMissing = true; - GBTModelData modelData = GBTRunner.train(inputTable, p).executeAndCollect().next(); - verifyModelData(modelData, p); + GBTModelData modelData = GBTRunner.train(inputTable, strategy).executeAndCollect().next(); + verifyModelData(modelData, strategy); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java index cc99b9627..be4a76d10 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java @@ -22,8 +22,8 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; -import org.apache.flink.ml.common.gbt.defs.GbtParams; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.util.TestUtils; @@ -127,13 +127,14 @@ public void invoke(T value, Context context) { @Test public void testPreprocessCols() throws Exception { - GbtParams p = new GbtParams(); - p.isInputVector = false; - p.featuresCols = new String[] {"f0", "f1", "f2"}; - p.categoricalCols = new String[] {"f2"}; - p.labelCol = "label"; - p.maxBins = 3; - Tuple2> results = Preprocess.preprocessCols(inputTable, p); + BoostingStrategy strategy = new BoostingStrategy(); + strategy.isInputVector = false; + strategy.featuresCols = new String[] {"f0", "f1", "f2"}; + strategy.categoricalCols = new String[] {"f2"}; + strategy.labelCol = "label"; + strategy.maxBins = 3; + Tuple2> results = + Preprocess.preprocessCols(inputTable, strategy); actualMeta.get().clear(); results.f1.addSink(new CollectSink<>(actualMeta)); @@ -174,12 +175,13 @@ public void testPreprocessCols() throws Exception { @Test public void testPreprocessVectorCol() throws Exception { - GbtParams p = new GbtParams(); - p.isInputVector = true; - p.featuresCols = new String[] {"vec"}; - p.labelCol = "label"; - p.maxBins = 3; - Tuple2> results = Preprocess.preprocessVecCol(inputTable, p); + BoostingStrategy strategy = new BoostingStrategy(); + strategy.isInputVector = true; + strategy.featuresCols = new String[] {"vec"}; + strategy.labelCol = "label"; + strategy.maxBins = 3; + Tuple2> results = + Preprocess.preprocessVecCol(inputTable, strategy); actualMeta.get().clear(); results.f1.addSink(new CollectSink<>(actualMeta)); From f236703b26112ea2f869fc4054d67c588f57a335 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 1 Mar 2023 11:39:03 +0800 Subject: [PATCH 23/47] Remove GBTRunnerTest. --- .../flink/ml/common/gbt/GBTRunnerTest.java | 140 ------------------ 1 file changed, 140 deletions(-) delete mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java deleted file mode 100644 index d58357e7c..000000000 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/GBTRunnerTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt; - -import org.apache.flink.api.common.restartstrategy.RestartStrategies; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.java.typeutils.RowTypeInfo; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; -import org.apache.flink.ml.common.gbt.defs.LossType; -import org.apache.flink.ml.common.gbt.defs.TaskType; -import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; -import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.Table; -import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; -import org.apache.flink.test.util.AbstractTestBase; -import org.apache.flink.types.Row; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; - -import java.util.Arrays; -import java.util.List; - -/** Tests {@link GBTRunner}. */ -public class GBTRunnerTest extends AbstractTestBase { - private static final List inputDataRows = - Arrays.asList( - Row.of(1.2, 2, null, 40., 1., 0., Vectors.dense(1.2, 2, Double.NaN)), - Row.of(2.3, 3, "b", 40., 2., 0., Vectors.dense(2.3, 3, 2.)), - Row.of(3.4, 4, "c", 40., 3., 0., Vectors.dense(3.4, 4, 3.)), - Row.of(4.5, 5, "a", 40., 4., 0., Vectors.dense(4.5, 5, 1.)), - Row.of(5.6, 2, "b", 40., 5., 0., Vectors.dense(5.6, 2, 2.)), - Row.of(null, 3, "c", 41., 1., 1., Vectors.dense(Double.NaN, 3, 3.)), - Row.of(12.8, 4, "e", 41., 2., 1., Vectors.dense(12.8, 4, 5.)), - Row.of(13.9, 2, "b", 41., 3., 1., Vectors.dense(13.9, 2, 2.)), - Row.of(14.1, 4, "a", 41., 4., 1., Vectors.dense(14.1, 4, 1.)), - Row.of(15.3, 1, "d", 41., 5., 1., Vectors.dense(15.3, 1, 4.))); - - @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); - private Table inputTable; - - @Before - public void before() { - Configuration config = new Configuration(); - config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); - StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); - env.getConfig().enableObjectReuse(); - env.setParallelism(4); - env.enableCheckpointing(100); - env.setRestartStrategy(RestartStrategies.noRestart()); - StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); - - inputTable = - tEnv.fromDataStream( - env.fromCollection( - inputDataRows, - new RowTypeInfo( - new TypeInformation[] { - Types.DOUBLE, - Types.INT, - Types.STRING, - Types.DOUBLE, - Types.DOUBLE, - Types.DOUBLE, - VectorTypeInfo.INSTANCE - }, - new String[] { - "f0", "f1", "f2", "label", "weight", "cls_label", "vec" - }))); - } - - private BoostingStrategy getCommonStrategy() { - BoostingStrategy strategy = new BoostingStrategy(); - strategy.featuresCols = new String[] {"f0", "f1", "f2"}; - strategy.categoricalCols = new String[] {"f2"}; - strategy.isInputVector = false; - strategy.regGamma = 0.; - strategy.maxBins = 3; - strategy.seed = 123; - strategy.featureSubsetStrategy = "all"; - strategy.maxDepth = 3; - strategy.maxNumLeaves = 1 << (strategy.maxDepth - 1); - strategy.maxIter = 20; - strategy.stepSize = 0.1; - return strategy; - } - - private void verifyModelData(GBTModelData modelData, BoostingStrategy p) { - Assert.assertEquals(p.taskType, TaskType.valueOf(modelData.type)); - Assert.assertEquals(p.stepSize, modelData.stepSize, 1e-12); - Assert.assertEquals(p.maxIter, modelData.allTrees.size()); - } - - @Test - public void testTrainClassifier() throws Exception { - BoostingStrategy strategy = getCommonStrategy(); - strategy.taskType = TaskType.CLASSIFICATION; - strategy.labelCol = "cls_label"; - strategy.lossType = LossType.LOGISTIC; - strategy.useMissing = true; - - GBTModelData modelData = GBTRunner.train(inputTable, strategy).executeAndCollect().next(); - verifyModelData(modelData, strategy); - } - - @Test - public void testTrainRegressor() throws Exception { - BoostingStrategy strategy = getCommonStrategy(); - strategy.taskType = TaskType.REGRESSION; - strategy.labelCol = "label"; - strategy.lossType = LossType.SQUARED; - strategy.useMissing = true; - - GBTModelData modelData = GBTRunner.train(inputTable, strategy).executeAndCollect().next(); - verifyModelData(modelData, strategy); - } -} From b21a23b1bd762b953e3d0650dcf8168a3b061c58 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 1 Mar 2023 12:57:03 +0800 Subject: [PATCH 24/47] Only call ListStateWithCache#update just before snapshot. --- .../flink/ml/common/sharedstorage/SharedStorage.java | 12 +++++++----- .../ml/common/gbt/operators/PostSplitsOperator.java | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java index f056034bb..747bcd6cf 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java @@ -110,6 +110,7 @@ T get() { static class Writer extends Reader { private final String ownerId; private final ListStateWithCache cache; + private boolean isDirty; Writer( Tuple3 t, @@ -138,6 +139,7 @@ static class Writer extends Reader { } catch (Exception e) { throw new RuntimeException(e); } + isDirty = false; } private void ensureOwner() { @@ -149,11 +151,7 @@ private void ensureOwner() { void set(T value) { ensureOwner(); m.put(t, value); - try { - cache.update(Collections.singletonList(value)); - } catch (Exception e) { - throw new RuntimeException(e); - } + isDirty = true; } void remove() { @@ -164,6 +162,10 @@ void remove() { } void snapshotState(StateSnapshotContext context) throws Exception { + if (isDirty) { + cache.update(Collections.singletonList(get())); + isDirty = false; + } cache.snapshotState(context); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 1668a87a2..6aa9a4968 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -110,6 +110,7 @@ public void initializeState(StateInitializationContext context) throws Exception @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); + splitsState.update(Collections.singletonList(splits)); splitsState.snapshotState(context); nodeSplitterState.snapshotState(context); instanceUpdaterState.snapshotState(context); @@ -204,7 +205,6 @@ public void onIterationTerminated(Context context, Collector collector) @Override public void processElement(StreamRecord element) throws Exception { splits = element.getValue(); - splitsState.update(Collections.singletonList(splits)); } @Override From 768820adb85be4550de0f7cc7cab1e04b8cf5336 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 1 Mar 2023 13:00:27 +0800 Subject: [PATCH 25/47] Refine some TODOs. --- .../java/org/apache/flink/ml/common/gbt/GBTModelData.java | 1 - .../flink/ml/common/gbt/operators/TerminationOperator.java | 2 +- .../apache/flink/ml/common/gbt/splitter/FeatureSplitter.java | 4 ++-- .../java/org/apache/flink/ml/common/gbt/PreprocessTest.java | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index 8b93b0272..165d2f9f4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -74,7 +74,6 @@ public class GBTModelData { public GBTModelData() {} - // TODO: !!! public GBTModelData( String type, boolean isInputVector, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index eed26b1f6..30157676b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -66,7 +66,7 @@ public void onEpochWatermarkIncremented( == getter.get(SharedStorageConstants.TRAIN_CONTEXT) .strategy .maxIter; - // TODO: add validation error rate + // TODO: Add validation error rate if (!terminated) { output.collect(new StreamRecord<>(0)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java index 26936e0a6..702bd2189 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/FeatureSplitter.java @@ -44,8 +44,8 @@ public FeatureSplitter(int featureId, FeatureMeta featureMeta, BoostingStrategy this.featureMeta = featureMeta; this.minSamplesPerLeaf = strategy.minInstancesPerNode; - this.minSampleRatioPerChild = - strategy.minWeightFractionPerNode; // TODO: not exactly the same + // TODO: not exactly the same since weights are not supported right now. + this.minSampleRatioPerChild = strategy.minWeightFractionPerNode; this.minInfoGain = strategy.minInfoGain; } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java index be4a76d10..3a8ca1573 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/gbt/PreprocessTest.java @@ -142,7 +142,7 @@ public void testPreprocessCols() throws Exception { List preprocessedRows = IteratorUtils.toList(tEnv.toDataStream(results.f0).executeAndCollect()); - // TODO: correct `binEdges` of feature `f0` after FLINK-30734 resolved. + // TODO: Correct `binEdges` of feature `f0` after FLINK-30734 resolved. List expectedMeta = Arrays.asList( FeatureMeta.continuous("f0", 3, new double[] {1.2, 4.5, 13.9, 15.3}), @@ -189,7 +189,7 @@ public void testPreprocessVectorCol() throws Exception { List preprocessedRows = IteratorUtils.toList(tEnv.toDataStream(results.f0).executeAndCollect()); - // TODO: correct `binEdges` of feature `_vec_f0` and `_vec_f2` after FLINK-30734 resolved. + // TODO: Correct `binEdges` of feature `_vec_f0` and `_vec_f2` after FLINK-30734 resolved. List expectedMeta = Arrays.asList( FeatureMeta.continuous("_vec_f0", 3, new double[] {1.2, 4.5, 13.9, 15.3}), From 132cca8f58d6b446a367d30ee2ff0037870a3cfb Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 1 Mar 2023 14:46:37 +0800 Subject: [PATCH 26/47] Improve javadoc for GBTClassifier and GBTRegressor. --- .../gbtclassifier/GBTClassifier.java | 18 +++++++++++++++++- .../regression/gbtregressor/GBTRegressor.java | 18 +++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java index 980af398d..3d7f76ee3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java @@ -36,7 +36,23 @@ import static org.apache.flink.table.api.Expressions.$; -/** An Estimator which implements the gradient boosting trees classification algorithm. */ +/** + * An Estimator which implements the gradient boosting trees classification algorithm (Gradient Boosting). + * + *

The implementation has been inspired by advanced implementations like XGBoost and LightGBM. + * It supports features like regularized learning objective with second-order approximation, + * histogram-based and sparsity-aware split-finding algorithm. + * + *

The implementation of distributed system takes this work as a reference. Right now, we + * support horizontal partition of data and row-store storage of instances. + * + *

NOTE: Currently, some features are not supported yet: weighted input samples, early-stopping + * with validation set, encoding with leaf ids, etc. + */ public class GBTClassifier implements Estimator, GBTClassifierParams { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java index 92b7f783b..bbe828bce 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java @@ -36,7 +36,23 @@ import static org.apache.flink.table.api.Expressions.$; -/** An Estimator which implements the gradient boosting trees regression algorithm. */ +/** + * An Estimator which implements the gradient boosting trees regression algorithm (Gradient Boosting). + * + *

The implementation has been inspired by advanced implementations like XGBoost and LightGBM. + * It supports features like regularized learning objective with second-order approximation, + * histogram-based and sparsity-aware split-finding algorithm. + * + *

The implementation of distributed system takes this work as a reference. Right now, we + * support horizontal partition of data and row-store storage of instances. + * + *

NOTE: Currently, some features are not supported yet: weighted input samples, early-stopping + * with validation set, encoding with leaf ids, etc. + */ public class GBTRegressor implements Estimator, GBTRegressorParams { From 546591ef518b219c86cfd8714b0715f7328313ad Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 6 Mar 2023 16:23:40 +0800 Subject: [PATCH 27/47] Improve categorical feature splitter by ignoring less frequent categories. --- .../splitter/CategoricalFeatureSplitter.java | 51 ++++++++++++------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java index 0006b4e0d..ee0d5aa5e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/splitter/CategoricalFeatureSplitter.java @@ -24,11 +24,10 @@ import org.apache.flink.ml.common.gbt.defs.HessianImpurity; import org.apache.flink.ml.common.gbt.defs.Split; -import org.apache.commons.lang3.ArrayUtils; +import org.eclipse.collections.api.list.primitive.MutableIntList; +import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; -import java.util.Arrays; import java.util.BitSet; -import java.util.Comparator; import static org.apache.flink.ml.common.gbt.DataUtils.BIN_SIZE; @@ -51,22 +50,35 @@ public Split.CategoricalSplit bestSplit() { } int numBins = slice.size(); - // Sorts categories based on grads / hessians, i.e., LightGBM many-vs-many approach. - Integer[] sortedCategories = new Integer[numBins]; + // Sorts categories (bins) based on grads / hessians, i.e., LightGBM many-vs-many approach. + MutableIntList sortedIndices = new IntArrayList(numBins); + // A category (bin) is treated as missing values if its occurrences is smaller than a + // threshold. Currently, the threshold is 0. + BitSet ignoredIndices = new BitSet(numBins); { double[] scores = new double[numBins]; for (int i = 0; i < numBins; ++i) { - sortedCategories[i] = i; - int startIndex = (slice.start + i) * BIN_SIZE; - scores[i] = hists[startIndex] / hists[startIndex + 1]; + int index = (slice.start + i) * BIN_SIZE; + if (hists[index + 3] > 0) { + sortedIndices.add(i); + scores[i] = hists[index] / hists[index + 1]; + } else { + ignoredIndices.set(i); + missing.add( + (int) hists[index + 3], + hists[index + 2], + hists[index], + hists[index + 1]); + } } - Arrays.sort(sortedCategories, Comparator.comparing(d -> scores[d])); + sortedIndices.sortThis( + (value1, value2) -> Double.compare(scores[value1], scores[value2])); } Tuple3 bestSplit = - findBestSplit(ArrayUtils.toPrimitive(sortedCategories), total, missing); + findBestSplit(sortedIndices.toArray(), total, missing); double bestGain = bestSplit.f0; - int bestSplitBinId = bestSplit.f1; + int bestSplitIndex = bestSplit.f1; boolean missingGoLeft = bestSplit.f2; if (bestGain <= Split.INVALID_GAIN || bestGain <= minInfoGain) { @@ -76,9 +88,9 @@ public Split.CategoricalSplit bestSplit() { // Indicates which bins should go left. BitSet binsGoLeft = new BitSet(numBins); if (useMissing) { - for (int i = 0; i < numBins; ++i) { - int binId = sortedCategories[i]; - if (i <= bestSplitBinId) { + for (int i = 0; i < sortedIndices.size(); ++i) { + int binId = sortedIndices.get(i); + if (i <= bestSplitIndex) { if (binId < featureMeta.missingBin) { binsGoLeft.set(binId); } else if (binId > featureMeta.missingBin) { @@ -87,15 +99,16 @@ public Split.CategoricalSplit bestSplit() { } } } else { - int numCategories = - ((FeatureMeta.CategoricalFeatureMeta) featureMeta).categories.length; - for (int i = 0; i < numCategories; i += 1) { - int binId = sortedCategories[i]; - if (i <= bestSplitBinId) { + for (int i = 0; i < sortedIndices.size(); i += 1) { + int binId = sortedIndices.get(i); + if (i <= bestSplitIndex) { binsGoLeft.set(binId); } } } + if (missingGoLeft) { + binsGoLeft.or(ignoredIndices); + } return new Split.CategoricalSplit( featureId, bestGain, From 5562e8033a1400e5523bdd87dfbf5bf155de24ea Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 7 Mar 2023 11:25:05 +0800 Subject: [PATCH 28/47] Change PredGradHess to double[]. --- .../ml/common/gbt/defs/PredGradHess.java | 40 ------ .../CacheDataCalcLocalHistsOperator.java | 13 +- .../ml/common/gbt/operators/HistBuilder.java | 13 +- .../common/gbt/operators/InstanceUpdater.java | 25 ++-- .../gbt/operators/PostSplitsOperator.java | 3 +- .../gbt/operators/SharedStorageConstants.java | 12 +- .../operators/TrainContextInitializer.java | 2 +- .../gbt/typeinfo/PredGradHessSerializer.java | 115 ------------------ 8 files changed, 30 insertions(+), 193 deletions(-) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java deleted file mode 100644 index 36272af57..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/PredGradHess.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.defs; - -/** Stores prediction, gradient, and hessian of an instance. */ -public class PredGradHess { - public double pred; - public double gradient; - public double hessian; - - public PredGradHess() {} - - public PredGradHess(double pred, double gradient, double hessian) { - this.pred = pred; - this.gradient = gradient; - this.hessian = hessian; - } - - @Override - public String toString() { - return String.format( - "PredGradHess{pred=%s, gradient=%s, hessian=%s}", pred, gradient, hessian); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index ddb546ba2..5c1f16679 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -26,7 +26,6 @@ import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; import org.apache.flink.ml.common.lossfunc.LossFunc; @@ -192,19 +191,17 @@ public void onEpochWatermarkIncremented( Preconditions.checkArgument( getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); BinnedInstance[] instances = getter.get(SharedStorageConstants.INSTANCES); - PredGradHess[] pgh = getter.get(SharedStorageConstants.PREDS_GRADS_HESSIANS); + double[] pgh = getter.get(SharedStorageConstants.PREDS_GRADS_HESSIANS); // In the first round, use prior as the predictions. if (0 == pgh.length) { - pgh = new PredGradHess[instances.length]; + pgh = new double[instances.length * 3]; double prior = trainContext.prior; LossFunc loss = trainContext.loss; for (int i = 0; i < instances.length; i += 1) { double label = instances[i].label; - pgh[i] = - new PredGradHess( - prior, - loss.gradient(prior, label), - loss.hessian(prior, label)); + pgh[3 * i] = prior; + pgh[3 * i + 1] = loss.gradient(prior, label); + pgh[3 * i + 2] = loss.hessian(prior, label); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index 255825b09..bba786009 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -23,7 +23,6 @@ import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.util.Distributor; @@ -86,7 +85,7 @@ private static void calcNodeFeaturePairHists( boolean isInputVector, int[] indices, BinnedInstance[] instances, - PredGradHess[] pgh, + double[] pgh, double[] hists) { int numNodes = layer.size(); int numFeatures = featureMetas.length; @@ -138,8 +137,8 @@ private static void calcNodeFeaturePairHists( int instanceId = indices[i]; BinnedInstance binnedInstance = instances[instanceId]; double weight = binnedInstance.weight; - double gradient = pgh[instanceId].gradient; - double hessian = pgh[instanceId].hessian; + double gradient = pgh[3 * instanceId + 1]; + double hessian = pgh[3 * instanceId + 2]; totalHists[0] += gradient; totalHists[1] += hessian; @@ -151,8 +150,8 @@ private static void calcNodeFeaturePairHists( int instanceId = indices[i]; BinnedInstance binnedInstance = instances[instanceId]; double weight = binnedInstance.weight; - double gradient = pgh[instanceId].gradient; - double hessian = pgh[instanceId].hessian; + double gradient = pgh[3 * instanceId + 1]; + double hessian = pgh[3 * instanceId + 2]; if (null == binnedInstance.featureIds) { for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { @@ -244,7 +243,7 @@ Histogram build( List layer, int[] indices, BinnedInstance[] instances, - PredGradHess[] pgh, + double[] pgh, Consumer nodeFeaturePairsSetter) { LOG.info("subtaskId: {}, {} start", subtaskId, HistBuilder.class.getSimpleName()); int numNodes = layer.size(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java index c8516575c..bee71a068 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -21,7 +21,6 @@ import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Node; -import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.lossfunc.LossFunc; @@ -48,20 +47,20 @@ public InstanceUpdater(TrainContext trainContext) { } public void update( - PredGradHess[] pgh, + double[] pgh, List leaves, int[] indices, BinnedInstance[] instances, - Consumer pghSetter, + Consumer pghSetter, List treeNodes) { LOG.info("subtaskId: {}, {} start", subtaskId, InstanceUpdater.class.getSimpleName()); if (pgh.length == 0) { - pgh = new PredGradHess[instances.length]; + pgh = new double[instances.length * 3]; for (int i = 0; i < instances.length; i += 1) { double label = instances[i].label; - pgh[i] = - new PredGradHess( - prior, loss.gradient(prior, label), loss.hessian(prior, label)); + pgh[3 * i] = prior; + pgh[3 * i + 1] = loss.gradient(prior, label); + pgh[3 * i + 2] = loss.hessian(prior, label); } } @@ -70,20 +69,20 @@ public void update( double pred = split.prediction * stepSize; for (int i = nodeInfo.slice.start; i < nodeInfo.slice.end; ++i) { int instanceId = indices[i]; - updatePgh(pred, instances[instanceId].label, pgh[instanceId]); + updatePgh(instanceId, pred, instances[instanceId].label, pgh); } for (int i = nodeInfo.oob.start; i < nodeInfo.oob.end; ++i) { int instanceId = indices[i]; - updatePgh(pred, instances[instanceId].label, pgh[instanceId]); + updatePgh(instanceId, pred, instances[instanceId].label, pgh); } } pghSetter.accept(pgh); LOG.info("subtaskId: {}, {} end", subtaskId, InstanceUpdater.class.getSimpleName()); } - private void updatePgh(double pred, double label, PredGradHess pgh) { - pgh.pred += pred; - pgh.gradient = loss.gradient(pgh.pred, label); - pgh.hessian = loss.hessian(pgh.pred, label); + private void updatePgh(int instanceId, double pred, double label, double[] pgh) { + pgh[instanceId * 3] += pred; + pgh[instanceId * 3 + 1] = loss.gradient(pgh[instanceId * 3], label); + pgh[instanceId * 3 + 2] = loss.hessian(pgh[instanceId * 3], label); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 6aa9a4968..4c5f04ee5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -25,7 +25,6 @@ import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Node; -import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.Splits; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; @@ -194,7 +193,7 @@ public void onIterationTerminated(Context context, Collector collector) throws Exception { sharedStorageContext.invoke( (getter, setter) -> { - setter.set(SharedStorageConstants.PREDS_GRADS_HESSIANS, new PredGradHess[0]); + setter.set(SharedStorageConstants.PREDS_GRADS_HESSIANS, new double[0]); setter.set(SharedStorageConstants.SWAPPED_INDICES, new int[0]); setter.set(SharedStorageConstants.LEAVES, Collections.emptyList()); setter.set(SharedStorageConstants.LAYER, Collections.emptyList()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java index 4d09f18d9..a42ffe762 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java @@ -29,14 +29,13 @@ import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Node; -import org.apache.flink.ml.common.gbt.defs.PredGradHess; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; import org.apache.flink.ml.common.gbt.typeinfo.NodeSerializer; -import org.apache.flink.ml.common.gbt.typeinfo.PredGradHessSerializer; import org.apache.flink.ml.common.sharedstorage.ItemDescriptor; import org.apache.flink.ml.common.sharedstorage.SharedStorageUtils; +import org.apache.flink.ml.linalg.typeinfo.OptimizedDoublePrimitiveArraySerializer; import java.util.ArrayList; import java.util.Arrays; @@ -67,15 +66,14 @@ public class SharedStorageConstants { new BinnedInstance[0]); /** - * Predictions, gradients, and hessians of instances, sharing same instances with {@link + * (prediction, gradient, and hessian) of instances, sharing same indexing with {@link * #INSTANCES}. */ - static final ItemDescriptor PREDS_GRADS_HESSIANS = + static final ItemDescriptor PREDS_GRADS_HESSIANS = ItemDescriptor.of( "preds_grads_hessians", - new GenericArraySerializer<>( - PredGradHess.class, PredGradHessSerializer.INSTANCE), - new PredGradHess[0]); + new OptimizedDoublePrimitiveArraySerializer(), + new double[0]); /** Shuffle indices of instances used after every new tree just initialized. */ static final ItemDescriptor SHUFFLED_INDICES = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java index fff2c6c1b..b01ecc707 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java @@ -84,7 +84,7 @@ public TrainContext init( stream(trainContext.featureMetas) .mapToInt(d -> d.numBins(trainContext.strategy.useMissing)) .toArray(); - + LOG.info("Number of bins for each feature: {}", trainContext.numFeatureBins); LOG.info("subtaskId: {}, {} end", subtaskId, TrainContextInitializer.class.getSimpleName()); return trainContext; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java deleted file mode 100644 index d206a1427..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/PredGradHessSerializer.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.typeinfo; - -import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.base.DoubleSerializer; -import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.common.gbt.defs.PredGradHess; - -import java.io.IOException; - -/** Serializer for {@link PredGradHess}. */ -public final class PredGradHessSerializer extends TypeSerializerSingleton { - - public static final PredGradHessSerializer INSTANCE = new PredGradHessSerializer(); - private static final long serialVersionUID = 1L; - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public PredGradHess createInstance() { - return new PredGradHess(); - } - - @Override - public PredGradHess copy(PredGradHess from) { - PredGradHess instance = new PredGradHess(); - instance.pred = from.pred; - instance.gradient = from.gradient; - instance.hessian = from.hessian; - return instance; - } - - @Override - public PredGradHess copy(PredGradHess from, PredGradHess reuse) { - assert from.getClass() == reuse.getClass(); - reuse.pred = from.pred; - reuse.gradient = from.gradient; - reuse.hessian = from.hessian; - return reuse; - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize(PredGradHess record, DataOutputView target) throws IOException { - DoubleSerializer.INSTANCE.serialize(record.pred, target); - DoubleSerializer.INSTANCE.serialize(record.gradient, target); - DoubleSerializer.INSTANCE.serialize(record.hessian, target); - } - - @Override - public PredGradHess deserialize(DataInputView source) throws IOException { - PredGradHess instance = new PredGradHess(); - instance.pred = DoubleSerializer.INSTANCE.deserialize(source); - instance.gradient = DoubleSerializer.INSTANCE.deserialize(source); - instance.hessian = DoubleSerializer.INSTANCE.deserialize(source); - return instance; - } - - @Override - public PredGradHess deserialize(PredGradHess reuse, DataInputView source) throws IOException { - reuse.pred = DoubleSerializer.INSTANCE.deserialize(source); - reuse.gradient = DoubleSerializer.INSTANCE.deserialize(source); - reuse.hessian = DoubleSerializer.INSTANCE.deserialize(source); - return reuse; - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - serialize(deserialize(source), target); - } - - // ------------------------------------------------------------------------ - - @Override - public TypeSerializerSnapshot snapshotConfiguration() { - return new PredGradHessSerializerSnapshot(); - } - - /** Serializer configuration snapshot for compatibility and format evolution. */ - @SuppressWarnings("WeakerAccess") - public static final class PredGradHessSerializerSnapshot - extends SimpleTypeSerializerSnapshot { - - public PredGradHessSerializerSnapshot() { - super(PredGradHessSerializer::new); - } - } -} From 01c42119b1eac1a1f7e13c2a96e2f5b62232e81a Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 7 Mar 2023 19:14:42 +0800 Subject: [PATCH 29/47] Improve Histogram to remove scattering. --- .../ml/common/gbt/BoostIterationBody.java | 53 +++---------------- .../flink/ml/common/gbt/defs/Histogram.java | 17 +++--- .../CacheDataCalcLocalHistsOperator.java | 21 +++++--- .../ml/common/gbt/operators/HistBuilder.java | 31 ++++++++++- .../ml/common/gbt/operators/SplitFinder.java | 1 + .../gbt/typeinfo/HistogramSerializer.java | 28 +++++----- ...timizedDoublePrimitiveArraySerializer.java | 10 ++-- 7 files changed, 79 insertions(+), 82 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index b54cb9b44..29b6c7cb5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -18,12 +18,9 @@ package org.apache.flink.ml.common.gbt; -import org.apache.flink.api.common.functions.FlatMapFunction; -import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationBody; @@ -47,8 +44,6 @@ import org.apache.flink.types.Row; import org.apache.flink.util.OutputTag; -import org.apache.commons.lang3.ArrayUtils; - import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -78,17 +73,21 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( // current tree layer. CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = new CacheDataCalcLocalHistsOperator(strategy); - SingleOutputStreamOperator localHists = + SingleOutputStreamOperator> localHists = data.connect(trainContext) .transform( "CacheDataCalcLocalHists", - TypeInformation.of(Histogram.class), + new TypeHint>() {}.getTypeInfo(), cacheDataCalcLocalHistsOp); for (ItemDescriptor s : SharedStorageConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { ownerMap.put(s, cacheDataCalcLocalHistsOp.getSharedStorageAccessorID()); } - DataStream globalHists = scatterReduceHistograms(localHists); + DataStream globalHists = + localHists + .partitionCustom((key, numPartitions) -> key, value -> value.f0) + .map(d -> d.f1) + .flatMap(new HistogramAggregateFunction()); SingleOutputStreamOperator localSplits = globalHists.transform( @@ -142,42 +141,4 @@ public IterationBodyResult process(DataStreamList variableStreams, DataStreamLis DataStreamList.of(finalModelData), termination); } - - public DataStream scatterReduceHistograms(DataStream localHists) { - return localHists - .flatMap( - (FlatMapFunction>) - (value, out) -> { - double[] hists = value.hists; - int[] recvcnts = value.recvcnts; - int p = 0; - for (int i = 0; i < recvcnts.length; i += 1) { - out.collect( - Tuple2.of( - i, - new Histogram( - value.subtaskId, - ArrayUtils.subarray( - hists, p, p + recvcnts[i]), - recvcnts))); - p += recvcnts[i]; - } - }) - .returns(new TypeHint>() {}) - .partitionCustom( - new Partitioner() { - @Override - public int partition(Integer key, int numPartitions) { - return key; - } - }, - new KeySelector, Integer>() { - @Override - public Integer getKey(Tuple2 value) { - return value.f0; - } - }) - .map(d -> d.f1) - .flatMap(new HistogramAggregateFunction()); - } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java index b939bd7fc..c85e44a03 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java @@ -30,26 +30,25 @@ */ @TypeInfo(HistogramTypeInfoFactory.class) public class Histogram implements Serializable { - - // Stores source subtask ID when reducing or target subtask ID when scattering. + // Stores source subtask ID. public int subtaskId; // Stores values of histogram bins. public double[] hists; - // Stores the number of elements received by subtasks in scattering. - public int[] recvcnts; + // Stores the valid slice of `hists`. + public Slice slice = new Slice(); public Histogram() {} - public Histogram(int subtaskId, double[] hists, int[] recvcnts) { + public Histogram(int subtaskId, double[] hists, Slice slice) { this.subtaskId = subtaskId; this.hists = hists; - this.recvcnts = recvcnts; + this.slice = slice; } private Histogram accumulate(Histogram other) { - Preconditions.checkArgument(hists.length == other.hists.length); - for (int i = 0; i < hists.length; i += 1) { - hists[i] += other.hists[i]; + Preconditions.checkArgument(slice.size() == other.slice.size()); + for (int i = 0; i < slice.size(); i += 1) { + hists[slice.start + i] += other.hists[other.slice.start + i]; } return this; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 5c1f16679..5a1983fdd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.common.gbt.operators; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; @@ -53,9 +54,10 @@ * Calculates local histograms for local data partition. Specifically in the first round, this * operator caches all data instances to JVM static region. */ -public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator - implements TwoInputStreamOperator, - IterationListener, +public class CacheDataCalcLocalHistsOperator + extends AbstractStreamOperator> + implements TwoInputStreamOperator>, + IterationListener>, SharedStorageStreamOperator { private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; @@ -154,7 +156,8 @@ public void processElement2(StreamRecord streamRecord) throws Exce @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector out) throws Exception { + int epochWatermark, Context context, Collector> out) + throws Exception { if (0 == epochWatermark) { // Initializes local state in first round. sharedStorageContext.invoke( @@ -226,20 +229,22 @@ public void onEpochWatermarkIncremented( setter.set(SharedStorageConstants.HAS_INITED_TREE, false); } - Histogram localHists = + List> histograms = histBuilder.build( layer, indices, instances, pgh, d -> setter.set(SharedStorageConstants.NODE_FEATURE_PAIRS, d)); - out.collect(localHists); + for (Tuple2 t : histograms) { + out.collect(t); + } }); } @Override - public void onIterationTerminated(Context context, Collector collector) - throws Exception { + public void onIterationTerminated( + Context context, Collector> collector) throws Exception { instancesCollecting.clear(); treeInitializerState.clear(); histBuilderState.clear(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index bba786009..3d2a5f370 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -18,17 +18,20 @@ package org.apache.flink.ml.common.gbt.operators; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.common.gbt.DataUtils; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; +import org.apache.flink.ml.common.gbt.defs.Slice; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.util.Distributor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.Arrays; import java.util.BitSet; import java.util.List; @@ -114,6 +117,7 @@ private static void calcNodeFeaturePairHists( BitSet featureValid = null; boolean allFeatureValid; for (int k = 0; k < numNodes; k += 1) { + long start = System.currentTimeMillis(); int[] features = nodeToFeatures[k]; int[] binOffsets = nodeToBinOffsets[k]; LearningNode node = layer.get(k); @@ -204,6 +208,12 @@ private static void calcNodeFeaturePairHists( } } } + LOG.info( + "STEP 3: node {}, {} #instances, {} #features, {} ms", + k, + node.slice.size(), + features.length, + System.currentTimeMillis() - start); } } @@ -239,7 +249,7 @@ private static int[] calcRecvCounts( } /** Calculate local histograms for nodes in current layer of tree. */ - Histogram build( + List> build( List layer, int[] indices, BinnedInstance[] instances, @@ -267,6 +277,7 @@ Histogram build( numNodes * Math.min(maxFeatureBins * numBaggingFeatures, totalNumFeatureBins); double[] hists = new double[maxNumBins * BIN_SIZE]; // Calculates histograms for (nodeId, featureId) pairs. + long start = System.currentTimeMillis(); calcNodeFeaturePairHists( layer, nodeToFeatures, @@ -277,11 +288,27 @@ Histogram build( instances, pgh, hists); + long elapsed = System.currentTimeMillis() - start; + LOG.info("Elapsed time for calcNodeFeaturePairHists: {} ms", elapsed); // Calculates number of elements received by each downstream subtask. int[] recvcnts = calcRecvCounts(numSubtasks, nodeFeaturePairs, numFeatureBins); + List> histograms = new ArrayList<>(); + int sliceStart = 0; + for (int i = 0; i < recvcnts.length; i += 1) { + int sliceSize = recvcnts[i]; + histograms.add( + Tuple2.of( + i, + new Histogram( + subtaskId, + hists, + new Slice(sliceStart, sliceStart + sliceSize)))); + sliceStart += sliceSize; + } + LOG.info("subtaskId: {}, {} end", this.subtaskId, HistBuilder.class.getSimpleName()); - return new Histogram(this.subtaskId, hists, recvcnts); + return histograms; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java index 9e850bc82..186f675b4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java @@ -82,6 +82,7 @@ public Splits calc( LearningNode node = layer.get(nodeId); Preconditions.checkState(node.depth < maxDepth || numLeaves + 2 <= maxNumLeaves); + Preconditions.checkState(histogram.slice.start == 0); splitters[featureId].reset( histogram.hists, new Slice(binOffset, binOffset + numFeatureBins[featureId])); Split bestSplit = splitters[featureId].bestSplit(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java index 02b0628f4..2b151c87b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java @@ -22,12 +22,14 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; -import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.ml.common.gbt.defs.Slice; import org.apache.flink.ml.linalg.typeinfo.OptimizedDoublePrimitiveArraySerializer; +import org.apache.commons.lang3.ArrayUtils; + import java.io.IOException; /** Serializer for {@link Histogram}. */ @@ -53,18 +55,15 @@ public Histogram createInstance() { public Histogram copy(Histogram from) { Histogram histogram = new Histogram(); histogram.subtaskId = from.subtaskId; - histogram.hists = from.hists.clone(); - histogram.recvcnts = from.recvcnts.clone(); + histogram.hists = ArrayUtils.subarray(from.hists, from.slice.start, from.slice.end); + histogram.slice.start = 0; + histogram.slice.end = from.slice.size(); return histogram; } @Override public Histogram copy(Histogram from, Histogram reuse) { - assert from.getClass() == reuse.getClass(); - reuse.subtaskId = from.subtaskId; - reuse.hists = from.hists.clone(); - reuse.recvcnts = from.recvcnts.clone(); - return reuse; + return copy(from); } @Override @@ -74,9 +73,9 @@ public int getLength() { @Override public void serialize(Histogram record, DataOutputView target) throws IOException { - IntSerializer.INSTANCE.serialize(record.subtaskId, target); - histsSerializer.serialize(record.hists, target); - IntPrimitiveArraySerializer.INSTANCE.serialize(record.recvcnts, target); + target.writeInt(record.subtaskId); + // Only writes valid slice of `hists`. + histsSerializer.serialize(record.hists, record.slice.start, record.slice.size(), target); } @Override @@ -84,15 +83,16 @@ public Histogram deserialize(DataInputView source) throws IOException { Histogram histogram = new Histogram(); histogram.subtaskId = IntSerializer.INSTANCE.deserialize(source); histogram.hists = histsSerializer.deserialize(source); - histogram.recvcnts = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + histogram.slice = new Slice(0, histogram.hists.length); return histogram; } @Override public Histogram deserialize(Histogram reuse, DataInputView source) throws IOException { reuse.subtaskId = IntSerializer.INSTANCE.deserialize(source); - reuse.hists = histsSerializer.deserialize(source); - reuse.recvcnts = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + reuse.hists = histsSerializer.deserialize(reuse.hists, source); + reuse.slice.start = 0; + reuse.slice.end = reuse.hists.length; return reuse; } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java index 9264be2ce..7af39064b 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/OptimizedDoublePrimitiveArraySerializer.java @@ -79,10 +79,14 @@ public void serialize(double[] record, DataOutputView target) throws IOException if (record == null) { throw new IllegalArgumentException("The record must not be null."); } - final int len = record.length; + serialize(record, 0, record.length, target); + } + + public void serialize(double[] record, int start, int len, DataOutputView target) + throws IOException { target.writeInt(len); - for (int i = 0; i < len; i++) { - Bits.putDouble(buf, (i & 127) << 3, record[i]); + for (int i = 0; i < len; i += 1) { + Bits.putDouble(buf, (i & 127) << 3, record[start + i]); if ((i & 127) == 127) { target.write(buf); } From 7f73aa156f0ebee77990177ee571670633e472bd Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Thu, 16 Mar 2023 14:50:23 +0800 Subject: [PATCH 30/47] Add eclipse collection jars to uber jar --- flink-ml-lib/pom.xml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 2c5d86c75..1b5f06539 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -163,10 +163,22 @@ under the License. shade + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + com.github.wendykierp:JTransforms pl.edu.icm:JLargeArrays + org.eclipse.collections:eclipse-collections-api + org.eclipse.collections:eclipse-collections From 075262dbb416816a16a0e6df4c2c0114bb2fa803 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 20 Mar 2023 11:51:42 +0800 Subject: [PATCH 31/47] Support output feature importance --- .../gbtclassifier/GBTClassifier.java | 6 +++- .../flink/ml/common/gbt/GBTModelData.java | 8 +++++ .../apache/flink/ml/common/gbt/GBTRunner.java | 36 +++++++++++++++++++ .../gbt/typeinfo/GBTModelDataSerializer.java | 13 +++++++ .../regression/gbtregressor/GBTRegressor.java | 6 +++- 5 files changed, 67 insertions(+), 2 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java index 3d7f76ee3..7aef49aa9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java @@ -78,8 +78,12 @@ public GBTClassifierModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); DataStream modelData = GBTRunner.train(inputs[0], this); + DataStream> featureImportance = + GBTRunner.getFeatureImportance(modelData); GBTClassifierModel model = new GBTClassifierModel(); - model.setModelData(tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData"))); + model.setModelData( + tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), + tEnv.fromDataStream(featureImportance)); ReadWriteUtils.updateExistingParams(model, getParamMap()); return model; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java index 165d2f9f4..49831f50f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTModelData.java @@ -37,6 +37,7 @@ import org.apache.flink.ml.feature.stringindexer.StringIndexerModel; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -49,6 +50,7 @@ import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; import java.util.BitSet; import java.util.List; @@ -68,6 +70,7 @@ public class GBTModelData { public double stepSize; public List> allTrees; + public List featureNames; public IntObjectHashMap> categoryToIdMaps; public IntObjectHashMap featureIdToBinEdges; public BitSet isCategorical; @@ -80,6 +83,7 @@ public GBTModelData( double prior, double stepSize, List> allTrees, + List featureNames, IntObjectHashMap> categoryToIdMaps, IntObjectHashMap featureIdToBinEdges, BitSet isCategorical) { @@ -88,12 +92,14 @@ public GBTModelData( this.prior = prior; this.stepSize = stepSize; this.allTrees = allTrees; + this.featureNames = featureNames; this.categoryToIdMaps = categoryToIdMaps; this.featureIdToBinEdges = featureIdToBinEdges; this.isCategorical = isCategorical; } public static GBTModelData from(TrainContext trainContext, List> allTrees) { + List featureNames = new ArrayList<>(); IntObjectHashMap> categoryToIdMaps = new IntObjectHashMap<>(); IntObjectHashMap featureIdToBinEdges = new IntObjectHashMap<>(); BitSet isCategorical = new BitSet(); @@ -101,6 +107,7 @@ public static GBTModelData from(TrainContext trainContext, List> allT FeatureMeta[] featureMetas = trainContext.featureMetas; for (int k = 0; k < featureMetas.length; k += 1) { FeatureMeta featureMeta = featureMetas[k]; + featureNames.add(featureMeta.name); if (featureMeta instanceof FeatureMeta.CategoricalFeatureMeta) { String[] categories = ((FeatureMeta.CategoricalFeatureMeta) featureMeta).categories; ObjectIntHashMap categoryToId = new ObjectIntHashMap<>(); @@ -120,6 +127,7 @@ public static GBTModelData from(TrainContext trainContext, List> allT trainContext.prior, trainContext.strategy.stepSize, allTrees, + featureNames, categoryToIdMaps, featureIdToBinEdges, isCategorical); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index 0709dd37c..6f753da4c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationConfig; @@ -33,6 +34,7 @@ import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.LossType; +import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.TaskType; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; @@ -102,6 +104,40 @@ static DataStream train(Table dataTable, BoostingStrategy strategy return boost(dataTable, strategy, featureMeta, labelSumCount); } + public static DataStream> getFeatureImportance( + DataStream modelData) { + return modelData + .map( + value -> { + Map featureImportanceMap = new HashMap<>(); + double sum = 0.; + for (List tree : value.allTrees) { + for (Node node : tree) { + if (node.isLeaf) { + continue; + } + featureImportanceMap.merge( + node.split.featureId, node.split.gain, Double::sum); + sum += node.split.gain; + } + } + if (sum > 0.) { + for (Map.Entry entry : + featureImportanceMap.entrySet()) { + entry.setValue(entry.getValue() / sum); + } + } + + List featureNames = value.featureNames; + return featureImportanceMap.entrySet().stream() + .collect( + Collectors.toMap( + d -> featureNames.get(d.getKey()), + Map.Entry::getValue)); + }) + .returns(Types.MAP(Types.STRING, Types.DOUBLE)); + } + private static DataStream boost( Table dataTable, BoostingStrategy strategy, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java index 53c188a48..6c815254f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/GBTModelDataSerializer.java @@ -71,6 +71,7 @@ public GBTModelData copy(GBTModelData from) { for (int i = 0; i < from.allTrees.size(); i += 1) { record.allTrees.add(new ArrayList<>(from.allTrees.get(i))); } + record.featureNames = new ArrayList<>(from.featureNames); record.categoryToIdMaps = new IntObjectHashMap<>(from.categoryToIdMaps); record.featureIdToBinEdges = new IntObjectHashMap<>(from.featureIdToBinEdges); record.isCategorical = BitSet.valueOf(from.isCategorical.toByteArray()); @@ -103,6 +104,11 @@ public void serialize(GBTModelData record, DataOutputView target) throws IOExcep } } + IntSerializer.INSTANCE.serialize(record.featureNames.size(), target); + for (int i = 0; i < record.featureNames.size(); i += 1) { + StringSerializer.INSTANCE.serialize(record.featureNames.get(i), target); + } + IntSerializer.INSTANCE.serialize(record.categoryToIdMaps.size(), target); for (int featureId : record.categoryToIdMaps.keysView().toArray()) { ObjectIntHashMap categoryToIdMap = record.categoryToIdMaps.get(featureId); @@ -145,6 +151,13 @@ public GBTModelData deserialize(DataInputView source) throws IOException { record.allTrees.add(treeNodes); } + int numFeatures = IntSerializer.INSTANCE.deserialize(source); + record.featureNames = new ArrayList<>(numFeatures); + for (int k = 0; k < numFeatures; k += 1) { + String featureName = StringSerializer.INSTANCE.deserialize(source); + record.featureNames.add(featureName); + } + int numCategoricalFeatures = IntSerializer.INSTANCE.deserialize(source); record.categoryToIdMaps = IntObjectHashMap.newMap(); for (int k = 0; k < numCategoricalFeatures; k += 1) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java index bbe828bce..c26ae9f64 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java @@ -77,8 +77,12 @@ public GBTRegressorModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); DataStream modelData = GBTRunner.train(inputs[0], this); + DataStream> featureImportance = + GBTRunner.getFeatureImportance(modelData); GBTRegressorModel model = new GBTRegressorModel(); - model.setModelData(tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData"))); + model.setModelData( + tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), + tEnv.fromDataStream(featureImportance)); ReadWriteUtils.updateExistingParams(model, getParamMap()); return model; } From 28cdd28d0355c9dc34308e4635b77c31b9fed807 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 20 Mar 2023 15:12:57 +0800 Subject: [PATCH 32/47] Update setModelData to support featureImportanceTable. --- .../ml/classification/gbtclassifier/GBTClassifierModel.java | 6 +++++- .../java/org/apache/flink/ml/common/gbt/BaseGBTModel.java | 6 +++++- .../flink/ml/regression/gbtregressor/GBTRegressorModel.java | 6 +++++- .../apache/flink/ml/classification/GBTClassifierTest.java | 3 +-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java index 7bf0763fe..191458353 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java @@ -25,6 +25,7 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.gbt.BaseGBTModel; import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.GBTRunner; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.util.ReadWriteUtils; @@ -41,6 +42,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.Map; /** A Model computed by {@link GBTClassifier}. */ public class GBTClassifierModel extends BaseGBTModel @@ -58,7 +60,9 @@ public static GBTClassifierModel load(StreamTableEnvironment tEnv, String path) GBTClassifierModel model = ReadWriteUtils.loadStageParam(path); Table modelDataTable = ReadWriteUtils.loadModelData(tEnv, path, new GBTModelData.ModelDataDecoder()); - return model.setModelData(modelDataTable); + DataStream> featureImportance = + GBTRunner.getFeatureImportance(GBTModelData.getModelDataStream(modelDataTable)); + return model.setModelData(modelDataTable, tEnv.fromDataStream(featureImportance)); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java index 315a49489..ae4152e84 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java @@ -25,6 +25,7 @@ import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.table.api.Table; +import org.apache.flink.util.Preconditions; import java.io.IOException; import java.util.HashMap; @@ -35,6 +36,7 @@ public abstract class BaseGBTModel> implements Model, Object> paramMap = new HashMap<>(); protected Table modelDataTable; + protected Table featureImportanceTable; public BaseGBTModel() { ParamUtils.initializeMapWithDefaultValues(paramMap, this); @@ -42,12 +44,14 @@ public BaseGBTModel() { @Override public Table[] getModelData() { - return new Table[] {modelDataTable}; + return new Table[] {modelDataTable, featureImportanceTable}; } @Override public T setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 2); modelDataTable = inputs[0]; + featureImportanceTable = inputs[1]; //noinspection unchecked return (T) this; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java index 6cadaf696..fcc4d8ba0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java @@ -25,6 +25,7 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.gbt.BaseGBTModel; import org.apache.flink.ml.common.gbt.GBTModelData; +import org.apache.flink.ml.common.gbt.GBTRunner; import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; @@ -38,6 +39,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.Map; /** A Model computed by {@link GBTRegressor}. */ public class GBTRegressorModel extends BaseGBTModel @@ -55,7 +57,9 @@ public static GBTRegressorModel load(StreamTableEnvironment tEnv, String path) GBTRegressorModel model = ReadWriteUtils.loadStageParam(path); Table modelDataTable = ReadWriteUtils.loadModelData(tEnv, path, new GBTModelData.ModelDataDecoder()); - return model.setModelData(modelDataTable); + DataStream> featureImportance = + GBTRunner.getFeatureImportance(GBTModelData.getModelDataStream(modelDataTable)); + return model.setModelData(modelDataTable, tEnv.fromDataStream(featureImportance)); } @Override diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java index edbb719fe..e0fb35c50 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -484,8 +484,7 @@ public void testSetModelData() throws Exception { .setMaxBins(3) .setSeed(123); GBTClassifierModel modelA = gbtc.fit(inputTable); - Table modelDataTable = modelA.getModelData()[0]; - GBTClassifierModel modelB = new GBTClassifierModel().setModelData(modelDataTable); + GBTClassifierModel modelB = new GBTClassifierModel().setModelData(modelA.getModelData()); ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); Table output = modelA.transform(inputTable)[0].select( From bc6465deb6f4ed77c3c9db650ef805b18d0c1c80 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 20 Mar 2023 15:47:44 +0800 Subject: [PATCH 33/47] Remove duplicated model data --- .../gbt/operators/TerminationOperator.java | 16 +++++++++------- .../ml/classification/GBTClassifierTest.java | 5 ++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index 30157676b..8b1726949 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -76,13 +76,15 @@ public void onEpochWatermarkIncremented( @Override public void onIterationTerminated(Context context, Collector collector) throws Exception { - sharedStorageContext.invoke( - (getter, setter) -> - context.output( - modelDataOutputTag, - GBTModelData.from( - getter.get(SharedStorageConstants.TRAIN_CONTEXT), - getter.get(SharedStorageConstants.ALL_TREES)))); + if (0 == getRuntimeContext().getIndexOfThisSubtask()) { + sharedStorageContext.invoke( + (getter, setter) -> + context.output( + modelDataOutputTag, + GBTModelData.from( + getter.get(SharedStorageConstants.TRAIN_CONTEXT), + getter.get(SharedStorageConstants.ALL_TREES)))); + } } @Override diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java index e0fb35c50..fa2376b84 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -457,7 +457,10 @@ public void testGetModelData() throws Exception { Assert.assertArrayEquals( new String[] {"modelData"}, modelDataColumnNames.toArray(new String[0])); - Row modelDataRow = (Row) IteratorUtils.toList(output.executeAndCollect()).get(0); + //noinspection unchecked + List modelDataRows = IteratorUtils.toList(output.executeAndCollect()); + Assert.assertEquals(1, modelDataRows.size()); + Row modelDataRow = modelDataRows.get(0); GBTModelData modelData = modelDataRow.getFieldAs(0); Assert.assertNotNull(modelData); From d4f1dd513f70a94b96b4f32a2cd461038c10cac7 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 21 Mar 2023 17:59:54 +0800 Subject: [PATCH 34/47] Fix save/load for feature importance. --- .../gbtclassifier/GBTClassifier.java | 3 +- .../gbtclassifier/GBTClassifierModel.java | 10 +-- .../flink/ml/common/gbt/BaseGBTModel.java | 87 ++++++++++++++++++- .../regression/gbtregressor/GBTRegressor.java | 3 +- .../gbtregressor/GBTRegressorModel.java | 10 +-- .../ml/classification/GBTClassifierTest.java | 25 ++++-- .../flink/ml/regression/GBTRegressorTest.java | 29 +++++-- 7 files changed, 135 insertions(+), 32 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java index 7aef49aa9..248aaea76 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java @@ -83,7 +83,8 @@ public GBTClassifierModel fit(Table... inputs) { GBTClassifierModel model = new GBTClassifierModel(); model.setModelData( tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), - tEnv.fromDataStream(featureImportance)); + tEnv.fromDataStream(featureImportance) + .renameColumns($("f0").as("featureImportance"))); ReadWriteUtils.updateExistingParams(model, getParamMap()); return model; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java index 191458353..7833e5d52 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifierModel.java @@ -25,10 +25,8 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.gbt.BaseGBTModel; import org.apache.flink.ml.common.gbt.GBTModelData; -import org.apache.flink.ml.common.gbt.GBTRunner; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; -import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -42,7 +40,6 @@ import java.io.IOException; import java.util.Collections; -import java.util.Map; /** A Model computed by {@link GBTClassifier}. */ public class GBTClassifierModel extends BaseGBTModel @@ -57,12 +54,7 @@ public class GBTClassifierModel extends BaseGBTModel */ public static GBTClassifierModel load(StreamTableEnvironment tEnv, String path) throws IOException { - GBTClassifierModel model = ReadWriteUtils.loadStageParam(path); - Table modelDataTable = - ReadWriteUtils.loadModelData(tEnv, path, new GBTModelData.ModelDataDecoder()); - DataStream> featureImportance = - GBTRunner.getFeatureImportance(GBTModelData.getModelDataStream(modelDataTable)); - return model.setModelData(modelDataTable, tEnv.fromDataStream(featureImportance)); + return BaseGBTModel.load(tEnv, path); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java index ae4152e84..11c1ff3c6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTModel.java @@ -18,21 +18,41 @@ package org.apache.flink.ml.common.gbt; +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.base.DoubleSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.Path; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.api.Model; import org.apache.flink.ml.classification.gbtclassifier.GBTClassifier; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; import org.apache.flink.util.Preconditions; import java.io.IOException; +import java.io.OutputStream; import java.util.HashMap; import java.util.Map; /** Base model computed by {@link GBTClassifier} or {@link GBTRegressor}. */ public abstract class BaseGBTModel> implements Model { + protected static final String MODEL_DATA_PATH = "model_data"; + protected static final String FEATURE_IMPORTANCE_PATH = "feature_importance"; protected final Map, Object> paramMap = new HashMap<>(); protected Table modelDataTable; @@ -42,6 +62,22 @@ public BaseGBTModel() { ParamUtils.initializeMapWithDefaultValues(paramMap, this); } + protected static > T load(StreamTableEnvironment tEnv, String path) + throws IOException { + T model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, + new Path(path, MODEL_DATA_PATH).toString(), + new GBTModelData.ModelDataDecoder()); + Table featureImportanceTable = + ReadWriteUtils.loadModelData( + tEnv, + new Path(path, FEATURE_IMPORTANCE_PATH).toString(), + new FeatureImportanceEncoderDecoder()); + return model.setModelData(modelDataTable, featureImportanceTable); + } + @Override public Table[] getModelData() { return new Table[] {modelDataTable, featureImportanceTable}; @@ -66,7 +102,56 @@ public void save(String path) throws IOException { ReadWriteUtils.saveMetadata(this, path); ReadWriteUtils.saveModelData( GBTModelData.getModelDataStream(modelDataTable), - path, + new Path(path, MODEL_DATA_PATH).toString(), new GBTModelData.ModelDataEncoder()); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) featureImportanceTable).getTableEnvironment(); + ReadWriteUtils.saveModelData( + tEnv.toDataStream( + featureImportanceTable, + DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE())), + new Path(path, FEATURE_IMPORTANCE_PATH).toString(), + new FeatureImportanceEncoderDecoder()); + } + + private static class FeatureImportanceEncoderDecoder + extends SimpleStreamFormat> + implements Encoder> { + + final MapSerializer serializer = + new MapSerializer<>(StringSerializer.INSTANCE, DoubleSerializer.INSTANCE); + + @Override + public void encode(Map element, OutputStream stream) throws IOException { + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(stream); + serializer.serialize(element, dataOutputView); + } + + @Override + public Reader> createReader( + Configuration config, FSDataInputStream stream) throws IOException { + return new Reader>() { + @Override + public Map read() { + DataInputView source = new DataInputViewStreamWrapper(stream); + try { + return serializer.deserialize(source); + } catch (IOException e) { + return null; + } + } + + @Override + public void close() throws IOException { + stream.close(); + } + }; + } + + @Override + public TypeInformation> getProducedType() { + return Types.MAP(Types.STRING, Types.DOUBLE); + } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java index c26ae9f64..1ccb7e194 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java @@ -82,7 +82,8 @@ public GBTRegressorModel fit(Table... inputs) { GBTRegressorModel model = new GBTRegressorModel(); model.setModelData( tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), - tEnv.fromDataStream(featureImportance)); + tEnv.fromDataStream(featureImportance) + .renameColumns($("f0").as("featureImportance"))); ReadWriteUtils.updateExistingParams(model, getParamMap()); return model; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java index fcc4d8ba0..0c78f151a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressorModel.java @@ -25,8 +25,6 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.gbt.BaseGBTModel; import org.apache.flink.ml.common.gbt.GBTModelData; -import org.apache.flink.ml.common.gbt.GBTRunner; -import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -39,7 +37,6 @@ import java.io.IOException; import java.util.Collections; -import java.util.Map; /** A Model computed by {@link GBTRegressor}. */ public class GBTRegressorModel extends BaseGBTModel @@ -54,12 +51,7 @@ public class GBTRegressorModel extends BaseGBTModel */ public static GBTRegressorModel load(StreamTableEnvironment tEnv, String path) throws IOException { - GBTRegressorModel model = ReadWriteUtils.loadStageParam(path); - Table modelDataTable = - ReadWriteUtils.loadModelData(tEnv, path, new GBTModelData.ModelDataDecoder()); - DataStream> featureImportance = - GBTRunner.getFeatureImportance(GBTModelData.getModelDataStream(modelDataTable)); - return model.setModelData(modelDataTable, tEnv.fromDataStream(featureImportance)); + return BaseGBTModel.load(tEnv, path); } @Override diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java index fa2376b84..4a8bae57e 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -32,7 +32,6 @@ import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.ml.util.TestUtils; -import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -54,6 +53,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.List; +import java.util.Map; import static org.apache.flink.table.api.Expressions.$; @@ -411,6 +411,9 @@ public void testEstimatorSaveLoadAndPredict() throws Exception { Assert.assertEquals( Collections.singletonList("modelData"), model.getModelData()[0].getResolvedSchema().getColumnNames()); + Assert.assertEquals( + Collections.singletonList("featureImportance"), + model.getModelData()[1].getResolvedSchema().getColumnNames()); Table output = model.transform(inputTable)[0].select( $(gbtc.getPredictionCol()), @@ -453,15 +456,14 @@ public void testGetModelData() throws Exception { GBTClassifierModel model = gbtc.fit(inputTable); Table modelDataTable = model.getModelData()[0]; List modelDataColumnNames = modelDataTable.getResolvedSchema().getColumnNames(); - DataStream output = tEnv.toDataStream(modelDataTable); Assert.assertArrayEquals( new String[] {"modelData"}, modelDataColumnNames.toArray(new String[0])); //noinspection unchecked - List modelDataRows = IteratorUtils.toList(output.executeAndCollect()); + List modelDataRows = + IteratorUtils.toList(tEnv.toDataStream(modelDataTable).executeAndCollect()); Assert.assertEquals(1, modelDataRows.size()); - Row modelDataRow = modelDataRows.get(0); - GBTModelData modelData = modelDataRow.getFieldAs(0); + GBTModelData modelData = modelDataRows.get(0).getFieldAs(0); Assert.assertNotNull(modelData); Assert.assertEquals(TaskType.CLASSIFICATION, TaskType.valueOf(modelData.type)); @@ -474,6 +476,19 @@ public void testGetModelData() throws Exception { gbtc.getFeaturesCols().length - gbtc.getCategoricalCols().length, modelData.featureIdToBinEdges.size()); Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); + + Table featureImportanceTable = model.getModelData()[1]; + Assert.assertEquals( + Collections.singletonList("featureImportance"), + featureImportanceTable.getResolvedSchema().getColumnNames()); + //noinspection unchecked + List featureImportanceRows = + IteratorUtils.toList(tEnv.toDataStream(featureImportanceTable).executeAndCollect()); + Assert.assertEquals(1, featureImportanceRows.size()); + Map featureImportanceMap = + featureImportanceRows.get(0).getFieldAs("featureImportance"); + Assert.assertArrayEquals( + gbtc.getFeaturesCols(), featureImportanceMap.keySet().toArray(new String[0])); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java index 852e7969f..6f8e60a92 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java @@ -31,7 +31,6 @@ import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.ml.util.TestUtils; -import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -53,6 +52,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.List; +import java.util.Map; import static org.apache.flink.table.api.Expressions.$; @@ -297,6 +297,9 @@ public void testEstimatorSaveLoadAndPredict() throws Exception { Assert.assertEquals( Collections.singletonList("modelData"), model.getModelData()[0].getResolvedSchema().getColumnNames()); + Assert.assertEquals( + Collections.singletonList("featureImportance"), + model.getModelData()[1].getResolvedSchema().getColumnNames()); Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); verifyPredictionResult(output, outputRows); } @@ -331,12 +334,14 @@ public void testGetModelData() throws Exception { GBTRegressorModel model = gbtr.fit(inputTable); Table modelDataTable = model.getModelData()[0]; List modelDataColumnNames = modelDataTable.getResolvedSchema().getColumnNames(); - DataStream output = tEnv.toDataStream(modelDataTable); Assert.assertArrayEquals( new String[] {"modelData"}, modelDataColumnNames.toArray(new String[0])); - Row modelDataRow = (Row) IteratorUtils.toList(output.executeAndCollect()).get(0); - GBTModelData modelData = modelDataRow.getFieldAs(0); + //noinspection unchecked + List modelDataRows = + IteratorUtils.toList(tEnv.toDataStream(modelDataTable).executeAndCollect()); + Assert.assertEquals(1, modelDataRows.size()); + GBTModelData modelData = modelDataRows.get(0).getFieldAs(0); Assert.assertNotNull(modelData); Assert.assertEquals(TaskType.REGRESSION, TaskType.valueOf(modelData.type)); @@ -349,6 +354,19 @@ public void testGetModelData() throws Exception { gbtr.getFeaturesCols().length - gbtr.getCategoricalCols().length, modelData.featureIdToBinEdges.size()); Assert.assertEquals(BitSet.valueOf(new byte[] {4}), modelData.isCategorical); + + Table featureImportanceTable = model.getModelData()[1]; + Assert.assertEquals( + Collections.singletonList("featureImportance"), + featureImportanceTable.getResolvedSchema().getColumnNames()); + //noinspection unchecked + List featureImportanceRows = + IteratorUtils.toList(tEnv.toDataStream(featureImportanceTable).executeAndCollect()); + Assert.assertEquals(1, featureImportanceRows.size()); + Map featureImportanceMap = + featureImportanceRows.get(0).getFieldAs("featureImportance"); + Assert.assertArrayEquals( + gbtr.getFeaturesCols(), featureImportanceMap.keySet().toArray(new String[0])); } @Test @@ -362,8 +380,7 @@ public void testSetModelData() throws Exception { .setMaxBins(3) .setSeed(123); GBTRegressorModel modelA = gbtr.fit(inputTable); - Table modelDataTable = modelA.getModelData()[0]; - GBTRegressorModel modelB = new GBTRegressorModel().setModelData(modelDataTable); + GBTRegressorModel modelB = new GBTRegressorModel().setModelData(modelA.getModelData()); ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); Table output = modelA.transform(inputTable)[0].select($(gbtr.getPredictionCol())); verifyPredictionResult(output, outputRows); From d8c71fc62f3ee7416860ef43e26864d7aa0e7e79 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 24 Mar 2023 17:33:23 +0800 Subject: [PATCH 35/47] Fix get label when type is not double. --- .../common/gbt/operators/CacheDataCalcLocalHistsOperator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 5a1983fdd..a19ae1ab9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -130,7 +130,7 @@ public void processElement1(StreamRecord streamRecord) throws Exception { Row row = streamRecord.getValue(); BinnedInstance instance = new BinnedInstance(); instance.weight = 1.; - instance.label = row.getFieldAs(strategy.labelCol); + instance.label = row.getFieldAs(strategy.labelCol).doubleValue(); if (strategy.isInputVector) { Vector vec = row.getFieldAs(strategy.featuresCols[0]); From adf3115691cfae7358514fbc8982cd5b6dbdd976 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 28 Mar 2023 20:38:45 +0800 Subject: [PATCH 36/47] Reduce computation for nodes with max depth. --- .../ml/common/gbt/operators/HistBuilder.java | 169 ++++++++++-------- .../ml/common/gbt/operators/SplitFinder.java | 11 +- .../operators/TrainContextInitializer.java | 12 +- 3 files changed, 111 insertions(+), 81 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index 3d2a5f370..ec7995f6b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -47,6 +47,7 @@ class HistBuilder { private final int subtaskId; private final int numSubtasks; + private final int numFeatures; private final int[] numFeatureBins; private final FeatureMeta[] featureMetas; @@ -55,14 +56,13 @@ class HistBuilder { private final int[] featureIndicesPool; private final boolean isInputVector; - - private final int maxFeatureBins; - private final int totalNumFeatureBins; + private final int maxDepth; public HistBuilder(TrainContext trainContext) { subtaskId = trainContext.subtaskId; numSubtasks = trainContext.numSubtasks; + numFeatures = trainContext.numFeatures; numFeatureBins = trainContext.numFeatureBins; featureMetas = trainContext.featureMetas; @@ -71,9 +71,7 @@ public HistBuilder(TrainContext trainContext) { featureIndicesPool = IntStream.range(0, trainContext.numFeatures).toArray(); isInputVector = trainContext.strategy.isInputVector; - - maxFeatureBins = Arrays.stream(numFeatureBins).max().orElse(0); - totalNumFeatureBins = Arrays.stream(numFeatureBins).sum(); + maxDepth = trainContext.strategy.maxDepth; } /** @@ -83,6 +81,7 @@ public HistBuilder(TrainContext trainContext) { private static void calcNodeFeaturePairHists( List layer, int[][] nodeToFeatures, + boolean[] needSplit, FeatureMeta[] featureMetas, int[] numFeatureBins, boolean isInputVector, @@ -122,20 +121,6 @@ private static void calcNodeFeaturePairHists( int[] binOffsets = nodeToBinOffsets[k]; LearningNode node = layer.get(k); - if (numFeatures != features.length) { - allFeatureValid = false; - featureValid = new BitSet(numFeatures); - for (int feature : features) { - featureValid.set(feature); - } - for (int i = 0; i < features.length; i += 1) { - featureOffset[features[i]] = binOffsets[i]; - } - } else { - allFeatureValid = true; - System.arraycopy(binOffsets, 0, featureOffset, 0, numFeatures); - } - double[] totalHists = new double[4]; for (int i = node.slice.start; i < node.slice.end; i += 1) { int instanceId = indices[i]; @@ -150,63 +135,86 @@ private static void calcNodeFeaturePairHists( totalHists[3] += 1.; } - for (int i = node.slice.start; i < node.slice.end; i += 1) { - int instanceId = indices[i]; - BinnedInstance binnedInstance = instances[instanceId]; - double weight = binnedInstance.weight; - double gradient = pgh[3 * instanceId + 1]; - double hessian = pgh[3 * instanceId + 2]; + if (needSplit[k]) { + if (numFeatures != features.length) { + allFeatureValid = false; + featureValid = new BitSet(numFeatures); + for (int i = 0; i < features.length; i += 1) { + featureValid.set(features[i]); + featureOffset[features[i]] = binOffsets[i]; + } + } else { + allFeatureValid = true; + System.arraycopy(binOffsets, 0, featureOffset, 0, numFeatures); + } - if (null == binnedInstance.featureIds) { - for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { - if (allFeatureValid || featureValid.get(j)) { - add( - hists, - featureOffset[j], - binnedInstance.featureValues[j], - gradient, - hessian, - weight, - 1.); + for (int i = node.slice.start; i < node.slice.end; i += 1) { + int instanceId = indices[i]; + BinnedInstance binnedInstance = instances[instanceId]; + double weight = binnedInstance.weight; + double gradient = pgh[3 * instanceId + 1]; + double hessian = pgh[3 * instanceId + 2]; + + if (null == binnedInstance.featureIds) { + for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { + if (allFeatureValid || featureValid.get(j)) { + add( + hists, + featureOffset[j], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); + } + } + } else { + for (int j = 0; j < binnedInstance.featureIds.length; j += 1) { + int featureId = binnedInstance.featureIds[j]; + if (allFeatureValid || featureValid.get(featureId)) { + add( + hists, + featureOffset[featureId], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); + } } } - } else { - for (int j = 0; j < binnedInstance.featureIds.length; j += 1) { - int featureId = binnedInstance.featureIds[j]; - if (allFeatureValid || featureValid.get(featureId)) { + } + + for (int featureId : features) { + int defaultVal = featureDefaultVal[featureId]; + int defaultValIndex = (featureOffset[featureId] + defaultVal) * BIN_SIZE; + hists[defaultValIndex] = totalHists[0]; + hists[defaultValIndex + 1] = totalHists[1]; + hists[defaultValIndex + 2] = totalHists[2]; + hists[defaultValIndex + 3] = totalHists[3]; + for (int i = 0; i < numFeatureBins[featureId]; i += 1) { + if (i != defaultVal) { + int index = (featureOffset[featureId] + i) * BIN_SIZE; add( hists, featureOffset[featureId], - binnedInstance.featureValues[j], - gradient, - hessian, - weight, - 1.); + defaultVal, + -hists[index], + -hists[index + 1], + -hists[index + 2], + -hists[index + 3]); } } } - } - - for (int featureId : features) { - int defaultVal = featureDefaultVal[featureId]; - int defaultValIndex = (featureOffset[featureId] + defaultVal) * BIN_SIZE; - hists[defaultValIndex] = totalHists[0]; - hists[defaultValIndex + 1] = totalHists[1]; - hists[defaultValIndex + 2] = totalHists[2]; - hists[defaultValIndex + 3] = totalHists[3]; - for (int i = 0; i < numFeatureBins[featureId]; i += 1) { - if (i != defaultVal) { - int index = (featureOffset[featureId] + i) * BIN_SIZE; - add( - hists, - featureOffset[featureId], - defaultVal, - -hists[index], - -hists[index + 1], - -hists[index + 2], - -hists[index + 3]); - } - } + } else { + add( + hists, + binOffsets[0], + 0, + totalHists[0], + totalHists[1], + totalHists[2], + totalHists[3]); } LOG.info( "STEP 3: node {}, {} #instances, {} #features, {} ms", @@ -261,26 +269,39 @@ List> build( // Generates (nodeId, featureId) pairs that are required to build histograms. int[][] nodeToFeatures = new int[numNodes][]; int[] nodeFeaturePairs = new int[numNodes * numBaggingFeatures * 2]; + boolean[] needSplit = new boolean[numNodes]; int p = 0; + int numTotalBins = 0; for (int k = 0; k < numNodes; k += 1) { - nodeToFeatures[k] = - DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); - Arrays.sort(nodeToFeatures[k]); + LearningNode node = layer.get(k); + if (node.depth == maxDepth) { + needSplit[k] = false; + // Ignores the results, just to consume the randomizer. + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + // No need to calculate histograms for features, only sum of gradients and hessians + // are needed. Uses `numFeatures` to indicate this special "feature". + nodeToFeatures[k] = new int[] {numFeatures}; + } else { + needSplit[k] = true; + nodeToFeatures[k] = + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + Arrays.sort(nodeToFeatures[k]); + } for (int featureId : nodeToFeatures[k]) { nodeFeaturePairs[p++] = k; nodeFeaturePairs[p++] = featureId; + numTotalBins += numFeatureBins[featureId]; } } nodeFeaturePairsSetter.accept(nodeFeaturePairs); - int maxNumBins = - numNodes * Math.min(maxFeatureBins * numBaggingFeatures, totalNumFeatureBins); - double[] hists = new double[maxNumBins * BIN_SIZE]; + double[] hists = new double[numTotalBins * BIN_SIZE]; // Calculates histograms for (nodeId, featureId) pairs. long start = System.currentTimeMillis(); calcNodeFeaturePairHists( layer, nodeToFeatures, + needSplit, featureMetas, numFeatureBins, isInputVector, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java index 186f675b4..767973925 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java @@ -52,8 +52,9 @@ public SplitFinder(TrainContext trainContext) { numFeatureBins = trainContext.numFeatureBins; FeatureMeta[] featureMetas = trainContext.featureMetas; - splitters = new HistogramFeatureSplitter[trainContext.numFeatures]; - for (int i = 0; i < trainContext.numFeatures; ++i) { + int numFeatures = trainContext.numFeatures; + splitters = new HistogramFeatureSplitter[numFeatures + 1]; + for (int i = 0; i < numFeatures; ++i) { splitters[i] = FeatureMeta.Type.CATEGORICAL == featureMetas[i].type ? new CategoricalFeatureSplitter( @@ -61,6 +62,12 @@ public SplitFinder(TrainContext trainContext) { : new ContinuousFeatureSplitter( i, featureMetas[i], trainContext.strategy); } + // Adds an addition splitter to obtain the prediction of the node. + splitters[numFeatures] = + new ContinuousFeatureSplitter( + numFeatures, + new FeatureMeta.ContinuousFeatureMeta("SPECIAL", 0, new double[0]), + trainContext.strategy); maxDepth = trainContext.strategy.maxDepth; maxNumLeaves = trainContext.strategy.maxNumLeaves; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java index b01ecc707..6b0226f6e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java @@ -28,6 +28,7 @@ import org.apache.flink.ml.common.lossfunc.SquaredErrorLoss; import org.apache.flink.util.Preconditions; +import org.apache.commons.lang3.ArrayUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,8 +37,6 @@ import java.util.Random; import java.util.function.Function; -import static java.util.Arrays.stream; - class TrainContextInitializer { private static final Logger LOG = LoggerFactory.getLogger(TrainContextInitializer.class); private final BoostingStrategy strategy; @@ -80,10 +79,13 @@ public TrainContext init( trainContext.loss = getLoss(); trainContext.prior = calcPrior(trainContext.labelSumCount); + // A special `feature` is appended with #bins = 1 to simplify codes. trainContext.numFeatureBins = - stream(trainContext.featureMetas) - .mapToInt(d -> d.numBins(trainContext.strategy.useMissing)) - .toArray(); + ArrayUtils.add( + Arrays.stream(trainContext.featureMetas) + .mapToInt(d -> d.numBins(trainContext.strategy.useMissing)) + .toArray(), + 1); LOG.info("Number of bins for each feature: {}", trainContext.numFeatureBins); LOG.info("subtaskId: {}, {} end", subtaskId, TrainContextInitializer.class.getSimpleName()); return trainContext; From a105e3592bcb6d3962370f4487617288c1ad0807 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 3 Apr 2023 18:01:54 +0800 Subject: [PATCH 37/47] [NO MERGE] Ad-hoc fix for NaN values in KBinsDiscretizer --- .../kbinsdiscretizer/KBinsDiscretizer.java | 17 ++++--- .../ml/classification/GBTClassifierTest.java | 44 +++++++++---------- .../flink/ml/regression/GBTRegressorTest.java | 20 ++++----- 3 files changed, 44 insertions(+), 37 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java index ad3132cf6..07948cc27 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java @@ -36,6 +36,7 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import org.apache.commons.lang3.ArrayUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -210,20 +211,26 @@ private static double[][] findBinEdgesWithQuantileStrategy( int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][]; - double[] features = new double[numData]; for (int columnId = 0; columnId < numColumns; columnId++) { + double[] features = new double[numData]; for (int i = 0; i < numData; i++) { features[i] = input.get(i).get(columnId); } Arrays.sort(features); int n = numData; - while (n > 0 && Double.isNaN(features[n - 1])) { - n -= 1; + { + int validRange = numData; + while (validRange > 0 && Double.isNaN(features[validRange - 1])) { + validRange -= 1; + } + if (validRange < numData) { + features = ArrayUtils.subarray(features, 0, validRange); + } } - if (features[0] == features[n - 1]) { + if (features[0] == features[features.length - 1]) { LOG.warn("Feature " + columnId + " is constant and the output will all be zero."); binEdges[columnId] = new double[] {Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY}; @@ -236,7 +243,7 @@ private static double[][] findBinEdgesWithQuantileStrategy( for (int binEdgeId = 0; binEdgeId < numBins; binEdgeId++) { tempBinEdges[binEdgeId] = features[(int) (binEdgeId * width)]; } - tempBinEdges[numBins] = features[n - 1]; + tempBinEdges[numBins] = features[features.length - 1]; } else { tempBinEdges = features; } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java index 4a8bae57e..357a7303c 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -354,44 +354,44 @@ public void testFitAndPredictWithNoCategoricalCols() throws Exception { Arrays.asList( Row.of( 0.0, - Vectors.dense(2.4386858360079877, -2.4386858360079877), - Vectors.dense(0.9197301210345855, 0.08026987896541447)), + Vectors.dense(2.34563907006811, -2.34563907006811), + Vectors.dense(0.9125869728543822, 0.0874130271456178)), Row.of( - 0.0, - Vectors.dense(2.079593609142336, -2.079593609142336), - Vectors.dense(0.8889039070093702, 0.11109609299062985)), + 1.0, + Vectors.dense(-2.3303467465269785, 2.3303467465269785), + Vectors.dense(0.0886406478666607, 0.9113593521333393)), Row.of( 1.0, - Vectors.dense(-2.4477766607449594, 2.4477766607449594), - Vectors.dense(0.07960128978764613, 0.9203987102123539)), + Vectors.dense(-2.6627806586536007, 2.6627806586536007), + Vectors.dense(0.06520563648648892, 0.9347943635135111)), Row.of( 0.0, - Vectors.dense(2.3680506847981113, -2.3680506847981113), - Vectors.dense(0.9143583384561507, 0.0856416615438493)), + Vectors.dense(2.2219234863111987, -2.2219234863111987), + Vectors.dense(0.9022010445561748, 0.09779895544382528)), Row.of( 1.0, - Vectors.dense(-2.0115161495245792, 2.0115161495245792), - Vectors.dense(0.11799909267017583, 0.8820009073298242)), + Vectors.dense(-2.4261826518456586, 2.4261826518456586), + Vectors.dense(0.08119780449041314, 0.9188021955095869)), Row.of( 0.0, - Vectors.dense(2.3680506847981113, -2.3680506847981113), - Vectors.dense(0.9143583384561507, 0.0856416615438493)), + Vectors.dense(2.6577392865785714, -2.6577392865785714), + Vectors.dense(0.9344863980226659, 0.06551360197733418)), Row.of( 1.0, - Vectors.dense(-2.1774376078697983, 2.1774376078697983), - Vectors.dense(0.10179497553813543, 0.8982050244618646)), + Vectors.dense(-2.6641132494818254, 2.6641132494818254), + Vectors.dense(0.0651244569774293, 0.9348755430225707)), Row.of( 0.0, - Vectors.dense(2.434832949283468, -2.434832949283468), - Vectors.dense(0.9194452150195366, 0.08055478498046341)), + Vectors.dense(2.6577392865785714, -2.6577392865785714), + Vectors.dense(0.9344863980226659, 0.06551360197733418)), Row.of( - 1.0, - Vectors.dense(-2.441225164856452, 2.441225164856452), - Vectors.dense(0.08008260858505134, 0.9199173914149487)), + 0.0, + Vectors.dense(2.6577392865785714, -2.6577392865785714), + Vectors.dense(0.9344863980226659, 0.06551360197733418)), Row.of( 1.0, - Vectors.dense(-2.672457199454413, 2.672457199454413), - Vectors.dense(0.06461828968951666, 0.9353817103104833))); + Vectors.dense(-2.4318453555603523, 2.4318453555603523), + Vectors.dense(0.08077634070928352, 0.9192236592907165))); verifyPredictionResult(output, outputRowsUsingNoCategoricalCols); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java index 6f8e60a92..ffcddd9aa 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java @@ -268,16 +268,16 @@ public void testFitAndPredictWithNoCategoricalCols() throws Exception { Table output = model.transform(inputTable)[0].select($(gbtr.getPredictionCol())); List outputRowsUsingNoCategoricalCols = Arrays.asList( - Row.of(40.07663214615239), - Row.of(40.92462268161843), - Row.of(40.941626445241624), - Row.of(40.06608854749729), - Row.of(40.12272436518743), - Row.of(40.92737873124178), - Row.of(40.08092204935494), - Row.of(40.898529570430696), - Row.of(40.08092204935494), - Row.of(40.88296818645738)); + Row.of(40.060788327295285), + Row.of(40.92126707025628), + Row.of(40.08161253493682), + Row.of(40.916655697518976), + Row.of(40.95467692795112), + Row.of(40.070253879056665), + Row.of(40.06975535946203), + Row.of(40.923228418693306), + Row.of(40.093329043797524), + Row.of(40.923115214426424)); verifyPredictionResult(output, outputRowsUsingNoCategoricalCols); } From 10068129578ae9bf90ffddffa1d0b3caa0b65c5e Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 3 Apr 2023 19:08:54 +0800 Subject: [PATCH 38/47] Fix out-of-bound exception for nodeFeaturePairs --- .../flink/ml/common/gbt/operators/HistBuilder.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index ec7995f6b..8c09810c6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -28,6 +28,7 @@ import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.util.Distributor; +import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -268,9 +269,8 @@ List> build( // Generates (nodeId, featureId) pairs that are required to build histograms. int[][] nodeToFeatures = new int[numNodes][]; - int[] nodeFeaturePairs = new int[numNodes * numBaggingFeatures * 2]; + IntArrayList nodeFeaturePairs = new IntArrayList(numNodes * numBaggingFeatures * 2); boolean[] needSplit = new boolean[numNodes]; - int p = 0; int numTotalBins = 0; for (int k = 0; k < numNodes; k += 1) { LearningNode node = layer.get(k); @@ -288,12 +288,12 @@ List> build( Arrays.sort(nodeToFeatures[k]); } for (int featureId : nodeToFeatures[k]) { - nodeFeaturePairs[p++] = k; - nodeFeaturePairs[p++] = featureId; + nodeFeaturePairs.add(k); + nodeFeaturePairs.add(featureId); numTotalBins += numFeatureBins[featureId]; } } - nodeFeaturePairsSetter.accept(nodeFeaturePairs); + nodeFeaturePairsSetter.accept(nodeFeaturePairs.toArray()); double[] hists = new double[numTotalBins * BIN_SIZE]; // Calculates histograms for (nodeId, featureId) pairs. @@ -313,7 +313,7 @@ List> build( LOG.info("Elapsed time for calcNodeFeaturePairHists: {} ms", elapsed); // Calculates number of elements received by each downstream subtask. - int[] recvcnts = calcRecvCounts(numSubtasks, nodeFeaturePairs, numFeatureBins); + int[] recvcnts = calcRecvCounts(numSubtasks, nodeFeaturePairs.toArray(), numFeatureBins); List> histograms = new ArrayList<>(); int sliceStart = 0; From ad646c43fcba9df6c271e2089376bc76d8ace9b8 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 4 Apr 2023 18:02:18 +0800 Subject: [PATCH 39/47] Change to streaming processing from histogram building to global splitting --- .../operator/SharedStorageWrapper.java | 3 +- .../ml/common/gbt/BoostIterationBody.java | 38 +- .../flink/ml/common/gbt/defs/Histogram.java | 40 +- .../flink/ml/common/gbt/defs/Split.java | 11 + .../flink/ml/common/gbt/defs/Splits.java | 83 ---- .../CacheDataCalcLocalHistsOperator.java | 39 +- .../operators/CalcLocalSplitsOperator.java | 52 +- .../ml/common/gbt/operators/HistBuilder.java | 461 +++++++++--------- .../operators/HistogramAggregateFunction.java | 55 --- .../common/gbt/operators/InstanceUpdater.java | 5 + .../ml/common/gbt/operators/NodeSplitter.java | 5 + .../gbt/operators/PostSplitsOperator.java | 45 +- .../operators/ReduceHistogramFunction.java | 81 +++ .../gbt/operators/ReduceSplitsOperator.java | 138 ++++++ .../ml/common/gbt/operators/SplitFinder.java | 53 +- .../operators/SplitsAggregateFunction.java | 54 -- .../operators/TrainContextInitializer.java | 2 +- .../common/gbt/operators/TreeInitializer.java | 3 +- .../gbt/typeinfo/HistogramSerializer.java | 5 - 19 files changed, 597 insertions(+), 576 deletions(-) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceHistogramFunction.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java index 153ba8980..5ce35f699 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java @@ -20,7 +20,6 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; -import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.iteration.operator.OperatorWrapper; import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; @@ -70,7 +69,7 @@ public StreamOperator nowrap( return StreamOperatorFactoryUtil.createOperator( operatorFactory, (StreamTask) parameters.getContainingTask(), - OperatorUtils.createWrappedOperatorConfig(parameters.getStreamConfig()), + parameters.getStreamConfig(), parameters.getOutput(), parameters.getOperatorEventDispatcher()) .f0; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index 29b6c7cb5..86f73e54c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -18,23 +18,23 @@ package org.apache.flink.ml.common.gbt; -import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationBody; import org.apache.flink.iteration.IterationBodyResult; import org.apache.flink.ml.common.gbt.defs.BoostingStrategy; import org.apache.flink.ml.common.gbt.defs.Histogram; -import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.operators.CacheDataCalcLocalHistsOperator; import org.apache.flink.ml.common.gbt.operators.CalcLocalSplitsOperator; -import org.apache.flink.ml.common.gbt.operators.HistogramAggregateFunction; import org.apache.flink.ml.common.gbt.operators.PostSplitsOperator; +import org.apache.flink.ml.common.gbt.operators.ReduceHistogramFunction; +import org.apache.flink.ml.common.gbt.operators.ReduceSplitsOperator; import org.apache.flink.ml.common.gbt.operators.SharedStorageConstants; -import org.apache.flink.ml.common.gbt.operators.SplitsAggregateFunction; import org.apache.flink.ml.common.gbt.operators.TerminationOperator; import org.apache.flink.ml.common.sharedstorage.ItemDescriptor; import org.apache.flink.ml.common.sharedstorage.SharedStorageBody; @@ -69,33 +69,35 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( Map, String> ownerMap = new HashMap<>(); - // In 1st round, cache all data. For all rounds calculate local histogram based on - // current tree layer. CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = new CacheDataCalcLocalHistsOperator(strategy); - SingleOutputStreamOperator> localHists = + SingleOutputStreamOperator> localHists = data.connect(trainContext) .transform( "CacheDataCalcLocalHists", - new TypeHint>() {}.getTypeInfo(), + Types.TUPLE( + Types.INT, Types.INT, TypeInformation.of(Histogram.class)), cacheDataCalcLocalHistsOp); for (ItemDescriptor s : SharedStorageConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { ownerMap.put(s, cacheDataCalcLocalHistsOp.getSharedStorageAccessorID()); } - DataStream globalHists = - localHists - .partitionCustom((key, numPartitions) -> key, value -> value.f0) - .map(d -> d.f1) - .flatMap(new HistogramAggregateFunction()); + DataStream> globalHists = + localHists.keyBy(d -> d.f1).flatMap(new ReduceHistogramFunction()); - SingleOutputStreamOperator localSplits = + SingleOutputStreamOperator> localSplits = globalHists.transform( "CalcLocalSplits", - TypeInformation.of(Splits.class), + Types.TUPLE(Types.INT, Types.INT, TypeInformation.of(Split.class)), new CalcLocalSplitsOperator()); - DataStream globalSplits = - localSplits.broadcast().flatMap(new SplitsAggregateFunction()); + + DataStream> globalSplits = + localSplits + .keyBy(d -> d.f0) + .transform( + "ReduceGlobalSplits", + Types.TUPLE(Types.INT, TypeInformation.of(Split.class)), + new ReduceSplitsOperator()); PostSplitsOperator postSplitsOp = new PostSplitsOperator(); SingleOutputStreamOperator updatedModelData = @@ -118,7 +120,7 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( return new SharedStorageBody.SharedStorageBodyResult( Arrays.asList(updatedModelData, finalModelData, termination), - Arrays.asList(localHists, localSplits, updatedModelData, termination), + Arrays.asList(localHists, localSplits, globalSplits, updatedModelData, termination), ownerMap); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java index c85e44a03..de6dc42bd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Histogram.java @@ -18,7 +18,6 @@ package org.apache.flink.ml.common.gbt.defs; -import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.typeinfo.TypeInfo; import org.apache.flink.ml.common.gbt.typeinfo.HistogramTypeInfoFactory; import org.apache.flink.util.Preconditions; @@ -26,12 +25,13 @@ import java.io.Serializable; /** - * This class stores values of histogram bins, and necessary information of reducing and scattering. + * This class stores values of histogram bins. + * + *

Note that only the part of {@link Histogram#hists} specified by {@link Histogram#slice} is + * valid. */ @TypeInfo(HistogramTypeInfoFactory.class) public class Histogram implements Serializable { - // Stores source subtask ID. - public int subtaskId; // Stores values of histogram bins. public double[] hists; // Stores the valid slice of `hists`. @@ -39,44 +39,16 @@ public class Histogram implements Serializable { public Histogram() {} - public Histogram(int subtaskId, double[] hists, Slice slice) { - this.subtaskId = subtaskId; + public Histogram(double[] hists, Slice slice) { this.hists = hists; this.slice = slice; } - private Histogram accumulate(Histogram other) { + public Histogram accumulate(Histogram other) { Preconditions.checkArgument(slice.size() == other.slice.size()); for (int i = 0; i < slice.size(); i += 1) { hists[slice.start + i] += other.hists[other.slice.start + i]; } return this; } - - /** Aggregator for Histogram. */ - public static class Aggregator - implements AggregateFunction, Serializable { - @Override - public Histogram createAccumulator() { - return null; - } - - @Override - public Histogram add(Histogram value, Histogram accumulator) { - if (null == accumulator) { - return value; - } - return accumulator.accumulate(value); - } - - @Override - public Histogram getResult(Histogram accumulator) { - return accumulator; - } - - @Override - public Histogram merge(Histogram a, Histogram b) { - return a.accumulate(b); - } - } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java index a88f5e668..1baeb35f4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Split.java @@ -53,6 +53,17 @@ public Split( this.prediction = prediction; } + public Split accumulate(Split other) { + if (gain < other.gain) { + return other; + } else if (gain == other.gain) { + if (featureId < other.featureId) { + return other; + } + } + return this; + } + /** * Test the binned instance should go to the left child or the right child. * diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java deleted file mode 100644 index 72c69e4ca..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/defs/Splits.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.defs; - -import org.apache.flink.api.common.functions.AggregateFunction; - -/** - * This class stores splits of nodes in the current layer, and necessary information of - * all-reducing.. - */ -public class Splits { - - // Stores source subtask ID when reducing or target subtask ID when scattering. - public int subtaskId; - // Stores splits of nodes in the current layer. - public Split[] splits; - - public Splits() {} - - public Splits(int subtaskId, Split[] splits) { - this.subtaskId = subtaskId; - this.splits = splits; - } - - private Splits accumulate(Splits other) { - for (int i = 0; i < splits.length; ++i) { - if (splits[i] == null && other.splits[i] != null) { - splits[i] = other.splits[i]; - } else if (splits[i] != null && other.splits[i] != null) { - if (splits[i].gain < other.splits[i].gain) { - splits[i] = other.splits[i]; - } else if (splits[i].gain == other.splits[i].gain) { - if (splits[i].featureId < other.splits[i].featureId) { - splits[i] = other.splits[i]; - } - } - } - } - return this; - } - - /** Aggregator for Splits. */ - public static class Aggregator implements AggregateFunction { - @Override - public Splits createAccumulator() { - return null; - } - - @Override - public Splits add(Splits value, Splits accumulator) { - if (null == accumulator) { - return value; - } - return accumulator.accumulate(value); - } - - @Override - public Splits getResult(Splits accumulator) { - return accumulator; - } - - @Override - public Splits merge(Splits a, Splits b) { - return a.accumulate(b); - } - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index a19ae1ab9..7b6c79be1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; @@ -51,13 +51,16 @@ import java.util.UUID; /** - * Calculates local histograms for local data partition. Specifically in the first round, this - * operator caches all data instances to JVM static region. + * Calculates local histograms for local data partition. + * + *

This operator only has input elements in the first round, including data instances and raw + * training context. There will be no input elements in other rounds. The output elements are tuples + * of (subtask index, (nodeId, featureId) pair index, Histogram). */ public class CacheDataCalcLocalHistsOperator - extends AbstractStreamOperator> - implements TwoInputStreamOperator>, - IterationListener>, + extends AbstractStreamOperator> + implements TwoInputStreamOperator>, + IterationListener>, SharedStorageStreamOperator { private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; @@ -154,9 +157,8 @@ public void processElement2(StreamRecord streamRecord) throws Exce setter.set(SharedStorageConstants.TRAIN_CONTEXT, rawTrainContext)); } - @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector> out) + int epochWatermark, Context context, Collector> out) throws Exception { if (0 == epochWatermark) { // Initializes local state in first round. @@ -215,6 +217,7 @@ public void onEpochWatermarkIncremented( // When last tree is finished, initializes a new tree, and shuffle instance // indices. treeInitializer.init( + getter.get(SharedStorageConstants.ALL_TREES).size(), d -> setter.set(SharedStorageConstants.SHUFFLED_INDICES, d)); LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); indices = getter.get(SharedStorageConstants.SHUFFLED_INDICES); @@ -229,22 +232,20 @@ public void onEpochWatermarkIncremented( setter.set(SharedStorageConstants.HAS_INITED_TREE, false); } - List> histograms = - histBuilder.build( - layer, - indices, - instances, - pgh, - d -> setter.set(SharedStorageConstants.NODE_FEATURE_PAIRS, d)); - for (Tuple2 t : histograms) { - out.collect(t); - } + histBuilder.build( + layer, + indices, + instances, + pgh, + d -> setter.set(SharedStorageConstants.NODE_FEATURE_PAIRS, d), + out); }); } @Override public void onIterationTerminated( - Context context, Collector> collector) throws Exception { + Context context, Collector> collector) + throws Exception { instancesCollecting.clear(); treeInitializerState.clear(); histBuilderState.clear(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java index 7750c5fa6..dbc77b1c6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -18,13 +18,14 @@ package org.apache.flink.ml.common.gbt.operators; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; -import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; -import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; @@ -32,18 +33,26 @@ import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Collections; import java.util.List; import java.util.UUID; -/** Calculates local splits for assigned (nodeId, featureId) pairs. */ -public class CalcLocalSplitsOperator extends AbstractStreamOperator - implements OneInputStreamOperator, - IterationListener, +/** + * Calculates best splits from histograms for (nodeId, featureId) pairs. + * + *

The input elements are tuples of ((nodeId, featureId) pair index, Histogram). The output + * elements are tuples of (node index, (nodeId, featureId) pair index, Split). + */ +public class CalcLocalSplitsOperator extends AbstractStreamOperator> + implements OneInputStreamOperator< + Tuple2, Tuple3>, SharedStorageStreamOperator { + private static final Logger LOG = LoggerFactory.getLogger(CalcLocalSplitsOperator.class); private static final String SPLIT_FINDER_STATE_NAME = "split_finder"; private final String sharedStorageAccessorID; // States of local data. @@ -79,11 +88,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } @Override - public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector collector) {} - - @Override - public void processElement(StreamRecord element) throws Exception { + public void processElement(StreamRecord> element) throws Exception { if (null == splitFinder) { sharedStorageContext.invoke( (getter, setter) -> { @@ -93,7 +98,10 @@ public void processElement(StreamRecord element) throws Exception { }); } - Histogram histogram = element.getValue(); + Tuple2 value = element.getValue(); + int pairId = value.f0; + Histogram histogram = value.f1; + LOG.debug("Received histogram for pairId: {}", pairId); sharedStorageContext.invoke( (getter, setter) -> { List layer = getter.get(SharedStorageConstants.LAYER); @@ -102,18 +110,26 @@ public void processElement(StreamRecord element) throws Exception { Collections.singletonList( getter.get(SharedStorageConstants.ROOT_LEARNING_NODE)); } - Splits splits = + + int[] nodeFeaturePairs = getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS); + int nodeId = nodeFeaturePairs[2 * pairId]; + int featureId = nodeFeaturePairs[2 * pairId + 1]; + LearningNode node = layer.get(nodeId); + + Split bestSplit = splitFinder.calc( - layer, - getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS), + node, + featureId, getter.get(SharedStorageConstants.LEAVES).size(), histogram); - output.collect(new StreamRecord<>(splits)); + output.collect(new StreamRecord<>(Tuple3.of(nodeId, pairId, bestSplit))); }); + LOG.debug("Output split for pairId: {}", pairId); } @Override - public void onIterationTerminated(Context context, Collector collector) { + public void close() throws Exception { + super.close(); splitFinderState.clear(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java index 8c09810c6..9ccc62211 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistBuilder.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.common.gbt.operators; -import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.common.gbt.DataUtils; import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; @@ -26,13 +26,13 @@ import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Slice; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.util.Distributor; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; import java.util.Arrays; import java.util.BitSet; import java.util.List; @@ -46,7 +46,6 @@ class HistBuilder { private static final Logger LOG = LoggerFactory.getLogger(HistBuilder.class); private final int subtaskId; - private final int numSubtasks; private final int numFeatures; private final int[] numFeatureBins; @@ -61,7 +60,6 @@ class HistBuilder { public HistBuilder(TrainContext trainContext) { subtaskId = trainContext.subtaskId; - numSubtasks = trainContext.numSubtasks; numFeatures = trainContext.numFeatures; numFeatureBins = trainContext.numFeatureBins; @@ -75,261 +73,282 @@ public HistBuilder(TrainContext trainContext) { maxDepth = trainContext.strategy.maxDepth; } - /** - * Calculate histograms for all (nodeId, featureId) pairs. The results are written to `hists`, - * so `hists` must be large enough to store values. - */ - private static void calcNodeFeaturePairHists( + /** Calculate local histograms for nodes in current layer of tree. */ + void build( List layer, - int[][] nodeToFeatures, - boolean[] needSplit, - FeatureMeta[] featureMetas, - int[] numFeatureBins, - boolean isInputVector, int[] indices, BinnedInstance[] instances, double[] pgh, - double[] hists) { + Consumer nodeFeaturePairsSetter, + Collector> out) { + LOG.info("subtaskId: {}, {} start", subtaskId, HistBuilder.class.getSimpleName()); int numNodes = layer.size(); - int numFeatures = featureMetas.length; - int[][] nodeToBinOffsets = new int[numNodes][]; - int binOffset = 0; + // Generates (nodeId, featureId) pairs that are required to build histograms. + int[][] nodeToFeatures = new int[numNodes][]; + IntArrayList nodeFeaturePairs = new IntArrayList(numNodes * numBaggingFeatures * 2); for (int k = 0; k < numNodes; k += 1) { - int[] features = nodeToFeatures[k]; - nodeToBinOffsets[k] = new int[features.length]; - for (int i = 0; i < features.length; i += 1) { - nodeToBinOffsets[k][i] = binOffset; - binOffset += numFeatureBins[features[i]]; + LearningNode node = layer.get(k); + if (node.depth == maxDepth) { + // Ignores the results, just to consume the randomizer. + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + // No need to calculate histograms for features, only sum of gradients and hessians + // are needed. Uses `numFeatures` to indicate this special "feature". + nodeToFeatures[k] = new int[] {numFeatures}; + } else { + nodeToFeatures[k] = + DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); + Arrays.sort(nodeToFeatures[k]); + } + for (int featureId : nodeToFeatures[k]) { + nodeFeaturePairs.add(k); + nodeFeaturePairs.add(featureId); } } + nodeFeaturePairsSetter.accept(nodeFeaturePairs.toArray()); - int[] featureDefaultVal = new int[numFeatures]; - for (int i = 0; i < numFeatures; i += 1) { - FeatureMeta d = featureMetas[i]; - featureDefaultVal[i] = - isInputVector && d instanceof FeatureMeta.ContinuousFeatureMeta - ? ((FeatureMeta.ContinuousFeatureMeta) d).zeroBin - : d.missingBin; + // Calculates histograms for (nodeId, featureId) pairs. + HistBuilderImpl builderImpl = + new HistBuilderImpl( + layer, + maxDepth, + numFeatures, + numFeatureBins, + nodeToFeatures, + indices, + instances, + pgh); + builderImpl.init(isInputVector, featureMetas); + builderImpl.calcHistsForPairs(subtaskId, out); + + LOG.info("subtaskId: {}, {} end", subtaskId, HistBuilder.class.getSimpleName()); + } + + static class HistBuilderImpl { + private final List layer; + private final int maxDepth; + private final int numFeatures; + private final int[] numFeatureBins; + private final int[][] nodeToFeatures; + private final int[] indices; + private final BinnedInstance[] instances; + private final double[] pgh; + + private int[] featureDefaultVal; + + public HistBuilderImpl( + List layer, + int maxDepth, + int numFeatures, + int[] numFeatureBins, + int[][] nodeToFeatures, + int[] indices, + BinnedInstance[] instances, + double[] pgh) { + this.layer = layer; + this.maxDepth = maxDepth; + this.numFeatures = numFeatures; + this.numFeatureBins = numFeatureBins; + this.nodeToFeatures = nodeToFeatures; + this.indices = indices; + this.instances = instances; + this.pgh = pgh; + Preconditions.checkArgument(numFeatureBins.length == numFeatures + 1); } - int[] featureOffset = new int[numFeatures]; - BitSet featureValid = null; - boolean allFeatureValid; - for (int k = 0; k < numNodes; k += 1) { - long start = System.currentTimeMillis(); - int[] features = nodeToFeatures[k]; - int[] binOffsets = nodeToBinOffsets[k]; - LearningNode node = layer.get(k); + private static void calcHistsForDefaultBin( + int defaultVal, + int featureOffset, + int numBins, + double[] totalHists, + double[] hists, + int nodeOffset) { + int defaultValIndex = (nodeOffset + featureOffset + defaultVal) * BIN_SIZE; + hists[defaultValIndex] = totalHists[0]; + hists[defaultValIndex + 1] = totalHists[1]; + hists[defaultValIndex + 2] = totalHists[2]; + hists[defaultValIndex + 3] = totalHists[3]; + for (int i = 0; i < numBins; i += 1) { + if (i != defaultVal) { + int index = (nodeOffset + featureOffset + i) * BIN_SIZE; + add( + hists, + nodeOffset + featureOffset, + defaultVal, + -hists[index], + -hists[index + 1], + -hists[index + 2], + -hists[index + 3]); + } + } + } - double[] totalHists = new double[4]; + private static void add( + double[] hists, int offset, int val, double d0, double d1, double d2, double d3) { + int index = (offset + val) * BIN_SIZE; + hists[index] += d0; + hists[index + 1] += d1; + hists[index + 2] += d2; + hists[index + 3] += d3; + } + + private void init(boolean isInputVector, FeatureMeta[] featureMetas) { + featureDefaultVal = new int[numFeatures]; + for (int i = 0; i < numFeatures; i += 1) { + FeatureMeta d = featureMetas[i]; + featureDefaultVal[i] = + isInputVector && d instanceof FeatureMeta.ContinuousFeatureMeta + ? ((FeatureMeta.ContinuousFeatureMeta) d).zeroBin + : d.missingBin; + } + } + + private void calcTotalHists(LearningNode node, double[] totalHists, int offset) { for (int i = node.slice.start; i < node.slice.end; i += 1) { int instanceId = indices[i]; BinnedInstance binnedInstance = instances[instanceId]; double weight = binnedInstance.weight; double gradient = pgh[3 * instanceId + 1]; double hessian = pgh[3 * instanceId + 2]; - - totalHists[0] += gradient; - totalHists[1] += hessian; - totalHists[2] += weight; - totalHists[3] += 1.; + add(totalHists, offset, 0, gradient, hessian, weight, 1.); } + } - if (needSplit[k]) { - if (numFeatures != features.length) { - allFeatureValid = false; - featureValid = new BitSet(numFeatures); - for (int i = 0; i < features.length; i += 1) { - featureValid.set(features[i]); - featureOffset[features[i]] = binOffsets[i]; - } - } else { - allFeatureValid = true; - System.arraycopy(binOffsets, 0, featureOffset, 0, numFeatures); - } + private void calcHistsForNonDefaultBins( + LearningNode node, + boolean allFeatureValid, + BitSet featureValid, + int[] featureOffset, + double[] hists, + int nodeOffset) { + for (int i = node.slice.start; i < node.slice.end; i += 1) { + int instanceId = indices[i]; + BinnedInstance binnedInstance = instances[instanceId]; + double weight = binnedInstance.weight; + double gradient = pgh[3 * instanceId + 1]; + double hessian = pgh[3 * instanceId + 2]; - for (int i = node.slice.start; i < node.slice.end; i += 1) { - int instanceId = indices[i]; - BinnedInstance binnedInstance = instances[instanceId]; - double weight = binnedInstance.weight; - double gradient = pgh[3 * instanceId + 1]; - double hessian = pgh[3 * instanceId + 2]; - - if (null == binnedInstance.featureIds) { - for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { - if (allFeatureValid || featureValid.get(j)) { - add( - hists, - featureOffset[j], - binnedInstance.featureValues[j], - gradient, - hessian, - weight, - 1.); - } - } - } else { - for (int j = 0; j < binnedInstance.featureIds.length; j += 1) { - int featureId = binnedInstance.featureIds[j]; - if (allFeatureValid || featureValid.get(featureId)) { - add( - hists, - featureOffset[featureId], - binnedInstance.featureValues[j], - gradient, - hessian, - weight, - 1.); - } + if (null == binnedInstance.featureIds) { + for (int j = 0; j < binnedInstance.featureValues.length; j += 1) { + if (allFeatureValid || featureValid.get(j)) { + add( + hists, + nodeOffset + featureOffset[j], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); } } - } - - for (int featureId : features) { - int defaultVal = featureDefaultVal[featureId]; - int defaultValIndex = (featureOffset[featureId] + defaultVal) * BIN_SIZE; - hists[defaultValIndex] = totalHists[0]; - hists[defaultValIndex + 1] = totalHists[1]; - hists[defaultValIndex + 2] = totalHists[2]; - hists[defaultValIndex + 3] = totalHists[3]; - for (int i = 0; i < numFeatureBins[featureId]; i += 1) { - if (i != defaultVal) { - int index = (featureOffset[featureId] + i) * BIN_SIZE; + } else { + for (int j = 0; j < binnedInstance.featureIds.length; j += 1) { + int featureId = binnedInstance.featureIds[j]; + if (allFeatureValid || featureValid.get(featureId)) { add( hists, - featureOffset[featureId], - defaultVal, - -hists[index], - -hists[index + 1], - -hists[index + 2], - -hists[index + 3]); + nodeOffset + featureOffset[featureId], + binnedInstance.featureValues[j], + gradient, + hessian, + weight, + 1.); } } } - } else { - add( - hists, - binOffsets[0], - 0, - totalHists[0], - totalHists[1], - totalHists[2], - totalHists[3]); } - LOG.info( - "STEP 3: node {}, {} #instances, {} #features, {} ms", - k, - node.slice.size(), - features.length, - System.currentTimeMillis() - start); } - } - - private static void add( - double[] hists, int offset, int val, double d0, double d1, double d2, double d3) { - int index = (offset + val) * BIN_SIZE; - hists[index] += d0; - hists[index + 1] += d1; - hists[index + 2] += d2; - hists[index + 3] += d3; - } - /** - * Calculates elements counts of histogram distributed to each downstream subtask. The elements - * counts is bin counts multiplied by STEP. The minimum unit to be distributed is (nodeId, - * featureId), i.e., all bins belonging to the same (nodeId, featureId) pair must go to one - * subtask. - */ - private static int[] calcRecvCounts( - int numSubtasks, int[] nodeFeaturePairs, int[] numFeatureBins) { - int[] recvcnts = new int[numSubtasks]; - Distributor.EvenDistributor distributor = - new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.length / 2); - for (int k = 0; k < numSubtasks; k += 1) { - int pairStart = (int) distributor.start(k); - int pairCnt = (int) distributor.count(k); - for (int i = pairStart; i < pairStart + pairCnt; i += 1) { - int featureId = nodeFeaturePairs[2 * i + 1]; - recvcnts[k] += numFeatureBins[featureId] * BIN_SIZE; + private void calcHistsForSplitNode( + LearningNode node, + int[] features, + int[] binOffsets, + double[] hists, + int nodeOffset) { + double[] totalHists = new double[4]; + calcTotalHists(node, totalHists, 0); + + int[] featureOffsets = new int[numFeatures]; + BitSet featureValid = null; + boolean allFeatureValid; + if (numFeatures != features.length) { + allFeatureValid = false; + featureValid = new BitSet(numFeatures); + for (int i = 0; i < features.length; i += 1) { + featureValid.set(features[i]); + featureOffsets[features[i]] = binOffsets[i]; + } + } else { + allFeatureValid = true; + System.arraycopy(binOffsets, 0, featureOffsets, 0, numFeatures); } - } - return recvcnts; - } - /** Calculate local histograms for nodes in current layer of tree. */ - List> build( - List layer, - int[] indices, - BinnedInstance[] instances, - double[] pgh, - Consumer nodeFeaturePairsSetter) { - LOG.info("subtaskId: {}, {} start", subtaskId, HistBuilder.class.getSimpleName()); - int numNodes = layer.size(); + calcHistsForNonDefaultBins( + node, allFeatureValid, featureValid, featureOffsets, hists, nodeOffset); - // Generates (nodeId, featureId) pairs that are required to build histograms. - int[][] nodeToFeatures = new int[numNodes][]; - IntArrayList nodeFeaturePairs = new IntArrayList(numNodes * numBaggingFeatures * 2); - boolean[] needSplit = new boolean[numNodes]; - int numTotalBins = 0; - for (int k = 0; k < numNodes; k += 1) { - LearningNode node = layer.get(k); - if (node.depth == maxDepth) { - needSplit[k] = false; - // Ignores the results, just to consume the randomizer. - DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); - // No need to calculate histograms for features, only sum of gradients and hessians - // are needed. Uses `numFeatures` to indicate this special "feature". - nodeToFeatures[k] = new int[] {numFeatures}; - } else { - needSplit[k] = true; - nodeToFeatures[k] = - DataUtils.sample(featureIndicesPool, numBaggingFeatures, featureRandomizer); - Arrays.sort(nodeToFeatures[k]); - } - for (int featureId : nodeToFeatures[k]) { - nodeFeaturePairs.add(k); - nodeFeaturePairs.add(featureId); - numTotalBins += numFeatureBins[featureId]; + for (int featureId : features) { + calcHistsForDefaultBin( + featureDefaultVal[featureId], + featureOffsets[featureId], + numFeatureBins[featureId], + totalHists, + hists, + nodeOffset); } } - nodeFeaturePairsSetter.accept(nodeFeaturePairs.toArray()); - double[] hists = new double[numTotalBins * BIN_SIZE]; - // Calculates histograms for (nodeId, featureId) pairs. - long start = System.currentTimeMillis(); - calcNodeFeaturePairHists( - layer, - nodeToFeatures, - needSplit, - featureMetas, - numFeatureBins, - isInputVector, - indices, - instances, - pgh, - hists); - long elapsed = System.currentTimeMillis() - start; - LOG.info("Elapsed time for calcNodeFeaturePairHists: {} ms", elapsed); - - // Calculates number of elements received by each downstream subtask. - int[] recvcnts = calcRecvCounts(numSubtasks, nodeFeaturePairs.toArray(), numFeatureBins); - - List> histograms = new ArrayList<>(); - int sliceStart = 0; - for (int i = 0; i < recvcnts.length; i += 1) { - int sliceSize = recvcnts[i]; - histograms.add( - Tuple2.of( - i, - new Histogram( + /** Calculate histograms for all (nodeId, featureId) pairs. */ + private void calcHistsForPairs( + int subtaskId, Collector> out) { + long start = System.currentTimeMillis(); + int numNodes = layer.size(); + int offset = 0; + int pairBaseId = 0; + for (int k = 0; k < numNodes; k += 1) { + int[] features = nodeToFeatures[k]; + final int nodeOffset = offset; + int[] binOffsets = new int[features.length]; + for (int i = 0; i < features.length; i += 1) { + binOffsets[i] = offset - nodeOffset; + offset += numFeatureBins[features[i]]; + } + + double[] nodeHists = new double[(offset - nodeOffset) * BIN_SIZE]; + long nodeStart = System.currentTimeMillis(); + LearningNode node = layer.get(k); + if (node.depth != maxDepth) { + calcHistsForSplitNode(node, features, binOffsets, nodeHists, 0); + } else { + calcTotalHists(node, nodeHists, 0); + } + LOG.info( + "subtaskId: {}, node {}, {} #instances, {} #features, {} ms", + subtaskId, + k, + node.slice.size(), + features.length, + System.currentTimeMillis() - nodeStart); + + int sliceStart = 0; + for (int i = 0; i < features.length; i += 1) { + int sliceSize = numFeatureBins[features[i]] * BIN_SIZE; + int pairId = pairBaseId + i; + out.collect( + Tuple3.of( subtaskId, - hists, - new Slice(sliceStart, sliceStart + sliceSize)))); - sliceStart += sliceSize; - } + pairId, + new Histogram( + nodeHists, + new Slice(sliceStart, sliceStart + sliceSize)))); + sliceStart += sliceSize; + } + pairBaseId += features.length; + } - LOG.info("subtaskId: {}, {} end", this.subtaskId, HistBuilder.class.getSimpleName()); - return histograms; + LOG.info( + "subtaskId: {}, elapsed time for calculating histograms: {} ms", + subtaskId, + System.currentTimeMillis() - start); + } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java deleted file mode 100644 index 8fbb869f4..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/HistogramAggregateFunction.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.operators; - -import org.apache.flink.api.common.functions.AggregateFunction; -import org.apache.flink.api.common.functions.RichFlatMapFunction; -import org.apache.flink.ml.common.gbt.defs.Histogram; -import org.apache.flink.util.Collector; -import org.apache.flink.util.Preconditions; - -import java.util.BitSet; - -/** Aggregation function for merging histograms. */ -public class HistogramAggregateFunction extends RichFlatMapFunction { - - private final AggregateFunction aggregator = - new Histogram.Aggregator(); - private int numSubtasks; - private BitSet accepted; - private Histogram acc = null; - - @Override - public void flatMap(Histogram value, Collector out) throws Exception { - if (null == accepted) { - numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); - accepted = new BitSet(numSubtasks); - } - int receivedSubtaskId = value.subtaskId; - Preconditions.checkState(!accepted.get(receivedSubtaskId)); - accepted.set(receivedSubtaskId); - acc = aggregator.add(value, acc); - if (numSubtasks == accepted.cardinality()) { - acc.subtaskId = getRuntimeContext().getIndexOfThisSubtask(); - out.collect(acc); - accepted = null; - acc = null; - } - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java index bee71a068..3cb3bd4cd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/InstanceUpdater.java @@ -54,6 +54,7 @@ public void update( Consumer pghSetter, List treeNodes) { LOG.info("subtaskId: {}, {} start", subtaskId, InstanceUpdater.class.getSimpleName()); + long start = System.currentTimeMillis(); if (pgh.length == 0) { pgh = new double[instances.length * 3]; for (int i = 0; i < instances.length; i += 1) { @@ -78,6 +79,10 @@ public void update( } pghSetter.accept(pgh); LOG.info("subtaskId: {}, {} end", subtaskId, InstanceUpdater.class.getSimpleName()); + LOG.info( + "subtaskId: {}, elapsed time for updating instances: {} ms", + subtaskId, + System.currentTimeMillis() - start); } private void updatePgh(int instanceId, double pred, double label, double[] pgh) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java index 66e3b7d51..6e6dd071b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/NodeSplitter.java @@ -105,6 +105,7 @@ public List split( int[] indices, BinnedInstance[] instances) { LOG.info("subtaskId: {}, {} start", subtaskId, NodeSplitter.class.getSimpleName()); + long start = System.currentTimeMillis(); Preconditions.checkState(splits.length == layer.size()); List nextLayer = new ArrayList<>(); @@ -136,6 +137,10 @@ public List split( } } LOG.info("subtaskId: {}, {} end", subtaskId, NodeSplitter.class.getSimpleName()); + LOG.info( + "subtaskId: {}, elapsed time for splitting nodes: {} ms", + subtaskId, + System.currentTimeMillis() - start); return nextLayer; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 4c5f04ee5..5ae41968c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.common.gbt.operators; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; @@ -25,7 +26,7 @@ import org.apache.flink.ml.common.gbt.defs.BinnedInstance; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Node; -import org.apache.flink.ml.common.gbt.defs.Splits; +import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; @@ -36,6 +37,9 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -46,19 +50,19 @@ * update instances scores after a tree is complete. */ public class PostSplitsOperator extends AbstractStreamOperator - implements OneInputStreamOperator, + implements OneInputStreamOperator, Integer>, IterationListener, SharedStorageStreamOperator { - private static final String SPLITS_STATE_NAME = "splits"; private static final String NODE_SPLITTER_STATE_NAME = "node_splitter"; private static final String INSTANCE_UPDATER_STATE_NAME = "instance_updater"; + private static final Logger LOG = LoggerFactory.getLogger(PostSplitsOperator.class); + private final String sharedStorageAccessorID; // States of local data. - private transient ListStateWithCache splitsState; - private transient Splits splits; + private transient Split[] nodeSplits; private transient ListStateWithCache nodeSplitterState; private transient NodeSplitter nodeSplitter; private transient ListStateWithCache instanceUpdaterState; @@ -73,14 +77,6 @@ public PostSplitsOperator() { public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - splitsState = - new ListStateWithCache<>( - new KryoSerializer<>(Splits.class, getExecutionConfig()), - getContainingTask(), - getRuntimeContext(), - context, - getOperatorID()); - splits = OperatorStateUtils.getUniqueElement(splitsState, SPLITS_STATE_NAME).orElse(null); nodeSplitterState = new ListStateWithCache<>( new KryoSerializer<>(NodeSplitter.class, getExecutionConfig()), @@ -109,8 +105,6 @@ public void initializeState(StateInitializationContext context) throws Exception @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - splitsState.update(Collections.singletonList(splits)); - splitsState.snapshotState(context); nodeSplitterState.snapshotState(context); instanceUpdaterState.snapshotState(context); sharedStorageContext.snapshotState(context); @@ -157,9 +151,10 @@ public void onEpochWatermarkIncremented( currentTreeNodes, layer, leaves, - splits.splits, + nodeSplits, indices, instances); + nodeSplits = null; setter.set(SharedStorageConstants.LEAVES, leaves); setter.set(SharedStorageConstants.LAYER, nextLayer); setter.set(SharedStorageConstants.CURRENT_TREE_NODES, currentTreeNodes); @@ -181,6 +176,7 @@ public void onEpochWatermarkIncremented( setter.set(SharedStorageConstants.LEAVES, new ArrayList<>()); setter.set(SharedStorageConstants.SWAPPED_INDICES, new int[0]); setter.set(SharedStorageConstants.ALL_TREES, allTrees); + LOG.info("finalize {}-th tree", allTrees.size()); } else { setter.set(SharedStorageConstants.SWAPPED_INDICES, indices); setter.set(SharedStorageConstants.NEED_INIT_TREE, false); @@ -202,13 +198,24 @@ public void onIterationTerminated(Context context, Collector collector) } @Override - public void processElement(StreamRecord element) throws Exception { - splits = element.getValue(); + public void processElement(StreamRecord> element) throws Exception { + if (null == nodeSplits) { + sharedStorageContext.invoke( + (getter, setter) -> { + List layer = getter.get(SharedStorageConstants.LAYER); + int numNodes = (layer.size() == 0) ? 1 : layer.size(); + nodeSplits = new Split[numNodes]; + }); + } + Tuple2 value = element.getValue(); + int nodeId = value.f0; + Split split = value.f1; + LOG.debug("Received split for node {}", nodeId); + nodeSplits[nodeId] = split; } @Override public void close() throws Exception { - splitsState.clear(); nodeSplitterState.clear(); instanceUpdaterState.clear(); sharedStorageContext.clear(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceHistogramFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceHistogramFunction.java new file mode 100644 index 000000000..c3a0617bf --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceHistogramFunction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.gbt.defs.Histogram; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.BitSet; +import java.util.HashMap; +import java.util.Map; + +/** + * This operator reduces histograms for (nodeId, featureId) pairs. + * + *

The input elements are tuples of (subtask index, (nodeId, featureId) pair index, Histogram). + * The output elements are tuples of ((nodeId, featureId) pair index, Histogram). + */ +public class ReduceHistogramFunction + extends RichFlatMapFunction< + Tuple3, Tuple2> { + + private static final Logger LOG = LoggerFactory.getLogger(ReduceHistogramFunction.class); + + private final Map pairAccepted = new HashMap<>(); + private final Map pairHistogram = new HashMap<>(); + private int numSubtasks; + + @Override + public void open(Configuration parameters) throws Exception { + numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); + } + + @Override + public void flatMap( + Tuple3 value, Collector> out) + throws Exception { + int sourceSubtaskId = value.f0; + int pairId = value.f1; + Histogram histogram = value.f2; + + BitSet accepted = pairAccepted.getOrDefault(pairId, new BitSet(numSubtasks)); + if (accepted.isEmpty()) { + LOG.debug("Received histogram for new pair {}", pairId); + } + Preconditions.checkState(!accepted.get(sourceSubtaskId)); + accepted.set(sourceSubtaskId); + pairAccepted.put(pairId, accepted); + + pairHistogram.compute(pairId, (k, v) -> null == v ? histogram : v.accumulate(histogram)); + if (numSubtasks == accepted.cardinality()) { + out.collect(Tuple2.of(pairId, pairHistogram.get(pairId))); + LOG.debug("Output accumulated histogram for pair {}", pairId); + pairAccepted.remove(pairId); + pairHistogram.remove(pairId); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java new file mode 100644 index 000000000..d0e5839f6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.gbt.operators; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.gbt.defs.Split; +import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.BitSet; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +/** + * Reduces best splits for nodes. + * + *

The input elements are tuples of (node index, (nodeId, featureId) pair index, Split). The + * output elements are tuples of (node index, Split). + */ +public class ReduceSplitsOperator extends AbstractStreamOperator> + implements OneInputStreamOperator, Tuple2>, + SharedStorageStreamOperator { + + private static final Logger LOG = LoggerFactory.getLogger(ReduceSplitsOperator.class); + + private final String sharedStorageAccessorID; + + private transient SharedStorageContext sharedStorageContext; + + private Map nodeFeatureMap; + private Map nodeBestSplit; + private Map nodeFeatureCounter; + + public ReduceSplitsOperator() { + sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + } + + @Override + public void onSharedStorageContextSet(SharedStorageContext context) { + sharedStorageContext = context; + } + + @Override + public String getSharedStorageAccessorID() { + return sharedStorageAccessorID; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + sharedStorageContext.initializeState(this, getRuntimeContext(), context); + nodeFeatureMap = new HashMap<>(); + nodeBestSplit = new HashMap<>(); + nodeFeatureCounter = new HashMap<>(); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + sharedStorageContext.snapshotState(context); + } + + @Override + public void processElement(StreamRecord> element) + throws Exception { + if (nodeFeatureMap.isEmpty()) { + Preconditions.checkState(nodeBestSplit.isEmpty()); + nodeFeatureCounter.clear(); + sharedStorageContext.invoke( + (getter, setter) -> { + int[] nodeFeaturePairs = + getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS); + for (int i = 0; i < nodeFeaturePairs.length / 2; i += 1) { + int nodeId = nodeFeaturePairs[2 * i]; + nodeFeatureCounter.compute(nodeId, (k, v) -> null == v ? 1 : v + 1); + } + }); + } + + Tuple3 value = element.getValue(); + int nodeId = value.f0; + int pairId = value.f1; + Split split = value.f2; + BitSet featureMap = nodeFeatureMap.getOrDefault(nodeId, new BitSet()); + if (featureMap.isEmpty()) { + LOG.debug("Received split for new node {}", nodeId); + } + sharedStorageContext.invoke( + (getter, setter) -> { + int[] nodeFeaturePairs = getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS); + Preconditions.checkState(nodeId == nodeFeaturePairs[pairId * 2]); + int featureId = nodeFeaturePairs[pairId * 2 + 1]; + Preconditions.checkState(!featureMap.get(featureId)); + featureMap.set(featureId); + }); + nodeFeatureMap.put(nodeId, featureMap); + + nodeBestSplit.compute(nodeId, (k, v) -> null == v ? split : v.accumulate(split)); + if (featureMap.cardinality() == nodeFeatureCounter.get(nodeId)) { + output.collect(new StreamRecord<>(Tuple2.of(nodeId, nodeBestSplit.get(nodeId)))); + LOG.debug("Output accumulated split for node {}", nodeId); + nodeBestSplit.remove(nodeId); + nodeFeatureMap.remove(nodeId); + nodeFeatureCounter.remove(nodeId); + } + } + + @Override + public void close() throws Exception { + sharedStorageContext.clear(); + super.close(); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java index 767973925..458ac0b9c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitFinder.java @@ -18,39 +18,24 @@ package org.apache.flink.ml.common.gbt.operators; +import org.apache.flink.ml.common.gbt.DataUtils; import org.apache.flink.ml.common.gbt.defs.FeatureMeta; import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Slice; import org.apache.flink.ml.common.gbt.defs.Split; -import org.apache.flink.ml.common.gbt.defs.Splits; import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.splitter.CategoricalFeatureSplitter; import org.apache.flink.ml.common.gbt.splitter.ContinuousFeatureSplitter; import org.apache.flink.ml.common.gbt.splitter.HistogramFeatureSplitter; -import org.apache.flink.ml.util.Distributor; import org.apache.flink.util.Preconditions; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - class SplitFinder { - private static final Logger LOG = LoggerFactory.getLogger(SplitFinder.class); - - private final int subtaskId; - private final int numSubtasks; - private final int[] numFeatureBins; private final HistogramFeatureSplitter[] splitters; private final int maxDepth; private final int maxNumLeaves; public SplitFinder(TrainContext trainContext) { - subtaskId = trainContext.subtaskId; - numSubtasks = trainContext.numSubtasks; - - numFeatureBins = trainContext.numFeatureBins; FeatureMeta[] featureMetas = trainContext.featureMetas; int numFeatures = trainContext.numFeatures; splitters = new HistogramFeatureSplitter[numFeatures + 1]; @@ -72,35 +57,11 @@ public SplitFinder(TrainContext trainContext) { maxNumLeaves = trainContext.strategy.maxNumLeaves; } - public Splits calc( - List layer, int[] nodeFeaturePairs, int numLeaves, Histogram histogram) { - LOG.info("subtaskId: {}, {} start", subtaskId, SplitFinder.class.getSimpleName()); - - Distributor distributor = - new Distributor.EvenDistributor(numSubtasks, nodeFeaturePairs.length / 2); - int start = (int) distributor.start(subtaskId); - int cnt = (int) distributor.count(subtaskId); - - Split[] nodesBestSplits = new Split[layer.size()]; - int binOffset = 0; - for (int i = start; i < start + cnt; i += 1) { - int nodeId = nodeFeaturePairs[2 * i]; - int featureId = nodeFeaturePairs[2 * i + 1]; - LearningNode node = layer.get(nodeId); - - Preconditions.checkState(node.depth < maxDepth || numLeaves + 2 <= maxNumLeaves); - Preconditions.checkState(histogram.slice.start == 0); - splitters[featureId].reset( - histogram.hists, new Slice(binOffset, binOffset + numFeatureBins[featureId])); - Split bestSplit = splitters[featureId].bestSplit(); - if (null == nodesBestSplits[nodeId] - || (bestSplit.gain > nodesBestSplits[nodeId].gain)) { - nodesBestSplits[nodeId] = bestSplit; - } - binOffset += numFeatureBins[featureId]; - } - - LOG.info("subtaskId: {}, {} end", subtaskId, SplitFinder.class.getSimpleName()); - return new Splits(subtaskId, nodesBestSplits); + public Split calc(LearningNode node, int featureId, int numLeaves, Histogram histogram) { + Preconditions.checkState(node.depth < maxDepth || numLeaves + 2 <= maxNumLeaves); + Preconditions.checkState(histogram.slice.start == 0); + splitters[featureId].reset( + histogram.hists, new Slice(0, histogram.hists.length / DataUtils.BIN_SIZE)); + return splitters[featureId].bestSplit(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java deleted file mode 100644 index 8b1c0cee4..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SplitsAggregateFunction.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.operators; - -import org.apache.flink.api.common.functions.AggregateFunction; -import org.apache.flink.api.common.functions.RichFlatMapFunction; -import org.apache.flink.ml.common.gbt.defs.Splits; -import org.apache.flink.util.Collector; -import org.apache.flink.util.Preconditions; - -import java.util.BitSet; - -/** Aggregation function for merging splits. */ -public class SplitsAggregateFunction extends RichFlatMapFunction { - - private final AggregateFunction aggregator = new Splits.Aggregator(); - private int numSubtasks; - private BitSet accepted; - private Splits acc = null; - - @Override - public void flatMap(Splits value, Collector out) throws Exception { - if (null == accepted) { - numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); - accepted = new BitSet(numSubtasks); - } - int receivedSubtaskId = value.subtaskId; - Preconditions.checkState(!accepted.get(receivedSubtaskId)); - accepted.set(receivedSubtaskId); - acc = aggregator.add(value, acc); - if (numSubtasks == accepted.cardinality()) { - acc.subtaskId = getRuntimeContext().getIndexOfThisSubtask(); - out.collect(acc); - accepted = null; - acc = null; - } - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java index 6b0226f6e..7c124b68e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TrainContextInitializer.java @@ -86,7 +86,7 @@ public TrainContext init( .mapToInt(d -> d.numBins(trainContext.strategy.useMissing)) .toArray(), 1); - LOG.info("Number of bins for each feature: {}", trainContext.numFeatureBins); + LOG.debug("Number of bins for each feature: {}", trainContext.numFeatureBins); LOG.info("subtaskId: {}, {} end", subtaskId, TrainContextInitializer.class.getSimpleName()); return trainContext; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java index e63e1da9e..4cb9090a1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TreeInitializer.java @@ -48,11 +48,12 @@ public TreeInitializer(TrainContext trainContext) { } /** Calculate local histograms for nodes in current layer of tree. */ - public void init(Consumer shuffledIndicesSetter) { + public void init(int numTrees, Consumer shuffledIndicesSetter) { LOG.info("subtaskId: {}, {} start", subtaskId, TreeInitializer.class.getSimpleName()); // Initializes the root node of a new tree when last tree is finalized. DataUtils.shuffle(shuffledIndices, instanceRandomizer); shuffledIndicesSetter.accept(shuffledIndices); + LOG.info("subtaskId: {}, initialize {}-th tree", subtaskId, numTrees + 1); LOG.info("subtaskId: {}, {} end", this.subtaskId, TreeInitializer.class.getSimpleName()); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java index 2b151c87b..a4bb4e28e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/HistogramSerializer.java @@ -20,7 +20,6 @@ import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; @@ -54,7 +53,6 @@ public Histogram createInstance() { @Override public Histogram copy(Histogram from) { Histogram histogram = new Histogram(); - histogram.subtaskId = from.subtaskId; histogram.hists = ArrayUtils.subarray(from.hists, from.slice.start, from.slice.end); histogram.slice.start = 0; histogram.slice.end = from.slice.size(); @@ -73,7 +71,6 @@ public int getLength() { @Override public void serialize(Histogram record, DataOutputView target) throws IOException { - target.writeInt(record.subtaskId); // Only writes valid slice of `hists`. histsSerializer.serialize(record.hists, record.slice.start, record.slice.size(), target); } @@ -81,7 +78,6 @@ public void serialize(Histogram record, DataOutputView target) throws IOExceptio @Override public Histogram deserialize(DataInputView source) throws IOException { Histogram histogram = new Histogram(); - histogram.subtaskId = IntSerializer.INSTANCE.deserialize(source); histogram.hists = histsSerializer.deserialize(source); histogram.slice = new Slice(0, histogram.hists.length); return histogram; @@ -89,7 +85,6 @@ public Histogram deserialize(DataInputView source) throws IOException { @Override public Histogram deserialize(Histogram reuse, DataInputView source) throws IOException { - reuse.subtaskId = IntSerializer.INSTANCE.deserialize(source); reuse.hists = histsSerializer.deserialize(reuse.hists, source); reuse.slice.start = 0; reuse.slice.end = reuse.hists.length; From 513ed960a4212cfba06908d12fc93d450a586590 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 13 Mar 2023 16:19:09 +0800 Subject: [PATCH 40/47] Simplify APIs in SharedStorageContext --- .../AbstractSharedStorageWrapperOperator.java | 18 +- .../OneInputSharedStorageWrapperOperator.java | 5 +- .../common/sharedstorage/SharedStorage.java | 31 +++- .../sharedstorage/SharedStorageBody.java | 6 +- .../sharedstorage/SharedStorageContext.java | 27 --- .../SharedStorageContextImpl.java | 24 +-- .../sharedstorage/SharedStorageUtils.java | 10 +- .../{operator => }/SharedStorageWrapper.java | 10 +- .../TwoInputSharedStorageWrapperOperator.java | 5 +- .../sharedstorage/SharedStorageUtilsTest.java | 163 ++++++++++++++++++ .../ml/common/gbt/BoostIterationBody.java | 7 +- .../CacheDataCalcLocalHistsOperator.java | 4 - .../operators/CalcLocalSplitsOperator.java | 2 - .../gbt/operators/PostSplitsOperator.java | 4 - .../gbt/operators/ReduceSplitsOperator.java | 8 - .../gbt/operators/TerminationOperator.java | 1 - 16 files changed, 241 insertions(+), 84 deletions(-) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/{operator => }/AbstractSharedStorageWrapperOperator.java (95%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/{operator => }/OneInputSharedStorageWrapperOperator.java (94%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/{operator => }/SharedStorageWrapper.java (92%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/{operator => }/TwoInputSharedStorageWrapperOperator.java (95%) create mode 100644 flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/AbstractSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/AbstractSharedStorageWrapperOperator.java similarity index 95% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/AbstractSharedStorageWrapperOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/AbstractSharedStorageWrapperOperator.java index 274fd9a16..0bccef8b7 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/AbstractSharedStorageWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/AbstractSharedStorageWrapperOperator.java @@ -16,15 +16,13 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage.operator; +package org.apache.flink.ml.common.sharedstorage; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext; import org.apache.flink.metrics.groups.OperatorMetricGroup; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.OperatorID; @@ -34,6 +32,7 @@ import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; import org.apache.flink.streaming.api.operators.Output; @@ -73,6 +72,7 @@ abstract class AbstractSharedStorageWrapperOperator> output; protected final StreamOperatorFactory operatorFactory; + private final SharedStorageContextImpl context; protected final OperatorMetricGroup metrics; protected final S wrappedOperator; protected transient StreamOperatorStateHandler stateHandler; @@ -83,12 +83,13 @@ abstract class AbstractSharedStorageWrapperOperator parameters, StreamOperatorFactory operatorFactory, - SharedStorageContext context) { + SharedStorageContextImpl context) { this.parameters = Objects.requireNonNull(parameters); this.streamConfig = Objects.requireNonNull(parameters.getStreamConfig()); this.containingTask = Objects.requireNonNull(parameters.getContainingTask()); this.output = Objects.requireNonNull(parameters.getOutput()); this.operatorFactory = Objects.requireNonNull(operatorFactory); + this.context = context; this.metrics = createOperatorMetricGroup(containingTask.getEnvironment(), streamConfig); this.wrappedOperator = (S) @@ -134,6 +135,7 @@ public void open() throws Exception { @Override public void close() throws Exception { wrappedOperator.close(); + context.clear(); } @Override @@ -148,10 +150,16 @@ public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { @Override public void initializeState(StateInitializationContext stateInitializationContext) - throws Exception {} + throws Exception { + context.initializeState( + wrappedOperator, + ((AbstractStreamOperator) wrappedOperator).getRuntimeContext(), + stateInitializationContext); + } @Override public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception { + context.snapshotState(stateSnapshotContext); if (wrappedOperator instanceof StreamOperatorStateHandler.CheckpointedStreamOperator) { ((CheckpointedStreamOperator) wrappedOperator).snapshotState(stateSnapshotContext); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/OneInputSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/OneInputSharedStorageWrapperOperator.java similarity index 94% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/OneInputSharedStorageWrapperOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/OneInputSharedStorageWrapperOperator.java index 6e4bc0cd4..6b427f03a 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/OneInputSharedStorageWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/OneInputSharedStorageWrapperOperator.java @@ -16,10 +16,9 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage.operator; +package org.apache.flink.ml.common.sharedstorage; import org.apache.flink.iteration.operator.OperatorUtils; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; @@ -37,7 +36,7 @@ class OneInputSharedStorageWrapperOperator OneInputSharedStorageWrapperOperator( StreamOperatorParameters parameters, StreamOperatorFactory operatorFactory, - SharedStorageContext context) { + SharedStorageContextImpl context) { super(parameters, operatorFactory, context); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java index 747bcd6cf..c0ca41594 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java @@ -41,10 +41,30 @@ class SharedStorage { private static final Map, String> owners = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap, Integer> + numItemRefs = new ConcurrentHashMap<>(); + + static int incRef(Tuple3 t) { + return numItemRefs.compute(t, (k, oldV) -> null == oldV ? 1 : oldV + 1); + } + + static int decRef(Tuple3 t) { + int numRefs = numItemRefs.compute(t, (k, oldV) -> oldV - 1); + if (numRefs == 0) { + m.remove(t); + owners.remove(t); + numItemRefs.remove(t); + } + return numRefs; + } + /** Gets a {@link Reader} of shared item identified by (storageID, subtaskId, descriptor). */ static Reader getReader( StorageID storageID, int subtaskId, ItemDescriptor descriptor) { - return new Reader<>(Tuple3.of(storageID, subtaskId, descriptor.key)); + Tuple3 t = Tuple3.of(storageID, subtaskId, descriptor.key); + Reader reader = new Reader<>(t); + incRef(t); + return reader; } /** Gets a {@link Writer} of shared item identified by (storageID, subtaskId, key). */ @@ -75,6 +95,7 @@ static Writer getWriter( stateInitializationContext, operatorID); writer.set(descriptor.initVal); + incRef(t); return writer; } @@ -105,6 +126,10 @@ T get() { throw new IllegalStateException( String.format("Failed to get value of %s after waiting %d ms.", t, waitTime)); } + + void remove() { + decRef(t); + } } static class Writer extends Reader { @@ -154,10 +179,10 @@ void set(T value) { isDirty = true; } + @Override void remove() { ensureOwner(); - m.remove(t); - owners.remove(t); + super.remove(); cache.clear(); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java index a7a59b28c..e9a4be7bf 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java @@ -66,12 +66,12 @@ class SharedStorageBodyResult { * {@link SharedStorageStreamOperator#getSharedStorageAccessorID()}, which must be kept * unchanged for an instance of {@link SharedStorageStreamOperator}. */ - private final Map, String> ownerMap; + private final Map, SharedStorageStreamOperator> ownerMap; public SharedStorageBodyResult( List> outputs, List> accessors, - Map, String> ownerMap) { + Map, SharedStorageStreamOperator> ownerMap) { this.outputs = outputs; this.accessors = accessors; this.ownerMap = ownerMap; @@ -85,7 +85,7 @@ public List> getAccessors() { return accessors; } - public Map, String> getOwnerMap() { + public Map, SharedStorageStreamOperator> getOwnerMap() { return ownerMap; } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java index 534051332..6117c6623 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java @@ -19,12 +19,6 @@ package org.apache.flink.ml.common.sharedstorage; import org.apache.flink.annotation.Experimental; -import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.StreamOperator; -import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler; -import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.util.function.BiConsumerWithException; /** @@ -32,17 +26,6 @@ * have an instance of this context set by {@link * SharedStorageStreamOperator#onSharedStorageContextSet} in runtime. User defined logic can be * invoked through {@link #invoke} with the access to shared items. - * - *

NOTE: The corresponding operator must explicitly invoke - * - *

    - *
  • {@link #initializeState} to initialize this context and possibly restore data items owned - * by itself in {@link StreamOperatorStateHandler.CheckpointedStreamOperator#initializeState}; - *
  • {@link #snapshotState} in order to save data items owned by itself in {@link - * StreamOperatorStateHandler.CheckpointedStreamOperator#snapshotState}; - *
  • {@link #clear()} in order to clear all data items owned by itself in {@link - * StreamOperator#close}. - *
*/ @Experimental public interface SharedStorageContext { @@ -56,16 +39,6 @@ public interface SharedStorageContext { void invoke(BiConsumerWithException func) throws Exception; - /** Initializes shared storage context and restores of shared items owned by this operator. */ - & SharedStorageStreamOperator> void initializeState( - T operator, StreamingRuntimeContext runtimeContext, StateInitializationContext context); - - /** Save shared items owned by this operator. */ - void snapshotState(StateSnapshotContext context) throws Exception; - - /** Clear all internal states. */ - void clear(); - /** Interface of shared item getter. */ @FunctionalInterface interface SharedItemGetter { diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java index 434f78ad3..25e08d573 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java @@ -21,6 +21,7 @@ import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.util.Preconditions; import org.apache.flink.util.function.BiConsumerWithException; @@ -41,7 +42,7 @@ public SharedStorageContextImpl() { this.storageID = new StorageID(); } - public void setOwnerMap(Map, String> ownerMap) { + void setOwnerMap(Map, String> ownerMap) { this.ownerMap = ownerMap; } @@ -73,12 +74,14 @@ private void setSharedItem(ItemDescriptor key, T value) { writer.set(value); } - @Override - public & SharedStorageStreamOperator> void initializeState( - T operator, + void initializeState( + StreamOperator operator, StreamingRuntimeContext runtimeContext, StateInitializationContext context) { - String ownerId = operator.getSharedStorageAccessorID(); + Preconditions.checkArgument( + operator instanceof SharedStorageStreamOperator + && operator instanceof AbstractStreamOperator); + String ownerId = ((SharedStorageStreamOperator) operator).getSharedStorageAccessorID(); int subtaskId = runtimeContext.getIndexOfThisSubtask(); for (Map.Entry, String> entry : ownerMap.entrySet()) { ItemDescriptor descriptor = entry.getKey(); @@ -91,7 +94,7 @@ public & SharedStorageStreamOperator> void descriptor, ownerId, operator.getOperatorID(), - operator.getContainingTask(), + ((AbstractStreamOperator) operator).getContainingTask(), runtimeContext, context)); } @@ -99,18 +102,19 @@ public & SharedStorageStreamOperator> void } } - @Override - public void snapshotState(StateSnapshotContext context) throws Exception { + void snapshotState(StateSnapshotContext context) throws Exception { for (SharedStorage.Writer writer : writers.values()) { writer.snapshotState(context); } } - @Override - public void clear() { + void clear() { for (SharedStorage.Writer writer : writers.values()) { writer.remove(); } + for (SharedStorage.Reader reader : readers.values()) { + reader.remove(); + } writers.clear(); readers.clear(); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java index 9fe21d978..0276f1c0a 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java @@ -20,13 +20,14 @@ import org.apache.flink.annotation.Experimental; import org.apache.flink.iteration.compile.DraftExecutionEnvironment; -import org.apache.flink.ml.common.sharedstorage.operator.SharedStorageWrapper; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.util.Preconditions; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.stream.Collectors; @@ -65,7 +66,12 @@ public static List> withSharedStorage( SharedStorageBody.SharedStorageBodyResult result = body.process(draftSources); List> draftOutputs = result.getOutputs(); - context.setOwnerMap(result.getOwnerMap()); + Map, SharedStorageStreamOperator> rawOwnerMap = result.getOwnerMap(); + Map, String> ownerMap = new HashMap<>(); + for (ItemDescriptor item : rawOwnerMap.keySet()) { + ownerMap.put(item, rawOwnerMap.get(item).getSharedStorageAccessorID()); + } + context.setOwnerMap(ownerMap); for (DataStream draftOutput : draftOutputs) { draftEnv.addOperator(draftOutput.getTransformation()); diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageWrapper.java similarity index 92% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageWrapper.java index 5ce35f699..357f0899d 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/SharedStorageWrapper.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageWrapper.java @@ -16,13 +16,11 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage.operator; +package org.apache.flink.ml.common.sharedstorage; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.iteration.operator.OperatorWrapper; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; @@ -34,12 +32,12 @@ import org.apache.flink.util.OutputTag; /** The operator wrapper for {@link AbstractSharedStorageWrapperOperator}. */ -public class SharedStorageWrapper implements OperatorWrapper { +class SharedStorageWrapper implements OperatorWrapper { /** Shared storage context. */ - private final SharedStorageContext context; + private final SharedStorageContextImpl context; - public SharedStorageWrapper(SharedStorageContext context) { + public SharedStorageWrapper(SharedStorageContextImpl context) { this.context = context; } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/TwoInputSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/TwoInputSharedStorageWrapperOperator.java similarity index 95% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/TwoInputSharedStorageWrapperOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/TwoInputSharedStorageWrapperOperator.java index 03824a48f..2b23b7b4e 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/operator/TwoInputSharedStorageWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/TwoInputSharedStorageWrapperOperator.java @@ -16,10 +16,9 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage.operator; +package org.apache.flink.ml.common.sharedstorage; import org.apache.flink.iteration.operator.OperatorUtils; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; import org.apache.flink.streaming.api.operators.BoundedMultiInput; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperatorParameters; @@ -37,7 +36,7 @@ class TwoInputSharedStorageWrapperOperator TwoInputSharedStorageWrapperOperator( StreamOperatorParameters parameters, StreamOperatorFactory operatorFactory, - SharedStorageContext context) { + SharedStorageContextImpl context) { super(parameters, operatorFactory, context); } diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java new file mode 100644 index 000000000..5fd30daa2 --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedstorage; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** Tests the {@link SharedStorageUtils}. */ +public class SharedStorageUtilsTest { + + private static final ItemDescriptor SUM = + ItemDescriptor.of("sum", LongSerializer.INSTANCE, 0L); + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + static SharedStorageBody.SharedStorageBodyResult sharedStorageBody(List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + + AOperator aOp = new AOperator(); + SingleOutputStreamOperator afterAOp = + data.transform("a", TypeInformation.of(Long.class), aOp); + + BOperator bOp = new BOperator(); + SingleOutputStreamOperator afterBOp = + data.transform("b", TypeInformation.of(Long.class), bOp); + + Map, SharedStorageStreamOperator> ownerMap = new HashMap<>(); + ownerMap.put(SUM, aOp); + + return new SharedStorageBody.SharedStorageBodyResult( + Collections.singletonList(afterBOp), Arrays.asList(afterAOp, afterBOp), ownerMap); + } + + @Test + public void testSharedStorage() throws Exception { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + + DataStream data = env.fromSequence(1, 100); + List> outputs = + SharedStorageUtils.withSharedStorage( + Collections.singletonList(data), SharedStorageUtilsTest::sharedStorageBody); + //noinspection unchecked + DataStream partitionSum = (DataStream) outputs.get(0); + DataStream allSum = DataStreamUtils.reduce(partitionSum, new SumReduceFunction()); + allSum.getTransformation().setParallelism(1); + //noinspection unchecked + List results = IteratorUtils.toList(allSum.executeAndCollect()); + Assert.assertEquals(Collections.singletonList(5050L), results); + } + + /** Operator A: add input elements to the shared {@link #SUM}. */ + static class AOperator extends AbstractStreamOperator + implements OneInputStreamOperator, SharedStorageStreamOperator { + + private final String sharedStorageAccessorID; + private SharedStorageContext sharedStorageContext; + + public AOperator() { + sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + } + + @Override + public void onSharedStorageContextSet(SharedStorageContext context) { + this.sharedStorageContext = context; + } + + @Override + public String getSharedStorageAccessorID() { + return sharedStorageAccessorID; + } + + @Override + public void processElement(StreamRecord element) throws Exception { + sharedStorageContext.invoke( + (getter, setter) -> { + Long currentSum = getter.get(SUM); + setter.set(SUM, currentSum + element.getValue()); + }); + output.collect(element); + } + } + + /** Operator B: when input ends, get the value from shared {@link #SUM}. */ + static class BOperator extends AbstractStreamOperator + implements OneInputStreamOperator, + SharedStorageStreamOperator, + BoundedOneInput { + + private final String sharedStorageAccessorID; + private SharedStorageContext sharedStorageContext; + + public BOperator() { + sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + } + + @Override + public void onSharedStorageContextSet(SharedStorageContext context) { + this.sharedStorageContext = context; + } + + @Override + public String getSharedStorageAccessorID() { + return sharedStorageAccessorID; + } + + @Override + public void processElement(StreamRecord element) throws Exception {} + + @Override + public void endInput() throws Exception { + sharedStorageContext.invoke( + (getter, setter) -> { + output.collect(new StreamRecord<>(getter.get(SUM))); + }); + } + } + + static class SumReduceFunction implements ReduceFunction { + @Override + public Long reduce(Long value1, Long value2) { + return value1 + value2; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index 86f73e54c..85be1c91c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -38,6 +38,7 @@ import org.apache.flink.ml.common.gbt.operators.TerminationOperator; import org.apache.flink.ml.common.sharedstorage.ItemDescriptor; import org.apache.flink.ml.common.sharedstorage.SharedStorageBody; +import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.ml.common.sharedstorage.SharedStorageUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; @@ -67,7 +68,7 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( //noinspection unchecked DataStream trainContext = (DataStream) inputs.get(1); - Map, String> ownerMap = new HashMap<>(); + Map, SharedStorageStreamOperator> ownerMap = new HashMap<>(); CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = new CacheDataCalcLocalHistsOperator(strategy); @@ -79,7 +80,7 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( Types.INT, Types.INT, TypeInformation.of(Histogram.class)), cacheDataCalcLocalHistsOp); for (ItemDescriptor s : SharedStorageConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { - ownerMap.put(s, cacheDataCalcLocalHistsOp.getSharedStorageAccessorID()); + ownerMap.put(s, cacheDataCalcLocalHistsOp); } DataStream> globalHists = @@ -105,7 +106,7 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( .broadcast() .transform("PostSplits", TypeInformation.of(Integer.class), postSplitsOp); for (ItemDescriptor descriptor : SharedStorageConstants.OWNED_BY_POST_SPLITS_OP) { - ownerMap.put(descriptor, postSplitsOp.getSharedStorageAccessorID()); + ownerMap.put(descriptor, postSplitsOp); } final OutputTag finalModelDataOutputTag = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index 7b6c79be1..acbf2b6b6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -115,8 +115,6 @@ public void initializeState(StateInitializationContext context) throws Exception histBuilder = OperatorStateUtils.getUniqueElement(histBuilderState, HIST_BUILDER_STATE_NAME) .orElse(null); - - sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override @@ -125,7 +123,6 @@ public void snapshotState(StateSnapshotContext context) throws Exception { instancesCollecting.snapshotState(context); treeInitializerState.snapshotState(context); histBuilderState.snapshotState(context); - sharedStorageContext.snapshotState(context); } @Override @@ -263,7 +260,6 @@ public void close() throws Exception { instancesCollecting.clear(); treeInitializerState.clear(); histBuilderState.clear(); - sharedStorageContext.clear(); super.close(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java index dbc77b1c6..03ccf8342 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -77,8 +77,6 @@ public void initializeState(StateInitializationContext context) throws Exception splitFinder = OperatorStateUtils.getUniqueElement(splitFinderState, SPLIT_FINDER_STATE_NAME) .orElse(null); - - sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 5ae41968c..1ba439de9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -98,8 +98,6 @@ public void initializeState(StateInitializationContext context) throws Exception OperatorStateUtils.getUniqueElement( instanceUpdaterState, INSTANCE_UPDATER_STATE_NAME) .orElse(null); - - sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override @@ -107,7 +105,6 @@ public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); nodeSplitterState.snapshotState(context); instanceUpdaterState.snapshotState(context); - sharedStorageContext.snapshotState(context); } @Override @@ -218,7 +215,6 @@ public void processElement(StreamRecord> element) throws public void close() throws Exception { nodeSplitterState.clear(); instanceUpdaterState.clear(); - sharedStorageContext.clear(); super.close(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java index d0e5839f6..d33a4a8e1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java @@ -24,7 +24,6 @@ import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -74,17 +73,11 @@ public String getSharedStorageAccessorID() { @Override public void initializeState(StateInitializationContext context) throws Exception { - sharedStorageContext.initializeState(this, getRuntimeContext(), context); nodeFeatureMap = new HashMap<>(); nodeBestSplit = new HashMap<>(); nodeFeatureCounter = new HashMap<>(); } - @Override - public void snapshotState(StateSnapshotContext context) throws Exception { - sharedStorageContext.snapshotState(context); - } - @Override public void processElement(StreamRecord> element) throws Exception { @@ -132,7 +125,6 @@ public void processElement(StreamRecord> element @Override public void close() throws Exception { - sharedStorageContext.clear(); super.close(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index 8b1726949..ee711d78a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -49,7 +49,6 @@ public TerminationOperator(OutputTag modelDataOutputTag) { @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - sharedStorageContext.initializeState(this, getRuntimeContext(), context); } @Override From e709a35b312c53b4ed3cb55d7c8b185231f8d378 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 5 May 2023 16:58:42 +0800 Subject: [PATCH 41/47] Fix after merging master --- .../apache/flink/ml/util/ReadWriteUtils.java | 2 +- .../gbtclassifier/GBTClassifier.java | 2 +- .../regression/gbtregressor/GBTRegressor.java | 2 +- .../ml/classification/GBTClassifierTest.java | 13 +++++++++---- .../flink/ml/regression/GBTRegressorTest.java | 17 +++++++++++------ 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java index 300f753eb..989b5a8a8 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java @@ -343,7 +343,7 @@ public static Table loadModelData( TypeInformation typeInfo) { StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); Source source = - FileSource.forRecordStreamFormat(modelDecoder, new Path(getDataPath(path))).build(); + FileSource.forRecordStreamFormat(modelDecoder, FileUtils.getDataPath(path)).build(); DataStream modelDataStream = env.fromSource(source, WatermarkStrategy.noWatermarks(), "modelData", typeInfo); return tEnv.fromDataStream(modelDataStream); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java index 248aaea76..c33614556 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/gbtclassifier/GBTClassifier.java @@ -85,7 +85,7 @@ public GBTClassifierModel fit(Table... inputs) { tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), tEnv.fromDataStream(featureImportance) .renameColumns($("f0").as("featureImportance"))); - ReadWriteUtils.updateExistingParams(model, getParamMap()); + ParamUtils.updateExistingParams(model, getParamMap()); return model; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java index 1ccb7e194..81d06018b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/gbtregressor/GBTRegressor.java @@ -84,7 +84,7 @@ public GBTRegressorModel fit(Table... inputs) { tEnv.fromDataStream(modelData).renameColumns($("f0").as("modelData")), tEnv.fromDataStream(featureImportance) .renameColumns($("f0").as("featureImportance"))); - ReadWriteUtils.updateExistingParams(model, getParamMap()); + ParamUtils.updateExistingParams(model, getParamMap()); return model; } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java index 357a7303c..f0d62e1a5 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/GBTClassifierTest.java @@ -30,7 +30,7 @@ import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; -import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -406,7 +406,8 @@ public void testEstimatorSaveLoadAndPredict() throws Exception { .setMaxBins(3) .setSeed(123); GBTClassifier loadedGbtc = - TestUtils.saveAndReload(tEnv, gbtc, tempFolder.newFolder().getAbsolutePath()); + TestUtils.saveAndReload( + tEnv, gbtc, tempFolder.newFolder().getAbsolutePath(), GBTClassifier::load); GBTClassifierModel model = loadedGbtc.fit(inputTable); Assert.assertEquals( Collections.singletonList("modelData"), @@ -434,7 +435,11 @@ public void testModelSaveLoadAndPredict() throws Exception { .setSeed(123); GBTClassifierModel model = gbtc.fit(inputTable); GBTClassifierModel loadedModel = - TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); + TestUtils.saveAndReload( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + GBTClassifierModel::load); Table output = loadedModel.transform(inputTable)[0].select( $(gbtc.getPredictionCol()), @@ -503,7 +508,7 @@ public void testSetModelData() throws Exception { .setSeed(123); GBTClassifierModel modelA = gbtc.fit(inputTable); GBTClassifierModel modelB = new GBTClassifierModel().setModelData(modelA.getModelData()); - ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); + ParamUtils.updateExistingParams(modelB, modelA.getParamMap()); Table output = modelA.transform(inputTable)[0].select( $(gbtc.getPredictionCol()), diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java index ffcddd9aa..db1366d85 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/GBTRegressorTest.java @@ -29,7 +29,7 @@ import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.regression.gbtregressor.GBTRegressor; import org.apache.flink.ml.regression.gbtregressor.GBTRegressorModel; -import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -291,9 +291,10 @@ public void testEstimatorSaveLoadAndPredict() throws Exception { .setRegGamma(0.) .setMaxBins(3) .setSeed(123); - GBTRegressor loadedgbtr = - TestUtils.saveAndReload(tEnv, gbtr, tempFolder.newFolder().getAbsolutePath()); - GBTRegressorModel model = loadedgbtr.fit(inputTable); + GBTRegressor loadedGbtr = + TestUtils.saveAndReload( + tEnv, gbtr, tempFolder.newFolder().getAbsolutePath(), GBTRegressor::load); + GBTRegressorModel model = loadedGbtr.fit(inputTable); Assert.assertEquals( Collections.singletonList("modelData"), model.getModelData()[0].getResolvedSchema().getColumnNames()); @@ -316,7 +317,11 @@ public void testModelSaveLoadAndPredict() throws Exception { .setSeed(123); GBTRegressorModel model = gbtr.fit(inputTable); GBTRegressorModel loadedModel = - TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); + TestUtils.saveAndReload( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + GBTRegressorModel::load); Table output = loadedModel.transform(inputTable)[0].select($(gbtr.getPredictionCol())); verifyPredictionResult(output, outputRows); } @@ -381,7 +386,7 @@ public void testSetModelData() throws Exception { .setSeed(123); GBTRegressorModel modelA = gbtr.fit(inputTable); GBTRegressorModel modelB = new GBTRegressorModel().setModelData(modelA.getModelData()); - ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); + ParamUtils.updateExistingParams(modelB, modelA.getParamMap()); Table output = modelA.transform(inputTable)[0].select($(gbtr.getPredictionCol())); verifyPredictionResult(output, outputRows); } From c109c76a337a168a05460af87e0707cf50008432 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 5 May 2023 17:12:07 +0800 Subject: [PATCH 42/47] Fix SharedStorageUtilsTest --- .../sharedstorage/SharedStorageUtilsTest.java | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java index 5fd30daa2..29fcdda15 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java @@ -61,7 +61,7 @@ static SharedStorageBody.SharedStorageBodyResult sharedStorageBody(List afterBOp = - data.transform("b", TypeInformation.of(Long.class), bOp); + afterAOp.transform("b", TypeInformation.of(Long.class), bOp); Map, SharedStorageStreamOperator> ownerMap = new HashMap<>(); ownerMap.put(SUM, aOp); @@ -89,7 +89,9 @@ public void testSharedStorage() throws Exception { /** Operator A: add input elements to the shared {@link #SUM}. */ static class AOperator extends AbstractStreamOperator - implements OneInputStreamOperator, SharedStorageStreamOperator { + implements OneInputStreamOperator, + SharedStorageStreamOperator, + BoundedOneInput { private final String sharedStorageAccessorID; private SharedStorageContext sharedStorageContext; @@ -115,15 +117,18 @@ public void processElement(StreamRecord element) throws Exception { Long currentSum = getter.get(SUM); setter.set(SUM, currentSum + element.getValue()); }); - output.collect(element); + } + + @Override + public void endInput() throws Exception { + // Informs BOperator to get the value from shared {@link #SUM}. + output.collect(new StreamRecord<>(0L)); } } /** Operator B: when input ends, get the value from shared {@link #SUM}. */ static class BOperator extends AbstractStreamOperator - implements OneInputStreamOperator, - SharedStorageStreamOperator, - BoundedOneInput { + implements OneInputStreamOperator, SharedStorageStreamOperator { private final String sharedStorageAccessorID; private SharedStorageContext sharedStorageContext; @@ -143,10 +148,7 @@ public String getSharedStorageAccessorID() { } @Override - public void processElement(StreamRecord element) throws Exception {} - - @Override - public void endInput() throws Exception { + public void processElement(StreamRecord element) throws Exception { sharedStorageContext.invoke( (getter, setter) -> { output.collect(new StreamRecord<>(getter.get(SUM))); From 5313d5181b80927d9f30f95e5f7f17c0a501e28c Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Fri, 5 May 2023 19:04:21 +0800 Subject: [PATCH 43/47] Rename shared storage to shared objects and change according to comments --- ...AbstractSharedObjectsWrapperOperator.java} | 20 ++-- .../ItemDescriptor.java | 25 ++--- ...OneInputSharedObjectsWrapperOperator.java} | 10 +- .../PoolID.java} | 10 +- .../SharedObjectsBody.java} | 51 +++++----- .../SharedObjectsContext.java} | 12 +-- .../SharedObjectsContextImpl.java} | 42 +++++---- .../SharedObjectsPools.java} | 94 +++++++++++-------- .../SharedObjectsStreamOperator.java} | 14 +-- .../SharedObjectsUtils.java} | 33 +++---- .../SharedObjectsWrapper.java} | 24 ++--- ...TwoInputSharedObjectsWrapperOperator.java} | 10 +- .../SharedObjectsUtilsTest.java} | 58 ++++++------ .../flink/ml/common/gbt/BaseGBTParams.java | 14 ++- .../ml/common/gbt/BoostIterationBody.java | 31 +++--- .../CacheDataCalcLocalHistsOperator.java | 70 +++++++------- .../operators/CalcLocalSplitsOperator.java | 34 +++---- .../gbt/operators/PostSplitsOperator.java | 80 ++++++++-------- .../gbt/operators/ReduceSplitsOperator.java | 28 +++--- ...tants.java => SharedObjectsConstants.java} | 10 +- .../gbt/operators/TerminationOperator.java | 32 +++---- 21 files changed, 372 insertions(+), 330 deletions(-) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/AbstractSharedStorageWrapperOperator.java => sharedobjects/AbstractSharedObjectsWrapperOperator.java} (95%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage => sharedobjects}/ItemDescriptor.java (68%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/OneInputSharedStorageWrapperOperator.java => sharedobjects/OneInputSharedObjectsWrapperOperator.java} (91%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/StorageID.java => sharedobjects/PoolID.java} (83%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/SharedStorageBody.java => sharedobjects/SharedObjectsBody.java} (58%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/SharedStorageContext.java => sharedobjects/SharedObjectsContext.java} (82%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/SharedStorageContextImpl.java => sharedobjects/SharedObjectsContextImpl.java} (73%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/SharedStorage.java => sharedobjects/SharedObjectsPools.java} (64%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/SharedStorageStreamOperator.java => sharedobjects/SharedObjectsStreamOperator.java} (72%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/SharedStorageUtils.java => sharedobjects/SharedObjectsUtils.java} (72%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/SharedStorageWrapper.java => sharedobjects/SharedObjectsWrapper.java} (84%) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/{sharedstorage/TwoInputSharedStorageWrapperOperator.java => sharedobjects/TwoInputSharedObjectsWrapperOperator.java} (92%) rename flink-ml-core/src/test/java/org/apache/flink/ml/common/{sharedstorage/SharedStorageUtilsTest.java => sharedobjects/SharedObjectsUtilsTest.java} (74%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/{SharedStorageConstants.java => SharedObjectsConstants.java} (95%) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/AbstractSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java similarity index 95% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/AbstractSharedStorageWrapperOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java index 0bccef8b7..6b1f3e5cf 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/AbstractSharedStorageWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.ManagedMemoryUseCase; @@ -56,12 +56,12 @@ import java.util.Objects; import java.util.Optional; -/** Base class for the shared storage wrapper operators. */ -abstract class AbstractSharedStorageWrapperOperator> +/** Base class for the shared objects wrapper operators. */ +abstract class AbstractSharedObjectsWrapperOperator> implements StreamOperator, IterationListener, CheckpointedStreamOperator { private static final Logger LOG = - LoggerFactory.getLogger(AbstractSharedStorageWrapperOperator.class); + LoggerFactory.getLogger(AbstractSharedObjectsWrapperOperator.class); protected final StreamOperatorParameters parameters; @@ -72,7 +72,7 @@ abstract class AbstractSharedStorageWrapperOperator> output; protected final StreamOperatorFactory operatorFactory; - private final SharedStorageContextImpl context; + private final SharedObjectsContextImpl context; protected final OperatorMetricGroup metrics; protected final S wrappedOperator; protected transient StreamOperatorStateHandler stateHandler; @@ -80,10 +80,10 @@ abstract class AbstractSharedStorageWrapperOperator timeServiceManager; @SuppressWarnings({"unchecked", "rawtypes"}) - AbstractSharedStorageWrapperOperator( + AbstractSharedObjectsWrapperOperator( StreamOperatorParameters parameters, StreamOperatorFactory operatorFactory, - SharedStorageContextImpl context) { + SharedObjectsContextImpl context) { this.parameters = Objects.requireNonNull(parameters); this.streamConfig = Objects.requireNonNull(parameters.getStreamConfig()); this.containingTask = Objects.requireNonNull(parameters.getContainingTask()); @@ -101,11 +101,11 @@ abstract class AbstractSharedStorageWrapperOperator implements Serializable { /** Name of the item. */ - public String key; + public final String name; /** Type serializer. */ - public TypeSerializer serializer; + public final TypeSerializer serializer; /** Initialize value. */ - public T initVal; + public final T initVal; - private ItemDescriptor(String key, TypeSerializer serializer, T initVal) { - this.key = key; + private ItemDescriptor(String name, TypeSerializer serializer, T initVal) { + Preconditions.checkNotNull( + initVal, "Cannot use `null` as the initial value of a shared item."); + this.name = name; this.serializer = serializer; this.initVal = initVal; } - public static ItemDescriptor of(String key, TypeSerializer serializer, T initVal) { - return new ItemDescriptor<>(key, serializer, initVal); + public static ItemDescriptor of(String name, TypeSerializer serializer, T initVal) { + return new ItemDescriptor<>(name, serializer, initVal); } @Override public int hashCode() { - return key.hashCode(); + return name.hashCode(); } @Override @@ -64,12 +67,12 @@ public boolean equals(Object o) { return false; } ItemDescriptor that = (ItemDescriptor) o; - return key.equals(that.key); + return name.equals(that.name); } @Override public String toString() { return String.format( - "ItemDescriptor{key='%s', serializer=%s, initVal=%s}", key, serializer, initVal); + "ItemDescriptor{name='%s', serializer=%s, initVal=%s}", name, serializer, initVal); } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/OneInputSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java similarity index 91% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/OneInputSharedStorageWrapperOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java index 6b427f03a..b6a197f5c 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/OneInputSharedStorageWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.streaming.api.operators.BoundedOneInput; @@ -29,14 +29,14 @@ import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; /** Wrapper for {@link OneInputStreamOperator}. */ -class OneInputSharedStorageWrapperOperator - extends AbstractSharedStorageWrapperOperator> +class OneInputSharedObjectsWrapperOperator + extends AbstractSharedObjectsWrapperOperator> implements OneInputStreamOperator, BoundedOneInput { - OneInputSharedStorageWrapperOperator( + OneInputSharedObjectsWrapperOperator( StreamOperatorParameters parameters, StreamOperatorFactory operatorFactory, - SharedStorageContextImpl context) { + SharedObjectsContextImpl context) { super(parameters, operatorFactory, context); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/StorageID.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/PoolID.java similarity index 83% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/StorageID.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/PoolID.java index 123edcb64..77ff6573e 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/StorageID.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/PoolID.java @@ -16,17 +16,17 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.util.AbstractID; -/** ID of a shared storage. */ -class StorageID extends AbstractID { +/** ID of a pool for shared objects. */ +class PoolID extends AbstractID { private static final long serialVersionUID = 1L; - public StorageID(byte[] bytes) { + public PoolID(byte[] bytes) { super(bytes); } - public StorageID() {} + public PoolID() {} } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java similarity index 58% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java index e9a4be7bf..64c737c70 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageBody.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java @@ -16,9 +16,10 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.annotation.Experimental; +import org.apache.flink.api.dag.Transformation; import org.apache.flink.streaming.api.datastream.DataStream; import java.io.Serializable; @@ -26,54 +27,52 @@ import java.util.Map; /** - * The builder of the subgraph that will be executed with a common shared storage. Users can only + * The builder of the subgraph that will be executed with a common shared objects. Users can only * create data streams from {@code inputs}. Users can not refer to data streams outside, and can not * add sources/sinks. * - *

The shared storage body requires all streams accessing the shared storage, i.e., {@link - * SharedStorageBodyResult#accessors} have same parallelism and can be co-located. + *

The shared objects body requires all transformations accessing the shared objects, i.e., + * {@link SharedObjectsBodyResult#coLocatedTransformations}, to have same parallelism and can be + * co-located. */ @Experimental @FunctionalInterface -public interface SharedStorageBody extends Serializable { +public interface SharedObjectsBody extends Serializable { /** - * This method creates the subgraph for the shared storage body. + * This method creates the subgraph for the shared objects body. * * @param inputs Input data streams. * @return Result of the subgraph, including output data streams, data streams with access to - * the shared storage, and a mapping from share items to their owners. + * the shared objects, and a mapping from share items to their owners. */ - SharedStorageBodyResult process(List> inputs); + SharedObjectsBodyResult process(List> inputs); /** - * The result of a {@link SharedStorageBody}, including output data streams, data streams with - * access to the shared storage, and a mapping from descriptors of share items to their owners. + * The result of a {@link SharedObjectsBody}, including output data streams, data streams with + * access to the shared objects, and a mapping from descriptors of share items to their owners. */ @Experimental - class SharedStorageBodyResult { + class SharedObjectsBodyResult { /** A list of output streams. */ private final List> outputs; - /** - * A list of data streams which access to the shared storage. All data streams in the list - * should implement {@link SharedStorageStreamOperator}. - */ - private final List> accessors; + /** A list of {@link Transformation}s that should be co-located. */ + private final List> coLocatedTransformations; /** * A mapping from descriptors of shared items to their owners. The owner is specified by - * {@link SharedStorageStreamOperator#getSharedStorageAccessorID()}, which must be kept - * unchanged for an instance of {@link SharedStorageStreamOperator}. + * {@link SharedObjectsStreamOperator#getSharedObjectsAccessorID()}, which must be kept + * unchanged for an instance of {@link SharedObjectsStreamOperator}. */ - private final Map, SharedStorageStreamOperator> ownerMap; + private final Map, SharedObjectsStreamOperator> ownerMap; - public SharedStorageBodyResult( + public SharedObjectsBodyResult( List> outputs, - List> accessors, - Map, SharedStorageStreamOperator> ownerMap) { + List> coLocatedTransformations, + Map, SharedObjectsStreamOperator> ownerMap) { this.outputs = outputs; - this.accessors = accessors; + this.coLocatedTransformations = coLocatedTransformations; this.ownerMap = ownerMap; } @@ -81,11 +80,11 @@ public List> getOutputs() { return outputs; } - public List> getAccessors() { - return accessors; + public List> getCoLocatedTransformations() { + return coLocatedTransformations; } - public Map, SharedStorageStreamOperator> getOwnerMap() { + public Map, SharedObjectsStreamOperator> getOwnerMap() { return ownerMap; } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java similarity index 82% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java index 6117c6623..f5419780e 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContext.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java @@ -16,22 +16,22 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.annotation.Experimental; import org.apache.flink.util.function.BiConsumerWithException; /** - * Context for shared storage. Every operator implementing {@link SharedStorageStreamOperator} will - * have an instance of this context set by {@link - * SharedStorageStreamOperator#onSharedStorageContextSet} in runtime. User defined logic can be + * Context for shared objects. Every operator implementing {@link SharedObjectsStreamOperator} will + * get an instance of this context set by {@link + * SharedObjectsStreamOperator#onSharedObjectsContextSet} in runtime. User-defined logic can be * invoked through {@link #invoke} with the access to shared items. */ @Experimental -public interface SharedStorageContext { +public interface SharedObjectsContext { /** - * Invoke user defined function with provided getters/setters of the shared storage. + * Invoke user defined function with provided getters/setters of the shared objects. * * @param func User defined function where share items can be accessed through getters/setters. * @throws Exception Possible exception. diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java similarity index 73% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java index 25e08d573..0c57b9cef 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageContextImpl.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; @@ -30,16 +30,22 @@ import java.util.HashMap; import java.util.Map; -/** Default implementation of {@link SharedStorageContext} using {@link SharedStorage}. */ +/** + * Default implementation of {@link SharedObjectsContext}. + * + *

It initializes readers and writers according to the owner map when the subtask starts and + * clean internal states when the subtask finishes. It also handles `initializeState` and + * `snapshotState` automatically. + */ @SuppressWarnings("rawtypes") -class SharedStorageContextImpl implements SharedStorageContext, Serializable { - private final StorageID storageID; - private final Map writers = new HashMap<>(); - private final Map readers = new HashMap<>(); +class SharedObjectsContextImpl implements SharedObjectsContext, Serializable { + private final PoolID poolID; + private final Map writers = new HashMap<>(); + private final Map readers = new HashMap<>(); private Map, String> ownerMap; - public SharedStorageContextImpl() { - this.storageID = new StorageID(); + public SharedObjectsContextImpl() { + this.poolID = new PoolID(); } void setOwnerMap(Map, String> ownerMap) { @@ -54,7 +60,7 @@ public void invoke(BiConsumerWithException T getSharedItem(ItemDescriptor key) { //noinspection unchecked - SharedStorage.Reader reader = readers.get(key); + SharedObjectsPools.Reader reader = readers.get(key); Preconditions.checkState( null != reader, String.format( @@ -65,7 +71,7 @@ private T getSharedItem(ItemDescriptor key) { private void setSharedItem(ItemDescriptor key, T value) { //noinspection unchecked - SharedStorage.Writer writer = writers.get(key); + SharedObjectsPools.Writer writer = writers.get(key); Preconditions.checkState( null != writer, String.format( @@ -79,17 +85,17 @@ void initializeState( StreamingRuntimeContext runtimeContext, StateInitializationContext context) { Preconditions.checkArgument( - operator instanceof SharedStorageStreamOperator + operator instanceof SharedObjectsStreamOperator && operator instanceof AbstractStreamOperator); - String ownerId = ((SharedStorageStreamOperator) operator).getSharedStorageAccessorID(); + String ownerId = ((SharedObjectsStreamOperator) operator).getSharedObjectsAccessorID(); int subtaskId = runtimeContext.getIndexOfThisSubtask(); for (Map.Entry, String> entry : ownerMap.entrySet()) { ItemDescriptor descriptor = entry.getKey(); if (ownerId.equals(entry.getValue())) { writers.put( descriptor, - SharedStorage.getWriter( - storageID, + SharedObjectsPools.getWriter( + poolID, subtaskId, descriptor, ownerId, @@ -98,21 +104,21 @@ void initializeState( runtimeContext, context)); } - readers.put(descriptor, SharedStorage.getReader(storageID, subtaskId, descriptor)); + readers.put(descriptor, SharedObjectsPools.getReader(poolID, subtaskId, descriptor)); } } void snapshotState(StateSnapshotContext context) throws Exception { - for (SharedStorage.Writer writer : writers.values()) { + for (SharedObjectsPools.Writer writer : writers.values()) { writer.snapshotState(context); } } void clear() { - for (SharedStorage.Writer writer : writers.values()) { + for (SharedObjectsPools.Writer writer : writers.values()) { writer.remove(); } - for (SharedStorage.Reader reader : readers.values()) { + for (SharedObjectsPools.Reader reader : readers.values()) { reader.remove(); } writers.clear(); diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java similarity index 64% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java index c0ca41594..5be832635 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorage.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple3; @@ -33,43 +33,56 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -/** A shared storage to support access through subtasks of different operators. */ -class SharedStorage { - private static final Map, Object> m = +/** + * Stores and manages all shared objects. Every shared object is identified by a tuple of (Pool ID, + * subtask ID, name). Every call of {@link SharedObjectsUtils#withSharedObjects} generated a + * different {@link PoolID}, so that they do not interfere with each other. + */ +class SharedObjectsPools { + + // Stores values of all shared objects. + private static final Map, Object> values = new ConcurrentHashMap<>(); - private static final Map, String> owners = + /** + * Stores owners of all shared objects, where the owner is identified by the accessor ID + * obtained from {@link SharedObjectsStreamOperator#getSharedObjectsAccessorID()}. + */ + private static final Map, String> owners = new ConcurrentHashMap<>(); - private static final ConcurrentHashMap, Integer> - numItemRefs = new ConcurrentHashMap<>(); + // Stores number of references of all shared objects. A shared object is removed when its number + // of references decreased to 0. + private static final ConcurrentHashMap, Integer> numRefs = + new ConcurrentHashMap<>(); - static int incRef(Tuple3 t) { - return numItemRefs.compute(t, (k, oldV) -> null == oldV ? 1 : oldV + 1); + @SuppressWarnings("UnusedReturnValue") + static int incRef(Tuple3 itemId) { + return numRefs.compute(itemId, (k, oldV) -> null == oldV ? 1 : oldV + 1); } - static int decRef(Tuple3 t) { - int numRefs = numItemRefs.compute(t, (k, oldV) -> oldV - 1); - if (numRefs == 0) { - m.remove(t); - owners.remove(t); - numItemRefs.remove(t); + @SuppressWarnings("UnusedReturnValue") + static int decRef(Tuple3 itemId) { + int num = numRefs.compute(itemId, (k, oldV) -> oldV - 1); + if (num == 0) { + values.remove(itemId); + owners.remove(itemId); + numRefs.remove(itemId); } - return numRefs; + return num; } - /** Gets a {@link Reader} of shared item identified by (storageID, subtaskId, descriptor). */ - static Reader getReader( - StorageID storageID, int subtaskId, ItemDescriptor descriptor) { - Tuple3 t = Tuple3.of(storageID, subtaskId, descriptor.key); - Reader reader = new Reader<>(t); - incRef(t); + /** Gets a {@link Reader} of a shared object. */ + static Reader getReader(PoolID poolID, int subtaskId, ItemDescriptor descriptor) { + Tuple3 itemId = Tuple3.of(poolID, subtaskId, descriptor.name); + Reader reader = new Reader<>(itemId); + incRef(itemId); return reader; } - /** Gets a {@link Writer} of shared item identified by (storageID, subtaskId, key). */ + /** Gets a {@link Writer} of a shared object. */ static Writer getWriter( - StorageID storageID, + PoolID poolId, int subtaskId, ItemDescriptor descriptor, String ownerId, @@ -77,17 +90,17 @@ static Writer getWriter( StreamTask containingTask, StreamingRuntimeContext runtimeContext, StateInitializationContext stateInitializationContext) { - Tuple3 t = Tuple3.of(storageID, subtaskId, descriptor.key); - String lastOwner = owners.putIfAbsent(t, ownerId); + Tuple3 objId = Tuple3.of(poolId, subtaskId, descriptor.name); + String lastOwner = owners.putIfAbsent(objId, ownerId); if (null != lastOwner) { throw new IllegalStateException( String.format( "The shared item (%s, %s, %s) already has a writer %s.", - storageID, subtaskId, descriptor.key, ownerId)); + poolId, subtaskId, descriptor.name, ownerId)); } Writer writer = new Writer<>( - t, + objId, ownerId, descriptor.serializer, containingTask, @@ -95,15 +108,15 @@ static Writer getWriter( stateInitializationContext, operatorID); writer.set(descriptor.initVal); - incRef(t); + incRef(objId); return writer; } static class Reader { - protected final Tuple3 t; + protected final Tuple3 objId; - Reader(Tuple3 t) { - this.t = t; + Reader(Tuple3 objId) { + this.objId = objId; } T get() { @@ -112,7 +125,7 @@ T get() { long waitTime = 10; do { //noinspection unchecked - T value = (T) m.get(t); + T value = (T) values.get(objId); if (null != value) { return value; } @@ -124,11 +137,12 @@ T get() { waitTime *= 2; } while (waitTime < 10 * 1000); throw new IllegalStateException( - String.format("Failed to get value of %s after waiting %d ms.", t, waitTime)); + String.format( + "Failed to get value of %s after waiting %d ms.", objId, waitTime)); } void remove() { - decRef(t); + decRef(objId); } } @@ -138,14 +152,14 @@ static class Writer extends Reader { private boolean isDirty; Writer( - Tuple3 t, + Tuple3 itemId, String ownerId, TypeSerializer serializer, StreamTask containingTask, StreamingRuntimeContext runtimeContext, StateInitializationContext stateInitializationContext, OperatorID operatorID) { - super(t); + super(itemId); this.ownerId = ownerId; try { cache = @@ -159,7 +173,7 @@ static class Writer extends Reader { if (iterator.hasNext()) { T value = iterator.next(); ensureOwner(); - m.put(t, value); + values.put(itemId, value); } } catch (Exception e) { throw new RuntimeException(e); @@ -170,12 +184,12 @@ static class Writer extends Reader { private void ensureOwner() { // Double-checks the owner, because a writer may call this method after the key removed // and re-added by other operators. - Preconditions.checkState(owners.get(t).equals(ownerId)); + Preconditions.checkState(owners.get(objId).equals(ownerId)); } void set(T value) { ensureOwner(); - m.put(t, value); + values.put(objId, value); isDirty = true; } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageStreamOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsStreamOperator.java similarity index 72% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageStreamOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsStreamOperator.java index 81d964d11..52afa9617 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageStreamOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsStreamOperator.java @@ -16,17 +16,17 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; -/** Interface for all operators that need to access the shared storage. */ -public interface SharedStorageStreamOperator { +/** Interface for all operators that need to access the shared objects. */ +public interface SharedObjectsStreamOperator { /** - * Set the shared storage context in runtime. + * Set the shared objects context in runtime. * - * @param context The shared storage context. + * @param context The context for shared objects. */ - void onSharedStorageContextSet(SharedStorageContext context); + void onSharedObjectsContextSet(SharedObjectsContext context); /** * Get a unique ID to represent the operator instance. The ID must be kept unchanged through its @@ -34,5 +34,5 @@ public interface SharedStorageStreamOperator { * * @return A unique ID. */ - String getSharedStorageAccessorID(); + String getSharedObjectsAccessorID(); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java similarity index 72% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java index 0276f1c0a..022469bb8 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java @@ -16,9 +16,10 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.annotation.Experimental; +import org.apache.flink.api.dag.Transformation; import org.apache.flink.iteration.compile.DraftExecutionEnvironment; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -31,16 +32,16 @@ import java.util.UUID; import java.util.stream.Collectors; -/** Utility class to support {@link SharedStorage} in DataStream. */ +/** Utility class to support shared objects mechanism in DataStream. */ @Experimental -public class SharedStorageUtils { +public class SharedObjectsUtils { /** - * Support read/write access of data in the shared storage from operators which implements - * {@link SharedStorageStreamOperator}. + * Support read/write access of data in the shared objects from operators which implements + * {@link SharedObjectsStreamOperator}. * - *

In the shared storage `body`, users build the subgraph with data streams only from - * `inputs`, return streams that have access to the shared storage, and return the mapping from + *

In the shared objects `body`, users build the subgraph with data streams only from + * `inputs`, return streams that have access to the shared objects, and return the mapping from * shared items to their owners. * * @param inputs Input data streams. @@ -48,28 +49,28 @@ public class SharedStorageUtils { * item. * @return The output data streams. */ - public static List> withSharedStorage( - List> inputs, SharedStorageBody body) { + public static List> withSharedObjects( + List> inputs, SharedObjectsBody body) { Preconditions.checkArgument(inputs.size() > 0); StreamExecutionEnvironment env = inputs.get(0).getExecutionEnvironment(); String coLocationID = "shared-storage-" + UUID.randomUUID(); - SharedStorageContextImpl context = new SharedStorageContextImpl(); + SharedObjectsContextImpl context = new SharedObjectsContextImpl(); DraftExecutionEnvironment draftEnv = - new DraftExecutionEnvironment(env, new SharedStorageWrapper<>(context)); + new DraftExecutionEnvironment(env, new SharedObjectsWrapper<>(context)); List> draftSources = inputs.stream() .map( dataStream -> draftEnv.addDraftSource(dataStream, dataStream.getType())) .collect(Collectors.toList()); - SharedStorageBody.SharedStorageBodyResult result = body.process(draftSources); + SharedObjectsBody.SharedObjectsBodyResult result = body.process(draftSources); List> draftOutputs = result.getOutputs(); - Map, SharedStorageStreamOperator> rawOwnerMap = result.getOwnerMap(); + Map, SharedObjectsStreamOperator> rawOwnerMap = result.getOwnerMap(); Map, String> ownerMap = new HashMap<>(); for (ItemDescriptor item : rawOwnerMap.keySet()) { - ownerMap.put(item, rawOwnerMap.get(item).getSharedStorageAccessorID()); + ownerMap.put(item, rawOwnerMap.get(item).getSharedObjectsAccessorID()); } context.setOwnerMap(ownerMap); @@ -78,8 +79,8 @@ public static List> withSharedStorage( } draftEnv.copyToActualEnvironment(); - for (DataStream accessor : result.getAccessors()) { - DataStream ds = draftEnv.getActualStream(accessor.getTransformation().getId()); + for (Transformation transformation : result.getCoLocatedTransformations()) { + DataStream ds = draftEnv.getActualStream(transformation.getId()); ds.getTransformation().setCoLocationGroupKey(coLocationID); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageWrapper.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java similarity index 84% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageWrapper.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java index 357f0899d..9c0c06564 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/SharedStorageWrapper.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; @@ -31,13 +31,13 @@ import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.util.OutputTag; -/** The operator wrapper for {@link AbstractSharedStorageWrapperOperator}. */ -class SharedStorageWrapper implements OperatorWrapper { +/** The operator wrapper for {@link AbstractSharedObjectsWrapperOperator}. */ +class SharedObjectsWrapper implements OperatorWrapper { - /** Shared storage context. */ - private final SharedStorageContextImpl context; + /** Shared objects context. */ + private final SharedObjectsContextImpl context; - public SharedStorageWrapper(SharedStorageContextImpl context) { + public SharedObjectsWrapper(SharedObjectsContextImpl context) { this.context = context; } @@ -47,12 +47,12 @@ public StreamOperator wrap( StreamOperatorFactory operatorFactory) { Class operatorClass = operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); - if (SharedStorageStreamOperator.class.isAssignableFrom(operatorClass)) { + if (SharedObjectsStreamOperator.class.isAssignableFrom(operatorClass)) { if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { - return new OneInputSharedStorageWrapperOperator<>( + return new OneInputSharedObjectsWrapperOperator<>( operatorParameters, operatorFactory, context); } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { - return new TwoInputSharedStorageWrapperOperator<>( + return new TwoInputSharedObjectsWrapperOperator<>( operatorParameters, operatorFactory, context); } else { return nowrap(operatorParameters, operatorFactory); @@ -79,12 +79,12 @@ public Class getStreamOperatorClass( Class operatorClass = operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { - return OneInputSharedStorageWrapperOperator.class; + return OneInputSharedObjectsWrapperOperator.class; } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { - return TwoInputSharedStorageWrapperOperator.class; + return TwoInputSharedObjectsWrapperOperator.class; } else { throw new UnsupportedOperationException( - "Unsupported operator class for shared storage wrapper: " + operatorClass); + "Unsupported operator class for shared objects wrapper: " + operatorClass); } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/TwoInputSharedStorageWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java similarity index 92% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/TwoInputSharedStorageWrapperOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java index 2b23b7b4e..306c84517 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedstorage/TwoInputSharedStorageWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.streaming.api.operators.BoundedMultiInput; @@ -29,14 +29,14 @@ import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; /** Wrapper for {@link TwoInputStreamOperator}. */ -class TwoInputSharedStorageWrapperOperator - extends AbstractSharedStorageWrapperOperator> +class TwoInputSharedObjectsWrapperOperator + extends AbstractSharedObjectsWrapperOperator> implements TwoInputStreamOperator, BoundedMultiInput { - TwoInputSharedStorageWrapperOperator( + TwoInputSharedObjectsWrapperOperator( StreamOperatorParameters parameters, StreamOperatorFactory operatorFactory, - SharedStorageContextImpl context) { + SharedObjectsContextImpl context) { super(parameters, operatorFactory, context); } diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java similarity index 74% rename from flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java rename to flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java index 29fcdda15..e5241ba88 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedstorage/SharedStorageUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.sharedstorage; +package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -44,14 +44,14 @@ import java.util.Map; import java.util.UUID; -/** Tests the {@link SharedStorageUtils}. */ -public class SharedStorageUtilsTest { +/** Tests the {@link SharedObjectsUtils}. */ +public class SharedObjectsUtilsTest { private static final ItemDescriptor SUM = ItemDescriptor.of("sum", LongSerializer.INSTANCE, 0L); @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); - static SharedStorageBody.SharedStorageBodyResult sharedStorageBody(List> inputs) { + static SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody(List> inputs) { //noinspection unchecked DataStream data = (DataStream) inputs.get(0); @@ -63,21 +63,23 @@ static SharedStorageBody.SharedStorageBodyResult sharedStorageBody(List afterBOp = afterAOp.transform("b", TypeInformation.of(Long.class), bOp); - Map, SharedStorageStreamOperator> ownerMap = new HashMap<>(); + Map, SharedObjectsStreamOperator> ownerMap = new HashMap<>(); ownerMap.put(SUM, aOp); - return new SharedStorageBody.SharedStorageBodyResult( - Collections.singletonList(afterBOp), Arrays.asList(afterAOp, afterBOp), ownerMap); + return new SharedObjectsBody.SharedObjectsBodyResult( + Collections.singletonList(afterBOp), + Arrays.asList(afterAOp.getTransformation(), afterBOp.getTransformation()), + ownerMap); } @Test - public void testSharedStorage() throws Exception { + public void testSharedObjects() throws Exception { StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); DataStream data = env.fromSequence(1, 100); List> outputs = - SharedStorageUtils.withSharedStorage( - Collections.singletonList(data), SharedStorageUtilsTest::sharedStorageBody); + SharedObjectsUtils.withSharedObjects( + Collections.singletonList(data), SharedObjectsUtilsTest::sharedObjectsBody); //noinspection unchecked DataStream partitionSum = (DataStream) outputs.get(0); DataStream allSum = DataStreamUtils.reduce(partitionSum, new SumReduceFunction()); @@ -90,29 +92,29 @@ public void testSharedStorage() throws Exception { /** Operator A: add input elements to the shared {@link #SUM}. */ static class AOperator extends AbstractStreamOperator implements OneInputStreamOperator, - SharedStorageStreamOperator, + SharedObjectsStreamOperator, BoundedOneInput { - private final String sharedStorageAccessorID; - private SharedStorageContext sharedStorageContext; + private final String sharedObjectsAccessorID; + private SharedObjectsContext sharedObjectsContext; public AOperator() { - sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override - public void onSharedStorageContextSet(SharedStorageContext context) { - this.sharedStorageContext = context; + public void onSharedObjectsContextSet(SharedObjectsContext context) { + this.sharedObjectsContext = context; } @Override - public String getSharedStorageAccessorID() { - return sharedStorageAccessorID; + public String getSharedObjectsAccessorID() { + return sharedObjectsAccessorID; } @Override public void processElement(StreamRecord element) throws Exception { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { Long currentSum = getter.get(SUM); setter.set(SUM, currentSum + element.getValue()); @@ -128,28 +130,28 @@ public void endInput() throws Exception { /** Operator B: when input ends, get the value from shared {@link #SUM}. */ static class BOperator extends AbstractStreamOperator - implements OneInputStreamOperator, SharedStorageStreamOperator { + implements OneInputStreamOperator, SharedObjectsStreamOperator { - private final String sharedStorageAccessorID; - private SharedStorageContext sharedStorageContext; + private final String sharedObjectsAccessorID; + private SharedObjectsContext sharedObjectsContext; public BOperator() { - sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override - public void onSharedStorageContextSet(SharedStorageContext context) { - this.sharedStorageContext = context; + public void onSharedObjectsContextSet(SharedObjectsContext context) { + this.sharedObjectsContext = context; } @Override - public String getSharedStorageAccessorID() { - return sharedStorageAccessorID; + public String getSharedObjectsAccessorID() { + return sharedObjectsAccessorID; } @Override public void processElement(StreamRecord element) throws Exception { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { output.collect(new StreamRecord<>(getter.get(SUM))); }); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java index d2ae5b446..a8c8272e4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BaseGBTParams.java @@ -30,7 +30,7 @@ /** * Common parameters for GBT classifier and regressor. * - *

NOTE: Features related with {@link #WEIGHT_COL}, {@link #LEAF_COL}, and {@link + *

NOTE: Features related to {@link #WEIGHT_COL}, {@link #LEAF_COL}, and {@link * #VALIDATION_INDICATOR_COL} are not implemented yet. * * @param The class type of this instance. @@ -43,63 +43,75 @@ public interface BaseGBTParams "Regularization term for the number of leaves.", 0., ParamValidators.gtEq(0.)); + Param REG_GAMMA = new DoubleParam( "regGamma", "L2 regularization term for the weights of leaves.", 1., ParamValidators.gtEq(0)); + Param LEAF_COL = new StringParam("leafCol", "Predicted leaf index of each instance in each tree.", null); + Param MAX_DEPTH = new IntParam("maxDepth", "Maximum depth of the tree.", 5, ParamValidators.gtEq(1)); + Param MAX_BINS = new IntParam( "maxBins", "Maximum number of bins used for discretizing continuous features.", 32, ParamValidators.gtEq(2)); + Param MIN_INSTANCES_PER_NODE = new IntParam( "minInstancesPerNode", "Minimum number of instances each node must have. If a split causes the left or right child to have fewer instances than minInstancesPerNode, the split is invalid.", 1, ParamValidators.gtEq(1)); + Param MIN_WEIGHT_FRACTION_PER_NODE = new DoubleParam( "minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each node must have. If a split causes the left or right child to have a smaller fraction of the total weight than minWeightFractionPerNode, the split is invalid.", 0., ParamValidators.gtEq(0.)); + Param MIN_INFO_GAIN = new DoubleParam( "minInfoGain", "Minimum information gain for a split to be considered valid.", 0., ParamValidators.gtEq(0.)); + Param STEP_SIZE = new DoubleParam( "stepSize", "Step size for shrinking the contribution of each estimator.", 0.1, ParamValidators.inRange(0., 1.)); + Param SUBSAMPLING_RATE = new DoubleParam( "subsamplingRate", "Fraction of the training data used for learning one tree.", 1., ParamValidators.inRange(0., 1.)); + Param FEATURE_SUBSET_STRATEGY = new StringParam( "featureSubsetStrategy.", "Fraction of the training data used for learning one tree. Supports \"auto\", \"all\", \"onethird\", \"sqrt\", \"log2\", (0.0 - 1.0], and [1 - n].", "auto", ParamValidators.notNull()); + Param VALIDATION_INDICATOR_COL = new StringParam( "validationIndicatorCol", "The name of the column that indicates whether each row is for training or for validation.", null); + Param VALIDATION_TOL = new DoubleParam( "validationTol", diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index 85be1c91c..3da40a9ca 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -34,12 +34,12 @@ import org.apache.flink.ml.common.gbt.operators.PostSplitsOperator; import org.apache.flink.ml.common.gbt.operators.ReduceHistogramFunction; import org.apache.flink.ml.common.gbt.operators.ReduceSplitsOperator; -import org.apache.flink.ml.common.gbt.operators.SharedStorageConstants; +import org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants; import org.apache.flink.ml.common.gbt.operators.TerminationOperator; -import org.apache.flink.ml.common.sharedstorage.ItemDescriptor; -import org.apache.flink.ml.common.sharedstorage.SharedStorageBody; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; -import org.apache.flink.ml.common.sharedstorage.SharedStorageUtils; +import org.apache.flink.ml.common.sharedobjects.ItemDescriptor; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsBody; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.types.Row; @@ -61,14 +61,14 @@ public BoostIterationBody(BoostingStrategy strategy) { this.strategy = strategy; } - private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( + private SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody( List> inputs) { //noinspection unchecked DataStream data = (DataStream) inputs.get(0); //noinspection unchecked DataStream trainContext = (DataStream) inputs.get(1); - Map, SharedStorageStreamOperator> ownerMap = new HashMap<>(); + Map, SharedObjectsStreamOperator> ownerMap = new HashMap<>(); CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = new CacheDataCalcLocalHistsOperator(strategy); @@ -79,7 +79,7 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( Types.TUPLE( Types.INT, Types.INT, TypeInformation.of(Histogram.class)), cacheDataCalcLocalHistsOp); - for (ItemDescriptor s : SharedStorageConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { + for (ItemDescriptor s : SharedObjectsConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { ownerMap.put(s, cacheDataCalcLocalHistsOp); } @@ -105,7 +105,7 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( globalSplits .broadcast() .transform("PostSplits", TypeInformation.of(Integer.class), postSplitsOp); - for (ItemDescriptor descriptor : SharedStorageConstants.OWNED_BY_POST_SPLITS_OP) { + for (ItemDescriptor descriptor : SharedObjectsConstants.OWNED_BY_POST_SPLITS_OP) { ownerMap.put(descriptor, postSplitsOp); } @@ -119,9 +119,14 @@ private SharedStorageBody.SharedStorageBodyResult sharedStorageBody( DataStream finalModelData = termination.getSideOutput(finalModelDataOutputTag); - return new SharedStorageBody.SharedStorageBodyResult( + return new SharedObjectsBody.SharedObjectsBodyResult( Arrays.asList(updatedModelData, finalModelData, termination), - Arrays.asList(localHists, localSplits, globalSplits, updatedModelData, termination), + Arrays.asList( + localHists.getTransformation(), + localSplits.getTransformation(), + globalSplits.getTransformation(), + updatedModelData.getTransformation(), + termination.getTransformation()), ownerMap); } @@ -131,8 +136,8 @@ public IterationBodyResult process(DataStreamList variableStreams, DataStreamLis DataStream trainContext = variableStreams.get(0); List> outputs = - SharedStorageUtils.withSharedStorage( - Arrays.asList(data, trainContext), this::sharedStorageBody); + SharedObjectsUtils.withSharedObjects( + Arrays.asList(data, trainContext), this::sharedObjectsBody); DataStream updatedModelData = outputs.get(0); DataStream finalModelData = outputs.get(1); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index acbf2b6b6..bedaccff5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -30,8 +30,8 @@ import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; import org.apache.flink.ml.common.lossfunc.LossFunc; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.runtime.state.StateInitializationContext; @@ -61,13 +61,13 @@ public class CacheDataCalcLocalHistsOperator extends AbstractStreamOperator> implements TwoInputStreamOperator>, IterationListener>, - SharedStorageStreamOperator { + SharedObjectsStreamOperator { private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; private static final String HIST_BUILDER_STATE_NAME = "hist_builder"; private final BoostingStrategy strategy; - private final String sharedStorageAccessorID; + private final String sharedObjectsAccessorID; // States of local data. private transient ListStateWithCache instancesCollecting; @@ -75,12 +75,12 @@ public class CacheDataCalcLocalHistsOperator private transient TreeInitializer treeInitializer; private transient ListStateWithCache histBuilderState; private transient HistBuilder histBuilder; - private transient SharedStorageContext sharedStorageContext; + private transient SharedObjectsContext sharedObjectsContext; public CacheDataCalcLocalHistsOperator(BoostingStrategy strategy) { super(); this.strategy = strategy; - sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -149,9 +149,9 @@ public void processElement1(StreamRecord streamRecord) throws Exception { @Override public void processElement2(StreamRecord streamRecord) throws Exception { TrainContext rawTrainContext = streamRecord.getValue(); - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> - setter.set(SharedStorageConstants.TRAIN_CONTEXT, rawTrainContext)); + setter.set(SharedObjectsConstants.TRAIN_CONTEXT, rawTrainContext)); } public void onEpochWatermarkIncremented( @@ -159,18 +159,18 @@ public void onEpochWatermarkIncremented( throws Exception { if (0 == epochWatermark) { // Initializes local state in first round. - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { BinnedInstance[] instances = (BinnedInstance[]) IteratorUtils.toArray( instancesCollecting.get().iterator(), BinnedInstance.class); - setter.set(SharedStorageConstants.INSTANCES, instances); + setter.set(SharedObjectsConstants.INSTANCES, instances); instancesCollecting.clear(); TrainContext rawTrainContext = - getter.get(SharedStorageConstants.TRAIN_CONTEXT); + getter.get(SharedObjectsConstants.TRAIN_CONTEXT); TrainContext trainContext = new TrainContextInitializer(strategy) .init( @@ -178,7 +178,7 @@ public void onEpochWatermarkIncremented( getRuntimeContext().getIndexOfThisSubtask(), getRuntimeContext().getNumberOfParallelSubtasks(), instances); - setter.set(SharedStorageConstants.TRAIN_CONTEXT, trainContext); + setter.set(SharedObjectsConstants.TRAIN_CONTEXT, trainContext); treeInitializer = new TreeInitializer(trainContext); treeInitializerState.update(Collections.singletonList(treeInitializer)); @@ -187,13 +187,13 @@ public void onEpochWatermarkIncremented( }); } - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { - TrainContext trainContext = getter.get(SharedStorageConstants.TRAIN_CONTEXT); + TrainContext trainContext = getter.get(SharedObjectsConstants.TRAIN_CONTEXT); Preconditions.checkArgument( getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); - BinnedInstance[] instances = getter.get(SharedStorageConstants.INSTANCES); - double[] pgh = getter.get(SharedStorageConstants.PREDS_GRADS_HESSIANS); + BinnedInstance[] instances = getter.get(SharedObjectsConstants.INSTANCES); + double[] pgh = getter.get(SharedObjectsConstants.PREDS_GRADS_HESSIANS); // In the first round, use prior as the predictions. if (0 == pgh.length) { pgh = new double[instances.length * 3]; @@ -207,26 +207,26 @@ public void onEpochWatermarkIncremented( } } - boolean needInitTree = getter.get(SharedStorageConstants.NEED_INIT_TREE); + boolean needInitTree = getter.get(SharedObjectsConstants.NEED_INIT_TREE); int[] indices; List layer; if (needInitTree) { // When last tree is finished, initializes a new tree, and shuffle instance // indices. treeInitializer.init( - getter.get(SharedStorageConstants.ALL_TREES).size(), - d -> setter.set(SharedStorageConstants.SHUFFLED_INDICES, d)); + getter.get(SharedObjectsConstants.ALL_TREES).size(), + d -> setter.set(SharedObjectsConstants.SHUFFLED_INDICES, d)); LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); - indices = getter.get(SharedStorageConstants.SHUFFLED_INDICES); + indices = getter.get(SharedObjectsConstants.SHUFFLED_INDICES); layer = Collections.singletonList(rootLearningNode); - setter.set(SharedStorageConstants.ROOT_LEARNING_NODE, rootLearningNode); - setter.set(SharedStorageConstants.HAS_INITED_TREE, true); + setter.set(SharedObjectsConstants.ROOT_LEARNING_NODE, rootLearningNode); + setter.set(SharedObjectsConstants.HAS_INITED_TREE, true); } else { // Otherwise, uses the swapped instance indices. - indices = getter.get(SharedStorageConstants.SWAPPED_INDICES); - layer = getter.get(SharedStorageConstants.LAYER); - setter.set(SharedStorageConstants.SHUFFLED_INDICES, new int[0]); - setter.set(SharedStorageConstants.HAS_INITED_TREE, false); + indices = getter.get(SharedObjectsConstants.SWAPPED_INDICES); + layer = getter.get(SharedObjectsConstants.LAYER); + setter.set(SharedObjectsConstants.SHUFFLED_INDICES, new int[0]); + setter.set(SharedObjectsConstants.HAS_INITED_TREE, false); } histBuilder.build( @@ -234,7 +234,7 @@ public void onEpochWatermarkIncremented( indices, instances, pgh, - d -> setter.set(SharedStorageConstants.NODE_FEATURE_PAIRS, d), + d -> setter.set(SharedObjectsConstants.NODE_FEATURE_PAIRS, d), out); }); } @@ -247,11 +247,11 @@ public void onIterationTerminated( treeInitializerState.clear(); histBuilderState.clear(); - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { - setter.set(SharedStorageConstants.INSTANCES, new BinnedInstance[0]); - setter.set(SharedStorageConstants.SHUFFLED_INDICES, new int[0]); - setter.set(SharedStorageConstants.NODE_FEATURE_PAIRS, new int[0]); + setter.set(SharedObjectsConstants.INSTANCES, new BinnedInstance[0]); + setter.set(SharedObjectsConstants.SHUFFLED_INDICES, new int[0]); + setter.set(SharedObjectsConstants.NODE_FEATURE_PAIRS, new int[0]); }); } @@ -264,12 +264,12 @@ public void close() throws Exception { } @Override - public void onSharedStorageContextSet(SharedStorageContext context) { - this.sharedStorageContext = context; + public void onSharedObjectsContextSet(SharedObjectsContext context) { + this.sharedObjectsContext = context; } @Override - public String getSharedStorageAccessorID() { - return sharedStorageAccessorID; + public String getSharedObjectsAccessorID() { + return sharedObjectsAccessorID; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java index 03ccf8342..b55b1f87d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Split; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -50,18 +50,18 @@ public class CalcLocalSplitsOperator extends AbstractStreamOperator> implements OneInputStreamOperator< Tuple2, Tuple3>, - SharedStorageStreamOperator { + SharedObjectsStreamOperator { private static final Logger LOG = LoggerFactory.getLogger(CalcLocalSplitsOperator.class); private static final String SPLIT_FINDER_STATE_NAME = "split_finder"; - private final String sharedStorageAccessorID; + private final String sharedObjectsAccessorID; // States of local data. private transient ListStateWithCache splitFinderState; private transient SplitFinder splitFinder; - private transient SharedStorageContext sharedStorageContext; + private transient SharedObjectsContext sharedObjectsContext; public CalcLocalSplitsOperator() { - sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -88,10 +88,10 @@ public void snapshotState(StateSnapshotContext context) throws Exception { @Override public void processElement(StreamRecord> element) throws Exception { if (null == splitFinder) { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { splitFinder = - new SplitFinder(getter.get(SharedStorageConstants.TRAIN_CONTEXT)); + new SplitFinder(getter.get(SharedObjectsConstants.TRAIN_CONTEXT)); splitFinderState.update(Collections.singletonList(splitFinder)); }); } @@ -100,16 +100,16 @@ public void processElement(StreamRecord> element) thr int pairId = value.f0; Histogram histogram = value.f1; LOG.debug("Received histogram for pairId: {}", pairId); - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { - List layer = getter.get(SharedStorageConstants.LAYER); + List layer = getter.get(SharedObjectsConstants.LAYER); if (layer.size() == 0) { layer = Collections.singletonList( - getter.get(SharedStorageConstants.ROOT_LEARNING_NODE)); + getter.get(SharedObjectsConstants.ROOT_LEARNING_NODE)); } - int[] nodeFeaturePairs = getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS); + int[] nodeFeaturePairs = getter.get(SharedObjectsConstants.NODE_FEATURE_PAIRS); int nodeId = nodeFeaturePairs[2 * pairId]; int featureId = nodeFeaturePairs[2 * pairId + 1]; LearningNode node = layer.get(nodeId); @@ -118,7 +118,7 @@ public void processElement(StreamRecord> element) thr splitFinder.calc( node, featureId, - getter.get(SharedStorageConstants.LEAVES).size(), + getter.get(SharedObjectsConstants.LEAVES).size(), histogram); output.collect(new StreamRecord<>(Tuple3.of(nodeId, pairId, bestSplit))); }); @@ -132,12 +132,12 @@ public void close() throws Exception { } @Override - public void onSharedStorageContextSet(SharedStorageContext context) { - this.sharedStorageContext = context; + public void onSharedObjectsContextSet(SharedObjectsContext context) { + this.sharedObjectsContext = context; } @Override - public String getSharedStorageAccessorID() { - return sharedStorageAccessorID; + public String getSharedObjectsAccessorID() { + return sharedObjectsAccessorID; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index 1ba439de9..a17b95ce6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -28,8 +28,8 @@ import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -52,14 +52,14 @@ public class PostSplitsOperator extends AbstractStreamOperator implements OneInputStreamOperator, Integer>, IterationListener, - SharedStorageStreamOperator { + SharedObjectsStreamOperator { private static final String NODE_SPLITTER_STATE_NAME = "node_splitter"; private static final String INSTANCE_UPDATER_STATE_NAME = "instance_updater"; private static final Logger LOG = LoggerFactory.getLogger(PostSplitsOperator.class); - private final String sharedStorageAccessorID; + private final String sharedObjectsAccessorID; // States of local data. private transient Split[] nodeSplits; @@ -67,10 +67,10 @@ public class PostSplitsOperator extends AbstractStreamOperator private transient NodeSplitter nodeSplitter; private transient ListStateWithCache instanceUpdaterState; private transient InstanceUpdater instanceUpdater; - private transient SharedStorageContext sharedStorageContext; + private transient SharedObjectsContext sharedObjectsContext; public PostSplitsOperator() { - sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -111,10 +111,10 @@ public void snapshotState(StateSnapshotContext context) throws Exception { public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { if (0 == epochWatermark) { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { TrainContext trainContext = - getter.get(SharedStorageConstants.TRAIN_CONTEXT); + getter.get(SharedObjectsConstants.TRAIN_CONTEXT); nodeSplitter = new NodeSplitter(trainContext); nodeSplitterState.update(Collections.singletonList(nodeSplitter)); instanceUpdater = new InstanceUpdater(trainContext); @@ -122,25 +122,25 @@ public void onEpochWatermarkIncremented( }); } - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { - int[] indices = getter.get(SharedStorageConstants.SWAPPED_INDICES); + int[] indices = getter.get(SharedObjectsConstants.SWAPPED_INDICES); if (0 == indices.length) { - indices = getter.get(SharedStorageConstants.SHUFFLED_INDICES).clone(); + indices = getter.get(SharedObjectsConstants.SHUFFLED_INDICES).clone(); } - BinnedInstance[] instances = getter.get(SharedStorageConstants.INSTANCES); - List leaves = getter.get(SharedStorageConstants.LEAVES); - List layer = getter.get(SharedStorageConstants.LAYER); + BinnedInstance[] instances = getter.get(SharedObjectsConstants.INSTANCES); + List leaves = getter.get(SharedObjectsConstants.LEAVES); + List layer = getter.get(SharedObjectsConstants.LAYER); List currentTreeNodes; if (layer.size() == 0) { layer = Collections.singletonList( - getter.get(SharedStorageConstants.ROOT_LEARNING_NODE)); + getter.get(SharedObjectsConstants.ROOT_LEARNING_NODE)); currentTreeNodes = new ArrayList<>(); currentTreeNodes.add(new Node()); } else { - currentTreeNodes = getter.get(SharedStorageConstants.CURRENT_TREE_NODES); + currentTreeNodes = getter.get(SharedObjectsConstants.CURRENT_TREE_NODES); } List nextLayer = @@ -152,31 +152,31 @@ public void onEpochWatermarkIncremented( indices, instances); nodeSplits = null; - setter.set(SharedStorageConstants.LEAVES, leaves); - setter.set(SharedStorageConstants.LAYER, nextLayer); - setter.set(SharedStorageConstants.CURRENT_TREE_NODES, currentTreeNodes); + setter.set(SharedObjectsConstants.LEAVES, leaves); + setter.set(SharedObjectsConstants.LAYER, nextLayer); + setter.set(SharedObjectsConstants.CURRENT_TREE_NODES, currentTreeNodes); if (nextLayer.isEmpty()) { // Current tree is finished. - setter.set(SharedStorageConstants.NEED_INIT_TREE, true); + setter.set(SharedObjectsConstants.NEED_INIT_TREE, true); instanceUpdater.update( - getter.get(SharedStorageConstants.PREDS_GRADS_HESSIANS), + getter.get(SharedObjectsConstants.PREDS_GRADS_HESSIANS), leaves, indices, instances, - d -> setter.set(SharedStorageConstants.PREDS_GRADS_HESSIANS, d), + d -> setter.set(SharedObjectsConstants.PREDS_GRADS_HESSIANS, d), currentTreeNodes); leaves.clear(); - List> allTrees = getter.get(SharedStorageConstants.ALL_TREES); + List> allTrees = getter.get(SharedObjectsConstants.ALL_TREES); allTrees.add(currentTreeNodes); - setter.set(SharedStorageConstants.LEAVES, new ArrayList<>()); - setter.set(SharedStorageConstants.SWAPPED_INDICES, new int[0]); - setter.set(SharedStorageConstants.ALL_TREES, allTrees); + setter.set(SharedObjectsConstants.LEAVES, new ArrayList<>()); + setter.set(SharedObjectsConstants.SWAPPED_INDICES, new int[0]); + setter.set(SharedObjectsConstants.ALL_TREES, allTrees); LOG.info("finalize {}-th tree", allTrees.size()); } else { - setter.set(SharedStorageConstants.SWAPPED_INDICES, indices); - setter.set(SharedStorageConstants.NEED_INIT_TREE, false); + setter.set(SharedObjectsConstants.SWAPPED_INDICES, indices); + setter.set(SharedObjectsConstants.NEED_INIT_TREE, false); } }); } @@ -184,22 +184,22 @@ public void onEpochWatermarkIncremented( @Override public void onIterationTerminated(Context context, Collector collector) throws Exception { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { - setter.set(SharedStorageConstants.PREDS_GRADS_HESSIANS, new double[0]); - setter.set(SharedStorageConstants.SWAPPED_INDICES, new int[0]); - setter.set(SharedStorageConstants.LEAVES, Collections.emptyList()); - setter.set(SharedStorageConstants.LAYER, Collections.emptyList()); - setter.set(SharedStorageConstants.CURRENT_TREE_NODES, Collections.emptyList()); + setter.set(SharedObjectsConstants.PREDS_GRADS_HESSIANS, new double[0]); + setter.set(SharedObjectsConstants.SWAPPED_INDICES, new int[0]); + setter.set(SharedObjectsConstants.LEAVES, Collections.emptyList()); + setter.set(SharedObjectsConstants.LAYER, Collections.emptyList()); + setter.set(SharedObjectsConstants.CURRENT_TREE_NODES, Collections.emptyList()); }); } @Override public void processElement(StreamRecord> element) throws Exception { if (null == nodeSplits) { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { - List layer = getter.get(SharedStorageConstants.LAYER); + List layer = getter.get(SharedObjectsConstants.LAYER); int numNodes = (layer.size() == 0) ? 1 : layer.size(); nodeSplits = new Split[numNodes]; }); @@ -219,12 +219,12 @@ public void close() throws Exception { } @Override - public void onSharedStorageContextSet(SharedStorageContext context) { - sharedStorageContext = context; + public void onSharedObjectsContextSet(SharedObjectsContext context) { + sharedObjectsContext = context; } @Override - public String getSharedStorageAccessorID() { - return sharedStorageAccessorID; + public String getSharedObjectsAccessorID() { + return sharedObjectsAccessorID; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java index d33a4a8e1..686733246 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java @@ -21,8 +21,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.common.gbt.defs.Split; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; @@ -45,30 +45,30 @@ */ public class ReduceSplitsOperator extends AbstractStreamOperator> implements OneInputStreamOperator, Tuple2>, - SharedStorageStreamOperator { + SharedObjectsStreamOperator { private static final Logger LOG = LoggerFactory.getLogger(ReduceSplitsOperator.class); - private final String sharedStorageAccessorID; + private final String sharedObjectsAccessorID; - private transient SharedStorageContext sharedStorageContext; + private transient SharedObjectsContext sharedObjectsContext; private Map nodeFeatureMap; private Map nodeBestSplit; private Map nodeFeatureCounter; public ReduceSplitsOperator() { - sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override - public void onSharedStorageContextSet(SharedStorageContext context) { - sharedStorageContext = context; + public void onSharedObjectsContextSet(SharedObjectsContext context) { + sharedObjectsContext = context; } @Override - public String getSharedStorageAccessorID() { - return sharedStorageAccessorID; + public String getSharedObjectsAccessorID() { + return sharedObjectsAccessorID; } @Override @@ -84,10 +84,10 @@ public void processElement(StreamRecord> element if (nodeFeatureMap.isEmpty()) { Preconditions.checkState(nodeBestSplit.isEmpty()); nodeFeatureCounter.clear(); - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { int[] nodeFeaturePairs = - getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS); + getter.get(SharedObjectsConstants.NODE_FEATURE_PAIRS); for (int i = 0; i < nodeFeaturePairs.length / 2; i += 1) { int nodeId = nodeFeaturePairs[2 * i]; nodeFeatureCounter.compute(nodeId, (k, v) -> null == v ? 1 : v + 1); @@ -103,9 +103,9 @@ public void processElement(StreamRecord> element if (featureMap.isEmpty()) { LOG.debug("Received split for new node {}", nodeId); } - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { - int[] nodeFeaturePairs = getter.get(SharedStorageConstants.NODE_FEATURE_PAIRS); + int[] nodeFeaturePairs = getter.get(SharedObjectsConstants.NODE_FEATURE_PAIRS); Preconditions.checkState(nodeId == nodeFeaturePairs[pairId * 2]); int featureId = nodeFeaturePairs[pairId * 2 + 1]; Preconditions.checkState(!featureMap.get(featureId)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java similarity index 95% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java index a42ffe762..1c5d4e8fd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedStorageConstants.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java @@ -33,8 +33,8 @@ import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; import org.apache.flink.ml.common.gbt.typeinfo.NodeSerializer; -import org.apache.flink.ml.common.sharedstorage.ItemDescriptor; -import org.apache.flink.ml.common.sharedstorage.SharedStorageUtils; +import org.apache.flink.ml.common.sharedobjects.ItemDescriptor; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsUtils; import org.apache.flink.ml.linalg.typeinfo.OptimizedDoublePrimitiveArraySerializer; import java.util.ArrayList; @@ -42,11 +42,11 @@ import java.util.List; /** - * Stores constants used for {@link SharedStorageUtils} in {@link GBTRunner}. + * Stores constants used for {@link SharedObjectsUtils} in {@link GBTRunner}. * *

In the iteration, some data needs to be shared and accessed between subtasks of different * operators within one JVM to reduce memory footprint and communication cost. We use {@link - * SharedStorageUtils} with co-location mechanism to achieve such purpose. + * SharedObjectsUtils} with co-location mechanism to achieve such purpose. * *

All shared data items have corresponding {@link ItemDescriptor}s, and can be read/written * through {@link ItemDescriptor}s from different operator subtasks. Note that every shared item has @@ -55,7 +55,7 @@ *

This class records all {@link ItemDescriptor}s used in {@link GBTRunner} and their owners. */ @Internal -public class SharedStorageConstants { +public class SharedObjectsConstants { /** Instances (after binned). */ static final ItemDescriptor INSTANCES = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index ee711d78a..b5f77585d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -20,8 +20,8 @@ import org.apache.flink.iteration.IterationListener; import org.apache.flink.ml.common.gbt.GBTModelData; -import org.apache.flink.ml.common.sharedstorage.SharedStorageContext; -import org.apache.flink.ml.common.sharedstorage.SharedStorageStreamOperator; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; +import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; @@ -35,15 +35,15 @@ public class TerminationOperator extends AbstractStreamOperator implements OneInputStreamOperator, IterationListener, - SharedStorageStreamOperator { + SharedObjectsStreamOperator { private final OutputTag modelDataOutputTag; - private final String sharedStorageAccessorID; - private transient SharedStorageContext sharedStorageContext; + private final String sharedObjectsAccessorID; + private transient SharedObjectsContext sharedObjectsContext; public TerminationOperator(OutputTag modelDataOutputTag) { this.modelDataOutputTag = modelDataOutputTag; - sharedStorageAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -58,11 +58,11 @@ public void processElement(StreamRecord element) throws Exception {} public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> { boolean terminated = - getter.get(SharedStorageConstants.ALL_TREES).size() - == getter.get(SharedStorageConstants.TRAIN_CONTEXT) + getter.get(SharedObjectsConstants.ALL_TREES).size() + == getter.get(SharedObjectsConstants.TRAIN_CONTEXT) .strategy .maxIter; // TODO: Add validation error rate @@ -76,23 +76,23 @@ public void onEpochWatermarkIncremented( public void onIterationTerminated(Context context, Collector collector) throws Exception { if (0 == getRuntimeContext().getIndexOfThisSubtask()) { - sharedStorageContext.invoke( + sharedObjectsContext.invoke( (getter, setter) -> context.output( modelDataOutputTag, GBTModelData.from( - getter.get(SharedStorageConstants.TRAIN_CONTEXT), - getter.get(SharedStorageConstants.ALL_TREES)))); + getter.get(SharedObjectsConstants.TRAIN_CONTEXT), + getter.get(SharedObjectsConstants.ALL_TREES)))); } } @Override - public void onSharedStorageContextSet(SharedStorageContext context) { - sharedStorageContext = context; + public void onSharedObjectsContextSet(SharedObjectsContext context) { + sharedObjectsContext = context; } @Override - public String getSharedStorageAccessorID() { - return sharedStorageAccessorID; + public String getSharedObjectsAccessorID() { + return sharedObjectsAccessorID; } } From e098994211ad7c26cfd83113d9f90f8d888cdaa4 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 29 May 2023 19:59:47 +0800 Subject: [PATCH 44/47] Update codes according to comments. --- .../org/apache/flink/ml/common/gbt/BoostIterationBody.java | 2 +- .../main/java/org/apache/flink/ml/common/gbt/GBTRunner.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index 3da40a9ca..7326ad17b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -73,7 +73,7 @@ private SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody( CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = new CacheDataCalcLocalHistsOperator(strategy); SingleOutputStreamOperator> localHists = - data.connect(trainContext) + data.connect(trainContext.broadcast()) .transform( "CacheDataCalcLocalHists", Types.TUPLE( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java index 6f753da4c..28aa36277 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/GBTRunner.java @@ -168,8 +168,8 @@ private static DataStream boost( DataStream data = tEnv.toDataStream(dataTable); DataStreamList dataStreamList = Iterations.iterateBoundedStreamsUntilTermination( - DataStreamList.of(initTrainContext.broadcast()), - ReplayableDataStreamList.notReplay(data, featureMeta), + DataStreamList.of(initTrainContext), + ReplayableDataStreamList.notReplay(data), IterationConfig.newBuilder() .setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND) .build(), From 233b885b3eac4f30146afcaad4db3193be86256e Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Mon, 15 May 2023 17:22:49 +0800 Subject: [PATCH 45/47] Refactor share objects infra to resolve challenges. --- ...tSharedObjectsOneInputStreamOperator.java} | 15 +- .../AbstractSharedObjectsStreamOperator.java | 56 +++ ...tSharedObjectsTwoInputStreamOperator.java} | 25 +- .../AbstractSharedObjectsWrapperOperator.java | 246 ++++++++++++- .../ml/common/sharedobjects/Descriptor.java | 117 +++++++ .../common/sharedobjects/ItemDescriptor.java | 78 ----- .../OneInputSharedObjectsWrapperOperator.java | 41 ++- .../ml/common/sharedobjects/ReadRequest.java | 65 ++++ .../sharedobjects/SharedObjectsBody.java | 22 +- .../sharedobjects/SharedObjectsContext.java | 53 +-- .../SharedObjectsContextImpl.java | 134 +++++--- .../sharedobjects/SharedObjectsPools.java | 212 +++++++++--- .../sharedobjects/SharedObjectsUtils.java | 33 +- .../sharedobjects/SharedObjectsWrapper.java | 2 +- .../TwoInputSharedObjectsWrapperOperator.java | 76 ++++- .../org/apache/flink/ml/util/Distributor.java | 77 ----- .../sharedobjects/SharedObjectsUtilsTest.java | 322 +++++++++++++----- .../ml/common/gbt/BoostIterationBody.java | 10 +- .../CacheDataCalcLocalHistsOperator.java | 211 ++++++------ .../operators/CalcLocalSplitsOperator.java | 92 +++-- .../gbt/operators/PostSplitsOperator.java | 198 +++++------ .../gbt/operators/ReduceSplitsOperator.java | 66 ++-- .../gbt/operators/SharedObjectsConstants.java | 74 ++-- .../gbt/operators/TerminationOperator.java | 72 ++-- 24 files changed, 1472 insertions(+), 825 deletions(-) rename flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/{PoolID.java => AbstractSharedObjectsOneInputStreamOperator.java} (65%) create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsStreamOperator.java rename flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/{SharedObjectsStreamOperator.java => AbstractSharedObjectsTwoInputStreamOperator.java} (61%) create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/Descriptor.java delete mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ItemDescriptor.java create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ReadRequest.java delete mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/util/Distributor.java diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/PoolID.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsOneInputStreamOperator.java similarity index 65% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/PoolID.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsOneInputStreamOperator.java index 77ff6573e..9d0ccbfdc 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/PoolID.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsOneInputStreamOperator.java @@ -18,15 +18,14 @@ package org.apache.flink.ml.common.sharedobjects; -import org.apache.flink.util.AbstractID; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -/** ID of a pool for shared objects. */ -class PoolID extends AbstractID { - private static final long serialVersionUID = 1L; +import java.util.List; - public PoolID(byte[] bytes) { - super(bytes); - } +/** The base class for {@link OneInputStreamOperator}s where shared objects are accessed. */ +public abstract class AbstractSharedObjectsOneInputStreamOperator + extends AbstractSharedObjectsStreamOperator + implements OneInputStreamOperator { - public PoolID() {} + public abstract List> readRequestsInProcessElement(); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsStreamOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsStreamOperator.java new file mode 100644 index 000000000..edc5a530f --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsStreamOperator.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; + +import java.util.UUID; + +/** + * A base class of stream operators where shared objects are required. + * + *

Official subclasses, i.e., {@link AbstractSharedObjectsOneInputStreamOperator} and {@link + * AbstractSharedObjectsTwoInputStreamOperator}, are strongly recommended. + * + *

If you are going to implement a subclass by yourself, you have to handle potential deadlocks. + */ +public abstract class AbstractSharedObjectsStreamOperator extends AbstractStreamOperator { + + /** + * A unique identifier for the instance, which is kept unchanged between client side and + * runtime. + */ + private final String accessorID; + + /** The context for shared objects reads/writes. */ + protected transient SharedObjectsContext context; + + AbstractSharedObjectsStreamOperator() { + super(); + accessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + } + + void onSharedObjectsContextSet(SharedObjectsContext context) { + this.context = context; + } + + String getAccessorID() { + return accessorID; + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsStreamOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsTwoInputStreamOperator.java similarity index 61% rename from flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsStreamOperator.java rename to flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsTwoInputStreamOperator.java index 52afa9617..ee7b59dc5 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsStreamOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsTwoInputStreamOperator.java @@ -18,21 +18,16 @@ package org.apache.flink.ml.common.sharedobjects; -/** Interface for all operators that need to access the shared objects. */ -public interface SharedObjectsStreamOperator { +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; - /** - * Set the shared objects context in runtime. - * - * @param context The context for shared objects. - */ - void onSharedObjectsContextSet(SharedObjectsContext context); +import java.util.List; - /** - * Get a unique ID to represent the operator instance. The ID must be kept unchanged through its - * lifetime. - * - * @return A unique ID. - */ - String getSharedObjectsAccessorID(); +/** The base class for {@link TwoInputStreamOperator}s where shared objects are accessed. */ +public abstract class AbstractSharedObjectsTwoInputStreamOperator + extends AbstractSharedObjectsStreamOperator + implements TwoInputStreamOperator { + + public abstract List> readRequestsInProcessElement1(); + + public abstract List> readRequestsInProcessElement2(); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java index 6b1f3e5cf..bed1f23df 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/AbstractSharedObjectsWrapperOperator.java @@ -21,8 +21,11 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext; import org.apache.flink.metrics.groups.OperatorMetricGroup; +import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement; +import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementSerializer; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.jobgraph.OperatorID; @@ -32,7 +35,6 @@ import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.graph.StreamConfig; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; import org.apache.flink.streaming.api.operators.Output; @@ -44,20 +46,28 @@ import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler; import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler.CheckpointedStreamOperator; import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.util.CloseableIterator; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.ThrowingConsumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Queue; /** Base class for the shared objects wrapper operators. */ -abstract class AbstractSharedObjectsWrapperOperator> +abstract class AbstractSharedObjectsWrapperOperator< + T, S extends AbstractSharedObjectsStreamOperator> implements StreamOperator, IterationListener, CheckpointedStreamOperator { private static final Logger LOG = @@ -72,9 +82,18 @@ abstract class AbstractSharedObjectsWrapperOperator> output; protected final StreamOperatorFactory operatorFactory; - private final SharedObjectsContextImpl context; + protected final OperatorMetricGroup metrics; + protected final S wrappedOperator; + + private final SharedObjectsContextImpl context; + private final int numInputs; + private final TypeSerializer[] inTypeSerializers; + private final ListStateWithCache>[] cachedElements; + private final Queue>[] readRequests; + private final boolean[] hasCachedElements; + protected transient StreamOperatorStateHandler stateHandler; protected transient InternalTimeServiceManager timeServiceManager; @@ -100,12 +119,27 @@ abstract class AbstractSharedObjectsWrapperOperator(getInputReadRequests(i)); + } + cachedElements = new ListStateWithCache[numInputs]; + hasCachedElements = new boolean[numInputs]; + Arrays.fill(hasCachedElements, false); } private OperatorMetricGroup createOperatorMetricGroup( @@ -127,6 +161,171 @@ private OperatorMetricGroup createOperatorMetricGroup( } } + /** + * Checks if the read requests are satisfied for the input. + * + * @param inputId The input id, starting from 0. + * @param wait Whether to wait until all requests satisfied, or not. + * @return If all requests of this input are satisfied. + */ + private boolean checkReadRequestsReady(int inputId, boolean wait) { + Queue> requests = readRequests[inputId]; + while (!requests.isEmpty()) { + ReadRequest request = requests.poll(); + try { + if (null == context.read(request, wait)) { + requests.add(request); + return false; + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + return true; + } + + /** + * Gets {@link ReadRequest}s required for processing elements in the input. + * + * @param inputId The input id, starting from 0. + * @return The {@link ReadRequest}s required for processing elements. + */ + protected abstract List> getInputReadRequests(int inputId); + + /** + * Extracts common processing logic in subclasses' processing elements. + * + * @param streamRecord The input record. + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings({"rawtypes"}) + protected void processElementX( + StreamRecord streamRecord, + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + if (checkReadRequestsReady(inputId, false)) { + if (hasCachedElements[inputId]) { + processCachedElements( + inputId, elementConsumer, watermarkConsumer, keyContextSetter); + hasCachedElements[inputId] = false; + } + keyContextSetter.accept(streamRecord); + elementConsumer.accept(streamRecord); + } else { + cachedElements[inputId].add(CacheElement.newRecord(streamRecord.getValue())); + hasCachedElements[inputId] = true; + } + } + + /** + * Extracts common processing logic in subclasses' processing watermarks. + * + * @param watermark The input watermark. + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings({"rawtypes"}) + protected void processWatermarkX( + Watermark watermark, + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + if (checkReadRequestsReady(inputId, false)) { + if (hasCachedElements[inputId]) { + processCachedElements( + inputId, elementConsumer, watermarkConsumer, keyContextSetter); + hasCachedElements[inputId] = false; + } + watermarkConsumer.accept(watermark); + } else { + cachedElements[inputId].add(CacheElement.newWatermark(watermark.getTimestamp())); + hasCachedElements[inputId] = true; + } + } + + /** + * Extracts common processing logic in subclasses' endInput(...). + * + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings("rawtypes") + protected void endInputX( + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + if (hasCachedElements[inputId]) { + checkReadRequestsReady(inputId, true); + processCachedElements(inputId, elementConsumer, watermarkConsumer, keyContextSetter); + hasCachedElements[inputId] = false; + } + } + + /** + * Processes elements that are cached by {@link ListStateWithCache}. + * + * @param inputId The input id, starting from 0. + * @param elementConsumer The consumer function of StreamRecord, i.e., + * operator.processElement(...). + * @param watermarkConsumer The consumer function of WaterMark, i.e., + * operator.processWatermark(...). + * @param keyContextSetter The consumer function of setting key context, i.e., + * operator.setKeyContext(...). + * @throws Exception Possible exception. + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + private void processCachedElements( + int inputId, + ThrowingConsumer elementConsumer, + ThrowingConsumer watermarkConsumer, + ThrowingConsumer keyContextSetter) + throws Exception { + for (CacheElement cacheElement : cachedElements[inputId].get()) { + switch (cacheElement.getType()) { + case RECORD: + StreamRecord record = new StreamRecord(cacheElement.getRecord()); + keyContextSetter.accept(record); + elementConsumer.accept(record); + break; + case WATERMARK: + watermarkConsumer.accept(new Watermark(cacheElement.getWatermark())); + break; + default: + throw new RuntimeException( + "Unsupported CacheElement type: " + cacheElement.getType()); + } + } + cachedElements[inputId].clear(); + Preconditions.checkState(readRequests[inputId].isEmpty()); + readRequests[inputId].addAll(getInputReadRequests(inputId)); + } + @Override public void open() throws Exception { wrappedOperator.open(); @@ -149,19 +348,28 @@ public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { } @Override + @SuppressWarnings("unchecked, rawtypes") public void initializeState(StateInitializationContext stateInitializationContext) throws Exception { - context.initializeState( - wrappedOperator, - ((AbstractStreamOperator) wrappedOperator).getRuntimeContext(), - stateInitializationContext); + StreamingRuntimeContext runtimeContext = wrappedOperator.getRuntimeContext(); + context.initializeState(wrappedOperator, runtimeContext, stateInitializationContext); + for (int i = 0; i < numInputs; i++) { + cachedElements[i] = + new ListStateWithCache<>( + new CacheElementSerializer(inTypeSerializers[i]), + containingTask, + runtimeContext, + stateInitializationContext, + streamConfig.getOperatorID()); + } } @Override public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception { context.snapshotState(stateSnapshotContext); - if (wrappedOperator instanceof StreamOperatorStateHandler.CheckpointedStreamOperator) { - ((CheckpointedStreamOperator) wrappedOperator).snapshotState(stateSnapshotContext); + wrappedOperator.snapshotState(stateSnapshotContext); + for (int i = 0; i < numInputs; i++) { + cachedElements[i].snapshotState(stateSnapshotContext); } } @@ -272,9 +480,16 @@ public void setCurrentKey(Object key) { wrappedOperator.setCurrentKey(key); } + protected abstract void processCachedElementsBeforeEpochIncremented(int inputId) + throws Exception; + @Override public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { + for (int i = 0; i < numInputs; i += 1) { + processCachedElementsBeforeEpochIncremented(i); + } + this.context.incStep(epochWatermark); if (wrappedOperator instanceof IterationListener) { //noinspection unchecked ((IterationListener) wrappedOperator) @@ -284,6 +499,7 @@ public void onEpochWatermarkIncremented( @Override public void onIterationTerminated(Context context, Collector collector) throws Exception { + this.context.incStep(); if (wrappedOperator instanceof IterationListener) { //noinspection unchecked ((IterationListener) wrappedOperator).onIterationTerminated(context, collector); diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/Descriptor.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/Descriptor.java new file mode 100644 index 000000000..fd61957b2 --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/Descriptor.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.io.Serializable; + +/** + * Descriptor for a shared object. + * + *

A shared object can have a non-null initial value, or have no initial values. If a non-null + * initial value provided, it is set with an initial write-step (See {@link ReadRequest}). + * + * @param The type of the shared object. + */ +@Experimental +public class Descriptor implements Serializable { + + /** Name of the shared object. */ + public final String name; + + /** Type serializer. */ + public final TypeSerializer serializer; + + /** Initialize value. */ + public final @Nullable T initVal; + + private Descriptor(String name, TypeSerializer serializer, T initVal) { + this.name = name; + this.serializer = serializer; + this.initVal = initVal; + } + + public static Descriptor of(String name, TypeSerializer serializer, T initVal) { + Preconditions.checkNotNull( + initVal, "Cannot use `null` as the initial value of a shared object."); + return new Descriptor<>(name, serializer, initVal); + } + + public static Descriptor of(String name, TypeSerializer serializer) { + return new Descriptor<>(name, serializer, null); + } + + /** + * Creates a read request which always reads this shared object with same read-step as the + * operator step. + * + * @return A read request. + */ + public ReadRequest sameStep() { + return new ReadRequest<>(this, ReadRequest.OFFSET.SAME); + } + + /** + * Creates a read request which always reads this shared object with the read-step be the + * previous item of the operator step. + * + * @return A read request. + */ + public ReadRequest prevStep() { + return new ReadRequest<>(this, ReadRequest.OFFSET.PREV); + } + + /** + * Creates a read request which always reads this shared object with the read-step be the next + * item of the operator step. + * + * @return A read request. + */ + public ReadRequest nextStep() { + return new ReadRequest<>(this, ReadRequest.OFFSET.NEXT); + } + + @Override + public int hashCode() { + return name.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Descriptor that = (Descriptor) o; + return name.equals(that.name); + } + + @Override + public String toString() { + return String.format( + "Descriptor{name='%s', serializer=%s, initVal=%s}", name, serializer, initVal); + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ItemDescriptor.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ItemDescriptor.java deleted file mode 100644 index 53570e342..000000000 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ItemDescriptor.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.sharedobjects; - -import org.apache.flink.annotation.Experimental; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.util.Preconditions; - -import java.io.Serializable; - -/** - * Descriptor for a shared item. - * - * @param The type of the shared item. - */ -@Experimental -public class ItemDescriptor implements Serializable { - - /** Name of the item. */ - public final String name; - - /** Type serializer. */ - public final TypeSerializer serializer; - - /** Initialize value. */ - public final T initVal; - - private ItemDescriptor(String name, TypeSerializer serializer, T initVal) { - Preconditions.checkNotNull( - initVal, "Cannot use `null` as the initial value of a shared item."); - this.name = name; - this.serializer = serializer; - this.initVal = initVal; - } - - public static ItemDescriptor of(String name, TypeSerializer serializer, T initVal) { - return new ItemDescriptor<>(name, serializer, initVal); - } - - @Override - public int hashCode() { - return name.hashCode(); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ItemDescriptor that = (ItemDescriptor) o; - return name.equals(that.name); - } - - @Override - public String toString() { - return String.format( - "ItemDescriptor{name='%s', serializer=%s, initVal=%s}", name, serializer, initVal); - } -} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java index b6a197f5c..68f3b3687 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/OneInputSharedObjectsWrapperOperator.java @@ -27,10 +27,14 @@ import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.util.Preconditions; + +import java.util.List; /** Wrapper for {@link OneInputStreamOperator}. */ class OneInputSharedObjectsWrapperOperator - extends AbstractSharedObjectsWrapperOperator> + extends AbstractSharedObjectsWrapperOperator< + OUT, AbstractSharedObjectsOneInputStreamOperator> implements OneInputStreamOperator, BoundedOneInput { OneInputSharedObjectsWrapperOperator( @@ -40,20 +44,51 @@ class OneInputSharedObjectsWrapperOperator super(parameters, operatorFactory, context); } + @Override + protected List> getInputReadRequests(int inputId) { + Preconditions.checkArgument(0 == inputId); + return wrappedOperator.readRequestsInProcessElement(); + } + + @Override + protected void processCachedElementsBeforeEpochIncremented(int inputId) throws Exception { + Preconditions.checkArgument(0 == inputId); + endInputX( + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); + } + @Override public void processElement(StreamRecord streamRecord) throws Exception { - wrappedOperator.processElement(streamRecord); + processElementX( + streamRecord, + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); } @Override public void endInput() throws Exception { + endInputX( + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); OperatorUtils.processOperatorOrUdfIfSatisfy( wrappedOperator, BoundedOneInput.class, BoundedOneInput::endInput); } @Override public void processWatermark(Watermark watermark) throws Exception { - wrappedOperator.processWatermark(watermark); + processWatermarkX( + watermark, + 0, + wrappedOperator::processElement, + wrappedOperator::processWatermark, + wrappedOperator::setKeyContextElement); } @Override diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ReadRequest.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ReadRequest.java new file mode 100644 index 000000000..c4f6178de --- /dev/null +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/ReadRequest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.sharedobjects; + +import org.apache.flink.annotation.Experimental; +import org.apache.flink.iteration.IterationListener; + +import java.io.Serializable; + +/** + * A read request for a shared object with given step offset. The step {@link OFFSET} is used to + * calculate read-step from current operator step. + * + *

The concept of `step` is first defined on operators. Every operator maintains its `step` + * implicitly. For operators in non-iterations usage, their `step`s are treated as constants. While + * for operators in iterations usage, their `step`s are bound to the epoch watermarks: + * + *

With every call of {@link IterationListener#onEpochWatermarkIncremented}, the value of step is + * set to the epoch watermark. Before the first call of {@link + * IterationListener#onEpochWatermarkIncremented}, the step is set to a small enough value. While + * after {@link IterationListener#onIterationTerminated}, the step is set to a large enough value. + * In this way, the changes of step can be considered as an ordered sequence. Note that, the `step` + * is implicitly maintained by the infrastructure, even if the operator is not implementing {@link + * IterationListener}. + * + *

Then, the concept of `step` is defined on reads and writes of shared objects. Every write + * brings the step of its owner operator at that moment, which is named as `write-step`. To read the + * shared object with the exact `write-step`, the reader operator must provide a same `read-step`. + * The `read-step` could be different from that of the reader operator, and their difference is kept + * unchanged, which is the step offset defined in {@link ReadRequest#offset}. + * + * @param The type of the shared object. + */ +@Experimental +public class ReadRequest implements Serializable { + final Descriptor descriptor; + final OFFSET offset; + + ReadRequest(Descriptor descriptor, OFFSET offset) { + this.descriptor = descriptor; + this.offset = offset; + } + + enum OFFSET { + SAME, + PREV, + NEXT, + } +} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java index 64c737c70..9f938a233 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsBody.java @@ -44,33 +44,33 @@ public interface SharedObjectsBody extends Serializable { * * @param inputs Input data streams. * @return Result of the subgraph, including output data streams, data streams with access to - * the shared objects, and a mapping from share items to their owners. + * the shared objects, and a mapping from share objects to their owners. */ SharedObjectsBodyResult process(List> inputs); /** * The result of a {@link SharedObjectsBody}, including output data streams, data streams with - * access to the shared objects, and a mapping from descriptors of share items to their owners. + * access to the shared objects, and a mapping from descriptors of share objects to their + * owners. */ @Experimental class SharedObjectsBodyResult { /** A list of output streams. */ private final List> outputs; - /** A list of {@link Transformation}s that should be co-located. */ - private final List> coLocatedTransformations; - /** - * A mapping from descriptors of shared items to their owners. The owner is specified by - * {@link SharedObjectsStreamOperator#getSharedObjectsAccessorID()}, which must be kept - * unchanged for an instance of {@link SharedObjectsStreamOperator}. + * A list of {@link Transformation}s that should be co-located, which should include all + * subclasses of {@link AbstractSharedObjectsStreamOperator}. */ - private final Map, SharedObjectsStreamOperator> ownerMap; + private final List> coLocatedTransformations; + + /** A mapping from descriptors of shared objects to their owner operators. */ + private final Map, AbstractSharedObjectsStreamOperator> ownerMap; public SharedObjectsBodyResult( List> outputs, List> coLocatedTransformations, - Map, SharedObjectsStreamOperator> ownerMap) { + Map, AbstractSharedObjectsStreamOperator> ownerMap) { this.outputs = outputs; this.coLocatedTransformations = coLocatedTransformations; this.ownerMap = ownerMap; @@ -84,7 +84,7 @@ public List> getCoLocatedTransformations() { return coLocatedTransformations; } - public Map, SharedObjectsStreamOperator> getOwnerMap() { + public Map, AbstractSharedObjectsStreamOperator> getOwnerMap() { return ownerMap; } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java index f5419780e..7ec0acf8c 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContext.java @@ -19,35 +19,48 @@ package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.annotation.Experimental; -import org.apache.flink.util.function.BiConsumerWithException; /** - * Context for shared objects. Every operator implementing {@link SharedObjectsStreamOperator} will - * get an instance of this context set by {@link - * SharedObjectsStreamOperator#onSharedObjectsContextSet} in runtime. User-defined logic can be - * invoked through {@link #invoke} with the access to shared items. + * Context for shared objects. Every operator implementing {@link + * AbstractSharedObjectsStreamOperator} will get an instance of this context set by {@link + * AbstractSharedObjectsStreamOperator#onSharedObjectsContextSet} in runtime. + * + *

See {@link ReadRequest} for details about coordination between reads and writes. */ @Experimental public interface SharedObjectsContext { /** - * Invoke user defined function with provided getters/setters of the shared objects. + * Reads the value of a shared object. + * + *

For subclasses of {@link AbstractSharedObjectsOneInputStreamOperator} and {@link + * AbstractSharedObjectsTwoInputStreamOperator}, this method is guaranteed to return non-null + * values immediately. * - * @param func User defined function where share items can be accessed through getters/setters. - * @throws Exception Possible exception. + * @param request A read request of a shared object. + * @return The value of the shared object. + * @param The type of the shared object. */ - void invoke(BiConsumerWithException func) - throws Exception; + T read(ReadRequest request); - /** Interface of shared item getter. */ - @FunctionalInterface - interface SharedItemGetter { - T get(ItemDescriptor key); - } + /** + * Writes a new value to the shared object. + * + * @param descriptor The shared object descriptor. + * @param value The value to be set. + * @param The type of the shared object. + */ + void write(Descriptor descriptor, T value); - /** Interface of shared item writer. */ - @FunctionalInterface - interface SharedItemSetter { - void set(ItemDescriptor key, T value); - } + /** + * Renew the shared object with current step. + * + *

For subclasses of {@link AbstractSharedObjectsOneInputStreamOperator} and {@link + * AbstractSharedObjectsTwoInputStreamOperator}, this method is guaranteed to return + * immediately. + * + * @param descriptor The shared object descriptor. + * @param The type of the shared object. + */ + void renew(Descriptor descriptor); } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java index 0c57b9cef..9819baf82 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsContextImpl.java @@ -24,77 +24,65 @@ import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.util.Preconditions; -import org.apache.flink.util.function.BiConsumerWithException; + +import javax.annotation.Nullable; import java.io.Serializable; import java.util.HashMap; import java.util.Map; +import static org.apache.flink.ml.common.sharedobjects.SharedObjectsPools.getReader; +import static org.apache.flink.ml.common.sharedobjects.SharedObjectsPools.getWriter; + /** - * Default implementation of {@link SharedObjectsContext}. + * A default implementation of {@link SharedObjectsContext}. * - *

It initializes readers and writers according to the owner map when the subtask starts and - * clean internal states when the subtask finishes. It also handles `initializeState` and - * `snapshotState` automatically. + *

It initializes readers and writers of shared objects according to the owner map when the + * subtask starts and clean internal states when the subtask finishes. It also handles + * `initializeState` and `snapshotState` automatically. */ @SuppressWarnings("rawtypes") class SharedObjectsContextImpl implements SharedObjectsContext, Serializable { - private final PoolID poolID; - private final Map writers = new HashMap<>(); - private final Map readers = new HashMap<>(); - private Map, String> ownerMap; + private final SharedObjectsPools.PoolID poolID; + private final Map writers = new HashMap<>(); + private final Map readers = new HashMap<>(); + private Map, String> ownerMap; + + /** The step of corresponding operator. See {@link ReadRequest} for more information. */ + private int step; public SharedObjectsContextImpl() { - this.poolID = new PoolID(); + this.poolID = new SharedObjectsPools.PoolID(); + step = -1; } - void setOwnerMap(Map, String> ownerMap) { + void setOwnerMap(Map, String> ownerMap) { this.ownerMap = ownerMap; } - @Override - public void invoke(BiConsumerWithException func) - throws Exception { - func.accept(this::getSharedItem, this::setSharedItem); - } - - private T getSharedItem(ItemDescriptor key) { - //noinspection unchecked - SharedObjectsPools.Reader reader = readers.get(key); - Preconditions.checkState( - null != reader, - String.format( - "The operator requested to read a shared item %s not owned by itself.", - key)); - return reader.get(); + void incStep(@Nullable Integer targetStep) { + step += 1; + // Sanity check + Preconditions.checkState(null == targetStep || step == targetStep); } - private void setSharedItem(ItemDescriptor key, T value) { - //noinspection unchecked - SharedObjectsPools.Writer writer = writers.get(key); - Preconditions.checkState( - null != writer, - String.format( - "The operator requested to read a shared item %s not owned by itself.", - key)); - writer.set(value); + void incStep() { + incStep(null); } void initializeState( StreamOperator operator, StreamingRuntimeContext runtimeContext, StateInitializationContext context) { - Preconditions.checkArgument( - operator instanceof SharedObjectsStreamOperator - && operator instanceof AbstractStreamOperator); - String ownerId = ((SharedObjectsStreamOperator) operator).getSharedObjectsAccessorID(); + Preconditions.checkArgument(operator instanceof AbstractSharedObjectsStreamOperator); + String ownerId = ((AbstractSharedObjectsStreamOperator) operator).getAccessorID(); int subtaskId = runtimeContext.getIndexOfThisSubtask(); - for (Map.Entry, String> entry : ownerMap.entrySet()) { - ItemDescriptor descriptor = entry.getKey(); + for (Map.Entry, String> entry : ownerMap.entrySet()) { + Descriptor descriptor = entry.getKey(); if (ownerId.equals(entry.getValue())) { writers.put( descriptor, - SharedObjectsPools.getWriter( + getWriter( poolID, subtaskId, descriptor, @@ -102,9 +90,10 @@ void initializeState( operator.getOperatorID(), ((AbstractStreamOperator) operator).getContainingTask(), runtimeContext, - context)); + context, + step)); } - readers.put(descriptor, SharedObjectsPools.getReader(poolID, subtaskId, descriptor)); + readers.put(descriptor, getReader(poolID, subtaskId, descriptor)); } } @@ -124,4 +113,61 @@ void clear() { writers.clear(); readers.clear(); } + + @Override + public T read(ReadRequest request) { + try { + return read(request, false); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + /** + * Gets the value of the shared object with possible waiting. + * + * @param request A read request of a shared object. + * @param wait Whether to wait or not. + * @return The value of the shared object, or null if not set yet. + * @param The type of the shared object. + */ + T read(ReadRequest request, boolean wait) throws InterruptedException { + Descriptor descriptor = request.descriptor; + //noinspection unchecked + SharedObjectsPools.Reader reader = readers.get(descriptor); + switch (request.offset) { + case SAME: + return reader.get(step, wait); + case PREV: + return reader.get(step - 1, wait); + case NEXT: + return reader.get(step + 1, wait); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public void write(Descriptor descriptor, T value) { + //noinspection unchecked + SharedObjectsPools.Writer writer = writers.get(descriptor); + Preconditions.checkState( + null != writer, + String.format( + "The operator requestes to write a shared object %s not owned by itself.", + descriptor)); + writer.set(value, step); + } + + @Override + public void renew(Descriptor descriptor) { + try { + //noinspection unchecked + write( + descriptor, + ((SharedObjectsPools.Reader) readers.get(descriptor)).get(step - 1, false)); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java index 5be832635..25f64151c 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsPools.java @@ -19,64 +19,86 @@ package org.apache.flink.ml.common.sharedobjects; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.util.AbstractID; import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; /** - * Stores and manages all shared objects. Every shared object is identified by a tuple of (Pool ID, - * subtask ID, name). Every call of {@link SharedObjectsUtils#withSharedObjects} generated a - * different {@link PoolID}, so that they do not interfere with each other. + * Stores all shared objects and coordinates their reads and writes. + * + *

Every shared object is identified by a tuple of (Pool ID, subtask ID, name). Their reads and + * writes are coordinated through the read- and write-steps. */ class SharedObjectsPools { - // Stores values of all shared objects. - private static final Map, Object> values = + private static final Logger LOG = LoggerFactory.getLogger(SharedObjectsPools.class); + + /** Stores values and corresponding write-steps of all shared objects. */ + private static final Map, Tuple2> values = new ConcurrentHashMap<>(); + /** + * Stores waiting read requests of all shared objects, including read-steps and count-down + * latches for notification when shared objects are ready. + */ + private static final Map, List>> + waitQueues = new ConcurrentHashMap<>(); + /** * Stores owners of all shared objects, where the owner is identified by the accessor ID - * obtained from {@link SharedObjectsStreamOperator#getSharedObjectsAccessorID()}. + * obtained from {@link AbstractSharedObjectsStreamOperator#getAccessorID()}. */ private static final Map, String> owners = new ConcurrentHashMap<>(); - // Stores number of references of all shared objects. A shared object is removed when its number - // of references decreased to 0. + /** + * Stores number of references of all shared objects. Every {@link Reader} and {@link Writer} + * counts. A shared object is removed from the pool when its number of references decreased to + * 0. + */ private static final ConcurrentHashMap, Integer> numRefs = new ConcurrentHashMap<>(); - @SuppressWarnings("UnusedReturnValue") - static int incRef(Tuple3 itemId) { - return numRefs.compute(itemId, (k, oldV) -> null == oldV ? 1 : oldV + 1); + private static void incRef(Tuple3 objId) { + numRefs.compute(objId, (k, oldV) -> null == oldV ? 1 : oldV + 1); } - @SuppressWarnings("UnusedReturnValue") - static int decRef(Tuple3 itemId) { - int num = numRefs.compute(itemId, (k, oldV) -> oldV - 1); + private static void decRef(Tuple3 objId) { + int num = numRefs.compute(objId, (k, oldV) -> oldV - 1); if (num == 0) { - values.remove(itemId); - owners.remove(itemId); - numRefs.remove(itemId); + values.remove(objId); + waitQueues.remove(objId); + owners.remove(objId); + numRefs.remove(objId); } - return num; } /** Gets a {@link Reader} of a shared object. */ - static Reader getReader(PoolID poolID, int subtaskId, ItemDescriptor descriptor) { - Tuple3 itemId = Tuple3.of(poolID, subtaskId, descriptor.name); - Reader reader = new Reader<>(itemId); - incRef(itemId); + static Reader getReader(PoolID poolID, int subtaskId, Descriptor descriptor) { + Tuple3 objId = Tuple3.of(poolID, subtaskId, descriptor.name); + Reader reader = new Reader<>(objId); + incRef(objId); return reader; } @@ -84,18 +106,19 @@ static Reader getReader(PoolID poolID, int subtaskId, ItemDescriptor d static Writer getWriter( PoolID poolId, int subtaskId, - ItemDescriptor descriptor, + Descriptor descriptor, String ownerId, OperatorID operatorID, StreamTask containingTask, StreamingRuntimeContext runtimeContext, - StateInitializationContext stateInitializationContext) { + StateInitializationContext stateInitializationContext, + int step) { Tuple3 objId = Tuple3.of(poolId, subtaskId, descriptor.name); String lastOwner = owners.putIfAbsent(objId, ownerId); if (null != lastOwner) { throw new IllegalStateException( String.format( - "The shared item (%s, %s, %s) already has a writer %s.", + "The shared object (%s, %s, %s) already has a writer %s.", poolId, subtaskId, descriptor.name, ownerId)); } Writer writer = @@ -107,11 +130,18 @@ static Writer getWriter( runtimeContext, stateInitializationContext, operatorID); - writer.set(descriptor.initVal); incRef(objId); + if (null != descriptor.initVal) { + writer.set(descriptor.initVal, step); + } return writer; } + /** + * Reader of a shared object. + * + * @param The type of the shared object. + */ static class Reader { protected final Tuple3 objId; @@ -119,26 +149,55 @@ static class Reader { this.objId = objId; } - T get() { - // It is possible that the `get` request of an item is triggered earlier than its - // initialization. In this case, we wait for a while. - long waitTime = 10; - do { - //noinspection unchecked - T value = (T) values.get(objId); - if (null != value) { - return value; + /** + * Gets the value with given read-step. There are 3 cases: + * + *

    + *
  1. The read-step is equal to the write-step: returns the value immediately. + *
  2. The read-step is larger than the write-step, or there is no values written yet: + * waits until the value with same write-step set if `wait` is true, or returns null + * otherwise. + *
  3. The read-step is smaller than the write-step: throws an exception as it is illegal. + *
+ * + * @param readStep The read-step. + * @param wait Whether to wait until the value with same write-step presents. + * @return The value or null. A return value of null means the corresponding value if not + * presented. If `wait` is true, the return value of this function is guaranteed to be a + * non-null value if it returns. + * @throws InterruptedException Interrupted when waiting. + */ + T get(int readStep, boolean wait) throws InterruptedException { + //noinspection unchecked + Tuple2 stepV = (Tuple2) values.get(objId); + if (null != stepV) { + int writeStep = stepV.f0; + LOG.debug("Get {} with read-step {}, write-step is {}", objId, readStep, writeStep); + Preconditions.checkState( + writeStep <= readStep, + String.format( + "Current write-step %d of %s is larger than read-step %d, which is illegal.", + writeStep, objId, readStep)); + if (readStep == stepV.f0) { + return stepV.f1; } - try { - Thread.sleep(waitTime); - } catch (InterruptedException e) { - break; + } + if (!wait) { + return null; + } + CountDownLatch latch = new CountDownLatch(1); + synchronized (waitQueues) { + if (!waitQueues.containsKey(objId)) { + waitQueues.put(objId, new ArrayList<>()); } - waitTime *= 2; - } while (waitTime < 10 * 1000); - throw new IllegalStateException( - String.format( - "Failed to get value of %s after waiting %d ms.", objId, waitTime)); + List> q = waitQueues.get(objId); + q.add(Tuple2.of(readStep, latch)); + } + latch.await(); + //noinspection unchecked + stepV = (Tuple2) values.get(objId); + Preconditions.checkState(stepV.f0 == readStep); + return stepV.f1; } void remove() { @@ -146,34 +205,43 @@ void remove() { } } + /** + * Writer of a shared object. + * + * @param The type of the shared object. + */ static class Writer extends Reader { private final String ownerId; - private final ListStateWithCache cache; + private final ListStateWithCache> cache; private boolean isDirty; Writer( - Tuple3 itemId, + Tuple3 objId, String ownerId, TypeSerializer serializer, StreamTask containingTask, StreamingRuntimeContext runtimeContext, StateInitializationContext stateInitializationContext, OperatorID operatorID) { - super(itemId); + super(objId); this.ownerId = ownerId; try { + //noinspection unchecked cache = new ListStateWithCache<>( - serializer, + new TupleSerializer<>( + (Class>) (Class) Tuple2.class, + new TypeSerializer[] {IntSerializer.INSTANCE, serializer}), containingTask, runtimeContext, stateInitializationContext, operatorID); - Iterator iterator = cache.get().iterator(); + Iterator> iterator = cache.get().iterator(); if (iterator.hasNext()) { - T value = iterator.next(); + Tuple2 stepV = iterator.next(); ensureOwner(); - values.put(itemId, value); + //noinspection unchecked + values.put(objId, (Tuple2) stepV); } } catch (Exception e) { throw new RuntimeException(e); @@ -182,15 +250,35 @@ static class Writer extends Reader { } private void ensureOwner() { - // Double-checks the owner, because a writer may call this method after the key removed - // and re-added by other operators. Preconditions.checkState(owners.get(objId).equals(ownerId)); } - void set(T value) { + /** + * Sets the value with given write-step. If there are read requests waiting for the value of + * exact the same write-step, they are notified. + * + * @param value The value. + * @param writeStep The write-step. + */ + void set(T value, int writeStep) { ensureOwner(); - values.put(objId, value); + values.put(objId, Tuple2.of(writeStep, value)); + LOG.debug("Set {} with write-step {}", objId, writeStep); isDirty = true; + synchronized (waitQueues) { + if (!waitQueues.containsKey(objId)) { + waitQueues.put(objId, new ArrayList<>()); + } + List> q = waitQueues.get(objId); + ListIterator> iter = q.listIterator(); + while (iter.hasNext()) { + Tuple2 next = iter.next(); + if (writeStep == next.f0) { + iter.remove(); + next.f1.countDown(); + } + } + } } @Override @@ -202,10 +290,22 @@ void remove() { void snapshotState(StateSnapshotContext context) throws Exception { if (isDirty) { - cache.update(Collections.singletonList(get())); + //noinspection unchecked + cache.update(Collections.singletonList((Tuple2) values.get(objId))); isDirty = false; } cache.snapshotState(context); } } + + /** ID of a pool for shared objects. */ + static class PoolID extends AbstractID { + private static final long serialVersionUID = 1L; + + public PoolID(byte[] bytes) { + super(bytes); + } + + public PoolID() {} + } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java index 022469bb8..2c884006a 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtils.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.Experimental; import org.apache.flink.api.dag.Transformation; +import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.compile.DraftExecutionEnvironment; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -37,21 +38,32 @@ public class SharedObjectsUtils { /** - * Support read/write access of data in the shared objects from operators which implements - * {@link SharedObjectsStreamOperator}. + * Supports read/write access of data in the shared objects from operators which inherit {@link + * AbstractSharedObjectsStreamOperator}. * *

In the shared objects `body`, users build the subgraph with data streams only from * `inputs`, return streams that have access to the shared objects, and return the mapping from - * shared items to their owners. + * shared objects to their owners. + * + *

There are several limitations to use this function: + * + *

    + *
  1. Only synchronized iterations and non-iterations are supported. + *
  2. Reads and writes of shared objects must obey strict rules defined on `step`s, as stated + * in {@link ReadRequest}. + *
  3. When in iterations, writes of shared objects can only occur in {@link + * IterationListener#onEpochWatermarkIncremented} and {@link + * IterationListener#onIterationTerminated}. + *
* * @param inputs Input data streams. - * @param body User defined logic to build subgraph and to specify owners of every shared data - * item. + * @param body User defined logic to build subgraph and to specify owners of every shared + * object. * @return The output data streams. */ public static List> withSharedObjects( List> inputs, SharedObjectsBody body) { - Preconditions.checkArgument(inputs.size() > 0); + Preconditions.checkArgument(!inputs.isEmpty()); StreamExecutionEnvironment env = inputs.get(0).getExecutionEnvironment(); String coLocationID = "shared-storage-" + UUID.randomUUID(); SharedObjectsContextImpl context = new SharedObjectsContextImpl(); @@ -67,10 +79,11 @@ public static List> withSharedObjects( SharedObjectsBody.SharedObjectsBodyResult result = body.process(draftSources); List> draftOutputs = result.getOutputs(); - Map, SharedObjectsStreamOperator> rawOwnerMap = result.getOwnerMap(); - Map, String> ownerMap = new HashMap<>(); - for (ItemDescriptor item : rawOwnerMap.keySet()) { - ownerMap.put(item, rawOwnerMap.get(item).getSharedObjectsAccessorID()); + Map, AbstractSharedObjectsStreamOperator> rawOwnerMap = + result.getOwnerMap(); + Map, String> ownerMap = new HashMap<>(); + for (Descriptor descriptor : rawOwnerMap.keySet()) { + ownerMap.put(descriptor, rawOwnerMap.get(descriptor).getAccessorID()); } context.setOwnerMap(ownerMap); diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java index 9c0c06564..c4e4e0102 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsWrapper.java @@ -47,7 +47,7 @@ public StreamOperator wrap( StreamOperatorFactory operatorFactory) { Class operatorClass = operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); - if (SharedObjectsStreamOperator.class.isAssignableFrom(operatorClass)) { + if (AbstractSharedObjectsStreamOperator.class.isAssignableFrom(operatorClass)) { if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { return new OneInputSharedObjectsWrapperOperator<>( operatorParameters, operatorFactory, context); diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java index 306c84517..fedc8f3ac 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/sharedobjects/TwoInputSharedObjectsWrapperOperator.java @@ -27,10 +27,14 @@ import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.util.Preconditions; + +import java.util.List; /** Wrapper for {@link TwoInputStreamOperator}. */ class TwoInputSharedObjectsWrapperOperator - extends AbstractSharedObjectsWrapperOperator> + extends AbstractSharedObjectsWrapperOperator< + OUT, AbstractSharedObjectsTwoInputStreamOperator> implements TwoInputStreamOperator, BoundedMultiInput { TwoInputSharedObjectsWrapperOperator( @@ -40,18 +44,70 @@ class TwoInputSharedObjectsWrapperOperator super(parameters, operatorFactory, context); } + @Override + protected List> getInputReadRequests(int inputId) { + Preconditions.checkArgument(0 == inputId || 1 == inputId); + if (0 == inputId) { + return wrappedOperator.readRequestsInProcessElement1(); + } else { + return wrappedOperator.readRequestsInProcessElement2(); + } + } + + @Override + protected void processCachedElementsBeforeEpochIncremented(int inputId) throws Exception { + Preconditions.checkArgument(0 == inputId || 1 == inputId); + if (inputId == 0) { + endInputX( + inputId, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); + } else { + endInputX( + inputId, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); + } + } + @Override public void processElement1(StreamRecord streamRecord) throws Exception { - wrappedOperator.processElement1(streamRecord); + processElementX( + streamRecord, + 0, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); } @Override public void processElement2(StreamRecord streamRecord) throws Exception { - wrappedOperator.processElement2(streamRecord); + processElementX( + streamRecord, + 1, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); } @Override public void endInput(int inputId) throws Exception { + Preconditions.checkArgument(1 == inputId || 2 == inputId); + if (1 == inputId) { + endInputX( + 0, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); + } else { + endInputX( + inputId - 1, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); + } OperatorUtils.processOperatorOrUdfIfSatisfy( wrappedOperator, BoundedMultiInput.class, @@ -60,12 +116,22 @@ public void endInput(int inputId) throws Exception { @Override public void processWatermark1(Watermark watermark) throws Exception { - wrappedOperator.processWatermark1(watermark); + processWatermarkX( + watermark, + 0, + wrappedOperator::processElement1, + wrappedOperator::processWatermark1, + wrappedOperator::setKeyContextElement1); } @Override public void processWatermark2(Watermark watermark) throws Exception { - wrappedOperator.processWatermark2(watermark); + processWatermarkX( + watermark, + 1, + wrappedOperator::processElement2, + wrappedOperator::processWatermark2, + wrappedOperator::setKeyContextElement2); } @Override diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/util/Distributor.java b/flink-ml-core/src/main/java/org/apache/flink/ml/util/Distributor.java deleted file mode 100644 index b985adfb8..000000000 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/Distributor.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.util; - -import java.io.Serializable; - -/** - * A utility class which helps data partitioning. - * - *

Given an indexable linear structures, like an array, of n elements and m tasks, the goal is to - * partition the linear structure into m consecutive segments and assign them to tasks accordingly. - * This class calculates the segment assigned to each task, including the start position and element - * count of the segment. - */ -public abstract class Distributor implements Serializable { - protected final long numTasks; - protected final long total; - - public Distributor(long total, long numTasks) { - this.numTasks = numTasks; - this.total = total; - } - - /** - * Calculates the start position of the segment assigned to the task. - * - * @param taskId The task index. - * @return The start position. - */ - public abstract long start(long taskId); - - /** - * Calculates the count of elements of the segment assigned to the task. - * - * @param taskId The task index. - * @return The count of elements. - */ - public abstract long count(long taskId); - - /** An implementation of {@link Distributor} which evenly partitioned the elements. */ - public static class EvenDistributor extends Distributor { - - public EvenDistributor(long parallelism, long totalCnt) { - super(totalCnt, parallelism); - } - - @Override - public long start(long taskId) { - long div = total / numTasks; - long mod = total % numTasks; - return taskId < mod ? div * taskId + taskId : div * taskId + mod; - } - - @Override - public long count(long taskId) { - long div = total / numTasks; - long mod = total % numTasks; - return taskId < mod ? div + 1 : div; - } - } -} diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java index e5241ba88..cc0a5ff27 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java @@ -21,14 +21,17 @@ import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.Path; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.runtime.state.storage.FileSystemCheckpointStorage; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.api.operators.BoundedOneInput; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.commons.collections.IteratorUtils; @@ -36,132 +39,295 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils; +import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.UUID; /** Tests the {@link SharedObjectsUtils}. */ public class SharedObjectsUtilsTest { - private static final ItemDescriptor SUM = - ItemDescriptor.of("sum", LongSerializer.INSTANCE, 0L); @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); - static SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody(List> inputs) { - //noinspection unchecked - DataStream data = (DataStream) inputs.get(0); - - AOperator aOp = new AOperator(); - SingleOutputStreamOperator afterAOp = - data.transform("a", TypeInformation.of(Long.class), aOp); - - BOperator bOp = new BOperator(); - SingleOutputStreamOperator afterBOp = - afterAOp.transform("b", TypeInformation.of(Long.class), bOp); - - Map, SharedObjectsStreamOperator> ownerMap = new HashMap<>(); - ownerMap.put(SUM, aOp); - - return new SharedObjectsBody.SharedObjectsBodyResult( - Collections.singletonList(afterBOp), - Arrays.asList(afterAOp.getTransformation(), afterBOp.getTransformation()), - ownerMap); - } - @Test - public void testSharedObjects() throws Exception { + public void testWithDataDeps() throws Exception { StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); DataStream data = env.fromSequence(1, 100); List> outputs = SharedObjectsUtils.withSharedObjects( - Collections.singletonList(data), SharedObjectsUtilsTest::sharedObjectsBody); + Collections.singletonList(data), new SharedObjectsBodyWithDataDeps()); //noinspection unchecked DataStream partitionSum = (DataStream) outputs.get(0); - DataStream allSum = DataStreamUtils.reduce(partitionSum, new SumReduceFunction()); + DataStream allSum = + DataStreamUtils.reduce( + partitionSum, new SharedObjectsBodyWithDataDeps.SumReduceFunction()); allSum.getTransformation().setParallelism(1); //noinspection unchecked List results = IteratorUtils.toList(allSum.executeAndCollect()); Assert.assertEquals(Collections.singletonList(5050L), results); } - /** Operator A: add input elements to the shared {@link #SUM}. */ - static class AOperator extends AbstractStreamOperator - implements OneInputStreamOperator, - SharedObjectsStreamOperator, - BoundedOneInput { - - private final String sharedObjectsAccessorID; - private SharedObjectsContext sharedObjectsContext; + @Test + public void testWithoutDataDeps() throws Exception { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); - public AOperator() { - sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); + DataStream data = env.fromSequence(1, 100); + List> outputs = + SharedObjectsUtils.withSharedObjects( + Collections.singletonList(data), new SharedObjectsBodyWithoutDataDeps()); + //noinspection unchecked + DataStream added = (DataStream) outputs.get(0); + //noinspection unchecked + List results = IteratorUtils.toList(added.executeAndCollect()); + Collections.sort(results); + List expected = new ArrayList<>(); + for (long i = 1; i <= 100; i += 1) { + expected.add(i + 5050); } + Assert.assertEquals(expected, results); + } + + @Test + public void testPotentialDeadlock() throws Exception { + Configuration configuration = new Configuration(); + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(configuration); + File stateFolder = tempFolder.newFolder(); + env.getCheckpointConfig() + .setCheckpointStorage( + new FileSystemCheckpointStorage( + new Path("file://", stateFolder.getPath()))); + final int n = 100; + // Set it to a large value, thus causing a deadlock situation. + final int len = 1 << 20; + DataStream data = + env.fromSequence(1, n).map(d -> RandomStringUtils.randomAlphabetic(len)); + List> outputs = + SharedObjectsUtils.withSharedObjects( + Collections.singletonList(data), new SharedObjectsBodyPotentialDeadlock()); + //noinspection unchecked + DataStream added = (DataStream) outputs.get(0); + added.addSink( + new SinkFunction() { + @Override + public void invoke(String value, Context context) { + Assert.assertEquals(2 * len, value.length()); + } + }); + env.execute(); + } + + static class SharedObjectsBodyWithDataDeps implements SharedObjectsBody { + private static final Descriptor SUM = + Descriptor.of("sum", LongSerializer.INSTANCE, 0L); @Override - public void onSharedObjectsContextSet(SharedObjectsContext context) { - this.sharedObjectsContext = context; + public SharedObjectsBodyResult process(List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + + AOperator aOp = new AOperator(); + SingleOutputStreamOperator afterAOp = + data.transform("a", TypeInformation.of(Long.class), aOp); + + BOperator bOp = new BOperator(); + SingleOutputStreamOperator afterBOp = + afterAOp.transform("b", TypeInformation.of(Long.class), bOp); + + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); + ownerMap.put(SUM, aOp); + + return new SharedObjectsBodyResult( + Collections.singletonList(afterBOp), + Arrays.asList(afterAOp.getTransformation(), afterBOp.getTransformation()), + ownerMap); } - @Override - public String getSharedObjectsAccessorID() { - return sharedObjectsAccessorID; + /** Operator A: add input elements to the shared {@link #SUM}. */ + static class AOperator extends AbstractSharedObjectsOneInputStreamOperator + implements BoundedOneInput { + + private transient long sum = 0; + + @Override + public void processElement(StreamRecord element) throws Exception { + sum += element.getValue(); + } + + @Override + public void endInput() throws Exception { + context.write(SUM, sum); + // Informs BOperator to get the value from shared {@link #SUM}. + output.collect(new StreamRecord<>(0L)); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.emptyList(); + } } - @Override - public void processElement(StreamRecord element) throws Exception { - sharedObjectsContext.invoke( - (getter, setter) -> { - Long currentSum = getter.get(SUM); - setter.set(SUM, currentSum + element.getValue()); - }); + /** Operator B: when input ends, get the value from shared {@link #SUM}. */ + static class BOperator extends AbstractSharedObjectsOneInputStreamOperator { + + @Override + public void processElement(StreamRecord element) throws Exception { + output.collect(new StreamRecord<>(context.read(SUM.sameStep()))); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(SUM.sameStep()); + } } - @Override - public void endInput() throws Exception { - // Informs BOperator to get the value from shared {@link #SUM}. - output.collect(new StreamRecord<>(0L)); + static class SumReduceFunction implements ReduceFunction { + @Override + public Long reduce(Long value1, Long value2) { + return value1 + value2; + } } } - /** Operator B: when input ends, get the value from shared {@link #SUM}. */ - static class BOperator extends AbstractStreamOperator - implements OneInputStreamOperator, SharedObjectsStreamOperator { + static class SharedObjectsBodyWithoutDataDeps implements SharedObjectsBody { + private static final Descriptor SUM = Descriptor.of("sum", LongSerializer.INSTANCE); - private final String sharedObjectsAccessorID; - private SharedObjectsContext sharedObjectsContext; + @Override + public SharedObjectsBodyResult process(List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + DataStream sum = DataStreamUtils.reduce(data, Long::sum); - public BOperator() { - sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); - } + COperator cOp = new COperator(); + SingleOutputStreamOperator afterCOp = + sum.broadcast().transform("c", TypeInformation.of(Long.class), cOp); - @Override - public void onSharedObjectsContextSet(SharedObjectsContext context) { - this.sharedObjectsContext = context; + DOperator dOp = new DOperator(); + SingleOutputStreamOperator afterDOp = + data.transform("d", TypeInformation.of(Long.class), dOp); + + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); + ownerMap.put(SUM, cOp); + + return new SharedObjectsBodyResult( + Collections.singletonList(afterDOp), + Arrays.asList(afterCOp.getTransformation(), afterDOp.getTransformation()), + ownerMap); } - @Override - public String getSharedObjectsAccessorID() { - return sharedObjectsAccessorID; + /** Operator C: set the shared object. */ + static class COperator extends AbstractSharedObjectsOneInputStreamOperator + implements BoundedOneInput { + private transient long sum; + + @Override + public void processElement(StreamRecord element) throws Exception { + sum = element.getValue(); + } + + @Override + public void endInput() throws Exception { + Thread.sleep(2 * 1000); + context.write(SUM, sum); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.emptyList(); + } } - @Override - public void processElement(StreamRecord element) throws Exception { - sharedObjectsContext.invoke( - (getter, setter) -> { - output.collect(new StreamRecord<>(getter.get(SUM))); - }); + /** Operator D: get the value from shared {@link #SUM}. */ + static class DOperator extends AbstractSharedObjectsOneInputStreamOperator { + + private Long sum; + + @Override + public void processElement(StreamRecord element) throws Exception { + if (null == sum) { + sum = context.read(SUM.sameStep()); + } + output.collect(new StreamRecord<>(sum + element.getValue())); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(SUM.sameStep()); + } } } - static class SumReduceFunction implements ReduceFunction { + static class SharedObjectsBodyPotentialDeadlock implements SharedObjectsBody { + private static final Descriptor LAST = + Descriptor.of("last", StringSerializer.INSTANCE); + @Override - public Long reduce(Long value1, Long value2) { - return value1 + value2; + public SharedObjectsBodyResult process(List> inputs) { + //noinspection unchecked + DataStream data = (DataStream) inputs.get(0); + DataStream sum = DataStreamUtils.reduce(data, (v1, v2) -> v2); + + EOperator eOp = new EOperator(); + SingleOutputStreamOperator afterCOp = + sum.broadcast().transform("e", TypeInformation.of(String.class), eOp); + + FOperator dOp = new FOperator(); + SingleOutputStreamOperator afterDOp = + data.transform("d", TypeInformation.of(String.class), dOp); + + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); + ownerMap.put(LAST, eOp); + + return new SharedObjectsBodyResult( + Collections.singletonList(afterDOp), + Arrays.asList(afterCOp.getTransformation(), afterDOp.getTransformation()), + ownerMap); + } + + /** Operator E: set the shared object. */ + static class EOperator extends AbstractSharedObjectsOneInputStreamOperator + implements BoundedOneInput { + private transient String last; + + @Override + public void processElement(StreamRecord element) throws Exception { + last = element.getValue(); + } + + @Override + public void endInput() throws Exception { + Thread.sleep(2 * 1000); + context.write(LAST, last); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.emptyList(); + } + } + + /** Operator F: get the value from shared {@link #LAST}. */ + static class FOperator extends AbstractSharedObjectsOneInputStreamOperator { + + private String last; + + @Override + public void processElement(StreamRecord element) throws Exception { + if (null == last) { + last = context.read(LAST.sameStep()); + } + output.collect(new StreamRecord<>(last + element.getValue())); + } + + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(LAST.sameStep()); + } } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java index 7326ad17b..34ec8d1f1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/BoostIterationBody.java @@ -36,9 +36,9 @@ import org.apache.flink.ml.common.gbt.operators.ReduceSplitsOperator; import org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants; import org.apache.flink.ml.common.gbt.operators.TerminationOperator; -import org.apache.flink.ml.common.sharedobjects.ItemDescriptor; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.Descriptor; import org.apache.flink.ml.common.sharedobjects.SharedObjectsBody; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; import org.apache.flink.ml.common.sharedobjects.SharedObjectsUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; @@ -68,7 +68,7 @@ private SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody( //noinspection unchecked DataStream trainContext = (DataStream) inputs.get(1); - Map, SharedObjectsStreamOperator> ownerMap = new HashMap<>(); + Map, AbstractSharedObjectsStreamOperator> ownerMap = new HashMap<>(); CacheDataCalcLocalHistsOperator cacheDataCalcLocalHistsOp = new CacheDataCalcLocalHistsOperator(strategy); @@ -79,7 +79,7 @@ private SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody( Types.TUPLE( Types.INT, Types.INT, TypeInformation.of(Histogram.class)), cacheDataCalcLocalHistsOp); - for (ItemDescriptor s : SharedObjectsConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { + for (Descriptor s : SharedObjectsConstants.OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP) { ownerMap.put(s, cacheDataCalcLocalHistsOp); } @@ -105,7 +105,7 @@ private SharedObjectsBody.SharedObjectsBodyResult sharedObjectsBody( globalSplits .broadcast() .transform("PostSplits", TypeInformation.of(Integer.class), postSplitsOp); - for (ItemDescriptor descriptor : SharedObjectsConstants.OWNED_BY_POST_SPLITS_OP) { + for (Descriptor descriptor : SharedObjectsConstants.OWNED_BY_POST_SPLITS_OP) { ownerMap.put(descriptor, postSplitsOp); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java index bedaccff5..a36748aa4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CacheDataCalcLocalHistsOperator.java @@ -30,14 +30,12 @@ import org.apache.flink.ml.common.gbt.defs.TrainContext; import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; import org.apache.flink.ml.common.lossfunc.LossFunc; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsTwoInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; @@ -48,7 +46,18 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.UUID; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ALL_TREES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.HAS_INITED_TREE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.INSTANCES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LAYER; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NEED_INIT_TREE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NODE_FEATURE_PAIRS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.PREDS_GRADS_HESSIANS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ROOT_LEARNING_NODE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SHUFFLED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SWAPPED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; /** * Calculates local histograms for local data partition. @@ -58,29 +67,26 @@ * of (subtask index, (nodeId, featureId) pair index, Histogram). */ public class CacheDataCalcLocalHistsOperator - extends AbstractStreamOperator> - implements TwoInputStreamOperator>, - IterationListener>, - SharedObjectsStreamOperator { + extends AbstractSharedObjectsTwoInputStreamOperator< + Row, TrainContext, Tuple3> + implements IterationListener> { private static final String TREE_INITIALIZER_STATE_NAME = "tree_initializer"; private static final String HIST_BUILDER_STATE_NAME = "hist_builder"; private final BoostingStrategy strategy; - private final String sharedObjectsAccessorID; // States of local data. + private transient TrainContext rawTrainContext; private transient ListStateWithCache instancesCollecting; private transient ListStateWithCache treeInitializerState; private transient TreeInitializer treeInitializer; private transient ListStateWithCache histBuilderState; private transient HistBuilder histBuilder; - private transient SharedObjectsContext sharedObjectsContext; public CacheDataCalcLocalHistsOperator(BoostingStrategy strategy) { super(); this.strategy = strategy; - sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -147,112 +153,111 @@ public void processElement1(StreamRecord streamRecord) throws Exception { } @Override - public void processElement2(StreamRecord streamRecord) throws Exception { - TrainContext rawTrainContext = streamRecord.getValue(); - sharedObjectsContext.invoke( - (getter, setter) -> - setter.set(SharedObjectsConstants.TRAIN_CONTEXT, rawTrainContext)); + public List> readRequestsInProcessElement1() { + return Collections.emptyList(); + } + + @Override + public void processElement2(StreamRecord streamRecord) { + rawTrainContext = streamRecord.getValue(); + } + + @Override + public List> readRequestsInProcessElement2() { + return Collections.emptyList(); } public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector> out) + int epochWatermark, Context c, Collector> out) throws Exception { if (0 == epochWatermark) { // Initializes local state in first round. - sharedObjectsContext.invoke( - (getter, setter) -> { - BinnedInstance[] instances = - (BinnedInstance[]) - IteratorUtils.toArray( - instancesCollecting.get().iterator(), - BinnedInstance.class); - setter.set(SharedObjectsConstants.INSTANCES, instances); - instancesCollecting.clear(); + BinnedInstance[] instances = + (BinnedInstance[]) + IteratorUtils.toArray( + instancesCollecting.get().iterator(), BinnedInstance.class); + context.write(INSTANCES, instances); + instancesCollecting.clear(); + + TrainContext trainContext = + new TrainContextInitializer(strategy) + .init( + rawTrainContext, + getRuntimeContext().getIndexOfThisSubtask(), + getRuntimeContext().getNumberOfParallelSubtasks(), + instances); + context.write(TRAIN_CONTEXT, trainContext); - TrainContext rawTrainContext = - getter.get(SharedObjectsConstants.TRAIN_CONTEXT); - TrainContext trainContext = - new TrainContextInitializer(strategy) - .init( - rawTrainContext, - getRuntimeContext().getIndexOfThisSubtask(), - getRuntimeContext().getNumberOfParallelSubtasks(), - instances); - setter.set(SharedObjectsConstants.TRAIN_CONTEXT, trainContext); + treeInitializer = new TreeInitializer(trainContext); + treeInitializerState.update(Collections.singletonList(treeInitializer)); + histBuilder = new HistBuilder(trainContext); + histBuilderState.update(Collections.singletonList(histBuilder)); - treeInitializer = new TreeInitializer(trainContext); - treeInitializerState.update(Collections.singletonList(treeInitializer)); - histBuilder = new HistBuilder(trainContext); - histBuilderState.update(Collections.singletonList(histBuilder)); - }); + } else { + context.renew(TRAIN_CONTEXT); + context.renew(INSTANCES); } - sharedObjectsContext.invoke( - (getter, setter) -> { - TrainContext trainContext = getter.get(SharedObjectsConstants.TRAIN_CONTEXT); - Preconditions.checkArgument( - getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); - BinnedInstance[] instances = getter.get(SharedObjectsConstants.INSTANCES); - double[] pgh = getter.get(SharedObjectsConstants.PREDS_GRADS_HESSIANS); - // In the first round, use prior as the predictions. - if (0 == pgh.length) { - pgh = new double[instances.length * 3]; - double prior = trainContext.prior; - LossFunc loss = trainContext.loss; - for (int i = 0; i < instances.length; i += 1) { - double label = instances[i].label; - pgh[3 * i] = prior; - pgh[3 * i + 1] = loss.gradient(prior, label); - pgh[3 * i + 2] = loss.hessian(prior, label); - } - } + TrainContext trainContext = context.read(TRAIN_CONTEXT.sameStep()); + Preconditions.checkArgument( + getRuntimeContext().getIndexOfThisSubtask() == trainContext.subtaskId); + BinnedInstance[] instances = context.read(INSTANCES.sameStep()); - boolean needInitTree = getter.get(SharedObjectsConstants.NEED_INIT_TREE); - int[] indices; - List layer; - if (needInitTree) { - // When last tree is finished, initializes a new tree, and shuffle instance - // indices. - treeInitializer.init( - getter.get(SharedObjectsConstants.ALL_TREES).size(), - d -> setter.set(SharedObjectsConstants.SHUFFLED_INDICES, d)); - LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); - indices = getter.get(SharedObjectsConstants.SHUFFLED_INDICES); - layer = Collections.singletonList(rootLearningNode); - setter.set(SharedObjectsConstants.ROOT_LEARNING_NODE, rootLearningNode); - setter.set(SharedObjectsConstants.HAS_INITED_TREE, true); - } else { - // Otherwise, uses the swapped instance indices. - indices = getter.get(SharedObjectsConstants.SWAPPED_INDICES); - layer = getter.get(SharedObjectsConstants.LAYER); - setter.set(SharedObjectsConstants.SHUFFLED_INDICES, new int[0]); - setter.set(SharedObjectsConstants.HAS_INITED_TREE, false); - } + double[] pgh = new double[0]; + boolean needInitTree = true; + int numTrees = 0; + if (epochWatermark > 0) { + pgh = context.read(PREDS_GRADS_HESSIANS.prevStep()); + needInitTree = context.read(NEED_INIT_TREE.prevStep()); + numTrees = context.read(ALL_TREES.prevStep()).size(); + } + // In the first round, use prior as the predictions. + if (0 == pgh.length) { + pgh = new double[instances.length * 3]; + double prior = trainContext.prior; + LossFunc loss = trainContext.loss; + for (int i = 0; i < instances.length; i += 1) { + double label = instances[i].label; + pgh[3 * i] = prior; + pgh[3 * i + 1] = loss.gradient(prior, label); + pgh[3 * i + 2] = loss.hessian(prior, label); + } + } - histBuilder.build( - layer, - indices, - instances, - pgh, - d -> setter.set(SharedObjectsConstants.NODE_FEATURE_PAIRS, d), - out); - }); + int[] indices; + List layer; + if (needInitTree) { + // When last tree is finished, initializes a new tree, and shuffle instance + // indices. + treeInitializer.init(numTrees, d -> context.write(SHUFFLED_INDICES, d)); + LearningNode rootLearningNode = treeInitializer.getRootLearningNode(); + indices = context.read(SHUFFLED_INDICES.sameStep()); + layer = Collections.singletonList(rootLearningNode); + context.write(ROOT_LEARNING_NODE, rootLearningNode); + context.write(HAS_INITED_TREE, true); + } else { + // Otherwise, uses the swapped instance indices. + indices = context.read(SWAPPED_INDICES.prevStep()); + layer = context.read(LAYER.prevStep()); + context.write(SHUFFLED_INDICES, new int[0]); + context.write(HAS_INITED_TREE, false); + context.renew(ROOT_LEARNING_NODE); + } + + histBuilder.build( + layer, indices, instances, pgh, d -> context.write(NODE_FEATURE_PAIRS, d), out); } @Override public void onIterationTerminated( - Context context, Collector> collector) - throws Exception { + Context c, Collector> collector) { instancesCollecting.clear(); treeInitializerState.clear(); histBuilderState.clear(); - sharedObjectsContext.invoke( - (getter, setter) -> { - setter.set(SharedObjectsConstants.INSTANCES, new BinnedInstance[0]); - setter.set(SharedObjectsConstants.SHUFFLED_INDICES, new int[0]); - setter.set(SharedObjectsConstants.NODE_FEATURE_PAIRS, new int[0]); - }); + context.write(INSTANCES, new BinnedInstance[0]); + context.write(SHUFFLED_INDICES, new int[0]); + context.write(NODE_FEATURE_PAIRS, new int[0]); } @Override @@ -262,14 +267,4 @@ public void close() throws Exception { histBuilderState.clear(); super.close(); } - - @Override - public void onSharedObjectsContextSet(SharedObjectsContext context) { - this.sharedObjectsContext = context; - } - - @Override - public String getSharedObjectsAccessorID() { - return sharedObjectsAccessorID; - } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java index b55b1f87d..9a33b95dc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/CalcLocalSplitsOperator.java @@ -26,20 +26,24 @@ import org.apache.flink.ml.common.gbt.defs.Histogram; import org.apache.flink.ml.common.gbt.defs.LearningNode; import org.apache.flink.ml.common.gbt.defs.Split; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.UUID; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LAYER; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LEAVES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NODE_FEATURE_PAIRS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ROOT_LEARNING_NODE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; /** * Calculates best splits from histograms for (nodeId, featureId) pairs. @@ -47,22 +51,15 @@ *

The input elements are tuples of ((nodeId, featureId) pair index, Histogram). The output * elements are tuples of (node index, (nodeId, featureId) pair index, Split). */ -public class CalcLocalSplitsOperator extends AbstractStreamOperator> - implements OneInputStreamOperator< - Tuple2, Tuple3>, - SharedObjectsStreamOperator { +public class CalcLocalSplitsOperator + extends AbstractSharedObjectsOneInputStreamOperator< + Tuple2, Tuple3> { private static final Logger LOG = LoggerFactory.getLogger(CalcLocalSplitsOperator.class); private static final String SPLIT_FINDER_STATE_NAME = "split_finder"; - private final String sharedObjectsAccessorID; // States of local data. private transient ListStateWithCache splitFinderState; private transient SplitFinder splitFinder; - private transient SharedObjectsContext sharedObjectsContext; - - public CalcLocalSplitsOperator() { - sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); - } @Override public void initializeState(StateInitializationContext context) throws Exception { @@ -88,56 +85,45 @@ public void snapshotState(StateSnapshotContext context) throws Exception { @Override public void processElement(StreamRecord> element) throws Exception { if (null == splitFinder) { - sharedObjectsContext.invoke( - (getter, setter) -> { - splitFinder = - new SplitFinder(getter.get(SharedObjectsConstants.TRAIN_CONTEXT)); - splitFinderState.update(Collections.singletonList(splitFinder)); - }); + splitFinder = new SplitFinder(context.read(TRAIN_CONTEXT.nextStep())); + splitFinderState.update(Collections.singletonList(splitFinder)); } Tuple2 value = element.getValue(); int pairId = value.f0; Histogram histogram = value.f1; LOG.debug("Received histogram for pairId: {}", pairId); - sharedObjectsContext.invoke( - (getter, setter) -> { - List layer = getter.get(SharedObjectsConstants.LAYER); - if (layer.size() == 0) { - layer = - Collections.singletonList( - getter.get(SharedObjectsConstants.ROOT_LEARNING_NODE)); - } - - int[] nodeFeaturePairs = getter.get(SharedObjectsConstants.NODE_FEATURE_PAIRS); - int nodeId = nodeFeaturePairs[2 * pairId]; - int featureId = nodeFeaturePairs[2 * pairId + 1]; - LearningNode node = layer.get(nodeId); - - Split bestSplit = - splitFinder.calc( - node, - featureId, - getter.get(SharedObjectsConstants.LEAVES).size(), - histogram); - output.collect(new StreamRecord<>(Tuple3.of(nodeId, pairId, bestSplit))); - }); - LOG.debug("Output split for pairId: {}", pairId); - } - @Override - public void close() throws Exception { - super.close(); - splitFinderState.clear(); + List layer = context.read(LAYER.sameStep()); + if (layer.isEmpty()) { + layer = Collections.singletonList(context.read(ROOT_LEARNING_NODE.nextStep())); + } + + int[] nodeFeaturePairs = context.read(NODE_FEATURE_PAIRS.nextStep()); + int nodeId = nodeFeaturePairs[2 * pairId]; + int featureId = nodeFeaturePairs[2 * pairId + 1]; + LearningNode node = layer.get(nodeId); + + Split bestSplit = + splitFinder.calc( + node, featureId, context.read(LEAVES.sameStep()).size(), histogram); + output.collect(new StreamRecord<>(Tuple3.of(nodeId, pairId, bestSplit))); + LOG.debug("Output split for pairId: {}", pairId); } @Override - public void onSharedObjectsContextSet(SharedObjectsContext context) { - this.sharedObjectsContext = context; + public List> readRequestsInProcessElement() { + return Arrays.asList( + TRAIN_CONTEXT.nextStep(), + LAYER.sameStep(), + ROOT_LEARNING_NODE.nextStep(), + NODE_FEATURE_PAIRS.nextStep(), + LEAVES.sameStep()); } @Override - public String getSharedObjectsAccessorID() { - return sharedObjectsAccessorID; + public void close() throws Exception { + super.close(); + splitFinderState.clear(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java index a17b95ce6..d8a0b909c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/PostSplitsOperator.java @@ -28,12 +28,10 @@ import org.apache.flink.ml.common.gbt.defs.Node; import org.apache.flink.ml.common.gbt.defs.Split; import org.apache.flink.ml.common.gbt.defs.TrainContext; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; @@ -43,35 +41,38 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.UUID; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ALL_TREES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.CURRENT_TREE_NODES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.INSTANCES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LAYER; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.LEAVES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NEED_INIT_TREE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.PREDS_GRADS_HESSIANS; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ROOT_LEARNING_NODE; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SHUFFLED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.SWAPPED_INDICES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; /** * Post-process after global splits obtained, including split instances to left or child nodes, and * update instances scores after a tree is complete. */ -public class PostSplitsOperator extends AbstractStreamOperator - implements OneInputStreamOperator, Integer>, - IterationListener, - SharedObjectsStreamOperator { +public class PostSplitsOperator + extends AbstractSharedObjectsOneInputStreamOperator, Integer> + implements IterationListener { private static final String NODE_SPLITTER_STATE_NAME = "node_splitter"; private static final String INSTANCE_UPDATER_STATE_NAME = "instance_updater"; private static final Logger LOG = LoggerFactory.getLogger(PostSplitsOperator.class); - private final String sharedObjectsAccessorID; - // States of local data. private transient Split[] nodeSplits; private transient ListStateWithCache nodeSplitterState; private transient NodeSplitter nodeSplitter; private transient ListStateWithCache instanceUpdaterState; private transient InstanceUpdater instanceUpdater; - private transient SharedObjectsContext sharedObjectsContext; - - public PostSplitsOperator() { - sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); - } @Override public void initializeState(StateInitializationContext context) throws Exception { @@ -109,100 +110,84 @@ public void snapshotState(StateSnapshotContext context) throws Exception { @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector collector) throws Exception { + int epochWatermark, Context c, Collector collector) throws Exception { if (0 == epochWatermark) { - sharedObjectsContext.invoke( - (getter, setter) -> { - TrainContext trainContext = - getter.get(SharedObjectsConstants.TRAIN_CONTEXT); - nodeSplitter = new NodeSplitter(trainContext); - nodeSplitterState.update(Collections.singletonList(nodeSplitter)); - instanceUpdater = new InstanceUpdater(trainContext); - instanceUpdaterState.update(Collections.singletonList(instanceUpdater)); - }); + TrainContext trainContext = context.read(TRAIN_CONTEXT.sameStep()); + nodeSplitter = new NodeSplitter(trainContext); + nodeSplitterState.update(Collections.singletonList(nodeSplitter)); + instanceUpdater = new InstanceUpdater(trainContext); + instanceUpdaterState.update(Collections.singletonList(instanceUpdater)); + } + + int[] indices = new int[0]; + if (epochWatermark > 0) { + indices = context.read(SWAPPED_INDICES.prevStep()); + } + if (0 == indices.length) { + indices = context.read(SHUFFLED_INDICES.sameStep()).clone(); } - sharedObjectsContext.invoke( - (getter, setter) -> { - int[] indices = getter.get(SharedObjectsConstants.SWAPPED_INDICES); - if (0 == indices.length) { - indices = getter.get(SharedObjectsConstants.SHUFFLED_INDICES).clone(); - } - - BinnedInstance[] instances = getter.get(SharedObjectsConstants.INSTANCES); - List leaves = getter.get(SharedObjectsConstants.LEAVES); - List layer = getter.get(SharedObjectsConstants.LAYER); - List currentTreeNodes; - if (layer.size() == 0) { - layer = - Collections.singletonList( - getter.get(SharedObjectsConstants.ROOT_LEARNING_NODE)); - currentTreeNodes = new ArrayList<>(); - currentTreeNodes.add(new Node()); - } else { - currentTreeNodes = getter.get(SharedObjectsConstants.CURRENT_TREE_NODES); - } - - List nextLayer = - nodeSplitter.split( - currentTreeNodes, - layer, - leaves, - nodeSplits, - indices, - instances); - nodeSplits = null; - setter.set(SharedObjectsConstants.LEAVES, leaves); - setter.set(SharedObjectsConstants.LAYER, nextLayer); - setter.set(SharedObjectsConstants.CURRENT_TREE_NODES, currentTreeNodes); - - if (nextLayer.isEmpty()) { - // Current tree is finished. - setter.set(SharedObjectsConstants.NEED_INIT_TREE, true); - instanceUpdater.update( - getter.get(SharedObjectsConstants.PREDS_GRADS_HESSIANS), - leaves, - indices, - instances, - d -> setter.set(SharedObjectsConstants.PREDS_GRADS_HESSIANS, d), - currentTreeNodes); - leaves.clear(); - List> allTrees = getter.get(SharedObjectsConstants.ALL_TREES); - allTrees.add(currentTreeNodes); - - setter.set(SharedObjectsConstants.LEAVES, new ArrayList<>()); - setter.set(SharedObjectsConstants.SWAPPED_INDICES, new int[0]); - setter.set(SharedObjectsConstants.ALL_TREES, allTrees); - LOG.info("finalize {}-th tree", allTrees.size()); - } else { - setter.set(SharedObjectsConstants.SWAPPED_INDICES, indices); - setter.set(SharedObjectsConstants.NEED_INIT_TREE, false); - } - }); + BinnedInstance[] instances = context.read(INSTANCES.sameStep()); + List leaves = context.read(LEAVES.prevStep()); + List layer = context.read(LAYER.prevStep()); + List currentTreeNodes; + if (layer.isEmpty()) { + layer = Collections.singletonList(context.read(ROOT_LEARNING_NODE.sameStep())); + currentTreeNodes = new ArrayList<>(); + currentTreeNodes.add(new Node()); + } else { + currentTreeNodes = context.read(CURRENT_TREE_NODES.prevStep()); + } + + List nextLayer = + nodeSplitter.split(currentTreeNodes, layer, leaves, nodeSplits, indices, instances); + nodeSplits = null; + context.write(LEAVES, leaves); + context.write(LAYER, nextLayer); + context.write(CURRENT_TREE_NODES, currentTreeNodes); + + if (nextLayer.isEmpty()) { + // Current tree is finished. + context.write(NEED_INIT_TREE, true); + instanceUpdater.update( + context.read(PREDS_GRADS_HESSIANS.prevStep()), + leaves, + indices, + instances, + d -> context.write(PREDS_GRADS_HESSIANS, d), + currentTreeNodes); + leaves.clear(); + List> allTrees = context.read(ALL_TREES.prevStep()); + allTrees.add(currentTreeNodes); + + context.write(LEAVES, new ArrayList<>()); + context.write(SWAPPED_INDICES, new int[0]); + context.write(ALL_TREES, allTrees); + LOG.info("finalize {}-th tree", allTrees.size()); + } else { + context.write(SWAPPED_INDICES, indices); + context.write(NEED_INIT_TREE, false); + + context.renew(PREDS_GRADS_HESSIANS); + context.renew(ALL_TREES); + } } @Override - public void onIterationTerminated(Context context, Collector collector) - throws Exception { - sharedObjectsContext.invoke( - (getter, setter) -> { - setter.set(SharedObjectsConstants.PREDS_GRADS_HESSIANS, new double[0]); - setter.set(SharedObjectsConstants.SWAPPED_INDICES, new int[0]); - setter.set(SharedObjectsConstants.LEAVES, Collections.emptyList()); - setter.set(SharedObjectsConstants.LAYER, Collections.emptyList()); - setter.set(SharedObjectsConstants.CURRENT_TREE_NODES, Collections.emptyList()); - }); + public void onIterationTerminated(Context c, Collector collector) { + context.write(PREDS_GRADS_HESSIANS, new double[0]); + context.write(SWAPPED_INDICES, new int[0]); + context.write(LEAVES, Collections.emptyList()); + context.write(LAYER, Collections.emptyList()); + context.write(CURRENT_TREE_NODES, Collections.emptyList()); } @Override public void processElement(StreamRecord> element) throws Exception { if (null == nodeSplits) { - sharedObjectsContext.invoke( - (getter, setter) -> { - List layer = getter.get(SharedObjectsConstants.LAYER); - int numNodes = (layer.size() == 0) ? 1 : layer.size(); - nodeSplits = new Split[numNodes]; - }); + List layer = context.read(LAYER.sameStep()); + int numNodes = (layer.isEmpty()) ? 1 : layer.size(); + nodeSplits = new Split[numNodes]; } Tuple2 value = element.getValue(); int nodeId = value.f0; @@ -211,20 +196,15 @@ public void processElement(StreamRecord> element) throws nodeSplits[nodeId] = split; } + @Override + public List> readRequestsInProcessElement() { + return Collections.singletonList(LAYER.sameStep()); + } + @Override public void close() throws Exception { nodeSplitterState.clear(); instanceUpdaterState.clear(); super.close(); } - - @Override - public void onSharedObjectsContextSet(SharedObjectsContext context) { - sharedObjectsContext = context; - } - - @Override - public String getSharedObjectsAccessorID() { - return sharedObjectsAccessorID; - } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java index 686733246..289b68c8a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/ReduceSplitsOperator.java @@ -21,11 +21,9 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.common.gbt.defs.Split; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Preconditions; @@ -33,9 +31,12 @@ import org.slf4j.LoggerFactory; import java.util.BitSet; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.UUID; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.NODE_FEATURE_PAIRS; /** * Reduces best splits for nodes. @@ -43,34 +44,16 @@ *

The input elements are tuples of (node index, (nodeId, featureId) pair index, Split). The * output elements are tuples of (node index, Split). */ -public class ReduceSplitsOperator extends AbstractStreamOperator> - implements OneInputStreamOperator, Tuple2>, - SharedObjectsStreamOperator { +public class ReduceSplitsOperator + extends AbstractSharedObjectsOneInputStreamOperator< + Tuple3, Tuple2> { private static final Logger LOG = LoggerFactory.getLogger(ReduceSplitsOperator.class); - private final String sharedObjectsAccessorID; - - private transient SharedObjectsContext sharedObjectsContext; - private Map nodeFeatureMap; private Map nodeBestSplit; private Map nodeFeatureCounter; - public ReduceSplitsOperator() { - sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); - } - - @Override - public void onSharedObjectsContextSet(SharedObjectsContext context) { - sharedObjectsContext = context; - } - - @Override - public String getSharedObjectsAccessorID() { - return sharedObjectsAccessorID; - } - @Override public void initializeState(StateInitializationContext context) throws Exception { nodeFeatureMap = new HashMap<>(); @@ -84,15 +67,11 @@ public void processElement(StreamRecord> element if (nodeFeatureMap.isEmpty()) { Preconditions.checkState(nodeBestSplit.isEmpty()); nodeFeatureCounter.clear(); - sharedObjectsContext.invoke( - (getter, setter) -> { - int[] nodeFeaturePairs = - getter.get(SharedObjectsConstants.NODE_FEATURE_PAIRS); - for (int i = 0; i < nodeFeaturePairs.length / 2; i += 1) { - int nodeId = nodeFeaturePairs[2 * i]; - nodeFeatureCounter.compute(nodeId, (k, v) -> null == v ? 1 : v + 1); - } - }); + int[] nodeFeaturePairs = context.read(NODE_FEATURE_PAIRS.nextStep()); + for (int i = 0; i < nodeFeaturePairs.length / 2; i += 1) { + int nodeId = nodeFeaturePairs[2 * i]; + nodeFeatureCounter.compute(nodeId, (k, v) -> null == v ? 1 : v + 1); + } } Tuple3 value = element.getValue(); @@ -103,14 +82,11 @@ public void processElement(StreamRecord> element if (featureMap.isEmpty()) { LOG.debug("Received split for new node {}", nodeId); } - sharedObjectsContext.invoke( - (getter, setter) -> { - int[] nodeFeaturePairs = getter.get(SharedObjectsConstants.NODE_FEATURE_PAIRS); - Preconditions.checkState(nodeId == nodeFeaturePairs[pairId * 2]); - int featureId = nodeFeaturePairs[pairId * 2 + 1]; - Preconditions.checkState(!featureMap.get(featureId)); - featureMap.set(featureId); - }); + int[] nodeFeaturePairs = context.read(NODE_FEATURE_PAIRS.nextStep()); + Preconditions.checkState(nodeId == nodeFeaturePairs[pairId * 2]); + int featureId = nodeFeaturePairs[pairId * 2 + 1]; + Preconditions.checkState(!featureMap.get(featureId)); + featureMap.set(featureId); nodeFeatureMap.put(nodeId, featureMap); nodeBestSplit.compute(nodeId, (k, v) -> null == v ? split : v.accumulate(split)); @@ -124,7 +100,7 @@ public void processElement(StreamRecord> element } @Override - public void close() throws Exception { - super.close(); + public List> readRequestsInProcessElement() { + return Collections.singletonList(NODE_FEATURE_PAIRS.nextStep()); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java index 1c5d4e8fd..1359fb4a8 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/SharedObjectsConstants.java @@ -33,7 +33,7 @@ import org.apache.flink.ml.common.gbt.typeinfo.BinnedInstanceSerializer; import org.apache.flink.ml.common.gbt.typeinfo.LearningNodeSerializer; import org.apache.flink.ml.common.gbt.typeinfo.NodeSerializer; -import org.apache.flink.ml.common.sharedobjects.ItemDescriptor; +import org.apache.flink.ml.common.sharedobjects.Descriptor; import org.apache.flink.ml.common.sharedobjects.SharedObjectsUtils; import org.apache.flink.ml.linalg.typeinfo.OptimizedDoublePrimitiveArraySerializer; @@ -48,85 +48,79 @@ * operators within one JVM to reduce memory footprint and communication cost. We use {@link * SharedObjectsUtils} with co-location mechanism to achieve such purpose. * - *

All shared data items have corresponding {@link ItemDescriptor}s, and can be read/written - * through {@link ItemDescriptor}s from different operator subtasks. Note that every shared item has - * an owner, and the owner can set new values and snapshot the item. + *

All shared objects have corresponding {@link Descriptor}s, and can be read/written through + * {@link Descriptor}s from different operator subtasks. Note that every shared object has an owner, + * and the owner can set new values and snapshot the object. * - *

This class records all {@link ItemDescriptor}s used in {@link GBTRunner} and their owners. + *

This class records all {@link Descriptor}s used in {@link GBTRunner} and their owners. */ @Internal public class SharedObjectsConstants { /** Instances (after binned). */ - static final ItemDescriptor INSTANCES = - ItemDescriptor.of( + static final Descriptor INSTANCES = + Descriptor.of( "instances", new GenericArraySerializer<>( - BinnedInstance.class, BinnedInstanceSerializer.INSTANCE), - new BinnedInstance[0]); + BinnedInstance.class, BinnedInstanceSerializer.INSTANCE)); /** * (prediction, gradient, and hessian) of instances, sharing same indexing with {@link * #INSTANCES}. */ - static final ItemDescriptor PREDS_GRADS_HESSIANS = - ItemDescriptor.of( + static final Descriptor PREDS_GRADS_HESSIANS = + Descriptor.of( "preds_grads_hessians", new OptimizedDoublePrimitiveArraySerializer(), new double[0]); /** Shuffle indices of instances used after every new tree just initialized. */ - static final ItemDescriptor SHUFFLED_INDICES = - ItemDescriptor.of("shuffled_indices", IntPrimitiveArraySerializer.INSTANCE, new int[0]); + static final Descriptor SHUFFLED_INDICES = + Descriptor.of("shuffled_indices", IntPrimitiveArraySerializer.INSTANCE); /** Swapped indices of instances used when {@link #SHUFFLED_INDICES} not applicable. */ - static final ItemDescriptor SWAPPED_INDICES = - ItemDescriptor.of("swapped_indices", IntPrimitiveArraySerializer.INSTANCE, new int[0]); + static final Descriptor SWAPPED_INDICES = + Descriptor.of("swapped_indices", IntPrimitiveArraySerializer.INSTANCE); /** (nodeId, featureId) pairs used to calculate histograms. */ - static final ItemDescriptor NODE_FEATURE_PAIRS = - ItemDescriptor.of( - "node_feature_pairs", IntPrimitiveArraySerializer.INSTANCE, new int[0]); + static final Descriptor NODE_FEATURE_PAIRS = + Descriptor.of("node_feature_pairs", IntPrimitiveArraySerializer.INSTANCE); /** Leaves nodes of current working tree. */ - static final ItemDescriptor> LEAVES = - ItemDescriptor.of( + static final Descriptor> LEAVES = + Descriptor.of( "leaves", new ListSerializer<>(LearningNodeSerializer.INSTANCE), new ArrayList<>()); /** Nodes in current layer of current working tree. */ - static final ItemDescriptor> LAYER = - ItemDescriptor.of( + static final Descriptor> LAYER = + Descriptor.of( "layer", new ListSerializer<>(LearningNodeSerializer.INSTANCE), new ArrayList<>()); /** The root node when initializing a new tree. */ - static final ItemDescriptor ROOT_LEARNING_NODE = - ItemDescriptor.of( - "root_learning_node", LearningNodeSerializer.INSTANCE, new LearningNode()); + static final Descriptor ROOT_LEARNING_NODE = + Descriptor.of("root_learning_node", LearningNodeSerializer.INSTANCE); /** All finished trees. */ - static final ItemDescriptor>> ALL_TREES = - ItemDescriptor.of( + static final Descriptor>> ALL_TREES = + Descriptor.of( "all_trees", new ListSerializer<>(new ListSerializer<>(NodeSerializer.INSTANCE)), new ArrayList<>()); /** Nodes in current working tree. */ - static final ItemDescriptor> CURRENT_TREE_NODES = - ItemDescriptor.of( - "current_tree_nodes", - new ListSerializer<>(NodeSerializer.INSTANCE), - new ArrayList<>()); + static final Descriptor> CURRENT_TREE_NODES = + Descriptor.of("current_tree_nodes", new ListSerializer<>(NodeSerializer.INSTANCE)); /** Indicates the necessity of initializing a new tree. */ - static final ItemDescriptor NEED_INIT_TREE = - ItemDescriptor.of("need_init_tree", BooleanSerializer.INSTANCE, true); + static final Descriptor NEED_INIT_TREE = + Descriptor.of("need_init_tree", BooleanSerializer.INSTANCE, true); /** Data items owned by the `PostSplits` operator. */ - public static final List> OWNED_BY_POST_SPLITS_OP = + public static final List> OWNED_BY_POST_SPLITS_OP = Arrays.asList( PREDS_GRADS_HESSIANS, SWAPPED_INDICES, @@ -137,18 +131,18 @@ public class SharedObjectsConstants { NEED_INIT_TREE); /** Indicate a new tree has been initialized. */ - static final ItemDescriptor HAS_INITED_TREE = - ItemDescriptor.of("has_inited_tree", BooleanSerializer.INSTANCE, false); + static final Descriptor HAS_INITED_TREE = + Descriptor.of("has_inited_tree", BooleanSerializer.INSTANCE, false); /** Training context. */ - static final ItemDescriptor TRAIN_CONTEXT = - ItemDescriptor.of( + static final Descriptor TRAIN_CONTEXT = + Descriptor.of( "train_context", new KryoSerializer<>(TrainContext.class, new ExecutionConfig()), new TrainContext()); /** Data items owned by the `CacheDataCalcLocalHists` operator. */ - public static final List> OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP = + public static final List> OWNED_BY_CACHE_DATA_CALC_LOCAL_HISTS_OP = Arrays.asList( INSTANCES, SHUFFLED_INDICES, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java index b5f77585d..3d0bc921a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/operators/TerminationOperator.java @@ -20,30 +20,28 @@ import org.apache.flink.iteration.IterationListener; import org.apache.flink.ml.common.gbt.GBTModelData; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsContext; -import org.apache.flink.ml.common.sharedobjects.SharedObjectsStreamOperator; +import org.apache.flink.ml.common.sharedobjects.AbstractSharedObjectsOneInputStreamOperator; +import org.apache.flink.ml.common.sharedobjects.ReadRequest; import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; import org.apache.flink.util.OutputTag; -import java.util.UUID; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.ALL_TREES; +import static org.apache.flink.ml.common.gbt.operators.SharedObjectsConstants.TRAIN_CONTEXT; /** Determines whether to terminated training. */ -public class TerminationOperator extends AbstractStreamOperator - implements OneInputStreamOperator, - IterationListener, - SharedObjectsStreamOperator { +public class TerminationOperator + extends AbstractSharedObjectsOneInputStreamOperator + implements IterationListener { private final OutputTag modelDataOutputTag; - private final String sharedObjectsAccessorID; - private transient SharedObjectsContext sharedObjectsContext; public TerminationOperator(OutputTag modelDataOutputTag) { this.modelDataOutputTag = modelDataOutputTag; - sharedObjectsAccessorID = getClass().getSimpleName() + "-" + UUID.randomUUID(); } @Override @@ -55,44 +53,30 @@ public void initializeState(StateInitializationContext context) throws Exception public void processElement(StreamRecord element) throws Exception {} @Override - public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector collector) - throws Exception { - sharedObjectsContext.invoke( - (getter, setter) -> { - boolean terminated = - getter.get(SharedObjectsConstants.ALL_TREES).size() - == getter.get(SharedObjectsConstants.TRAIN_CONTEXT) - .strategy - .maxIter; - // TODO: Add validation error rate - if (!terminated) { - output.collect(new StreamRecord<>(0)); - } - }); + public List> readRequestsInProcessElement() { + return Collections.emptyList(); } @Override - public void onIterationTerminated(Context context, Collector collector) - throws Exception { - if (0 == getRuntimeContext().getIndexOfThisSubtask()) { - sharedObjectsContext.invoke( - (getter, setter) -> - context.output( - modelDataOutputTag, - GBTModelData.from( - getter.get(SharedObjectsConstants.TRAIN_CONTEXT), - getter.get(SharedObjectsConstants.ALL_TREES)))); + public void onEpochWatermarkIncremented( + int epochWatermark, Context c, Collector collector) { + boolean terminated = + context.read(ALL_TREES.sameStep()).size() + == context.read(TRAIN_CONTEXT.sameStep()).strategy.maxIter; + // TODO: Add validation error rate + if (!terminated) { + output.collect(new StreamRecord<>(0)); } } @Override - public void onSharedObjectsContextSet(SharedObjectsContext context) { - sharedObjectsContext = context; - } - - @Override - public String getSharedObjectsAccessorID() { - return sharedObjectsAccessorID; + public void onIterationTerminated(Context c, Collector collector) { + if (0 == getRuntimeContext().getIndexOfThisSubtask()) { + c.output( + modelDataOutputTag, + GBTModelData.from( + context.read(TRAIN_CONTEXT.prevStep()), + context.read(ALL_TREES.prevStep()))); + } } } From 4db52f36f26fef60d5554aeb3580699500c319c1 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Wed, 9 Aug 2023 10:54:33 +0800 Subject: [PATCH 46/47] Remove unused files. --- .../gbt/typeinfo/IntIntPairSerializer.java | 102 ------------------ 1 file changed, 102 deletions(-) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java deleted file mode 100644 index 27975e07c..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/gbt/typeinfo/IntIntPairSerializer.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.gbt.typeinfo; - -import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.base.IntSerializer; -import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; - -import org.eclipse.collections.api.tuple.primitive.IntIntPair; -import org.eclipse.collections.impl.tuple.primitive.PrimitiveTuples; - -import java.io.IOException; - -/** Serializer for {@link IntIntPair}. */ -public class IntIntPairSerializer extends TypeSerializerSingleton { - - public static final IntIntPairSerializer INSTANCE = new IntIntPairSerializer(); - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public IntIntPair createInstance() { - return PrimitiveTuples.pair(0, 0); - } - - @Override - public IntIntPair copy(IntIntPair from) { - return PrimitiveTuples.pair(from.getOne(), from.getTwo()); - } - - @Override - public IntIntPair copy(IntIntPair from, IntIntPair reuse) { - return copy(from); - } - - @Override - public int getLength() { - return 2 * IntSerializer.INSTANCE.getLength(); - } - - @Override - public void serialize(IntIntPair record, DataOutputView target) throws IOException { - IntSerializer.INSTANCE.serialize(record.getOne(), target); - IntSerializer.INSTANCE.serialize(record.getTwo(), target); - } - - @Override - public IntIntPair deserialize(DataInputView source) throws IOException { - return PrimitiveTuples.pair( - (int) IntSerializer.INSTANCE.deserialize(source), - (int) IntSerializer.INSTANCE.deserialize(source)); - } - - @Override - public IntIntPair deserialize(IntIntPair reuse, DataInputView source) throws IOException { - return deserialize(source); - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - serialize(deserialize(source), target); - } - - // ------------------------------------------------------------------------ - - @Override - public TypeSerializerSnapshot snapshotConfiguration() { - return new IntIntPairSerializer.IntIntPairSerializerSnapshot(); - } - - /** Serializer configuration snapshot for compatibility and format evolution. */ - @SuppressWarnings("WeakerAccess") - public static final class IntIntPairSerializerSnapshot - extends SimpleTypeSerializerSnapshot { - - public IntIntPairSerializerSnapshot() { - super(IntIntPairSerializer::new); - } - } -} From 0109bb87bc665b66cfcee2ba53bc67f9f2ecea45 Mon Sep 17 00:00:00 2001 From: Fan Hong Date: Tue, 15 Aug 2023 17:44:30 +0800 Subject: [PATCH 47/47] Fix imports. --- .../flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java index cc0a5ff27..910096abb 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/sharedobjects/SharedObjectsUtilsTest.java @@ -35,11 +35,11 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils; import java.io.File; import java.util.ArrayList;