Spaces:
Runtime error
Runtime error
remove ema parts in checkpoints.
Browse files- app.py +3 -3
- model_states.pt → model_wo_ema.ckpt +2 -2
- transfer.py +8 -2
app.py
CHANGED
|
@@ -66,11 +66,11 @@ def process_multi_wrapper_only_show_rendered(rendered_txt_0, rendered_txt_1, ren
|
|
| 66 |
shared_eta, shared_a_prompt, shared_n_prompt,
|
| 67 |
only_show_rendered_image=True)
|
| 68 |
|
| 69 |
-
# cfg = OmegaConf.load("config.yaml")
|
| 70 |
-
# model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
| 71 |
|
| 72 |
cfg = OmegaConf.load("config.yaml")
|
| 73 |
-
model = load_model_from_config(cfg, "
|
|
|
|
|
|
|
| 74 |
|
| 75 |
ddim_sampler = DDIMSampler(model)
|
| 76 |
render_tool = Render_Text(model)
|
|
|
|
| 66 |
shared_eta, shared_a_prompt, shared_n_prompt,
|
| 67 |
only_show_rendered_image=True)
|
| 68 |
|
|
|
|
|
|
|
| 69 |
|
| 70 |
cfg = OmegaConf.load("config.yaml")
|
| 71 |
+
model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
|
| 72 |
+
# model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
| 73 |
+
# model = load_model_from_config(cfg, "model.ckpt", verbose=True)
|
| 74 |
|
| 75 |
ddim_sampler = DDIMSampler(model)
|
| 76 |
render_tool = Render_Text(model)
|
model_states.pt → model_wo_ema.ckpt
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b86b22188bf580e80773a5ae101bf9787eb258349f3f1acf0ae50fd10cb3fec
|
| 3 |
+
size 6671922039
|
transfer.py
CHANGED
|
@@ -6,9 +6,15 @@ model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
|
| 6 |
|
| 7 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 8 |
with model.ema_scope("store ema weights"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
file_content = {
|
| 10 |
-
'state_dict':
|
| 11 |
}
|
| 12 |
-
torch.save(file_content, "
|
| 13 |
print("has stored the transfered ckpt.")
|
| 14 |
print("trial ends!")
|
|
|
|
| 6 |
|
| 7 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 8 |
with model.ema_scope("store ema weights"):
|
| 9 |
+
model_sd = model.state_dict()
|
| 10 |
+
store_sd = {}
|
| 11 |
+
for key in model_sd:
|
| 12 |
+
if "ema" in key:
|
| 13 |
+
continue
|
| 14 |
+
store_sd[key] = model_sd[key]
|
| 15 |
file_content = {
|
| 16 |
+
'state_dict': store_sd
|
| 17 |
}
|
| 18 |
+
torch.save(file_content, "model_wo_ema.ckpt")
|
| 19 |
print("has stored the transfered ckpt.")
|
| 20 |
print("trial ends!")
|