Spaces:
Runtime error
Runtime error
| import torch | |
| from kornia.morphology import dilation, closing | |
| import requests | |
| from transformers import SamModel, SamProcessor | |
| print('Loading SAM...') | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
| print('DONE') | |
| def build_mask(image, faces, hairs): | |
| # 1. Segmentation | |
| input_points = faces # 2D location of the face | |
| with torch.no_grad(): | |
| inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
| ) | |
| scores = outputs.iou_scores | |
| input_points = hairs # 2D location of the face | |
| with torch.no_grad(): | |
| inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| h_masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
| ) | |
| h_scores = outputs.iou_scores | |
| # 2. Post-processing | |
| mask=masks[0][0].all(0) | h_masks[0][0].all(0) | |
| # dilation | |
| tensor = mask[None,None,:,:] | |
| kernel = torch.ones(3, 3) | |
| mask = closing(tensor, kernel)[0,0].bool() | |
| return mask | |
| def build_mask_multi(image, faces, hairs): | |
| all_masks = [] | |
| for face,hair in zip(faces,hairs): | |
| # 1. Segmentation | |
| input_points = [face] # 2D location of the face | |
| with torch.no_grad(): | |
| inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
| ) | |
| scores = outputs.iou_scores | |
| input_points = [hair] # 2D location of the face | |
| with torch.no_grad(): | |
| inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
| outputs = model(**inputs) | |
| h_masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
| ) | |
| h_scores = outputs.iou_scores | |
| # 2. Post-processing | |
| mask=masks[0][0].all(0) | h_masks[0][0].all(0) | |
| # dilation | |
| mask_T = mask[None,None,:,:] | |
| kernel = torch.ones(3, 3) | |
| mask = closing(mask_T, kernel)[0,0].bool() | |
| all_masks.append(mask) | |
| mask = all_masks[0] | |
| for next_mask in all_masks[1:]: | |
| mask = mask | next_mask | |
| return mask |