Spaces:
Runtime error
Runtime error
Update
Browse files- README.md +5 -2
- app.py +10 -13
- requirements.txt +2 -2
README.md
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: ⚡
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 3.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: GAN-Control
|
| 3 |
emoji: ⚡
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.35.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
suggested_hardware: t4-small
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
| 14 |
+
|
| 15 |
+
https://arxiv.org/abs/2101.02477
|
app.py
CHANGED
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
| 5 |
import functools
|
| 6 |
import os
|
| 7 |
import pathlib
|
|
|
|
| 8 |
import subprocess
|
| 9 |
import sys
|
| 10 |
import tarfile
|
|
@@ -17,27 +18,22 @@ import torch
|
|
| 17 |
|
| 18 |
if os.getenv('SYSTEM') == 'spaces':
|
| 19 |
with open('patch') as f:
|
| 20 |
-
subprocess.run('patch -p1'
|
| 21 |
|
| 22 |
sys.path.insert(0, 'gan-control/src')
|
| 23 |
|
| 24 |
from gan_control.inference.controller import Controller
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
This is an unofficial demo for https://github.com/amazon-research/gan-control.
|
| 29 |
-
'''
|
| 30 |
-
|
| 31 |
-
TOKEN = os.getenv('HF_TOKEN')
|
| 32 |
|
| 33 |
|
| 34 |
def download_models() -> None:
|
| 35 |
model_dir = pathlib.Path('controller_age015id025exp02hai04ori02gam15')
|
| 36 |
if not model_dir.exists():
|
| 37 |
path = huggingface_hub.hf_hub_download(
|
| 38 |
-
'
|
| 39 |
-
'controller_age015id025exp02hai04ori02gam15.tar.gz'
|
| 40 |
-
use_auth_token=TOKEN)
|
| 41 |
with tarfile.open(path) as f:
|
| 42 |
f.extractall()
|
| 43 |
|
|
@@ -96,10 +92,10 @@ download_models()
|
|
| 96 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 97 |
path = 'controller_age015id025exp02hai04ori02gam15/'
|
| 98 |
controller = Controller(path, device)
|
| 99 |
-
|
| 100 |
|
| 101 |
gr.Interface(
|
| 102 |
-
fn=
|
| 103 |
inputs=[
|
| 104 |
gr.Slider(label='Seed', minimum=0, maximum=1000000, step=1, value=0),
|
| 105 |
gr.Slider(label='Truncation',
|
|
@@ -142,5 +138,6 @@ gr.Interface(
|
|
| 142 |
gr.Image(label='Age Controlled', type='pil'),
|
| 143 |
gr.Image(label='Hair Color Controlled', type='pil'),
|
| 144 |
],
|
|
|
|
| 145 |
description=DESCRIPTION,
|
| 146 |
-
).queue().launch(
|
|
|
|
| 5 |
import functools
|
| 6 |
import os
|
| 7 |
import pathlib
|
| 8 |
+
import shlex
|
| 9 |
import subprocess
|
| 10 |
import sys
|
| 11 |
import tarfile
|
|
|
|
| 18 |
|
| 19 |
if os.getenv('SYSTEM') == 'spaces':
|
| 20 |
with open('patch') as f:
|
| 21 |
+
subprocess.run(shlex.split('patch -p1'), cwd='gan-control', stdin=f)
|
| 22 |
|
| 23 |
sys.path.insert(0, 'gan-control/src')
|
| 24 |
|
| 25 |
from gan_control.inference.controller import Controller
|
| 26 |
|
| 27 |
+
TITLE = 'GAN-Control'
|
| 28 |
+
DESCRIPTION = 'https://github.com/amazon-research/gan-control'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def download_models() -> None:
|
| 32 |
model_dir = pathlib.Path('controller_age015id025exp02hai04ori02gam15')
|
| 33 |
if not model_dir.exists():
|
| 34 |
path = huggingface_hub.hf_hub_download(
|
| 35 |
+
'public-data/gan-control',
|
| 36 |
+
'controller_age015id025exp02hai04ori02gam15.tar.gz')
|
|
|
|
| 37 |
with tarfile.open(path) as f:
|
| 38 |
f.extractall()
|
| 39 |
|
|
|
|
| 92 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 93 |
path = 'controller_age015id025exp02hai04ori02gam15/'
|
| 94 |
controller = Controller(path, device)
|
| 95 |
+
fn = functools.partial(run, controller=controller, device=device)
|
| 96 |
|
| 97 |
gr.Interface(
|
| 98 |
+
fn=fn,
|
| 99 |
inputs=[
|
| 100 |
gr.Slider(label='Seed', minimum=0, maximum=1000000, step=1, value=0),
|
| 101 |
gr.Slider(label='Truncation',
|
|
|
|
| 138 |
gr.Image(label='Age Controlled', type='pil'),
|
| 139 |
gr.Image(label='Hair Color Controlled', type='pil'),
|
| 140 |
],
|
| 141 |
+
title=TITLE,
|
| 142 |
description=DESCRIPTION,
|
| 143 |
+
).queue(max_size=10).launch()
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
numpy==1.
|
| 2 |
-
Pillow==
|
| 3 |
torch==1.11.0
|
| 4 |
torchvision==0.12.0
|
|
|
|
| 1 |
+
numpy==1.23.5
|
| 2 |
+
Pillow==10.0.0
|
| 3 |
torch==1.11.0
|
| 4 |
torchvision==0.12.0
|