Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement StableDiffusion #22

Merged
merged 19 commits into from
Apr 20, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/docs/build/
Manifest.toml
LocalPreferences.toml
20 changes: 13 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,26 @@ authors = ["Harish Anand"]
version = "0.1.0"

[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
AMDGPU = "0.4.12"
NNlib = "0.8.20"
KernelAbstractions = "0.9.2"
julia = "1.9"
Binary file added painting-of-a-farmer-in-the-field.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
147 changes: 98 additions & 49 deletions src/Diffusers.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,87 @@
module Diffusers
export HGF

import MLUtils
import JSON3
import MLUtils
import Pickle

using Adapt
using FileIO
using Flux
using HuggingFaceApi
using ImageCore
using ImageIO
using OrderedCollections
using ProgressMeter
using Statistics

const Maybe{T} = Union{Nothing, T}
using AMDGPU
using KernelAbstractions
const Backend = ROCBackend()

const HGF = Val{:HGF}()
function sync_free!(args...)
KernelAbstractions.synchronize(Backend)
KernelAbstractions.unsafe_free!.(args)
end

const Maybe{T} = Union{Nothing, T}

# TODO better way of handling this
const FluxDeviceAdaptors = (
Flux.FluxCPUAdaptor,
Flux.FluxCUDAAdaptor,
Flux.FluxAMDAdaptor)
const FluxEltypeAdaptors = (
Flux.FluxEltypeAdaptor{Float32},
Flux.FluxEltypeAdaptor{Float16})

get_pb(n, desc::String) = Progress(
n; desc, dt=1, barglyphs=BarGlyphs("[=> ]"), barlen=50, color=:white)

# TODO
# This matches what PyTorch is doing.
# Upstream to Flux.
# Currently it is doing: (x - μ) / (σ + ϵ)
# But instead it should: (x - μ) / sqrt(σ² + ϵ)
function (ln::LayerNorm)(x::AbstractArray)
ϵ = convert(float(eltype(x)), ln.ϵ)
μ, σ² = _normalize(x; dims=1:length(ln.size))
y = ln.diag((x .- μ) .* inv.(sqrt.(σ² .+ ϵ)))
sync_free!(μ, σ²)
return y
end

function (gn::Flux.GroupNorm)(x::AbstractArray)
sz = size(x)
x2 = reshape(x, sz[1:end - 2]..., sz[end - 1] ÷ gn.G, gn.G, sz[end])
N = ndims(x2) # == ndims(x)+1
reduce_dims = 1:(N - 2)
affine_shape = ntuple(i -> i ∈ (N - 1, N - 2) ? size(x2, i) : 1, N)

μ, σ² = _normalize(x2; dims=reduce_dims)
γ = reshape(gn.γ, affine_shape)
β = reshape(gn.β, affine_shape)

ϵ = convert(float(eltype(x)), gn.ϵ)
scale = γ .* inv.(sqrt.(σ² .+ ϵ))
bias = -scale .* μ .+ β

sync_free!(μ, σ²)
return reshape(gn.λ.(scale .* x2 .+ bias), sz)
end

function _normalize(x::AbstractArray{Float16}; dims)
x_fp32 = Float32.(x)
μ, σ² = _normalize(x_fp32; dims)
m, v = Float16.(μ), Float16.(σ²)
sync_free!(x_fp32, μ, σ²)
return m, v
end

function _normalize(x; dims)
μ = mean(x; dims)
σ² = var(x; dims, mean=μ, corrected=false)
μ, σ²
end

include("timestep.jl")
include("feed_forward.jl")
Expand All @@ -36,60 +100,45 @@ include("clip/tokenizer.jl")
include("clip/models.jl")

include("schedulers/pndm.jl")
include("stable_diffusion.jl")

include("load_utils.jl")

# >>> vae.encoder(x).sum()
# tensor(28057.0957, grad_fn=<SumBackward0>)

function main()
kl = AutoencoderKL(
"runwayml/stable-diffusion-v1-5";
state_file="vae/diffusion_pytorch_model.bin",
config_file="vae/config.json")
x = ones(Float32, 256, 256, 3, 1)
sd = StableDiffusion("runwayml/stable-diffusion-v1-5") |> f16 |> gpu
println("Running StableDiffusion on $(get_backend(sd))")

n_images_per_prompt = 1
prompts = ["painting of a farmer in the field"]
images = sd(prompts; n_images_per_prompt, n_inference_steps=20)

idx = 1
for prompt in prompts, i in 1:n_images_per_prompt
joined_prompt = replace(prompt, ' ' => '-')
save("$joined_prompt-$i.png", rotr90(RGB{N0f8}.(images[:, :, idx])))
idx += 1
end
return
end

@show sum(kl.encoder(x))
@show sum(kl(x))
function debug()
m = LayerNorm(320)
lf = m |> f32 |> gpu
lh = m |> f16 |> gpu

# y = kl(x)
# @show size(y)
# @show sum(y)
x = rand(Float32, 320, 4096, 1)
xf = x |> f32 |> gpu
xh = x |> f16 |> gpu

# y = kl(x; sample_posterior = true)
# @show size(y)
# @show sum(y)
y = m(x)
yf = lf(xf) |> cpu
yh = lh(xh) |> cpu |> f32

println()
@show sum(y)
@show sum(yf)
@show sum(yh)
return
end

function tk()
input_texts = [
"Hello, world!",
"There is nothing basically... I mean it quite literally",
"I was now on a dark path, unsettled by a future filled with big data and small comprehension.",
]
println("Input texts:")
display(input_texts); println()

tokenizer = CLIPTokenizer()
tokens, pad_mask = tokenize(tokenizer, input_texts; context_length=32)
println("Tokens:")
display(tokens); println()
display(pad_mask); println()
@show size(pad_mask)

texts = [
decode(tokenizer, @view(tokens[:, i]))
for i in 1:size(tokens, 2)]
println("Decoded texts:")
display(texts); println()

nothing
end

"""
- CLIPFeatureExtractor: https://github.com/huggingface/transformers/blob/fb366b9a2a94b38171896f6ba9fb9ae8bffd77af/src/transformers/models/clip/feature_extraction_clip.py#L26
"""

end
22 changes: 16 additions & 6 deletions src/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ function Attention(dim::Int;
cross_attention_norm::Bool = false,
scale::Float32 = 1f0,
dropout::Real = 0,
ϵ::Float32 = 1f-5,
)
cross_attention_norm && isnothing(context_dim) && throw(ArgumentError("""
`context_dim` is `nothing`, but `cross_attention_norm` is `true`.
Expand All @@ -54,7 +55,7 @@ function Attention(dim::Int;
norm = if is_cross_attention
cross_attention_norm ? LayerNorm(context_dim) : identity
else
isnothing(n_groups) ? identity : GroupNorm(dim, n_groups)
isnothing(n_groups) ? identity : GroupNorm(dim, n_groups; ϵ)
end

Attention(
Expand All @@ -65,8 +66,8 @@ end
function (attn::Attention)(
x::T, context::Maybe{C} = nothing; mask::Maybe{M} = nothing,
) where {
T <: AbstractArray{Float32, 3},
C <: AbstractArray{Float32, 3},
T <: AbstractArray{<:Real, 3},
C <: AbstractArray{<:Real, 3},
M <: AbstractMatrix{Bool},
}
residual = x
Expand All @@ -77,15 +78,24 @@ function (attn::Attention)(
attn.to_q(x), attn.to_k(c), attn.to_v(c)
else
x = attn.norm(x)
# TODO add doc that in this case input should be in (w * h, c, b)
x = permutedims(x, (2, 1, 3))
attn.to_q(x), attn.to_k(x), attn.to_v(x)
end

mask = isnothing(mask) ? nothing :
reshape(mask, size(mask, 1), 1, 1, size(mask, 2))
ω, _ = dot_product_attention(q, k, v; mask, nheads=attn.n_heads)
ω, α = dot_product_attention(q, k, v; mask, nheads=attn.n_heads)

o = attn.to_out(reshape(ω, :, seq_length, batch))
sync_free!(α, q, k, v)
isnothing(mask) || KernelAbstractions.unsafe_free!(mask)

cross_attention(attn) && (ω = reshape(ω, :, seq_length, batch);)
o = attn.to_out(ω)

sync_free!(ω)
cross_attention(attn) && return o

(o .+ residual) ./ attn.scale
FP = eltype(x)
(permutedims(o, (2, 1, 3)) .+ residual) .* FP(inv(attn.scale))
end
11 changes: 6 additions & 5 deletions src/autoencoder/blocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function Encoder(
push!(down_blocks, SamplerBlock2D{Downsample2D}(
input_channels => output_channels;
n_groups, n_layers=n_block_layers,
add_sampler=!is_final, λ))
sampler_pad=0, add_sampler=!is_final, λ))
end

mid_block = MidBlock2D(
Expand All @@ -43,11 +43,12 @@ function Encoder(
Encoder(conv_in, conv_out, norm, Chain(down_blocks...), mid_block)
end

function (enc::Encoder)(x::T) where T <: AbstractArray{Float32, 4}
function (enc::Encoder)(x::T) where T <: AbstractArray{<:Real, 4}
x = enc.conv_in(x)
x = enc.down_blocks(x)
x = enc.mid_block(x)
enc.conv_out(enc.norm(x))
x = enc.norm(x)
enc.conv_out(x)
end

struct Decoder{C1, C2, N, U, M}
Expand Down Expand Up @@ -94,7 +95,7 @@ function Decoder(
Decoder(conv_in, conv_out, norm, Chain(up_blocks...), mid_block)
end

function (dec::Decoder)(x::T) where T <: AbstractArray{Float32, 4}
function (dec::Decoder)(x::T) where T <: AbstractArray{<:Real, 4}
x = dec.conv_in(x)
x = dec.mid_block(x)
x = dec.up_blocks(x)
Expand Down Expand Up @@ -124,7 +125,7 @@ end
# Kullback–Leibler divergence.
function kl(
dg::DiagonalGaussian{T}, other::Maybe{DiagonalGaussian{T}} = nothing,
) where T <: AbstractArray{Float32, 4}
) where T <: AbstractArray{<:Real, 4}
dims = (1, 2, 3)
0.5f0 .* (isnothing(other) ?
sum(dg.μ.^2 .+ dg.ν .- dg.log_σ .- 1f0; dims) :
Expand Down
6 changes: 3 additions & 3 deletions src/autoencoder/kl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ function AutoencoderKL(
AutoencoderKL(encoder, decoder, quant_conv, post_quant_conv, scaling_factor)
end

function encode(kl::AutoencoderKL, x::T) where T <: AbstractArray{Float32, 4}
function encode(kl::AutoencoderKL, x::T) where T <: AbstractArray{<:Real, 4}
h = kl.encoder(x)
moments = kl.quant_conv(h)
DiagonalGaussian(moments)
end

function decode(kl::AutoencoderKL, z::T) where T <: AbstractArray{Float32, 4}
function decode(kl::AutoencoderKL, z::T) where T <: AbstractArray{<:Real, 4}
h = kl.post_quant_conv(z)
kl.decoder(h)
end

function (kl::AutoencoderKL)(
x::T; sample_posterior::Bool = false,
) where T <: AbstractArray{Float32, 4}
) where T <: AbstractArray{<:Real, 4}
posterior = encode(kl, x)
if sample_posterior
z = sample(posterior)
Expand Down
26 changes: 9 additions & 17 deletions src/clip/basic.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,24 @@
struct Embedding{E}
weights::E
end
Flux.@functor Embedding

function Embedding(; embed_dim::Int, vocab_size::Int)
Embedding(randn(Float32, embed_dim, vocab_size))
end

# ids are 1-based
function (e::Embedding)(ids::T) where T <: AbstractMatrix{<: Integer}
NNlib.gather(e.weights, ids)
end

struct CLIPTextEmbeddings{T, P, I}
struct CLIPTextEmbeddings{T, P, I <: AbstractMatrix{Int32}}
token_embedding::T
position_embedding::P
position_ids::I
end
Flux.@functor CLIPTextEmbeddings

# TODO better way to handle fixed int type during f16 conversion
function CLIPTextEmbeddings(token_embedding, position_embedding, position_ids)
pi = eltype(position_ids) == Int32 ? position_ids : Int32.(position_ids)
CLIPTextEmbeddings(token_embedding, position_embedding, pi)
end

Flux.trainable(emb::CLIPTextEmbeddings) = (emb.token_embedding, emb.position_embedding)

function CLIPTextEmbeddings(;
vocab_size::Int, embed_dim::Int, max_position_embeddings::Int,
)
CLIPTextEmbeddings(
Embedding(; embed_dim, vocab_size),
Embedding(; embed_dim, vocab_size=max_position_embeddings),
Embedding(vocab_size => embed_dim),
Embedding(max_position_embeddings => embed_dim),
reshape(collect(UnitRange{Int32}(1, max_position_embeddings)), :, 1))
end

Expand Down
Loading