Spaces:
Runtime error
Runtime error
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -6,13 +6,19 @@ from safetensors.torch import load_file
|
|
| 6 |
|
| 7 |
# Load the model
|
| 8 |
pipe = FluxPipeline.from_pretrained(
|
| 9 |
-
'
|
| 10 |
torch_dtype=torch.bfloat16,
|
| 11 |
use_safetensors=True
|
| 12 |
).to('cuda')
|
| 13 |
|
| 14 |
-
# Load SRPO weights
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
pipe.transformer.load_state_dict(state_dict)
|
| 17 |
|
| 18 |
@spaces.GPU(duration=120)
|
|
|
|
| 6 |
|
| 7 |
# Load the model
|
| 8 |
pipe = FluxPipeline.from_pretrained(
|
| 9 |
+
'./data/flux',
|
| 10 |
torch_dtype=torch.bfloat16,
|
| 11 |
use_safetensors=True
|
| 12 |
).to('cuda')
|
| 13 |
|
| 14 |
+
# Load SRPO weights
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
srpo_path = hf_hub_download(
|
| 18 |
+
repo_id="tencent/SRPO",
|
| 19 |
+
filename="diffusion_pytorch_model.safetensors"
|
| 20 |
+
)
|
| 21 |
+
state_dict = load_file(srpo_path)
|
| 22 |
pipe.transformer.load_state_dict(state_dict)
|
| 23 |
|
| 24 |
@spaces.GPU(duration=120)
|