From bbe121f376ce96100e5ed36ecb785f0d8724669a Mon Sep 17 00:00:00 2001 From: Essam Date: Sat, 31 Aug 2024 19:45:05 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20Get=20rid=20of=20ScientificTypes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/MLJFlux.jl | 1 - src/entity_embedding_utils.jl | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index a3d96cf..59d9c24 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -4,7 +4,6 @@ export CUDALibs, CPU1 import Flux using MLJModelInterface using MLJModelInterface.ScientificTypesBase -using ScientificTypes: schema, Finite import Base.== using ProgressMeter using CategoricalArrays diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl index fe1a913..a87c184 100644 --- a/src/entity_embedding_utils.jl +++ b/src/entity_embedding_utils.jl @@ -66,7 +66,7 @@ Returns the indices of the categorical columns in the table `X`. function get_cat_inds(X) # if input is a matrix; conclude no categorical columns Tables.istable(X) || return Int[] - types = schema(X).scitypes + types = [scitype(Tables.getcolumn(X, name)[1]) for name in Tables.schema(X).names] cat_inds = findall(x -> x <: Finite, types) return cat_inds end