jounery-d commited on
Commit
6756f18
·
0 Parent(s):

first commit

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. README.md +90 -0
  3. build_config.json +51 -0
  4. ms_ssim.py +200 -0
  5. requirements.txt +5 -0
  6. run_axmodel.py +233 -0
  7. run_onnx.py +233 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ base_model:
6
+ - rife
7
+ pipeline_tag: frame
8
+ tags:
9
+ - Image
10
+ - SuperResolution
11
+ ---
12
+
13
+ # RIFE
14
+
15
+ This version of RIFE has been converted to run on the Axera NPU using **w16a16** quantization.
16
+
17
+ This model has been optimized with the following LoRA:
18
+
19
+ Compatible with Pulsar2 version: 4.2
20
+
21
+ ## Convert tools links:
22
+
23
+ For those who are interested in model conversion, you can try to export axmodel through
24
+
25
+ - [The repo of AXera Platform](https://github.com/AXERA-TECH/ax-samples), which you can get the detail of guide
26
+
27
+ - [Pulsar2 Link, How to Convert ONNX to axmodel](https://pulsar2-docs.readthedocs.io/en/latest/pulsar2/introduction.html)
28
+
29
+
30
+ ## Support Platform
31
+
32
+ - AX650
33
+ - [M4N-Dock(爱芯派Pro)](https://wiki.sipeed.com/hardware/zh/maixIV/m4ndock/m4ndock.html)
34
+ - [M.2 Accelerator card](https://axcl-docs.readthedocs.io/zh-cn/latest/doc_guide_hardware.html)
35
+ - AX630C
36
+ - [爱芯派2](https://axera-pi-2-docs-cn.readthedocs.io/zh-cn/latest/index.html)
37
+ - [Module-LLM](https://docs.m5stack.com/zh_CN/module/Module-LLM)
38
+ - [LLM630 Compute Kit](https://docs.m5stack.com/zh_CN/core/LLM630%20Compute%20Kit)
39
+
40
+ |Chips|model|cost|
41
+ |--|--|--|
42
+ |AX650|RIFE|200 ms|
43
+
44
+ ## How to use
45
+
46
+ Download all files from this repository to the device
47
+
48
+ ```
49
+
50
+ root@ax650:~/rife# tree
51
+ .
52
+ |-- model
53
+ | `-- rife_x2_720p.onnx
54
+ | `-- rife_x2_720p.axmodel
55
+ |`-- run_onnx.py
56
+ |`-- run_axmodel.py
57
+ |`-- ms_ssim.py
58
+ |`-- build_config.json
59
+ |`-- requirements.txt
60
+
61
+
62
+
63
+ ```
64
+
65
+ ### Inference
66
+
67
+ Input Data:
68
+ |-- video
69
+ | `-- demo.mp4
70
+
71
+ #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
72
+
73
+ ```
74
+ root@ax650 ~/rife #python3 run_axmodel.py --model ./rife_x2_720p.axmodel --video ./demo.mp4
75
+ [INFO] Available providers: ['AxEngineExecutionProvider']
76
+ [INFO] Using provider: AxEngineExecutionProvider
77
+ [INFO] Chip type: ChipType.MC50
78
+ [INFO] VNPU type: VNPUType.DISABLED
79
+ [INFO] Engine version: 2.12.0s
80
+ [INFO] Model type: 2 (triple core)
81
+ [INFO] Compiler version: 4.2 77cdc0c2
82
+ input name: onnx::Slice_0
83
+ demo.mp4, 128.0 frames in total, 25.0FPS to 50.0FPS
84
+ The audio will be merged after interpolation process
85
+ 99%|██████████████████████████████████████▋| 127/128.0 [01:38<00:00, 1.29it/s]
86
+
87
+ ```
88
+
89
+ Output:
90
+ [INFO]:
build_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input": "./rife_x2_720p.onnx",
3
+ "output_dir": "./output",
4
+ "output_name": "rife_x2_720p.axmodel",
5
+ "work_dir": "",
6
+ "model_type": "ONNX",
7
+ "target_hardware": "AX650",
8
+ "npu_mode": "NPU3",
9
+ "onnx_opt": {
10
+ "disable_onnx_optimization": false,
11
+ "model_check": false,
12
+ },
13
+ "quant": {
14
+ "input_configs": [
15
+ {
16
+ "tensor_name": "DEFAULT",
17
+ "calibration_dataset": "1.zip",
18
+ "calibration_format": "Numpy",
19
+ "calibration_size": 10,
20
+ "calibration_mean": [0, 0, 0, 0, 0, 0],
21
+ "calibration_std": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
22
+ }
23
+ ],
24
+ "layer_configs":[
25
+ {
26
+ "start_tensor_names": ["DEFAULT"],
27
+ "end_tensor_names": ["DEFAULT"],
28
+ "data_type": "U16",
29
+ }
30
+ ],
31
+ "calibration_method": "MinMax",
32
+ "precision_analysis": true,
33
+ "precision_analysis_method": "EndToEnd",
34
+ "precision_analysis_mode": "Reference"
35
+ },
36
+ "input_processors": [
37
+ {
38
+ "tensor_name": "DEFAULT",
39
+ "src_dtype": "FP32",
40
+ }
41
+ ],
42
+ "output_processors": [
43
+ {
44
+ "tensor_name": "DEFAULT"
45
+ }
46
+ ],
47
+ "compiler": {
48
+ "check": 0
49
+ }
50
+ }
51
+
ms_ssim.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def gaussian(window_size, sigma):
9
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10
+ return gauss/gauss.sum()
11
+
12
+
13
+ def create_window(window_size, channel=1):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
16
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
17
+ return window
18
+
19
+ def create_window_3d(window_size, channel=1):
20
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
21
+ _2D_window = _1D_window.mm(_1D_window.t())
22
+ _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
23
+ window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
24
+ return window
25
+
26
+
27
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
28
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
29
+ if val_range is None:
30
+ if torch.max(img1) > 128:
31
+ max_val = 255
32
+ else:
33
+ max_val = 1
34
+
35
+ if torch.min(img1) < -0.5:
36
+ min_val = -1
37
+ else:
38
+ min_val = 0
39
+ L = max_val - min_val
40
+ else:
41
+ L = val_range
42
+
43
+ padd = 0
44
+ (_, channel, height, width) = img1.size()
45
+ if window is None:
46
+ real_size = min(window_size, height, width)
47
+ window = create_window(real_size, channel=channel).to(img1.device)
48
+
49
+ # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
50
+ # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
51
+ mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
52
+ mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
53
+
54
+ mu1_sq = mu1.pow(2)
55
+ mu2_sq = mu2.pow(2)
56
+ mu1_mu2 = mu1 * mu2
57
+
58
+ sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
59
+ sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
60
+ sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
61
+
62
+ C1 = (0.01 * L) ** 2
63
+ C2 = (0.03 * L) ** 2
64
+
65
+ v1 = 2.0 * sigma12 + C2
66
+ v2 = sigma1_sq + sigma2_sq + C2
67
+ cs = torch.mean(v1 / v2) # contrast sensitivity
68
+
69
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
70
+
71
+ if size_average:
72
+ ret = ssim_map.mean()
73
+ else:
74
+ ret = ssim_map.mean(1).mean(1).mean(1)
75
+
76
+ if full:
77
+ return ret, cs
78
+ return ret
79
+
80
+
81
+ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
82
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
83
+ if val_range is None:
84
+ if torch.max(img1) > 128:
85
+ max_val = 255
86
+ else:
87
+ max_val = 1
88
+
89
+ if torch.min(img1) < -0.5:
90
+ min_val = -1
91
+ else:
92
+ min_val = 0
93
+ L = max_val - min_val
94
+ else:
95
+ L = val_range
96
+
97
+ padd = 0
98
+ (_, _, height, width) = img1.size()
99
+ if window is None:
100
+ real_size = min(window_size, height, width)
101
+ window = create_window_3d(real_size, channel=1).to(img1.device)
102
+ # Channel is set to 1 since we consider color images as volumetric images
103
+
104
+ img1 = img1.unsqueeze(1)
105
+ img2 = img2.unsqueeze(1)
106
+
107
+ mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
108
+ mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
109
+
110
+ mu1_sq = mu1.pow(2)
111
+ mu2_sq = mu2.pow(2)
112
+ mu1_mu2 = mu1 * mu2
113
+
114
+ sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
115
+ sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
116
+ sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
117
+
118
+ C1 = (0.01 * L) ** 2
119
+ C2 = (0.03 * L) ** 2
120
+
121
+ v1 = 2.0 * sigma12 + C2
122
+ v2 = sigma1_sq + sigma2_sq + C2
123
+ cs = torch.mean(v1 / v2) # contrast sensitivity
124
+
125
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
126
+
127
+ if size_average:
128
+ ret = ssim_map.mean()
129
+ else:
130
+ ret = ssim_map.mean(1).mean(1).mean(1)
131
+
132
+ if full:
133
+ return ret, cs
134
+ return ret
135
+
136
+
137
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
138
+ device = img1.device
139
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
140
+ levels = weights.size()[0]
141
+ mssim = []
142
+ mcs = []
143
+ for _ in range(levels):
144
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
145
+ mssim.append(sim)
146
+ mcs.append(cs)
147
+
148
+ img1 = F.avg_pool2d(img1, (2, 2))
149
+ img2 = F.avg_pool2d(img2, (2, 2))
150
+
151
+ mssim = torch.stack(mssim)
152
+ mcs = torch.stack(mcs)
153
+
154
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
155
+ if normalize:
156
+ mssim = (mssim + 1) / 2
157
+ mcs = (mcs + 1) / 2
158
+
159
+ pow1 = mcs ** weights
160
+ pow2 = mssim ** weights
161
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
162
+ output = torch.prod(pow1[:-1] * pow2[-1])
163
+ return output
164
+
165
+
166
+ # Classes to re-use window
167
+ class SSIM(torch.nn.Module):
168
+ def __init__(self, window_size=11, size_average=True, val_range=None):
169
+ super(SSIM, self).__init__()
170
+ self.window_size = window_size
171
+ self.size_average = size_average
172
+ self.val_range = val_range
173
+
174
+ # Assume 3 channel for SSIM
175
+ self.channel = 3
176
+ self.window = create_window(window_size, channel=self.channel)
177
+
178
+ def forward(self, img1, img2):
179
+ (_, channel, _, _) = img1.size()
180
+
181
+ if channel == self.channel and self.window.dtype == img1.dtype:
182
+ window = self.window
183
+ else:
184
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
185
+ self.window = window
186
+ self.channel = channel
187
+
188
+ _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
189
+ dssim = (1 - _ssim) / 2
190
+ return dssim
191
+
192
+ class MSSSIM(torch.nn.Module):
193
+ def __init__(self, window_size=11, size_average=True, channel=3):
194
+ super(MSSSIM, self).__init__()
195
+ self.window_size = window_size
196
+ self.size_average = size_average
197
+ self.channel = channel
198
+
199
+ def forward(self, img1, img2):
200
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy>=1.16, <=1.23.5
2
+ tqdm>=4.35.0
3
+ torch>=1.3.0
4
+ opencv-python>=4.1.2
5
+ torchvision>=0.7.0
run_axmodel.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import argparse
5
+ import numpy as np
6
+ import axengine as axe
7
+ import _thread
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import ms_ssim
11
+
12
+ from tqdm import tqdm
13
+ from queue import Queue, Empty
14
+
15
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
16
+ parser.add_argument('--video', dest='video', type=str, default='./demo.mp4')
17
+ parser.add_argument('--output', dest='output', type=str, default=None)
18
+ parser.add_argument('--img', dest='img', type=str, default=None)
19
+ parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
20
+ parser.add_argument('--model', dest='model', type=str, default=None, help='directory with trained model files')
21
+ parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
22
+ parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
23
+ parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
24
+ parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
25
+ parser.add_argument('--fps', dest='fps', type=int, default=None)
26
+ parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
27
+ parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
28
+ parser.add_argument('--exp', dest='exp', type=int, default=1)
29
+ parser.add_argument('--multi', dest='multi', type=int, default=2)
30
+
31
+ def read_video(video_path):
32
+ cap = cv2.VideoCapture(video_path)
33
+ if not cap.isOpened():
34
+ raise IOError(f"Cannot open video: {video_path}")
35
+ try:
36
+ while True:
37
+ ret, frame = cap.read()
38
+ if not ret:
39
+ break
40
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
+ yield frame
42
+ finally:
43
+ cap.release()
44
+
45
+ def clear_write_buffer(user_args, write_buffer, vid_out):
46
+ cnt = 0
47
+ while True:
48
+ item = write_buffer.get()
49
+ if item is None:
50
+ break
51
+ if user_args.png:
52
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
53
+ cnt += 1
54
+ else:
55
+ vid_out.write(item[:, :, ::-1])
56
+
57
+ def build_read_buffer(user_args, read_buffer, videogen):
58
+ try:
59
+ for frame in videogen:
60
+ if not user_args.img is None:
61
+ frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
62
+ if user_args.montage:
63
+ frame = frame[:, left: left + w]
64
+ read_buffer.put(frame)
65
+ except:
66
+ pass
67
+ read_buffer.put(None)
68
+
69
+ def pad_image(img, padding):
70
+ if(args.fp16):
71
+ return F.pad(img, padding).half()
72
+ else:
73
+ return F.pad(img, padding)
74
+
75
+ def run(args):
76
+ '''onnx inference'''
77
+ # model
78
+ session = axe.InferenceSession(args.model, providers=['AxEngineExecutionProvider'])
79
+ output_names = [x.name for x in session.get_outputs()]
80
+ input_name = session.get_inputs()[0].name
81
+
82
+ # video
83
+ videoCapture = cv2.VideoCapture(args.video)
84
+ fps = videoCapture.get(cv2.CAP_PROP_FPS)
85
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
86
+ videoCapture.release()
87
+ if args.fps is None:
88
+ fpsNotAssigned = True
89
+ args.fps = fps * args.multi
90
+ else:
91
+ fpsNotAssigned = False
92
+ videogen = read_video(args.video)
93
+ lastframe = next(videogen)
94
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
95
+ video_path_wo_ext, ext = os.path.splitext(args.video)
96
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
97
+ if args.png == False and fpsNotAssigned == True:
98
+ print("The audio will be merged after interpolation process")
99
+ else:
100
+ print("Will not merge audio because using png or fps flag!")
101
+
102
+ #
103
+ h, w, _ = lastframe.shape
104
+ vid_out_name = None
105
+ vid_out = None
106
+ if args.png:
107
+ if not os.path.exists('vid_out'):
108
+ os.mkdir('vid_out')
109
+ else:
110
+ if args.output is not None:
111
+ vid_out_name = args.output
112
+ else:
113
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
114
+ vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
115
+
116
+ tmp = max(128, int(128 / args.scale))
117
+ ph = ((h - 1) // tmp + 1) * tmp
118
+ pw = ((w - 1) // tmp + 1) * tmp
119
+ #padding = (0, pw - w, 0, ph - h)
120
+ padding = ((0, 0), (0, 0), (0, ph - h), (0, pw - w))
121
+ pbar = tqdm(total=tot_frame, ncols=80)
122
+
123
+ write_buffer = Queue(maxsize=500)
124
+ read_buffer = Queue(maxsize=500)
125
+ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
126
+ _thread.start_new_thread(clear_write_buffer, (args, write_buffer, vid_out))
127
+
128
+ #device = 'cpu'
129
+ #I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
130
+ I1 = np.expand_dims(np.transpose(lastframe, (2,0,1)), 0).astype(np.float32) / 255.
131
+ I1 = np.pad(I1, padding)
132
+
133
+ temp = None # save lastframe when processing static frame
134
+ while True:
135
+ if temp is not None:
136
+ frame = temp
137
+ temp = None
138
+ else:
139
+ frame = read_buffer.get()
140
+ if frame is None:
141
+ break
142
+ I0 = I1
143
+ #I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
144
+ I1 = np.expand_dims(np.transpose(frame, (2,0,1)), 0).astype(np.float32) / 255.
145
+ I1 = np.pad(I1, padding)
146
+
147
+ I0_small = F.interpolate(torch.from_numpy(I0).float(), (32, 32), mode='bilinear', align_corners=False)
148
+ I1_small = F.interpolate(torch.from_numpy(I1).float(), (32, 32), mode='bilinear', align_corners=False)
149
+ ssim = ms_ssim.ssim_matlab(I0_small[:, :3], I1_small[:, :3])
150
+
151
+ break_flag = False
152
+ if ssim > 0.996: #0.996
153
+ frame = read_buffer.get() # read a new frame
154
+ if frame is None:
155
+ break_flag = True
156
+ frame = lastframe
157
+ else:
158
+ temp = frame
159
+ #I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
160
+ I1 = np.expand_dims(np.transpose(frame, (2,0,1)), 0).astype(np.float32) / 255.
161
+ I1 = np.pad(I1, padding)
162
+
163
+ #imgs = torch.cat((I0, I1), 1).cpu().numpy()
164
+ imgs = np.concatenate((I0, I1), axis=1)
165
+ I1 = session.run(output_names, {input_name: imgs})
166
+
167
+ #I1 = torch.from_numpy(I1[-1])
168
+ I1 = np.array(I1[-1])
169
+ I1_small = F.interpolate(torch.from_numpy(I1).float(), (32, 32), mode='bilinear', align_corners=False)
170
+ ssim = ms_ssim.ssim_matlab(I0_small[:, :3], I1_small[:, :3])
171
+ #frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
172
+ frame = np.clip(I1[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)[:h, :w]
173
+
174
+ if ssim < 0.2:
175
+ output = []
176
+ for i in range(args.multi - 1):
177
+ output.append(I0)
178
+ '''
179
+ output = []
180
+ step = 1 / args.multi
181
+ alpha = 0
182
+ for i in range(args.multi - 1):
183
+ alpha += step
184
+ beta = 1-alpha
185
+ output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
186
+ '''
187
+ else:
188
+ imgs = np.concatenate((I0, I1), axis=1)
189
+ output = [session.run(output_names, {input_name: imgs})[-1]]
190
+
191
+ if args.montage:
192
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
193
+ for mid in output:
194
+ #mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
195
+ mid = np.clip(mid[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)
196
+ write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
197
+ else:
198
+ write_buffer.put(lastframe)
199
+ for mid in output:
200
+ #mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
201
+ mid = np.clip(mid[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)
202
+ write_buffer.put(mid[:h, :w])
203
+ pbar.update(1)
204
+ lastframe = frame
205
+ if break_flag:
206
+ break
207
+
208
+ if args.montage:
209
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
210
+ else:
211
+ write_buffer.put(lastframe)
212
+ write_buffer.put(None)
213
+
214
+ while(not write_buffer.empty()):
215
+ time.sleep(0.1)
216
+ pbar.close()
217
+ if not vid_out is None:
218
+ vid_out.release()
219
+
220
+ if __name__ == '__main__':
221
+ args = parser.parse_args()
222
+ if args.exp != 1:
223
+ args.multi = (2 ** args.exp)
224
+ assert (not args.video is None or not args.img is None)
225
+ if args.skip:
226
+ print("skip flag is abandoned, please refer to issue #207.")
227
+ if args.UHD and args.scale==1.0:
228
+ args.scale = 0.5
229
+ assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
230
+ if not args.img is None:
231
+ args.png = True
232
+
233
+ run(args)
run_onnx.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import argparse
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ import _thread
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import ms_ssim
11
+
12
+ from tqdm import tqdm
13
+ from queue import Queue, Empty
14
+
15
+ parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
16
+ parser.add_argument('--video', dest='video', type=str, default='./demo.mp4')
17
+ parser.add_argument('--output', dest='output', type=str, default=None)
18
+ parser.add_argument('--img', dest='img', type=str, default=None)
19
+ parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video')
20
+ parser.add_argument('--model', dest='model', type=str, default=None, help='directory with trained model files')
21
+ parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores')
22
+ parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
23
+ parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video')
24
+ parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
25
+ parser.add_argument('--fps', dest='fps', type=int, default=None)
26
+ parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs')
27
+ parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension')
28
+ parser.add_argument('--exp', dest='exp', type=int, default=1)
29
+ parser.add_argument('--multi', dest='multi', type=int, default=2)
30
+
31
+ def read_video(video_path):
32
+ cap = cv2.VideoCapture(video_path)
33
+ if not cap.isOpened():
34
+ raise IOError(f"Cannot open video: {video_path}")
35
+ try:
36
+ while True:
37
+ ret, frame = cap.read()
38
+ if not ret:
39
+ break
40
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
+ yield frame
42
+ finally:
43
+ cap.release()
44
+
45
+ def clear_write_buffer(user_args, write_buffer, vid_out):
46
+ cnt = 0
47
+ while True:
48
+ item = write_buffer.get()
49
+ if item is None:
50
+ break
51
+ if user_args.png:
52
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
53
+ cnt += 1
54
+ else:
55
+ vid_out.write(item[:, :, ::-1])
56
+
57
+ def build_read_buffer(user_args, read_buffer, videogen):
58
+ try:
59
+ for frame in videogen:
60
+ if not user_args.img is None:
61
+ frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
62
+ if user_args.montage:
63
+ frame = frame[:, left: left + w]
64
+ read_buffer.put(frame)
65
+ except:
66
+ pass
67
+ read_buffer.put(None)
68
+
69
+ def pad_image(img, padding):
70
+ if(args.fp16):
71
+ return F.pad(img, padding).half()
72
+ else:
73
+ return F.pad(img, padding)
74
+
75
+ def run(args):
76
+ '''onnx inference'''
77
+ # model
78
+ session = ort.InferenceSession(args.model, providers=['CPUExecutionProvider'])
79
+ output_names = [x.name for x in session.get_outputs()]
80
+ input_name = session.get_inputs()[0].name
81
+
82
+ # video
83
+ videoCapture = cv2.VideoCapture(args.video)
84
+ fps = videoCapture.get(cv2.CAP_PROP_FPS)
85
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
86
+ videoCapture.release()
87
+ if args.fps is None:
88
+ fpsNotAssigned = True
89
+ args.fps = fps * args.multi
90
+ else:
91
+ fpsNotAssigned = False
92
+ videogen = read_video(args.video)
93
+ lastframe = next(videogen)
94
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
95
+ video_path_wo_ext, ext = os.path.splitext(args.video)
96
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))
97
+ if args.png == False and fpsNotAssigned == True:
98
+ print("The audio will be merged after interpolation process")
99
+ else:
100
+ print("Will not merge audio because using png or fps flag!")
101
+
102
+ #
103
+ h, w, _ = lastframe.shape
104
+ vid_out_name = None
105
+ vid_out = None
106
+ if args.png:
107
+ if not os.path.exists('vid_out'):
108
+ os.mkdir('vid_out')
109
+ else:
110
+ if args.output is not None:
111
+ vid_out_name = args.output
112
+ else:
113
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
114
+ vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))
115
+
116
+ tmp = max(128, int(128 / args.scale))
117
+ ph = ((h - 1) // tmp + 1) * tmp
118
+ pw = ((w - 1) // tmp + 1) * tmp
119
+ #padding = (0, pw - w, 0, ph - h)
120
+ padding = ((0, 0), (0, 0), (0, ph - h), (0, pw - w))
121
+ pbar = tqdm(total=tot_frame, ncols=80)
122
+
123
+ write_buffer = Queue(maxsize=500)
124
+ read_buffer = Queue(maxsize=500)
125
+ _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
126
+ _thread.start_new_thread(clear_write_buffer, (args, write_buffer, vid_out))
127
+
128
+ #device = 'cpu'
129
+ #I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
130
+ I1 = np.expand_dims(np.transpose(lastframe, (2,0,1)), 0).astype(np.float32) / 255.
131
+ I1 = np.pad(I1, padding)
132
+
133
+ temp = None # save lastframe when processing static frame
134
+ while True:
135
+ if temp is not None:
136
+ frame = temp
137
+ temp = None
138
+ else:
139
+ frame = read_buffer.get()
140
+ if frame is None:
141
+ break
142
+ I0 = I1
143
+ #I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
144
+ I1 = np.expand_dims(np.transpose(frame, (2,0,1)), 0).astype(np.float32) / 255.
145
+ I1 = np.pad(I1, padding)
146
+
147
+ I0_small = F.interpolate(torch.from_numpy(I0).float(), (32, 32), mode='bilinear', align_corners=False)
148
+ I1_small = F.interpolate(torch.from_numpy(I1).float(), (32, 32), mode='bilinear', align_corners=False)
149
+ ssim = ms_ssim.ssim_matlab(I0_small[:, :3], I1_small[:, :3])
150
+
151
+ break_flag = False
152
+ if ssim > 0.996: #0.996
153
+ frame = read_buffer.get() # read a new frame
154
+ if frame is None:
155
+ break_flag = True
156
+ frame = lastframe
157
+ else:
158
+ temp = frame
159
+ #I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
160
+ I1 = np.expand_dims(np.transpose(frame, (2,0,1)), 0).astype(np.float32) / 255.
161
+ I1 = np.pad(I1, padding)
162
+
163
+ #imgs = torch.cat((I0, I1), 1).cpu().numpy()
164
+ imgs = np.concatenate((I0, I1), axis=1)
165
+ I1 = session.run(output_names, {input_name: imgs})
166
+
167
+ #I1 = torch.from_numpy(I1[-1])
168
+ I1 = np.array(I1[-1])
169
+ I1_small = F.interpolate(torch.from_numpy(I1).float(), (32, 32), mode='bilinear', align_corners=False)
170
+ ssim = ms_ssim.ssim_matlab(I0_small[:, :3], I1_small[:, :3])
171
+ #frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
172
+ frame = np.clip(I1[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)[:h, :w]
173
+
174
+ if ssim < 0.2:
175
+ output = []
176
+ for i in range(args.multi - 1):
177
+ output.append(I0)
178
+ '''
179
+ output = []
180
+ step = 1 / args.multi
181
+ alpha = 0
182
+ for i in range(args.multi - 1):
183
+ alpha += step
184
+ beta = 1-alpha
185
+ output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
186
+ '''
187
+ else:
188
+ imgs = np.concatenate((I0, I1), axis=1)
189
+ output = [session.run(output_names, {input_name: imgs})[-1]]
190
+
191
+ if args.montage:
192
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
193
+ for mid in output:
194
+ #mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
195
+ mid = np.clip(mid[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)
196
+ write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
197
+ else:
198
+ write_buffer.put(lastframe)
199
+ for mid in output:
200
+ #mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
201
+ mid = np.clip(mid[0] * 255, 0, 255).astype(np.uint8).transpose(1, 2, 0)
202
+ write_buffer.put(mid[:h, :w])
203
+ pbar.update(1)
204
+ lastframe = frame
205
+ if break_flag:
206
+ break
207
+
208
+ if args.montage:
209
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
210
+ else:
211
+ write_buffer.put(lastframe)
212
+ write_buffer.put(None)
213
+
214
+ while(not write_buffer.empty()):
215
+ time.sleep(0.1)
216
+ pbar.close()
217
+ if not vid_out is None:
218
+ vid_out.release()
219
+
220
+ if __name__ == '__main__':
221
+ args = parser.parse_args()
222
+ if args.exp != 1:
223
+ args.multi = (2 ** args.exp)
224
+ assert (not args.video is None or not args.img is None)
225
+ if args.skip:
226
+ print("skip flag is abandoned, please refer to issue #207.")
227
+ if args.UHD and args.scale==1.0:
228
+ args.scale = 0.5
229
+ assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
230
+ if not args.img is None:
231
+ args.png = True
232
+
233
+ run(args)