RIFE.axera / run_onnx.py
jounery-d's picture
first commit
6756f18
import os
import cv2
import time
import argparse
import numpy as np
import onnxruntime as ort
import _thread
import torch
import torch.nn.functional as F
import ms_ssim
from tqdm import tqdm
from queue import Queue, Empty
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--video', dest='video', type=str, default='./demo.mp4')
parser.add_argument('--output', dest='output', type=str, default=None)
parser.add_argument('--img', dest='img', type=str, default=None)
parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
parser.add_argument('--model', dest='model', type=str, default=None, help='directory with trained model files')
parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
parser.add_argument('--fps', dest='fps', type=int, default=None)
parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
parser.add_argument('--exp', dest='exp', type=int, default=1)
parser.add_argument('--multi', dest='multi', type=int, default=2)
def read_video(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Cannot open video: {video_path}")
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
yield frame
finally:
cap.release()
def clear_write_buffer(user_args, write_buffer, vid_out):
cnt = 0
while True:
item = write_buffer.get()
if item is None:
break
if user_args.png:
cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
cnt += 1
else:
vid_out.write(item[:, :, ::-1])
def build_read_buffer(user_args, read_buffer, videogen):
try:
for frame in videogen:
if not user_args.img is None:
frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
if user_args.montage:
frame = frame[:, left: left + w]
read_buffer.put(frame)
except:
pass
read_buffer.put(None)
def pad_image(img, padding):
if(args.fp16):
return F.pad(img, padding).half()
else:
return F.pad(img, padding)
def run(args):
'''onnx inference'''
# model
session = ort.InferenceSession(args.model, providers=['CPUExecutionProvider'])
output_names = [x.name for x in session.get_outputs()]
input_name = session.get_inputs()[0].name
# video
videoCapture = cv2.VideoCapture(args.video)
fps = videoCapture.get(cv2.CAP_PROP_FPS)
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
videoCapture.release()
if args.fps is None:
fpsNotAssigned = True
args.fps = fps * args.multi
else:
fpsNotAssigned = False
videogen = read_video(args.video)
lastframe = next(videogen)
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
video_path_wo_ext, ext = os.path.splitext(args.video)
print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
if args.png == False and fpsNotAssigned == True:
print("The audio will be merged after interpolation process")
else:
print("Will not merge audio because using png or fps flag!")
#
h, w, _ = lastframe.shape
vid_out_name = None
vid_out = None
if args.png:
if not os.path.exists('vid_out'):
os.mkdir('vid_out')
else:
if args.output is not None:
vid_out_name = args.output
else:
vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
tmp = max(128, int(128 / args.scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
#padding = (0, pw - w, 0, ph - h)
padding = ((0, 0), (0, 0), (0, ph - h), (0, pw - w))
pbar = tqdm(total=tot_frame, ncols=80)
write_buffer = Queue(maxsize=500)
read_buffer = Queue(maxsize=500)
_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
_thread.start_new_thread(clear_write_buffer, (args, write_buffer, vid_out))
#device = 'cpu'
#I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = np.expand_dims(np.transpose(lastframe, (2,0,1)), 0).astype(np.float32) / 255.
I1 = np.pad(I1, padding)
temp = None # save lastframe when processing static frame
while True:
if temp is not None:
frame = temp
temp = None
else:
frame = read_buffer.get()
if frame is None:
break
I0 = I1
#I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = np.expand_dims(np.transpose(frame, (2,0,1)), 0).astype(np.float32) / 255.
I1 = np.pad(I1, padding)
I0_small = F.interpolate(torch.from_numpy(I0).float(), (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(torch.from_numpy(I1).float(), (32, 32), mode='bilinear', align_corners=False)
ssim = ms_ssim.ssim_matlab(I0_small[:, :3], I1_small[:, :3])
break_flag = False
if ssim > 0.996: #0.996
frame = read_buffer.get() # read a new frame
if frame is None:
break_flag = True
frame = lastframe
else:
temp = frame
#I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
I1 = np.expand_dims(np.transpose(frame, (2,0,1)), 0).astype(np.float32) / 255.
I1 = np.pad(I1, padding)
#imgs = torch.cat((I0, I1), 1).cpu().numpy()
imgs = np.concatenate((I0, I1), axis=1)
I1 = session.run(output_names, {input_name: imgs})
#I1 = torch.from_numpy(I1[-1])
I1 = np.array(I1[-1])
I1_small = F.interpolate(torch.from_numpy(I1).float(), (32, 32), mode='bilinear', align_corners=False)
ssim = ms_ssim.ssim_matlab(I0_small[:, :3], I1_small[:, :3])
#frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
frame = np.clip(I1[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)[:h, :w]
if ssim < 0.2:
output = []
for i in range(args.multi - 1):
output.append(I0)
'''
output = []
step = 1 / args.multi
alpha = 0
for i in range(args.multi - 1):
alpha += step
beta = 1-alpha
output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
'''
else:
imgs = np.concatenate((I0, I1), axis=1)
output = [session.run(output_names, {input_name: imgs})[-1]]
if args.montage:
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
for mid in output:
#mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
mid = np.clip(mid[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)
write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
else:
write_buffer.put(lastframe)
for mid in output:
#mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
mid = np.clip(mid[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)
write_buffer.put(mid[:h, :w])
pbar.update(1)
lastframe = frame
if break_flag:
break
if args.montage:
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
else:
write_buffer.put(lastframe)
write_buffer.put(None)
while(not write_buffer.empty()):
time.sleep(0.1)
pbar.close()
if not vid_out is None:
vid_out.release()
if __name__ == '__main__':
args = parser.parse_args()
if args.exp != 1:
args.multi = (2 ** args.exp)
assert (not args.video is None or not args.img is None)
if args.skip:
print("skip flag is abandoned, please refer to issue #207.")
if args.UHD and args.scale==1.0:
args.scale = 0.5
assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
if not args.img is None:
args.png = True
run(args)