Skip to content

Commit

Permalink
Change abstract arrays back to arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmul1114 committed Jan 12, 2024
1 parent 2a7f0c2 commit b72f6e9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ function _gcp(
T = promote_type(TX, Float32)

Check warning on line 19 in ext/CUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDAExt.jl#L19

Added line #L19 was not covered by tests

# Random initialization
M0 = CPD(ones(T, r), rand.(T, size(X), r))
#M0norm = sqrt(mapreduce(abs2, +, M0[I] for I in CartesianIndices(size(M0))))
M0 = CPD(ones(T, r), rand.(T, size(X), r))X_gpu
M0norm = sqrt(sum(abs2, M0[I] for I in CartesianIndices(size(M0))))
Xnorm = sqrt(mapreduce(x -> isnan(x) ? 0 : abs2(x), +, X, init=0f0))
for k in Base.OneTo(N)
Expand Down
8 changes: 4 additions & 4 deletions src/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Main fitting function
"""
gcp(X::AbstractArray, r, loss = LeastSquaresLoss();
gcp(X::Array, r, loss = LeastSquaresLoss();
constraints = default_constraints(loss),
algorithm = default_algorithm(X, r, loss, constraints)) -> CPD
Expand All @@ -21,7 +21,7 @@ to see what losses are supported.
See also: `CPD`, `AbstractLoss`.
"""
gcp(
X::AbstractArray,
X::Array,
r,
loss = LeastSquaresLoss();
constraints = default_constraints(loss),
Expand All @@ -43,7 +43,7 @@ function default_constraints(loss)
end

# Choose default algorithm
default_algorithm(X::AbstractArray{<:Real}, r, loss::LeastSquaresLoss, constraints::Tuple{}) =
default_algorithm(X::Array{<:Real}, r, loss::LeastSquaresLoss, constraints::Tuple{}) =
GCPAlgorithms.ALS()
default_algorithm(X, r, loss, constraints) = GCPAlgorithms.LBFGSB()

Expand Down Expand Up @@ -213,7 +213,7 @@ function mttkrp(X, U, n)
return Rn
end

function khatrirao(A::Vararg{T,N}) where {T<:AbstractMatrix,N}
function khatrirao(A::Vararg{T,N}) where {T<:Matrix,N}
r = size(A[1],2)
R = ntuple(Val(N)) do k
dims = (ntuple(i->1,Val(N-k))..., :, ntuple(i->1,Val(k-1))..., r)
Expand Down

0 comments on commit b72f6e9

Please sign in to comment.