Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -161,7 +161,8 @@ class ForgeLoader4Bit(torch.nn.Module):
|
|
| 161 |
data=state_dict[prefix + 'weight'],
|
| 162 |
quantized_stats=quant_state_dict,
|
| 163 |
requires_grad=False,
|
| 164 |
-
device=self.dummy.device,
|
|
|
|
| 165 |
module=self
|
| 166 |
)
|
| 167 |
self.quant_state = self.weight.quant_state
|
|
@@ -717,12 +718,10 @@ def get_image(image) -> torch.Tensor | None:
|
|
| 717 |
# ---------------- Demo ----------------
|
| 718 |
|
| 719 |
|
| 720 |
-
from
|
|
|
|
| 721 |
|
| 722 |
-
|
| 723 |
-
torch.hub.download_url_to_file("https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/resolve/main/flux1-dev-bnb-nf4.safetensors", "flux1-dev-bnb-nf4.safetensors")
|
| 724 |
-
|
| 725 |
-
sd = load_file("flux1-dev-bnb-nf4.safetensors")
|
| 726 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
| 727 |
model = Flux().to(dtype=torch.float16, device="cuda")
|
| 728 |
result = model.load_state_dict(sd)
|
|
@@ -753,7 +752,8 @@ def generate_image(
|
|
| 753 |
init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)]
|
| 754 |
height = init_image.shape[-2]
|
| 755 |
width = init_image.shape[-1]
|
| 756 |
-
init_image = ae.encode(init_image.to(torch_device))
|
|
|
|
| 757 |
|
| 758 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 759 |
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
|
|
|
|
| 161 |
data=state_dict[prefix + 'weight'],
|
| 162 |
quantized_stats=quant_state_dict,
|
| 163 |
requires_grad=False,
|
| 164 |
+
# device=self.dummy.device,
|
| 165 |
+
device=torch.device('cuda'),
|
| 166 |
module=self
|
| 167 |
)
|
| 168 |
self.quant_state = self.weight.quant_state
|
|
|
|
| 718 |
# ---------------- Demo ----------------
|
| 719 |
|
| 720 |
|
| 721 |
+
from huggingface_hub import hf_hub_download
|
| 722 |
+
from safetensors.torch import load_file
|
| 723 |
|
| 724 |
+
sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4.safetensors"))
|
|
|
|
|
|
|
|
|
|
| 725 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
| 726 |
model = Flux().to(dtype=torch.float16, device="cuda")
|
| 727 |
result = model.load_state_dict(sd)
|
|
|
|
| 752 |
init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)]
|
| 753 |
height = init_image.shape[-2]
|
| 754 |
width = init_image.shape[-1]
|
| 755 |
+
init_image = ae.encode(init_image.to(torch_device)).latent_dist.sample()
|
| 756 |
+
init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
|
| 757 |
|
| 758 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 759 |
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
|