Spaces:
Configuration error
Configuration error
| # from transformers import AutoModel | |
| import argparse | |
| import logging | |
| import os | |
| import glob | |
| import tqdm | |
| import torch, re | |
| import PIL | |
| import cv2 | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from utils import Config, Logger, CharsetMapper | |
| import gradio as gr | |
| #dfgdfg | |
| import gdown | |
| gdown.download(id='16PF_b4dURVkBt4OT7E-a-vq-SRxi0uDl', output='lol.pth') | |
| gdown.download(id='19rGjfo73P25O_keQv30snfe3IHrK0uV2', output='config.yaml') | |
| # gdown.download(id='1qyNV80qmYHx_r4KsG3_8PXQ6ff1a1dov', output='modules.zip') | |
| # gdown.download(id='1UMZ7i8SpfuNw0N2JvVY8euaNx9gu3x6N', output='configs.zip') | |
| # gdown.download(id='1yHD7_4DD_keUwGs2nenAYDaQ2CNEA5IU', output='data.zip') | |
| # os.system('unzip data.zip && unzip configs.zip && unzip modules.zip') | |
| def get_model(config): | |
| import importlib | |
| names = config.model_name.split('.') | |
| module_name, class_name = '.'.join(names[:-1]), names[-1] | |
| cls = getattr(importlib.import_module(module_name), class_name) | |
| model = cls(config) | |
| logging.info(model) | |
| model = model.eval() | |
| return model | |
| def load(model, file, device=None, strict=True): | |
| if device is None: device = 'cpu' | |
| elif isinstance(device, int): device = torch.device('cuda', device) | |
| assert os.path.isfile(file) | |
| state = torch.load(file, map_location=device) | |
| if set(state.keys()) == {'model', 'opt'}: | |
| state = state['model'] | |
| model.load_state_dict(state, strict=strict) | |
| return model | |
| config = Config('config.yaml') | |
| config.model_vision_checkpoint = None | |
| model = get_model(config) | |
| model = load(model, 'lol.pth') | |
| def postprocess(output, charset, model_eval): | |
| def _get_output(last_output, model_eval): | |
| if isinstance(last_output, (tuple, list)): | |
| for res in last_output: | |
| if res['name'] == model_eval: output = res | |
| else: output = last_output | |
| return output | |
| def _decode(logit): | |
| """ Greed decode """ | |
| out = F.softmax(logit, dim=2) | |
| pt_text, pt_scores, pt_lengths = [], [], [] | |
| for o in out: | |
| text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) | |
| text = text.split(charset.null_char)[0] # end at end-token | |
| pt_text.append(text) | |
| pt_scores.append(o.max(dim=1)[0]) | |
| pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token | |
| return pt_text, pt_scores, pt_lengths | |
| output = _get_output(output, model_eval) | |
| logits, pt_lengths = output['logits'], output['pt_lengths'] | |
| pt_text, pt_scores, pt_lengths_ = _decode(logits) | |
| return pt_text, pt_scores, pt_lengths_ | |
| def preprocess(img, width, height): | |
| img = cv2.resize(np.array(img), (width, height)) | |
| img = transforms.ToTensor()(img).unsqueeze(0) | |
| mean = torch.tensor([0.485, 0.456, 0.406]) | |
| std = torch.tensor([0.229, 0.224, 0.225]) | |
| return (img-mean[...,None,None]) / std[...,None,None] | |
| def process_image(image): | |
| charset = CharsetMapper(filename=config.dataset_charset_path, max_length=config.dataset_max_length + 1) | |
| img = image.convert('RGB') | |
| img = preprocess(img, config.dataset_image_width, config.dataset_image_height) | |
| res = model(img) | |
| return postprocess(res, charset, 'alignment')[0][0] | |
| iface = gr.Interface(fn=process_image, | |
| inputs=gr.inputs.Image(type="pil"), | |
| outputs=gr.outputs.Textbox(), | |
| title="8kun kek", | |
| description="Making Jim Watkins sheete because he is a techlet pedo", | |
| # article=article, | |
| # examples=glob.glob('figs/test/*.png') | |
| ) | |
| iface.launch(debug=True) |