akhaliq HF Staff commited on
Commit
8bac254
·
verified ·
1 Parent(s): 892be11

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -6,13 +6,19 @@ from safetensors.torch import load_file
6
 
7
  # Load the model
8
  pipe = FluxPipeline.from_pretrained(
9
- 'black-forest-labs/FLUX.1-dev',
10
  torch_dtype=torch.bfloat16,
11
  use_safetensors=True
12
  ).to('cuda')
13
 
14
- # Load SRPO weights from https://huggingface.co/tencent/SRPO
15
- state_dict = load_file("tencent/SRPO/diffusion_pytorch_model.safetensors")
 
 
 
 
 
 
16
  pipe.transformer.load_state_dict(state_dict)
17
 
18
  @spaces.GPU(duration=120)
 
6
 
7
  # Load the model
8
  pipe = FluxPipeline.from_pretrained(
9
+ './data/flux',
10
  torch_dtype=torch.bfloat16,
11
  use_safetensors=True
12
  ).to('cuda')
13
 
14
+ # Load SRPO weights
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ srpo_path = hf_hub_download(
18
+ repo_id="tencent/SRPO",
19
+ filename="diffusion_pytorch_model.safetensors"
20
+ )
21
+ state_dict = load_file(srpo_path)
22
  pipe.transformer.load_state_dict(state_dict)
23
 
24
  @spaces.GPU(duration=120)