Spaces:
Runtime error
Runtime error
zejunyang
commited on
Commit
·
9667e74
1
Parent(s):
9464d6e
update
Browse files- .gitattributes +1 -0
- NTED/NTED_module.py +101 -0
- NTED/base_function.py +434 -0
- NTED/base_module.py +115 -0
- NTED/config.py +202 -0
- NTED/demo_dataset.py +182 -0
- NTED/edge_attention_layer.py +116 -0
- NTED/extraction_distribution_model.py +62 -0
- NTED/fashion_512.yaml +129 -0
- NTED/nted_checkpoint.pt +3 -0
- NTED/op/__init__.py +2 -0
- NTED/op/conv2d_gradfix.py +227 -0
- NTED/op/fused_act.py +127 -0
- NTED/op/fused_bias_act.cpp +32 -0
- NTED/op/fused_bias_act_kernel.cu +105 -0
- NTED/op/upfirdn2d.cpp +31 -0
- NTED/op/upfirdn2d.py +209 -0
- NTED/op/upfirdn2d_kernel.cu +369 -0
- app.py +20 -8
- example/exp1.png +0 -0
- example/exp2.png +0 -0
- example/exp3.png +0 -0
- example/exp4.png +0 -0
- example/exp5.png +0 -0
- example/exp6.png +0 -0
- example/ref_img.png +3 -0
- lite_openpose/body_bbox_detector.py +179 -0
- lite_openpose/checkpoint_iter_370000.pth +3 -0
- lite_openpose/modules/__init__.py +0 -0
- lite_openpose/modules/conv.py +32 -0
- lite_openpose/modules/get_parameters.py +23 -0
- lite_openpose/modules/keypoints.py +201 -0
- lite_openpose/modules/load_state.py +32 -0
- lite_openpose/modules/loss.py +5 -0
- lite_openpose/modules/one_euro_filter.py +51 -0
- lite_openpose/modules/pose.py +118 -0
- lite_openpose/pose2d_models/__init__.py +0 -0
- lite_openpose/pose2d_models/with_mobilenet.py +123 -0
.gitattributes
CHANGED
|
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*/ref_img.png filter=lfs diff=lfs merge=lfs -text
|
NTED/NTED_module.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import mediapipe as mp
|
| 7 |
+
from lite_openpose.body_bbox_detector import BodyPoseEstimator
|
| 8 |
+
from NTED.extraction_distribution_model import Generator
|
| 9 |
+
from NTED.demo_dataset import DemoDataset
|
| 10 |
+
from NTED.base_function import accumulate
|
| 11 |
+
from NTED.config import Config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def set_random_seed(seed):
|
| 15 |
+
r"""Set random seeds for everything.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
seed (int): Random seed.
|
| 19 |
+
by_rank (bool):
|
| 20 |
+
"""
|
| 21 |
+
random.seed(seed)
|
| 22 |
+
np.random.seed(seed)
|
| 23 |
+
torch.manual_seed(seed)
|
| 24 |
+
torch.cuda.manual_seed(seed)
|
| 25 |
+
torch.cuda.manual_seed_all(seed)
|
| 26 |
+
|
| 27 |
+
class NTED():
|
| 28 |
+
def __init__(self):
|
| 29 |
+
super(NTED, self).__init__()
|
| 30 |
+
|
| 31 |
+
self.openpose_module = BodyPoseEstimator('cpu')
|
| 32 |
+
set_random_seed(0)
|
| 33 |
+
self.opt = Config('NTED/fashion_512.yaml', is_train=False)
|
| 34 |
+
|
| 35 |
+
net_G = Generator(**self.opt.gen.param).to('cpu')
|
| 36 |
+
net_G_ema = Generator(**self.opt.gen.param).to('cpu')
|
| 37 |
+
net_G_ema.eval()
|
| 38 |
+
accumulate(net_G_ema, net_G, 0)
|
| 39 |
+
|
| 40 |
+
checkpoint = torch.load('NTED/nted_checkpoint.pt', map_location=lambda storage, loc: storage)
|
| 41 |
+
net_G_ema.load_state_dict(checkpoint['net_G_ema'])
|
| 42 |
+
self.net_G = net_G_ema.eval()
|
| 43 |
+
|
| 44 |
+
self.data_loader = DemoDataset()
|
| 45 |
+
|
| 46 |
+
mp_hands = mp.solutions.hands
|
| 47 |
+
self.hands = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.1)
|
| 48 |
+
|
| 49 |
+
self.ref_img = cv2.imread('example/ref_img.png')
|
| 50 |
+
self.ref_img = cv2.resize(self.ref_img, (352, 512))
|
| 51 |
+
|
| 52 |
+
def hand_pose_est(self, img):
|
| 53 |
+
results = self.hands.process(cv2.cvtColor(cv2.flip(img, 1), cv2.COLOR_BGR2RGB))
|
| 54 |
+
image_height, image_width, _ = img.shape
|
| 55 |
+
pose_data = []
|
| 56 |
+
|
| 57 |
+
if results.multi_hand_landmarks is not None:
|
| 58 |
+
for hand_landmarks in results.multi_hand_landmarks:
|
| 59 |
+
for joint_idx in range(21):
|
| 60 |
+
pose_data.append([image_width - hand_landmarks.landmark[joint_idx].x * image_width, hand_landmarks.landmark[joint_idx].y * image_height])
|
| 61 |
+
if len(results.multi_hand_landmarks) == 2:
|
| 62 |
+
if results.multi_handedness[0].classification[0].label == 'Right':
|
| 63 |
+
# 交换一下,先左手再右手
|
| 64 |
+
tmp = pose_data[:21].copy()
|
| 65 |
+
pose_data[:21] = pose_data[21:]
|
| 66 |
+
pose_data[21:] = tmp
|
| 67 |
+
elif len(results.multi_hand_landmarks) == 1:
|
| 68 |
+
miss_hand = [[-1, -1] for _ in range(21)]
|
| 69 |
+
if results.multi_handedness[0].classification[0].label == 'Left':
|
| 70 |
+
pose_data += miss_hand
|
| 71 |
+
else:
|
| 72 |
+
pose_data = miss_hand + pose_data
|
| 73 |
+
else:
|
| 74 |
+
for _ in range(42):
|
| 75 |
+
pose_data.append([-1, -1])
|
| 76 |
+
pose_data = np.array(pose_data, dtype=np.int32)
|
| 77 |
+
|
| 78 |
+
return pose_data
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def inference(self, img):
|
| 82 |
+
|
| 83 |
+
img = cv2.resize(img, (352, 512))
|
| 84 |
+
|
| 85 |
+
body_pose, bbox = self.openpose_module.detect_body_pose(img.copy())
|
| 86 |
+
|
| 87 |
+
hand_pose = self.hand_pose_est(img.copy())
|
| 88 |
+
|
| 89 |
+
data = self.data_loader.load_item(self.ref_img, body_pose[0], hand_pose)
|
| 90 |
+
|
| 91 |
+
output = self.net_G(
|
| 92 |
+
data['reference_image'],
|
| 93 |
+
data['target_skeleton'],
|
| 94 |
+
)
|
| 95 |
+
fake_image = output['fake_image'][0]
|
| 96 |
+
|
| 97 |
+
fake_image = self.data_loader.tensor2im(fake_image)
|
| 98 |
+
|
| 99 |
+
fake_image = cv2.resize(fake_image, (288, 480))
|
| 100 |
+
|
| 101 |
+
return data['skeleton_img'], fake_image
|
NTED/base_function.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from NTED.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
| 9 |
+
|
| 10 |
+
class ExtractionOperation(nn.Module):
|
| 11 |
+
def __init__(self, in_channel, num_label, match_kernel):
|
| 12 |
+
super(ExtractionOperation, self).__init__()
|
| 13 |
+
self.value_conv = EqualConv2d(in_channel, in_channel, match_kernel, 1, match_kernel//2, bias=True)
|
| 14 |
+
self.semantic_extraction_filter = EqualConv2d(in_channel, num_label, match_kernel, 1, match_kernel//2, bias=False)
|
| 15 |
+
|
| 16 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 17 |
+
self.num_label = num_label
|
| 18 |
+
|
| 19 |
+
def forward(self, value, recoder):
|
| 20 |
+
key = value
|
| 21 |
+
b,c,h,w = value.shape
|
| 22 |
+
key = self.semantic_extraction_filter(self.feature_norm(key))
|
| 23 |
+
extraction_softmax = self.softmax(key.view(b, -1, h*w)) #bkm
|
| 24 |
+
values_flatten = self.value_conv(value).view(b, -1, h*w)
|
| 25 |
+
neural_textures = torch.einsum('bkm,bvm->bvk', extraction_softmax, values_flatten)
|
| 26 |
+
recoder['extraction_softmax'].insert(0, extraction_softmax)
|
| 27 |
+
recoder['neural_textures'].insert(0, neural_textures)
|
| 28 |
+
return neural_textures, extraction_softmax
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def feature_norm(self, input_tensor):
|
| 32 |
+
input_tensor = input_tensor - input_tensor.mean(dim=1, keepdim=True)
|
| 33 |
+
norm = torch.norm(input_tensor, 2, 1, keepdim=True) + sys.float_info.epsilon
|
| 34 |
+
out = torch.div(input_tensor, norm)
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
class DistributionOperation(nn.Module):
|
| 38 |
+
def __init__(self, num_label, input_dim, match_kernel=3):
|
| 39 |
+
super(DistributionOperation, self).__init__()
|
| 40 |
+
self.semantic_distribution_filter = EqualConv2d(input_dim, num_label,
|
| 41 |
+
kernel_size=match_kernel,
|
| 42 |
+
stride=1,
|
| 43 |
+
padding=match_kernel//2)
|
| 44 |
+
self.num_label = num_label
|
| 45 |
+
|
| 46 |
+
def forward(self, query, extracted_feature, recoder):
|
| 47 |
+
b,c,h,w = query.shape
|
| 48 |
+
|
| 49 |
+
query = self.semantic_distribution_filter(query)
|
| 50 |
+
query_flatten = query.view(b, self.num_label, -1)
|
| 51 |
+
query_softmax = F.softmax(query_flatten, 1)
|
| 52 |
+
values_q = torch.einsum('bkm,bkv->bvm', query_softmax, extracted_feature.permute(0,2,1))
|
| 53 |
+
attn_out = values_q.view(b,-1,h,w)
|
| 54 |
+
recoder['semantic_distribution'].append(query)
|
| 55 |
+
return attn_out
|
| 56 |
+
|
| 57 |
+
class EncoderLayer(nn.Sequential):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
in_channel,
|
| 61 |
+
out_channel,
|
| 62 |
+
kernel_size,
|
| 63 |
+
downsample=False,
|
| 64 |
+
blur_kernel=[1, 3, 3, 1],
|
| 65 |
+
bias=True,
|
| 66 |
+
activate=True,
|
| 67 |
+
use_extraction=False,
|
| 68 |
+
num_label=None,
|
| 69 |
+
match_kernel=None,
|
| 70 |
+
num_extractions=2
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
if downsample:
|
| 75 |
+
factor = 2
|
| 76 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
| 77 |
+
pad0 = (p + 1) // 2
|
| 78 |
+
pad1 = p // 2
|
| 79 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
| 80 |
+
|
| 81 |
+
stride = 2
|
| 82 |
+
padding = 0
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
self.blur = None
|
| 86 |
+
stride = 1
|
| 87 |
+
padding = kernel_size // 2
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
self.conv = EqualConv2d(
|
| 91 |
+
in_channel,
|
| 92 |
+
out_channel,
|
| 93 |
+
kernel_size,
|
| 94 |
+
padding=padding,
|
| 95 |
+
stride=stride,
|
| 96 |
+
bias=bias and not activate,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.activate = FusedLeakyReLU(out_channel, bias=bias) if activate else None
|
| 100 |
+
self.use_extraction = use_extraction
|
| 101 |
+
if self.use_extraction:
|
| 102 |
+
self.extraction_operations = nn.ModuleList()
|
| 103 |
+
for _ in range(num_extractions):
|
| 104 |
+
self.extraction_operations.append(
|
| 105 |
+
ExtractionOperation(
|
| 106 |
+
out_channel,
|
| 107 |
+
num_label,
|
| 108 |
+
match_kernel
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, input, recoder=None):
|
| 113 |
+
out = self.blur(input) if self.blur is not None else input
|
| 114 |
+
out = self.conv(out)
|
| 115 |
+
out = self.activate(out) if self.activate is not None else out
|
| 116 |
+
if self.use_extraction:
|
| 117 |
+
for extraction_operation in self.extraction_operations:
|
| 118 |
+
extraction_operation(out, recoder)
|
| 119 |
+
return out
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class DecoderLayer(nn.Module):
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
in_channel,
|
| 126 |
+
out_channel,
|
| 127 |
+
kernel_size,
|
| 128 |
+
upsample=False,
|
| 129 |
+
blur_kernel=[1, 3, 3, 1],
|
| 130 |
+
bias=True,
|
| 131 |
+
activate=True,
|
| 132 |
+
use_distribution=True,
|
| 133 |
+
num_label=16,
|
| 134 |
+
match_kernel=3,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
if upsample:
|
| 138 |
+
factor = 2
|
| 139 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
| 140 |
+
pad0 = (p + 1) // 2 + factor - 1
|
| 141 |
+
pad1 = p // 2 + 1
|
| 142 |
+
|
| 143 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
| 144 |
+
self.conv = EqualTransposeConv2d(
|
| 145 |
+
in_channel,
|
| 146 |
+
out_channel,
|
| 147 |
+
kernel_size,
|
| 148 |
+
stride=2,
|
| 149 |
+
padding=0,
|
| 150 |
+
bias=bias and not activate,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
self.conv = EqualConv2d(
|
| 154 |
+
in_channel,
|
| 155 |
+
out_channel,
|
| 156 |
+
kernel_size,
|
| 157 |
+
stride=1,
|
| 158 |
+
padding=kernel_size//2,
|
| 159 |
+
bias=bias and not activate,
|
| 160 |
+
)
|
| 161 |
+
self.blur = None
|
| 162 |
+
|
| 163 |
+
self.distribution_operation = DistributionOperation(
|
| 164 |
+
num_label,
|
| 165 |
+
out_channel,
|
| 166 |
+
match_kernel=match_kernel
|
| 167 |
+
) if use_distribution else None
|
| 168 |
+
self.activate = FusedLeakyReLU(out_channel, bias=bias) if activate else None
|
| 169 |
+
self.use_distribution = use_distribution
|
| 170 |
+
|
| 171 |
+
def forward(self, input, neural_texture=None, recoder=None):
|
| 172 |
+
out = self.conv(input)
|
| 173 |
+
out = self.blur(out) if self.blur is not None else out
|
| 174 |
+
if self.use_distribution and neural_texture is not None:
|
| 175 |
+
out_attn = self.distribution_operation(out, neural_texture, recoder)
|
| 176 |
+
out = (out + out_attn) / math.sqrt(2)
|
| 177 |
+
|
| 178 |
+
out = self.activate(out.contiguous()) if self.activate is not None else out
|
| 179 |
+
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
class EqualConv2d(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
| 185 |
+
):
|
| 186 |
+
super().__init__()
|
| 187 |
+
|
| 188 |
+
self.weight = nn.Parameter(
|
| 189 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
| 190 |
+
)
|
| 191 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
| 192 |
+
|
| 193 |
+
self.stride = stride
|
| 194 |
+
self.padding = padding
|
| 195 |
+
|
| 196 |
+
if bias:
|
| 197 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
| 198 |
+
|
| 199 |
+
else:
|
| 200 |
+
self.bias = None
|
| 201 |
+
|
| 202 |
+
def forward(self, input):
|
| 203 |
+
out = conv2d_gradfix.conv2d(
|
| 204 |
+
input,
|
| 205 |
+
self.weight * self.scale,
|
| 206 |
+
bias=self.bias,
|
| 207 |
+
stride=self.stride,
|
| 208 |
+
padding=self.padding,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return out
|
| 212 |
+
|
| 213 |
+
def __repr__(self):
|
| 214 |
+
return (
|
| 215 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
| 216 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class EqualTransposeConv2d(nn.Module):
|
| 221 |
+
def __init__(
|
| 222 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
|
| 226 |
+
self.weight = nn.Parameter(
|
| 227 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
| 228 |
+
)
|
| 229 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
| 230 |
+
|
| 231 |
+
self.stride = stride
|
| 232 |
+
self.padding = padding
|
| 233 |
+
|
| 234 |
+
if bias:
|
| 235 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
self.bias = None
|
| 239 |
+
|
| 240 |
+
def forward(self, input):
|
| 241 |
+
weight = self.weight.transpose(0,1)
|
| 242 |
+
out = conv2d_gradfix.conv_transpose2d(
|
| 243 |
+
input,
|
| 244 |
+
weight * self.scale,
|
| 245 |
+
bias=self.bias,
|
| 246 |
+
stride=self.stride,
|
| 247 |
+
padding=self.padding,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return out
|
| 251 |
+
|
| 252 |
+
def __repr__(self):
|
| 253 |
+
return (
|
| 254 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
| 255 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
class ToRGB(nn.Module):
|
| 259 |
+
def __init__(
|
| 260 |
+
self,
|
| 261 |
+
in_channel,
|
| 262 |
+
upsample=True,
|
| 263 |
+
blur_kernel=[1, 3, 3, 1]
|
| 264 |
+
):
|
| 265 |
+
super().__init__()
|
| 266 |
+
|
| 267 |
+
if upsample:
|
| 268 |
+
self.upsample = Upsample(blur_kernel)
|
| 269 |
+
self.conv = EqualConv2d(in_channel, 3, 3, stride=1, padding=1)
|
| 270 |
+
|
| 271 |
+
def forward(self, input, skip=None):
|
| 272 |
+
out = self.conv(input)
|
| 273 |
+
if skip is not None:
|
| 274 |
+
skip = self.upsample(skip)
|
| 275 |
+
out = out + skip
|
| 276 |
+
return out
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class EqualLinear(nn.Module):
|
| 280 |
+
def __init__(
|
| 281 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
| 282 |
+
):
|
| 283 |
+
super().__init__()
|
| 284 |
+
|
| 285 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
| 286 |
+
|
| 287 |
+
if bias:
|
| 288 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
| 289 |
+
|
| 290 |
+
else:
|
| 291 |
+
self.bias = None
|
| 292 |
+
|
| 293 |
+
self.activation = activation
|
| 294 |
+
|
| 295 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
| 296 |
+
self.lr_mul = lr_mul
|
| 297 |
+
|
| 298 |
+
def forward(self, input):
|
| 299 |
+
if self.activation:
|
| 300 |
+
out = F.linear(input, self.weight * self.scale)
|
| 301 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
| 302 |
+
|
| 303 |
+
else:
|
| 304 |
+
out = F.linear(
|
| 305 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
return out
|
| 309 |
+
|
| 310 |
+
def __repr__(self):
|
| 311 |
+
return (
|
| 312 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
class Upsample(nn.Module):
|
| 316 |
+
def __init__(self, kernel, factor=2):
|
| 317 |
+
super().__init__()
|
| 318 |
+
|
| 319 |
+
self.factor = factor
|
| 320 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
| 321 |
+
self.register_buffer("kernel", kernel)
|
| 322 |
+
|
| 323 |
+
p = kernel.shape[0] - factor
|
| 324 |
+
|
| 325 |
+
pad0 = (p + 1) // 2 + factor - 1
|
| 326 |
+
pad1 = p // 2
|
| 327 |
+
|
| 328 |
+
self.pad = (pad0, pad1)
|
| 329 |
+
|
| 330 |
+
def forward(self, input):
|
| 331 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
| 332 |
+
|
| 333 |
+
return out
|
| 334 |
+
|
| 335 |
+
class ResBlock(nn.Module):
|
| 336 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
| 337 |
+
super().__init__()
|
| 338 |
+
|
| 339 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
| 340 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
| 341 |
+
|
| 342 |
+
self.skip = ConvLayer(
|
| 343 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def forward(self, input):
|
| 347 |
+
out = self.conv1(input)
|
| 348 |
+
out = self.conv2(out)
|
| 349 |
+
|
| 350 |
+
skip = self.skip(input)
|
| 351 |
+
out = (out + skip) / math.sqrt(2)
|
| 352 |
+
|
| 353 |
+
return out
|
| 354 |
+
|
| 355 |
+
class ConvLayer(nn.Sequential):
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
in_channel,
|
| 359 |
+
out_channel,
|
| 360 |
+
kernel_size,
|
| 361 |
+
downsample=False,
|
| 362 |
+
blur_kernel=[1, 3, 3, 1],
|
| 363 |
+
bias=True,
|
| 364 |
+
activate=True,
|
| 365 |
+
):
|
| 366 |
+
layers = []
|
| 367 |
+
|
| 368 |
+
if downsample:
|
| 369 |
+
factor = 2
|
| 370 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
| 371 |
+
pad0 = (p + 1) // 2
|
| 372 |
+
pad1 = p // 2
|
| 373 |
+
|
| 374 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
| 375 |
+
|
| 376 |
+
stride = 2
|
| 377 |
+
self.padding = 0
|
| 378 |
+
|
| 379 |
+
else:
|
| 380 |
+
stride = 1
|
| 381 |
+
self.padding = kernel_size // 2
|
| 382 |
+
|
| 383 |
+
layers.append(
|
| 384 |
+
EqualConv2d(
|
| 385 |
+
in_channel,
|
| 386 |
+
out_channel,
|
| 387 |
+
kernel_size,
|
| 388 |
+
padding=self.padding,
|
| 389 |
+
stride=stride,
|
| 390 |
+
bias=bias and not activate,
|
| 391 |
+
)
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if activate:
|
| 395 |
+
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
| 396 |
+
|
| 397 |
+
super().__init__(*layers)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class Blur(nn.Module):
|
| 401 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
| 402 |
+
super().__init__()
|
| 403 |
+
|
| 404 |
+
kernel = make_kernel(kernel)
|
| 405 |
+
|
| 406 |
+
if upsample_factor > 1:
|
| 407 |
+
kernel = kernel * (upsample_factor ** 2)
|
| 408 |
+
|
| 409 |
+
self.register_buffer("kernel", kernel)
|
| 410 |
+
|
| 411 |
+
self.pad = pad
|
| 412 |
+
|
| 413 |
+
def forward(self, input):
|
| 414 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
| 415 |
+
|
| 416 |
+
return out
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def make_kernel(k):
|
| 420 |
+
k = torch.tensor(k, dtype=torch.float32)
|
| 421 |
+
|
| 422 |
+
if k.ndim == 1:
|
| 423 |
+
k = k[None, :] * k[:, None]
|
| 424 |
+
|
| 425 |
+
k /= k.sum()
|
| 426 |
+
|
| 427 |
+
return k
|
| 428 |
+
|
| 429 |
+
def accumulate(model1, model2, decay=0.999):
|
| 430 |
+
par1 = dict(model1.named_parameters())
|
| 431 |
+
par2 = dict(model2.named_parameters())
|
| 432 |
+
|
| 433 |
+
for k in par1.keys():
|
| 434 |
+
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
|
NTED/base_module.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import functools
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from NTED.base_function import EncoderLayer, DecoderLayer, ToRGB
|
| 9 |
+
from NTED.edge_attention_layer import Edge_Attn
|
| 10 |
+
|
| 11 |
+
class Encoder(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
size,
|
| 15 |
+
input_dim,
|
| 16 |
+
channels,
|
| 17 |
+
num_labels=None,
|
| 18 |
+
match_kernels=None,
|
| 19 |
+
blur_kernel=[1, 3, 3, 1],
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.first = EncoderLayer(input_dim, channels[size], 1)
|
| 23 |
+
self.convs = nn.ModuleList()
|
| 24 |
+
|
| 25 |
+
log_size = int(math.log(size, 2))
|
| 26 |
+
self.log_size = log_size
|
| 27 |
+
|
| 28 |
+
in_channel = channels[size]
|
| 29 |
+
for i in range(log_size-1, 3, -1):
|
| 30 |
+
out_channel = channels[2 ** i]
|
| 31 |
+
num_label = num_labels[2 ** i] if num_labels is not None else None
|
| 32 |
+
match_kernel = match_kernels[2 ** i] if match_kernels is not None else None
|
| 33 |
+
use_extraction = num_label and match_kernel
|
| 34 |
+
conv = EncoderLayer(
|
| 35 |
+
in_channel,
|
| 36 |
+
out_channel,
|
| 37 |
+
kernel_size=3,
|
| 38 |
+
downsample=True,
|
| 39 |
+
blur_kernel=blur_kernel,
|
| 40 |
+
use_extraction=use_extraction,
|
| 41 |
+
num_label=num_label,
|
| 42 |
+
match_kernel=match_kernel
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.convs.append(conv)
|
| 46 |
+
in_channel = out_channel
|
| 47 |
+
|
| 48 |
+
def forward(self, input, recoder=None):
|
| 49 |
+
out = self.first(input)
|
| 50 |
+
for idx, layer in enumerate(self.convs):
|
| 51 |
+
out = layer(out, recoder)
|
| 52 |
+
return out
|
| 53 |
+
|
| 54 |
+
class Decoder(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
size,
|
| 58 |
+
channels,
|
| 59 |
+
num_labels,
|
| 60 |
+
match_kernels,
|
| 61 |
+
blur_kernel=[1, 3, 3, 1],
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
self.convs = nn.ModuleList()
|
| 67 |
+
# input at resolution 16*16
|
| 68 |
+
in_channel = channels[16]
|
| 69 |
+
self.log_size = int(math.log(size, 2))
|
| 70 |
+
|
| 71 |
+
for i in range(4, self.log_size + 1):
|
| 72 |
+
out_channel = channels[2 ** i]
|
| 73 |
+
num_label, match_kernel = num_labels[2 ** i], match_kernels[2 ** i]
|
| 74 |
+
use_distribution = num_label and match_kernel
|
| 75 |
+
upsample = (i != 4)
|
| 76 |
+
|
| 77 |
+
base_layer = functools.partial(
|
| 78 |
+
DecoderLayer,
|
| 79 |
+
out_channel=out_channel,
|
| 80 |
+
kernel_size=3,
|
| 81 |
+
blur_kernel=blur_kernel,
|
| 82 |
+
use_distribution=use_distribution,
|
| 83 |
+
num_label=num_label,
|
| 84 |
+
match_kernel=match_kernel
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
up = nn.Module()
|
| 88 |
+
up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
|
| 89 |
+
up.conv1 = base_layer(in_channel=out_channel, upsample=False)
|
| 90 |
+
up.to_rgb = ToRGB(out_channel, upsample=upsample)
|
| 91 |
+
self.convs.append(up)
|
| 92 |
+
in_channel = out_channel
|
| 93 |
+
|
| 94 |
+
self.num_labels, self.match_kernels = num_labels, match_kernels
|
| 95 |
+
|
| 96 |
+
self.edge_attn_block = Edge_Attn(in_channels=3)
|
| 97 |
+
|
| 98 |
+
def forward(self, input, neural_textures, recoder):
|
| 99 |
+
counter = 0
|
| 100 |
+
out, skip = input, None
|
| 101 |
+
for i, up in enumerate(self.convs):
|
| 102 |
+
if self.num_labels[2**(i+4)] and self.match_kernels[2**(i+4)]:
|
| 103 |
+
neural_texture_conv0 = neural_textures[counter]
|
| 104 |
+
neural_texture_conv1 = neural_textures[counter+1]
|
| 105 |
+
counter += 2
|
| 106 |
+
else:
|
| 107 |
+
neural_texture_conv0, neural_texture_conv1 = None, None
|
| 108 |
+
out = up.conv0(out, neural_texture=neural_texture_conv0, recoder=recoder)
|
| 109 |
+
out = up.conv1(out, neural_texture=neural_texture_conv1, recoder=recoder)
|
| 110 |
+
|
| 111 |
+
skip = up.to_rgb(out, skip)
|
| 112 |
+
image = self.edge_attn_block(skip)
|
| 113 |
+
# image = skip
|
| 114 |
+
return image
|
| 115 |
+
|
NTED/config.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
class AttrDict(dict):
|
| 9 |
+
"""Dict as attribute trick."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, *args, **kwargs):
|
| 12 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 13 |
+
self.__dict__ = self
|
| 14 |
+
for key, value in self.__dict__.items():
|
| 15 |
+
if isinstance(value, dict):
|
| 16 |
+
self.__dict__[key] = AttrDict(value)
|
| 17 |
+
elif isinstance(value, (list, tuple)):
|
| 18 |
+
if isinstance(value[0], dict):
|
| 19 |
+
self.__dict__[key] = [AttrDict(item) for item in value]
|
| 20 |
+
else:
|
| 21 |
+
self.__dict__[key] = value
|
| 22 |
+
|
| 23 |
+
def yaml(self):
|
| 24 |
+
"""Convert object to yaml dict and return."""
|
| 25 |
+
yaml_dict = {}
|
| 26 |
+
for key, value in self.__dict__.items():
|
| 27 |
+
if isinstance(value, AttrDict):
|
| 28 |
+
yaml_dict[key] = value.yaml()
|
| 29 |
+
elif isinstance(value, list):
|
| 30 |
+
if isinstance(value[0], AttrDict):
|
| 31 |
+
new_l = []
|
| 32 |
+
for item in value:
|
| 33 |
+
new_l.append(item.yaml())
|
| 34 |
+
yaml_dict[key] = new_l
|
| 35 |
+
else:
|
| 36 |
+
yaml_dict[key] = value
|
| 37 |
+
else:
|
| 38 |
+
yaml_dict[key] = value
|
| 39 |
+
return yaml_dict
|
| 40 |
+
|
| 41 |
+
def __repr__(self):
|
| 42 |
+
"""Print all variables."""
|
| 43 |
+
ret_str = []
|
| 44 |
+
for key, value in self.__dict__.items():
|
| 45 |
+
if isinstance(value, AttrDict):
|
| 46 |
+
ret_str.append('{}:'.format(key))
|
| 47 |
+
child_ret_str = value.__repr__().split('\n')
|
| 48 |
+
for item in child_ret_str:
|
| 49 |
+
ret_str.append(' ' + item)
|
| 50 |
+
elif isinstance(value, list):
|
| 51 |
+
if isinstance(value[0], AttrDict):
|
| 52 |
+
ret_str.append('{}:'.format(key))
|
| 53 |
+
for item in value:
|
| 54 |
+
# Treat as AttrDict above.
|
| 55 |
+
child_ret_str = item.__repr__().split('\n')
|
| 56 |
+
for item in child_ret_str:
|
| 57 |
+
ret_str.append(' ' + item)
|
| 58 |
+
else:
|
| 59 |
+
ret_str.append('{}: {}'.format(key, value))
|
| 60 |
+
else:
|
| 61 |
+
ret_str.append('{}: {}'.format(key, value))
|
| 62 |
+
return '\n'.join(ret_str)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Config(AttrDict):
|
| 66 |
+
r"""Configuration class. This should include every human specifiable
|
| 67 |
+
hyperparameter values for your training."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, filename=None, verbose=False, is_train=True):
|
| 70 |
+
super(Config, self).__init__()
|
| 71 |
+
# Set default parameters.
|
| 72 |
+
# Logging.
|
| 73 |
+
|
| 74 |
+
large_number = 1000000000
|
| 75 |
+
self.snapshot_save_iter = large_number
|
| 76 |
+
self.snapshot_save_epoch = large_number
|
| 77 |
+
self.snapshot_save_start_iter = 0
|
| 78 |
+
self.snapshot_save_start_epoch = 0
|
| 79 |
+
self.image_save_iter = large_number
|
| 80 |
+
self.eval_epoch = large_number
|
| 81 |
+
self.start_eval_epoch = large_number
|
| 82 |
+
self.eval_epoch = large_number
|
| 83 |
+
self.max_epoch = large_number
|
| 84 |
+
self.max_iter = large_number
|
| 85 |
+
self.logging_iter = 100
|
| 86 |
+
self.image_to_tensorboard=False
|
| 87 |
+
self.which_iter = None
|
| 88 |
+
self.resume = True
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
self.checkpoints_dir = 'NTED'
|
| 92 |
+
self.name = 'nted_checkpoint.pt'
|
| 93 |
+
self.phase = 'train' if is_train else 'test'
|
| 94 |
+
|
| 95 |
+
# Networks.
|
| 96 |
+
self.gen = AttrDict(type='generators.dummy')
|
| 97 |
+
self.dis = AttrDict(type='discriminators.dummy')
|
| 98 |
+
|
| 99 |
+
# Optimizers.
|
| 100 |
+
self.gen_optimizer = AttrDict(type='adam',
|
| 101 |
+
lr=0.0001,
|
| 102 |
+
adam_beta1=0.0,
|
| 103 |
+
adam_beta2=0.999,
|
| 104 |
+
eps=1e-8,
|
| 105 |
+
lr_policy=AttrDict(iteration_mode=False,
|
| 106 |
+
type='step',
|
| 107 |
+
step_size=large_number,
|
| 108 |
+
gamma=1))
|
| 109 |
+
self.dis_optimizer = AttrDict(type='adam',
|
| 110 |
+
lr=0.0001,
|
| 111 |
+
adam_beta1=0.0,
|
| 112 |
+
adam_beta2=0.999,
|
| 113 |
+
eps=1e-8,
|
| 114 |
+
lr_policy=AttrDict(iteration_mode=False,
|
| 115 |
+
type='step',
|
| 116 |
+
step_size=large_number,
|
| 117 |
+
gamma=1))
|
| 118 |
+
# Data.
|
| 119 |
+
self.data = AttrDict(name='dummy',
|
| 120 |
+
type='datasets.images',
|
| 121 |
+
num_workers=0)
|
| 122 |
+
self.test_data = AttrDict(name='dummy',
|
| 123 |
+
type='datasets.images',
|
| 124 |
+
num_workers=0,
|
| 125 |
+
test=AttrDict(is_lmdb=False,
|
| 126 |
+
roots='',
|
| 127 |
+
batch_size=1))
|
| 128 |
+
self.trainer = AttrDict(
|
| 129 |
+
image_to_tensorboard=False,
|
| 130 |
+
hparam_to_tensorboard=False)
|
| 131 |
+
|
| 132 |
+
# Cudnn.
|
| 133 |
+
self.cudnn = AttrDict(deterministic=False,
|
| 134 |
+
benchmark=True)
|
| 135 |
+
|
| 136 |
+
# Others.
|
| 137 |
+
self.pretrained_weight = ''
|
| 138 |
+
self.inference_args = AttrDict()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Update with given configurations.
|
| 142 |
+
assert os.path.exists(filename), 'File {} not exist.'.format(filename)
|
| 143 |
+
loader = yaml.SafeLoader
|
| 144 |
+
loader.add_implicit_resolver(
|
| 145 |
+
u'tag:yaml.org,2002:float',
|
| 146 |
+
re.compile(u'''^(?:
|
| 147 |
+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
| 148 |
+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
| 149 |
+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
| 150 |
+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
| 151 |
+
|[-+]?\\.(?:inf|Inf|INF)
|
| 152 |
+
|\\.(?:nan|NaN|NAN))$''', re.X),
|
| 153 |
+
list(u'-+0123456789.'))
|
| 154 |
+
try:
|
| 155 |
+
with open(filename, 'r') as f:
|
| 156 |
+
cfg_dict = yaml.load(f, Loader=loader)
|
| 157 |
+
except EnvironmentError:
|
| 158 |
+
print('Please check the file with name of "%s"', filename)
|
| 159 |
+
recursive_update(self, cfg_dict)
|
| 160 |
+
|
| 161 |
+
# Put common opts in both gen and dis.
|
| 162 |
+
if 'common' in cfg_dict:
|
| 163 |
+
self.common = AttrDict(**cfg_dict['common'])
|
| 164 |
+
self.gen.common = self.common
|
| 165 |
+
self.dis.common = self.common
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if verbose:
|
| 169 |
+
print(' config '.center(80, '-'))
|
| 170 |
+
print(self.__repr__())
|
| 171 |
+
print(''.center(80, '-'))
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def rsetattr(obj, attr, val):
|
| 175 |
+
"""Recursively find object and set value"""
|
| 176 |
+
pre, _, post = attr.rpartition('.')
|
| 177 |
+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def rgetattr(obj, attr, *args):
|
| 181 |
+
"""Recursively find object and return value"""
|
| 182 |
+
|
| 183 |
+
def _getattr(obj, attr):
|
| 184 |
+
r"""Get attribute."""
|
| 185 |
+
return getattr(obj, attr, *args)
|
| 186 |
+
|
| 187 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def recursive_update(d, u):
|
| 191 |
+
"""Recursively update AttrDict d with AttrDict u"""
|
| 192 |
+
for key, value in u.items():
|
| 193 |
+
if isinstance(value, collections.abc.Mapping):
|
| 194 |
+
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
|
| 195 |
+
elif isinstance(value, (list, tuple)):
|
| 196 |
+
if isinstance(value[0], dict):
|
| 197 |
+
d.__dict__[key] = [AttrDict(item) for item in value]
|
| 198 |
+
else:
|
| 199 |
+
d.__dict__[key] = value
|
| 200 |
+
else:
|
| 201 |
+
d.__dict__[key] = value
|
| 202 |
+
return d
|
NTED/demo_dataset.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms.functional as F
|
| 10 |
+
|
| 11 |
+
class DemoDataset(object):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.LIMBSEQ = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
| 15 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
| 16 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
| 17 |
+
|
| 18 |
+
self.COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
| 19 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
| 20 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
| 21 |
+
|
| 22 |
+
self.LIMBSEQ_hands = [[0, 1], [1, 2], [2, 3], [3, 4], \
|
| 23 |
+
[0, 5], [5, 6], [6, 7], [7, 8], \
|
| 24 |
+
[0, 9], [9, 10], [10, 11], [11, 12], \
|
| 25 |
+
[0, 13], [13, 14], [14, 15], [15, 16], \
|
| 26 |
+
[0, 17], [17, 18], [18, 19], [19, 20], \
|
| 27 |
+
[21, 22], [22, 23], [23, 24], [24, 25], \
|
| 28 |
+
[21, 26], [26, 27], [27, 28], [28, 29], \
|
| 29 |
+
[21, 30], [30, 31], [31, 32], [32, 33], \
|
| 30 |
+
[21, 34], [34, 35], [35, 36], [36, 37], \
|
| 31 |
+
[21, 38], [38, 39], [39, 40], [40, 41]]
|
| 32 |
+
|
| 33 |
+
self.COLORS_hands = [[85, 0, 0], [170, 0, 0], [85, 85, 0], [85, 170, 0], [170, 85, 0], [170, 170, 0], [85, 85, 85], \
|
| 34 |
+
[85, 85, 170], [85, 170, 85], [85, 170, 170], [0, 85, 0], [0, 170, 0], [0, 85, 85], [0, 85, 170], \
|
| 35 |
+
[0, 170, 85], [0, 170, 170], [50, 0, 0], [135, 0, 0], [50, 50, 0], [50, 135, 0], [135, 50, 0], \
|
| 36 |
+
[135, 135, 0], [50, 50, 50], [50, 50, 135], [50, 135, 50], [50, 135, 135], [0, 50, 0], [0, 135, 0], \
|
| 37 |
+
[0, 50, 50], [0, 50, 135], [0, 135, 50], [0, 135, 135], [100, 0, 0], [200, 0, 0], [100, 100, 0], \
|
| 38 |
+
[100, 200, 0], [200, 100, 0], [200, 200, 0], [100, 100, 100], [100, 100, 200], [100, 200, 100], [100, 200, 200]
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
self.img_size = tuple([512, 352])
|
| 42 |
+
|
| 43 |
+
def load_item(self, img, pose, handpose=None):
|
| 44 |
+
|
| 45 |
+
reference_img = self.get_image_tensor(img)[None,:]
|
| 46 |
+
label, ske = self.get_label_tensor(pose, handpose)
|
| 47 |
+
label = label[None,:]
|
| 48 |
+
|
| 49 |
+
return {'reference_image':reference_img, 'target_skeleton':label, 'skeleton_img': ske}
|
| 50 |
+
|
| 51 |
+
def get_image_tensor(self, bgr_img):
|
| 52 |
+
img = Image.fromarray(cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB))
|
| 53 |
+
img = F.resize(img, self.img_size)
|
| 54 |
+
img = F.to_tensor(img)
|
| 55 |
+
img = F.normalize(img, (0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
|
| 56 |
+
return img
|
| 57 |
+
|
| 58 |
+
def get_label_tensor(self, pose, hand_pose=None):
|
| 59 |
+
canvas = np.zeros((self.img_size[0], self.img_size[1], 3)).astype(np.uint8)
|
| 60 |
+
keypoint = np.array(pose)
|
| 61 |
+
if hand_pose is not None:
|
| 62 |
+
keypoint_hands = np.array(hand_pose)
|
| 63 |
+
else:
|
| 64 |
+
keypoint_hands = None
|
| 65 |
+
|
| 66 |
+
# keypoint = self.trans_keypoins(keypoint)
|
| 67 |
+
|
| 68 |
+
stickwidth = 4
|
| 69 |
+
for i in range(18):
|
| 70 |
+
x, y = keypoint[i, 0:2]
|
| 71 |
+
if x == -1 or y == -1:
|
| 72 |
+
continue
|
| 73 |
+
cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS[i], thickness=-1)
|
| 74 |
+
if keypoint_hands is not None:
|
| 75 |
+
for i in range(42):
|
| 76 |
+
x, y = keypoint_hands[i, 0:2]
|
| 77 |
+
if x == -1 or y == -1:
|
| 78 |
+
continue
|
| 79 |
+
cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS_hands[i], thickness=-1)
|
| 80 |
+
|
| 81 |
+
joints = []
|
| 82 |
+
for i in range(17):
|
| 83 |
+
Y = keypoint[np.array(self.LIMBSEQ[i])-1, 0]
|
| 84 |
+
X = keypoint[np.array(self.LIMBSEQ[i])-1, 1]
|
| 85 |
+
cur_canvas = canvas.copy()
|
| 86 |
+
if -1 in Y or -1 in X:
|
| 87 |
+
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
|
| 88 |
+
continue
|
| 89 |
+
mX = np.mean(X)
|
| 90 |
+
mY = np.mean(Y)
|
| 91 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 92 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 93 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 94 |
+
cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS[i])
|
| 95 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
| 96 |
+
|
| 97 |
+
joint = np.zeros_like(cur_canvas[:, :, 0])
|
| 98 |
+
cv2.fillConvexPoly(joint, polygon, 255)
|
| 99 |
+
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
|
| 100 |
+
joints.append(joint)
|
| 101 |
+
if keypoint_hands is not None:
|
| 102 |
+
for i in range(40):
|
| 103 |
+
Y = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 0]
|
| 104 |
+
X = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 1]
|
| 105 |
+
cur_canvas = canvas.copy()
|
| 106 |
+
if -1 in Y or -1 in X:
|
| 107 |
+
if (i+1) % 4 == 0:
|
| 108 |
+
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
|
| 109 |
+
continue
|
| 110 |
+
mX = np.mean(X)
|
| 111 |
+
mY = np.mean(Y)
|
| 112 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 113 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 114 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), int(stickwidth/2)), int(angle), 0, 360, 1)
|
| 115 |
+
cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS_hands[i])
|
| 116 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
| 117 |
+
|
| 118 |
+
# 一根手指一个通道
|
| 119 |
+
if i % 4 == 0:
|
| 120 |
+
joint = np.zeros_like(cur_canvas[:, :, 0])
|
| 121 |
+
cv2.fillConvexPoly(joint, polygon, 255)
|
| 122 |
+
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
|
| 123 |
+
if (i+1) % 4 == 0:
|
| 124 |
+
joints.append(joint)
|
| 125 |
+
|
| 126 |
+
pose = F.to_tensor(Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
|
| 127 |
+
|
| 128 |
+
tensors_dist = 0
|
| 129 |
+
e = 1
|
| 130 |
+
for i in range(len(joints)):
|
| 131 |
+
im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
|
| 132 |
+
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
|
| 133 |
+
tensor_dist = F.to_tensor(Image.fromarray(im_dist))
|
| 134 |
+
tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
|
| 135 |
+
e += 1
|
| 136 |
+
|
| 137 |
+
label_tensor = torch.cat((pose, tensors_dist), dim=0)
|
| 138 |
+
|
| 139 |
+
return label_tensor, canvas
|
| 140 |
+
|
| 141 |
+
def tensor2im(self, image_tensor, imtype=np.uint8, normalize=True,
|
| 142 |
+
three_channel_output=True):
|
| 143 |
+
r"""Convert tensor to image.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
image_tensor (torch.tensor or list of torch.tensor): If tensor then
|
| 147 |
+
(NxCxHxW) or (NxTxCxHxW) or (CxHxW).
|
| 148 |
+
imtype (np.dtype): Type of output image.
|
| 149 |
+
normalize (bool): Is the input image normalized or not?
|
| 150 |
+
three_channel_output (bool): Should single channel images be made 3
|
| 151 |
+
channel in output?
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(numpy.ndarray, list if case 1, 2 above).
|
| 155 |
+
"""
|
| 156 |
+
if image_tensor is None:
|
| 157 |
+
return None
|
| 158 |
+
if isinstance(image_tensor, list):
|
| 159 |
+
return [self.tensor2im(x, imtype, normalize) for x in image_tensor]
|
| 160 |
+
if image_tensor.dim() == 5 or image_tensor.dim() == 4:
|
| 161 |
+
return [self.tensor2im(image_tensor[idx], imtype, normalize)
|
| 162 |
+
for idx in range(image_tensor.size(0))]
|
| 163 |
+
|
| 164 |
+
if image_tensor.dim() == 3:
|
| 165 |
+
image_numpy = image_tensor.detach().float().numpy()
|
| 166 |
+
if normalize:
|
| 167 |
+
image_numpy = (np.transpose(
|
| 168 |
+
image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
| 169 |
+
else:
|
| 170 |
+
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
| 171 |
+
image_numpy = np.clip(image_numpy, 0, 255)
|
| 172 |
+
if image_numpy.shape[2] == 1 and three_channel_output:
|
| 173 |
+
image_numpy = np.repeat(image_numpy, 3, axis=2)
|
| 174 |
+
elif image_numpy.shape[2] > 3:
|
| 175 |
+
image_numpy = image_numpy[:, :, :3]
|
| 176 |
+
return image_numpy.astype(imtype)
|
| 177 |
+
|
| 178 |
+
def trans_keypoins(self, keypoints):
|
| 179 |
+
missing_keypoint_index = keypoints == -1
|
| 180 |
+
|
| 181 |
+
keypoints[missing_keypoint_index] = -1
|
| 182 |
+
return keypoints
|
NTED/edge_attention_layer.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Date: 2023-03-14
|
| 2 |
+
# Creater: zejunyang
|
| 3 |
+
# Function: 边缘注意力层。
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from NTED.base_function import Blur
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ResBlock(nn.Module):
|
| 13 |
+
def __init__(self, in_nc, out_nc, scale='down'): # , norm_layer=nn.BatchNorm2d
|
| 14 |
+
super(ResBlock, self).__init__()
|
| 15 |
+
use_bias = True
|
| 16 |
+
assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
|
| 17 |
+
|
| 18 |
+
if scale == 'same':
|
| 19 |
+
# self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
|
| 20 |
+
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True)
|
| 21 |
+
if scale == 'up':
|
| 22 |
+
self.scale = nn.Sequential(
|
| 23 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
| 24 |
+
nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
|
| 25 |
+
)
|
| 26 |
+
if scale == 'down':
|
| 27 |
+
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
|
| 28 |
+
|
| 29 |
+
self.block = nn.Sequential(
|
| 30 |
+
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
| 31 |
+
# norm_layer(out_nc),
|
| 32 |
+
nn.ReLU(inplace=True),
|
| 33 |
+
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
| 34 |
+
# norm_layer(out_nc)
|
| 35 |
+
)
|
| 36 |
+
self.relu = nn.ReLU(inplace=True)
|
| 37 |
+
# self.padding = nn.ReplicationPad2d(padding=(0, 1, 0, 0))
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
residual = self.scale(x)
|
| 41 |
+
return self.relu(residual + self.block(residual))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Edge_Attn(nn.Module):
|
| 45 |
+
def __init__(self, in_channels=3):
|
| 46 |
+
super(Edge_Attn, self).__init__()
|
| 47 |
+
self.in_channels = in_channels
|
| 48 |
+
|
| 49 |
+
blur_kernel=[1, 3, 3, 3, 1]
|
| 50 |
+
self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1)
|
| 51 |
+
|
| 52 |
+
# self.conv = nn.Conv2d(self.in_channels, self.in_channels, 3, padding=1, bias=False)
|
| 53 |
+
self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same')
|
| 54 |
+
self.sigmoid = nn.Sigmoid()
|
| 55 |
+
|
| 56 |
+
def gradient(self, x):
|
| 57 |
+
h_x = x.size()[2]
|
| 58 |
+
w_x = x.size()[3]
|
| 59 |
+
stride = 3
|
| 60 |
+
r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:]
|
| 61 |
+
l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x]
|
| 62 |
+
t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :]
|
| 63 |
+
b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :]
|
| 64 |
+
xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
|
| 65 |
+
xgrad = self.blur(xgrad)
|
| 66 |
+
return xgrad
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
# feature_edge = self.gradient(x).detach()
|
| 70 |
+
# attn = self.conv(feature_edge)
|
| 71 |
+
|
| 72 |
+
for b in range(x.shape[0]):
|
| 73 |
+
for c in range(x.shape[1]):
|
| 74 |
+
if c == 0:
|
| 75 |
+
channel_edge = self.gradient(x[b:b+1, c:c+1])
|
| 76 |
+
else:
|
| 77 |
+
channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1)
|
| 78 |
+
if b == 0:
|
| 79 |
+
feature_edge = channel_edge
|
| 80 |
+
else:
|
| 81 |
+
feature_edge = torch.concat([feature_edge, channel_edge], dim=0)
|
| 82 |
+
feature_edge = feature_edge.detach()
|
| 83 |
+
feature_edge = x * feature_edge
|
| 84 |
+
attn = self.res_block(feature_edge)
|
| 85 |
+
attn = self.sigmoid(attn)
|
| 86 |
+
|
| 87 |
+
# out = x * attn
|
| 88 |
+
|
| 89 |
+
out = x * attn + x
|
| 90 |
+
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == '__main__':
|
| 96 |
+
from PIL import Image
|
| 97 |
+
import numpy as np
|
| 98 |
+
import cv2
|
| 99 |
+
|
| 100 |
+
edg_atten = Edge_Attn()
|
| 101 |
+
|
| 102 |
+
im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png')
|
| 103 |
+
npim = np.array(im,dtype=np.float32)
|
| 104 |
+
npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY)
|
| 105 |
+
|
| 106 |
+
# npim = npim[:, :, 2]
|
| 107 |
+
tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0)
|
| 108 |
+
edge = edg_atten.gradient(tim)
|
| 109 |
+
npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy()
|
| 110 |
+
Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
|
| 111 |
+
|
| 112 |
+
# tim = torch.from_numpy(npim).unsqueeze_(0)
|
| 113 |
+
# edge = edg_atten.gradient_1order(tim)
|
| 114 |
+
# npgrad = edge.squeeze(0).data.clamp(0,255).numpy()[:, :, 0]
|
| 115 |
+
# Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
|
| 116 |
+
|
NTED/extraction_distribution_model.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
from torch import nn
|
| 3 |
+
from NTED.base_module import Encoder, Decoder
|
| 4 |
+
|
| 5 |
+
from torch.cuda.amp import autocast as autocast
|
| 6 |
+
|
| 7 |
+
class Generator(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
size,
|
| 11 |
+
semantic_dim,
|
| 12 |
+
channels,
|
| 13 |
+
num_labels,
|
| 14 |
+
match_kernels,
|
| 15 |
+
blur_kernel=[1, 3, 3, 1],
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.size = size
|
| 19 |
+
self.reference_encoder = Encoder(
|
| 20 |
+
size, 3, channels, num_labels, match_kernels, blur_kernel
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.skeleton_encoder = Encoder(
|
| 24 |
+
size, semantic_dim, channels,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.target_image_renderer = Decoder(
|
| 28 |
+
size, channels, num_labels, match_kernels, blur_kernel
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def _cal_temp(self, module):
|
| 32 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 33 |
+
|
| 34 |
+
def forward(
|
| 35 |
+
self,
|
| 36 |
+
source_image,
|
| 37 |
+
skeleton,
|
| 38 |
+
amp_flag=False,
|
| 39 |
+
):
|
| 40 |
+
if amp_flag:
|
| 41 |
+
with autocast():
|
| 42 |
+
output_dict={}
|
| 43 |
+
recoder = collections.defaultdict(list)
|
| 44 |
+
skeleton_feature = self.skeleton_encoder(skeleton)
|
| 45 |
+
_ = self.reference_encoder(source_image, recoder)
|
| 46 |
+
neural_textures = recoder["neural_textures"]
|
| 47 |
+
output_dict['fake_image'] = self.target_image_renderer(
|
| 48 |
+
skeleton_feature, neural_textures, recoder
|
| 49 |
+
)
|
| 50 |
+
output_dict['info'] = recoder
|
| 51 |
+
return output_dict
|
| 52 |
+
else:
|
| 53 |
+
output_dict={}
|
| 54 |
+
recoder = collections.defaultdict(list)
|
| 55 |
+
skeleton_feature = self.skeleton_encoder(skeleton)
|
| 56 |
+
_ = self.reference_encoder(source_image, recoder)
|
| 57 |
+
neural_textures = recoder["neural_textures"]
|
| 58 |
+
output_dict['fake_image'] = self.target_image_renderer(
|
| 59 |
+
skeleton_feature, neural_textures, recoder
|
| 60 |
+
)
|
| 61 |
+
output_dict['info'] = recoder
|
| 62 |
+
return output_dict
|
NTED/fashion_512.yaml
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
distributed: True
|
| 2 |
+
image_to_tensorboard: True
|
| 3 |
+
snapshot_save_iter: 50000
|
| 4 |
+
snapshot_save_epoch: 20
|
| 5 |
+
snapshot_save_start_iter: 20000
|
| 6 |
+
snapshot_save_start_epoch: 100
|
| 7 |
+
image_save_iter: 1000
|
| 8 |
+
max_epoch: 400
|
| 9 |
+
logging_iter: 100
|
| 10 |
+
amp: False
|
| 11 |
+
|
| 12 |
+
gen_optimizer:
|
| 13 |
+
type: adam
|
| 14 |
+
lr: 0.002
|
| 15 |
+
adam_beta1: 0.
|
| 16 |
+
adam_beta2: 0.99
|
| 17 |
+
lr_policy:
|
| 18 |
+
iteration_mode: False
|
| 19 |
+
type: step
|
| 20 |
+
step_size: 1000000
|
| 21 |
+
gamma: 1
|
| 22 |
+
|
| 23 |
+
dis_optimizer:
|
| 24 |
+
type: adam
|
| 25 |
+
lr: 0.001882
|
| 26 |
+
adam_beta1: 0.
|
| 27 |
+
adam_beta2: 0.9905
|
| 28 |
+
lr_policy:
|
| 29 |
+
iteration_mode: False
|
| 30 |
+
type: step
|
| 31 |
+
step_size: 1000000
|
| 32 |
+
gamma: 1
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
trainer:
|
| 36 |
+
type: NTED.extraction_distribution_trainer::Trainer
|
| 37 |
+
gan_mode: style_gan2
|
| 38 |
+
gan_start_iteration: 1000 # 0
|
| 39 |
+
face_crop_method: util.face_crop::crop_face_from_output
|
| 40 |
+
hand_crop_method: util.face_crop::crop_hands_from_output
|
| 41 |
+
d_reg_every: 16
|
| 42 |
+
r1: 10
|
| 43 |
+
loss_weight:
|
| 44 |
+
weight_perceptual: 1
|
| 45 |
+
weight_gan: 1.5
|
| 46 |
+
weight_attn_rec: 15
|
| 47 |
+
weight_face: 1
|
| 48 |
+
weight_hand: 1
|
| 49 |
+
weight_l1: 1
|
| 50 |
+
weight_l1_hand: 0.8
|
| 51 |
+
weight_edge: 100
|
| 52 |
+
attn_weights:
|
| 53 |
+
8: 1
|
| 54 |
+
16: 1
|
| 55 |
+
32: 1
|
| 56 |
+
64: 1
|
| 57 |
+
128: 1
|
| 58 |
+
256: 1
|
| 59 |
+
vgg_param:
|
| 60 |
+
network: vgg19
|
| 61 |
+
layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
|
| 62 |
+
num_scales: 3
|
| 63 |
+
use_style_loss: True
|
| 64 |
+
style_to_perceptual: 1000
|
| 65 |
+
vgg_hand_param:
|
| 66 |
+
network: vgg19
|
| 67 |
+
layers: ['relu_1_1', 'relu_2_1', 'relu_3_1','relu_3_3', 'relu_4_1', 'relu_4_3', 'relu_5_1']
|
| 68 |
+
|
| 69 |
+
gen:
|
| 70 |
+
type: NTED.extraction_distribution_model::Generator
|
| 71 |
+
param:
|
| 72 |
+
size: 512
|
| 73 |
+
semantic_dim: 30
|
| 74 |
+
channels:
|
| 75 |
+
16: 512
|
| 76 |
+
32: 512
|
| 77 |
+
64: 512
|
| 78 |
+
128: 256
|
| 79 |
+
256: 128
|
| 80 |
+
512: 64
|
| 81 |
+
1024: 32
|
| 82 |
+
num_labels:
|
| 83 |
+
16: 16
|
| 84 |
+
32: 32
|
| 85 |
+
64: 32
|
| 86 |
+
128: 64
|
| 87 |
+
256: 64
|
| 88 |
+
512: False
|
| 89 |
+
match_kernels:
|
| 90 |
+
16: 1
|
| 91 |
+
32: 3
|
| 92 |
+
64: 3
|
| 93 |
+
128: 3
|
| 94 |
+
256: 3
|
| 95 |
+
512: False
|
| 96 |
+
|
| 97 |
+
dis:
|
| 98 |
+
type: generators.discriminator::Discriminator
|
| 99 |
+
param:
|
| 100 |
+
size: 512
|
| 101 |
+
channels:
|
| 102 |
+
4: 512
|
| 103 |
+
8: 512
|
| 104 |
+
16: 512
|
| 105 |
+
32: 512
|
| 106 |
+
64: 512
|
| 107 |
+
128: 256
|
| 108 |
+
256: 128
|
| 109 |
+
512: 64
|
| 110 |
+
is_square_image: False
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
data:
|
| 114 |
+
type: data.fashion_data::Dataset
|
| 115 |
+
preprocess_mode: resize_and_crop # resize_and_crop
|
| 116 |
+
path: /apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset_2d
|
| 117 |
+
num_workers: 16
|
| 118 |
+
sub_path: 512-352
|
| 119 |
+
resolution: 512
|
| 120 |
+
scale_param: 0.1
|
| 121 |
+
train:
|
| 122 |
+
batch_size: 4 # real_batch_size: 2 * 2 (source-->target & target --> source) * 4 (GPUs) = 16
|
| 123 |
+
distributed: True
|
| 124 |
+
val:
|
| 125 |
+
batch_size: 4
|
| 126 |
+
distributed: True
|
| 127 |
+
hand_keypoint: True
|
| 128 |
+
|
| 129 |
+
|
NTED/nted_checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:359d3d3bac365afe04aa8b906f1dc8891f0dd87ff1dfe5e60059b4fb9bb96af8
|
| 3 |
+
size 284375285
|
NTED/op/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
| 2 |
+
from .upfirdn2d import upfirdn2d
|
NTED/op/conv2d_gradfix.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import autograd
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
enabled = True
|
| 9 |
+
weight_gradients_disabled = False
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@contextlib.contextmanager
|
| 13 |
+
def no_weight_gradients():
|
| 14 |
+
global weight_gradients_disabled
|
| 15 |
+
|
| 16 |
+
old = weight_gradients_disabled
|
| 17 |
+
weight_gradients_disabled = True
|
| 18 |
+
yield
|
| 19 |
+
weight_gradients_disabled = old
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 23 |
+
if could_use_op(input):
|
| 24 |
+
return conv2d_gradfix(
|
| 25 |
+
transpose=False,
|
| 26 |
+
weight_shape=weight.shape,
|
| 27 |
+
stride=stride,
|
| 28 |
+
padding=padding,
|
| 29 |
+
output_padding=0,
|
| 30 |
+
dilation=dilation,
|
| 31 |
+
groups=groups,
|
| 32 |
+
).apply(input, weight, bias)
|
| 33 |
+
|
| 34 |
+
return F.conv2d(
|
| 35 |
+
input=input,
|
| 36 |
+
weight=weight,
|
| 37 |
+
bias=bias,
|
| 38 |
+
stride=stride,
|
| 39 |
+
padding=padding,
|
| 40 |
+
dilation=dilation,
|
| 41 |
+
groups=groups,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def conv_transpose2d(
|
| 46 |
+
input,
|
| 47 |
+
weight,
|
| 48 |
+
bias=None,
|
| 49 |
+
stride=1,
|
| 50 |
+
padding=0,
|
| 51 |
+
output_padding=0,
|
| 52 |
+
groups=1,
|
| 53 |
+
dilation=1,
|
| 54 |
+
):
|
| 55 |
+
if could_use_op(input):
|
| 56 |
+
return conv2d_gradfix(
|
| 57 |
+
transpose=True,
|
| 58 |
+
weight_shape=weight.shape,
|
| 59 |
+
stride=stride,
|
| 60 |
+
padding=padding,
|
| 61 |
+
output_padding=output_padding,
|
| 62 |
+
groups=groups,
|
| 63 |
+
dilation=dilation,
|
| 64 |
+
).apply(input, weight, bias)
|
| 65 |
+
|
| 66 |
+
return F.conv_transpose2d(
|
| 67 |
+
input=input,
|
| 68 |
+
weight=weight,
|
| 69 |
+
bias=bias,
|
| 70 |
+
stride=stride,
|
| 71 |
+
padding=padding,
|
| 72 |
+
output_padding=output_padding,
|
| 73 |
+
dilation=dilation,
|
| 74 |
+
groups=groups,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def could_use_op(input):
|
| 79 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
if input.device.type != "cuda":
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
warnings.warn(
|
| 89 |
+
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def ensure_tuple(xs, ndim):
|
| 96 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
| 97 |
+
|
| 98 |
+
return xs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
conv2d_gradfix_cache = dict()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def conv2d_gradfix(
|
| 105 |
+
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
| 106 |
+
):
|
| 107 |
+
ndim = 2
|
| 108 |
+
weight_shape = tuple(weight_shape)
|
| 109 |
+
stride = ensure_tuple(stride, ndim)
|
| 110 |
+
padding = ensure_tuple(padding, ndim)
|
| 111 |
+
output_padding = ensure_tuple(output_padding, ndim)
|
| 112 |
+
dilation = ensure_tuple(dilation, ndim)
|
| 113 |
+
|
| 114 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
| 115 |
+
if key in conv2d_gradfix_cache:
|
| 116 |
+
return conv2d_gradfix_cache[key]
|
| 117 |
+
|
| 118 |
+
common_kwargs = dict(
|
| 119 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def calc_output_padding(input_shape, output_shape):
|
| 123 |
+
if transpose:
|
| 124 |
+
return [0, 0]
|
| 125 |
+
|
| 126 |
+
return [
|
| 127 |
+
input_shape[i + 2]
|
| 128 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
| 129 |
+
- (1 - 2 * padding[i])
|
| 130 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
| 131 |
+
for i in range(ndim)
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
class Conv2d(autograd.Function):
|
| 135 |
+
@staticmethod
|
| 136 |
+
def forward(ctx, input, weight, bias):
|
| 137 |
+
if not transpose:
|
| 138 |
+
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
| 139 |
+
|
| 140 |
+
else:
|
| 141 |
+
out = F.conv_transpose2d(
|
| 142 |
+
input=input,
|
| 143 |
+
weight=weight,
|
| 144 |
+
bias=bias,
|
| 145 |
+
output_padding=output_padding,
|
| 146 |
+
**common_kwargs,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
ctx.save_for_backward(input, weight)
|
| 150 |
+
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def backward(ctx, grad_output):
|
| 155 |
+
input, weight = ctx.saved_tensors
|
| 156 |
+
grad_input, grad_weight, grad_bias = None, None, None
|
| 157 |
+
|
| 158 |
+
if ctx.needs_input_grad[0]:
|
| 159 |
+
p = calc_output_padding(
|
| 160 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
| 161 |
+
)
|
| 162 |
+
grad_input = conv2d_gradfix(
|
| 163 |
+
transpose=(not transpose),
|
| 164 |
+
weight_shape=weight_shape,
|
| 165 |
+
output_padding=p,
|
| 166 |
+
**common_kwargs,
|
| 167 |
+
).apply(grad_output, weight, None)
|
| 168 |
+
|
| 169 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
| 170 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
| 171 |
+
|
| 172 |
+
if ctx.needs_input_grad[2]:
|
| 173 |
+
grad_bias = grad_output.sum((0, 2, 3))
|
| 174 |
+
|
| 175 |
+
return grad_input, grad_weight, grad_bias
|
| 176 |
+
|
| 177 |
+
class Conv2dGradWeight(autograd.Function):
|
| 178 |
+
@staticmethod
|
| 179 |
+
def forward(ctx, grad_output, input):
|
| 180 |
+
op = torch._C._jit_get_operation(
|
| 181 |
+
"aten::cudnn_convolution_backward_weight"
|
| 182 |
+
if not transpose
|
| 183 |
+
else "aten::cudnn_convolution_transpose_backward_weight"
|
| 184 |
+
)
|
| 185 |
+
flags = [
|
| 186 |
+
torch.backends.cudnn.benchmark,
|
| 187 |
+
torch.backends.cudnn.deterministic,
|
| 188 |
+
torch.backends.cudnn.allow_tf32,
|
| 189 |
+
]
|
| 190 |
+
grad_weight = op(
|
| 191 |
+
weight_shape,
|
| 192 |
+
grad_output,
|
| 193 |
+
input,
|
| 194 |
+
padding,
|
| 195 |
+
stride,
|
| 196 |
+
dilation,
|
| 197 |
+
groups,
|
| 198 |
+
*flags,
|
| 199 |
+
)
|
| 200 |
+
ctx.save_for_backward(grad_output, input)
|
| 201 |
+
|
| 202 |
+
return grad_weight
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def backward(ctx, grad_grad_weight):
|
| 206 |
+
grad_output, input = ctx.saved_tensors
|
| 207 |
+
grad_grad_output, grad_grad_input = None, None
|
| 208 |
+
|
| 209 |
+
if ctx.needs_input_grad[0]:
|
| 210 |
+
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
| 211 |
+
|
| 212 |
+
if ctx.needs_input_grad[1]:
|
| 213 |
+
p = calc_output_padding(
|
| 214 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
| 215 |
+
)
|
| 216 |
+
grad_grad_input = conv2d_gradfix(
|
| 217 |
+
transpose=(not transpose),
|
| 218 |
+
weight_shape=weight_shape,
|
| 219 |
+
output_padding=p,
|
| 220 |
+
**common_kwargs,
|
| 221 |
+
).apply(grad_output, grad_grad_weight, None)
|
| 222 |
+
|
| 223 |
+
return grad_grad_output, grad_grad_input
|
| 224 |
+
|
| 225 |
+
conv2d_gradfix_cache[key] = Conv2d
|
| 226 |
+
|
| 227 |
+
return Conv2d
|
NTED/op/fused_act.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.autograd import Function
|
| 7 |
+
from torch.utils.cpp_extension import load
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
module_path = os.path.dirname(__file__)
|
| 11 |
+
fused = load(
|
| 12 |
+
"fused",
|
| 13 |
+
sources=[
|
| 14 |
+
os.path.join(module_path, "fused_bias_act.cpp"),
|
| 15 |
+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
| 16 |
+
],
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
| 21 |
+
@staticmethod
|
| 22 |
+
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
| 23 |
+
ctx.save_for_backward(out)
|
| 24 |
+
ctx.negative_slope = negative_slope
|
| 25 |
+
ctx.scale = scale
|
| 26 |
+
|
| 27 |
+
empty = grad_output.new_empty(0)
|
| 28 |
+
|
| 29 |
+
grad_input = fused.fused_bias_act(
|
| 30 |
+
grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
dim = [0]
|
| 34 |
+
|
| 35 |
+
if grad_input.ndim > 2:
|
| 36 |
+
dim += list(range(2, grad_input.ndim))
|
| 37 |
+
|
| 38 |
+
if bias:
|
| 39 |
+
grad_bias = grad_input.sum(dim).detach()
|
| 40 |
+
|
| 41 |
+
else:
|
| 42 |
+
grad_bias = empty
|
| 43 |
+
|
| 44 |
+
return grad_input, grad_bias
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
| 48 |
+
out, = ctx.saved_tensors
|
| 49 |
+
gradgrad_out = fused.fused_bias_act(
|
| 50 |
+
gradgrad_input.contiguous(),
|
| 51 |
+
gradgrad_bias.to(gradgrad_input.dtype),
|
| 52 |
+
out,
|
| 53 |
+
3,
|
| 54 |
+
1,
|
| 55 |
+
ctx.negative_slope,
|
| 56 |
+
ctx.scale,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return gradgrad_out, None, None, None, None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FusedLeakyReLUFunction(Function):
|
| 63 |
+
@staticmethod
|
| 64 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
| 65 |
+
empty = input.new_empty(0)
|
| 66 |
+
|
| 67 |
+
ctx.bias = bias is not None
|
| 68 |
+
|
| 69 |
+
if bias is None:
|
| 70 |
+
bias = empty
|
| 71 |
+
|
| 72 |
+
out = fused.fused_bias_act(input, bias.to(input.dtype), empty, 3, 0, negative_slope, scale)
|
| 73 |
+
ctx.save_for_backward(out)
|
| 74 |
+
ctx.negative_slope = negative_slope
|
| 75 |
+
ctx.scale = scale
|
| 76 |
+
|
| 77 |
+
return out
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def backward(ctx, grad_output):
|
| 81 |
+
out, = ctx.saved_tensors
|
| 82 |
+
|
| 83 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
| 84 |
+
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if not ctx.bias:
|
| 88 |
+
grad_bias = None
|
| 89 |
+
|
| 90 |
+
return grad_input, grad_bias, None, None
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class FusedLeakyReLU(nn.Module):
|
| 94 |
+
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
| 95 |
+
super().__init__()
|
| 96 |
+
|
| 97 |
+
if bias:
|
| 98 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
| 99 |
+
|
| 100 |
+
else:
|
| 101 |
+
self.bias = None
|
| 102 |
+
|
| 103 |
+
self.negative_slope = negative_slope
|
| 104 |
+
self.scale = scale
|
| 105 |
+
|
| 106 |
+
def forward(self, input):
|
| 107 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
| 111 |
+
if input.device.type == "cpu":
|
| 112 |
+
if bias is not None:
|
| 113 |
+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
| 114 |
+
return (
|
| 115 |
+
F.leaky_relu(
|
| 116 |
+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
| 117 |
+
)
|
| 118 |
+
* scale
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
return F.leaky_relu(input, negative_slope=0.2) * scale
|
| 123 |
+
|
| 124 |
+
else:
|
| 125 |
+
return FusedLeakyReLUFunction.apply(
|
| 126 |
+
input.contiguous(), bias, negative_slope, scale
|
| 127 |
+
)
|
NTED/op/fused_bias_act.cpp
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
| 6 |
+
const torch::Tensor &bias,
|
| 7 |
+
const torch::Tensor &refer, int act, int grad,
|
| 8 |
+
float alpha, float scale);
|
| 9 |
+
|
| 10 |
+
#define CHECK_CUDA(x) \
|
| 11 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 12 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 13 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 14 |
+
#define CHECK_INPUT(x) \
|
| 15 |
+
CHECK_CUDA(x); \
|
| 16 |
+
CHECK_CONTIGUOUS(x)
|
| 17 |
+
|
| 18 |
+
torch::Tensor fused_bias_act(const torch::Tensor &input,
|
| 19 |
+
const torch::Tensor &bias,
|
| 20 |
+
const torch::Tensor &refer, int act, int grad,
|
| 21 |
+
float alpha, float scale) {
|
| 22 |
+
CHECK_INPUT(input);
|
| 23 |
+
CHECK_INPUT(bias);
|
| 24 |
+
|
| 25 |
+
at::DeviceGuard guard(input.device());
|
| 26 |
+
|
| 27 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 31 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
| 32 |
+
}
|
NTED/op/fused_bias_act_kernel.cu
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
// To view a copy of this license, visit
|
| 5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
| 6 |
+
|
| 7 |
+
#include <torch/types.h>
|
| 8 |
+
|
| 9 |
+
#include <ATen/ATen.h>
|
| 10 |
+
#include <ATen/AccumulateType.h>
|
| 11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 12 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
#include <cuda.h>
|
| 16 |
+
#include <cuda_runtime.h>
|
| 17 |
+
|
| 18 |
+
template <typename scalar_t>
|
| 19 |
+
static __global__ void
|
| 20 |
+
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
|
| 21 |
+
const scalar_t *p_ref, int act, int grad, scalar_t alpha,
|
| 22 |
+
scalar_t scale, int loop_x, int size_x, int step_b,
|
| 23 |
+
int size_b, int use_bias, int use_ref) {
|
| 24 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
| 25 |
+
|
| 26 |
+
scalar_t zero = 0.0;
|
| 27 |
+
|
| 28 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
|
| 29 |
+
loop_idx++, xi += blockDim.x) {
|
| 30 |
+
scalar_t x = p_x[xi];
|
| 31 |
+
|
| 32 |
+
if (use_bias) {
|
| 33 |
+
x += p_b[(xi / step_b) % size_b];
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
| 37 |
+
|
| 38 |
+
scalar_t y;
|
| 39 |
+
|
| 40 |
+
switch (act * 10 + grad) {
|
| 41 |
+
default:
|
| 42 |
+
case 10:
|
| 43 |
+
y = x;
|
| 44 |
+
break;
|
| 45 |
+
case 11:
|
| 46 |
+
y = x;
|
| 47 |
+
break;
|
| 48 |
+
case 12:
|
| 49 |
+
y = 0.0;
|
| 50 |
+
break;
|
| 51 |
+
|
| 52 |
+
case 30:
|
| 53 |
+
y = (x > 0.0) ? x : x * alpha;
|
| 54 |
+
break;
|
| 55 |
+
case 31:
|
| 56 |
+
y = (ref > 0.0) ? x : x * alpha;
|
| 57 |
+
break;
|
| 58 |
+
case 32:
|
| 59 |
+
y = 0.0;
|
| 60 |
+
break;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
out[xi] = y * scale;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor &input,
|
| 68 |
+
const torch::Tensor &bias,
|
| 69 |
+
const torch::Tensor &refer, int act, int grad,
|
| 70 |
+
float alpha, float scale) {
|
| 71 |
+
int curDevice = -1;
|
| 72 |
+
cudaGetDevice(&curDevice);
|
| 73 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 74 |
+
|
| 75 |
+
auto x = input.contiguous();
|
| 76 |
+
auto b = bias.contiguous();
|
| 77 |
+
auto ref = refer.contiguous();
|
| 78 |
+
|
| 79 |
+
int use_bias = b.numel() ? 1 : 0;
|
| 80 |
+
int use_ref = ref.numel() ? 1 : 0;
|
| 81 |
+
|
| 82 |
+
int size_x = x.numel();
|
| 83 |
+
int size_b = b.numel();
|
| 84 |
+
int step_b = 1;
|
| 85 |
+
|
| 86 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
| 87 |
+
step_b *= x.size(i);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
int loop_x = 4;
|
| 91 |
+
int block_size = 4 * 32;
|
| 92 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
| 93 |
+
|
| 94 |
+
auto y = torch::empty_like(x);
|
| 95 |
+
|
| 96 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 97 |
+
x.scalar_type(), "fused_bias_act_kernel", [&] {
|
| 98 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
| 99 |
+
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
| 100 |
+
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
|
| 101 |
+
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
|
| 102 |
+
});
|
| 103 |
+
|
| 104 |
+
return y;
|
| 105 |
+
}
|
NTED/op/upfirdn2d.cpp
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/ATen.h>
|
| 2 |
+
#include <torch/extension.h>
|
| 3 |
+
|
| 4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
| 5 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
| 6 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
| 7 |
+
int pad_y0, int pad_y1);
|
| 8 |
+
|
| 9 |
+
#define CHECK_CUDA(x) \
|
| 10 |
+
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 11 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 12 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 13 |
+
#define CHECK_INPUT(x) \
|
| 14 |
+
CHECK_CUDA(x); \
|
| 15 |
+
CHECK_CONTIGUOUS(x)
|
| 16 |
+
|
| 17 |
+
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
|
| 18 |
+
int up_x, int up_y, int down_x, int down_y, int pad_x0,
|
| 19 |
+
int pad_x1, int pad_y0, int pad_y1) {
|
| 20 |
+
CHECK_INPUT(input);
|
| 21 |
+
CHECK_INPUT(kernel);
|
| 22 |
+
|
| 23 |
+
at::DeviceGuard guard(input.device());
|
| 24 |
+
|
| 25 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
| 26 |
+
pad_y0, pad_y1);
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 30 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
| 31 |
+
}
|
NTED/op/upfirdn2d.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import abc
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.autograd import Function
|
| 7 |
+
from torch.utils.cpp_extension import load
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
module_path = os.path.dirname(__file__)
|
| 11 |
+
upfirdn2d_op = load(
|
| 12 |
+
"upfirdn2d",
|
| 13 |
+
sources=[
|
| 14 |
+
os.path.join(module_path, "upfirdn2d.cpp"),
|
| 15 |
+
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
| 16 |
+
],
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class UpFirDn2dBackward(Function):
|
| 21 |
+
@staticmethod
|
| 22 |
+
def forward(
|
| 23 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
| 24 |
+
):
|
| 25 |
+
|
| 26 |
+
up_x, up_y = up
|
| 27 |
+
down_x, down_y = down
|
| 28 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
| 29 |
+
|
| 30 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
| 31 |
+
|
| 32 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
| 33 |
+
grad_output,
|
| 34 |
+
grad_kernel.to(grad_output.dtype),
|
| 35 |
+
down_x,
|
| 36 |
+
down_y,
|
| 37 |
+
up_x,
|
| 38 |
+
up_y,
|
| 39 |
+
g_pad_x0,
|
| 40 |
+
g_pad_x1,
|
| 41 |
+
g_pad_y0,
|
| 42 |
+
g_pad_y1,
|
| 43 |
+
)
|
| 44 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
| 45 |
+
|
| 46 |
+
ctx.save_for_backward(kernel)
|
| 47 |
+
|
| 48 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
| 49 |
+
|
| 50 |
+
ctx.up_x = up_x
|
| 51 |
+
ctx.up_y = up_y
|
| 52 |
+
ctx.down_x = down_x
|
| 53 |
+
ctx.down_y = down_y
|
| 54 |
+
ctx.pad_x0 = pad_x0
|
| 55 |
+
ctx.pad_x1 = pad_x1
|
| 56 |
+
ctx.pad_y0 = pad_y0
|
| 57 |
+
ctx.pad_y1 = pad_y1
|
| 58 |
+
ctx.in_size = in_size
|
| 59 |
+
ctx.out_size = out_size
|
| 60 |
+
|
| 61 |
+
return grad_input
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def backward(ctx, gradgrad_input):
|
| 65 |
+
kernel, = ctx.saved_tensors
|
| 66 |
+
|
| 67 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
| 68 |
+
|
| 69 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
| 70 |
+
gradgrad_input,
|
| 71 |
+
kernel.to(gradgrad_input.dtype),
|
| 72 |
+
ctx.up_x,
|
| 73 |
+
ctx.up_y,
|
| 74 |
+
ctx.down_x,
|
| 75 |
+
ctx.down_y,
|
| 76 |
+
ctx.pad_x0,
|
| 77 |
+
ctx.pad_x1,
|
| 78 |
+
ctx.pad_y0,
|
| 79 |
+
ctx.pad_y1,
|
| 80 |
+
)
|
| 81 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
| 82 |
+
gradgrad_out = gradgrad_out.view(
|
| 83 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class UpFirDn2d(Function):
|
| 90 |
+
@staticmethod
|
| 91 |
+
def forward(ctx, input, kernel, up, down, pad):
|
| 92 |
+
up_x, up_y = up
|
| 93 |
+
down_x, down_y = down
|
| 94 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
| 95 |
+
|
| 96 |
+
kernel_h, kernel_w = kernel.shape
|
| 97 |
+
batch, channel, in_h, in_w = input.shape
|
| 98 |
+
ctx.in_size = input.shape
|
| 99 |
+
|
| 100 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
| 101 |
+
|
| 102 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
| 103 |
+
|
| 104 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
| 105 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
| 106 |
+
ctx.out_size = (out_h, out_w)
|
| 107 |
+
|
| 108 |
+
ctx.up = (up_x, up_y)
|
| 109 |
+
ctx.down = (down_x, down_y)
|
| 110 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
| 111 |
+
|
| 112 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
| 113 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
| 114 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
| 115 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
| 116 |
+
|
| 117 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
| 118 |
+
|
| 119 |
+
out = upfirdn2d_op.upfirdn2d(
|
| 120 |
+
input, kernel.to(input.dtype), up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
| 121 |
+
)
|
| 122 |
+
# out = out.view(major, out_h, out_w, minor)
|
| 123 |
+
out = out.view(-1, channel, out_h, out_w)
|
| 124 |
+
|
| 125 |
+
return out
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def backward(ctx, grad_output):
|
| 129 |
+
kernel, grad_kernel = ctx.saved_tensors
|
| 130 |
+
|
| 131 |
+
grad_input = None
|
| 132 |
+
|
| 133 |
+
if ctx.needs_input_grad[0]:
|
| 134 |
+
grad_input = UpFirDn2dBackward.apply(
|
| 135 |
+
grad_output,
|
| 136 |
+
kernel,
|
| 137 |
+
grad_kernel,
|
| 138 |
+
ctx.up,
|
| 139 |
+
ctx.down,
|
| 140 |
+
ctx.pad,
|
| 141 |
+
ctx.g_pad,
|
| 142 |
+
ctx.in_size,
|
| 143 |
+
ctx.out_size,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return grad_input, None, None, None, None
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
| 150 |
+
if not isinstance(up, abc.Iterable):
|
| 151 |
+
up = (up, up)
|
| 152 |
+
|
| 153 |
+
if not isinstance(down, abc.Iterable):
|
| 154 |
+
down = (down, down)
|
| 155 |
+
|
| 156 |
+
if len(pad) == 2:
|
| 157 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
| 158 |
+
|
| 159 |
+
if input.device.type == "cpu":
|
| 160 |
+
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
|
| 161 |
+
|
| 162 |
+
else:
|
| 163 |
+
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
| 164 |
+
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def upfirdn2d_native(
|
| 169 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
| 170 |
+
):
|
| 171 |
+
_, channel, in_h, in_w = input.shape
|
| 172 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
| 173 |
+
|
| 174 |
+
_, in_h, in_w, minor = input.shape
|
| 175 |
+
kernel_h, kernel_w = kernel.shape
|
| 176 |
+
|
| 177 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
| 178 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
| 179 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
| 180 |
+
|
| 181 |
+
out = F.pad(
|
| 182 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
| 183 |
+
)
|
| 184 |
+
out = out[
|
| 185 |
+
:,
|
| 186 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
| 187 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
| 188 |
+
:,
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
out = out.permute(0, 3, 1, 2)
|
| 192 |
+
out = out.reshape(
|
| 193 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
| 194 |
+
)
|
| 195 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
| 196 |
+
out = F.conv2d(out, w)
|
| 197 |
+
out = out.reshape(
|
| 198 |
+
-1,
|
| 199 |
+
minor,
|
| 200 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
| 201 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
| 202 |
+
)
|
| 203 |
+
out = out.permute(0, 2, 3, 1)
|
| 204 |
+
out = out[:, ::down_y, ::down_x, :]
|
| 205 |
+
|
| 206 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
| 207 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
| 208 |
+
|
| 209 |
+
return out.view(-1, channel, out_h, out_w)
|
NTED/op/upfirdn2d_kernel.cu
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
// To view a copy of this license, visit
|
| 5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
| 6 |
+
|
| 7 |
+
#include <torch/types.h>
|
| 8 |
+
|
| 9 |
+
#include <ATen/ATen.h>
|
| 10 |
+
#include <ATen/AccumulateType.h>
|
| 11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 12 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 13 |
+
|
| 14 |
+
#include <cuda.h>
|
| 15 |
+
#include <cuda_runtime.h>
|
| 16 |
+
|
| 17 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
| 18 |
+
int c = a / b;
|
| 19 |
+
|
| 20 |
+
if (c * b > a) {
|
| 21 |
+
c--;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
return c;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
struct UpFirDn2DKernelParams {
|
| 28 |
+
int up_x;
|
| 29 |
+
int up_y;
|
| 30 |
+
int down_x;
|
| 31 |
+
int down_y;
|
| 32 |
+
int pad_x0;
|
| 33 |
+
int pad_x1;
|
| 34 |
+
int pad_y0;
|
| 35 |
+
int pad_y1;
|
| 36 |
+
|
| 37 |
+
int major_dim;
|
| 38 |
+
int in_h;
|
| 39 |
+
int in_w;
|
| 40 |
+
int minor_dim;
|
| 41 |
+
int kernel_h;
|
| 42 |
+
int kernel_w;
|
| 43 |
+
int out_h;
|
| 44 |
+
int out_w;
|
| 45 |
+
int loop_major;
|
| 46 |
+
int loop_x;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template <typename scalar_t>
|
| 50 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
| 51 |
+
const scalar_t *kernel,
|
| 52 |
+
const UpFirDn2DKernelParams p) {
|
| 53 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 54 |
+
int out_y = minor_idx / p.minor_dim;
|
| 55 |
+
minor_idx -= out_y * p.minor_dim;
|
| 56 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
| 57 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
| 58 |
+
|
| 59 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
| 60 |
+
major_idx_base >= p.major_dim) {
|
| 61 |
+
return;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
| 65 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
| 66 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
| 67 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
| 68 |
+
|
| 69 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
| 70 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
| 71 |
+
loop_major++, major_idx++) {
|
| 72 |
+
for (int loop_x = 0, out_x = out_x_base;
|
| 73 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
| 74 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
| 75 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
| 76 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
| 77 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
| 78 |
+
|
| 79 |
+
const scalar_t *x_p =
|
| 80 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
| 81 |
+
minor_idx];
|
| 82 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
| 83 |
+
int x_px = p.minor_dim;
|
| 84 |
+
int k_px = -p.up_x;
|
| 85 |
+
int x_py = p.in_w * p.minor_dim;
|
| 86 |
+
int k_py = -p.up_y * p.kernel_w;
|
| 87 |
+
|
| 88 |
+
scalar_t v = 0.0f;
|
| 89 |
+
|
| 90 |
+
for (int y = 0; y < h; y++) {
|
| 91 |
+
for (int x = 0; x < w; x++) {
|
| 92 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
| 93 |
+
x_p += x_px;
|
| 94 |
+
k_p += k_px;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
x_p += x_py - w * x_px;
|
| 98 |
+
k_p += k_py - w * k_px;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
| 102 |
+
minor_idx] = v;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
| 108 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
| 109 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
| 110 |
+
const scalar_t *kernel,
|
| 111 |
+
const UpFirDn2DKernelParams p) {
|
| 112 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
| 113 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
| 114 |
+
|
| 115 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
| 116 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
| 117 |
+
|
| 118 |
+
int minor_idx = blockIdx.x;
|
| 119 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
| 120 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
| 121 |
+
tile_out_y *= tile_out_h;
|
| 122 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
| 123 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
| 124 |
+
|
| 125 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
| 126 |
+
major_idx_base >= p.major_dim) {
|
| 127 |
+
return;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
| 131 |
+
tap_idx += blockDim.x) {
|
| 132 |
+
int ky = tap_idx / kernel_w;
|
| 133 |
+
int kx = tap_idx - ky * kernel_w;
|
| 134 |
+
scalar_t v = 0.0;
|
| 135 |
+
|
| 136 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
| 137 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
sk[ky][kx] = v;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
| 144 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
| 145 |
+
loop_major++, major_idx++) {
|
| 146 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
| 147 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
| 148 |
+
loop_x++, tile_out_x += tile_out_w) {
|
| 149 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
| 150 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
| 151 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
| 152 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
| 153 |
+
|
| 154 |
+
__syncthreads();
|
| 155 |
+
|
| 156 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
| 157 |
+
in_idx += blockDim.x) {
|
| 158 |
+
int rel_in_y = in_idx / tile_in_w;
|
| 159 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
| 160 |
+
int in_x = rel_in_x + tile_in_x;
|
| 161 |
+
int in_y = rel_in_y + tile_in_y;
|
| 162 |
+
|
| 163 |
+
scalar_t v = 0.0;
|
| 164 |
+
|
| 165 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
| 166 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
| 167 |
+
p.minor_dim +
|
| 168 |
+
minor_idx];
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
sx[rel_in_y][rel_in_x] = v;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
__syncthreads();
|
| 175 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
| 176 |
+
out_idx += blockDim.x) {
|
| 177 |
+
int rel_out_y = out_idx / tile_out_w;
|
| 178 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
| 179 |
+
int out_x = rel_out_x + tile_out_x;
|
| 180 |
+
int out_y = rel_out_y + tile_out_y;
|
| 181 |
+
|
| 182 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
| 183 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
| 184 |
+
int in_x = floor_div(mid_x, up_x);
|
| 185 |
+
int in_y = floor_div(mid_y, up_y);
|
| 186 |
+
int rel_in_x = in_x - tile_in_x;
|
| 187 |
+
int rel_in_y = in_y - tile_in_y;
|
| 188 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
| 189 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
| 190 |
+
|
| 191 |
+
scalar_t v = 0.0;
|
| 192 |
+
|
| 193 |
+
#pragma unroll
|
| 194 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
| 195 |
+
#pragma unroll
|
| 196 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
| 197 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
| 198 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
| 199 |
+
|
| 200 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
| 201 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
| 202 |
+
minor_idx] = v;
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
| 210 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
| 211 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
| 212 |
+
int pad_y0, int pad_y1) {
|
| 213 |
+
int curDevice = -1;
|
| 214 |
+
cudaGetDevice(&curDevice);
|
| 215 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 216 |
+
|
| 217 |
+
UpFirDn2DKernelParams p;
|
| 218 |
+
|
| 219 |
+
auto x = input.contiguous();
|
| 220 |
+
auto k = kernel.contiguous();
|
| 221 |
+
|
| 222 |
+
p.major_dim = x.size(0);
|
| 223 |
+
p.in_h = x.size(1);
|
| 224 |
+
p.in_w = x.size(2);
|
| 225 |
+
p.minor_dim = x.size(3);
|
| 226 |
+
p.kernel_h = k.size(0);
|
| 227 |
+
p.kernel_w = k.size(1);
|
| 228 |
+
p.up_x = up_x;
|
| 229 |
+
p.up_y = up_y;
|
| 230 |
+
p.down_x = down_x;
|
| 231 |
+
p.down_y = down_y;
|
| 232 |
+
p.pad_x0 = pad_x0;
|
| 233 |
+
p.pad_x1 = pad_x1;
|
| 234 |
+
p.pad_y0 = pad_y0;
|
| 235 |
+
p.pad_y1 = pad_y1;
|
| 236 |
+
|
| 237 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
| 238 |
+
p.down_y;
|
| 239 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
| 240 |
+
p.down_x;
|
| 241 |
+
|
| 242 |
+
auto out =
|
| 243 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
| 244 |
+
|
| 245 |
+
int mode = -1;
|
| 246 |
+
|
| 247 |
+
int tile_out_h = -1;
|
| 248 |
+
int tile_out_w = -1;
|
| 249 |
+
|
| 250 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
| 251 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
| 252 |
+
mode = 1;
|
| 253 |
+
tile_out_h = 16;
|
| 254 |
+
tile_out_w = 64;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
| 258 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
| 259 |
+
mode = 2;
|
| 260 |
+
tile_out_h = 16;
|
| 261 |
+
tile_out_w = 64;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
| 265 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
| 266 |
+
mode = 3;
|
| 267 |
+
tile_out_h = 16;
|
| 268 |
+
tile_out_w = 64;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
| 272 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
| 273 |
+
mode = 4;
|
| 274 |
+
tile_out_h = 16;
|
| 275 |
+
tile_out_w = 64;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
| 279 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
| 280 |
+
mode = 5;
|
| 281 |
+
tile_out_h = 8;
|
| 282 |
+
tile_out_w = 32;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
| 286 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
| 287 |
+
mode = 6;
|
| 288 |
+
tile_out_h = 8;
|
| 289 |
+
tile_out_w = 32;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
dim3 block_size;
|
| 293 |
+
dim3 grid_size;
|
| 294 |
+
|
| 295 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
| 296 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
| 297 |
+
p.loop_x = 1;
|
| 298 |
+
block_size = dim3(32 * 8, 1, 1);
|
| 299 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
| 300 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
| 301 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
| 302 |
+
} else {
|
| 303 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
| 304 |
+
p.loop_x = 4;
|
| 305 |
+
block_size = dim3(4, 32, 1);
|
| 306 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
| 307 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
| 308 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
| 312 |
+
switch (mode) {
|
| 313 |
+
case 1:
|
| 314 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
| 315 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 316 |
+
x.data_ptr<scalar_t>(),
|
| 317 |
+
k.data_ptr<scalar_t>(), p);
|
| 318 |
+
|
| 319 |
+
break;
|
| 320 |
+
|
| 321 |
+
case 2:
|
| 322 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
| 323 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 324 |
+
x.data_ptr<scalar_t>(),
|
| 325 |
+
k.data_ptr<scalar_t>(), p);
|
| 326 |
+
|
| 327 |
+
break;
|
| 328 |
+
|
| 329 |
+
case 3:
|
| 330 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
| 331 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 332 |
+
x.data_ptr<scalar_t>(),
|
| 333 |
+
k.data_ptr<scalar_t>(), p);
|
| 334 |
+
|
| 335 |
+
break;
|
| 336 |
+
|
| 337 |
+
case 4:
|
| 338 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
| 339 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 340 |
+
x.data_ptr<scalar_t>(),
|
| 341 |
+
k.data_ptr<scalar_t>(), p);
|
| 342 |
+
|
| 343 |
+
break;
|
| 344 |
+
|
| 345 |
+
case 5:
|
| 346 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
| 347 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 348 |
+
x.data_ptr<scalar_t>(),
|
| 349 |
+
k.data_ptr<scalar_t>(), p);
|
| 350 |
+
|
| 351 |
+
break;
|
| 352 |
+
|
| 353 |
+
case 6:
|
| 354 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
| 355 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 356 |
+
x.data_ptr<scalar_t>(),
|
| 357 |
+
k.data_ptr<scalar_t>(), p);
|
| 358 |
+
|
| 359 |
+
break;
|
| 360 |
+
|
| 361 |
+
default:
|
| 362 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
| 363 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
| 364 |
+
k.data_ptr<scalar_t>(), p);
|
| 365 |
+
}
|
| 366 |
+
});
|
| 367 |
+
|
| 368 |
+
return out;
|
| 369 |
+
}
|
app.py
CHANGED
|
@@ -1,17 +1,29 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
return "恭喜,您今年" + 年龄预测器_输入您的年龄 + "岁了!"
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
'''
|
| 11 |
TODO
|
| 12 |
-
先把openpose light整合进来测试一下
|
| 13 |
-
|
| 14 |
测试视频展示功能
|
| 15 |
-
|
| 16 |
-
|
| 17 |
'''
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from NTED.NTED_module import NTED
|
| 6 |
|
| 7 |
+
NTED_Module = NTED()
|
|
|
|
| 8 |
|
| 9 |
+
def pose_transfer(上传人体姿态图):
|
| 10 |
+
img = 上传人体姿态图
|
| 11 |
+
fake_img = NTED_Module.inference(img)
|
| 12 |
+
|
| 13 |
+
return fake_img
|
| 14 |
|
| 15 |
+
with gr.Column():
|
| 16 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
| 17 |
+
|
| 18 |
+
gr.Interface(fn=pose_transfer,
|
| 19 |
+
inputs=["image"],
|
| 20 |
+
outputs=[result_gallery],
|
| 21 |
+
title="谷小雨姿态驱动图像",
|
| 22 |
+
examples=[["example/exp1.png"], ["example/exp2.png"], ["example/exp3.png"],\
|
| 23 |
+
["example/exp4.png"], ["example/exp5.png"], ["example/exp6.png"]],
|
| 24 |
+
).launch(server_name='0.0.0.0')
|
| 25 |
|
| 26 |
'''
|
| 27 |
TODO
|
|
|
|
|
|
|
| 28 |
测试视频展示功能
|
|
|
|
|
|
|
| 29 |
'''
|
example/exp1.png
ADDED
|
example/exp2.png
ADDED
|
example/exp3.png
ADDED
|
example/exp4.png
ADDED
|
example/exp5.png
ADDED
|
example/exp6.png
ADDED
|
example/ref_img.png
ADDED
|
Git LFS Details
|
lite_openpose/body_bbox_detector.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import sys
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
# from PIL import Image
|
| 13 |
+
|
| 14 |
+
# Code from https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch/blob/master/demo.py
|
| 15 |
+
|
| 16 |
+
# 2D body pose estimator
|
| 17 |
+
sys.path.append('/apdcephfs/share_1474453/zejunzhang/workspace/HR-VITON/dataset_process_utils/lite_openpose')
|
| 18 |
+
from pose2d_models.with_mobilenet import PoseEstimationWithMobileNet
|
| 19 |
+
from modules.load_state import load_state
|
| 20 |
+
from modules.pose import Pose, track_poses
|
| 21 |
+
from modules.keypoints import extract_keypoints, group_keypoints
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def normalize(img, img_mean, img_scale):
|
| 25 |
+
img = np.array(img, dtype=np.float32)
|
| 26 |
+
img = (img - img_mean) * img_scale
|
| 27 |
+
return img
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def pad_width(img, stride, pad_value, min_dims):
|
| 31 |
+
h, w, _ = img.shape
|
| 32 |
+
h = min(min_dims[0], h)
|
| 33 |
+
min_dims[0] = math.ceil(min_dims[0] / float(stride)) * stride
|
| 34 |
+
min_dims[1] = max(min_dims[1], w)
|
| 35 |
+
min_dims[1] = math.ceil(min_dims[1] / float(stride)) * stride
|
| 36 |
+
pad = []
|
| 37 |
+
pad.append(int(math.floor((min_dims[0] - h) / 2.0)))
|
| 38 |
+
pad.append(int(math.floor((min_dims[1] - w) / 2.0)))
|
| 39 |
+
pad.append(int(min_dims[0] - h - pad[0]))
|
| 40 |
+
pad.append(int(min_dims[1] - w - pad[1]))
|
| 41 |
+
padded_img = cv2.copyMakeBorder(img, pad[0], pad[2], pad[1], pad[3],
|
| 42 |
+
cv2.BORDER_CONSTANT, value=pad_value)
|
| 43 |
+
return padded_img, pad
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BodyPoseEstimator(object):
|
| 47 |
+
"""
|
| 48 |
+
Hand Detector for third-view input.
|
| 49 |
+
It combines a body pose estimator (https://github.com/jhugestar/lightweight-human-pose-estimation.pytorch.git)
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self, device='cpu'):
|
| 52 |
+
# print("Loading Body Pose Estimator")
|
| 53 |
+
self.device=device
|
| 54 |
+
self.__load_body_estimator()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def __load_body_estimator(self):
|
| 59 |
+
net = PoseEstimationWithMobileNet()
|
| 60 |
+
pose2d_checkpoint = "lite_openpose/checkpoint_iter_370000.pth"
|
| 61 |
+
checkpoint = torch.load(pose2d_checkpoint, map_location='cpu')
|
| 62 |
+
load_state(net, checkpoint)
|
| 63 |
+
net = net.eval()
|
| 64 |
+
net = net.to(self.device)
|
| 65 |
+
self.model = net
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
#Code from https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch/demo.py
|
| 69 |
+
def __infer_fast(self, img, input_height_size, stride, upsample_ratio,
|
| 70 |
+
cpu=False, pad_value=(0, 0, 0), img_mean=(128, 128, 128), img_scale=1/256):
|
| 71 |
+
height, width, _ = img.shape
|
| 72 |
+
scale = input_height_size / height
|
| 73 |
+
|
| 74 |
+
scaled_img = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
| 75 |
+
scaled_img = normalize(scaled_img, img_mean, img_scale)
|
| 76 |
+
min_dims = [input_height_size, max(scaled_img.shape[1], input_height_size)]
|
| 77 |
+
padded_img, pad = pad_width(scaled_img, stride, pad_value, min_dims)
|
| 78 |
+
|
| 79 |
+
tensor_img = torch.from_numpy(padded_img).permute(2, 0, 1).unsqueeze(0).float()
|
| 80 |
+
if not cpu:
|
| 81 |
+
tensor_img = tensor_img.to(self.device)
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
stages_output = self.model(tensor_img)
|
| 85 |
+
|
| 86 |
+
stage2_heatmaps = stages_output[-2]
|
| 87 |
+
heatmaps = np.transpose(stage2_heatmaps.squeeze().cpu().data.numpy(), (1, 2, 0))
|
| 88 |
+
heatmaps = cv2.resize(heatmaps, (0, 0), fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
|
| 89 |
+
|
| 90 |
+
stage2_pafs = stages_output[-1]
|
| 91 |
+
pafs = np.transpose(stage2_pafs.squeeze().cpu().data.numpy(), (1, 2, 0))
|
| 92 |
+
pafs = cv2.resize(pafs, (0, 0), fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
|
| 93 |
+
|
| 94 |
+
return heatmaps, pafs, scale, pad
|
| 95 |
+
|
| 96 |
+
def detect_body_pose(self, img):
|
| 97 |
+
"""
|
| 98 |
+
Output:
|
| 99 |
+
current_bbox: BBOX_XYWH
|
| 100 |
+
"""
|
| 101 |
+
stride = 8
|
| 102 |
+
upsample_ratio = 4
|
| 103 |
+
orig_img = img.copy()
|
| 104 |
+
|
| 105 |
+
# forward
|
| 106 |
+
heatmaps, pafs, scale, pad = self.__infer_fast(img,
|
| 107 |
+
input_height_size=256, stride=stride, upsample_ratio=upsample_ratio)
|
| 108 |
+
|
| 109 |
+
total_keypoints_num = 0
|
| 110 |
+
all_keypoints_by_type = []
|
| 111 |
+
num_keypoints = Pose.num_kpts
|
| 112 |
+
for kpt_idx in range(num_keypoints): # 19th for bg
|
| 113 |
+
total_keypoints_num += extract_keypoints(heatmaps[:, :, kpt_idx], all_keypoints_by_type, total_keypoints_num)
|
| 114 |
+
|
| 115 |
+
pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, pafs, demo=True)
|
| 116 |
+
for kpt_id in range(all_keypoints.shape[0]):
|
| 117 |
+
all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride / upsample_ratio - pad[1]) / scale
|
| 118 |
+
all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride / upsample_ratio - pad[0]) / scale
|
| 119 |
+
|
| 120 |
+
'''
|
| 121 |
+
# print(len(pose_entries))
|
| 122 |
+
if len(pose_entries)>1:
|
| 123 |
+
pose_entries = pose_entries[:1]
|
| 124 |
+
print("We only support one person currently")
|
| 125 |
+
# assert len(pose_entries) == 1, "We only support one person currently"
|
| 126 |
+
'''
|
| 127 |
+
|
| 128 |
+
current_poses, current_bbox = list(), list()
|
| 129 |
+
for n in range(len(pose_entries)):
|
| 130 |
+
if len(pose_entries[n]) == 0:
|
| 131 |
+
continue
|
| 132 |
+
pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1
|
| 133 |
+
for kpt_id in range(num_keypoints):
|
| 134 |
+
if pose_entries[n][kpt_id] != -1.0: # keypoint was found
|
| 135 |
+
pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0])
|
| 136 |
+
pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1])
|
| 137 |
+
pose = Pose(pose_keypoints, pose_entries[n][18])
|
| 138 |
+
current_poses.append(pose.keypoints)
|
| 139 |
+
current_bbox.append(np.array(pose.bbox))
|
| 140 |
+
|
| 141 |
+
# enlarge the bbox
|
| 142 |
+
for i, bbox in enumerate(current_bbox):
|
| 143 |
+
x, y, w, h = bbox
|
| 144 |
+
margin = 0.2
|
| 145 |
+
x_margin = int(w * margin)
|
| 146 |
+
y_margin = int(h * margin)
|
| 147 |
+
x0 = max(x-x_margin, 0)
|
| 148 |
+
y0 = max(y-y_margin, 0)
|
| 149 |
+
x1 = min(x+w+x_margin, orig_img.shape[1])
|
| 150 |
+
y1 = min(y+h+y_margin, orig_img.shape[0])
|
| 151 |
+
current_bbox[i] = np.array((x0, y0, x1, y1)).astype(np.int32) # ltrb
|
| 152 |
+
|
| 153 |
+
# 只拿一个人
|
| 154 |
+
body_point_list = []
|
| 155 |
+
if len(current_poses) > 0:
|
| 156 |
+
for item in current_poses[0]:
|
| 157 |
+
if item[0] == item[1] == -1:
|
| 158 |
+
body_point_list += [0.0, 0.0, 0.0]
|
| 159 |
+
else:
|
| 160 |
+
body_point_list += [float(item[0]), float(item[1]), 1.0]
|
| 161 |
+
else:
|
| 162 |
+
for i in range(18):
|
| 163 |
+
body_point_list += [0.0, 0.0, 0.0]
|
| 164 |
+
|
| 165 |
+
pose_dict = dict()
|
| 166 |
+
pose_dict["people"] = []
|
| 167 |
+
pose_dict["people"].append({
|
| 168 |
+
"person_id": [-1],
|
| 169 |
+
"pose_keypoints_2d": body_point_list,
|
| 170 |
+
"hand_left_keypoints_2d": [],
|
| 171 |
+
"hand_right_keypoints_2d": [],
|
| 172 |
+
"face_keypoints_2d": [],
|
| 173 |
+
"pose_keypoints_3d": [],
|
| 174 |
+
"face_keypoints_3d": [],
|
| 175 |
+
"hand_left_keypoints_3d": [],
|
| 176 |
+
"hand_right_keypoints_3d": [],
|
| 177 |
+
})
|
| 178 |
+
|
| 179 |
+
return current_poses, current_bbox
|
lite_openpose/checkpoint_iter_370000.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:517c86f769c6636583083f1467e3d212a0006c27109edb3aeffc19a79622d411
|
| 3 |
+
size 87959810
|
lite_openpose/modules/__init__.py
ADDED
|
File without changes
|
lite_openpose/modules/conv.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def conv(in_channels, out_channels, kernel_size=3, padding=1, bn=True, dilation=1, stride=1, relu=True, bias=True):
|
| 5 |
+
modules = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)]
|
| 6 |
+
if bn:
|
| 7 |
+
modules.append(nn.BatchNorm2d(out_channels))
|
| 8 |
+
if relu:
|
| 9 |
+
modules.append(nn.ReLU(inplace=True))
|
| 10 |
+
return nn.Sequential(*modules)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def conv_dw(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
|
| 14 |
+
return nn.Sequential(
|
| 15 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
|
| 16 |
+
nn.BatchNorm2d(in_channels),
|
| 17 |
+
nn.ReLU(inplace=True),
|
| 18 |
+
|
| 19 |
+
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
|
| 20 |
+
nn.BatchNorm2d(out_channels),
|
| 21 |
+
nn.ReLU(inplace=True),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def conv_dw_no_bn(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
|
| 26 |
+
return nn.Sequential(
|
| 27 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
|
| 28 |
+
nn.ELU(inplace=True),
|
| 29 |
+
|
| 30 |
+
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
|
| 31 |
+
nn.ELU(inplace=True),
|
| 32 |
+
)
|
lite_openpose/modules/get_parameters.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_parameters(model, predicate):
|
| 5 |
+
for module in model.modules():
|
| 6 |
+
for param_name, param in module.named_parameters():
|
| 7 |
+
if predicate(module, param_name):
|
| 8 |
+
yield param
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_parameters_conv(model, name):
|
| 12 |
+
return get_parameters(model, lambda m, p: isinstance(m, nn.Conv2d) and m.groups == 1 and p == name)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_parameters_conv_depthwise(model, name):
|
| 16 |
+
return get_parameters(model, lambda m, p: isinstance(m, nn.Conv2d)
|
| 17 |
+
and m.groups == m.in_channels
|
| 18 |
+
and m.in_channels == m.out_channels
|
| 19 |
+
and p == name)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_parameters_bn(model, name):
|
| 23 |
+
return get_parameters(model, lambda m, p: isinstance(m, nn.BatchNorm2d) and p == name)
|
lite_openpose/modules/keypoints.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
from operator import itemgetter
|
| 4 |
+
|
| 5 |
+
BODY_PARTS_KPT_IDS = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [1, 11],
|
| 6 |
+
[11, 12], [12, 13], [1, 0], [0, 14], [14, 16], [0, 15], [15, 17], [2, 16], [5, 17]]
|
| 7 |
+
BODY_PARTS_PAF_IDS = ([12, 13], [20, 21], [14, 15], [16, 17], [22, 23], [24, 25], [0, 1], [2, 3], [4, 5],
|
| 8 |
+
[6, 7], [8, 9], [10, 11], [28, 29], [30, 31], [34, 35], [32, 33], [36, 37], [18, 19], [26, 27])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def linspace2d(start, stop, n=10):
|
| 12 |
+
points = 1 / (n - 1) * (stop - start)
|
| 13 |
+
return points[:, None] * np.arange(n) + start[:, None]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def extract_keypoints(heatmap, all_keypoints, total_keypoint_num):
|
| 17 |
+
heatmap[heatmap < 0.1] = 0
|
| 18 |
+
heatmap_with_borders = np.pad(heatmap, [(2, 2), (2, 2)], mode='constant')
|
| 19 |
+
heatmap_center = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 1:heatmap_with_borders.shape[1]-1]
|
| 20 |
+
heatmap_left = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 2:heatmap_with_borders.shape[1]]
|
| 21 |
+
heatmap_right = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 0:heatmap_with_borders.shape[1]-2]
|
| 22 |
+
heatmap_up = heatmap_with_borders[2:heatmap_with_borders.shape[0], 1:heatmap_with_borders.shape[1]-1]
|
| 23 |
+
heatmap_down = heatmap_with_borders[0:heatmap_with_borders.shape[0]-2, 1:heatmap_with_borders.shape[1]-1]
|
| 24 |
+
|
| 25 |
+
heatmap_peaks = (heatmap_center > heatmap_left) &\
|
| 26 |
+
(heatmap_center > heatmap_right) &\
|
| 27 |
+
(heatmap_center > heatmap_up) &\
|
| 28 |
+
(heatmap_center > heatmap_down)
|
| 29 |
+
heatmap_peaks = heatmap_peaks[1:heatmap_center.shape[0]-1, 1:heatmap_center.shape[1]-1]
|
| 30 |
+
keypoints = list(zip(np.nonzero(heatmap_peaks)[1], np.nonzero(heatmap_peaks)[0])) # (w, h)
|
| 31 |
+
keypoints = sorted(keypoints, key=itemgetter(0))
|
| 32 |
+
|
| 33 |
+
suppressed = np.zeros(len(keypoints), np.uint8)
|
| 34 |
+
keypoints_with_score_and_id = []
|
| 35 |
+
keypoint_num = 0
|
| 36 |
+
for i in range(len(keypoints)):
|
| 37 |
+
if suppressed[i]:
|
| 38 |
+
continue
|
| 39 |
+
for j in range(i+1, len(keypoints)):
|
| 40 |
+
if math.sqrt((keypoints[i][0] - keypoints[j][0]) ** 2 +
|
| 41 |
+
(keypoints[i][1] - keypoints[j][1]) ** 2) < 6:
|
| 42 |
+
suppressed[j] = 1
|
| 43 |
+
keypoint_with_score_and_id = (keypoints[i][0], keypoints[i][1], heatmap[keypoints[i][1], keypoints[i][0]],
|
| 44 |
+
total_keypoint_num + keypoint_num)
|
| 45 |
+
keypoints_with_score_and_id.append(keypoint_with_score_and_id)
|
| 46 |
+
keypoint_num += 1
|
| 47 |
+
all_keypoints.append(keypoints_with_score_and_id)
|
| 48 |
+
return keypoint_num
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def group_keypoints(all_keypoints_by_type, pafs, pose_entry_size=20, min_paf_score=0.05, demo=False):
|
| 52 |
+
pose_entries = []
|
| 53 |
+
all_keypoints = np.array([item for sublist in all_keypoints_by_type for item in sublist])
|
| 54 |
+
for part_id in range(len(BODY_PARTS_PAF_IDS)):
|
| 55 |
+
part_pafs = pafs[:, :, BODY_PARTS_PAF_IDS[part_id]]
|
| 56 |
+
kpts_a = all_keypoints_by_type[BODY_PARTS_KPT_IDS[part_id][0]]
|
| 57 |
+
kpts_b = all_keypoints_by_type[BODY_PARTS_KPT_IDS[part_id][1]]
|
| 58 |
+
num_kpts_a = len(kpts_a)
|
| 59 |
+
num_kpts_b = len(kpts_b)
|
| 60 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
| 61 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
| 62 |
+
|
| 63 |
+
if num_kpts_a == 0 and num_kpts_b == 0: # no keypoints for such body part
|
| 64 |
+
continue
|
| 65 |
+
elif num_kpts_a == 0: # body part has just 'b' keypoints
|
| 66 |
+
for i in range(num_kpts_b):
|
| 67 |
+
num = 0
|
| 68 |
+
for j in range(len(pose_entries)): # check if already in some pose, was added by another body part
|
| 69 |
+
if pose_entries[j][kpt_b_id] == kpts_b[i][3]:
|
| 70 |
+
num += 1
|
| 71 |
+
continue
|
| 72 |
+
if num == 0:
|
| 73 |
+
pose_entry = np.ones(pose_entry_size) * -1
|
| 74 |
+
pose_entry[kpt_b_id] = kpts_b[i][3] # keypoint idx
|
| 75 |
+
pose_entry[-1] = 1 # num keypoints in pose
|
| 76 |
+
pose_entry[-2] = kpts_b[i][2] # pose score
|
| 77 |
+
pose_entries.append(pose_entry)
|
| 78 |
+
continue
|
| 79 |
+
elif num_kpts_b == 0: # body part has just 'a' keypoints
|
| 80 |
+
for i in range(num_kpts_a):
|
| 81 |
+
num = 0
|
| 82 |
+
for j in range(len(pose_entries)):
|
| 83 |
+
if pose_entries[j][kpt_a_id] == kpts_a[i][3]:
|
| 84 |
+
num += 1
|
| 85 |
+
continue
|
| 86 |
+
if num == 0:
|
| 87 |
+
pose_entry = np.ones(pose_entry_size) * -1
|
| 88 |
+
pose_entry[kpt_a_id] = kpts_a[i][3]
|
| 89 |
+
pose_entry[-1] = 1
|
| 90 |
+
pose_entry[-2] = kpts_a[i][2]
|
| 91 |
+
pose_entries.append(pose_entry)
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
connections = []
|
| 95 |
+
for i in range(num_kpts_a):
|
| 96 |
+
kpt_a = np.array(kpts_a[i][0:2])
|
| 97 |
+
for j in range(num_kpts_b):
|
| 98 |
+
kpt_b = np.array(kpts_b[j][0:2])
|
| 99 |
+
mid_point = [(), ()]
|
| 100 |
+
mid_point[0] = (int(round((kpt_a[0] + kpt_b[0]) * 0.5)),
|
| 101 |
+
int(round((kpt_a[1] + kpt_b[1]) * 0.5)))
|
| 102 |
+
mid_point[1] = mid_point[0]
|
| 103 |
+
|
| 104 |
+
vec = [kpt_b[0] - kpt_a[0], kpt_b[1] - kpt_a[1]]
|
| 105 |
+
vec_norm = math.sqrt(vec[0] ** 2 + vec[1] ** 2)
|
| 106 |
+
if vec_norm == 0:
|
| 107 |
+
continue
|
| 108 |
+
vec[0] /= vec_norm
|
| 109 |
+
vec[1] /= vec_norm
|
| 110 |
+
cur_point_score = (vec[0] * part_pafs[mid_point[0][1], mid_point[0][0], 0] +
|
| 111 |
+
vec[1] * part_pafs[mid_point[1][1], mid_point[1][0], 1])
|
| 112 |
+
|
| 113 |
+
height_n = pafs.shape[0] // 2
|
| 114 |
+
success_ratio = 0
|
| 115 |
+
point_num = 10 # number of points to integration over paf
|
| 116 |
+
if cur_point_score > -100:
|
| 117 |
+
passed_point_score = 0
|
| 118 |
+
passed_point_num = 0
|
| 119 |
+
x, y = linspace2d(kpt_a, kpt_b)
|
| 120 |
+
for point_idx in range(point_num):
|
| 121 |
+
if not demo:
|
| 122 |
+
px = int(round(x[point_idx]))
|
| 123 |
+
py = int(round(y[point_idx]))
|
| 124 |
+
else:
|
| 125 |
+
px = int(x[point_idx])
|
| 126 |
+
py = int(y[point_idx])
|
| 127 |
+
paf = part_pafs[py, px, 0:2]
|
| 128 |
+
cur_point_score = vec[0] * paf[0] + vec[1] * paf[1]
|
| 129 |
+
if cur_point_score > min_paf_score:
|
| 130 |
+
passed_point_score += cur_point_score
|
| 131 |
+
passed_point_num += 1
|
| 132 |
+
success_ratio = passed_point_num / point_num
|
| 133 |
+
ratio = 0
|
| 134 |
+
if passed_point_num > 0:
|
| 135 |
+
ratio = passed_point_score / passed_point_num
|
| 136 |
+
ratio += min(height_n / vec_norm - 1, 0)
|
| 137 |
+
if ratio > 0 and success_ratio > 0.8:
|
| 138 |
+
score_all = ratio + kpts_a[i][2] + kpts_b[j][2]
|
| 139 |
+
connections.append([i, j, ratio, score_all])
|
| 140 |
+
if len(connections) > 0:
|
| 141 |
+
connections = sorted(connections, key=itemgetter(2), reverse=True)
|
| 142 |
+
|
| 143 |
+
num_connections = min(num_kpts_a, num_kpts_b)
|
| 144 |
+
has_kpt_a = np.zeros(num_kpts_a, dtype=np.int32)
|
| 145 |
+
has_kpt_b = np.zeros(num_kpts_b, dtype=np.int32)
|
| 146 |
+
filtered_connections = []
|
| 147 |
+
for row in range(len(connections)):
|
| 148 |
+
if len(filtered_connections) == num_connections:
|
| 149 |
+
break
|
| 150 |
+
i, j, cur_point_score = connections[row][0:3]
|
| 151 |
+
if not has_kpt_a[i] and not has_kpt_b[j]:
|
| 152 |
+
filtered_connections.append([kpts_a[i][3], kpts_b[j][3], cur_point_score])
|
| 153 |
+
has_kpt_a[i] = 1
|
| 154 |
+
has_kpt_b[j] = 1
|
| 155 |
+
connections = filtered_connections
|
| 156 |
+
if len(connections) == 0:
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
if part_id == 0:
|
| 160 |
+
pose_entries = [np.ones(pose_entry_size) * -1 for _ in range(len(connections))]
|
| 161 |
+
for i in range(len(connections)):
|
| 162 |
+
pose_entries[i][BODY_PARTS_KPT_IDS[0][0]] = connections[i][0]
|
| 163 |
+
pose_entries[i][BODY_PARTS_KPT_IDS[0][1]] = connections[i][1]
|
| 164 |
+
pose_entries[i][-1] = 2
|
| 165 |
+
pose_entries[i][-2] = np.sum(all_keypoints[connections[i][0:2], 2]) + connections[i][2]
|
| 166 |
+
elif part_id == 17 or part_id == 18:
|
| 167 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
| 168 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
| 169 |
+
for i in range(len(connections)):
|
| 170 |
+
for j in range(len(pose_entries)):
|
| 171 |
+
if pose_entries[j][kpt_a_id] == connections[i][0] and pose_entries[j][kpt_b_id] == -1:
|
| 172 |
+
pose_entries[j][kpt_b_id] = connections[i][1]
|
| 173 |
+
elif pose_entries[j][kpt_b_id] == connections[i][1] and pose_entries[j][kpt_a_id] == -1:
|
| 174 |
+
pose_entries[j][kpt_a_id] = connections[i][0]
|
| 175 |
+
continue
|
| 176 |
+
else:
|
| 177 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
| 178 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
| 179 |
+
for i in range(len(connections)):
|
| 180 |
+
num = 0
|
| 181 |
+
for j in range(len(pose_entries)):
|
| 182 |
+
if pose_entries[j][kpt_a_id] == connections[i][0]:
|
| 183 |
+
pose_entries[j][kpt_b_id] = connections[i][1]
|
| 184 |
+
num += 1
|
| 185 |
+
pose_entries[j][-1] += 1
|
| 186 |
+
pose_entries[j][-2] += all_keypoints[connections[i][1], 2] + connections[i][2]
|
| 187 |
+
if num == 0:
|
| 188 |
+
pose_entry = np.ones(pose_entry_size) * -1
|
| 189 |
+
pose_entry[kpt_a_id] = connections[i][0]
|
| 190 |
+
pose_entry[kpt_b_id] = connections[i][1]
|
| 191 |
+
pose_entry[-1] = 2
|
| 192 |
+
pose_entry[-2] = np.sum(all_keypoints[connections[i][0:2], 2]) + connections[i][2]
|
| 193 |
+
pose_entries.append(pose_entry)
|
| 194 |
+
|
| 195 |
+
filtered_entries = []
|
| 196 |
+
for i in range(len(pose_entries)):
|
| 197 |
+
if pose_entries[i][-1] < 3 or (pose_entries[i][-2] / pose_entries[i][-1] < 0.2):
|
| 198 |
+
continue
|
| 199 |
+
filtered_entries.append(pose_entries[i])
|
| 200 |
+
pose_entries = np.asarray(filtered_entries)
|
| 201 |
+
return pose_entries, all_keypoints
|
lite_openpose/modules/load_state.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def load_state(net, checkpoint):
|
| 5 |
+
source_state = checkpoint['state_dict']
|
| 6 |
+
target_state = net.state_dict()
|
| 7 |
+
new_target_state = collections.OrderedDict()
|
| 8 |
+
for target_key, target_value in target_state.items():
|
| 9 |
+
if target_key in source_state and source_state[target_key].size() == target_state[target_key].size():
|
| 10 |
+
new_target_state[target_key] = source_state[target_key]
|
| 11 |
+
else:
|
| 12 |
+
new_target_state[target_key] = target_state[target_key]
|
| 13 |
+
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
|
| 14 |
+
|
| 15 |
+
net.load_state_dict(new_target_state)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_from_mobilenet(net, checkpoint):
|
| 19 |
+
source_state = checkpoint['state_dict']
|
| 20 |
+
target_state = net.state_dict()
|
| 21 |
+
new_target_state = collections.OrderedDict()
|
| 22 |
+
for target_key, target_value in target_state.items():
|
| 23 |
+
k = target_key
|
| 24 |
+
if k.find('model') != -1:
|
| 25 |
+
k = k.replace('model', 'module.model')
|
| 26 |
+
if k in source_state and source_state[k].size() == target_state[target_key].size():
|
| 27 |
+
new_target_state[target_key] = source_state[k]
|
| 28 |
+
else:
|
| 29 |
+
new_target_state[target_key] = target_state[target_key]
|
| 30 |
+
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
|
| 31 |
+
|
| 32 |
+
net.load_state_dict(new_target_state)
|
lite_openpose/modules/loss.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def l2_loss(input, target, mask, batch_size):
|
| 2 |
+
loss = (input - target) * mask
|
| 3 |
+
loss = (loss * loss) / 2 / batch_size
|
| 4 |
+
|
| 5 |
+
return loss.sum()
|
lite_openpose/modules/one_euro_filter.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_alpha(rate=30, cutoff=1):
|
| 5 |
+
tau = 1 / (2 * math.pi * cutoff)
|
| 6 |
+
te = 1 / rate
|
| 7 |
+
return 1 / (1 + tau / te)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LowPassFilter:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.x_previous = None
|
| 13 |
+
|
| 14 |
+
def __call__(self, x, alpha=0.5):
|
| 15 |
+
if self.x_previous is None:
|
| 16 |
+
self.x_previous = x
|
| 17 |
+
return x
|
| 18 |
+
x_filtered = alpha * x + (1 - alpha) * self.x_previous
|
| 19 |
+
self.x_previous = x_filtered
|
| 20 |
+
return x_filtered
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OneEuroFilter:
|
| 24 |
+
def __init__(self, freq=15, mincutoff=1, beta=0.05, dcutoff=1):
|
| 25 |
+
self.freq = freq
|
| 26 |
+
self.mincutoff = mincutoff
|
| 27 |
+
self.beta = beta
|
| 28 |
+
self.dcutoff = dcutoff
|
| 29 |
+
self.filter_x = LowPassFilter()
|
| 30 |
+
self.filter_dx = LowPassFilter()
|
| 31 |
+
self.x_previous = None
|
| 32 |
+
self.dx = None
|
| 33 |
+
|
| 34 |
+
def __call__(self, x):
|
| 35 |
+
if self.dx is None:
|
| 36 |
+
self.dx = 0
|
| 37 |
+
else:
|
| 38 |
+
self.dx = (x - self.x_previous) * self.freq
|
| 39 |
+
dx_smoothed = self.filter_dx(self.dx, get_alpha(self.freq, self.dcutoff))
|
| 40 |
+
cutoff = self.mincutoff + self.beta * abs(dx_smoothed)
|
| 41 |
+
x_filtered = self.filter_x(x, get_alpha(self.freq, cutoff))
|
| 42 |
+
self.x_previous = x
|
| 43 |
+
return x_filtered
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == '__main__':
|
| 47 |
+
filter = OneEuroFilter(freq=15, beta=0.1)
|
| 48 |
+
for val in range(10):
|
| 49 |
+
x = val + (-1)**(val % 2)
|
| 50 |
+
x_filtered = filter(x)
|
| 51 |
+
print(x_filtered, x)
|
lite_openpose/modules/pose.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from modules.keypoints import BODY_PARTS_KPT_IDS, BODY_PARTS_PAF_IDS
|
| 5 |
+
from modules.one_euro_filter import OneEuroFilter
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Pose:
|
| 9 |
+
num_kpts = 18
|
| 10 |
+
kpt_names = ['nose', 'neck',
|
| 11 |
+
'r_sho', 'r_elb', 'r_wri', 'l_sho', 'l_elb', 'l_wri',
|
| 12 |
+
'r_hip', 'r_knee', 'r_ank', 'l_hip', 'l_knee', 'l_ank',
|
| 13 |
+
'r_eye', 'l_eye',
|
| 14 |
+
'r_ear', 'l_ear']
|
| 15 |
+
sigmas = np.array([.26, .79, .79, .72, .62, .79, .72, .62, 1.07, .87, .89, 1.07, .87, .89, .25, .25, .35, .35],
|
| 16 |
+
dtype=np.float32) / 10.0
|
| 17 |
+
vars = (sigmas * 2) ** 2
|
| 18 |
+
last_id = -1
|
| 19 |
+
color = [0, 224, 255]
|
| 20 |
+
|
| 21 |
+
def __init__(self, keypoints, confidence):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.keypoints = keypoints
|
| 24 |
+
self.confidence = confidence
|
| 25 |
+
self.bbox = Pose.get_bbox(self.keypoints)
|
| 26 |
+
self.id = None
|
| 27 |
+
self.filters = [[OneEuroFilter(), OneEuroFilter()] for _ in range(Pose.num_kpts)]
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def get_bbox(keypoints):
|
| 31 |
+
found_keypoints = np.zeros((np.count_nonzero(keypoints[:, 0] != -1), 2), dtype=np.int32)
|
| 32 |
+
found_kpt_id = 0
|
| 33 |
+
for kpt_id in range(Pose.num_kpts):
|
| 34 |
+
if keypoints[kpt_id, 0] == -1:
|
| 35 |
+
continue
|
| 36 |
+
found_keypoints[found_kpt_id] = keypoints[kpt_id]
|
| 37 |
+
found_kpt_id += 1
|
| 38 |
+
bbox = cv2.boundingRect(found_keypoints)
|
| 39 |
+
return bbox
|
| 40 |
+
|
| 41 |
+
def update_id(self, id=None):
|
| 42 |
+
self.id = id
|
| 43 |
+
if self.id is None:
|
| 44 |
+
self.id = Pose.last_id + 1
|
| 45 |
+
Pose.last_id += 1
|
| 46 |
+
|
| 47 |
+
def draw(self, img):
|
| 48 |
+
assert self.keypoints.shape == (Pose.num_kpts, 2)
|
| 49 |
+
|
| 50 |
+
for part_id in range(len(BODY_PARTS_PAF_IDS) - 2):
|
| 51 |
+
kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
|
| 52 |
+
global_kpt_a_id = self.keypoints[kpt_a_id, 0]
|
| 53 |
+
if global_kpt_a_id != -1:
|
| 54 |
+
x_a, y_a = self.keypoints[kpt_a_id]
|
| 55 |
+
cv2.circle(img, (int(x_a), int(y_a)), 3, Pose.color, -1)
|
| 56 |
+
kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
|
| 57 |
+
global_kpt_b_id = self.keypoints[kpt_b_id, 0]
|
| 58 |
+
if global_kpt_b_id != -1:
|
| 59 |
+
x_b, y_b = self.keypoints[kpt_b_id]
|
| 60 |
+
cv2.circle(img, (int(x_b), int(y_b)), 3, Pose.color, -1)
|
| 61 |
+
if global_kpt_a_id != -1 and global_kpt_b_id != -1:
|
| 62 |
+
cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), Pose.color, 2)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_similarity(a, b, threshold=0.5):
|
| 66 |
+
num_similar_kpt = 0
|
| 67 |
+
for kpt_id in range(Pose.num_kpts):
|
| 68 |
+
if a.keypoints[kpt_id, 0] != -1 and b.keypoints[kpt_id, 0] != -1:
|
| 69 |
+
distance = np.sum((a.keypoints[kpt_id] - b.keypoints[kpt_id]) ** 2)
|
| 70 |
+
area = max(a.bbox[2] * a.bbox[3], b.bbox[2] * b.bbox[3])
|
| 71 |
+
similarity = np.exp(-distance / (2 * (area + np.spacing(1)) * Pose.vars[kpt_id]))
|
| 72 |
+
if similarity > threshold:
|
| 73 |
+
num_similar_kpt += 1
|
| 74 |
+
return num_similar_kpt
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def track_poses(previous_poses, current_poses, threshold=3, smooth=False):
|
| 78 |
+
"""Propagate poses ids from previous frame results. Id is propagated,
|
| 79 |
+
if there are at least `threshold` similar keypoints between pose from previous frame and current.
|
| 80 |
+
If correspondence between pose on previous and current frame was established, pose keypoints are smoothed.
|
| 81 |
+
|
| 82 |
+
:param previous_poses: poses from previous frame with ids
|
| 83 |
+
:param current_poses: poses from current frame to assign ids
|
| 84 |
+
:param threshold: minimal number of similar keypoints between poses
|
| 85 |
+
:param smooth: smooth pose keypoints between frames
|
| 86 |
+
:return: None
|
| 87 |
+
"""
|
| 88 |
+
current_poses = sorted(current_poses, key=lambda pose: pose.confidence, reverse=True) # match confident poses first
|
| 89 |
+
mask = np.ones(len(previous_poses), dtype=np.int32)
|
| 90 |
+
for current_pose in current_poses:
|
| 91 |
+
best_matched_id = None
|
| 92 |
+
best_matched_pose_id = None
|
| 93 |
+
best_matched_iou = 0
|
| 94 |
+
for id, previous_pose in enumerate(previous_poses):
|
| 95 |
+
if not mask[id]:
|
| 96 |
+
continue
|
| 97 |
+
iou = get_similarity(current_pose, previous_pose)
|
| 98 |
+
if iou > best_matched_iou:
|
| 99 |
+
best_matched_iou = iou
|
| 100 |
+
best_matched_pose_id = previous_pose.id
|
| 101 |
+
best_matched_id = id
|
| 102 |
+
if best_matched_iou >= threshold:
|
| 103 |
+
mask[best_matched_id] = 0
|
| 104 |
+
else: # pose not similar to any previous
|
| 105 |
+
best_matched_pose_id = None
|
| 106 |
+
current_pose.update_id(best_matched_pose_id)
|
| 107 |
+
|
| 108 |
+
if smooth:
|
| 109 |
+
for kpt_id in range(Pose.num_kpts):
|
| 110 |
+
if current_pose.keypoints[kpt_id, 0] == -1:
|
| 111 |
+
continue
|
| 112 |
+
# reuse filter if previous pose has valid filter
|
| 113 |
+
if (best_matched_pose_id is not None
|
| 114 |
+
and previous_poses[best_matched_id].keypoints[kpt_id, 0] != -1):
|
| 115 |
+
current_pose.filters[kpt_id] = previous_poses[best_matched_id].filters[kpt_id]
|
| 116 |
+
current_pose.keypoints[kpt_id, 0] = current_pose.filters[kpt_id][0](current_pose.keypoints[kpt_id, 0])
|
| 117 |
+
current_pose.keypoints[kpt_id, 1] = current_pose.filters[kpt_id][1](current_pose.keypoints[kpt_id, 1])
|
| 118 |
+
current_pose.bbox = Pose.get_bbox(current_pose.keypoints)
|
lite_openpose/pose2d_models/__init__.py
ADDED
|
File without changes
|
lite_openpose/pose2d_models/with_mobilenet.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from modules.conv import conv, conv_dw, conv_dw_no_bn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Cpm(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, out_channels):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.align = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
|
| 11 |
+
self.trunk = nn.Sequential(
|
| 12 |
+
conv_dw_no_bn(out_channels, out_channels),
|
| 13 |
+
conv_dw_no_bn(out_channels, out_channels),
|
| 14 |
+
conv_dw_no_bn(out_channels, out_channels)
|
| 15 |
+
)
|
| 16 |
+
self.conv = conv(out_channels, out_channels, bn=False)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x = self.align(x)
|
| 20 |
+
x = self.conv(x + self.trunk(x))
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InitialStage(nn.Module):
|
| 25 |
+
def __init__(self, num_channels, num_heatmaps, num_pafs):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.trunk = nn.Sequential(
|
| 28 |
+
conv(num_channels, num_channels, bn=False),
|
| 29 |
+
conv(num_channels, num_channels, bn=False),
|
| 30 |
+
conv(num_channels, num_channels, bn=False)
|
| 31 |
+
)
|
| 32 |
+
self.heatmaps = nn.Sequential(
|
| 33 |
+
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
|
| 34 |
+
conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
|
| 35 |
+
)
|
| 36 |
+
self.pafs = nn.Sequential(
|
| 37 |
+
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
|
| 38 |
+
conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
trunk_features = self.trunk(x)
|
| 43 |
+
heatmaps = self.heatmaps(trunk_features)
|
| 44 |
+
pafs = self.pafs(trunk_features)
|
| 45 |
+
return [heatmaps, pafs]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class RefinementStageBlock(nn.Module):
|
| 49 |
+
def __init__(self, in_channels, out_channels):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
|
| 52 |
+
self.trunk = nn.Sequential(
|
| 53 |
+
conv(out_channels, out_channels),
|
| 54 |
+
conv(out_channels, out_channels, dilation=2, padding=2)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
initial_features = self.initial(x)
|
| 59 |
+
trunk_features = self.trunk(initial_features)
|
| 60 |
+
return initial_features + trunk_features
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class RefinementStage(nn.Module):
|
| 64 |
+
def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.trunk = nn.Sequential(
|
| 67 |
+
RefinementStageBlock(in_channels, out_channels),
|
| 68 |
+
RefinementStageBlock(out_channels, out_channels),
|
| 69 |
+
RefinementStageBlock(out_channels, out_channels),
|
| 70 |
+
RefinementStageBlock(out_channels, out_channels),
|
| 71 |
+
RefinementStageBlock(out_channels, out_channels)
|
| 72 |
+
)
|
| 73 |
+
self.heatmaps = nn.Sequential(
|
| 74 |
+
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
|
| 75 |
+
conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
|
| 76 |
+
)
|
| 77 |
+
self.pafs = nn.Sequential(
|
| 78 |
+
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
|
| 79 |
+
conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
trunk_features = self.trunk(x)
|
| 84 |
+
heatmaps = self.heatmaps(trunk_features)
|
| 85 |
+
pafs = self.pafs(trunk_features)
|
| 86 |
+
return [heatmaps, pafs]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class PoseEstimationWithMobileNet(nn.Module):
|
| 90 |
+
def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.model = nn.Sequential(
|
| 93 |
+
conv( 3, 32, stride=2, bias=False),
|
| 94 |
+
conv_dw( 32, 64),
|
| 95 |
+
conv_dw( 64, 128, stride=2),
|
| 96 |
+
conv_dw(128, 128),
|
| 97 |
+
conv_dw(128, 256, stride=2),
|
| 98 |
+
conv_dw(256, 256),
|
| 99 |
+
conv_dw(256, 512), # conv4_2
|
| 100 |
+
conv_dw(512, 512, dilation=2, padding=2),
|
| 101 |
+
conv_dw(512, 512),
|
| 102 |
+
conv_dw(512, 512),
|
| 103 |
+
conv_dw(512, 512),
|
| 104 |
+
conv_dw(512, 512) # conv5_5
|
| 105 |
+
)
|
| 106 |
+
self.cpm = Cpm(512, num_channels)
|
| 107 |
+
|
| 108 |
+
self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
|
| 109 |
+
self.refinement_stages = nn.ModuleList()
|
| 110 |
+
for idx in range(num_refinement_stages):
|
| 111 |
+
self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
|
| 112 |
+
num_heatmaps, num_pafs))
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
backbone_features = self.model(x)
|
| 116 |
+
backbone_features = self.cpm(backbone_features)
|
| 117 |
+
|
| 118 |
+
stages_output = self.initial_stage(backbone_features)
|
| 119 |
+
for refinement_stage in self.refinement_stages:
|
| 120 |
+
stages_output.extend(
|
| 121 |
+
refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1)))
|
| 122 |
+
|
| 123 |
+
return stages_output
|