Spaces:
Runtime error
Runtime error
| import os | |
| from PIL import Image | |
| import torch | |
| import gradio as gr | |
| os.system("pip install gradio==2.5.3") | |
| import torch | |
| torch.backends.cudnn.benchmark = True | |
| from torchvision import transforms, utils | |
| from util import * | |
| from PIL import Image | |
| import math | |
| import random | |
| import numpy as np | |
| from torch import nn, autograd, optim | |
| from torch.nn import functional as F | |
| from tqdm import tqdm | |
| import lpips | |
| from model import * | |
| from e4e_projection import projection as e4e_projection | |
| from copy import deepcopy | |
| import imageio | |
| os.makedirs('inversion_codes', exist_ok=True) | |
| os.makedirs('style_images', exist_ok=True) | |
| os.makedirs('style_images_aligned', exist_ok=True) | |
| os.makedirs('models', exist_ok=True) | |
| os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2") | |
| os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2") | |
| os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat") | |
| device = 'cpu' | |
| os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_btZ") | |
| latent_dim = 512 | |
| original_generator = Generator(1024, latent_dim, 8, 2).to(device) | |
| ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) | |
| original_generator.load_state_dict(ckpt["g_ema"], strict=False) | |
| mean_latent = original_generator.mean_latent(10000) | |
| generatorjojo = deepcopy(original_generator) | |
| generatordisney = deepcopy(original_generator) | |
| generatorjinx = deepcopy(original_generator) | |
| generatorcaitlyn = deepcopy(original_generator) | |
| generatoryasuho = deepcopy(original_generator) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ] | |
| ) | |
| os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK") | |
| os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt") | |
| os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y") | |
| ckptjojo = torch.load('jojo.pt', map_location=lambda storage, loc: storage) | |
| generatorjojo.load_state_dict(ckptjojo["g"], strict=False) | |
| os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi") | |
| ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage) | |
| generatordisney.load_state_dict(ckptdisney["g"], strict=False) | |
| os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney") | |
| ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage) | |
| generatorjinx.load_state_dict(ckptjinx["g"], strict=False) | |
| os.system("gdown https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH") | |
| ckptcaitlyn = torch.load('arcane_caitlyn_preserve_color.pt', map_location=lambda storage, loc: storage) | |
| generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False) | |
| os.system("gdown https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L") | |
| ckptyasuho = torch.load('jojo_yasuho_preserve_color.pt', map_location=lambda storage, loc: storage) | |
| generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False) | |
| def inference(img, model): | |
| aligned_face = align_face(img) | |
| my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0) | |
| if model == 'JoJo': | |
| with torch.no_grad(): | |
| my_sample = generatorjojo(my_w, input_is_latent=True) | |
| elif model == 'Disney': | |
| with torch.no_grad(): | |
| my_sample = generatordisney(my_w, input_is_latent=True) | |
| elif model == 'Jinx': | |
| with torch.no_grad(): | |
| my_sample = generatorjinx(my_w, input_is_latent=True) | |
| elif model == 'Caitlyn': | |
| with torch.no_grad(): | |
| my_sample = generatorcaitlyn(my_w, input_is_latent=True) | |
| else: | |
| with torch.no_grad(): | |
| my_sample = generatoryasuho(my_w, input_is_latent=True) | |
| npimage = my_sample[0].permute(1, 2, 0).detach().numpy() | |
| imageio.imwrite('filename.jpeg', npimage) | |
| return 'filename.jpeg' | |
| title = "JoJoGAN" | |
| description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center> <p style='text-align: center'>samples from repo: <img src='https://raw.githubusercontent.com/mchong6/JoJoGAN/main/teaser.jpg' alt='animation'/></p>" | |
| examples=[['iu.jpeg','Jinx']] | |
| gr.Interface(inference, [gr.inputs.Image(type="filepath"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch() | |