Skip to content

Commit

Permalink
remove a couple from torch import ... from the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Aug 21, 2024
1 parent 45143e2 commit 2cb0f06
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 80 deletions.
4 changes: 2 additions & 2 deletions src/refiners/fluxion/layers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload

import torch
from torch import Tensor, cat, device as Device, dtype as DType
from torch import Tensor, device as Device, dtype as DType

from refiners.fluxion.context import ContextProvider, Contexts
from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
Expand Down Expand Up @@ -950,7 +950,7 @@ def __init__(self, *modules: Module, dim: int = 0) -> None:

def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self]
return cat(
return torch.cat(
[output for output in outputs if output is not None],
dim=self.dim,
)
Expand Down
9 changes: 5 additions & 4 deletions src/refiners/fluxion/layers/norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType, ones, sqrt, zeros
from torch import Tensor, device as Device, dtype as DType
from torch.nn import (
GroupNorm as _GroupNorm,
InstanceNorm2d as _InstanceNorm2d,
Expand Down Expand Up @@ -111,8 +112,8 @@ def __init__(
dtype: DType | None = None,
) -> None:
super().__init__()
self.weight = TorchParameter(ones(channels, device=device, dtype=dtype))
self.bias = TorchParameter(zeros(channels, device=device, dtype=dtype))
self.weight = TorchParameter(torch.ones(channels, device=device, dtype=dtype))
self.bias = TorchParameter(torch.zeros(channels, device=device, dtype=dtype))
self.eps = eps

def forward(
Expand All @@ -121,7 +122,7 @@ def forward(
) -> Float[Tensor, "batch channels height width"]:
x_mean = x.mean(1, keepdim=True)
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
x_norm = (x - x_mean) / torch.sqrt(x_var + self.eps)
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
return x_out

Expand Down
18 changes: 5 additions & 13 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,22 @@
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import (
Tensor,
cat,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore

T = TypeVar("T")
E = TypeVar("E")


def norm(x: Tensor) -> Tensor:
return _norm(x) # type: ignore
return torch.norm(x) # type: ignore


def manual_seed(seed: int) -> None:
_manual_seed(seed)
torch.manual_seed(seed) # type: ignore


class no_grad(_no_grad):
class no_grad(torch.no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls)

Expand Down Expand Up @@ -123,7 +115,7 @@ def default_sigma(kernel_size: int) -> float:
def images_to_tensor(
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
) -> Tensor:
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
return torch.cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])


def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
Expand Down
5 changes: 3 additions & 2 deletions src/refiners/foundationals/clip/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torch import Tensor, arange, device as Device, dtype as DType
import torch
from torch import Tensor, device as Device, dtype as DType

import refiners.fluxion.layers as fl

Expand All @@ -25,7 +26,7 @@ def __init__(

@property
def position_ids(self) -> Tensor:
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
return torch.arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)

def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]
Expand Down
16 changes: 12 additions & 4 deletions src/refiners/foundationals/clip/concepts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
from typing import cast

import torch
import torch.nn.functional as F
from torch import Tensor, cat, zeros
from torch import Tensor
from torch.nn import Parameter

import refiners.fluxion.layers as fl
Expand All @@ -22,19 +23,26 @@ def __init__(
with self.setup_adapter(target):
super().__init__(fl.Lambda(func=self.lookup))
p = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
torch.zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default
self.old_weight = cast(Parameter, target.weight)
self.new_weight = p

# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor:
# Concatenate old and new weights for dynamic embedding updates during training
return F.embedding(x, cat([self.old_weight, self.new_weight]))
return F.embedding(x, torch.cat([self.old_weight, self.new_weight]))

def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],)
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
p = Parameter(
torch.cat(
[
self.new_weight,
embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype),
]
)
)
self.new_weight = p

@property
Expand Down
21 changes: 11 additions & 10 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload

import torch
from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
from torch import Tensor, device as Device, dtype as DType, nn

import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
Expand Down Expand Up @@ -98,7 +99,7 @@ def forward(
v = self.reshape_tensor(value)

attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = torch.softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = attention @ v

return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
Expand Down Expand Up @@ -159,7 +160,7 @@ def __init__(
)

def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
return cat((x, latents), dim=-2)
return torch.cat((x, latents), dim=-2)


class LatentsToken(fl.Chain):
Expand Down Expand Up @@ -484,7 +485,7 @@ def compute_clip_image_embedding(
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt)
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])

negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)

Expand All @@ -493,28 +494,28 @@ def compute_clip_image_embedding(
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
if any(weight != 1.0 for weight in weights):
conditional_embedding *= (
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
torch.tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
.unsqueeze(-1)
.unsqueeze(-1)
)

if batch_size > 1 and concat_batches:
# Create a longer image tokens sequence when a batch of images is given
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
negative_embedding = torch.cat(negative_embedding.chunk(batch_size), dim=1)
conditional_embedding = torch.cat(conditional_embedding.chunk(batch_size), dim=1)

return cat((negative_embedding, conditional_embedding))
return torch.cat((negative_embedding, conditional_embedding))

def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding)
if not self.fine_grained:
negative_embedding = self.image_proj(zeros_like(clip_embedding))
negative_embedding = self.image_proj(torch.zeros_like(clip_embedding))
else:
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt))
clip_embedding = image_encoder(torch.zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding)
return negative_embedding, conditional_embedding

Expand Down
9 changes: 5 additions & 4 deletions src/refiners/foundationals/latent_diffusion/range_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math

import torch
from jaxtyping import Float, Int
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin
from torch import Tensor, device as Device, dtype as DType

import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
Expand All @@ -14,10 +15,10 @@ def compute_sinusoidal_embedding(
half_dim = embedding_dim // 2
# Note: it is important that this computation is done in float32.
# The result can be cast to lower precision later if necessary.
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
exponent /= half_dim
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
embedding = torch.cat([torch.cos(embedding), torch.sin(embedding)], dim=-1)
return embedding


Expand Down
11 changes: 6 additions & 5 deletions src/refiners/foundationals/latent_diffusion/solvers/ddim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses

from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype

from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
Expand Down Expand Up @@ -28,7 +29,7 @@ def __init__(
first_inference_step: int = 0,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
dtype: Dtype = torch.float32,
) -> None:
"""Initializes a new DDIM solver.
Expand Down Expand Up @@ -71,7 +72,7 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
(
self.timesteps[step + 1]
if step < self.num_inference_steps - 1
else tensor(data=[0], device=self.device, dtype=self.dtype)
else torch.tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = (
Expand All @@ -82,8 +83,8 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
else self.cumulative_scale_factors[0]
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
noise_factor = sqrt(1 - previous_scale_factor**2)
predicted_x = (x - torch.sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
noise_factor = torch.sqrt(1 - previous_scale_factor**2)

# Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1:
Expand Down
12 changes: 6 additions & 6 deletions src/refiners/foundationals/latent_diffusion/solvers/dpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from torch import Generator, Tensor, device as Device, dtype as Dtype

from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
params: BaseSolverParams | None = None,
last_step_first_order: bool = False,
device: Device | str = "cpu",
dtype: Dtype = float32,
dtype: Dtype = torch.float32,
):
"""Initializes a new DPM solver.
Expand All @@ -62,7 +62,7 @@ def __init__(
device=device,
dtype=dtype,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order

def rebuild(
Expand Down Expand Up @@ -94,7 +94,7 @@ def _generate_timesteps(self) -> Tensor:
offset = self.params.timesteps_offset
max_timestep = self.params.num_train_timesteps - 1 + offset
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
return torch.tensor(np_space).flip(0)

def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
Expand All @@ -110,7 +110,7 @@ def dpm_solver_first_order_update(
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])

previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
Expand Down Expand Up @@ -144,7 +144,7 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noi
Returns:
The denoised version of the input data `x`.
"""
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]

Expand Down
6 changes: 3 additions & 3 deletions src/refiners/foundationals/latent_diffusion/solvers/euler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from torch import Generator, Tensor, device as Device, dtype as Dtype

from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
Expand All @@ -23,7 +23,7 @@ def __init__(
first_inference_step: int = 0,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
dtype: Dtype = torch.float32,
):
"""Initializes a new Euler solver.
Expand Down Expand Up @@ -57,7 +57,7 @@ def _generate_sigmas(self) -> Tensor:
"""Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
sigmas = torch.cat([sigmas, tensor([0.0])])
sigmas = torch.cat([sigmas, torch.tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)

def scale_model_input(self, x: Tensor, step: int) -> Tensor:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
from typing import Any, Callable, Protocol, TypeVar

from torch import Generator, Tensor, device as Device, dtype as DType, float32
import torch
from torch import Generator, Tensor, device as Device, dtype as DType

from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing

Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(
num_inference_steps: int,
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: DType = float32,
dtype: DType = torch.float32,
**kwargs: Any, # for typing, ignored
) -> None:
self.get_diffusers_scheduler = get_diffusers_scheduler
Expand Down
Loading

0 comments on commit 2cb0f06

Please sign in to comment.