Skip to content

Commit

Permalink
Merge pull request #787 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.20.5 release
  • Loading branch information
ablaom authored Jun 12, 2022
2 parents 1a89215 + adb341f commit df704df
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.20.4"
version = "0.20.5"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -49,6 +49,7 @@ Tables = "0.2, 1.0"
julia = "1.6"

[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -59,4 +60,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
18 changes: 11 additions & 7 deletions src/interface/data_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,18 @@ function MMI.selectrows(::FI, ::Val{:table}, X, r)
end

function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Symbol, Integer})
cols = Tables.columntable(X) # named tuple of vectors
return cols[c]
cols = Tables.columns(X)
return Tables.getcolumn(cols, c)
end

function MMI.selectcols(::FI, ::Val{:table}, X, c::AbstractArray)
cols = Tables.columntable(X) # named tuple of vectors
newcols = project(cols, c)
return Tables.materializer(X)(newcols)
function MMI.selectcols(::FI, ::Val{:table}, X, c::Union{Colon, AbstractArray})
if isdataframe(X)
return X[!, c]
else
cols = Tables.columntable(X) # named tuple of vectors
newcols = project(cols, c)
return Tables.materializer(X)(newcols)
end
end

# -------------------------------
Expand All @@ -124,7 +128,7 @@ function project(t::NamedTuple, indices::AbstractArray{<:Integer})
end

# utils for selectrows
typename(X) = split(string(supertype(typeof(X)).name), '.')[end]
typename(X) = split(string(supertype(typeof(X))), '.')[end]
isdataframe(X) = typename(X) == "AbstractDataFrame"

# ----------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/machines.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
## SCITYPE CHECK LEVEL

"""
default_scitype_check_level()
default_scitype_check_level()
Return the current global default value for scientific type checking
when constructing machines.
default_scitype_check_level(i::Integer)
default_scitype_check_level(i::Integer)
Set the global default value for scientific type checking to `i`.
Expand Down
15 changes: 13 additions & 2 deletions test/interface/data_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import DataFrames

rng = StableRNGs.StableRNG(123)

@testset "categorical" begin
Expand All @@ -23,7 +25,7 @@ end
b = categorical(["a", "b", "c"])
c = categorical(["a", "b", "c"]; ordered=true)
X = (x1=x, x2=z, x3=b, x4=c)
@test MLJModelInterface.scitype(x) == ST.scitype(x)
@test MLJModelInterface.scitype(x) == ST.scitype(x)
@test MLJModelInterface.scitype(y) == ST.scitype(y)
@test MLJModelInterface.scitype(z) == ST.scitype(z)
@test MLJModelInterface.scitype(a) == ST.scitype(a)
Expand All @@ -39,7 +41,7 @@ end
b = categorical(["a", "b", "c"])
c = categorical(["a", "b", "c"]; ordered=true)
X = (x1=x, x2=z, x3=b, x4=c)
@test_throws ArgumentError MLJModelInterface.schema(x)
@test_throws ArgumentError MLJModelInterface.schema(x)
@test MLJModelInterface.schema(X) == ST.schema(X)
end

Expand Down Expand Up @@ -197,4 +199,13 @@ end
@test selectcols(tt, :w) == v
end

# https://github.com/JuliaAI/MLJBase.jl/issues/784
@testset "typename and dataframes" begin
df = DataFrames.DataFrame(x=[1,2,3], y=[2,3,4], z=[4,5,6])
@test MLJBase.typename(df) == "AbstractDataFrame"
@test MLJBase.isdataframe(df)
@test selectrows(df, 2:3) == df[2:3, :]
@test selectcols(df, [:x, :z]) == df[!, [:x, :z]]
end

true

0 comments on commit df704df

Please sign in to comment.