Skip to content

Commit

Permalink
Fixed the indexing of 'pandas_categorical' attribute elements. Fixes #63
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jul 14, 2024
1 parent b21e367 commit 8ef3280
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pmml-lightgbm/src/main/java/org/jpmml/lightgbm/GBDT.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,21 @@ public Schema encodeSchema(String targetName, List<String> targetCategories, Lig
if(LightGBMUtil.isNone(featureInfo)){
features.add(null);

pandasCategorical:
if(hasPandasCategories){

if(pandasCategoryIndex >= this.pandas_categorical.size()){
break pandasCategorical;
}

List<?> pandasCategoryValues = this.pandas_categorical.get(pandasCategoryIndex);

// A constant categorical column
if(pandasCategoryValues.size() == 1){
pandasCategoryIndex++;
}
}

continue;
}

Expand Down Expand Up @@ -237,13 +252,17 @@ public Schema encodeSchema(String targetName, List<String> targetCategories, Lig

boolean direct = true;

pandasCategorical:
if(hasPandasCategories){

if(pandasCategoryIndex >= this.pandas_categorical.size()){
throw new IllegalArgumentException("Conflicting categorical feature information between the header and \"pandas_categorical\" sections");
}

List<?> pandasCategoryValues = this.pandas_categorical.get(pandasCategoryIndex);
if(pandasCategoryValues.size() < values.size()){
throw new IllegalArgumentException("Expected at least " + values.size() + " category levels, got " + pandasCategoryValues.size() + " category levels");
}

values = pandasCategoryValues;

Expand Down

0 comments on commit 8ef3280

Please sign in to comment.