Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
·
11137cc
1
Parent(s):
541a5f9
Update app.py
Browse files
app.py
CHANGED
|
@@ -102,6 +102,8 @@ generatorart = deepcopy(original_generator)
|
|
| 102 |
|
| 103 |
generatorspider = deepcopy(original_generator)
|
| 104 |
|
|
|
|
|
|
|
| 105 |
|
| 106 |
transform = transforms.Compose(
|
| 107 |
[
|
|
@@ -162,6 +164,10 @@ modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filen
|
|
| 162 |
ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
|
| 163 |
generatorspider.load_state_dict(ckptspider["g"], strict=False)
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
def inference(img, model):
|
| 167 |
img.save('out.jpg')
|
|
@@ -189,9 +195,12 @@ def inference(img, model):
|
|
| 189 |
elif model == 'Art':
|
| 190 |
with torch.no_grad():
|
| 191 |
my_sample = generatorart(my_w, input_is_latent=True)
|
| 192 |
-
|
| 193 |
with torch.no_grad():
|
| 194 |
my_sample = generatorspider(my_w, input_is_latent=True)
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
|
@@ -204,4 +213,4 @@ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, si
|
|
| 204 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
|
| 205 |
|
| 206 |
examples=[['mona.png','Jinx']]
|
| 207 |
-
gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch(enable_queue=True, cache_examples=True)
|
|
|
|
| 102 |
|
| 103 |
generatorspider = deepcopy(original_generator)
|
| 104 |
|
| 105 |
+
generatorsketch = deepcopy(original_generator)
|
| 106 |
+
|
| 107 |
|
| 108 |
transform = transforms.Compose(
|
| 109 |
[
|
|
|
|
| 164 |
ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
|
| 165 |
generatorspider.load_state_dict(ckptspider["g"], strict=False)
|
| 166 |
|
| 167 |
+
modelSketch = hf_hub_download(repo_id="akhaliq/akhaliq/jojogan-sketch", filename="sketch_multi.pt")
|
| 168 |
+
|
| 169 |
+
ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
|
| 170 |
+
generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
|
| 171 |
|
| 172 |
def inference(img, model):
|
| 173 |
img.save('out.jpg')
|
|
|
|
| 195 |
elif model == 'Art':
|
| 196 |
with torch.no_grad():
|
| 197 |
my_sample = generatorart(my_w, input_is_latent=True)
|
| 198 |
+
elif model == 'Spider-Verse':
|
| 199 |
with torch.no_grad():
|
| 200 |
my_sample = generatorspider(my_w, input_is_latent=True)
|
| 201 |
+
else:
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
my_sample = generatorsketch(my_w, input_is_latent=True)
|
| 204 |
|
| 205 |
|
| 206 |
npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
|
|
|
|
| 213 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
|
| 214 |
|
| 215 |
examples=[['mona.png','Jinx']]
|
| 216 |
+
gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch(enable_queue=True, cache_examples=True)
|