Skip to content

Commit

Permalink
Merge pull request #20 from ReindeerCzar/main
Browse files Browse the repository at this point in the history
Issue #18 lower compute capability requirement.
  • Loading branch information
zjp-shadow authored Jul 24, 2024
2 parents 6fda565 + 7b84296 commit 38a2079
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
20 changes: 19 additions & 1 deletion 2D_Stage/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@
repo_id = "zjpshadow/CharacterGen"
all_files = list_repo_files(repo_id, revision="main")

#7-23-2024 Changed to allow GPU with compute < 8
device_capability = -1

#bfloat Support is typically 8 or higher.
def check_bfloat16_support():
# Check if bfloat16 is supported
device_capability = torch.cuda.get_device_capability()

if device_capability[0] >= 8:
print("CUDA device capability is above 8, using bfloat16.")
return torch.bfloat16
else:
print("CUDA device capability is below 8, using float 32.")
return torch.float32

#7-23-2024 Changed to allow GPU with compute < 8
data_type_float = check_bfloat16_support()

for file in all_files:
if os.path.exists("../" + file):
continue
Expand Down Expand Up @@ -191,7 +209,7 @@ def inference(self, input_image, vae, feature_extractor, image_encoder, unet, re

# (B*Nv, 3, H, W)
B = 1
weight_dtype = torch.bfloat16
weight_dtype = data_type_float #7-23-2024 Changed to allow GPU with compute < 8
imgs_in = process_image(input_image, totensor)
imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")

Expand Down
22 changes: 20 additions & 2 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@

from huggingface_hub import hf_hub_download, list_repo_files

#7-23-2024 Changed to allow GPU with compute < 8
device_capability = -1

#bfloat Support is typically 8 or higher.
def check_bfloat16_support():
# Check if bfloat16 is supported
device_capability = torch.cuda.get_device_capability()

if device_capability[0] >= 8:
print("CUDA device capability is above 8, using bfloat16.")
return torch.bfloat16
else:
print("CUDA device capability is below 8, using float 32.")
return torch.float32

#7-23-2024 Changed to allow GPU with compute < 8
data_type_float = check_bfloat16_support()

repo_id = "zjpshadow/CharacterGen"
all_files = list_repo_files(repo_id, revision="main")

Expand Down Expand Up @@ -248,7 +266,7 @@ def inference(self, input_image, val_width, val_height,

# (B*Nv, 3, H, W)
B = 1
weight_dtype = torch.bfloat16
weight_dtype = data_type_float #7-23-2024 Changed to allow GPU with compute < 8
imgs_in = process_image(input_image, totensor)
imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")

Expand Down Expand Up @@ -448,4 +466,4 @@ def gen4views(image, width, height, seed, timestep, remove_bg):
demo.launch(server_name="0.0.0.0")

if __name__ == "__main__":
main()
main()

0 comments on commit 38a2079

Please sign in to comment.