From 8ef3280f79a88b91a147d4f82688abdde169f3f5 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sun, 14 Jul 2024 22:49:04 +0300 Subject: [PATCH] Fixed the indexing of 'pandas_categorical' attribute elements. Fixes #63 --- .../main/java/org/jpmml/lightgbm/GBDT.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pmml-lightgbm/src/main/java/org/jpmml/lightgbm/GBDT.java b/pmml-lightgbm/src/main/java/org/jpmml/lightgbm/GBDT.java index 9fc05d6..3ea34ff 100644 --- a/pmml-lightgbm/src/main/java/org/jpmml/lightgbm/GBDT.java +++ b/pmml-lightgbm/src/main/java/org/jpmml/lightgbm/GBDT.java @@ -205,6 +205,21 @@ public Schema encodeSchema(String targetName, List targetCategories, Lig if(LightGBMUtil.isNone(featureInfo)){ features.add(null); + pandasCategorical: + if(hasPandasCategories){ + + if(pandasCategoryIndex >= this.pandas_categorical.size()){ + break pandasCategorical; + } + + List pandasCategoryValues = this.pandas_categorical.get(pandasCategoryIndex); + + // A constant categorical column + if(pandasCategoryValues.size() == 1){ + pandasCategoryIndex++; + } + } + continue; } @@ -237,6 +252,7 @@ public Schema encodeSchema(String targetName, List targetCategories, Lig boolean direct = true; + pandasCategorical: if(hasPandasCategories){ if(pandasCategoryIndex >= this.pandas_categorical.size()){ @@ -244,6 +260,9 @@ public Schema encodeSchema(String targetName, List targetCategories, Lig } List pandasCategoryValues = this.pandas_categorical.get(pandasCategoryIndex); + if(pandasCategoryValues.size() < values.size()){ + throw new IllegalArgumentException("Expected at least " + values.size() + " category levels, got " + pandasCategoryValues.size() + " category levels"); + } values = pandasCategoryValues;