-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds sdxl's VAE decoder implementation #653
base: main
Are you sure you want to change the base?
Conversation
3e6c40c
to
d967c01
Compare
subfolder="vae", | ||
) | ||
|
||
def decode(self, inp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have a scaling factor parameter. You should use that instead of the constant if possible. Otherwise, maybe document why it is this value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the diffusers reference code so is not using our HParams at all. They also do not do the scaling inside vae decode but in their pipeline as input here https://github.com/huggingface/diffusers/blob/18f9b990884883533491fc87f303e7305dc27d75/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L591. I added it to our implementation as forcing shortfin to handle it in the pipeline would introduce extra host/device roundtrips which would impact performance.
I will add comments to clarify here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can probably replace img = 1 / 0.13025 * inp
with img = img / self.vae.config.scaling_factor
source
this is done differently in flux/sd3 where we have a shift factor:
img = img / self.ae.config.scaling_factor + self.ae.config.shift_factor
|
||
|
||
def get_random_inputs(dtype, device, bs: int = 2): | ||
torch.random.manual_seed(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone calls this function multiple times with the same batch size, I think they get the exact same tensor, which does not sound like desirable behavior. I would leave seed setting to callers, like main/test scripts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. Left over from some debugging awhile ago. Will remove
@@ -60,6 +60,7 @@ | |||
"unflatten", | |||
"unshard", | |||
"unsqueeze", | |||
"squeeze", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep this alphabetized.
|
||
# there is always at least one resnet | ||
if resnet_time_scale_shift == "spatial": | ||
# TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you fill out the TODO?
Also, seems to me like you might want your else
branch to be an elif
where you verify you're getting what you expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each layer always has a resnet and specifically either ResnetBlockCondNorm2d for spatial time shift or ResnetBlock2D otherwise. No else needed
else: | ||
inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs) | ||
|
||
if args.export: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When this branch is reached, it looks like VAE is not run, only exported. I would not expect such a thing of a file named run_vae.py
. Could you at least add a TODO to move the export code to another file or put a print statement "VAE exported. Skipping execution" or something like that?
sharktank/sharktank/types/tensors.py
Outdated
@@ -393,6 +393,11 @@ def unsqueeze(self, dim: int) -> "AnyTensor": | |||
|
|||
return unsqueeze(self, dim) | |||
|
|||
def squeeze(self, dim: Optional[int] = None) -> "AnyTensor": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise lets keep this alphabetized.
@@ -0,0 +1,261 @@ | |||
# Copyright 2024 Advanced Micro Devices, Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized. This should probably be broken up and put in sharktank/sharktank/layers/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it should. I kept it separate following punet layers and was planning to clean up both in a follow up PR
Moves irpa generation into vae setup
No description provided.