Skip to content

Commit

Permalink
Merge pull request #12 from JuliaAlgebra/sexponents
Browse files Browse the repository at this point in the history
Make SExponents opaque isbits type
  • Loading branch information
saschatimme authored Jun 17, 2018
2 parents 3498aa7 + 683a0ea commit a87630f
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 32 deletions.
8 changes: 4 additions & 4 deletions src/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ Evaluate the polynomial `f` at `x`.
end
(f::Polynomial)(x::AbstractVector) = evaluate(f, x)

function evaluate_impl(f::Type{Polynomial{T, NVars, E}}) where {T, NVars, E<:SExponents}
function evaluate_impl(f::Type{Polynomial{T, NVars, E}}) where {T, NVars, E}
quote
@boundscheck length(x) NVars
c = coefficients(f)
@inbounds out = begin
$(generate_evaluate(exponents(E, NVars), T))
$(generate_evaluate(exponents(E), T))
end
out
end
Expand Down Expand Up @@ -76,12 +76,12 @@ end
_val_gradient_impl(f)
end

function _val_gradient_impl(f::Type{Polynomial{T, NVars, E}}) where {T, NVars, E<:SExponents}
function _val_gradient_impl(f::Type{Polynomial{T, NVars, E}}) where {T, NVars, E}
quote
@boundscheck length(x) NVars
c = coefficients(f)
@inbounds val, grad = begin
$(generate_gradient(exponents(E, NVars), T))
$(generate_gradient(exponents(E), T))
end
val, grad
end
Expand Down
20 changes: 11 additions & 9 deletions src/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ export Polynomial, coefficients, exponents, nvariables, coefficienttype
Construct a Polynomial from `f`.
"""
struct Polynomial{T, NVars, E<:SExponents}
struct Polynomial{T, NVars, SE}
coefficients::Vector{T}
variables::SVector{NVars, Symbol}

function Polynomial{T, NVars, SExponents{E}}(coefficients::Vector{T}, variables::SVector{NVars, Symbol}) where {T, NVars, E}
@assert length(coefficients) == div(length(E), NVars) "Coefficients size does not match exponents size"
function Polynomial{T, NVars, SE}(coefficients::Vector{T}, variables::SVector{NVars, Symbol}) where {T, NVars, SE}
@assert length(coefficients) == div(length(SE), NVars) "Coefficients size does not match exponents size"
new(coefficients, variables)
end
end

function Polynomial(coefficients::Vector{T}, nvars, exponents::E, variables) where {T, E<:SExponents}
return Polynomial{T, nvars, E}(coefficients, variables)
function Polynomial(coefficients::Vector{T}, nvars, exponents::SExponents, variables) where {T}
return Polynomial{T, nvars, exponents}(coefficients, variables)
end

function Polynomial(coefficients::Vector{T}, exponents::Matrix{<:Integer}, variables=SVector((Symbol("x", i) for i=1:size(exponents, 1))...)) where {T}
Expand Down Expand Up @@ -75,10 +75,12 @@ coefficients(f::Polynomial) = f.coefficients
Return the exponents of `f` as an matrix where each column represents
the exponents of a monomial.
"""
function exponents(::Polynomial{T, NVars, E}) where {T, NVars, E<:SExponents}
exponents(E, NVars)
function exponents(::Polynomial{T, NVars, E}) where {T, NVars, E}
exponents(E)
end

sexponents(::Polynomial{T, NVars, E}) where {T, NVars, E} = E

"""
nvariables(f::Polynomial)
Expand All @@ -94,6 +96,6 @@ Return the type of the coefficients of `f`.
coefficienttype(::Polynomial{T, NVars}) where {T, NVars} = T


function Base.:(==)(f::Polynomial{T, NVars, E}, g::Polynomial{T, NVars, E}) where {T, NVars, E<:SExponents}
coefficients(f) == coefficients(g)
function Base.:(==)(f::Polynomial{T, NVars, E1}, g::Polynomial{T, NVars, E2}) where {T, NVars, E1, E2}
E1 == E2 && coefficients(f) == coefficients(g)
end
35 changes: 19 additions & 16 deletions src/sexponents.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
export SExponents

struct SExponents{E}
function SExponents{E}() where { E}
@assert typeof(E) <: NTuple{N, Int} where N "Exponents type invalid"
new()
end
struct SExponents{N}
exponents::NTuple{N, UInt8}
size::Tuple{Int,Int} # nvars, nterms
end

function SExponents(exponents::Matrix{<:Integer})
# NVars = size(exponents, 1)
E = ntuple(i -> convert(Int, exponents[i]), length(exponents))
E = ntuple(i -> convert(UInt8, exponents[i]), length(exponents))

return SExponents(E, size(exponents))
end

return SExponents{E}()
Base.isbits(::Type{<:SExponents}) = true
Base.length(::SExponents{N}) where N = N
function Base.:(==)(f::SExponents{N}, g::SExponents{N}) where {N}
f.exponents == g.exponents && f.size == g.size
end
Base.hash(f::SExponents, h) = hash(f.exponents, hash(f.size, h))


"""
exponents(::SExponents)
Converts exponents stored in a `SExponents` to a matrix.
"""
function exponents(::Type{SExponents{E}}, nvars) where {E}
nterms = div(length(E), nvars)
function exponents(SE::SExponents)
nvars, nterms = SE.size
exps = fill(0, nvars, nterms)
for k=1:nvars*nterms
exps[k] = E[k]
exps[k] = SE.exponents[k]
end
exps
end
exponents(::S, nvars) where {S<:SExponents} = exponents(S, nvars)

function Base.show(io::IO, ::Type{SExponents{E}}) where {E}
exps_hash = num2hex(hash(E))
print(io, "SExponents{$(exps_hash)}")
function Base.show(io::IO, SE::SExponents{N}) where {N}
exps_hash = num2hex(hash(SE.exponents))
print(io, "SExponents{$N}($(exps_hash))")
end
Base.show(io::IO, S::SExponents) = print(io, typeof(S), "()")
2 changes: 1 addition & 1 deletion src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function Base.show(io::IO, p::Polynomial{T, N, E}) where {T,N,E}
first = true
cfs = coefficients(p)

exps = exponents(E, N)
exps = exponents(E)
NVars, NTerms = size(exps)

for j=1:NTerms
Expand Down
2 changes: 1 addition & 1 deletion src/system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ module Systems
fs = [Symbol("f", i) for i=1:n]
Es = [Symbol("E", i) for i=1:n]
fields = [:($(fs[i])::Polynomial{T, N, $(Symbol("E", i))}) for i=1:n]
types = [:($(Es[i])<:SExponents) for i=1:n]
types = [:($(Es[i])) for i=1:n]
name = Symbol("System", n)
quote
struct $(name){T, N, $(types...)} <: AbstractSystem{T, $n, N}
Expand Down
2 changes: 1 addition & 1 deletion test/basic_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "constructors" begin
A = round.(Int, max.(0.0, 5 * rand(6, 10) .- 1))
f = Polynomial(rand(10), A)
@test typeof(f) <: Polynomial{Float64, 6, <:SExponents}
@test typeof(f) <: Polynomial{Float64, 6}

@test_throws AssertionError Polynomial(rand(9), A)

Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ include("codegen_tests.jl")
include("basic_tests.jl")
include("evaluation_tests.jl")
include("gradient_tests.jl")

A = [1 3 3; 0 2 3]

isbits(SP.SExponents(A))

@polyvar x y

f = Polynomial(x^2+y^2+2y*x-3x^3*y)
w = rand(2)
SP.evaluate(f,w)

0 comments on commit a87630f

Please sign in to comment.