From 46577f714b911f32b6688df0849be84d1b6eaedb Mon Sep 17 00:00:00 2001 From: Essam Date: Sun, 1 Sep 2024 10:57:05 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20Enforce=20columns?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/encoders.jl | 6 +++--- src/entity_embedding_utils.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/encoders.jl b/src/encoders.jl index 6cb3aae..be835b6 100644 --- a/src/encoders.jl +++ b/src/encoders.jl @@ -16,7 +16,7 @@ function ordinal_encoder_fit(X; featinds) # 2. Use feature mapper to compute the mapping of each level in each column for i in featinds - feat_col = Tables.getcolumn(X, i) + feat_col = Tables.getcolumn(Tables.columns(X), i) feat_levels = levels(feat_col) # Check if feat levels is already ordinal encoded in which case we skip (Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue @@ -59,7 +59,7 @@ function ordinal_encoder_transform(X, mapping_matrix) numfeats = length(feat_names) new_feats = [] for ind in 1:numfeats - col = Tables.getcolumn(X, ind) + col = Tables.getcolumn(Tables.columns(X), ind) # Create the transformation function for each column if ind in keys(mapping_matrix) @@ -125,7 +125,7 @@ function embedding_transform(X, mapping_matrices) new_feat_names = Symbol[] new_cols = [] for feat_name in feat_names - col = Tables.getcolumn(X, feat_name) + col = Tables.getcolumn(Tables.columns(X), feat_name) # Create the transformation function for each column if feat_name in keys(mapping_matrices) level2vector = mapping_matrices[feat_name] diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl index c75b7a9..7477431 100644 --- a/src/entity_embedding_utils.jl +++ b/src/entity_embedding_utils.jl @@ -77,7 +77,7 @@ Returns the number of levels in each categorical column in the table `X`. function get_num_levels(X, cat_inds) num_levels = [] for i in cat_inds - num_levels = push!(num_levels, length(levels(Tables.getcolumn(X, i)))) + num_levels = push!(num_levels, length(levels(Tables.getcolumn(Tables.columns(X), i)))) end return num_levels end