Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds sdxl's VAE decoder implementation #653

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

IanNod
Copy link
Contributor

@IanNod IanNod commented Dec 6, 2024

No description provided.

subfolder="vae",
)

def decode(self, inp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a scaling factor parameter. You should use that instead of the constant if possible. Otherwise, maybe document why it is this value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the diffusers reference code so is not using our HParams at all. They also do not do the scaling inside vae decode but in their pipeline as input here https://github.com/huggingface/diffusers/blob/18f9b990884883533491fc87f303e7305dc27d75/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L591. I added it to our implementation as forcing shortfin to handle it in the pipeline would introduce extra host/device roundtrips which would impact performance.

I will add comments to clarify here

Copy link
Contributor

@monorimet monorimet Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably replace img = 1 / 0.13025 * inp with img = img / self.vae.config.scaling_factor source

this is done differently in flux/sd3 where we have a shift factor:

img = img / self.ae.config.scaling_factor + self.ae.config.shift_factor



def get_random_inputs(dtype, device, bs: int = 2):
torch.random.manual_seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone calls this function multiple times with the same batch size, I think they get the exact same tensor, which does not sound like desirable behavior. I would leave seed setting to callers, like main/test scripts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. Left over from some debugging awhile ago. Will remove

@@ -60,6 +60,7 @@
"unflatten",
"unshard",
"unsqueeze",
"squeeze",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep this alphabetized.


# there is always at least one resnet
if resnet_time_scale_shift == "spatial":
# TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you fill out the TODO?
Also, seems to me like you might want your else branch to be an elif where you verify you're getting what you expect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each layer always has a resnet and specifically either ResnetBlockCondNorm2d for spatial time shift or ResnetBlock2D otherwise. No else needed

sharktank/sharktank/models/vae/layers.py Show resolved Hide resolved
else:
inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs)

if args.export:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this branch is reached, it looks like VAE is not run, only exported. I would not expect such a thing of a file named run_vae.py. Could you at least add a TODO to move the export code to another file or put a print statement "VAE exported. Skipping execution" or something like that?

@@ -393,6 +393,11 @@ def unsqueeze(self, dim: int) -> "AnyTensor":

return unsqueeze(self, dim)

def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise lets keep this alphabetized.

@@ -0,0 +1,261 @@
# Copyright 2024 Advanced Micro Devices, Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized. This should probably be broken up and put in sharktank/sharktank/layers/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it should. I kept it separate following punet layers and was planning to clean up both in a follow up PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants