Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constraint support #119

Merged
merged 17 commits into from
Aug 29, 2024
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Legolas"
uuid = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
authors = ["Beacon Biosignals, Inc."]
version = "0.5.19"
version = "0.5.20"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Expand Down
21 changes: 14 additions & 7 deletions ext/LegolasConstructionBaseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
if VERSION < v"1.7"
ConstructionBase.getproperties(r::AbstractRecord) = NamedTuple(r)

# This is largely copy-paste from `ConstructionBase.setproperties_object`:
# https://github.com/JuliaObjects/ConstructionBase.jl/blob/cd24e541fd90ab54d2ee12ddd6ccd229be9a5f1e/src/ConstructionBase.jl#L211-L218
function ConstructionBase.setproperties(r::R, patch::NamedTuple) where {R<:AbstractRecord}
nt = getproperties(r)
nt_new = merge(nt, patch)
ConstructionBase.check_patch_properties_exist(nt_new, nt, r, patch)
args = Tuple(nt_new) # old Julia inference prefers if we wrap in `Tuple`
return constructorof(R)(args...)
if isdefined(ConstructionBase, :check_patch_properties_exist)
# This is largely copy-paste from `ConstructionBase.setproperties_object`:
# https://github.com/JuliaObjects/ConstructionBase.jl/blob/cd24e541fd90ab54d2ee12ddd6ccd229be9a5f1e/src/ConstructionBase.jl#L211-L218
nt = getproperties(r)
nt_new = merge(nt, patch)
ConstructionBase.check_patch_properties_exist(nt_new, nt, r, patch)
args = Tuple(nt_new) # old Julia inference prefers if we wrap in `Tuple`
return constructorof(R)(args...)

Check warning on line 23 in ext/LegolasConstructionBaseExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LegolasConstructionBaseExt.jl#L19-L23

Added lines #L19 - L23 were not covered by tests
else
# As of ConstructionBase 1.5.7 the internals of `ConstructionBase.setproperties_object` have changed:
# https://github.com/JuliaObjects/ConstructionBase.jl/blob/71fb5a5198f41f3ef29a53c01940cf7cf6b233eb/src/ConstructionBase.jl#L205-L209
ConstructionBase.check_patch_fields_exist(r, patch)
return ConstructionBase.setfields_object(r, patch)
end
end
end

Expand Down
1 change: 1 addition & 0 deletions src/Legolas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Tables, Arrow, UUIDs
const LEGOLAS_SCHEMA_QUALIFIED_METADATA_KEY = "legolas_schema_qualified"

include("lift.jl")
include("constraints.jl")
include("schemas.jl")
include("tables.jl")
include("record_merge.jl")
Expand Down
13 changes: 13 additions & 0 deletions src/constraints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct CheckConstraintError <: Exception
predicate::Expr
end

function Base.showerror(io::IO, ex::CheckConstraintError)
print(io, "$CheckConstraintError: $(ex.predicate)")
return nothing

Check warning on line 7 in src/constraints.jl

View check run for this annotation

Codecov / codecov/patch

src/constraints.jl#L5-L7

Added lines #L5 - L7 were not covered by tests
end

macro check(expr)
quoted_expr = QuoteNode(expr)
return :($(esc(expr)) || throw(CheckConstraintError($quoted_expr)))
end
27 changes: 24 additions & 3 deletions src/schemas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@
# includes the parent's declared field RHS statements. We cannot interpolate/incorporate these statements
# in the child's record type definition because they may reference bindings from the parent's `@version`
# callsite that are not available/valid at the child's `@version` callsite.
function _generate_record_type_definitions(schema_version::SchemaVersion, record_type_symbol::Symbol)
function _generate_record_type_definitions(schema_version::SchemaVersion, record_type_symbol::Symbol, constraints::AbstractVector)
# generate `schema_version_type_alias_definition`
T = Symbol(string(record_type_symbol, "SchemaVersion"))
schema_version_type_alias_definition = :(const $T = $(Base.Meta.quot(typeof(schema_version))))
Expand Down Expand Up @@ -616,6 +616,7 @@
function $R(; $(field_kwargs...))
$parent_record_application
$(field_assignments...)
$(constraints...)
return new($(keys(record_fields)...))
end
end
Expand All @@ -625,11 +626,13 @@
function $R{$(type_param_names...)}(; $(field_kwargs...)) where {$(type_param_names...)}
$parent_record_application
$(parametric_field_assignments...)
$(constraints...)

Check warning on line 629 in src/schemas.jl

View check run for this annotation

Codecov / codecov/patch

src/schemas.jl#L629

Added line #L629 was not covered by tests
return new{$(type_param_names...)}($(keys(record_fields)...))
end
function $R(; $(field_kwargs...))
$parent_record_application
$(field_assignments...)
$(constraints...)

Check warning on line 635 in src/schemas.jl

View check run for this annotation

Codecov / codecov/patch

src/schemas.jl#L635

Added line #L635 was not covered by tests
return new{$((:(typeof($n)) for n in names_of_parameterized_fields)...)}($(keys(record_fields)...))
end
end
Expand Down Expand Up @@ -778,8 +781,25 @@

# parse `declared_fields_block`
declared_field_statements = Any[]
declared_constraint_statements = Any[]
if declared_fields_block isa Expr && declared_fields_block.head == :block && !isempty(declared_fields_block.args)
declared_field_statements = [f for f in declared_fields_block.args if !(f isa LineNumberNode)]
for f in declared_fields_block.args
if f isa LineNumberNode
continue
elseif f isa Expr && f.head === :macrocall && f.args[1] === Symbol("@check")
constraint_expr = Base.macroexpand(Legolas, f)
# Update the expression such that a failure shows the location of the user
# defined `@check` call. Ideally `Meta.replace_sourceloc!` would do this.
if f.args[2] isa LineNumberNode
constraint_expr = Expr(:block, f.args[2], constraint_expr)
end
omus marked this conversation as resolved.
Show resolved Hide resolved
push!(declared_constraint_statements, constraint_expr)
elseif isempty(declared_constraint_statements)
push!(declared_field_statements, f)
else
return :(throw(SchemaVersionDeclarationError("all `@version` field expressions must be defined before constraints:\n", $(Base.Meta.quot(declared_fields_block)))))
end
end
end
declared_field_infos = DeclaredFieldInfo[]
for stmt in declared_field_statements
Expand All @@ -800,6 +820,7 @@
return :(throw(SchemaVersionDeclarationError($msg)))
end
declared_field_names_types = Expr(:tuple, Expr(:parameters, (Expr(:kw, f.name, esc(f.type)) for f in declared_field_infos)...))
constraints = [Base.Meta.quot(ex) for ex in declared_constraint_statements]

return quote
if !isdefined((@__MODULE__), :__legolas_schema_name_from_prefix__)
Expand Down Expand Up @@ -827,7 +848,7 @@
Base.@__doc__($(Base.Meta.quot(record_type)))
$(esc(:eval))($Legolas._generate_schema_version_definitions(schema_version, parent, $declared_field_names_types, schema_version_declaration))
$(esc(:eval))($Legolas._generate_validation_definitions(schema_version))
$(esc(:eval))($Legolas._generate_record_type_definitions(schema_version, $(Base.Meta.quot(record_type))))
$(esc(:eval))($Legolas._generate_record_type_definitions(schema_version, $(Base.Meta.quot(record_type)), [$(constraints...)]))
end
end
nothing
Expand Down
72 changes: 71 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Compat: current_exceptions
using Legolas, Test, DataFrames, Arrow, UUIDs
using Legolas: SchemaVersion, @schema, @version, SchemaVersionDeclarationError, DeclaredFieldInfo
using Legolas: @schema, @version, CheckConstraintError, SchemaVersion,
SchemaVersionDeclarationError, DeclaredFieldInfo
using Accessors
using Aqua

Expand Down Expand Up @@ -817,3 +818,72 @@ end
@test r.i isa UInt16
@test r.i == 1
end

#####
##### constraints
#####

@schema "test.constraint" Constraint

const CONSTRAINT_V1_EQUAL_CONSTRAINT_LINE = @__LINE__() + 4
@version ConstraintV1 begin
a
b = clamp(b, 0, 5)
@check a == b
@check a > 0
end

@testset "constraints" begin
r = ConstraintV1(; a=1, b=1)
@test r isa ConstraintV1
@test r.a === 1
@test r.b === 1

r = ConstraintV1(; a=1, b=1.0)
@test r isa ConstraintV1
@test r.a === 1
@test r.b === 1.0

# In Julia 1.8+ we can test can test against "CheckConstraintError: a == b"
try
ConstraintV1(; a=1, b=2)
@test false
catch e
@test e isa CheckConstraintError
@test e.predicate == :(a == b)
end

try
ConstraintV1(; a=0, b=0)
@test false
catch e
@test e isa CheckConstraintError
@test e.predicate == :(a > 0)
end

try
ConstraintV1(; a=6, b=6)
@test false
catch e
@test e isa CheckConstraintError
@test e.predicate == :(a == b)
end

# For exceptions that occur during processing constraints its convenient to include the
# location of the `@check` in the stacktrace.
try
ConstraintV1(; a=1, b=missing) # Fails when running check `a == b`
@test false
omus marked this conversation as resolved.
Show resolved Hide resolved
catch e
@test e isa TypeError

bt = Base.process_backtrace(catch_backtrace())
sf = bt[1][1]::Base.StackFrame
@test string(sf.file) == @__FILE__
@test sf.line == CONSTRAINT_V1_EQUAL_CONSTRAINT_LINE
end
end

@testset "constraints must be after all fields" begin
@test_throws SchemaVersionDeclarationError @version(ConstraintV2, begin a; @check a == 1; b end)
end
Loading