Spaces:
Runtime error
Runtime error
new app
Browse files- .gitignore +2 -0
- app.py +46 -15
- configs/model/optimus.yaml +2 -1
- lib/model_zoo/common/get_model.py +8 -21
- requirements.txt +1 -1
.gitignore
CHANGED
|
@@ -7,3 +7,5 @@ log/
|
|
| 7 |
log
|
| 8 |
pretrained/
|
| 9 |
pretrained
|
|
|
|
|
|
|
|
|
| 7 |
log
|
| 8 |
pretrained/
|
| 9 |
pretrained
|
| 10 |
+
gradio_cached_examples/
|
| 11 |
+
gradio_cached_examples
|
app.py
CHANGED
|
@@ -252,10 +252,6 @@ class vd_inference(object):
|
|
| 252 |
assert False, 'Model type not supported'
|
| 253 |
net = get_model()(cfgm)
|
| 254 |
|
| 255 |
-
if self.which == 'v1.0':
|
| 256 |
-
sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
|
| 257 |
-
net.load_state_dict(sd, strict=False)
|
| 258 |
-
|
| 259 |
if fp16:
|
| 260 |
highlight_print('Running in FP16')
|
| 261 |
if self.which == 'v1.0':
|
|
@@ -266,6 +262,20 @@ class vd_inference(object):
|
|
| 266 |
else:
|
| 267 |
self.dtype = torch.float32
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
self.use_cuda = torch.cuda.is_available()
|
| 270 |
if self.use_cuda:
|
| 271 |
net.to('cuda')
|
|
@@ -855,9 +865,11 @@ def tcg_interface(with_example=False):
|
|
| 855 |
cache_examples=cache_examples, )
|
| 856 |
|
| 857 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
| 858 |
-
'<
|
| 859 |
-
|
| 860 |
-
|
|
|
|
|
|
|
| 861 |
|
| 862 |
def mcg_interface(with_example=False):
|
| 863 |
num_img_input = 4
|
|
@@ -917,9 +929,11 @@ def mcg_interface(with_example=False):
|
|
| 917 |
cache_examples=cache_examples, )
|
| 918 |
|
| 919 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
| 920 |
-
'<
|
| 921 |
-
|
| 922 |
-
|
|
|
|
|
|
|
| 923 |
|
| 924 |
###########
|
| 925 |
# Example #
|
|
@@ -1017,6 +1031,21 @@ css = """
|
|
| 1017 |
margin: 0rem;
|
| 1018 |
color: #6B7280;
|
| 1019 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1020 |
"""
|
| 1021 |
|
| 1022 |
if True:
|
|
@@ -1025,7 +1054,7 @@ if True:
|
|
| 1025 |
"""
|
| 1026 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
| 1027 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
| 1028 |
-
Versatile Diffusion
|
| 1029 |
</h1>
|
| 1030 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
|
| 1031 |
We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
|
|
@@ -1041,8 +1070,7 @@ if True:
|
|
| 1041 |
[<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
|
| 1042 |
</h3>
|
| 1043 |
</div>
|
| 1044 |
-
"""
|
| 1045 |
-
# .format('')) #
|
| 1046 |
|
| 1047 |
with gr.Tab('Text-to-Image'):
|
| 1048 |
t2i_interface(with_example=True)
|
|
@@ -1061,7 +1089,10 @@ if True:
|
|
| 1061 |
|
| 1062 |
gr.HTML(
|
| 1063 |
"""
|
| 1064 |
-
<div style="text-align:
|
|
|
|
|
|
|
|
|
|
| 1065 |
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
| 1066 |
<b>Caution</b>:
|
| 1067 |
We would like the raise the awareness of users of this demo of its potential issues and concerns.
|
|
@@ -1077,7 +1108,7 @@ if True:
|
|
| 1077 |
VD in this demo is meant only for research purposes.
|
| 1078 |
</h3>
|
| 1079 |
</div>
|
| 1080 |
-
""")
|
| 1081 |
|
| 1082 |
demo.launch(share=True)
|
| 1083 |
# demo.launch(debug=True)
|
|
|
|
| 252 |
assert False, 'Model type not supported'
|
| 253 |
net = get_model()(cfgm)
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
if fp16:
|
| 256 |
highlight_print('Running in FP16')
|
| 257 |
if self.which == 'v1.0':
|
|
|
|
| 262 |
else:
|
| 263 |
self.dtype = torch.float32
|
| 264 |
|
| 265 |
+
if self.which == 'v1.0':
|
| 266 |
+
# if fp16:
|
| 267 |
+
# sd = torch.load('pretrained/vd-four-flow-v1-0-fp16.pth', map_location='cpu')
|
| 268 |
+
# else:
|
| 269 |
+
# sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
|
| 270 |
+
from huggingface_hub import hf_hub_download
|
| 271 |
+
if fp16:
|
| 272 |
+
temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0-fp16.pth')
|
| 273 |
+
else:
|
| 274 |
+
temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0.pth')
|
| 275 |
+
sd = torch.load(temppath, map_location='cpu')
|
| 276 |
+
|
| 277 |
+
net.load_state_dict(sd, strict=False)
|
| 278 |
+
|
| 279 |
self.use_cuda = torch.cuda.is_available()
|
| 280 |
if self.use_cuda:
|
| 281 |
net.to('cuda')
|
|
|
|
| 865 |
cache_examples=cache_examples, )
|
| 866 |
|
| 867 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
| 868 |
+
'<div id="maskinst">'+
|
| 869 |
+
'<img src="file/assets/demo/misc/mask_inst1.gif">'+
|
| 870 |
+
'<img src="file/assets/demo/misc/mask_inst2.gif">'+
|
| 871 |
+
'<img src="file/assets/demo/misc/mask_inst3.gif">'+
|
| 872 |
+
'</div>')
|
| 873 |
|
| 874 |
def mcg_interface(with_example=False):
|
| 875 |
num_img_input = 4
|
|
|
|
| 929 |
cache_examples=cache_examples, )
|
| 930 |
|
| 931 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
| 932 |
+
'<div id="maskinst">'+
|
| 933 |
+
'<img src="file/assets/demo/misc/mask_inst1.gif">'+
|
| 934 |
+
'<img src="file/assets/demo/misc/mask_inst2.gif">'+
|
| 935 |
+
'<img src="file/assets/demo/misc/mask_inst3.gif">'+
|
| 936 |
+
'</div>')
|
| 937 |
|
| 938 |
###########
|
| 939 |
# Example #
|
|
|
|
| 1031 |
margin: 0rem;
|
| 1032 |
color: #6B7280;
|
| 1033 |
}
|
| 1034 |
+
#maskinst {
|
| 1035 |
+
text-align: justify;
|
| 1036 |
+
min-width: 1200px;
|
| 1037 |
+
}
|
| 1038 |
+
#maskinst>img {
|
| 1039 |
+
min-width:399px;
|
| 1040 |
+
max-width:450px;
|
| 1041 |
+
vertical-align: top;
|
| 1042 |
+
display: inline-block;
|
| 1043 |
+
}
|
| 1044 |
+
#maskinst:after {
|
| 1045 |
+
content: "";
|
| 1046 |
+
width: 100%;
|
| 1047 |
+
display: inline-block;
|
| 1048 |
+
}
|
| 1049 |
"""
|
| 1050 |
|
| 1051 |
if True:
|
|
|
|
| 1054 |
"""
|
| 1055 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
| 1056 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
| 1057 |
+
Versatile Diffusion
|
| 1058 |
</h1>
|
| 1059 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
|
| 1060 |
We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
|
|
|
|
| 1070 |
[<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
|
| 1071 |
</h3>
|
| 1072 |
</div>
|
| 1073 |
+
""")
|
|
|
|
| 1074 |
|
| 1075 |
with gr.Tab('Text-to-Image'):
|
| 1076 |
t2i_interface(with_example=True)
|
|
|
|
| 1089 |
|
| 1090 |
gr.HTML(
|
| 1091 |
"""
|
| 1092 |
+
<div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
|
| 1093 |
+
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
| 1094 |
+
<b>Version</b>: {}
|
| 1095 |
+
</h3>
|
| 1096 |
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
| 1097 |
<b>Caution</b>:
|
| 1098 |
We would like the raise the awareness of users of this demo of its potential issues and concerns.
|
|
|
|
| 1108 |
VD in this demo is meant only for research purposes.
|
| 1109 |
</h3>
|
| 1110 |
</div>
|
| 1111 |
+
""".format(' '+vd_inference.which))
|
| 1112 |
|
| 1113 |
demo.launch(share=True)
|
| 1114 |
# demo.launch(debug=True)
|
configs/model/optimus.yaml
CHANGED
|
@@ -92,7 +92,6 @@ optimus_gpt2_tokenizer:
|
|
| 92 |
optimus_v1:
|
| 93 |
super_cfg: optimus
|
| 94 |
type: optimus_vae_next
|
| 95 |
-
pth: pretrained/optimus-vae.pth
|
| 96 |
args:
|
| 97 |
encoder: MODEL(optimus_bert_encoder)
|
| 98 |
decoder: MODEL(optimus_gpt2_decoder)
|
|
@@ -100,3 +99,5 @@ optimus_v1:
|
|
| 100 |
tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
|
| 101 |
args:
|
| 102 |
latent_size: 768
|
|
|
|
|
|
|
|
|
| 92 |
optimus_v1:
|
| 93 |
super_cfg: optimus
|
| 94 |
type: optimus_vae_next
|
|
|
|
| 95 |
args:
|
| 96 |
encoder: MODEL(optimus_bert_encoder)
|
| 97 |
decoder: MODEL(optimus_gpt2_decoder)
|
|
|
|
| 99 |
tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
|
| 100 |
args:
|
| 101 |
latent_size: 768
|
| 102 |
+
# pth: pretrained/optimus-vae.pth
|
| 103 |
+
hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/optimus-vae.pth']
|
lib/model_zoo/common/get_model.py
CHANGED
|
@@ -8,27 +8,6 @@ from .utils import \
|
|
| 8 |
get_total_param, get_total_param_sum, \
|
| 9 |
get_unit
|
| 10 |
|
| 11 |
-
# def load_state_dict(net, model_path):
|
| 12 |
-
# if isinstance(net, dict):
|
| 13 |
-
# for ni, neti in net.items():
|
| 14 |
-
# paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
|
| 15 |
-
# new_paras = neti.state_dict()
|
| 16 |
-
# new_paras.update(paras)
|
| 17 |
-
# neti.load_state_dict(new_paras)
|
| 18 |
-
# else:
|
| 19 |
-
# paras = torch.load(model_path, map_location=torch.device('cpu'))
|
| 20 |
-
# new_paras = net.state_dict()
|
| 21 |
-
# new_paras.update(paras)
|
| 22 |
-
# net.load_state_dict(new_paras)
|
| 23 |
-
# return
|
| 24 |
-
|
| 25 |
-
# def save_state_dict(net, path):
|
| 26 |
-
# if isinstance(net, (torch.nn.DataParallel,
|
| 27 |
-
# torch.nn.parallel.DistributedDataParallel)):
|
| 28 |
-
# torch.save(net.module.state_dict(), path)
|
| 29 |
-
# else:
|
| 30 |
-
# torch.save(net.state_dict(), path)
|
| 31 |
-
|
| 32 |
def singleton(class_):
|
| 33 |
instances = {}
|
| 34 |
def getinstance(*args, **kwargs):
|
|
@@ -94,6 +73,14 @@ class get_model(object):
|
|
| 94 |
net.load_state_dict(sd, strict=strict_sd)
|
| 95 |
if verbose:
|
| 96 |
print_log('Load pth from {}'.format(cfg.pth))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# display param_num & param_sum
|
| 99 |
if verbose:
|
|
|
|
| 8 |
get_total_param, get_total_param_sum, \
|
| 9 |
get_unit
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def singleton(class_):
|
| 12 |
instances = {}
|
| 13 |
def getinstance(*args, **kwargs):
|
|
|
|
| 73 |
net.load_state_dict(sd, strict=strict_sd)
|
| 74 |
if verbose:
|
| 75 |
print_log('Load pth from {}'.format(cfg.pth))
|
| 76 |
+
elif 'hfm' in cfg:
|
| 77 |
+
from huggingface_hub import hf_hub_download
|
| 78 |
+
temppath = hf_hub_download(cfg.hfm[0], cfg.hfm[1])
|
| 79 |
+
sd = torch.load(temppath, map_location='cpu')
|
| 80 |
+
strict_sd = cfg.get('strict_sd', True)
|
| 81 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 82 |
+
if verbose:
|
| 83 |
+
print_log('Load hfm from {}/{}'.format(*cfg.hfm))
|
| 84 |
|
| 85 |
# display param_num & param_sum
|
| 86 |
if verbose:
|
requirements.txt
CHANGED
|
@@ -12,5 +12,5 @@ torchmetrics==0.7.3
|
|
| 12 |
|
| 13 |
einops==0.3.0
|
| 14 |
omegaconf==2.1.1
|
| 15 |
-
huggingface-hub==0.
|
| 16 |
gradio==3.17.1
|
|
|
|
| 12 |
|
| 13 |
einops==0.3.0
|
| 14 |
omegaconf==2.1.1
|
| 15 |
+
huggingface-hub==0.11.1
|
| 16 |
gradio==3.17.1
|