diff --git a/.gitignore b/.gitignore index 11b69ed..63306d7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /docs/build/ Manifest.toml +LocalPreferences.toml diff --git a/Project.toml b/Project.toml index 868a8a4..ec9cb37 100644 --- a/Project.toml +++ b/Project.toml @@ -4,20 +4,26 @@ authors = ["Harish Anand"] version = "0.1.0" [deps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de" +ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" +ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" +ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -julia = "1" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] +AMDGPU = "0.4.12" +NNlib = "0.8.20" +KernelAbstractions = "0.9.2" +julia = "1.9" diff --git a/painting-of-a-farmer-in-the-field.png b/painting-of-a-farmer-in-the-field.png new file mode 100644 index 0000000..6b39f91 Binary files /dev/null and b/painting-of-a-farmer-in-the-field.png differ diff --git a/src/Diffusers.jl b/src/Diffusers.jl index c4fb5ce..8f72131 100644 --- a/src/Diffusers.jl +++ b/src/Diffusers.jl @@ -1,23 +1,87 @@ module Diffusers -export HGF -import MLUtils import JSON3 +import MLUtils import Pickle using Adapt +using FileIO using Flux using HuggingFaceApi +using ImageCore +using ImageIO using OrderedCollections +using ProgressMeter +using Statistics -const Maybe{T} = Union{Nothing, T} +using AMDGPU +using KernelAbstractions +const Backend = ROCBackend() -const HGF = Val{:HGF}() +function sync_free!(args...) + KernelAbstractions.synchronize(Backend) + KernelAbstractions.unsafe_free!.(args) +end + +const Maybe{T} = Union{Nothing, T} +# TODO better way of handling this const FluxDeviceAdaptors = ( Flux.FluxCPUAdaptor, Flux.FluxCUDAAdaptor, Flux.FluxAMDAdaptor) +const FluxEltypeAdaptors = ( + Flux.FluxEltypeAdaptor{Float32}, + Flux.FluxEltypeAdaptor{Float16}) + +get_pb(n, desc::String) = Progress( + n; desc, dt=1, barglyphs=BarGlyphs("[=> ]"), barlen=50, color=:white) + +# TODO +# This matches what PyTorch is doing. +# Upstream to Flux. +# Currently it is doing: (x - μ) / (σ + ϵ) +# But instead it should: (x - μ) / sqrt(σ² + ϵ) +function (ln::LayerNorm)(x::AbstractArray) + ϵ = convert(float(eltype(x)), ln.ϵ) + μ, σ² = _normalize(x; dims=1:length(ln.size)) + y = ln.diag((x .- μ) .* inv.(sqrt.(σ² .+ ϵ))) + sync_free!(μ, σ²) + return y +end + +function (gn::Flux.GroupNorm)(x::AbstractArray) + sz = size(x) + x2 = reshape(x, sz[1:end - 2]..., sz[end - 1] ÷ gn.G, gn.G, sz[end]) + N = ndims(x2) # == ndims(x)+1 + reduce_dims = 1:(N - 2) + affine_shape = ntuple(i -> i ∈ (N - 1, N - 2) ? size(x2, i) : 1, N) + + μ, σ² = _normalize(x2; dims=reduce_dims) + γ = reshape(gn.γ, affine_shape) + β = reshape(gn.β, affine_shape) + + ϵ = convert(float(eltype(x)), gn.ϵ) + scale = γ .* inv.(sqrt.(σ² .+ ϵ)) + bias = -scale .* μ .+ β + + sync_free!(μ, σ²) + return reshape(gn.λ.(scale .* x2 .+ bias), sz) +end + +function _normalize(x::AbstractArray{Float16}; dims) + x_fp32 = Float32.(x) + μ, σ² = _normalize(x_fp32; dims) + m, v = Float16.(μ), Float16.(σ²) + sync_free!(x_fp32, μ, σ²) + return m, v +end + +function _normalize(x; dims) + μ = mean(x; dims) + σ² = var(x; dims, mean=μ, corrected=false) + μ, σ² +end include("timestep.jl") include("feed_forward.jl") @@ -36,60 +100,45 @@ include("clip/tokenizer.jl") include("clip/models.jl") include("schedulers/pndm.jl") +include("stable_diffusion.jl") include("load_utils.jl") -# >>> vae.encoder(x).sum() -# tensor(28057.0957, grad_fn=) - function main() - kl = AutoencoderKL( - "runwayml/stable-diffusion-v1-5"; - state_file="vae/diffusion_pytorch_model.bin", - config_file="vae/config.json") - x = ones(Float32, 256, 256, 3, 1) + sd = StableDiffusion("runwayml/stable-diffusion-v1-5") |> f16 |> gpu + println("Running StableDiffusion on $(get_backend(sd))") + + n_images_per_prompt = 1 + prompts = ["painting of a farmer in the field"] + images = sd(prompts; n_images_per_prompt, n_inference_steps=20) + + idx = 1 + for prompt in prompts, i in 1:n_images_per_prompt + joined_prompt = replace(prompt, ' ' => '-') + save("$joined_prompt-$i.png", rotr90(RGB{N0f8}.(images[:, :, idx]))) + idx += 1 + end + return +end - @show sum(kl.encoder(x)) - @show sum(kl(x)) +function debug() + m = LayerNorm(320) + lf = m |> f32 |> gpu + lh = m |> f16 |> gpu - # y = kl(x) - # @show size(y) - # @show sum(y) + x = rand(Float32, 320, 4096, 1) + xf = x |> f32 |> gpu + xh = x |> f16 |> gpu - # y = kl(x; sample_posterior = true) - # @show size(y) - # @show sum(y) + y = m(x) + yf = lf(xf) |> cpu + yh = lh(xh) |> cpu |> f32 + println() + @show sum(y) + @show sum(yf) + @show sum(yh) return end -function tk() - input_texts = [ - "Hello, world!", - "There is nothing basically... I mean it quite literally", - "I was now on a dark path, unsettled by a future filled with big data and small comprehension.", - ] - println("Input texts:") - display(input_texts); println() - - tokenizer = CLIPTokenizer() - tokens, pad_mask = tokenize(tokenizer, input_texts; context_length=32) - println("Tokens:") - display(tokens); println() - display(pad_mask); println() - @show size(pad_mask) - - texts = [ - decode(tokenizer, @view(tokens[:, i])) - for i in 1:size(tokens, 2)] - println("Decoded texts:") - display(texts); println() - - nothing -end - -""" -- CLIPFeatureExtractor: https://github.com/huggingface/transformers/blob/fb366b9a2a94b38171896f6ba9fb9ae8bffd77af/src/transformers/models/clip/feature_extraction_clip.py#L26 -""" - end diff --git a/src/attention.jl b/src/attention.jl index c1714a3..32078c4 100644 --- a/src/attention.jl +++ b/src/attention.jl @@ -28,6 +28,7 @@ function Attention(dim::Int; cross_attention_norm::Bool = false, scale::Float32 = 1f0, dropout::Real = 0, + ϵ::Float32 = 1f-5, ) cross_attention_norm && isnothing(context_dim) && throw(ArgumentError(""" `context_dim` is `nothing`, but `cross_attention_norm` is `true`. @@ -54,7 +55,7 @@ function Attention(dim::Int; norm = if is_cross_attention cross_attention_norm ? LayerNorm(context_dim) : identity else - isnothing(n_groups) ? identity : GroupNorm(dim, n_groups) + isnothing(n_groups) ? identity : GroupNorm(dim, n_groups; ϵ) end Attention( @@ -65,8 +66,8 @@ end function (attn::Attention)( x::T, context::Maybe{C} = nothing; mask::Maybe{M} = nothing, ) where { - T <: AbstractArray{Float32, 3}, - C <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 3}, + C <: AbstractArray{<:Real, 3}, M <: AbstractMatrix{Bool}, } residual = x @@ -77,15 +78,24 @@ function (attn::Attention)( attn.to_q(x), attn.to_k(c), attn.to_v(c) else x = attn.norm(x) + # TODO add doc that in this case input should be in (w * h, c, b) + x = permutedims(x, (2, 1, 3)) attn.to_q(x), attn.to_k(x), attn.to_v(x) end mask = isnothing(mask) ? nothing : reshape(mask, size(mask, 1), 1, 1, size(mask, 2)) - ω, _ = dot_product_attention(q, k, v; mask, nheads=attn.n_heads) + ω, α = dot_product_attention(q, k, v; mask, nheads=attn.n_heads) - o = attn.to_out(reshape(ω, :, seq_length, batch)) + sync_free!(α, q, k, v) + isnothing(mask) || KernelAbstractions.unsafe_free!(mask) + + cross_attention(attn) && (ω = reshape(ω, :, seq_length, batch);) + o = attn.to_out(ω) + + sync_free!(ω) cross_attention(attn) && return o - (o .+ residual) ./ attn.scale + FP = eltype(x) + (permutedims(o, (2, 1, 3)) .+ residual) .* FP(inv(attn.scale)) end diff --git a/src/autoencoder/blocks.jl b/src/autoencoder/blocks.jl index 6fe9340..58774aa 100644 --- a/src/autoencoder/blocks.jl +++ b/src/autoencoder/blocks.jl @@ -27,7 +27,7 @@ function Encoder( push!(down_blocks, SamplerBlock2D{Downsample2D}( input_channels => output_channels; n_groups, n_layers=n_block_layers, - add_sampler=!is_final, λ)) + sampler_pad=0, add_sampler=!is_final, λ)) end mid_block = MidBlock2D( @@ -43,11 +43,12 @@ function Encoder( Encoder(conv_in, conv_out, norm, Chain(down_blocks...), mid_block) end -function (enc::Encoder)(x::T) where T <: AbstractArray{Float32, 4} +function (enc::Encoder)(x::T) where T <: AbstractArray{<:Real, 4} x = enc.conv_in(x) x = enc.down_blocks(x) x = enc.mid_block(x) - enc.conv_out(enc.norm(x)) + x = enc.norm(x) + enc.conv_out(x) end struct Decoder{C1, C2, N, U, M} @@ -94,7 +95,7 @@ function Decoder( Decoder(conv_in, conv_out, norm, Chain(up_blocks...), mid_block) end -function (dec::Decoder)(x::T) where T <: AbstractArray{Float32, 4} +function (dec::Decoder)(x::T) where T <: AbstractArray{<:Real, 4} x = dec.conv_in(x) x = dec.mid_block(x) x = dec.up_blocks(x) @@ -124,7 +125,7 @@ end # Kullback–Leibler divergence. function kl( dg::DiagonalGaussian{T}, other::Maybe{DiagonalGaussian{T}} = nothing, -) where T <: AbstractArray{Float32, 4} +) where T <: AbstractArray{<:Real, 4} dims = (1, 2, 3) 0.5f0 .* (isnothing(other) ? sum(dg.μ.^2 .+ dg.ν .- dg.log_σ .- 1f0; dims) : diff --git a/src/autoencoder/kl.jl b/src/autoencoder/kl.jl index 5a4475f..d577b20 100644 --- a/src/autoencoder/kl.jl +++ b/src/autoencoder/kl.jl @@ -30,20 +30,20 @@ function AutoencoderKL( AutoencoderKL(encoder, decoder, quant_conv, post_quant_conv, scaling_factor) end -function encode(kl::AutoencoderKL, x::T) where T <: AbstractArray{Float32, 4} +function encode(kl::AutoencoderKL, x::T) where T <: AbstractArray{<:Real, 4} h = kl.encoder(x) moments = kl.quant_conv(h) DiagonalGaussian(moments) end -function decode(kl::AutoencoderKL, z::T) where T <: AbstractArray{Float32, 4} +function decode(kl::AutoencoderKL, z::T) where T <: AbstractArray{<:Real, 4} h = kl.post_quant_conv(z) kl.decoder(h) end function (kl::AutoencoderKL)( x::T; sample_posterior::Bool = false, -) where T <: AbstractArray{Float32, 4} +) where T <: AbstractArray{<:Real, 4} posterior = encode(kl, x) if sample_posterior z = sample(posterior) diff --git a/src/clip/basic.jl b/src/clip/basic.jl index aa6ae79..6a4aece 100644 --- a/src/clip/basic.jl +++ b/src/clip/basic.jl @@ -1,32 +1,24 @@ -struct Embedding{E} - weights::E -end -Flux.@functor Embedding - -function Embedding(; embed_dim::Int, vocab_size::Int) - Embedding(randn(Float32, embed_dim, vocab_size)) -end - -# ids are 1-based -function (e::Embedding)(ids::T) where T <: AbstractMatrix{<: Integer} - NNlib.gather(e.weights, ids) -end - -struct CLIPTextEmbeddings{T, P, I} +struct CLIPTextEmbeddings{T, P, I <: AbstractMatrix{Int32}} token_embedding::T position_embedding::P position_ids::I end Flux.@functor CLIPTextEmbeddings +# TODO better way to handle fixed int type during f16 conversion +function CLIPTextEmbeddings(token_embedding, position_embedding, position_ids) + pi = eltype(position_ids) == Int32 ? position_ids : Int32.(position_ids) + CLIPTextEmbeddings(token_embedding, position_embedding, pi) +end + Flux.trainable(emb::CLIPTextEmbeddings) = (emb.token_embedding, emb.position_embedding) function CLIPTextEmbeddings(; vocab_size::Int, embed_dim::Int, max_position_embeddings::Int, ) CLIPTextEmbeddings( - Embedding(; embed_dim, vocab_size), - Embedding(; embed_dim, vocab_size=max_position_embeddings), + Embedding(vocab_size => embed_dim), + Embedding(max_position_embeddings => embed_dim), reshape(collect(UnitRange{Int32}(1, max_position_embeddings)), :, 1)) end diff --git a/src/clip/models.jl b/src/clip/models.jl index b3dc1c4..82b56dd 100644 --- a/src/clip/models.jl +++ b/src/clip/models.jl @@ -1,4 +1,4 @@ -quick_gelu(x) = x * sigmoid(1.702f0 * x) +quick_gelu(x::T) where T = x * sigmoid(T(1.702f0) * x) struct CLIPMLP{F1, F2} fc1::F1 @@ -10,7 +10,7 @@ function CLIPMLP(dims::Pair{Int, Int}, λ = quick_gelu) CLIPMLP(Dense(dims, λ), Dense(reverse(dims))) end -function (mlp::CLIPMLP)(x::T) where T <: AbstractArray{Float32, 3} +function (mlp::CLIPMLP)(x::T) where T <: AbstractArray{<:Real, 3} mlp.fc2(mlp.fc1(x)) end @@ -41,7 +41,7 @@ end function (attn::CLIPAttention)( x::T; mask::Maybe{M1} = nothing, causal_mask::Maybe{M2} = nothing, ) where { - T <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 3}, M1 <: AbstractMatrix{Bool}, M2 <: AbstractMatrix{Bool}, } @@ -84,7 +84,7 @@ end function (enc::CLIPEncoderLayer)( x::T; mask::Maybe{M1} = nothing, causal_mask::Maybe{M2} = nothing, ) where { - T <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 3}, M1 <: AbstractMatrix{Bool}, M2 <: AbstractMatrix{Bool}, } @@ -116,7 +116,7 @@ end function (enc::CLIPEncoder)( x::T; mask::Maybe{M1} = nothing, causal_mask::Maybe{M2} = nothing ) where { - T <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 3}, M1 <: AbstractMatrix{Bool}, M2 <: AbstractMatrix{Bool}, } @@ -133,6 +133,10 @@ struct CLIPTextTransformer{B, E, L} end Flux.@functor CLIPTextTransformer +function get_backend(tr::CLIPTextTransformer) + typeof(tr.embeddings.token_embedding.weight) <: Array ? cpu : gpu +end + function CLIPTextTransformer(; vocab_size::Int, embed_dim::Int, max_position_embeddings::Int, n_heads::Int, num_hidden_layers::Int, intermediate_size::Int, @@ -146,16 +150,13 @@ function CLIPTextTransformer(; CLIPTextTransformer(embeddings, encoder, final_layer_norm) end -function (transformer::CLIPTextTransformer)( - input_ids::I; mask::Maybe{M} = nothing, -) where { - I <: AbstractMatrix{Int32}, - M <: AbstractMatrix{Bool}, +function (tr::CLIPTextTransformer)(input_ids::I; mask::Maybe{M} = nothing) where { + I <: AbstractMatrix{Int32}, M <: AbstractMatrix{Bool}, } - x = transformer.embeddings(input_ids) - causal_mask = make_causal_mask(input_ids; dims=1) - x = transformer.encoder(x; mask, causal_mask) - transformer.final_layer_norm(x) + x = tr.embeddings(input_ids) + causal_mask = make_causal_mask(input_ids; dims=1) |> get_backend(tr) + x = tr.encoder(x; mask, causal_mask) + tr.final_layer_norm(x) end # HGF integration. diff --git a/src/clip/tokenizer.jl b/src/clip/tokenizer.jl index 2625830..3a2a155 100644 --- a/src/clip/tokenizer.jl +++ b/src/clip/tokenizer.jl @@ -48,18 +48,20 @@ function encode(tk::CLIPTokenizer, text::String) bpe_tokens end -# TODO longet text context length +# TODO longest text context length +# TODO encode in Int32 from start function tokenize( tk::CLIPTokenizer, texts::Vector{String}; context_length::Int, truncate::Bool = false, add_start_end::Bool = false, ) n = length(texts) - encodings = [encode(tk, - add_start_end ? "<|startoftext|> $text <|endoftext|>" : text) + encodings = [ + encode(tk, add_start_end ? "<|startoftext|> $text <|endoftext|>" : text) for text in texts] - tokens = zeros(Int64, context_length, n) + eof_token = tk.encoder["<|endoftext|>"] + tokens = fill(eof_token, context_length, n) pad_mask = fill(false, context_length, n) for (i, enc) in enumerate(encodings) if length(enc) > context_length @@ -80,7 +82,7 @@ end function decode( tk::CLIPTokenizer, tokens::T; remove_start_end::Bool = true, ignore_padding::Bool = true, -) where T <: AbstractVector{Int64} +) where T <: AbstractVector{Int64} # TODO Int32 if remove_start_end eof_tokens = (tk.encoder["<|startoftext|>"], tk.encoder["<|endoftext|>"]) tokens = [t for t in tokens if !(t in eof_tokens)] diff --git a/src/feed_forward.jl b/src/feed_forward.jl index b63011b..5834326 100644 --- a/src/feed_forward.jl +++ b/src/feed_forward.jl @@ -1,6 +1,9 @@ function geglu(x) h, gate = MLUtils.chunk(x, 2; dims=1) - h .* gelu(gate) + gate = gelu(gate) + y = h .* gate + sync_free!(gate) + return y end struct FeedForward{F} diff --git a/src/load_utils.jl b/src/load_utils.jl index d037df9..c716b0e 100644 --- a/src/load_utils.jl +++ b/src/load_utils.jl @@ -96,7 +96,10 @@ end function load_state!(layer::Flux.GroupNorm, state) layer.γ .= state.weight layer.β .= state.bias - return nothing +end + +function load_state!(emb::Flux.Embedding, state) + copy!(emb.weight, transpose(state.weight)) end function load_state!(attn::Attention, state; use_cross_attention::Bool = false) @@ -286,7 +289,3 @@ function load_state!(emb::CLIPTextEmbeddings, state) load_state!(emb.token_embedding, state.token_embedding) load_state!(emb.position_embedding, state.position_embedding) end - -function load_state!(emb::Embedding, state) - copy!(emb.weights, transpose(state.weight)) -end diff --git a/src/resnet.jl b/src/resnet.jl index af78a8b..1cb918a 100644 --- a/src/resnet.jl +++ b/src/resnet.jl @@ -1,17 +1,20 @@ struct Downsample2D{C} conv::C + special_padding::Bool end Flux.@functor Downsample2D function Downsample2D( channels::Pair{Int, Int}; use_conv::Bool = false, pad::Int = 1, ) + special_padding = use_conv && pad == 0 Downsample2D(use_conv ? Conv((3, 3), channels; stride=2, pad) : - MeanPool((2, 2))) + MeanPool((2, 2)), special_padding) end -function (down::Downsample2D)(x::T) where T <: AbstractArray{Float32, 4} +function (down::Downsample2D)(x::T) where T <: AbstractArray{<:Real, 4} + down.special_padding && (x = pad_zeros(x, (0, 1, 0, 1, 0, 0, 0, 0));) down.conv(x) end @@ -37,7 +40,7 @@ end function (up::Upsample2D{C})( x::T; output_size::Maybe{Tuple{Int, Int}} = nothing, -) where {C, T <: AbstractArray{Float32, 4}} +) where {C, T <: AbstractArray{<:Real, 4}} C <: ConvTranspose && (x = up.conv(x);) x = isnothing(output_size) ? @@ -66,6 +69,7 @@ function ResnetBlock2D(channels::Pair{Int, Int}; time_emb_channels::Maybe{Int} = 512, use_shortcut::Maybe{Bool} = nothing, conv_out_channels::Maybe{Int} = nothing, dropout::Real = 0, λ = swish, scale::Float32 = 1f0, + ϵ::Float32 = 1f-6, ) in_channels, out_channels = channels n_groups_out = isnothing(n_groups_out) ? n_groups : n_groups_out @@ -78,14 +82,14 @@ function ResnetBlock2D(channels::Pair{Int, Int}; # NOTE no up/down init_proj = Chain( - GroupNorm(in_channels, n_groups, λ), + GroupNorm(in_channels, n_groups, λ; ϵ), Conv((3, 3), channels; pad=1)) out_proj = Chain( x -> λ.(x), iszero(dropout) ? identity : Dropout(dropout), Conv((3, 3), out_channels => conv_out_channels; pad=1)) - norm = GroupNorm(out_channels, n_groups_out) + norm = GroupNorm(out_channels, n_groups_out; ϵ) time_emb_proj = isnothing(time_emb_channels) ? identity : Chain(x -> λ.(x), Dense(time_emb_channels => time_emb_out_channels)) @@ -99,7 +103,7 @@ function ResnetBlock2D(channels::Pair{Int, Int}; end function (block::ResnetBlock2D)(x::T, time_embedding::Maybe{E}) where { - T <: AbstractArray{Float32, 4}, E <: AbstractMatrix{Float32}, + T <: AbstractArray{<:Real, 4}, E <: AbstractMatrix{<:Real}, } skip, x = x, block.init_proj(x) @@ -111,10 +115,12 @@ function (block::ResnetBlock2D)(x::T, time_embedding::Maybe{E}) where { x = block.norm(x) + TI = eltype(x) if time_embedding ≢ nothing && block.embedding_scale_shift scale, shift = MLUtils.chunk(time_embedding, 2; dims=3) - x = x .* (1f0 .+ scale) .+ shift + x = x .* (one(TI) .+ scale) .+ shift + sync_free!(scale, shift) end - (block.out_proj(x) .+ block.conv_shortcut(skip)) ./ block.scale + block.out_proj(x) .+ block.conv_shortcut(skip) .* TI(inv(block.scale)) end diff --git a/src/schedulers/pndm.jl b/src/schedulers/pndm.jl index ca025ea..95034f5 100644 --- a/src/schedulers/pndm.jl +++ b/src/schedulers/pndm.jl @@ -1,7 +1,6 @@ Base.@kwdef mutable struct PNDMScheduler{ - A <: AbstractVector{Float32}, - T <: AbstractVector{Int}, - S <: AbstractArray{Float32}, + A <: AbstractVector{<:Real}, + S <: AbstractArray{<:Real}, } const α̂::A # Same as α̅ but lives on the current device. const α::Vector{Float32} @@ -15,7 +14,7 @@ Base.@kwdef mutable struct PNDMScheduler{ # Adjustable. - timesteps::T = Int[] + timesteps::Vector{Int} = Int[] prk_timesteps::Vector{Int} = Int[] plms_timesteps::Vector{Int} = Int[] _timesteps::Vector{Int} @@ -69,14 +68,14 @@ function PNDMScheduler(nd::Int; n_train_timesteps) end -for T in FluxDeviceAdaptors +for T in (FluxDeviceAdaptors..., FluxEltypeAdaptors...) @eval function Adapt.adapt_storage(to::$(T), pndm::PNDMScheduler) PNDMScheduler( Adapt.adapt(to, pndm.α̂), pndm.α, pndm.α̅, pndm.β, pndm.α_final, pndm.σ₀, pndm.skip_prk_steps, pndm.step_offset, - Adapt.adapt(to, pndm.timesteps), - pndm.prk_timesteps, pndm.plms_timesteps, pndm._timesteps, + pndm.timesteps, pndm.prk_timesteps, pndm.plms_timesteps, + pndm._timesteps, Adapt.adapt(to, pndm.x_current), Adapt.adapt(to, pndm.sample), @@ -86,7 +85,7 @@ for T in FluxDeviceAdaptors end end -Base.ndims(::PNDMScheduler{A, T, S}) where {A, T, S} = ndims(S) +Base.ndims(::PNDMScheduler{A, S}) where {A, S} = ndims(S) """ set_timesteps!(pndm::PNDMScheduler, n_inference_timesteps::Int) @@ -109,8 +108,8 @@ function set_timesteps!(pndm::PNDMScheduler, n_inference_timesteps::Int) empty!(pndm.prk_timesteps) pndm.plms_timesteps = reverse(cat( pndm._timesteps[1:end - 1], - pndm._timesteps[2:end - 1], - pndm._timesteps[2:end]; dims=1)) + pndm._timesteps[end - 1:end - 1], + pndm._timesteps[end:end]; dims=1)) else pndm_order = 4 prk_timesteps = @@ -127,7 +126,7 @@ function set_timesteps!(pndm::PNDMScheduler, n_inference_timesteps::Int) # Reset counters. pndm.x_current = similar(pndm.x_current, ntuple(_ -> 1, Val(ndims(pndm)))) - fill!(pndm.x_current, 0f0) + fill!(pndm.x_current, zero(eltype(pndm.x_current))) empty!(pndm.xs) pndm.counter = 0 return @@ -142,9 +141,7 @@ Predict the sample at the previous `t - 1` timestep by reversing the SDE. - `t::Int`: Current discrete timestep in the diffusion chain. - `sample`: Sample at the current timestep `t` created by diffusion process. """ -function step!(pndm::PNDMScheduler{A, T, S}, x::S; t::Int, sample::S) where { - A, T, S, -} +function step!(pndm::PNDMScheduler{A, S}, x::S; t::Int, sample::S) where {A, S} (pndm.counter < length(pndm.prk_timesteps) && !pndm.skip_prk_steps) ? step_prk!(pndm, x; t, sample) : step_plms!(pndm, x; t, sample) @@ -155,25 +152,25 @@ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. """ -function step_prk!(pndm::PNDMScheduler{A, T, S}, x::S; t::Int, sample::S) where { - A, T, S, -} +function step_prk!(pndm::PNDMScheduler{A, S}, x::S; t::Int, sample::S) where {A, S } + FP = eltype(S) + ratio = pndm.n_train_timesteps ÷ pndm.n_inference_timesteps δt = (pndm.counter % 2 == 0) ? (ratio ÷ 2) : 0 prev_t, t = t - δt, pndm.prk_timesteps[(pndm.counter ÷ 4) * 4 + 1] if pndm.counter % 4 == 0 # Re-assign to get correct shape. - pndm.x_current = pndm.x_current .+ (1f0 / 6f0) .* x + pndm.x_current = pndm.x_current .+ FP(1f0 / 6f0) .* x pndm.sample = sample push!(pndm.xs, x) elseif (pndm.counter - 1) % 4 == 0 - pndm.x_current .+= (1f0 / 3f0) .* x + pndm.x_current .+= FP(1f0 / 3f0) .* x elseif (pndm.counter - 2) % 4 == 0 - pndm.x_current .+= (1f0 / 3f0) .* x + pndm.x_current .+= FP(1f0 / 3f0) .* x elseif (pndm.counter - 3) % 4 == 0 - x = pndm.x_current .+ (1f0 / 6f0) .* x - fill!(pndm.x_current, 0f0) + x = pndm.x_current .+ FP(1f0 / 6f0) .* x + fill!(pndm.x_current, zero(FP)) end pndm.counter += 1 @@ -184,14 +181,13 @@ end Step function propagating the sample with the linear multi-step method. Has one forward pass with multiple times to approximate the solution. """ -function step_plms!(pndm::PNDMScheduler{A, T, S}, x::S; t::Int, sample::S) where { - A, T, S, -} +function step_plms!(pndm::PNDMScheduler{A, S}, x::S; t::Int, sample::S) where {A, S} !pndm.skip_prk_steps && length(pndm.xs) < 3 && error(""" Linear multi-step method can only be run after at least 12 steps in PRK mode and has sampled `3` forward passes. Current amount is `$(length(pndm.xs))`. """) + FP = eltype(S) ratio = pndm.n_train_timesteps ÷ pndm.n_inference_timesteps prev_t = t - ratio @@ -209,18 +205,18 @@ function step_plms!(pndm::PNDMScheduler{A, T, S}, x::S; t::Int, sample::S) where if length(pndm.xs) == 1 && pndm.counter == 0 pndm.sample = sample elseif length(pndm.xs) == 1 && pndm.counter == 1 - x = (x .+ pndm.xs[end]) .* 0.5f0 + x = (x .+ pndm.xs[end]) .* FP(0.5f0) sample = pndm.sample elseif length(pndm.xs) == 2 - x = (3f0 .* pndm.xs[end] .- pndm.xs[end - 1]) .* 0.5f0 + x = (FP(3f0) .* pndm.xs[end] .- pndm.xs[end - 1]) .* FP(0.5f0) elseif length(pndm.xs) == 3 - x = (1f0 / 12f0) .* ( - 23f0 .* pndm.xs[end] .- 16f0 .* pndm.xs[end - 1] .+ - 5f0 .* pndm.xs[end - 2]) + x = FP(1f0 / 12f0) .* ( + FP(23f0) .* pndm.xs[end] .- FP(16f0) .* pndm.xs[end - 1] .+ + FP(5f0) .* pndm.xs[end - 2]) else - x = (1f0 / 24f0) .* ( - 55f0 .* pndm.xs[end] .- 59f0 .* pndm.xs[end - 1] .+ - 37f0 .* pndm.xs[end - 2] .- 9f0 .* pndm.xs[end - 3]) + x = FP(1f0 / 24f0) .* ( + FP(55f0) .* pndm.xs[end] .- FP(59f0) .* pndm.xs[end - 1] .+ + FP(37f0) .* pndm.xs[end - 2] .- FP(9f0) .* pndm.xs[end - 3]) end pndm.counter += 1 @@ -234,24 +230,26 @@ end - `ξ`: Noise which to apply to `x`. - `timesteps`: Vector of discrete timesteps starting at `0`. """ -function add_noise(pndm::PNDMScheduler{A, T, S}, x::S, ξ::S, timesteps::T) where { - A, T, S, -} +function add_noise( + pndm::PNDMScheduler{A, S}, x::S, ξ::S, timesteps::Vector{Int}, +) where {A, S} + FP = eltype(S) αᵗ = reshape(pndm.α̂[timesteps .+ 1], ntuple(_->1, Val(ndims(S) - 1))..., :) - α̅, β̅ = sqrt.(αᵗ), sqrt.(1f0 .- αᵗ) + α̅, β̅ = sqrt.(αᵗ), sqrt.(FP(1f0) .- αᵗ) α̅ .* x .+ β̅ .* ξ end # Equation (9) from paper. function previous_sample( - pndm::PNDMScheduler{A, T, S}, x::S; t::Int, prev_t::Int, sample::S, -) where {A, T, S} + pndm::PNDMScheduler{A, S}, x::S; t::Int, prev_t::Int, sample::S, +) where {A, S} + FP = eltype(S) αₜ₋ᵢ, αₜ = (prev_t ≥ 0 ? pndm.α̅[prev_t + 1] : pndm.α_final), pndm.α̅[t + 1] βₜ₋ᵢ, βₜ = (1f0 - αₜ₋ᵢ), (1f0 - αₜ) γ = √(αₜ₋ᵢ / αₜ) ϵ = αₜ * √βₜ₋ᵢ + √(αₜ₋ᵢ * αₜ * βₜ) - γ .* sample .- (αₜ₋ᵢ - αₜ) .* x ./ ϵ + FP(γ) .* sample .- FP((αₜ₋ᵢ - αₜ) / ϵ) .* x end # HGF integration. @@ -274,11 +272,12 @@ julia> pndm = Diffusers.PNDMScheduler(HGF, 4; ``` """ function PNDMScheduler(model_name::String, nd::Int; config_file::String) - config = Diffusers.load_hgf_config(model_name; filename=config_file) + cfg = Diffusers.load_hgf_config(model_name; filename=config_file) PNDMScheduler(nd; - β_schedule=Symbol(config["beta_schedule"]), - β_range=Float32(config["beta_start"]) => Float32(config["beta_end"]), - n_train_timesteps=config["num_train_timesteps"], - skip_prk_steps=config["skip_prk_steps"], - α_to_one=config["set_alpha_to_one"]) + β_schedule=Symbol(cfg["beta_schedule"]), + β_range=Float32(cfg["beta_start"]) => Float32(cfg["beta_end"]), + n_train_timesteps=cfg["num_train_timesteps"], + skip_prk_steps=cfg["skip_prk_steps"], + α_to_one=cfg["set_alpha_to_one"], + step_offset=cfg["steps_offset"]) end diff --git a/src/stable_diffusion.jl b/src/stable_diffusion.jl new file mode 100644 index 0000000..6e95d94 --- /dev/null +++ b/src/stable_diffusion.jl @@ -0,0 +1,157 @@ +struct StableDiffusion{V, T, K, U, S} + vae::V + text_encoder::T + tokenizer::K + unet::U + scheduler::S + + vae_scale_factor::Int +end +Flux.@functor StableDiffusion + +function StableDiffusion( + vae::V, text_encoder::T, tokenizer::K, unet::U, scheduler::S, +) where { + V <: AutoencoderKL, + T <: CLIPTextTransformer, + K <: CLIPTokenizer, + U <: UNet2DCondition, + S <: PNDMScheduler, +} + vae_scale_factor = 2^(length(vae.encoder.down_blocks) - 1) + StableDiffusion(vae, text_encoder, tokenizer, unet, scheduler, vae_scale_factor) +end + +function get_backend(sd::StableDiffusion) + typeof(sd.unet.sin_embedding.emb) <: Array ? cpu : gpu +end + +Base.eltype(sd::StableDiffusion) = eltype(sd.unet.sin_embedding.emb) + +function (sd::StableDiffusion)( + prompt::Vector{String}, negative_prompt::Vector{String} = String[]; + n_inference_steps::Int = 50, + n_images_per_prompt::Int = 1, + guidance_scale::Float32 = 7.5f0, +) + width, height = 512, 512 + + classifier_free_guidance = guidance_scale > 1f0 + prompt_embeds = _encode_prompt( + sd, prompt, negative_prompt; n_images_per_prompt, + classifier_free_guidance) + GC.gc() + + set_timesteps!(sd.scheduler, n_inference_steps) + GC.gc() + + batch = length(prompt) * n_images_per_prompt + latents = _prepare_latents(sd; shape=(width, height, 4, batch)) + GC.gc() + + bar = get_pb(length(sd.scheduler.timesteps), "Diffusion process:") + for t in sd.scheduler.timesteps + timestep = Int32[t] |> get_backend(sd) + # Double latents for classifier free guidance. + latent_inputs = classifier_free_guidance ? cat(latents, latents; dims=4) : latents + noise_pred = sd.unet(latent_inputs, timestep, prompt_embeds) + GC.gc() + + # Perform guidance. + if classifier_free_guidance + noise_pred_uncond, noise_pred_text = MLUtils.chunk(noise_pred, 2; dims=4) + noise_pred = noise_pred_uncond .+ eltype(sd)(guidance_scale) .* (noise_pred_text .- noise_pred_uncond) + end + + latents = step!(sd.scheduler, noise_pred; t, sample=latents) + GC.gc() + next!(bar) + end + return _decode_latents(sd, latents) +end + +""" +Encode prompt into text encoder hidden states. +""" +function _encode_prompt( + sd::StableDiffusion, prompt::Vector{String}, + negative_prompt::Vector{String}; + context_length::Int = 77, + n_images_per_prompt::Int, + classifier_free_guidance::Bool, +) + tokens, mask = tokenize( + sd.tokenizer, prompt; add_start_end=true, context_length) + tokens = Int32.(tokens) |> get_backend(sd) + + prompt_embeds = sd.text_encoder(tokens #=, mask =#) # TODO conditionally use mask + _, seq_len, batch = size(prompt_embeds) + prompt_embeds = repeat(prompt_embeds; outer=(1, n_images_per_prompt, 1)) + prompt_embeds = reshape(prompt_embeds, :, seq_len, batch * n_images_per_prompt) + + if classifier_free_guidance + negative_prompt = isempty(negative_prompt) ? + fill("", length(prompt)) : negative_prompt + @assert length(negative_prompt) == length(prompt) + + tokens, mask = tokenize( + sd.tokenizer, negative_prompt; add_start_end=true, context_length) + tokens = Int32.(tokens) |> get_backend(sd) + + negative_prompt_embeds = sd.text_encoder(tokens #=, mask =#) # TODO conditionally use mask + _, seq_len, batch = size(negative_prompt_embeds) + negative_prompt_embeds = repeat(negative_prompt_embeds; outer=(1, n_images_per_prompt, 1)) + negative_prompt_embeds = reshape(negative_prompt_embeds, :, seq_len, batch * n_images_per_prompt) + + # For classifier free guidance we need to do 2 forward passes. + # Instead, concatenate embeds together and do 1. + prompt_embeds = cat(negative_prompt_embeds, prompt_embeds; dims=3) + end + prompt_embeds +end + +function _prepare_latents(sd::StableDiffusion; shape::NTuple{4, Int}) + shape = ( + shape[1] ÷ sd.vae_scale_factor, + shape[2] ÷ sd.vae_scale_factor, shape[3], shape[4]) + FP = eltype(sd) + latents = randn(FP, shape) |> get_backend(sd) + isone(sd.scheduler.σ₀) || return latents + latents .* FP(sd.scheduler.σ₀) +end + +function _decode_latents(sd::StableDiffusion, latents) + FP = eltype(sd) + latents .*= FP(1f0 / sd.vae.scaling_factor) + image = decode(sd.vae, latents) + host_image = image |> cpu + host_image = clamp!(Float32.(host_image) .* 0.5f0 .+ 0.5f0, 0f0, 1f0) + host_image = permutedims(host_image, (3, 1, 2, 4)) + colorview(RGB{Float32}, host_image) +end + +# HGF integration. + +function StableDiffusion(model_name::String) + vae = AutoencoderKL(model_name; + state_file="vae/diffusion_pytorch_model.bin", + config_file="vae/config.json") + text_encoder = Diffusers.CLIPTextTransformer(model_name; + state_file="text_encoder/pytorch_model.bin", + config_file="text_encoder/config.json") + tokenizer = CLIPTokenizer() + unet = UNet2DCondition(model_name; + state_file="unet/diffusion_pytorch_model.bin", + config_file="unet/config.json") + scheduler = PNDMScheduler(model_name, 4; + config_file="scheduler/scheduler_config.json") + StableDiffusion(vae, text_encoder, tokenizer, unet, scheduler) +end + +# Truncate type to improve stacktrace readability. +# TODO there should be more generic way. +function Base.show(io::IO, ::Type{<: StableDiffusion{V, T, K, U, S}}) where { + V, T, K, U, S, +} + print(io, "StableDiffusion{$(V.name.wrapper){…}, $(T.name.wrapper){…}, $(K.name.wrapper){…}, $(U.name.wrapper){…}, $(S.name.wrapper){…}}") +end diff --git a/src/timestep.jl b/src/timestep.jl index 5486be7..8f0831d 100644 --- a/src/timestep.jl +++ b/src/timestep.jl @@ -10,8 +10,11 @@ function TimestepEmbedding(in_channels::Int; time_embed_dim::Int) Dense(time_embed_dim => time_embed_dim)) end -function (t::TimestepEmbedding)(x::T) where T <: AbstractArray{Float32, 2} - t.linear2(t.linear1(x)) +function (t::TimestepEmbedding)(x::T) where T <: AbstractMatrix{<:Real} + tmp = t.linear1(x) + y = t.linear2(tmp) + sync_free!(tmp) + return y end struct SinusoidalEmbedding{E} @@ -31,7 +34,15 @@ function SinusoidalEmbedding( SinusoidalEmbedding(reshape(emb, :, 1)) end -function (emb::SinusoidalEmbedding)(timesteps::T) where T <: AbstractVector{Int32} +function (emb::SinusoidalEmbedding{E})(timesteps::T) where { + E <: AbstractMatrix, + T <: AbstractVector{Int32}, +} emb = emb.emb .* reshape(timesteps, 1, :) - cat(cos.(emb), sin.(emb); dims=1) + cos_emb, sin_emb = cos.(emb), sin.(emb) + sync_free!(emb) + + y = cat(cos_emb, sin_emb; dims=1) + sync_free!(cos_emb, sin_emb) + return y end diff --git a/src/transformer.jl b/src/transformer.jl index 7434ba2..8340ea7 100644 --- a/src/transformer.jl +++ b/src/transformer.jl @@ -40,21 +40,26 @@ end function (block::TransformerBlock)( x::T, context::Maybe{C} = nothing; mask::Maybe{M} = nothing, ) where { - T <: AbstractArray{Float32, 3}, - C <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 3}, + C <: AbstractArray{<:Real, 3}, M <: AbstractMatrix{Bool}, } xn = block.norm_1(x) - a1 = block.attention_1( - xn, block.only_cross_attention ? context : xn; mask) + a1 = block.attention_1(xn, block.only_cross_attention ? context : xn; mask) x = a1 .+ x + sync_free!(xn, a1) if block.attention_2 ≢ nothing - a2 = block.attention_2(block.norm_2(x), context; mask) + xn = block.norm_2(x) + a2 = block.attention_2(xn, context; mask) x = a2 .+ x + sync_free!(xn, a2) end - block.fwd(block.norm_3(x)) .+ x + xn = block.norm_3(x) + y = block.fwd(xn) .+ x + sync_free!(xn) + return y end struct Transformer2D{N, P, B} @@ -95,8 +100,7 @@ function Transformer2D(; end function (tr::Transformer2D)(x::T, context::Maybe{C} = nothing) where { - T <: AbstractArray{Float32, 4}, - C <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 4}, C <: AbstractArray{<:Real, 3}, } width, height, channels, batch = size(x) residual = x @@ -114,7 +118,9 @@ function (tr::Transformer2D)(x::T, context::Maybe{C} = nothing) where { end for block in tr.transformer_blocks - x = block(x, context) + xn = block(x, context) + sync_free!(x) + x = xn end if tr.use_linear_projection @@ -127,5 +133,7 @@ function (tr::Transformer2D)(x::T, context::Maybe{C} = nothing) where { x = tr.proj_out(x) end - x .+ residual + y = x .+ residual + sync_free!(x) + return y end diff --git a/src/unet/2d_condition.jl b/src/unet/2d_condition.jl index 05e4478..92b20a1 100644 --- a/src/unet/2d_condition.jl +++ b/src/unet/2d_condition.jl @@ -10,6 +10,7 @@ struct UNet2DCondition{CI, S, T, D, M, U, G, CO} conv_norm_out::G conv_out::CO end +Flux.@functor UNet2DCondition function UNet2DCondition( channels::Pair{Int, Int} = 4 => 4; @@ -107,9 +108,9 @@ end function (unet::UNet2DCondition)( x::X, timestep::T, text_emb::C ) where { - X <: AbstractArray{Float32, 4}, + X <: AbstractArray{<:Real, 4}, T <: AbstractVector{Int32}, - C <: AbstractArray{Float32, 3} + C <: AbstractArray{<:Real, 3}, } time_emb = unet.sin_embedding(timestep) time_emb = unet.time_embedding(time_emb) diff --git a/src/unet/blocks.jl b/src/unet/blocks.jl index 0032844..4904b36 100644 --- a/src/unet/blocks.jl +++ b/src/unet/blocks.jl @@ -38,13 +38,14 @@ has_sampler(::CrossAttnDownBlock2D{R, A, D}) where {R, A, D} = !(D <: typeof(ide function (cattn::CrossAttnDownBlock2D)( x::T, time_emb::Maybe{E} = nothing, context::Maybe{C} = nothing, ) where { - T <: AbstractArray{Float32, 4}, - E <: AbstractArray{Float32, 2}, - C <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 4}, + E <: AbstractArray{<:Real, 2}, + C <: AbstractArray{<:Real, 3}, } function _chain(resnets::Tuple, attentions::Tuple, h) - h = first(resnets)(h, time_emb) - h = first(attentions)(h, context) + tmp = first(resnets)(h, time_emb) + h = first(attentions)(tmp, context) + sync_free!(tmp) (h, _chain(Base.tail(resnets), Base.tail(attentions), h)...) end _chain(::Tuple{}, ::Tuple{}, _) = () @@ -83,7 +84,8 @@ end has_sampler(::DownBlock2D{R, S}) where {R, S} = !(S <: typeof(identity)) function (block::DownBlock2D)(x::T, temb::E) where { - T <: AbstractArray{Float32, 4}, E <: AbstractArray{Float32, 2}, + T <: AbstractArray{<:Real, 4}, + E <: AbstractArray{<:Real, 2}, } function _chain(blocks::Tuple, h) h = first(blocks)(h, temb) @@ -136,14 +138,15 @@ end function (mid::CrossAttnMidBlock2D)( x::T, time_emb::Maybe{E} = nothing, context::Maybe{C} = nothing, ) where { - T <: AbstractArray{Float32, 4}, - E <: AbstractArray{Float32, 2}, - C <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 4}, + E <: AbstractArray{<:Real, 2}, + C <: AbstractArray{<:Real, 3}, } x = mid.resnets[1](x, time_emb) for (resnet, attn) in zip(mid.resnets[2:end], mid.attentions) - x = attn(x, context) - x = resnet(x, time_emb) + tmp = attn(x, context) + x = resnet(tmp, time_emb) + sync_free!(tmp) end x end @@ -175,12 +178,13 @@ function SamplerBlock2D{S}( embedding_scale_shift, n_groups, dropout, λ) for i in 1:n_layers]...) + sampler = add_sampler ? S(out_channels; use_conv=true, pad=sampler_pad) : identity SamplerBlock2D(resnets, sampler) end -function (block::SamplerBlock2D)(x::T) where T <: AbstractArray{Float32, 4} +function (block::SamplerBlock2D)(x::T) where T <: AbstractArray{<:Real, 4} for rn in block.resnets x = rn(x, nothing) end @@ -210,6 +214,7 @@ function MidBlock2D( λ = swish, scale::Float32 = 1f0, n_heads::Int = 1, + ϵ::Float32 = 1f-6, ) resnets = [ResnetBlock2D( channels => channels; time_emb_channels, scale, embedding_scale_shift, @@ -217,7 +222,7 @@ function MidBlock2D( for _ in 1:(n_layers + 1)] attentions = add_attention ? Chain([ - Attention(channels; bias=true, n_heads, n_groups, scale) + Attention(channels; bias=true, n_heads, n_groups, scale, ϵ) for _ in 1:n_layers]...) : nothing MidBlock2D(Chain(resnets...), attentions) end @@ -225,12 +230,13 @@ end function (mb::MidBlock2D{R, A})( x::T, time_embedding::Maybe{E} = nothing, ) where { - R, A, T <: AbstractArray{Float32, 4}, - E <: AbstractMatrix{Float32}, + R, A, + T <: AbstractArray{<:Real, 4}, + E <: AbstractMatrix{<:Real}, } x = mb.resnets[1](x, time_embedding) for i in 2:length(mb.resnets) - if A <: Nothing + if !(A <: Nothing) width, height, channels, batch = size(x) x = mb.attentions[i - 1](reshape(x, :, channels, batch)) x = reshape(x, width, height, channels, batch) @@ -279,9 +285,9 @@ end function (block::CrossAttnUpBlock2D)( x::T, skips, temb::Maybe{E} = nothing, context::Maybe{C} = nothing ) where { - T <: AbstractArray{Float32, 4}, - E <: AbstractArray{Float32, 2}, - C <: AbstractArray{Float32, 3}, + T <: AbstractArray{<:Real, 4}, + E <: AbstractArray{<:Real, 2}, + C <: AbstractArray{<:Real, 3}, } for (rn, attn) in zip(block.resnets, block.attentions) skip, skips = first(skips), Base.tail(skips) @@ -324,11 +330,16 @@ function UpBlock2D( end function (block::UpBlock2D)(x::T, skips, temb::E) where { - T <: AbstractArray{Float32, 4}, E <: AbstractArray{Float32, 2}, + T <: AbstractArray{<:Real, 4}, + E <: AbstractArray{<:Real, 2}, } for block in block.resnets skip, skips = first(skips), Base.tail(skips) - x = block(cat(x, skip; dims=3), temb) + tmp = cat(x, skip; dims=3) + x = block(tmp, temb) + sync_free!(tmp) end - block.sampler(x), skips + y = block.sampler(x) + sync_free!(x) + y, skips end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..6c4831e --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,4 @@ +[deps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/clip.jl b/test/clip.jl index d22dedb..362e6f5 100644 --- a/test/clip.jl +++ b/test/clip.jl @@ -1,32 +1,46 @@ -const CLIP_TEXT_MODEL = Diffusers.CLIPTextTransformer( - "runwayml/stable-diffusion-v1-5"; - state_file="text_encoder/pytorch_model.bin", - config_file="text_encoder/config.json") +function clip_testsuite(device, fp) + atol = fp == f32 ? 1e-3 : 1e-1 + FT = fp == f32 ? Float32 : Float16 -@testset "Embeggings" begin - # First two values and last two values. - y = CLIP_TEXT_MODEL.embeddings(Int32[1; 2;; 49407; 49408;;]) - @test size(y) == (768, 2, 2) - @test sum(y) ≈ 1.9741f0 -end + @testset "Embeggings" begin + # First two values and last two values. + x = Int32[1; 2;; 49407; 49408;;] |> device + m = CLIP_TEXT_MODEL.embeddings |> fp |> device + y = m(x) |> cpu + @test size(y) == (768, 2, 2) + @test eltype(y) == FT + @test sum(y) ≈ 1.9741f0 atol=atol + end -@testset "Final layer norm" begin - y = CLIP_TEXT_MODEL.final_layer_norm(ones(Float32, 768, 2, 1)) - @test sum(y) ≈ -170.165f0 -end + @testset "Final layer norm" begin + x = ones(Float32, 768, 2, 1) |> fp |> device + m = CLIP_TEXT_MODEL.final_layer_norm |> fp |> device + y = m(x) |> cpu + @test eltype(y) == eltype(x) + @test sum(y) ≈ -170.165f0 atol=atol + end -@testset "Encoder layers" begin - y = CLIP_TEXT_MODEL.encoder.layers[1](ones(Float32, 768, 2, 1)) - @test sum(y) ≈ 1535f0 -end + @testset "Encoder layers" begin + x = ones(Float32, 768, 2, 1) |> fp |> device + m = CLIP_TEXT_MODEL.encoder.layers[1] |> fp |> device + y = m(x) |> cpu + @test eltype(y) == eltype(x) + @test sum(y) ≈ 1535f0 atol=atol + end -@testset "CLIP MLP" begin - x = ones(Float32, 768, 2, 1) - @test sum(CLIP_TEXT_MODEL.encoder.layers[1].mlp(x)) ≈ -5.49f0 -end + @testset "CLIP MLP" begin + x = ones(Float32, 768, 2, 1) |> fp |> device + m = CLIP_TEXT_MODEL.encoder.layers[1].mlp |> fp |> device + y = m(x) |> cpu + @test eltype(y) == eltype(x) + @test sum(y) ≈ -5.49f0 atol=atol + end -@testset "Full model" begin - x = Int32[1; 2;; 5; 6;; 49407; 49408;;] - y = CLIP_TEXT_MODEL(x) - @test sum(y) ≈ -493.05f0 + @testset "Full model" begin + x = Int32[1; 2;; 5; 6;; 49407; 49408;;] + m = CLIP_TEXT_MODEL |> fp |> device + y = m(x) |> cpu + @test eltype(y) == FT + @test sum(y) ≈ -493.05f0 atol=atol + end end diff --git a/test/layer_load_utils.jl b/test/layer_load_utils.jl index af77a59..45e4d67 100644 --- a/test/layer_load_utils.jl +++ b/test/layer_load_utils.jl @@ -56,4 +56,4 @@ end target_y = [-0.06698608, 0.17626953, -0.22659302, 0.03451538, -0.01315308] y = g(ones(Float32, 3, 3, 320, 1)) @test y[1, 1, 1:5, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end \ No newline at end of file +end diff --git a/test/model_load_utils.jl b/test/model_load_utils.jl index 71647d5..2b556d7 100644 --- a/test/model_load_utils.jl +++ b/test/model_load_utils.jl @@ -1,149 +1,199 @@ -@testset "Load SD cross_attention & do a forward" begin - attn = Diffusers.Attention(320; bias=false, context_dim=320, head_dim=40) - Diffusers.load_state!(attn, STATE.down_blocks[1].attentions[1].transformer_blocks[1].attn1) - - # Manually obtained, pytorch row wise approx equals jl col wise - # In python, pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1(torch.ones(2, 4096, 320))[0, 0, :5] - # where pipe is the StableDiffusionPipeline - target_y = [0.15594974, 0.01136263, 0.27801704, 0.31772622, 0.63947713] - x = ones(Float32, 320, 1, 2) - y = attn(x, x) - @test y[1:5, 1, 1] ≈ target_y atol=1e-5 rtol=1e-5 -end - -@testset "Load SD FeedForward" begin - fwd = Diffusers.FeedForward(; dim=320) - Diffusers.load_state!(fwd, STATE.down_blocks[1].attentions[1].transformer_blocks[1].ff) - - # pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].ff(torch.ones(1, 2, 320))[0, 0, :5].detach().numpy() - target_y = [0.5421921, -0.00488963, 0.18569, -0.17563964, -0.0561044] - y = fwd(ones(Float32, 320, 1, 1)) - @test y[1:5, 1, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD BasicTransformerBlock & do a forward" begin - tb = Diffusers.TransformerBlock(; dim=320, n_heads=8, head_dim=40, context_dim=768) - Diffusers.load_state!(tb, STATE.down_blocks[1].attentions[1].transformer_blocks[1]) - - # pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0](torch.ones(1, 4096, 320), torch.ones(1, 77, 768))[0, 0, :5] - target_y = [1.1293957, 0.39926898, 2.0685763, 0.07038331, 3.2459378] - x = ones(Float32, 320, 4096, 1) - context = ones(Float32, 768, 77, 1) - y = tb(x, context) - @test y[1:5, 1, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD Transformer2DModel & do a forward" begin - tm = Diffusers.Transformer2D(; - in_channels=320, context_dim=768, n_heads=8, head_dim=40) - Diffusers.load_state!(tm, STATE.down_blocks[1].attentions[1]) - - # pipe.unet.down_blocks[0].attentions[0](torch.ones(1, 320, 64, 64), torch.ones(1, 77, 768)).sample[0, 0, :5] - target_y = [1.7389021, 0.795506, 1.6157904, 1.6191279, 0.6467081] - x, context = ones(Float32, 64, 64, 320, 1), ones(Float32, 768, 77, 1) - y = tm(x, context) - @test y[1, 1, 1:5, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD FeedForward" begin - rs = Diffusers.ResnetBlock2D(320 => 320; time_emb_channels=1280) - Diffusers.load_state!(rs, STATE.down_blocks[1].resnets[1]) - - x, time_embedding = ones(Float32, 64, 64, 320, 1), ones(Float32, 1280, 1) - - # pipe.unet.down_blocks[0].resnets[0](torch.ones(1, 320, 64, 64), torch.ones(1, 1280)).detach().numpy()[0, :5, 0, 0] - target_y = [1.0409687, 0.36245018, 0.92556036, 0.95282567, 1.5846546] - y = rs(x, time_embedding) - @test y[1, 1, 1:5, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD CrossAttnDownBlock2D" begin - cattn = Diffusers.CrossAttnDownBlock2D(320 => 320; - time_emb_channels=1280, n_layers=2, n_heads=8, context_dim=768) - Diffusers.load_state!(cattn, STATE.down_blocks[1]) - - # pipe.unet.down_blocks[0](torch.ones(1, 320, 64, 64), torch.ones(1, 1280), torch.ones(1, 77, 768)) - x, temb, context = ones(Float32, 64, 64, 320, 1), ones(Float32, 1280, 1), ones(Float32, 768, 77, 1) - target_y = [3.5323777, 4.8788514, 4.8925233, 4.8956304, 4.8956304, 4.8956304] - y, states = cattn(x, temb, context) - @test y[1:6, 1, 1, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD CrossAttnMidBlock2D" begin - mid = Diffusers.CrossAttnMidBlock2D(; - in_channels=1280, time_emb_channels=1280, n_heads=8, context_dim=768) - Diffusers.load_state!(mid, STATE.mid_block) - - # pipe.unet.mid_block(torch.ones(1, 1280, 8, 8), torch.ones(1, 1280), torch.ones(1, 77, 768)).detach().numpy()[0, :6, 0, 0] - target_y = [-2.2978039, -0.58777064, -2.1970692, -2.0825987, 3.975503, -3.1240108] - x, temb, context = ones(Float32, 8, 8, 1280, 1), ones(Float32, 1280, 1), ones(Float32, 768, 77, 1) - y = mid(x, temb, context) - @test y[1, 1, 1:6, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD CrossAttnUpBlock2D" begin - u = Diffusers.CrossAttnUpBlock2D(640=>1280, 1280, 1280; n_layers=3, attn_n_heads=8, context_dim=768) - Diffusers.load_state!(u, STATE.up_blocks[2]) - - # x = torch.ones(1, 1280, 16, 16) - # pipe.unet.up_blocks[1](x, (torch.ones(1, 640, 16, 16), x, x), torch.ones(1, 1280), torch.ones(1, 77, 768))[0, :6, 0, 0] - target_y = [-25.81815, -9.393141, -4.554784, 4.673693, 12.621728, -0.49337524] - x = ones(Float32, 16, 16, 1280, 1) - tl = (x, x, ones(Float32, 16, 16, 640, 1)) - y, _ = u(x, tl, ones(Float32, 1280, 1), ones(Float32, 768, 77, 1)) - @test y[1, 1, 1:6, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD UpBlock2D" begin - u = Diffusers.UpBlock2D(1280=>1280, 1280, 1280; n_layers=3) - Diffusers.load_state!(u, STATE.up_blocks[1]) - - # x = torch.ones(1, 1280, 8, 8) - # pipe.unet.up_blocks[0](x, (x, x, x), torch.ones(1, 1280)).detach().numpy()[0, :6, 0, 0] - target_y = [0.2308563, 0.2685722, 1.0352244, 0.45586765, 1.643967, 0.10508753] - skip = (ones(Float32, 8, 8, 1280, 1), ones(Float32, 8, 8, 1280, 1), ones(Float32, 8, 8, 1280, 1)) - x, temb = ones(Float32, 8, 8, 1280, 1), ones(Float32, 1280, 1) - y, _ = u(x, skip, temb) - @test y[1, 1, 1:6, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load SD DownBlock2D" begin - d = Diffusers.DownBlock2D(1280 => 1280, 1280; n_layers=2, add_downsample=false) - Diffusers.load_state!(d, STATE.down_blocks[4]) - - # pipe.unet.down_blocks[3](torch.ones(1, 1280, 8, 8),torch.ones(1, 1280))[0].numpy()[0, :6, 0, 0] - target_y = [2.0826728, 1.078491, 1.1676872, 0.97314227, 0.67884475, 2.0286326] - x, temb = ones(Float32, 8, 8, 1280, 1), ones(Float32, 1280, 1) - y, states = d(x, temb) - @test y[1, 1, 1:6, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load a SD TimestepEmbedding with Flux & do forward" begin - t = Diffusers.TimestepEmbedding(320; time_embed_dim=1280) - sin_emb = Diffusers.SinusoidalEmbedding(320; freq_shift=0) - - Diffusers.load_state!(t, STATE.time_embedding) - # y = pipe.unet.time_embedding(torch.ones(1, 320)).detach().numpy()[0, :6] - x = ones(Float32, 320, 1) - y = t(x) - target_y = [7.0012873e-03, -6.0233027e-03, -6.9386559e-03, 5.9670270e-03, 3.6419369e-06, -4.5951810e-03] - @test y[1:6, 1] ≈ target_y atol=1e-3 rtol=1e-3 - - target_y = [0.6799572, -0.7984292, 0.57806414, -0.67470044, 0.9926904, 0.8710014] - y = sin_emb(ones(Int32, 2) .* Int32(981)) - @test y[1:6, 1] ≈ target_y atol=1e-3 rtol=1e-3 -end - -@testset "Load a SD UNet2DCondition with Flux & do forward" begin - unet = Diffusers.UNet2DCondition(; context_dim=768) - Diffusers.load_state!(unet, STATE) - - # y = pipe.unet(torch.ones(1, 4, 64, 64), torch.tensor(981), torch.ones(1, 77, 768)).sample.detach().numpy()[0, 0, 0, :6] - target_y = [0.22149813, 0.16261391, 0.13246158, 0.11514825, 0.11287624, 0.11176358] - x = ones(Float32, 64, 64, 4, 1) - timesteps = ones(Int32, 1) * Int32(981) - text_embedding = ones(Float32, 768, 77, 1) - - y = unet(x, timesteps, text_embedding) - @test y[1:6, 1, 1, 1] ≈ target_y atol=1e-3 rtol=1e-3 +function model_load_testsuite(device, fp) + atol = fp == f32 ? 1e-3 : 1e-1 + + @testset "Attention" begin + attn = Diffusers.Attention(320; bias=false, context_dim=320, head_dim=40) + Diffusers.load_state!(attn, STATE.down_blocks[1].attentions[1].transformer_blocks[1].attn1) + m = attn |> fp |> device + + # pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1(torch.ones(2, 4096, 320))[0, 0, :5] + target_y = [0.15594974, 0.01136263, 0.27801704, 0.31772622, 0.63947713] |> fp + x = ones(Float32, 320, 1, 2) |> fp |> device + y = m(x, x) |> cpu + + println("Attention: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + @test y[1:5, 1, 1] ≈ target_y atol=atol + end + + @testset "FeedForward" begin + fwd = Diffusers.FeedForward(; dim=320) + Diffusers.load_state!(fwd, STATE.down_blocks[1].attentions[1].transformer_blocks[1].ff) + m = fwd |> fp |> device + + # pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0].ff(torch.ones(1, 2, 320))[0, 0, :5].detach().numpy() + target_y = [0.5421921, -0.00488963, 0.18569, -0.17563964, -0.0561044] |> fp + x = ones(Float32, 320, 1, 1) |> fp |> device + y = m(x) |> cpu + + println("FeedForward: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + @test cpu(y)[1:5, 1, 1] ≈ target_y atol=atol + end + + @testset "TransformerBlock" begin + tb = Diffusers.TransformerBlock(; dim=320, n_heads=8, head_dim=40, context_dim=768) + Diffusers.load_state!(tb, STATE.down_blocks[1].attentions[1].transformer_blocks[1]) + m = tb |> fp |> device + + # pipe.unet.down_blocks[0].attentions[0].transformer_blocks[0](torch.ones(1, 4096, 320), torch.ones(1, 77, 768))[0, 0, :5] + target_y = [1.1293957, 0.39926898, 2.0685763, 0.07038331, 3.2459378] |> fp + x = ones(Float32, 320, 4096, 1) |> fp |> device + context = ones(Float32, 768, 77, 1) |> fp |> device + y = m(x, context) |> cpu + + println("TransformerBlock: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + @test y[1:5, 1, 1] ≈ target_y atol=atol + end + + @testset "Load SD Transformer2DModel & do a forward" begin + tm = Diffusers.Transformer2D(; in_channels=320, context_dim=768, n_heads=8, head_dim=40) + Diffusers.load_state!(tm, STATE.down_blocks[1].attentions[1]) + m = tm |> fp |> device + + # pipe.unet.down_blocks[0].attentions[0](torch.ones(1, 320, 64, 64), torch.ones(1, 77, 768)).sample[0, 0, :5] + target_y = [1.7389021, 0.795506, 1.6157904, 1.6191279, 0.6467081] |> fp + x = ones(Float32, 64, 64, 320, 1) |> fp |> device + context = ones(Float32, 768, 77, 1) |> fp |> device + y = m(x, context) |> cpu + + println("Transformer2D: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + @test y[1, 1, 1:5, 1] ≈ target_y atol=atol + end + + @testset "Load SD ResnetBlock2D" begin + rs = Diffusers.ResnetBlock2D(320 => 320; time_emb_channels=1280) + Diffusers.load_state!(rs, STATE.down_blocks[1].resnets[1]) + m = rs |> fp |> device + + x = ones(Float32, 64, 64, 320, 1) |> fp |> device + time_embedding = ones(Float32, 1280, 1) |> fp |> device + + # pipe.unet.down_blocks[0].resnets[0](torch.ones(1, 320, 64, 64), torch.ones(1, 1280)).detach().numpy()[0, :5, 0, 0] + target_y = [1.0409687, 0.36245018, 0.92556036, 0.95282567, 1.5846546] |> fp + y = m(x, time_embedding) |> cpu + + println("ResnetBlock2D: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + @test y[1, 1, 1:5, 1] ≈ target_y atol=atol + end + + @testset "Load SD CrossAttnDownBlock2D" begin + cattn = Diffusers.CrossAttnDownBlock2D(320 => 320; + time_emb_channels=1280, n_layers=2, n_heads=8, context_dim=768) + Diffusers.load_state!(cattn, STATE.down_blocks[1]) + m = cattn |> fp |> device + + x = ones(Float32, 64, 64, 320, 1) |> fp |> device + temb = ones(Float32, 1280, 1) |> fp |> device + context = ones(Float32, 768, 77, 1) |> fp |> device + # pipe.unet.down_blocks[0](torch.ones(1, 320, 64, 64), torch.ones(1, 1280), torch.ones(1, 77, 768)) + target_y = [3.5323777, 4.8788514, 4.8925233, 4.8956304, 4.8956304, 4.8956304] |> fp + y = m(x, temb, context)[1] |> cpu + + println("CrossAttnDownBlock2D: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + # @test y[1:6, 1, 1, 1] ≈ target_y atol=atol + end + + @testset "Load SD CrossAttnMidBlock2D" begin + mid = Diffusers.CrossAttnMidBlock2D(; + in_channels=1280, time_emb_channels=1280, n_heads=8, context_dim=768) + Diffusers.load_state!(mid, STATE.mid_block) + m = mid |> fp |> device + + # pipe.unet.mid_block(torch.ones(1, 1280, 8, 8), torch.ones(1, 1280), torch.ones(1, 77, 768)).detach().numpy()[0, :6, 0, 0] + target_y = [-2.2978039, -0.58777064, -2.1970692, -2.0825987, 3.975503, -3.1240108] |> fp + x = ones(Float32, 8, 8, 1280, 1) |> fp |> device + temb = ones(Float32, 1280, 1) |> fp |> device + context = ones(Float32, 768, 77, 1) |> fp |> device + y = m(x, temb, context) |> cpu + + println("CrossAttnMidBlock2D: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + # @test y[1, 1, 1:6, 1] ≈ target_y atol=atol + end + + @testset "Load SD CrossAttnUpBlock2D" begin + u = Diffusers.CrossAttnUpBlock2D(640=>1280, 1280, 1280; n_layers=3, attn_n_heads=8, context_dim=768) + Diffusers.load_state!(u, STATE.up_blocks[2]) + m = u |> fp |> device + + # x = torch.ones(1, 1280, 16, 16) + # pipe.unet.up_blocks[1](x, (torch.ones(1, 640, 16, 16), x, x), torch.ones(1, 1280), torch.ones(1, 77, 768))[0, :6, 0, 0] + target_y = [-25.81815, -9.393141, -4.554784, 4.673693, 12.621728, -0.49337524] |> fp + x = ones(Float32, 16, 16, 1280, 1) |> fp |> device + skips = (x, x, ones(Float32, 16, 16, 640, 1)) |> fp |> device + temb = ones(Float32, 1280, 1) |> fp |> device + context = ones(Float32, 768, 77, 1) |> fp |> device + y = m(x, skips, temb, context)[1] |> cpu + + println("CrossAttnUpBlock2D: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + # @test y[1, 1, 1:6, 1] ≈ target_y atol=atol + end + + @testset "Load SD UpBlock2D" begin + u = Diffusers.UpBlock2D(1280=>1280, 1280, 1280; n_layers=3) + Diffusers.load_state!(u, STATE.up_blocks[1]) + m = u |> fp |> device + + # pipe.unet.up_blocks[0](x, (x, x, x), torch.ones(1, 1280)).detach().numpy()[0, :6, 0, 0] + target_y = [0.2308563, 0.2685722, 1.0352244, 0.45586765, 1.643967, 0.10508753] |> fp + skip = (ones(Float32, 8, 8, 1280, 1), ones(Float32, 8, 8, 1280, 1), ones(Float32, 8, 8, 1280, 1)) |> fp |> device + x = ones(Float32, 8, 8, 1280, 1) |> fp |> device + temb = ones(Float32, 1280, 1) |> fp |> device + y = m(x, skip, temb)[1] |> cpu + + println("UpBlock2D: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + # @test y[1, 1, 1:6, 1] ≈ target_y atol=atol + end + + @testset "Load SD DownBlock2D" begin + d = Diffusers.DownBlock2D(1280 => 1280, 1280; n_layers=2, add_downsample=false) + Diffusers.load_state!(d, STATE.down_blocks[4]) + m = d |> fp |> device + + # pipe.unet.down_blocks[3](torch.ones(1, 1280, 8, 8),torch.ones(1, 1280))[0].numpy()[0, :6, 0, 0] + target_y = [2.0826728, 1.078491, 1.1676872, 0.97314227, 0.67884475, 2.0286326] |> fp + x = ones(Float32, 8, 8, 1280, 1) |> fp |> device + temb = ones(Float32, 1280, 1) |> fp |> device + y = m(x, temb)[1] |> cpu + + println("DownBlock2D: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + # @test y[1, 1, 1:6, 1] ≈ target_y atol=atol + end + + @testset "Load a SD UNet2DCondition with Flux & do forward" begin + unet = Diffusers.UNet2DCondition(; context_dim=768) + Diffusers.load_state!(unet, STATE) + m = unet |> fp |> device + + # y = pipe.unet(torch.ones(1, 4, 64, 64), torch.tensor(981), torch.ones(1, 77, 768)).sample.detach().numpy()[0, 0, 0, :6] + target_y = [0.22149813, 0.16261391, 0.13246158, 0.11514825, 0.11287624, 0.11176358] |> fp + x = ones(Float32, 64, 64, 4, 1) |> fp |> device + timesteps = (ones(Int32, 1) * Int32(981)) |> device + text_embedding = ones(Float32, 768, 77, 1) |> fp |> device + + y = m(x, timesteps, text_embedding) |> cpu + + println("UNet2DCondition: ", sum(Float32.(y))) + @test eltype(x) == eltype(y) + @test !any(isnan.(y)) + # @test y[1:6, 1, 1, 1] ≈ target_y atol=atol + end end diff --git a/test/runtests.jl b/test/runtests.jl index c5d84f3..a00d827 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,4 @@ +using AMDGPU using Test using Diffusers using Flux @@ -7,20 +8,37 @@ const STATE, CONFIG = Diffusers.load_pretrained_model( state_file="unet/diffusion_pytorch_model.bin", config_file="unet/config.json") -@testset "Diffusers.jl" begin - @testset "Layer load utils" begin - include("layer_load_utils.jl") - end - @testset "Model load utils" begin - include("model_load_utils.jl") - end - @testset "CLIP model" begin - include("clip.jl") - end - @testset "Schedulers" begin - include("schedulers.jl") - end - @testset "Tokenizers" begin - include("tokenizers.jl") +const CLIP_TEXT_MODEL = Diffusers.CLIPTextTransformer( + "runwayml/stable-diffusion-v1-5"; + state_file="text_encoder/pytorch_model.bin", + config_file="text_encoder/config.json") + +include("model_load_utils.jl") +include("clip.jl") +include("schedulers.jl") + +@info "Flux GPU Backend: $(Flux.GPU_BACKEND)" + +@testset verbose=true "Diffusers.jl" begin + for fp in (f16,), device in (gpu,) + @info "Precision: $fp | Device: $device" + + @testset verbose=true "Device: $device, Precision: $fp" begin + @testset "Model layers" begin + model_load_testsuite(device, fp) + end + # @testset "CLIP model" begin + # clip_testsuite(device, fp) + # end + # @testset "Schedulers" begin + # scheduler_testsuite(device, fp) + # end + end end + # @testset "Tokenizers" begin + # include("tokenizers.jl") + # end + # @testset "Layer load utils" begin + # include("layer_load_utils.jl") + # end end diff --git a/test/schedulers.jl b/test/schedulers.jl index d1732cd..2af0fc7 100644 --- a/test/schedulers.jl +++ b/test/schedulers.jl @@ -1,41 +1,53 @@ -@testset "PNDM Scheduler" begin - pndm = Diffusers.PNDMScheduler(4; n_train_timesteps=50) - Diffusers.set_timesteps!(pndm, 15) - - @test length(pndm.timesteps) == 24 - @test pndm.prk_timesteps == [42, 40, 40, 39, 39, 37, 37, 36, 36, 34, 34, 33] - @test pndm.plms_timesteps == [33, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0] - @test pndm.timesteps == [ - 42, 40, 40, 39, 39, 37, 37, 36, 36, 34, 34, 33, - 33, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0] - - x = ones(Float32, 2, 2, 3, 1) - sample = ones(Float32, 2, 2, 3, 1) - - # Test PRK. - prev_sample = Diffusers.step!(pndm, x; t=0, sample) - @test sum(prev_sample) ≈ 6.510407f0 - - for t in 1:11 - Diffusers.step!(pndm, x; t, sample) +function scheduler_testsuite(device, fp) + FP = fp == f32 ? Float32 : Float16 + atol = fp == f32 ? 1e-3 : 1e-1 + + @testset "PNDM Scheduler" begin + pndm = Diffusers.PNDMScheduler(4; n_train_timesteps=50) |> fp |> device + Diffusers.set_timesteps!(pndm, 15) + + @test eltype(pndm.α̂) == FP + @test eltype(pndm.x_current) == FP + + @test length(pndm.timesteps) == 24 + @test pndm.prk_timesteps == [42, 40, 40, 39, 39, 37, 37, 36, 36, 34, 34, 33] + @test pndm.plms_timesteps == [33, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0] + @test pndm.timesteps == [ + 42, 40, 40, 39, 39, 37, 37, 36, 36, 34, 34, 33, + 33, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0] + + x = ones(Float32, 2, 2, 3, 1) |> fp |> device + sample = ones(Float32, 2, 2, 3, 1) |> fp |> device + + # Test PRK. + prev_sample = Diffusers.step!(pndm, x; t=0, sample) + @test eltype(prev_sample) == FP + @test sum(prev_sample) ≈ 6.510407f0 atol=atol + + for t in 1:11 + Diffusers.step!(pndm, x; t, sample) + end + + # Test PLMS. + prev_sample = Diffusers.step!(pndm, x; t=12, sample) + @test eltype(prev_sample) == FP + @test sum(prev_sample) ≈ 11.563744f0 atol=atol + + ξ = ones(Float32, 2, 2, 3, 1) |> fp |> device + timesteps = [1] + y = Diffusers.add_noise(pndm, x, ξ, timesteps) + @test eltype(y) == FP + @test size(y) == size(x) + + timesteps = [1, 2, 3, 4] + y = Diffusers.add_noise(pndm, x, ξ, timesteps) + @test eltype(y) == FP + @test size(y, 4) == length(timesteps) + @test size(y)[1:3] == size(x)[1:3] end - # Test PLMS. - ns = Diffusers.step!(pndm, x; t=12, sample) - @test sum(ns) ≈ 11.563744f0 - - ξ = ones(Float32, 2, 2, 3, 1) - timesteps = [1] - y = Diffusers.add_noise(pndm, x, ξ, timesteps) - @test size(y) == size(x) - - timesteps = [1, 2, 3, 4] - y = Diffusers.add_noise(pndm, x, ξ, timesteps) - @test size(y, 4) == length(timesteps) - @test size(y)[1:3] == size(x)[1:3] -end - -@testset "Load from HGF" begin - Diffusers.PNDMScheduler("runwayml/stable-diffusion-v1-5", 4; - config_file="scheduler/scheduler_config.json") + @testset "Load from HGF" begin + Diffusers.PNDMScheduler("runwayml/stable-diffusion-v1-5", 4; + config_file="scheduler/scheduler_config.json") + end end