Spaces:
Runtime error
Runtime error
| import torch | |
| from .model import PHNet | |
| import torchvision.transforms.functional as tf | |
| from .util import inference_img, log | |
| from .stylematte import StyleMatte | |
| import numpy as np | |
| class Inference: | |
| def __init__(self, **kwargs): | |
| self.rank = 0 | |
| self.__dict__.update(kwargs) | |
| self.model = PHNet(enc_sizes=self.enc_sizes, | |
| skips=self.skips, | |
| grid_count=self.grid_counts, | |
| init_weights=self.init_weights, | |
| init_value=self.init_value) | |
| log(f"checkpoint: {self.checkpoint.harmonizer}") | |
| state = torch.load(self.checkpoint.harmonizer, | |
| map_location=self.device) | |
| self.model.load_state_dict(state, strict=True) | |
| self.model.eval() | |
| def harmonize(self, composite, mask): | |
| if len(composite.shape) < 4: | |
| composite = composite.unsqueeze(0) | |
| while len(mask.shape) < 4: | |
| mask = mask.unsqueeze(0) | |
| composite = tf.resize(composite, [self.image_size, self.image_size]) | |
| mask = tf.resize(mask, [self.image_size, self.image_size]) | |
| log(composite.shape, mask.shape) | |
| with torch.no_grad(): | |
| harmonized = self.model(composite, mask)['harmonized'] | |
| result = harmonized * mask + composite * (1-mask) | |
| print(result.shape) | |
| return result | |
| class Matting: | |
| def __init__(self, **kwargs): | |
| self.rank = 0 | |
| self.__dict__.update(kwargs) | |
| self.model = StyleMatte().to(self.device) | |
| log(f"checkpoint: {self.checkpoint.matting}") | |
| state = torch.load(self.checkpoint.matting, map_location=self.device) | |
| self.model.load_state_dict(state, strict=True) | |
| self.model.eval() | |
| def extract(self, inp): | |
| mask = inference_img(self.model, inp, self.device) | |
| inp_np = np.array(inp) | |
| fg = mask[:, :, None]*inp_np | |
| return [mask, fg] | |