Spaces:
Configuration error
Configuration error
Commit
·
952c41a
1
Parent(s):
fdab143
huggingface dataset
Browse files
README.md
CHANGED
|
@@ -1,12 +1,371 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VideoGrain: Modulating Space-Time Attention for Multi-Grained Video Editing (ICLR 2025)
|
| 2 |
+
## [<a href="https://knightyxp.github.io/VideoGrain_project_page/" target="_blank">Project Page</a>]
|
| 3 |
+
|
| 4 |
+
[](https://arxiv.org/abs/2502.17258)
|
| 5 |
+
[](https://huggingface.co/papers/2502.17258)
|
| 6 |
+
[](https://knightyxp.github.io/VideoGrain_project_page/)
|
| 7 |
+
[](https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link)
|
| 8 |
+

|
| 9 |
+
[](https://www.youtube.com/watch?v=XEM4Pex7F9E)
|
| 10 |
+
[](https://huggingface.co/datasets/XiangpengYang/VideoGrain-dataset)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
## Introduction
|
| 14 |
+
VideoGrain is a zero-shot method for class-level, instance-level, and part-level video editing.
|
| 15 |
+
- **Multi-grained Video Editing**
|
| 16 |
+
- class-level: Editing objects within the same class (previous SOTA limited to this level)
|
| 17 |
+
- instance-level: Editing each individual instance to distinct object
|
| 18 |
+
- part-level: Adding new objects or modifying existing attributes at the part-level
|
| 19 |
+
- **Training-Free**
|
| 20 |
+
- Does not require any training/fine-tuning
|
| 21 |
+
- **One-Prompt Multi-region Control & Deep investigations about cross/self attn**
|
| 22 |
+
- modulating cross-attn for multi-regions control (visualizations available)
|
| 23 |
+
- modulating self-attn for feature decoupling (clustering are available)
|
| 24 |
+
|
| 25 |
+
<table class="center" border="1" cellspacing="0" cellpadding="5">
|
| 26 |
+
<tr>
|
| 27 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/class_level.gif" style="width:250px; height:auto;"></td>
|
| 28 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/instance_part.gif" style="width:250px; height:auto;"></td>
|
| 29 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/2monkeys.gif" style="width:250px; height:auto;"></td>
|
| 30 |
+
</tr>
|
| 31 |
+
<tr>
|
| 32 |
+
<!-- <td colspan="1" style="text-align:right; width:125px;"> </td> -->
|
| 33 |
+
<td colspan="2" style="text-align:right; width:250px;"> class level</td>
|
| 34 |
+
<td colspan="1" style="text-align:center; width:125px;">instance level</td>
|
| 35 |
+
<td colspan="1" style="text-align:center; width:125px;">part level</td>
|
| 36 |
+
<td colspan="2" style="text-align:center; width:250px;">animal instances</td>
|
| 37 |
+
</tr>
|
| 38 |
+
|
| 39 |
+
<tr>
|
| 40 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/2cats.gif" style="width:250px; height:auto;"></td>
|
| 41 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/soap-box.gif" style="width:250px; height:auto;"></td>
|
| 42 |
+
<td colspan="2" style="text-align:center;"><img src="assets/teaser/man-text-message.gif" style="width:250px; height:auto;"></td>
|
| 43 |
+
</tr>
|
| 44 |
+
<tr>
|
| 45 |
+
<td colspan="2" style="text-align:center; width:250px;">animal instances</td>
|
| 46 |
+
<td colspan="2" style="text-align:center; width:250px;">human instances</td>
|
| 47 |
+
<td colspan="2" style="text-align:center; width:250px;">part-level modification</td>
|
| 48 |
+
</tr>
|
| 49 |
+
</table>
|
| 50 |
+
|
| 51 |
+
## 📀 Demo Video
|
| 52 |
+
<!-- [](https://www.youtube.com/watch?v=XEM4Pex7F9E "Demo Video of VideoGrain") -->
|
| 53 |
+
https://github.com/user-attachments/assets/9bec92fc-21bd-4459-86fa-62404d8762bf
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
## 📣 News
|
| 57 |
+
* **[2025/2/25]** Our VideoGrain is posted and recommended by Gradio on [LinkedIn](https://www.linkedin.com/posts/gradio_just-dropped-videograin-a-new-zero-shot-activity-7300094635094261760-hoiE) and [Twitter](https://x.com/Gradio/status/1894328911154028566), and recommended by [AK](https://x.com/_akhaliq/status/1894254599223017622).
|
| 58 |
+
* **[2025/2/25]** Our VideoGrain is submited by AK to [HuggingFace-daily papers](https://huggingface.co/papers?date=2025-02-25), and rank [#1](https://huggingface.co/papers/2502.17258) paper of that day.
|
| 59 |
+
* **[2025/2/24]** We release our paper on [arxiv](https://arxiv.org/abs/2502.17258), we also release [code](https://github.com/knightyxp/VideoGrain) and [full-data](https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link) on google drive.
|
| 60 |
+
* **[2025/1/23]** Our paper is accepted to [ICLR2025](https://openreview.net/forum?id=SSslAtcPB6)! Welcome to **watch** 👀 this repository for the latest updates.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
## 🍻 Setup Environment
|
| 64 |
+
Our method is tested using cuda12.1, fp16 of accelerator and xformers on a single L40.
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
# Step 1: Create and activate Conda environment
|
| 68 |
+
conda create -n videograin python==3.10
|
| 69 |
+
conda activate videograin
|
| 70 |
+
|
| 71 |
+
# Step 2: Install PyTorch, CUDA and Xformers
|
| 72 |
+
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 73 |
+
pip install --pre -U xformers==0.0.27
|
| 74 |
+
# Step 3: Install additional dependencies with pip
|
| 75 |
+
pip install -r requirements.txt
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
`xformers` is recommended to save memory and running time.
|
| 79 |
+
|
| 80 |
+
</details>
|
| 81 |
+
|
| 82 |
+
You may download all the base model checkpoints using the following bash command
|
| 83 |
+
```bash
|
| 84 |
+
## download sd 1.5, controlnet depth/pose v10/v11
|
| 85 |
+
bash download_all.sh
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
<details><summary>Click for ControlNet annotator weights (if you can not access to huggingface)</summary>
|
| 89 |
+
|
| 90 |
+
You can download all the annotator checkpoints (such as DW-Pose, depth_zoe, depth_midas, and OpenPose, cost around 4G) from [baidu](https://pan.baidu.com/s/1sgBFLFkdTCDTn4oqHjGb9A?pwd=pdm5) or [google](https://drive.google.com/file/d/1qOsmWshnFMMr8x1HteaTViTSQLh_4rle/view?usp=drive_link)
|
| 91 |
+
Then extract them into ./annotator/ckpts
|
| 92 |
+
|
| 93 |
+
</details>
|
| 94 |
+
|
| 95 |
+
## ⚡️ Prepare all the data
|
| 96 |
+
|
| 97 |
+
### Full VideoGrain Data
|
| 98 |
+
We have provided `all the video data and layout masks in VideoGrain` at following link. Please download unzip the data and put them in the `./data' root directory.
|
| 99 |
+
```
|
| 100 |
+
gdown https://drive.google.com/file/d/1dzdvLnXWeMFR3CE2Ew0Bs06vyFSvnGXA/view?usp=drive_link
|
| 101 |
+
tar -zxvf videograin_data.tar.gz
|
| 102 |
+
```
|
| 103 |
+
### Customize Your Own Data
|
| 104 |
+
**prepare video to frames**
|
| 105 |
+
If the input video is mp4 file, using the following command to process it to frames:
|
| 106 |
+
```bash
|
| 107 |
+
python image_util/sample_video2frames.py --video_path 'your video path' --output_dir './data/video_name/video_name'
|
| 108 |
+
```
|
| 109 |
+
**prepare layout masks**
|
| 110 |
+
We segment videos using our ReLER lab's [SAM-Track](https://github.com/z-x-yang/Segment-and-Track-Anything). I suggest using the `app.py` in SAM-Track for `graio` mode to manually select which region in the video your want to edit. Here, we also provided an script ` image_util/process_webui_mask.py` to process masks from SAM-Track path to VideoGrain path.
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
## 🔥🔥🔥 VideoGrain Editing
|
| 114 |
+
|
| 115 |
+
### 🎨 Inference
|
| 116 |
+
Your can reproduce the instance + part level results in our teaser by running:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
bash test.sh
|
| 120 |
+
#or
|
| 121 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
For other instance/part/class results in VideoGrain project page or teaser, we provide all the data (video frames and layout masks) and corresponding configs to reproduce, check results in [🚀Multi-Grained Video Editing](#multi-grained-video-editing-results).
|
| 125 |
+
|
| 126 |
+
<details><summary>The result is saved at `./result` . (Click for directory structure) </summary>
|
| 127 |
+
|
| 128 |
+
```
|
| 129 |
+
result
|
| 130 |
+
├── run_two_man
|
| 131 |
+
│ ├── control # control conditon
|
| 132 |
+
│ ├── infer_samples
|
| 133 |
+
│ ├── input # the input video frames
|
| 134 |
+
│ ├── masked_video.mp4 # check whether edit regions are accuratedly covered
|
| 135 |
+
│ ├── sample
|
| 136 |
+
│ ├── step_0 # result image folder
|
| 137 |
+
│ ├── step_0.mp4 # result video
|
| 138 |
+
│ ├── source_video.mp4 # the input video
|
| 139 |
+
│ ├── visualization_denoise # cross attention weight
|
| 140 |
+
│ ├── sd_study # cluster inversion feature
|
| 141 |
+
```
|
| 142 |
+
</details>
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
## Editing guidance for YOUR Video
|
| 146 |
+
### 🔛prepare your config
|
| 147 |
+
|
| 148 |
+
VideoGrain is a training-free framework. To run VideoGrain on your video, modify `./config/demo_config.yaml` based on your needs:
|
| 149 |
+
|
| 150 |
+
1. Replace your pretrained model path and controlnet path in your config. you can change the control_type to `dwpose` or `depth_zoe` or `depth`(midas).
|
| 151 |
+
2. Prepare your video frames and layout masks (edit regions) using SAM-Track or SAM2 in dataset config.
|
| 152 |
+
3. Change the `prompt`, and extract each `local prompt` in the editing prompts. the local prompt order should be same as layout masks order.
|
| 153 |
+
4. Your can change flatten resolution with 1->64, 2->16, 4->8. (commonly, flatten at 64 worked best)
|
| 154 |
+
5. To ensure temporal consistency, you can set `use_pnp: True` and `inject_step:5/10`. (Note: pnp>10 steps will be bad for multi-regions editing)
|
| 155 |
+
6. If you want to visualize the cross attn weight, set `vis_cross_attn: True`
|
| 156 |
+
7. If you want to cluster DDIM Inversion spatial temporal video feature, set `cluster_inversion_feature: True`
|
| 157 |
+
|
| 158 |
+
### 😍Editing your video
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
bash test.sh
|
| 162 |
+
#or
|
| 163 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config /path/to/the/config
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
## 🚀Multi-Grained Video Editing Results
|
| 167 |
+
|
| 168 |
+
### 🌈 Multi-Grained Definition
|
| 169 |
+
You can get multi-grained definition result, using the following command:
|
| 170 |
+
```bash
|
| 171 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config /config/class_level/running_two_man/man2spider.yaml #class-level
|
| 172 |
+
# /config/instance_level/running_two_man/4cls_spider_polar.yaml #instance-level
|
| 173 |
+
#config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml #part-level
|
| 174 |
+
```
|
| 175 |
+
<table class="center">
|
| 176 |
+
<tr>
|
| 177 |
+
<td width=25% style="text-align:center;">source video</td>
|
| 178 |
+
<td width=25% style="text-align:center;">class level</td>
|
| 179 |
+
<td width=25% style="text-align:center;">instance level</td>
|
| 180 |
+
<td width=25% style="text-align:center;">part level</td>
|
| 181 |
+
</tr>
|
| 182 |
+
<tr>
|
| 183 |
+
<td><img src="./assets/teaser/run_two_man.gif"></td>
|
| 184 |
+
<td><img src="./assets/teaser/class_level_0.gif"></td>
|
| 185 |
+
<td><img src="./assets/teaser/instance_level.gif"></td>
|
| 186 |
+
<td><img src="./assets/teaser/part_level.gif"></td>
|
| 187 |
+
</tr>
|
| 188 |
+
</table>
|
| 189 |
+
|
| 190 |
+
## 💃 Instance-level Video Editing
|
| 191 |
+
You can get instance-level video editing results, using the following command:
|
| 192 |
+
```bash
|
| 193 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/running_two_man/running_3cls_iron_spider.yaml
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
<table class="center">
|
| 197 |
+
<tr>
|
| 198 |
+
<td width=50% style="text-align:center;">running_two_man/3cls_iron_spider.yaml</td>
|
| 199 |
+
<td width=50% style="text-align:center;">2_monkeys/2cls_teddy_bear_koala.yaml</td>
|
| 200 |
+
</tr>
|
| 201 |
+
<tr>
|
| 202 |
+
<td><img src="assets/instance-level/left_iron_right_spider.gif"></td>
|
| 203 |
+
<td><img src="assets/instance-level/teddy_koala.gif"></td>
|
| 204 |
+
</tr>
|
| 205 |
+
<tr>
|
| 206 |
+
<td width=50% style="text-align:center;">badminton/2cls_wonder_woman_spiderman.yaml</td>
|
| 207 |
+
<td width=50% style="text-align:center;">soap-box/soap-box.yaml</td>
|
| 208 |
+
</tr>
|
| 209 |
+
<tr>
|
| 210 |
+
<td><img src="assets/instance-level/badminton.gif"></td>
|
| 211 |
+
<td><img src="assets/teaser/soap-box.gif"></td>
|
| 212 |
+
</tr>
|
| 213 |
+
<tr>
|
| 214 |
+
<td width=50% style="text-align:center;">2_cats/4cls_panda_vs_poddle.yaml</td>
|
| 215 |
+
<td width=50% style="text-align:center;">2_cars/left_firetruck_right_bus.yaml</td>
|
| 216 |
+
</tr>
|
| 217 |
+
<tr>
|
| 218 |
+
<td><img src="assets/instance-level/panda_vs_poddle.gif"></td>
|
| 219 |
+
<td><img src="assets/instance-level/2cars.gif"></td>
|
| 220 |
+
</tr>
|
| 221 |
+
</table>
|
| 222 |
+
|
| 223 |
+
## 🕺 Part-level Video Editing
|
| 224 |
+
You can get part-level video editing results, using the following command:
|
| 225 |
+
```bash
|
| 226 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/part_level/modification/man_text_message/blue_shirt.yaml
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
<table class="center">
|
| 230 |
+
<tr>
|
| 231 |
+
<td><img src="assets/part-level/man_text_message.gif"></td>
|
| 232 |
+
<td><img src="assets/part-level/blue-shirt.gif"></td>
|
| 233 |
+
<td><img src="assets/part-level/black-suit.gif"></td>
|
| 234 |
+
<td><img src="assets/part-level/cat_flower.gif"></td>
|
| 235 |
+
<td><img src="assets/part-level/ginger_head.gif"></td>
|
| 236 |
+
<td><img src="assets/part-level/ginger_body.gif"></td>
|
| 237 |
+
</tr>
|
| 238 |
+
<tr>
|
| 239 |
+
<td width=15% style="text-align:center;">source video</td>
|
| 240 |
+
<td width=15% style="text-align:center;">blue shirt</td>
|
| 241 |
+
<td width=15% style="text-align:center;">black suit</td>
|
| 242 |
+
<td width=15% style="text-align:center;">source video</td>
|
| 243 |
+
<td width=15% style="text-align:center;">ginger head </td>
|
| 244 |
+
<td width=15% style="text-align:center;">ginger body</td>
|
| 245 |
+
</tr>
|
| 246 |
+
<tr>
|
| 247 |
+
<td><img src="assets/part-level/man_text_message.gif"></td>
|
| 248 |
+
<td><img src="assets/part-level/superman.gif"></td>
|
| 249 |
+
<td><img src="assets/part-level/superman+cap.gif"></td>
|
| 250 |
+
<td><img src="assets/part-level/spin-ball.gif"></td>
|
| 251 |
+
<td><img src="assets/part-level/superman_spin.gif"></td>
|
| 252 |
+
<td><img src="assets/part-level/super_sunglass_spin.gif"></td>
|
| 253 |
+
</tr>
|
| 254 |
+
<tr>
|
| 255 |
+
<td width=15% style="text-align:center;">source video</td>
|
| 256 |
+
<td width=15% style="text-align:center;">superman</td>
|
| 257 |
+
<td width=15% style="text-align:center;">superman + cap</td>
|
| 258 |
+
<td width=15% style="text-align:center;">source video</td>
|
| 259 |
+
<td width=15% style="text-align:center;">superman </td>
|
| 260 |
+
<td width=15% style="text-align:center;">superman + sunglasses</td>
|
| 261 |
+
</tr>
|
| 262 |
+
</table>
|
| 263 |
+
|
| 264 |
+
## 🥳 Class-level Video Editing
|
| 265 |
+
You can get class-level video editing results, using the following command:
|
| 266 |
+
```bash
|
| 267 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/class_level/wolf/wolf.yaml
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
<table class="center">
|
| 271 |
+
<tr>
|
| 272 |
+
<td><img src="assets/class-level/wolf.gif"></td>
|
| 273 |
+
<td><img src="assets/class-level/pig.gif"></td>
|
| 274 |
+
<td><img src="assets/class-level/husky.gif"></td>
|
| 275 |
+
<td><img src="assets/class-level/bear.gif"></td>
|
| 276 |
+
<td><img src="assets/class-level/tiger.gif"></td>
|
| 277 |
+
</tr>
|
| 278 |
+
<tr>
|
| 279 |
+
<td width=15% style="text-align:center;">input</td>
|
| 280 |
+
<td width=15% style="text-align:center;">pig</td>
|
| 281 |
+
<td width=15% style="text-align:center;">husky</td>
|
| 282 |
+
<td width=15% style="text-align:center;">bear</td>
|
| 283 |
+
<td width=15% style="text-align:center;">tiger</td>
|
| 284 |
+
</tr>
|
| 285 |
+
<tr>
|
| 286 |
+
<td><img src="assets/class-level/tennis.gif"></td>
|
| 287 |
+
<td><img src="assets/class-level/tennis_1cls.gif"></td>
|
| 288 |
+
<td><img src="assets/class-level/tennis_3cls.gif"></td>
|
| 289 |
+
<td><img src="assets/class-level/car-1.gif"></td>
|
| 290 |
+
<td><img src="assets/class-level/posche.gif"></td>
|
| 291 |
+
</tr>
|
| 292 |
+
<tr>
|
| 293 |
+
<td width=15% style="text-align:center;">input</td>
|
| 294 |
+
<td width=15% style="text-align:center;">iron man</td>
|
| 295 |
+
<td width=15% style="text-align:center;">Batman + snow court + iced wall</td>
|
| 296 |
+
<td width=15% style="text-align:center;">input </td>
|
| 297 |
+
<td width=15% style="text-align:center;">posche</td>
|
| 298 |
+
</tr>
|
| 299 |
+
</table>
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
## Soely Edit on specific subjects, keep background unchanged
|
| 303 |
+
You can get soely video editing results, using the following command:
|
| 304 |
+
```bash
|
| 305 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/soely_edit/only_left.yaml
|
| 306 |
+
#--config config/instance_level/soely_edit/only_right.yaml
|
| 307 |
+
#--config config/instance_level/soely_edit/joint_edit.yaml
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
<table class="center">
|
| 311 |
+
<tr>
|
| 312 |
+
<td><img src="assets/soely_edit/input.gif"></td>
|
| 313 |
+
<td><img src="assets/soely_edit/left.gif"></td>
|
| 314 |
+
<td><img src="assets/soely_edit/right.gif"></td>
|
| 315 |
+
<td><img src="assets/soely_edit/joint.gif"></td>
|
| 316 |
+
</tr>
|
| 317 |
+
<tr>
|
| 318 |
+
<td width=25% style="text-align:center;">source video</td>
|
| 319 |
+
<td width=25% style="text-align:center;">left→Iron Man</td>
|
| 320 |
+
<td width=25% style="text-align:center;">right→Spiderman</td>
|
| 321 |
+
<td width=25% style="text-align:center;">joint edit</td>
|
| 322 |
+
</tr>
|
| 323 |
+
</table>
|
| 324 |
+
|
| 325 |
+
## 🔍 Visualize Cross Attention Weight
|
| 326 |
+
You can get visulize attention weight editing results, using the following command:
|
| 327 |
+
```bash
|
| 328 |
+
#setting vis_cross_attn: True in your config
|
| 329 |
+
CUDA_VISIBLE_DEVICES=0 accelerate launch test.py --config config/instance_level/running_two_man/3cls_spider_polar_vis_weight.yaml
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
<table class="center">
|
| 333 |
+
<tr>
|
| 334 |
+
<td><img src="assets/soely_edit/input.gif"></td>
|
| 335 |
+
<td><img src="assets/vis/edit.gif"></td>
|
| 336 |
+
<td><img src="assets/vis/spiderman_weight.gif"></td>
|
| 337 |
+
<td><img src="assets/vis/bear_weight.gif"></td>
|
| 338 |
+
<td><img src="/assets/vis/cherry_weight.gif"></td>
|
| 339 |
+
</tr>
|
| 340 |
+
<tr>
|
| 341 |
+
<td width=20% style="text-align:center;">source video</td>
|
| 342 |
+
<td width=20% style="text-align:center;">left→spiderman, right→polar bear, trees→cherry blossoms</td>
|
| 343 |
+
<td width=20% style="text-align:center;">spiderman weight</td>
|
| 344 |
+
<td width=20% style="text-align:center;">bear weight</td>
|
| 345 |
+
<td width=20% style="text-align:center;">cherry weight</td>
|
| 346 |
+
</tr>
|
| 347 |
+
</table>
|
| 348 |
+
|
| 349 |
+
## ✏️ Citation
|
| 350 |
+
If you think this project is helpful, please feel free to leave a star⭐️⭐️⭐️ and cite our paper:
|
| 351 |
+
```bibtex
|
| 352 |
+
@article{yang2025videograin,
|
| 353 |
+
title={VideoGrain: Modulating Space-Time Attention for Multi-grained Video Editing},
|
| 354 |
+
author={Yang, Xiangpeng and Zhu, Linchao and Fan, Hehe and Yang, Yi},
|
| 355 |
+
journal={arXiv preprint arXiv:2502.17258},
|
| 356 |
+
year={2025}
|
| 357 |
+
}
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
## 📞 Contact Authors
|
| 361 |
+
Xiangpeng Yang [@knightyxp](https://github.com/knightyxp), email: [email protected]/[email protected]
|
| 362 |
+
|
| 363 |
+
## ✨ Acknowledgements
|
| 364 |
+
|
| 365 |
+
- This code builds on [diffusers](https://github.com/huggingface/diffusers), and [FateZero](https://github.com/ChenyangQiQi/FateZero). Thanks for open-sourcing!
|
| 366 |
+
- We would like to thank [AK(@_akhaliq)](https://x.com/_akhaliq/status/1894254599223017622) and Gradio team for recommendation!
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
## ⭐️ Star History
|
| 370 |
+
|
| 371 |
+
[](https://star-history.com/#knightyxp/VideoGrain&Date)
|
config/part_level/adding_new_object/run_two_man/spider_polar_sunglass.yaml
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
pretrained_model_path: "ckpt/stable-diffusion-v1-5"
|
| 2 |
logdir: ./result/part_level/run_two_man/left2spider_right2polar-sunglasses
|
| 3 |
|
| 4 |
dataset_config:
|
|
|
|
| 1 |
+
pretrained_model_path: "./ckpt/stable-diffusion-v1-5"
|
| 2 |
logdir: ./result/part_level/run_two_man/left2spider_right2polar-sunglasses
|
| 3 |
|
| 4 |
dataset_config:
|
image.png
ADDED
|
video_diffusion/data/__pycache__/dataset.cpython-310.pyc
CHANGED
|
Binary files a/video_diffusion/data/__pycache__/dataset.cpython-310.pyc and b/video_diffusion/data/__pycache__/dataset.cpython-310.pyc differ
|
|
|
video_diffusion/data/dataset.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from PIL import Image
|
| 4 |
from einops import rearrange
|
|
@@ -10,24 +11,26 @@ from torch.utils.data import Dataset
|
|
| 10 |
from .transform import short_size_scale, random_crop, center_crop, offset_crop
|
| 11 |
from ..common.image_util import IMAGE_EXTENSION
|
| 12 |
import cv2
|
| 13 |
-
import imageio
|
| 14 |
-
import shutil
|
| 15 |
|
| 16 |
class ImageSequenceDataset(Dataset):
|
| 17 |
def __init__(
|
| 18 |
self,
|
| 19 |
-
path: str,
|
| 20 |
-
|
|
|
|
| 21 |
prompt_ids: torch.Tensor,
|
| 22 |
prompt: str,
|
| 23 |
-
start_sample_frame: int
|
| 24 |
n_sample_frame: int = 8,
|
| 25 |
sampling_rate: int = 1,
|
| 26 |
-
stride: int = -1,
|
| 27 |
image_mode: str = "RGB",
|
| 28 |
image_size: int = 512,
|
| 29 |
crop: str = "center",
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
offset: dict = {
|
| 32 |
"left": 0,
|
| 33 |
"right": 0,
|
|
@@ -35,42 +38,33 @@ class ImageSequenceDataset(Dataset):
|
|
| 35 |
"bottom": 0
|
| 36 |
},
|
| 37 |
**args
|
|
|
|
| 38 |
):
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
self.images = self.get_image_list(self.path)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
self.layout_mask_dirs = []
|
| 49 |
-
for idx, file in enumerate(layout_files):
|
| 50 |
-
if file.endswith('.mp4'):
|
| 51 |
-
folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}')
|
| 52 |
-
else:
|
| 53 |
-
folder = file
|
| 54 |
-
self.layout_mask_dirs.append(folder)
|
| 55 |
-
# 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序)
|
| 56 |
-
self.layout_mask_order = list(range(len(self.layout_mask_dirs)))
|
| 57 |
-
# 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数)
|
| 58 |
-
self.masks_index = self.get_image_list(self.layout_mask_dirs[0])
|
| 59 |
|
|
|
|
| 60 |
self.n_images = len(self.images)
|
| 61 |
self.offset = offset
|
| 62 |
self.start_sample_frame = start_sample_frame
|
| 63 |
if n_sample_frame < 0:
|
| 64 |
-
n_sample_frame = len(self.images)
|
| 65 |
self.n_sample_frame = n_sample_frame
|
|
|
|
| 66 |
self.sampling_rate = sampling_rate
|
| 67 |
|
| 68 |
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
|
| 69 |
if self.n_images < self.sequence_length:
|
| 70 |
-
raise ValueError(f"self.n_images
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
self.stride = stride if stride > 0 else (self.n_images
|
| 74 |
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
|
| 75 |
|
| 76 |
self.image_mode = image_mode
|
|
@@ -80,53 +74,67 @@ class ImageSequenceDataset(Dataset):
|
|
| 80 |
"random": random_crop,
|
| 81 |
}
|
| 82 |
if crop not in crop_methods:
|
| 83 |
-
raise ValueError
|
| 84 |
self.crop = crop_methods[crop]
|
| 85 |
|
| 86 |
self.prompt = prompt
|
| 87 |
self.prompt_ids = prompt_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
def __len__(self):
|
| 91 |
max_len = (self.n_images - self.sequence_length) // self.stride + 1
|
|
|
|
| 92 |
if hasattr(self, 'num_class_images'):
|
| 93 |
max_len = max(max_len, self.num_class_images)
|
|
|
|
| 94 |
return max_len
|
| 95 |
|
| 96 |
def __getitem__(self, index):
|
| 97 |
return_batch = {}
|
| 98 |
-
frame_indices = self.get_frame_indices(index
|
| 99 |
frames = [self.load_frame(i) for i in frame_indices]
|
| 100 |
frames = self.transform(frames)
|
| 101 |
|
| 102 |
layout_ = []
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
masks = np.stack(mask) # shape: (n_sample_frame, c, h, w)
|
| 109 |
layout_.append(masks)
|
| 110 |
-
layout_ = np.stack(layout_)
|
| 111 |
-
|
| 112 |
merged_masks = []
|
| 113 |
for i in range(int(self.n_sample_frame)):
|
| 114 |
-
merged_mask_frame = np.sum(layout_[:,
|
| 115 |
-
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
|
| 116 |
merged_masks.append(merged_mask_frame)
|
| 117 |
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
|
| 118 |
masks = torch.from_numpy(masks).half()
|
| 119 |
|
| 120 |
-
layouts = rearrange(layout_,
|
| 121 |
layouts = torch.from_numpy(layouts).half()
|
| 122 |
|
| 123 |
-
return_batch.update(
|
|
|
|
| 124 |
"images": frames,
|
| 125 |
-
"masks":
|
| 126 |
-
"layouts":
|
| 127 |
"prompt_ids": self.prompt_ids,
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
return return_batch
|
| 131 |
|
| 132 |
def transform(self, frames):
|
|
@@ -141,18 +149,24 @@ class ImageSequenceDataset(Dataset):
|
|
| 141 |
frames = rearrange(np.stack(frames), "f h w c -> c f h w")
|
| 142 |
return torch.from_numpy(frames).div(255) * 2 - 1
|
| 143 |
|
| 144 |
-
def _read_mask(self,
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
| 148 |
mask = (mask > 0).astype(np.uint8)
|
| 149 |
-
#
|
| 150 |
height, width = mask.shape
|
| 151 |
dest_size = (width // 8, height // 8)
|
| 152 |
-
|
|
|
|
| 153 |
mask = mask[np.newaxis, ...]
|
|
|
|
| 154 |
return mask
|
| 155 |
|
|
|
|
| 156 |
def load_frame(self, index):
|
| 157 |
image_path = os.path.join(self.path, self.images[index])
|
| 158 |
return Image.open(image_path).convert(self.image_mode)
|
|
@@ -170,31 +184,12 @@ class ImageSequenceDataset(Dataset):
|
|
| 170 |
|
| 171 |
def get_class_indices(self, index):
|
| 172 |
frame_start = index
|
| 173 |
-
return (frame_start + i
|
| 174 |
|
| 175 |
@staticmethod
|
| 176 |
def get_image_list(path):
|
| 177 |
images = []
|
| 178 |
-
# 如果传入的是 mp4 文件,则先转换成 PNG 图像目录
|
| 179 |
-
if path.endswith('.mp4'):
|
| 180 |
-
path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video')
|
| 181 |
for file in sorted(os.listdir(path)):
|
| 182 |
if file.endswith(IMAGE_EXTENSION):
|
| 183 |
images.append(file)
|
| 184 |
-
return images
|
| 185 |
-
|
| 186 |
-
@staticmethod
|
| 187 |
-
def mp4_to_png(video_source: str, target_dir: str):
|
| 188 |
-
"""
|
| 189 |
-
Convert an mp4 video to a sequence of PNG images, storing them in target_dir.
|
| 190 |
-
target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1'
|
| 191 |
-
"""
|
| 192 |
-
if os.path.exists(target_dir):
|
| 193 |
-
shutil.rmtree(target_dir)
|
| 194 |
-
os.makedirs(target_dir, exist_ok=True)
|
| 195 |
-
|
| 196 |
-
reader = imageio.get_reader(video_source)
|
| 197 |
-
for i, im in enumerate(reader):
|
| 198 |
-
path = os.path.join(target_dir, f"{i:05d}.png")
|
| 199 |
-
cv2.imwrite(path, im[:, :, ::-1])
|
| 200 |
-
return target_dir
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
from einops import rearrange
|
|
|
|
| 11 |
from .transform import short_size_scale, random_crop, center_crop, offset_crop
|
| 12 |
from ..common.image_util import IMAGE_EXTENSION
|
| 13 |
import cv2
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class ImageSequenceDataset(Dataset):
|
| 16 |
def __init__(
|
| 17 |
self,
|
| 18 |
+
path: str,
|
| 19 |
+
layout_mask_dir: str,
|
| 20 |
+
layout_mask_order: list,
|
| 21 |
prompt_ids: torch.Tensor,
|
| 22 |
prompt: str,
|
| 23 |
+
start_sample_frame: int=0,
|
| 24 |
n_sample_frame: int = 8,
|
| 25 |
sampling_rate: int = 1,
|
| 26 |
+
stride: int = -1, # only used during tuning to sample a long video
|
| 27 |
image_mode: str = "RGB",
|
| 28 |
image_size: int = 512,
|
| 29 |
crop: str = "center",
|
| 30 |
+
|
| 31 |
+
class_data_root: str = None,
|
| 32 |
+
class_prompt_ids: torch.Tensor = None,
|
| 33 |
+
|
| 34 |
offset: dict = {
|
| 35 |
"left": 0,
|
| 36 |
"right": 0,
|
|
|
|
| 38 |
"bottom": 0
|
| 39 |
},
|
| 40 |
**args
|
| 41 |
+
|
| 42 |
):
|
| 43 |
+
self.path = path
|
| 44 |
+
self.images = self.get_image_list(path)
|
| 45 |
+
#
|
| 46 |
+
self.layout_mask_dir = layout_mask_dir
|
| 47 |
+
self.layout_mask_order = list(layout_mask_order)
|
|
|
|
| 48 |
|
| 49 |
+
layout_mask_dir0 = os.path.join(self.layout_mask_dir,self.layout_mask_order[0])
|
| 50 |
+
self.masks_index = self.get_image_list(layout_mask_dir0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
#
|
| 53 |
self.n_images = len(self.images)
|
| 54 |
self.offset = offset
|
| 55 |
self.start_sample_frame = start_sample_frame
|
| 56 |
if n_sample_frame < 0:
|
| 57 |
+
n_sample_frame = len(self.images)
|
| 58 |
self.n_sample_frame = n_sample_frame
|
| 59 |
+
# local sampling rate from the video
|
| 60 |
self.sampling_rate = sampling_rate
|
| 61 |
|
| 62 |
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
|
| 63 |
if self.n_images < self.sequence_length:
|
| 64 |
+
raise ValueError(f"self.n_images {self.n_images } < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images }")
|
| 65 |
|
| 66 |
+
# During tuning if video is too long, we sample the long video every self.stride globally
|
| 67 |
+
self.stride = stride if stride > 0 else (self.n_images+1)
|
| 68 |
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
|
| 69 |
|
| 70 |
self.image_mode = image_mode
|
|
|
|
| 74 |
"random": random_crop,
|
| 75 |
}
|
| 76 |
if crop not in crop_methods:
|
| 77 |
+
raise ValueError
|
| 78 |
self.crop = crop_methods[crop]
|
| 79 |
|
| 80 |
self.prompt = prompt
|
| 81 |
self.prompt_ids = prompt_ids
|
| 82 |
+
# Negative prompt for regularization to avoid overfitting during one-shot tuning
|
| 83 |
+
if class_data_root is not None:
|
| 84 |
+
self.class_data_root = Path(class_data_root)
|
| 85 |
+
self.class_images_path = sorted(list(self.class_data_root.iterdir()))
|
| 86 |
+
self.num_class_images = len(self.class_images_path)
|
| 87 |
+
self.class_prompt_ids = class_prompt_ids
|
| 88 |
|
| 89 |
|
| 90 |
def __len__(self):
|
| 91 |
max_len = (self.n_images - self.sequence_length) // self.stride + 1
|
| 92 |
+
|
| 93 |
if hasattr(self, 'num_class_images'):
|
| 94 |
max_len = max(max_len, self.num_class_images)
|
| 95 |
+
|
| 96 |
return max_len
|
| 97 |
|
| 98 |
def __getitem__(self, index):
|
| 99 |
return_batch = {}
|
| 100 |
+
frame_indices = self.get_frame_indices(index%self.video_len)
|
| 101 |
frames = [self.load_frame(i) for i in frame_indices]
|
| 102 |
frames = self.transform(frames)
|
| 103 |
|
| 104 |
layout_ = []
|
| 105 |
+
for layout_name in self.layout_mask_order:
|
| 106 |
+
frame_indices = self.get_frame_indices(index%self.video_len)
|
| 107 |
+
layout_mask_dir = os.path.join(self.layout_mask_dir,layout_name)
|
| 108 |
+
mask = [self._read_mask(layout_mask_dir,i) for i in frame_indices]
|
| 109 |
+
masks = np.stack(mask)
|
|
|
|
| 110 |
layout_.append(masks)
|
| 111 |
+
layout_ = np.stack(layout_)
|
|
|
|
| 112 |
merged_masks = []
|
| 113 |
for i in range(int(self.n_sample_frame)):
|
| 114 |
+
merged_mask_frame = np.sum(layout_[:,i,:,:,:], axis=0)
|
| 115 |
+
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
|
| 116 |
merged_masks.append(merged_mask_frame)
|
| 117 |
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
|
| 118 |
masks = torch.from_numpy(masks).half()
|
| 119 |
|
| 120 |
+
layouts = rearrange(layout_,"s f c h w -> f s c h w" )
|
| 121 |
layouts = torch.from_numpy(layouts).half()
|
| 122 |
|
| 123 |
+
return_batch.update(
|
| 124 |
+
{
|
| 125 |
"images": frames,
|
| 126 |
+
"masks":masks,
|
| 127 |
+
"layouts":layouts,
|
| 128 |
"prompt_ids": self.prompt_ids,
|
| 129 |
+
}
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if hasattr(self, 'class_data_root'):
|
| 133 |
+
class_index = index % (self.num_class_images - self.n_sample_frame)
|
| 134 |
+
class_indices = self.get_class_indices(class_index)
|
| 135 |
+
frames = [self.load_class_frame(i) for i in class_indices]
|
| 136 |
+
return_batch["class_images"] = self.tensorize_frames(frames)
|
| 137 |
+
return_batch["class_prompt_ids"] = self.class_prompt_ids
|
| 138 |
return return_batch
|
| 139 |
|
| 140 |
def transform(self, frames):
|
|
|
|
| 149 |
frames = rearrange(np.stack(frames), "f h w c -> c f h w")
|
| 150 |
return torch.from_numpy(frames).div(255) * 2 - 1
|
| 151 |
|
| 152 |
+
def _read_mask(self, mask_path,index: int):
|
| 153 |
+
### read mask by pil
|
| 154 |
+
|
| 155 |
+
mask_path = os.path.join(mask_path,f"{index:05d}.png")
|
| 156 |
+
|
| 157 |
+
### read mask by cv2
|
| 158 |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
| 159 |
mask = (mask > 0).astype(np.uint8)
|
| 160 |
+
# Determine dynamic destination size
|
| 161 |
height, width = mask.shape
|
| 162 |
dest_size = (width // 8, height // 8)
|
| 163 |
+
# Resize using nearest neighbor interpolation
|
| 164 |
+
mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) #cv2.INTER_CUBIC
|
| 165 |
mask = mask[np.newaxis, ...]
|
| 166 |
+
|
| 167 |
return mask
|
| 168 |
|
| 169 |
+
|
| 170 |
def load_frame(self, index):
|
| 171 |
image_path = os.path.join(self.path, self.images[index])
|
| 172 |
return Image.open(image_path).convert(self.image_mode)
|
|
|
|
| 184 |
|
| 185 |
def get_class_indices(self, index):
|
| 186 |
frame_start = index
|
| 187 |
+
return (frame_start + i for i in range(self.n_sample_frame))
|
| 188 |
|
| 189 |
@staticmethod
|
| 190 |
def get_image_list(path):
|
| 191 |
images = []
|
|
|
|
|
|
|
|
|
|
| 192 |
for file in sorted(os.listdir(path)):
|
| 193 |
if file.endswith(IMAGE_EXTENSION):
|
| 194 |
images.append(file)
|
| 195 |
+
return images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|