Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -66,23 +66,14 @@ COLORS = [
|
|
| 66 |
[255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]
|
| 67 |
]
|
| 68 |
|
| 69 |
-
|
| 70 |
-
skeleton = []
|
| 71 |
-
count = 0
|
| 72 |
-
color_idx = 0
|
| 73 |
-
prev_pt = None
|
| 74 |
-
prev_pt_idx = None
|
| 75 |
-
prev_clicked = None
|
| 76 |
-
original_support_image = None
|
| 77 |
-
checkpoint_path = ''
|
| 78 |
-
|
| 79 |
-
def process(query_img,
|
| 80 |
cfg_path='configs/demo_b.py'):
|
| 81 |
-
global skeleton
|
| 82 |
cfg = Config.fromfile(cfg_path)
|
| 83 |
-
kp_src_np = np.array(kp_src).copy().astype(np.float32)
|
| 84 |
-
kp_src_np[:, 0] = kp_src_np[:,
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
kp_src_np = np.flip(kp_src_np, 1).copy()
|
| 87 |
kp_src_tensor = torch.tensor(kp_src_np).float()
|
| 88 |
preprocess = transforms.Compose([
|
|
@@ -91,10 +82,10 @@ def process(query_img,
|
|
| 91 |
Resize_Pad(cfg.model.encoder_config.img_size,
|
| 92 |
cfg.model.encoder_config.img_size)])
|
| 93 |
|
| 94 |
-
if len(skeleton) == 0:
|
| 95 |
skeleton = [(0, 0)]
|
| 96 |
|
| 97 |
-
support_img = preprocess(original_support_image).flip(0)[None]
|
| 98 |
np_query = np.array(query_img)[:, :, ::-1].copy()
|
| 99 |
q_img = preprocess(np_query).flip(0)[None]
|
| 100 |
# Create heatmap from keypoints
|
|
@@ -104,9 +95,9 @@ def process(query_img,
|
|
| 104 |
cfg.model.encoder_config.img_size])
|
| 105 |
data_cfg['joint_weights'] = None
|
| 106 |
data_cfg['use_different_joint_weights'] = False
|
| 107 |
-
kp_src_3d = torch.
|
| 108 |
(kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
| 109 |
-
kp_src_3d_weight = torch.
|
| 110 |
(torch.ones_like(kp_src_tensor),
|
| 111 |
torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
| 112 |
target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,
|
|
@@ -125,8 +116,8 @@ def process(query_img,
|
|
| 125 |
'target_q': None,
|
| 126 |
'target_weight_q': None,
|
| 127 |
'return_loss': False,
|
| 128 |
-
'img_metas': [{'sample_skeleton': [skeleton],
|
| 129 |
-
'query_skeleton': skeleton,
|
| 130 |
'sample_joints_3d': [kp_src_3d],
|
| 131 |
'query_joints_3d': kp_src_3d,
|
| 132 |
'sample_center': [kp_src_tensor.mean(dim=0)],
|
|
@@ -165,54 +156,77 @@ def process(query_img,
|
|
| 165 |
vis_s_weight,
|
| 166 |
None,
|
| 167 |
vis_q_weight,
|
| 168 |
-
skeleton,
|
| 169 |
None,
|
| 170 |
torch.tensor(outputs['points']).squeeze(0),
|
| 171 |
)
|
| 172 |
-
return out
|
| 173 |
|
| 174 |
|
| 175 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
gr.Markdown('''
|
| 177 |
# Pose Anything Demo
|
| 178 |
-
We present a novel approach to category agnostic pose estimation that
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
## Instructions
|
| 182 |
1. Upload an image of the object you want to pose on the **left** image.
|
| 183 |
2. Click on the **left** image to mark keypoints.
|
| 184 |
3. Click on the keypoints on the **right** image to mark limbs.
|
| 185 |
-
4. Upload an image of the object you want to pose to the query image (
|
|
|
|
| 186 |
5. Click **Evaluate** to pose the query image.
|
| 187 |
''')
|
| 188 |
with gr.Row():
|
| 189 |
support_img = gr.Image(label="Support Image",
|
| 190 |
type="pil",
|
| 191 |
info='Click to mark keypoints').style(
|
| 192 |
-
height=
|
| 193 |
posed_support = gr.Image(label="Posed Support Image",
|
| 194 |
type="pil",
|
| 195 |
-
interactive=False).style(height=
|
|
|
|
| 196 |
with gr.Row():
|
| 197 |
query_img = gr.Image(label="Query Image",
|
| 198 |
-
type="pil").style(height=
|
| 199 |
with gr.Row():
|
| 200 |
eval_btn = gr.Button(value="Evaluate")
|
| 201 |
with gr.Row():
|
| 202 |
-
output_img = gr.Plot(label="Output Image", height=
|
| 203 |
|
| 204 |
|
| 205 |
def get_select_coords(kp_support,
|
| 206 |
limb_support,
|
|
|
|
| 207 |
evt: gr.SelectData,
|
| 208 |
r=0.015):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
pixels_in_queue = set()
|
| 210 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
| 211 |
while len(pixels_in_queue) > 0:
|
| 212 |
pixel = pixels_in_queue.pop()
|
| 213 |
if pixel[0] is not None and pixel[
|
| 214 |
-
1] is not None and pixel not in kp_src:
|
| 215 |
-
kp_src.append(pixel)
|
| 216 |
else:
|
| 217 |
print("Invalid pixel")
|
| 218 |
if limb_support is None:
|
|
@@ -230,13 +244,13 @@ with gr.Blocks() as demo:
|
|
| 230 |
draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
| 231 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
| 232 |
|
| 233 |
-
return canvas_kp, canvas_limb
|
| 234 |
|
| 235 |
|
| 236 |
def get_limbs(kp_support,
|
|
|
|
| 237 |
evt: gr.SelectData,
|
| 238 |
r=0.02, width=0.02):
|
| 239 |
-
global count, color_idx, prev_pt, skeleton, prev_pt_idx, prev_clicked
|
| 240 |
curr_pixel = (evt.index[1], evt.index[0])
|
| 241 |
pixels_in_queue = set()
|
| 242 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
|
@@ -244,64 +258,62 @@ with gr.Blocks() as demo:
|
|
| 244 |
w, h = canvas_kp.size
|
| 245 |
r = int(r * w)
|
| 246 |
width = int(width * w)
|
| 247 |
-
while
|
| 248 |
-
curr_pixel != prev_clicked and
|
| 249 |
-
len(kp_src) > 0):
|
| 250 |
pixel = pixels_in_queue.pop()
|
| 251 |
-
prev_clicked = pixel
|
| 252 |
-
closest_point = min(kp_src,
|
| 253 |
key=lambda p: (p[0] - pixel[0]) ** 2 +
|
| 254 |
(p[1] - pixel[1]) ** 2)
|
| 255 |
-
closest_point_index = kp_src.index(closest_point)
|
| 256 |
draw_limb = ImageDraw.Draw(canvas_kp)
|
| 257 |
-
if color_idx < len(COLORS):
|
| 258 |
-
c = COLORS[color_idx]
|
| 259 |
else:
|
| 260 |
c = random.choices(range(256), k=3)
|
| 261 |
leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
|
| 262 |
rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
|
| 263 |
twoPointList = [leftUpPoint, rightDownPoint]
|
| 264 |
draw_limb.ellipse(twoPointList, fill=tuple(c))
|
| 265 |
-
if count == 0:
|
| 266 |
-
prev_pt = closest_point[1], closest_point[0]
|
| 267 |
-
prev_pt_idx = closest_point_index
|
| 268 |
-
count = count + 1
|
| 269 |
else:
|
| 270 |
-
if prev_pt_idx != closest_point_index:
|
| 271 |
# Create Line and add Limb
|
| 272 |
-
draw_limb.line(
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
| 277 |
else:
|
| 278 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
| 279 |
-
count = 0
|
| 280 |
-
return canvas_kp
|
| 281 |
|
| 282 |
|
| 283 |
-
def
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
original_support_image = np.array(support_img)[:, :, ::-1].copy()
|
| 288 |
support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS)
|
| 289 |
-
return support_img, support_img
|
| 290 |
|
| 291 |
|
| 292 |
support_img.select(get_select_coords,
|
| 293 |
-
[support_img, posed_support],
|
| 294 |
-
[support_img, posed_support]
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
outputs=[support_img,posed_support])
|
| 299 |
posed_support.select(get_limbs,
|
| 300 |
-
posed_support,
|
| 301 |
-
posed_support)
|
| 302 |
eval_btn.click(fn=process,
|
| 303 |
-
inputs=[query_img],
|
| 304 |
-
outputs=output_img)
|
|
|
|
| 305 |
|
| 306 |
if __name__ == "__main__":
|
| 307 |
parser = argparse.ArgumentParser(description='Pose Anything Demo')
|
|
|
|
| 66 |
[255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]
|
| 67 |
]
|
| 68 |
|
| 69 |
+
def process(query_img, state,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
cfg_path='configs/demo_b.py'):
|
|
|
|
| 71 |
cfg = Config.fromfile(cfg_path)
|
| 72 |
+
kp_src_np = np.array(state['kp_src']).copy().astype(np.float32)
|
| 73 |
+
kp_src_np[:, 0] = kp_src_np[:,
|
| 74 |
+
0] / 128. * cfg.model.encoder_config.img_size
|
| 75 |
+
kp_src_np[:, 1] = kp_src_np[:,
|
| 76 |
+
1] / 128. * cfg.model.encoder_config.img_size
|
| 77 |
kp_src_np = np.flip(kp_src_np, 1).copy()
|
| 78 |
kp_src_tensor = torch.tensor(kp_src_np).float()
|
| 79 |
preprocess = transforms.Compose([
|
|
|
|
| 82 |
Resize_Pad(cfg.model.encoder_config.img_size,
|
| 83 |
cfg.model.encoder_config.img_size)])
|
| 84 |
|
| 85 |
+
if len(state['skeleton']) == 0:
|
| 86 |
skeleton = [(0, 0)]
|
| 87 |
|
| 88 |
+
support_img = preprocess(state['original_support_image']).flip(0)[None]
|
| 89 |
np_query = np.array(query_img)[:, :, ::-1].copy()
|
| 90 |
q_img = preprocess(np_query).flip(0)[None]
|
| 91 |
# Create heatmap from keypoints
|
|
|
|
| 95 |
cfg.model.encoder_config.img_size])
|
| 96 |
data_cfg['joint_weights'] = None
|
| 97 |
data_cfg['use_different_joint_weights'] = False
|
| 98 |
+
kp_src_3d = torch.concatenate(
|
| 99 |
(kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
| 100 |
+
kp_src_3d_weight = torch.concatenate(
|
| 101 |
(torch.ones_like(kp_src_tensor),
|
| 102 |
torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
| 103 |
target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,
|
|
|
|
| 116 |
'target_q': None,
|
| 117 |
'target_weight_q': None,
|
| 118 |
'return_loss': False,
|
| 119 |
+
'img_metas': [{'sample_skeleton': [state['skeleton']],
|
| 120 |
+
'query_skeleton': state['skeleton'],
|
| 121 |
'sample_joints_3d': [kp_src_3d],
|
| 122 |
'query_joints_3d': kp_src_3d,
|
| 123 |
'sample_center': [kp_src_tensor.mean(dim=0)],
|
|
|
|
| 156 |
vis_s_weight,
|
| 157 |
None,
|
| 158 |
vis_q_weight,
|
| 159 |
+
state['skeleton'],
|
| 160 |
None,
|
| 161 |
torch.tensor(outputs['points']).squeeze(0),
|
| 162 |
)
|
| 163 |
+
return out, state
|
| 164 |
|
| 165 |
|
| 166 |
with gr.Blocks() as demo:
|
| 167 |
+
state = gr.State({
|
| 168 |
+
'kp_src': [],
|
| 169 |
+
'skeleton': [],
|
| 170 |
+
'count': 0,
|
| 171 |
+
'color_idx': 0,
|
| 172 |
+
'prev_pt': None,
|
| 173 |
+
'prev_pt_idx': None,
|
| 174 |
+
'prev_clicked': None,
|
| 175 |
+
'original_support_image': None,
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
gr.Markdown('''
|
| 179 |
# Pose Anything Demo
|
| 180 |
+
We present a novel approach to category agnostic pose estimation that
|
| 181 |
+
leverages the inherent geometrical relations between keypoints through a
|
| 182 |
+
newly designed Graph Transformer Decoder. By capturing and incorporating
|
| 183 |
+
this crucial structural information, our method enhances the accuracy of
|
| 184 |
+
keypoint localization, marking a significant departure from conventional
|
| 185 |
+
CAPE techniques that treat keypoints as isolated entities.
|
| 186 |
+
### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](
|
| 187 |
+
https://github.com/orhir/PoseAnything)
|
| 188 |
## Instructions
|
| 189 |
1. Upload an image of the object you want to pose on the **left** image.
|
| 190 |
2. Click on the **left** image to mark keypoints.
|
| 191 |
3. Click on the keypoints on the **right** image to mark limbs.
|
| 192 |
+
4. Upload an image of the object you want to pose to the query image (
|
| 193 |
+
**bottom**).
|
| 194 |
5. Click **Evaluate** to pose the query image.
|
| 195 |
''')
|
| 196 |
with gr.Row():
|
| 197 |
support_img = gr.Image(label="Support Image",
|
| 198 |
type="pil",
|
| 199 |
info='Click to mark keypoints').style(
|
| 200 |
+
height=400, width=400)
|
| 201 |
posed_support = gr.Image(label="Posed Support Image",
|
| 202 |
type="pil",
|
| 203 |
+
interactive=False).style(height=400,
|
| 204 |
+
width=400)
|
| 205 |
with gr.Row():
|
| 206 |
query_img = gr.Image(label="Query Image",
|
| 207 |
+
type="pil").style(height=400, width=400)
|
| 208 |
with gr.Row():
|
| 209 |
eval_btn = gr.Button(value="Evaluate")
|
| 210 |
with gr.Row():
|
| 211 |
+
output_img = gr.Plot(label="Output Image", height=400, width=400)
|
| 212 |
|
| 213 |
|
| 214 |
def get_select_coords(kp_support,
|
| 215 |
limb_support,
|
| 216 |
+
state,
|
| 217 |
evt: gr.SelectData,
|
| 218 |
r=0.015):
|
| 219 |
+
# global original_support_image
|
| 220 |
+
# if len(kp_src) == 0:
|
| 221 |
+
# original_support_image = np.array(kp_support)[:, :,
|
| 222 |
+
# ::-1].copy()
|
| 223 |
pixels_in_queue = set()
|
| 224 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
| 225 |
while len(pixels_in_queue) > 0:
|
| 226 |
pixel = pixels_in_queue.pop()
|
| 227 |
if pixel[0] is not None and pixel[
|
| 228 |
+
1] is not None and pixel not in state['kp_src']:
|
| 229 |
+
state['kp_src'].append(pixel)
|
| 230 |
else:
|
| 231 |
print("Invalid pixel")
|
| 232 |
if limb_support is None:
|
|
|
|
| 244 |
draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
| 245 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
| 246 |
|
| 247 |
+
return canvas_kp, canvas_limb, state
|
| 248 |
|
| 249 |
|
| 250 |
def get_limbs(kp_support,
|
| 251 |
+
state,
|
| 252 |
evt: gr.SelectData,
|
| 253 |
r=0.02, width=0.02):
|
|
|
|
| 254 |
curr_pixel = (evt.index[1], evt.index[0])
|
| 255 |
pixels_in_queue = set()
|
| 256 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
|
|
|
| 258 |
w, h = canvas_kp.size
|
| 259 |
r = int(r * w)
|
| 260 |
width = int(width * w)
|
| 261 |
+
while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']:
|
|
|
|
|
|
|
| 262 |
pixel = pixels_in_queue.pop()
|
| 263 |
+
state['prev_clicked'] = pixel
|
| 264 |
+
closest_point = min(state['kp_src'],
|
| 265 |
key=lambda p: (p[0] - pixel[0]) ** 2 +
|
| 266 |
(p[1] - pixel[1]) ** 2)
|
| 267 |
+
closest_point_index = state['kp_src'].index(closest_point)
|
| 268 |
draw_limb = ImageDraw.Draw(canvas_kp)
|
| 269 |
+
if state['color_idx'] < len(COLORS):
|
| 270 |
+
c = COLORS[state['color_idx']]
|
| 271 |
else:
|
| 272 |
c = random.choices(range(256), k=3)
|
| 273 |
leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
|
| 274 |
rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
|
| 275 |
twoPointList = [leftUpPoint, rightDownPoint]
|
| 276 |
draw_limb.ellipse(twoPointList, fill=tuple(c))
|
| 277 |
+
if state['count'] == 0:
|
| 278 |
+
state['prev_pt'] = closest_point[1], closest_point[0]
|
| 279 |
+
state['prev_pt_idx'] = closest_point_index
|
| 280 |
+
state['count'] = state['count'] + 1
|
| 281 |
else:
|
| 282 |
+
if state['prev_pt_idx'] != closest_point_index:
|
| 283 |
# Create Line and add Limb
|
| 284 |
+
draw_limb.line(
|
| 285 |
+
[state['prev_pt'], (closest_point[1], closest_point[0])],
|
| 286 |
+
fill=tuple(c),
|
| 287 |
+
width=width)
|
| 288 |
+
state['skeleton'].append((state['prev_pt_idx'], closest_point_index))
|
| 289 |
+
state['color_idx'] = state['color_idx'] + 1
|
| 290 |
else:
|
| 291 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
| 292 |
+
state['count'] = 0
|
| 293 |
+
return canvas_kp, state
|
| 294 |
|
| 295 |
|
| 296 |
+
def set_qery(support_img, state):
|
| 297 |
+
state['skeleton'].clear()
|
| 298 |
+
state['kp_src'].clear()
|
| 299 |
+
state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()
|
|
|
|
| 300 |
support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS)
|
| 301 |
+
return support_img, support_img, state
|
| 302 |
|
| 303 |
|
| 304 |
support_img.select(get_select_coords,
|
| 305 |
+
[support_img, posed_support, state],
|
| 306 |
+
[support_img, posed_support, state])
|
| 307 |
+
support_img.upload(set_qery,
|
| 308 |
+
inputs=[support_img, state],
|
| 309 |
+
outputs=[support_img, posed_support, state])
|
|
|
|
| 310 |
posed_support.select(get_limbs,
|
| 311 |
+
[posed_support, state],
|
| 312 |
+
[posed_support, state])
|
| 313 |
eval_btn.click(fn=process,
|
| 314 |
+
inputs=[query_img, state],
|
| 315 |
+
outputs=[output_img, state])
|
| 316 |
+
|
| 317 |
|
| 318 |
if __name__ == "__main__":
|
| 319 |
parser = argparse.ArgumentParser(description='Pose Anything Demo')
|