Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Upload 14 files
Browse files- data/__init__.py +2 -0
- data/configs/example.yaml +45 -0
- data/data_utils.py +177 -0
- data/dataset_base.py +620 -0
- data/dataset_info.py +39 -0
- data/distributed_iterable_dataset.py +58 -0
- data/interleave_datasets/__init__.py +5 -0
- data/interleave_datasets/edit_dataset.py +72 -0
- data/interleave_datasets/interleave_t2i_dataset.py +212 -0
- data/parquet_utils.py +90 -0
- data/t2i_dataset.py +128 -0
- data/transforms.py +287 -0
- data/video_utils.py +165 -0
- data/vlm_dataset.py +195 -0
    	
        data/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
    	
        data/configs/example.yaml
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            t2i_pretrain:
         | 
| 2 | 
            +
              dataset_names:
         | 
| 3 | 
            +
              - t2i
         | 
| 4 | 
            +
              image_transform_args:
         | 
| 5 | 
            +
                image_stride: 16
         | 
| 6 | 
            +
                max_image_size: 1024
         | 
| 7 | 
            +
                min_image_size: 512
         | 
| 8 | 
            +
              is_mandatory: true
         | 
| 9 | 
            +
              num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS
         | 
| 10 | 
            +
              - 10
         | 
| 11 | 
            +
              weight: 1
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            unified_edit:
         | 
| 14 | 
            +
              dataset_names:
         | 
| 15 | 
            +
              - seedxedit_multi
         | 
| 16 | 
            +
              image_transform_args:
         | 
| 17 | 
            +
                image_stride: 16
         | 
| 18 | 
            +
                max_image_size: 1024
         | 
| 19 | 
            +
                min_image_size: 512
         | 
| 20 | 
            +
              vit_image_transform_args:
         | 
| 21 | 
            +
                image_stride: 14
         | 
| 22 | 
            +
                max_image_size: 518
         | 
| 23 | 
            +
                min_image_size: 224
         | 
| 24 | 
            +
              is_mandatory: false
         | 
| 25 | 
            +
              num_used_data:
         | 
| 26 | 
            +
              - 10
         | 
| 27 | 
            +
              weight: 1
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            vlm_sft:
         | 
| 30 | 
            +
              dataset_names:
         | 
| 31 | 
            +
              - llava_ov
         | 
| 32 | 
            +
              image_transform_args:
         | 
| 33 | 
            +
                image_stride: 14
         | 
| 34 | 
            +
                max_image_size: 980
         | 
| 35 | 
            +
                min_image_size: 378
         | 
| 36 | 
            +
                max_pixels: 2_007_040
         | 
| 37 | 
            +
              frame_sampler_args:
         | 
| 38 | 
            +
                max_num_frames: 12
         | 
| 39 | 
            +
                min_num_frames: 8
         | 
| 40 | 
            +
              is_mandatory: true
         | 
| 41 | 
            +
              shuffle_lines: True
         | 
| 42 | 
            +
              shuffle_seed: 0
         | 
| 43 | 
            +
              num_used_data:
         | 
| 44 | 
            +
              - 1000
         | 
| 45 | 
            +
              weight: 1
         | 
    	
        data/data_utils.py
    ADDED
    
    | @@ -0,0 +1,177 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from torch.nn.attention.flex_attention import or_masks, and_masks
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def create_sparse_mask(document_lens, split_lens, attn_modes, device):
         | 
| 14 | 
            +
                def causal_mask(b, h, q_idx, kv_idx):
         | 
| 15 | 
            +
                    return q_idx >= kv_idx
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def full_and_noise_mask(b, h, q_idx, kv_idx):
         | 
| 18 | 
            +
                    return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def remove_noise_mask(b, h, q_idx, kv_idx):
         | 
| 21 | 
            +
                    return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def sample_mask(b, h, q_idx, kv_idx):
         | 
| 24 | 
            +
                    return document_id[q_idx] == document_id[kv_idx]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                full_and_noise_tmp = []
         | 
| 27 | 
            +
                noise_tmp = []
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
         | 
| 30 | 
            +
                    value = i if model in ['full', 'noise'] else -1
         | 
| 31 | 
            +
                    full_and_noise_tmp.extend([value] * length)
         | 
| 32 | 
            +
                    value_noise = i if model == 'noise' else -1
         | 
| 33 | 
            +
                    noise_tmp.extend([value_noise] * length)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
         | 
| 36 | 
            +
                noise_seq_id = torch.Tensor(noise_tmp).to(device)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def patchify(image, patch_size):
         | 
| 44 | 
            +
                p = patch_size
         | 
| 45 | 
            +
                c, h, w = image.shape
         | 
| 46 | 
            +
                assert h % p == 0 and w % p == 0
         | 
| 47 | 
            +
                image = image.reshape(c, h // p, p, w // p, p)
         | 
| 48 | 
            +
                image = torch.einsum("chpwq->hwpqc", image)
         | 
| 49 | 
            +
                image = image.reshape(-1, p**2 * c)
         | 
| 50 | 
            +
                return image
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
         | 
| 54 | 
            +
                num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
         | 
| 55 | 
            +
                coords_h = torch.arange(0, num_patches_h)
         | 
| 56 | 
            +
                coords_w = torch.arange(0, num_patches_w)
         | 
| 57 | 
            +
                pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
         | 
| 58 | 
            +
                return pos_ids
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
         | 
| 62 | 
            +
                num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
         | 
| 63 | 
            +
                boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
         | 
| 64 | 
            +
                fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
         | 
| 65 | 
            +
                fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
         | 
| 66 | 
            +
                bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
         | 
| 67 | 
            +
                bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
         | 
| 68 | 
            +
                pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
         | 
| 69 | 
            +
                return pos_ids
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within 
         | 
| 75 | 
            +
                    a sample, where each sample contains multiple splits with different attn modes.
         | 
| 76 | 
            +
                nested_attn_modes: whether to use full attn in each split.
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                sample_len = sum(split_lens)
         | 
| 79 | 
            +
                attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                csum = 0
         | 
| 82 | 
            +
                for s, attn_mode in zip(split_lens, attn_modes):
         | 
| 83 | 
            +
                    assert attn_mode in ['causal', 'full', 'noise']
         | 
| 84 | 
            +
                    if attn_mode == "causal":
         | 
| 85 | 
            +
                        attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
         | 
| 86 | 
            +
                        attention_mask[csum:csum + s, :csum] = 1
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
         | 
| 89 | 
            +
                        attention_mask[csum:csum + s, :csum] = 1
         | 
| 90 | 
            +
                    csum += s
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                csum = 0
         | 
| 93 | 
            +
                for s, attn_mode in zip(split_lens, attn_modes):
         | 
| 94 | 
            +
                    if attn_mode == "noise":
         | 
| 95 | 
            +
                        attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
         | 
| 96 | 
            +
                        attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
         | 
| 97 | 
            +
                    csum += s
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
         | 
| 100 | 
            +
                    ~attention_mask, float("-inf")
         | 
| 101 | 
            +
                )
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                return attention_mask
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            def split_integer_exp_decay(S, ng_sample_decay=1.0):
         | 
| 107 | 
            +
                if ng_sample_decay == 1.0:
         | 
| 108 | 
            +
                    N = random.randint(1, S)
         | 
| 109 | 
            +
                else:
         | 
| 110 | 
            +
                    base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
         | 
| 111 | 
            +
                    p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
         | 
| 112 | 
            +
                    N = random.choices(list(range(1, S + 1)), p, k=1)[0]
         | 
| 113 | 
            +
                cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
         | 
| 114 | 
            +
                result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
         | 
| 115 | 
            +
                return result, cumsum
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def pil_img2rgb(image):
         | 
| 119 | 
            +
                if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
         | 
| 120 | 
            +
                    image = image.convert("RGBA")
         | 
| 121 | 
            +
                    white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
         | 
| 122 | 
            +
                    white.paste(image, mask=image.split()[3])
         | 
| 123 | 
            +
                    image = white
         | 
| 124 | 
            +
                else:
         | 
| 125 | 
            +
                    image = image.convert("RGB")
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                return image
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            def add_special_tokens(tokenizer):
         | 
| 131 | 
            +
                all_special_tokens = []
         | 
| 132 | 
            +
                for k, v in tokenizer.special_tokens_map.items():
         | 
| 133 | 
            +
                    if isinstance(v, str):
         | 
| 134 | 
            +
                        all_special_tokens.append(v)
         | 
| 135 | 
            +
                    elif isinstance(v, list):
         | 
| 136 | 
            +
                        all_special_tokens += v
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                new_tokens = []
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                if '<|im_start|>' not in all_special_tokens:
         | 
| 141 | 
            +
                    new_tokens.append('<|im_start|>')
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                if '<|im_end|>' not in all_special_tokens:
         | 
| 144 | 
            +
                    new_tokens.append('<|im_end|>')
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                if '<|vision_start|>' not in all_special_tokens:
         | 
| 147 | 
            +
                    new_tokens.append('<|vision_start|>')
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                if '<|vision_end|>' not in all_special_tokens:
         | 
| 150 | 
            +
                    new_tokens.append('<|vision_end|>')
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                num_new_tokens = tokenizer.add_tokens(new_tokens)
         | 
| 153 | 
            +
                bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
         | 
| 154 | 
            +
                eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
         | 
| 155 | 
            +
                start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
         | 
| 156 | 
            +
                end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                new_token_ids = dict(
         | 
| 159 | 
            +
                    bos_token_id=bos_token_id, 
         | 
| 160 | 
            +
                    eos_token_id=eos_token_id, 
         | 
| 161 | 
            +
                    start_of_image=start_of_image, 
         | 
| 162 | 
            +
                    end_of_image=end_of_image, 
         | 
| 163 | 
            +
                )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                return tokenizer, new_token_ids, num_new_tokens
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            def len2weight(x, loss_reduction='square'):
         | 
| 169 | 
            +
                if x == 0:
         | 
| 170 | 
            +
                    return x
         | 
| 171 | 
            +
                if loss_reduction == 'token':
         | 
| 172 | 
            +
                    return 1
         | 
| 173 | 
            +
                if loss_reduction == 'sample':
         | 
| 174 | 
            +
                    return 1 / x
         | 
| 175 | 
            +
                if loss_reduction == 'square':
         | 
| 176 | 
            +
                    return 1 / (x ** 0.5)
         | 
| 177 | 
            +
                raise NotImplementedError(loss_reduction)
         | 
    	
        data/dataset_base.py
    ADDED
    
    | @@ -0,0 +1,620 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .data_utils import (
         | 
| 12 | 
            +
                get_flattened_position_ids_interpolate,
         | 
| 13 | 
            +
                get_flattened_position_ids_extrapolate, 
         | 
| 14 | 
            +
                len2weight,
         | 
| 15 | 
            +
                patchify, 
         | 
| 16 | 
            +
                prepare_attention_mask_per_sample, 
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
            from .dataset_info import DATASET_INFO, DATASET_REGISTRY
         | 
| 19 | 
            +
            from .transforms import ImageTransform
         | 
| 20 | 
            +
            from .video_utils import FrameSampler
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class DataConfig:
         | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self, 
         | 
| 26 | 
            +
                    grouped_datasets, 
         | 
| 27 | 
            +
                    text_cond_dropout_prob=0.1,
         | 
| 28 | 
            +
                    vit_cond_dropout_prob=0.4,
         | 
| 29 | 
            +
                    vae_cond_dropout_prob=0.1,
         | 
| 30 | 
            +
                    vae_image_downsample=16,
         | 
| 31 | 
            +
                    max_latent_size=32,
         | 
| 32 | 
            +
                    vit_patch_size=14,
         | 
| 33 | 
            +
                    max_num_patch_per_side=70,
         | 
| 34 | 
            +
                ):
         | 
| 35 | 
            +
                    self.grouped_datasets = grouped_datasets
         | 
| 36 | 
            +
                    self.text_cond_dropout_prob = text_cond_dropout_prob
         | 
| 37 | 
            +
                    self.vit_cond_dropout_prob = vit_cond_dropout_prob
         | 
| 38 | 
            +
                    self.vit_patch_size = vit_patch_size
         | 
| 39 | 
            +
                    self.max_num_patch_per_side = max_num_patch_per_side
         | 
| 40 | 
            +
                    self.vae_cond_dropout_prob = vae_cond_dropout_prob
         | 
| 41 | 
            +
                    self.vae_image_downsample = vae_image_downsample
         | 
| 42 | 
            +
                    self.max_latent_size = max_latent_size
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class PackedDataset(torch.utils.data.IterableDataset):
         | 
| 46 | 
            +
                def __init__(
         | 
| 47 | 
            +
                    self, 
         | 
| 48 | 
            +
                    data_config, 
         | 
| 49 | 
            +
                    tokenizer, 
         | 
| 50 | 
            +
                    special_tokens,
         | 
| 51 | 
            +
                    local_rank, 
         | 
| 52 | 
            +
                    world_size, 
         | 
| 53 | 
            +
                    num_workers,
         | 
| 54 | 
            +
                    expected_num_tokens=32768, 
         | 
| 55 | 
            +
                    max_num_tokens_per_sample=16384,
         | 
| 56 | 
            +
                    max_num_tokens=36864,
         | 
| 57 | 
            +
                    prefer_buffer_before=16384,
         | 
| 58 | 
            +
                    max_buffer_size=50,
         | 
| 59 | 
            +
                    interpolate_pos=False,
         | 
| 60 | 
            +
                    use_flex=False,
         | 
| 61 | 
            +
                    data_status=None,
         | 
| 62 | 
            +
                ):
         | 
| 63 | 
            +
                    super().__init__()
         | 
| 64 | 
            +
                    self.expected_num_tokens = expected_num_tokens
         | 
| 65 | 
            +
                    self.max_num_tokens_per_sample = max_num_tokens_per_sample
         | 
| 66 | 
            +
                    self.prefer_buffer_before = prefer_buffer_before
         | 
| 67 | 
            +
                    self.max_num_tokens = max_num_tokens
         | 
| 68 | 
            +
                    self.max_buffer_size = max_buffer_size
         | 
| 69 | 
            +
                    self.tokenizer = tokenizer
         | 
| 70 | 
            +
                    self.local_rank = local_rank
         | 
| 71 | 
            +
                    self.world_size = world_size
         | 
| 72 | 
            +
                    self.num_workers = num_workers
         | 
| 73 | 
            +
                    self.use_flex = use_flex
         | 
| 74 | 
            +
                    for k, v in special_tokens.items():
         | 
| 75 | 
            +
                        setattr(self, k, v)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
         | 
| 78 | 
            +
                        data_config.grouped_datasets, data_status
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
                    self.grouped_datasets = grouped_datasets
         | 
| 81 | 
            +
                    self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
         | 
| 82 | 
            +
                    self.is_mandatory = is_mandatory
         | 
| 83 | 
            +
                    self.grouped_weights = grouped_weights
         | 
| 84 | 
            +
                    self.data_config = data_config
         | 
| 85 | 
            +
                    self.interpolate_pos = interpolate_pos
         | 
| 86 | 
            +
                    if self.interpolate_pos:
         | 
| 87 | 
            +
                        self.get_flattened_position_ids = get_flattened_position_ids_interpolate
         | 
| 88 | 
            +
                    else:
         | 
| 89 | 
            +
                        self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def build_datasets(self, datasets_metainfo, data_status):
         | 
| 92 | 
            +
                    datasets = []
         | 
| 93 | 
            +
                    is_mandatory = []
         | 
| 94 | 
            +
                    grouped_weights = []
         | 
| 95 | 
            +
                    for grouped_dataset_name, dataset_args in datasets_metainfo.items():
         | 
| 96 | 
            +
                        is_mandatory.append(dataset_args.pop('is_mandatory', False))
         | 
| 97 | 
            +
                        grouped_weights.append(dataset_args.pop('weight', 0.0))
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                        if 'frame_sampler_args' in dataset_args.keys():
         | 
| 100 | 
            +
                            frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args'))
         | 
| 101 | 
            +
                            dataset_args['frame_sampler'] = frame_sampler
         | 
| 102 | 
            +
                        if 'image_transform_args' in dataset_args.keys():
         | 
| 103 | 
            +
                            transform = ImageTransform(**dataset_args.pop('image_transform_args'))
         | 
| 104 | 
            +
                            dataset_args['transform'] = transform
         | 
| 105 | 
            +
                        if 'vit_image_transform_args' in dataset_args.keys():
         | 
| 106 | 
            +
                            vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args'))
         | 
| 107 | 
            +
                            dataset_args['vit_transform'] = vit_transform
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        assert 'dataset_names' in dataset_args.keys()
         | 
| 110 | 
            +
                        dataset_names = dataset_args.pop('dataset_names')
         | 
| 111 | 
            +
                        dataset_args['data_dir_list'] = []
         | 
| 112 | 
            +
                        for item in dataset_names:
         | 
| 113 | 
            +
                            if self.local_rank == 0:
         | 
| 114 | 
            +
                                print(f'Preparing Dataset {grouped_dataset_name}/{item}')
         | 
| 115 | 
            +
                            meta_info = DATASET_INFO[grouped_dataset_name][item]
         | 
| 116 | 
            +
                            dataset_args['data_dir_list'].append(meta_info['data_dir'])
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                            if "parquet_info_path" in meta_info.keys():
         | 
| 119 | 
            +
                                if 'parquet_info' not in dataset_args.keys():
         | 
| 120 | 
            +
                                    dataset_args['parquet_info'] = {}
         | 
| 121 | 
            +
                                with open(meta_info['parquet_info_path'], 'r') as f:
         | 
| 122 | 
            +
                                    parquet_info = json.load(f)
         | 
| 123 | 
            +
                                dataset_args['parquet_info'].update(parquet_info)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                            if 'json_dir' in meta_info.keys():
         | 
| 126 | 
            +
                                # parquet/tar with json
         | 
| 127 | 
            +
                                if 'json_dir_list' not in dataset_args.keys():
         | 
| 128 | 
            +
                                    dataset_args['json_dir_list'] = [meta_info['json_dir']]
         | 
| 129 | 
            +
                                else:
         | 
| 130 | 
            +
                                    dataset_args['json_dir_list'].append(meta_info['json_dir'])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                            if 'jsonl_path' in meta_info.keys():
         | 
| 133 | 
            +
                                # jsonl with jpeg
         | 
| 134 | 
            +
                                if 'jsonl_path_list' not in dataset_args.keys():
         | 
| 135 | 
            +
                                    dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
         | 
| 136 | 
            +
                                else:
         | 
| 137 | 
            +
                                    dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        resume_data_status = dataset_args.pop('resume_data_status', True)
         | 
| 140 | 
            +
                        if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
         | 
| 141 | 
            +
                            data_status_per_group = data_status[grouped_dataset_name]
         | 
| 142 | 
            +
                        else:
         | 
| 143 | 
            +
                            data_status_per_group = None
         | 
| 144 | 
            +
                        dataset = DATASET_REGISTRY[grouped_dataset_name](
         | 
| 145 | 
            +
                            dataset_name=grouped_dataset_name,
         | 
| 146 | 
            +
                            tokenizer=self.tokenizer,
         | 
| 147 | 
            +
                            local_rank=self.local_rank,
         | 
| 148 | 
            +
                            world_size=self.world_size,
         | 
| 149 | 
            +
                            num_workers=self.num_workers,
         | 
| 150 | 
            +
                            data_status=data_status_per_group,
         | 
| 151 | 
            +
                            **dataset_args
         | 
| 152 | 
            +
                        )
         | 
| 153 | 
            +
                        datasets.append(dataset)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    return datasets, is_mandatory, grouped_weights
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def set_epoch(self, seed):
         | 
| 158 | 
            +
                    for dataset in self.grouped_datasets:
         | 
| 159 | 
            +
                        dataset.set_epoch(seed)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def set_sequence_status(self):
         | 
| 162 | 
            +
                    sequence_status = dict(
         | 
| 163 | 
            +
                        curr                        = 0,
         | 
| 164 | 
            +
                        sample_lens                 = list(),
         | 
| 165 | 
            +
                        packed_position_ids         = list(),
         | 
| 166 | 
            +
                        nested_attention_masks      = list(),
         | 
| 167 | 
            +
                        split_lens                  = list(),
         | 
| 168 | 
            +
                        attn_modes                  = list(),
         | 
| 169 | 
            +
                        packed_text_ids             = list(), 
         | 
| 170 | 
            +
                        packed_text_indexes         = list(),
         | 
| 171 | 
            +
                        packed_label_ids            = list(),
         | 
| 172 | 
            +
                        ce_loss_indexes             = list(),
         | 
| 173 | 
            +
                        ce_loss_weights             = list(),
         | 
| 174 | 
            +
                        vae_image_tensors           = list(), 
         | 
| 175 | 
            +
                        packed_latent_position_ids  = list(),
         | 
| 176 | 
            +
                        vae_latent_shapes           = list(), 
         | 
| 177 | 
            +
                        packed_vae_token_indexes    = list(), 
         | 
| 178 | 
            +
                        packed_timesteps            = list(), 
         | 
| 179 | 
            +
                        mse_loss_indexes            = list(),
         | 
| 180 | 
            +
                        packed_vit_tokens           = list(), 
         | 
| 181 | 
            +
                        vit_token_seqlens           = list(),
         | 
| 182 | 
            +
                        packed_vit_position_ids     = list(),
         | 
| 183 | 
            +
                        packed_vit_token_indexes    = list(), 
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
                    return sequence_status
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def to_tensor(self, sequence_status):
         | 
| 188 | 
            +
                    data = dict(
         | 
| 189 | 
            +
                        sequence_length=sum(sequence_status['sample_lens']),
         | 
| 190 | 
            +
                        sample_lens=sequence_status['sample_lens'],
         | 
| 191 | 
            +
                        packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
         | 
| 192 | 
            +
                        packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
         | 
| 193 | 
            +
                        packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
         | 
| 194 | 
            +
                    )
         | 
| 195 | 
            +
                    if not self.use_flex:
         | 
| 196 | 
            +
                        data['nested_attention_masks'] = sequence_status['nested_attention_masks']
         | 
| 197 | 
            +
                    else:
         | 
| 198 | 
            +
                        sequence_len = data['sequence_length']
         | 
| 199 | 
            +
                        pad_len = self.max_num_tokens - sequence_len
         | 
| 200 | 
            +
                        data['split_lens'] = sequence_status['split_lens'] + [pad_len]
         | 
| 201 | 
            +
                        data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
         | 
| 202 | 
            +
                        data['sample_lens'] += [pad_len]
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # if the model has a convnet vae (e.g., as visual tokenizer)
         | 
| 205 | 
            +
                    if len(sequence_status['vae_image_tensors']) > 0:
         | 
| 206 | 
            +
                        image_tensors = sequence_status.pop('vae_image_tensors')
         | 
| 207 | 
            +
                        image_sizes = [item.shape for item in image_tensors]
         | 
| 208 | 
            +
                        max_image_size = [max(item) for item in list(zip(*image_sizes))]
         | 
| 209 | 
            +
                        padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
         | 
| 210 | 
            +
                        for i, image_tensor in enumerate(image_tensors):
         | 
| 211 | 
            +
                            padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                        data['padded_images'] = padded_images
         | 
| 214 | 
            +
                        data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
         | 
| 215 | 
            +
                        data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
         | 
| 216 | 
            +
                        data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    # if the model has a vit (e.g., as visual tokenizer)
         | 
| 219 | 
            +
                    if len(sequence_status['packed_vit_tokens']) > 0:
         | 
| 220 | 
            +
                        data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
         | 
| 221 | 
            +
                        data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
         | 
| 222 | 
            +
                        data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
         | 
| 223 | 
            +
                        data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    # if the model is required to perform visual generation
         | 
| 226 | 
            +
                    if len(sequence_status['packed_timesteps']) > 0:
         | 
| 227 | 
            +
                        data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
         | 
| 228 | 
            +
                        data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # if the model is required to perform text generation
         | 
| 231 | 
            +
                    if len(sequence_status['packed_label_ids']) > 0:
         | 
| 232 | 
            +
                        data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
         | 
| 233 | 
            +
                        data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
         | 
| 234 | 
            +
                        data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    return data
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                def __iter__(self):
         | 
| 239 | 
            +
                    total_weights = sum(self.grouped_weights)
         | 
| 240 | 
            +
                    assert total_weights > 0.0
         | 
| 241 | 
            +
                    group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights 
         | 
| 242 | 
            +
                                      for i in range(len(self.grouped_weights))]
         | 
| 243 | 
            +
                    sequence_status = self.set_sequence_status()
         | 
| 244 | 
            +
                    batch_data_indexes = []
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    buffer = []
         | 
| 247 | 
            +
                    while True:
         | 
| 248 | 
            +
                        # Ensure at least one sample from each group
         | 
| 249 | 
            +
                        if sequence_status['curr'] == 0:
         | 
| 250 | 
            +
                            for group_index, group_iter in enumerate(self.dataset_iters):
         | 
| 251 | 
            +
                                if self.is_mandatory[group_index]:
         | 
| 252 | 
            +
                                    while True:
         | 
| 253 | 
            +
                                        sample = next(group_iter)
         | 
| 254 | 
            +
                                        # if a sample is too long, skip it
         | 
| 255 | 
            +
                                        num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
         | 
| 256 | 
            +
                                        if num_tokens < self.max_num_tokens_per_sample:
         | 
| 257 | 
            +
                                            sequence_status = self.pack_sequence(sample, sequence_status)
         | 
| 258 | 
            +
                                            batch_data_indexes.append(sample['data_indexes'])
         | 
| 259 | 
            +
                                            break
         | 
| 260 | 
            +
                                        else:
         | 
| 261 | 
            +
                                            print(f"skip a sample with length {num_tokens}")
         | 
| 262 | 
            +
                                            continue
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0:
         | 
| 265 | 
            +
                            sample = buffer.pop(0)
         | 
| 266 | 
            +
                            sample_from_buffer = True
         | 
| 267 | 
            +
                        else:
         | 
| 268 | 
            +
                            # sample normally across all groups
         | 
| 269 | 
            +
                            n = random.random()
         | 
| 270 | 
            +
                            group_index = 0
         | 
| 271 | 
            +
                            for i, cumprob in enumerate(group_cumprobs):
         | 
| 272 | 
            +
                                if n < cumprob:
         | 
| 273 | 
            +
                                    group_index = i
         | 
| 274 | 
            +
                                    break
         | 
| 275 | 
            +
                            sample = next(self.dataset_iters[group_index])
         | 
| 276 | 
            +
                            sample_from_buffer = False
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                        # if a sample is too long, skip it
         | 
| 279 | 
            +
                        num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
         | 
| 280 | 
            +
                        if num_tokens > self.max_num_tokens_per_sample:
         | 
| 281 | 
            +
                            print(f"skip a sample with length {num_tokens}")
         | 
| 282 | 
            +
                            continue
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                        if sequence_status['curr'] + num_tokens > self.max_num_tokens:
         | 
| 285 | 
            +
                            if len(buffer) < self.max_buffer_size and not sample_from_buffer:
         | 
| 286 | 
            +
                                buffer.append(sample)
         | 
| 287 | 
            +
                            else:
         | 
| 288 | 
            +
                                print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
         | 
| 289 | 
            +
                                data = self.to_tensor(sequence_status)
         | 
| 290 | 
            +
                                data['batch_data_indexes'] = batch_data_indexes
         | 
| 291 | 
            +
                                yield data
         | 
| 292 | 
            +
                                sequence_status = self.set_sequence_status()
         | 
| 293 | 
            +
                                batch_data_indexes = []
         | 
| 294 | 
            +
                            continue
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                        sequence_status = self.pack_sequence(sample, sequence_status)
         | 
| 297 | 
            +
                        batch_data_indexes.append(sample['data_indexes'])
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                        if sequence_status['curr'] >= self.expected_num_tokens:
         | 
| 300 | 
            +
                            data = self.to_tensor(sequence_status)
         | 
| 301 | 
            +
                            data['batch_data_indexes'] = batch_data_indexes
         | 
| 302 | 
            +
                            yield data
         | 
| 303 | 
            +
                            sequence_status = self.set_sequence_status()
         | 
| 304 | 
            +
                            batch_data_indexes = []
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def pack_sequence(self, sample, sequence_status):
         | 
| 307 | 
            +
                    image_tensor_list = sample['image_tensor_list']
         | 
| 308 | 
            +
                    text_ids_list = sample['text_ids_list']
         | 
| 309 | 
            +
                    sequence_plan = sample['sequence_plan']
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    split_lens, attn_modes = list(), list()
         | 
| 312 | 
            +
                    curr = sequence_status['curr']
         | 
| 313 | 
            +
                    curr_rope_id = 0
         | 
| 314 | 
            +
                    sample_lens = 0
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    for item in sequence_plan:
         | 
| 317 | 
            +
                        split_start = item.get('split_start', True)
         | 
| 318 | 
            +
                        if split_start:
         | 
| 319 | 
            +
                            curr_split_len = 0
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                        if item['type'] == 'text':
         | 
| 322 | 
            +
                            text_ids = text_ids_list.pop(0)
         | 
| 323 | 
            +
                            if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob:
         | 
| 324 | 
            +
                                continue
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                            shifted_text_ids = [self.bos_token_id] + text_ids
         | 
| 327 | 
            +
                            sequence_status['packed_text_ids'].extend(shifted_text_ids)
         | 
| 328 | 
            +
                            sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
         | 
| 329 | 
            +
                            if item['loss'] == 1:
         | 
| 330 | 
            +
                                sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
         | 
| 331 | 
            +
                                sequence_status['ce_loss_weights'].extend(
         | 
| 332 | 
            +
                                    [len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
         | 
| 333 | 
            +
                                )
         | 
| 334 | 
            +
                                sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
         | 
| 335 | 
            +
                            curr += len(shifted_text_ids)
         | 
| 336 | 
            +
                            curr_split_len += len(shifted_text_ids)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                            # add a <|im_end|> token
         | 
| 339 | 
            +
                            sequence_status['packed_text_ids'].append(self.eos_token_id)
         | 
| 340 | 
            +
                            sequence_status['packed_text_indexes'].append(curr)
         | 
| 341 | 
            +
                            if item['special_token_loss'] == 1: # <|im_end|> may have loss
         | 
| 342 | 
            +
                                sequence_status['ce_loss_indexes'].append(curr)
         | 
| 343 | 
            +
                                sequence_status['ce_loss_weights'].append(1.0)
         | 
| 344 | 
            +
                                sequence_status['packed_label_ids'].append(item['special_token_label'])
         | 
| 345 | 
            +
                            curr += 1
         | 
| 346 | 
            +
                            curr_split_len += 1
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                            # update sequence status
         | 
| 349 | 
            +
                            attn_modes.append("causal")
         | 
| 350 | 
            +
                            sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
         | 
| 351 | 
            +
                            curr_rope_id += curr_split_len
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                        elif item['type'] == 'vit_image':
         | 
| 354 | 
            +
                            image_tensor = image_tensor_list.pop(0)
         | 
| 355 | 
            +
                            if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob:
         | 
| 356 | 
            +
                                curr_rope_id += 1
         | 
| 357 | 
            +
                                continue
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                            # add a <|startofimage|> token
         | 
| 360 | 
            +
                            sequence_status['packed_text_ids'].append(self.start_of_image)
         | 
| 361 | 
            +
                            sequence_status['packed_text_indexes'].append(curr)
         | 
| 362 | 
            +
                            curr += 1
         | 
| 363 | 
            +
                            curr_split_len += 1
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                            # preprocess image
         | 
| 366 | 
            +
                            vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
         | 
| 367 | 
            +
                            num_img_tokens = vit_tokens.shape[0]
         | 
| 368 | 
            +
                            sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
         | 
| 369 | 
            +
                            curr += num_img_tokens
         | 
| 370 | 
            +
                            curr_split_len += num_img_tokens
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                            sequence_status['packed_vit_tokens'].append(vit_tokens)
         | 
| 373 | 
            +
                            sequence_status['vit_token_seqlens'].append(num_img_tokens)
         | 
| 374 | 
            +
                            sequence_status['packed_vit_position_ids'].append(
         | 
| 375 | 
            +
                                self.get_flattened_position_ids(
         | 
| 376 | 
            +
                                    image_tensor.size(1), image_tensor.size(2),
         | 
| 377 | 
            +
                                    self.data_config.vit_patch_size, 
         | 
| 378 | 
            +
                                    max_num_patches_per_side=self.data_config.max_num_patch_per_side
         | 
| 379 | 
            +
                                )
         | 
| 380 | 
            +
                            )
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                            # add a <|endofimage|> token
         | 
| 383 | 
            +
                            sequence_status['packed_text_ids'].append(self.end_of_image)
         | 
| 384 | 
            +
                            sequence_status['packed_text_indexes'].append(curr)
         | 
| 385 | 
            +
                            if item['special_token_loss'] == 1: # <|endofimage|> may have loss
         | 
| 386 | 
            +
                                sequence_status['ce_loss_indexes'].append(curr)
         | 
| 387 | 
            +
                                sequence_status['ce_loss_weights'].append(1.0)
         | 
| 388 | 
            +
                                sequence_status['packed_label_ids'].append(item['special_token_label'])
         | 
| 389 | 
            +
                            curr += 1
         | 
| 390 | 
            +
                            curr_split_len += 1
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                            # update sequence status
         | 
| 393 | 
            +
                            attn_modes.append("full")
         | 
| 394 | 
            +
                            sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
         | 
| 395 | 
            +
                            curr_rope_id += 1
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                        elif item['type'] == 'vae_image':
         | 
| 398 | 
            +
                            image_tensor = image_tensor_list.pop(0)
         | 
| 399 | 
            +
                            if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob:
         | 
| 400 | 
            +
                                # FIXME fix vae dropout in video2video setting.
         | 
| 401 | 
            +
                                curr_rope_id += 1
         | 
| 402 | 
            +
                                continue
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                            # add a <|startofimage|> token
         | 
| 405 | 
            +
                            sequence_status['packed_text_ids'].append(self.start_of_image)
         | 
| 406 | 
            +
                            sequence_status['packed_text_indexes'].append(curr)
         | 
| 407 | 
            +
                            curr += 1
         | 
| 408 | 
            +
                            curr_split_len += 1
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                            # preprocess image
         | 
| 411 | 
            +
                            sequence_status['vae_image_tensors'].append(image_tensor)
         | 
| 412 | 
            +
                            sequence_status['packed_latent_position_ids'].append(
         | 
| 413 | 
            +
                                self.get_flattened_position_ids(
         | 
| 414 | 
            +
                                    image_tensor.size(1), image_tensor.size(2),
         | 
| 415 | 
            +
                                    self.data_config.vae_image_downsample, 
         | 
| 416 | 
            +
                                    max_num_patches_per_side=self.data_config.max_latent_size
         | 
| 417 | 
            +
                                )
         | 
| 418 | 
            +
                            )
         | 
| 419 | 
            +
                            H, W = image_tensor.shape[1:]
         | 
| 420 | 
            +
                            h = H // self.data_config.vae_image_downsample
         | 
| 421 | 
            +
                            w = W // self.data_config.vae_image_downsample
         | 
| 422 | 
            +
                            sequence_status['vae_latent_shapes'].append((h, w))
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                            num_img_tokens = w * h
         | 
| 425 | 
            +
                            sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
         | 
| 426 | 
            +
                            if item['loss'] == 1:
         | 
| 427 | 
            +
                                sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
         | 
| 428 | 
            +
                                if split_start:
         | 
| 429 | 
            +
                                    timestep = np.random.randn()
         | 
| 430 | 
            +
                            else:
         | 
| 431 | 
            +
                                timestep = float('-inf')
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                            sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
         | 
| 434 | 
            +
                            curr += num_img_tokens
         | 
| 435 | 
            +
                            curr_split_len += num_img_tokens
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                            # add a <|endofimage|> token
         | 
| 438 | 
            +
                            sequence_status['packed_text_ids'].append(self.end_of_image)
         | 
| 439 | 
            +
                            sequence_status['packed_text_indexes'].append(curr)
         | 
| 440 | 
            +
                            # <|endofimage|> may have loss
         | 
| 441 | 
            +
                            if item['special_token_loss'] == 1:
         | 
| 442 | 
            +
                                sequence_status['ce_loss_indexes'].append(curr)
         | 
| 443 | 
            +
                                sequence_status['ce_loss_weights'].append(1.0)
         | 
| 444 | 
            +
                                sequence_status['packed_label_ids'].append(item['special_token_label'])
         | 
| 445 | 
            +
                            curr += 1
         | 
| 446 | 
            +
                            curr_split_len += 1
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                            # update sequence status
         | 
| 449 | 
            +
                            if split_start:
         | 
| 450 | 
            +
                                if item['loss'] == 1 and 'frame_delta' not in item.keys():
         | 
| 451 | 
            +
                                    attn_modes.append("noise")
         | 
| 452 | 
            +
                                else:
         | 
| 453 | 
            +
                                    attn_modes.append("full")
         | 
| 454 | 
            +
                            sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
         | 
| 455 | 
            +
                            if 'frame_delta' in item.keys():
         | 
| 456 | 
            +
                                curr_rope_id += item['frame_delta']
         | 
| 457 | 
            +
                            elif item['loss'] == 0:
         | 
| 458 | 
            +
                                curr_rope_id += 1
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                        if item.get('split_end', True):
         | 
| 461 | 
            +
                            split_lens.append(curr_split_len)
         | 
| 462 | 
            +
                            sample_lens += curr_split_len
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    sequence_status['curr'] = curr
         | 
| 465 | 
            +
                    sequence_status['sample_lens'].append(sample_lens)
         | 
| 466 | 
            +
                    # prepare attention mask
         | 
| 467 | 
            +
                    if not self.use_flex:
         | 
| 468 | 
            +
                        sequence_status['nested_attention_masks'].append(
         | 
| 469 | 
            +
                            prepare_attention_mask_per_sample(split_lens, attn_modes)
         | 
| 470 | 
            +
                        )
         | 
| 471 | 
            +
                    else:
         | 
| 472 | 
            +
                        sequence_status['split_lens'].extend(split_lens)
         | 
| 473 | 
            +
                        sequence_status['attn_modes'].extend(attn_modes)
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    return sequence_status
         | 
| 476 | 
            +
             | 
| 477 | 
            +
             | 
| 478 | 
            +
            class SimpleCustomBatch:
         | 
| 479 | 
            +
                def __init__(self, batch):
         | 
| 480 | 
            +
                    data = batch[0]
         | 
| 481 | 
            +
                    self.batch_data_indexes = data['batch_data_indexes']
         | 
| 482 | 
            +
                    self.sequence_length = data["sequence_length"]
         | 
| 483 | 
            +
                    self.sample_lens = data["sample_lens"]
         | 
| 484 | 
            +
                    self.packed_text_ids = data["packed_text_ids"]
         | 
| 485 | 
            +
                    self.packed_text_indexes = data["packed_text_indexes"]
         | 
| 486 | 
            +
                    self.packed_position_ids = data["packed_position_ids"]
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    self.use_flex = "nested_attention_masks" not in data.keys()
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    if self.use_flex:
         | 
| 491 | 
            +
                        self.split_lens = data["split_lens"]
         | 
| 492 | 
            +
                        self.attn_modes = data["attn_modes"]
         | 
| 493 | 
            +
                    else:
         | 
| 494 | 
            +
                        self.nested_attention_masks = data["nested_attention_masks"]
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    if "padded_images" in data.keys():
         | 
| 497 | 
            +
                        self.padded_images = data["padded_images"]
         | 
| 498 | 
            +
                        self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
         | 
| 499 | 
            +
                        self.packed_latent_position_ids = data["packed_latent_position_ids"]
         | 
| 500 | 
            +
                        self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                    if "packed_vit_tokens" in data.keys():
         | 
| 503 | 
            +
                        self.packed_vit_tokens = data["packed_vit_tokens"]
         | 
| 504 | 
            +
                        self.packed_vit_position_ids = data["packed_vit_position_ids"]
         | 
| 505 | 
            +
                        self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
         | 
| 506 | 
            +
                        self.vit_token_seqlens = data["vit_token_seqlens"]
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                    if "packed_timesteps" in data.keys():
         | 
| 509 | 
            +
                        self.packed_timesteps = data["packed_timesteps"]
         | 
| 510 | 
            +
                        self.mse_loss_indexes = data["mse_loss_indexes"]
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    if "packed_label_ids" in data.keys():
         | 
| 513 | 
            +
                        self.packed_label_ids = data["packed_label_ids"]
         | 
| 514 | 
            +
                        self.ce_loss_indexes = data["ce_loss_indexes"]
         | 
| 515 | 
            +
                        self.ce_loss_weights = data["ce_loss_weights"]
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                def pin_memory(self):
         | 
| 518 | 
            +
                    self.packed_text_ids = self.packed_text_ids.pin_memory()
         | 
| 519 | 
            +
                    self.packed_text_indexes = self.packed_text_indexes.pin_memory()
         | 
| 520 | 
            +
                    self.packed_position_ids = self.packed_position_ids.pin_memory()
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    if not self.use_flex:
         | 
| 523 | 
            +
                        self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks]
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    if hasattr(self, 'padded_images'):
         | 
| 526 | 
            +
                        self.padded_images = self.padded_images.pin_memory()
         | 
| 527 | 
            +
                        self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
         | 
| 528 | 
            +
                        self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory()
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    if hasattr(self, 'packed_timesteps'):
         | 
| 531 | 
            +
                        self.packed_timesteps = self.packed_timesteps.pin_memory()
         | 
| 532 | 
            +
                        self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    if hasattr(self, 'packed_vit_tokens'):
         | 
| 535 | 
            +
                        self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
         | 
| 536 | 
            +
                        self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
         | 
| 537 | 
            +
                        self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
         | 
| 538 | 
            +
                        self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    if hasattr(self, 'packed_label_ids'):
         | 
| 541 | 
            +
                        self.packed_label_ids = self.packed_label_ids.pin_memory()
         | 
| 542 | 
            +
                        self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
         | 
| 543 | 
            +
                        self.ce_loss_weights = self.ce_loss_weights.pin_memory()
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    return self
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                def cuda(self, device):
         | 
| 548 | 
            +
                    self.packed_text_ids = self.packed_text_ids.to(device)
         | 
| 549 | 
            +
                    self.packed_text_indexes = self.packed_text_indexes.to(device)
         | 
| 550 | 
            +
                    self.packed_position_ids = self.packed_position_ids.to(device)
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    if not self.use_flex:
         | 
| 553 | 
            +
                        self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks]
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    if hasattr(self, 'padded_images'):
         | 
| 556 | 
            +
                        self.padded_images = self.padded_images.to(device)
         | 
| 557 | 
            +
                        self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
         | 
| 558 | 
            +
                        self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    if hasattr(self, 'packed_timesteps'):
         | 
| 561 | 
            +
                        self.packed_timesteps = self.packed_timesteps.to(device)
         | 
| 562 | 
            +
                        self.mse_loss_indexes = self.mse_loss_indexes.to(device)
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                    if hasattr(self, 'packed_vit_tokens'):
         | 
| 565 | 
            +
                        self.packed_vit_tokens = self.packed_vit_tokens.to(device)
         | 
| 566 | 
            +
                        self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
         | 
| 567 | 
            +
                        self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
         | 
| 568 | 
            +
                        self.vit_token_seqlens = self.vit_token_seqlens.to(device)
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    if hasattr(self, 'packed_label_ids'):
         | 
| 571 | 
            +
                        self.packed_label_ids = self.packed_label_ids.to(device)
         | 
| 572 | 
            +
                        self.ce_loss_indexes = self.ce_loss_indexes.to(device)
         | 
| 573 | 
            +
                        self.ce_loss_weights = self.ce_loss_weights.to(device)
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                    return self
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                def to_dict(self):
         | 
| 578 | 
            +
                    data = dict(
         | 
| 579 | 
            +
                        sequence_length = self.sequence_length,
         | 
| 580 | 
            +
                        sample_lens = self.sample_lens,
         | 
| 581 | 
            +
                        packed_text_ids = self.packed_text_ids,
         | 
| 582 | 
            +
                        packed_text_indexes = self.packed_text_indexes,
         | 
| 583 | 
            +
                        packed_position_ids = self.packed_position_ids,
         | 
| 584 | 
            +
                        batch_data_indexes = self.batch_data_indexes,
         | 
| 585 | 
            +
                    )
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                    if not self.use_flex:
         | 
| 588 | 
            +
                        data['nested_attention_masks'] = self.nested_attention_masks
         | 
| 589 | 
            +
                    else:
         | 
| 590 | 
            +
                        data['split_lens'] = self.split_lens
         | 
| 591 | 
            +
                        data['attn_modes'] = self.attn_modes
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    if hasattr(self, 'padded_images'):
         | 
| 594 | 
            +
                        data['padded_images'] = self.padded_images
         | 
| 595 | 
            +
                        data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes
         | 
| 596 | 
            +
                        data['packed_latent_position_ids'] = self.packed_latent_position_ids
         | 
| 597 | 
            +
                        data['packed_vae_token_indexes'] = self.packed_vae_token_indexes
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    if hasattr(self, 'packed_vit_tokens'):
         | 
| 600 | 
            +
                        data['packed_vit_tokens'] = self.packed_vit_tokens
         | 
| 601 | 
            +
                        data['packed_vit_position_ids'] = self.packed_vit_position_ids
         | 
| 602 | 
            +
                        data['packed_vit_token_indexes'] = self.packed_vit_token_indexes
         | 
| 603 | 
            +
                        data['vit_token_seqlens'] = self.vit_token_seqlens
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    if hasattr(self, 'packed_timesteps'):
         | 
| 606 | 
            +
                        data['packed_timesteps'] = self.packed_timesteps
         | 
| 607 | 
            +
                        data['mse_loss_indexes'] = self.mse_loss_indexes
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    if hasattr(self, 'packed_label_ids'):
         | 
| 610 | 
            +
                        data['packed_label_ids'] = self.packed_label_ids
         | 
| 611 | 
            +
                        data['ce_loss_indexes'] = self.ce_loss_indexes
         | 
| 612 | 
            +
                        data['ce_loss_weights'] = self.ce_loss_weights
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    return data
         | 
| 615 | 
            +
             | 
| 616 | 
            +
             | 
| 617 | 
            +
            def collate_wrapper():
         | 
| 618 | 
            +
                def collate_fn(batch):
         | 
| 619 | 
            +
                    return SimpleCustomBatch(batch)
         | 
| 620 | 
            +
                return collate_fn
         | 
    	
        data/dataset_info.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .interleave_datasets import UnifiedEditIterableDataset
         | 
| 5 | 
            +
            from .t2i_dataset import T2IIterableDataset
         | 
| 6 | 
            +
            from .vlm_dataset import SftJSONLIterableDataset
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            DATASET_REGISTRY = {
         | 
| 10 | 
            +
                't2i_pretrain': T2IIterableDataset,
         | 
| 11 | 
            +
                'vlm_sft': SftJSONLIterableDataset,
         | 
| 12 | 
            +
                'unified_edit': UnifiedEditIterableDataset,
         | 
| 13 | 
            +
            }
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            DATASET_INFO = {
         | 
| 17 | 
            +
                't2i_pretrain': {
         | 
| 18 | 
            +
                    't2i': {
         | 
| 19 | 
            +
                        'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files
         | 
| 20 | 
            +
                        'num_files': 10, # number of data units to be sharded across all ranks and workers
         | 
| 21 | 
            +
                        'num_total_samples': 1000, # number of total samples in the dataset
         | 
| 22 | 
            +
                    },
         | 
| 23 | 
            +
                },
         | 
| 24 | 
            +
                'unified_edit':{
         | 
| 25 | 
            +
                    'seedxedit_multi': {
         | 
| 26 | 
            +
                        'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi',
         | 
| 27 | 
            +
                        'num_files': 10,
         | 
| 28 | 
            +
                        'num_total_samples': 1000,
         | 
| 29 | 
            +
                        "parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files
         | 
| 30 | 
            +
            		},
         | 
| 31 | 
            +
                },
         | 
| 32 | 
            +
                'vlm_sft': {
         | 
| 33 | 
            +
                    'llava_ov': {
         | 
| 34 | 
            +
            			'data_dir': 'your_data_path/bagel_example/vlm/images',
         | 
| 35 | 
            +
            			'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl',
         | 
| 36 | 
            +
            			'num_total_samples': 1000
         | 
| 37 | 
            +
            		},
         | 
| 38 | 
            +
                },
         | 
| 39 | 
            +
            }
         | 
    	
        data/distributed_iterable_dataset.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class DistributedIterableDataset(torch.utils.data.IterableDataset):
         | 
| 9 | 
            +
                def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
         | 
| 10 | 
            +
                    self.dataset_name = dataset_name
         | 
| 11 | 
            +
                    self.local_rank = local_rank
         | 
| 12 | 
            +
                    self.world_size = world_size
         | 
| 13 | 
            +
                    self.num_workers = num_workers
         | 
| 14 | 
            +
                    self.rng = random.Random()
         | 
| 15 | 
            +
                    self.data_paths = None
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def get_data_paths(self, *args, **kwargs):
         | 
| 18 | 
            +
                    raise NotImplementedError
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def set_epoch(self, seed=42):
         | 
| 21 | 
            +
                    if self.data_paths is None:
         | 
| 22 | 
            +
                        return
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    if isinstance(self.data_paths[0], tuple):
         | 
| 25 | 
            +
                        data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
         | 
| 26 | 
            +
                    elif isinstance(self.data_paths[0], str):
         | 
| 27 | 
            +
                        data_paths = sorted(self.data_paths)
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self.rng.seed(seed)
         | 
| 32 | 
            +
                    self.rng.shuffle(data_paths)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    num_files_per_rank = len(data_paths) // self.world_size
         | 
| 35 | 
            +
                    local_start = self.local_rank * num_files_per_rank
         | 
| 36 | 
            +
                    local_end = (self.local_rank + 1) * num_files_per_rank
         | 
| 37 | 
            +
                    self.num_files_per_rank = num_files_per_rank
         | 
| 38 | 
            +
                    self.data_paths_per_rank = data_paths[local_start:local_end]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def get_data_paths_per_worker(self):
         | 
| 41 | 
            +
                    if self.data_paths is None:
         | 
| 42 | 
            +
                        return None
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    info = torch.utils.data.get_worker_info()
         | 
| 45 | 
            +
                    if info is None:
         | 
| 46 | 
            +
                        # Single worker: Use all files assigned to the rank
         | 
| 47 | 
            +
                        return self.data_paths_per_rank, 0
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    worker_id = info.id
         | 
| 50 | 
            +
                    num_files_per_worker = self.num_files_per_rank // info.num_workers
         | 
| 51 | 
            +
                    start = num_files_per_worker * worker_id
         | 
| 52 | 
            +
                    end = num_files_per_worker * (worker_id + 1)
         | 
| 53 | 
            +
                    data_paths_per_worker = self.data_paths_per_rank[start:end]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    return data_paths_per_worker[::-1], worker_id
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __iter__(self):
         | 
| 58 | 
            +
                    raise NotImplementedError
         | 
    	
        data/interleave_datasets/__init__.py
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .edit_dataset import UnifiedEditIterableDataset
         | 
| 5 | 
            +
             | 
    	
        data/interleave_datasets/edit_dataset.py
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import io
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            from PIL import Image, ImageFile, PngImagePlugin
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
         | 
| 9 | 
            +
            from ..data_utils import pil_img2rgb
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            Image.MAX_IMAGE_PIXELS = 200000000
         | 
| 13 | 
            +
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
| 14 | 
            +
            MaximumDecompressedSize = 1024
         | 
| 15 | 
            +
            MegaByte = 2 ** 20
         | 
| 16 | 
            +
            PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def parse_row(self, row):
         | 
| 22 | 
            +
                    image_num = len(row["image_list"])
         | 
| 23 | 
            +
                    # randomly choose start and end, return [0, 1] when only two images
         | 
| 24 | 
            +
                    start_idx = random.choice(range(image_num - 1))
         | 
| 25 | 
            +
                    max_end = min(start_idx + 3, image_num)
         | 
| 26 | 
            +
                    end_idx = random.choice(range(start_idx + 1, max_end))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    data = self._init_data()
         | 
| 29 | 
            +
                    data = self._add_image(
         | 
| 30 | 
            +
                        data, 
         | 
| 31 | 
            +
                        pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
         | 
| 32 | 
            +
                        need_loss=False, 
         | 
| 33 | 
            +
                        need_vae=True, 
         | 
| 34 | 
            +
                        need_vit=True, 
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
         | 
| 38 | 
            +
                        if end_idx == image_num - 1:
         | 
| 39 | 
            +
                            end_idx -= 1
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                        instruction = ""
         | 
| 42 | 
            +
                        for idx in range(start_idx + 1, end_idx + 1):
         | 
| 43 | 
            +
                            instruction += random.choice(row["instruction_list"][idx-1]) + ". "
         | 
| 44 | 
            +
                        data = self._add_text(data, instruction.rstrip(), need_loss=False)
         | 
| 45 | 
            +
                        data = self._add_image(
         | 
| 46 | 
            +
                            data, 
         | 
| 47 | 
            +
                            pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
         | 
| 48 | 
            +
                            need_loss=True, 
         | 
| 49 | 
            +
                            need_vae=False, 
         | 
| 50 | 
            +
                            need_vit=False,
         | 
| 51 | 
            +
                        )
         | 
| 52 | 
            +
                    else:
         | 
| 53 | 
            +
                        for idx in range(start_idx + 1, end_idx + 1):
         | 
| 54 | 
            +
                            instruction = random.choice(row["instruction_list"][idx-1])
         | 
| 55 | 
            +
                            data = self._add_text(data, instruction, need_loss=False)
         | 
| 56 | 
            +
                            if idx != end_idx:
         | 
| 57 | 
            +
                                data = self._add_image(
         | 
| 58 | 
            +
                                    data, 
         | 
| 59 | 
            +
                                    pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
         | 
| 60 | 
            +
                                    need_loss=True, 
         | 
| 61 | 
            +
                                    need_vae=True, 
         | 
| 62 | 
            +
                                    need_vit=True,
         | 
| 63 | 
            +
                                )
         | 
| 64 | 
            +
                            else:
         | 
| 65 | 
            +
                                data = self._add_image(
         | 
| 66 | 
            +
                                    data, 
         | 
| 67 | 
            +
                                    pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
         | 
| 68 | 
            +
                                    need_loss=True, 
         | 
| 69 | 
            +
                                    need_vae=False, 
         | 
| 70 | 
            +
                                    need_vit=False,
         | 
| 71 | 
            +
                                )
         | 
| 72 | 
            +
                    return data
         | 
    	
        data/interleave_datasets/interleave_t2i_dataset.py
    ADDED
    
    | @@ -0,0 +1,212 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import pyarrow.parquet as pq
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from ..distributed_iterable_dataset import DistributedIterableDataset
         | 
| 7 | 
            +
            from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class InterleavedBaseIterableDataset(DistributedIterableDataset):
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def _init_data(self):
         | 
| 13 | 
            +
                    data = {
         | 
| 14 | 
            +
                        'sequence_plan': [],
         | 
| 15 | 
            +
                        'text_ids_list': [],
         | 
| 16 | 
            +
                        'image_tensor_list': [],
         | 
| 17 | 
            +
                        'num_tokens': 0,
         | 
| 18 | 
            +
                    }
         | 
| 19 | 
            +
                    return data
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def _add_text(self, data, text, need_loss, enable_cfg=True):
         | 
| 22 | 
            +
                    text_ids = self.tokenizer.encode(text)
         | 
| 23 | 
            +
                    data['num_tokens'] += len(text_ids)
         | 
| 24 | 
            +
                    data['text_ids_list'].append(text_ids)
         | 
| 25 | 
            +
                    data['sequence_plan'].append(
         | 
| 26 | 
            +
                        {
         | 
| 27 | 
            +
                            'type': 'text',
         | 
| 28 | 
            +
                            'enable_cfg': int(enable_cfg),
         | 
| 29 | 
            +
                            'loss': int(need_loss),
         | 
| 30 | 
            +
                            'special_token_loss': 0,
         | 
| 31 | 
            +
                            'special_token_label': None,
         | 
| 32 | 
            +
                        }
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                    return data
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True):
         | 
| 37 | 
            +
                    assert need_loss or need_vae or need_vit
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    if need_loss:
         | 
| 40 | 
            +
                        data['sequence_plan'].append(
         | 
| 41 | 
            +
                            {
         | 
| 42 | 
            +
                                'type': 'vae_image', 
         | 
| 43 | 
            +
                                'enable_cfg': 0, 
         | 
| 44 | 
            +
                                'loss': 1, 
         | 
| 45 | 
            +
                                'special_token_loss': 0,
         | 
| 46 | 
            +
                                'special_token_label': None,
         | 
| 47 | 
            +
                            }
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                        image_tensor = self.transform(image)
         | 
| 51 | 
            +
                        height, width = image_tensor.shape[1:]
         | 
| 52 | 
            +
                        data['num_tokens'] += width * height // self.transform.stride ** 2
         | 
| 53 | 
            +
                        data['image_tensor_list'].append(image_tensor)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    if need_vae:
         | 
| 56 | 
            +
                        data['sequence_plan'].append(
         | 
| 57 | 
            +
                            {
         | 
| 58 | 
            +
                                'type': 'vae_image', 
         | 
| 59 | 
            +
                                'enable_cfg': int(enable_cfg), 
         | 
| 60 | 
            +
                                'loss': 0, 
         | 
| 61 | 
            +
                                'special_token_loss': 0,
         | 
| 62 | 
            +
                                'special_token_label': None,
         | 
| 63 | 
            +
                            }
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        image_tensor = self.transform(image)
         | 
| 67 | 
            +
                        height, width = image_tensor.shape[1:]
         | 
| 68 | 
            +
                        data['num_tokens'] += width * height // self.transform.stride ** 2
         | 
| 69 | 
            +
                        data['image_tensor_list'].append(image_tensor.clone())
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    if need_vit:
         | 
| 72 | 
            +
                        data['sequence_plan'].append(
         | 
| 73 | 
            +
                            {
         | 
| 74 | 
            +
                                'type': 'vit_image',
         | 
| 75 | 
            +
                                'enable_cfg': int(enable_cfg), 
         | 
| 76 | 
            +
                                'loss': 0,
         | 
| 77 | 
            +
                                'special_token_loss': 0,
         | 
| 78 | 
            +
                                'special_token_label': None,
         | 
| 79 | 
            +
                            },
         | 
| 80 | 
            +
                        )
         | 
| 81 | 
            +
                        vit_image_tensor = self.vit_transform(image)
         | 
| 82 | 
            +
                        height, width = vit_image_tensor.shape[1:]
         | 
| 83 | 
            +
                        data['num_tokens'] += width * height // self.vit_transform.stride ** 2
         | 
| 84 | 
            +
                        data['image_tensor_list'].append(vit_image_tensor)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    return data
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
         | 
| 89 | 
            +
                    assert int(need_loss) + int(need_vae) == 1
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if need_loss:
         | 
| 92 | 
            +
                        for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
         | 
| 93 | 
            +
                            current_sequence_plan = {
         | 
| 94 | 
            +
                                'type': 'vae_image', 
         | 
| 95 | 
            +
                                'enable_cfg': 0, 
         | 
| 96 | 
            +
                                'loss': 1, 
         | 
| 97 | 
            +
                                'special_token_loss': 0,
         | 
| 98 | 
            +
                                'special_token_label': None,
         | 
| 99 | 
            +
                                'split_start': idx == 0,
         | 
| 100 | 
            +
                                'split_end': idx == len(frames) - 1,
         | 
| 101 | 
            +
                            }
         | 
| 102 | 
            +
                            if idx < len(frame_indexes) - 1:
         | 
| 103 | 
            +
                                current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
         | 
| 104 | 
            +
                            data['sequence_plan'].append(current_sequence_plan)
         | 
| 105 | 
            +
                            image_tensor = self.transform(image)
         | 
| 106 | 
            +
                            height, width = image_tensor.shape[1:]
         | 
| 107 | 
            +
                            data['image_tensor_list'].append(image_tensor)
         | 
| 108 | 
            +
                            data['num_tokens'] += width * height // self.transform.stride ** 2
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    elif need_vae:
         | 
| 111 | 
            +
                        for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
         | 
| 112 | 
            +
                            current_sequence_plan = {
         | 
| 113 | 
            +
                                'type': 'vae_image', 
         | 
| 114 | 
            +
                                'enable_cfg': int(enable_cfg), 
         | 
| 115 | 
            +
                                'loss': 0, 
         | 
| 116 | 
            +
                                'special_token_loss': 0,
         | 
| 117 | 
            +
                                'special_token_label': None,
         | 
| 118 | 
            +
                                'split_start': idx == 0,
         | 
| 119 | 
            +
                                'split_end': idx == len(frames) - 1,
         | 
| 120 | 
            +
                            }
         | 
| 121 | 
            +
                            if idx < len(frame_indexes) - 1:
         | 
| 122 | 
            +
                                current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
         | 
| 123 | 
            +
                            data['sequence_plan'].append(current_sequence_plan)
         | 
| 124 | 
            +
                            image_tensor = self.transform(image)
         | 
| 125 | 
            +
                            height, width = image_tensor.shape[1:]
         | 
| 126 | 
            +
                            data['image_tensor_list'].append(image_tensor)
         | 
| 127 | 
            +
                            data['num_tokens'] += width * height // self.transform.stride ** 2
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    return data
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            class ParquetStandardIterableDataset(DistributedIterableDataset):
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def __init__(
         | 
| 135 | 
            +
                    self, dataset_name, transform, tokenizer, vit_transform, 
         | 
| 136 | 
            +
                    data_dir_list, num_used_data, parquet_info,
         | 
| 137 | 
            +
                    local_rank=0, world_size=1, num_workers=8, data_status=None,
         | 
| 138 | 
            +
                ):
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
                    data_dir_list: list of data directories contains parquet files
         | 
| 141 | 
            +
                    num_used_data: list of number of sampled data paths for each data directory
         | 
| 142 | 
            +
                    vit_transform: input transform for vit model.
         | 
| 143 | 
            +
                    """
         | 
| 144 | 
            +
                    super().__init__(dataset_name, local_rank, world_size, num_workers)
         | 
| 145 | 
            +
                    self.transform = transform
         | 
| 146 | 
            +
                    self.vit_transform = vit_transform
         | 
| 147 | 
            +
                    self.tokenizer = tokenizer
         | 
| 148 | 
            +
                    self.data_status = data_status
         | 
| 149 | 
            +
                    self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
         | 
| 150 | 
            +
                    self.set_epoch()
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
         | 
| 153 | 
            +
                    row_groups = []
         | 
| 154 | 
            +
                    for data_dir, num_data_path in zip(data_dir_list, num_used_data):
         | 
| 155 | 
            +
                        data_paths = get_parquet_data_paths([data_dir], [num_data_path])
         | 
| 156 | 
            +
                        for data_path in data_paths:
         | 
| 157 | 
            +
                            if data_path in parquet_info.keys():
         | 
| 158 | 
            +
                                num_row_groups = parquet_info[data_path]['num_row_groups']
         | 
| 159 | 
            +
                                for rg_idx in range(num_row_groups):
         | 
| 160 | 
            +
                                    row_groups.append((data_path, rg_idx))
         | 
| 161 | 
            +
                    return row_groups
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def parse_row(self, row):
         | 
| 164 | 
            +
                    raise NotImplementedError
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def __iter__(self):
         | 
| 167 | 
            +
                    file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
         | 
| 168 | 
            +
                    if self.data_status is not None:
         | 
| 169 | 
            +
                        global_row_group_start_id = self.data_status[worker_id][0]
         | 
| 170 | 
            +
                        row_start_id = self.data_status[worker_id][1] + 1
         | 
| 171 | 
            +
                    else:
         | 
| 172 | 
            +
                        global_row_group_start_id = 0
         | 
| 173 | 
            +
                        row_start_id = 0
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    print(
         | 
| 176 | 
            +
                        f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
         | 
| 177 | 
            +
                        f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    while True:
         | 
| 181 | 
            +
                        file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
         | 
| 182 | 
            +
                        for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
         | 
| 183 | 
            +
                            file_paths_per_worker_, start=global_row_group_start_id
         | 
| 184 | 
            +
                        ):
         | 
| 185 | 
            +
                            fs = init_arrow_pf_fs(parquet_file_path)
         | 
| 186 | 
            +
                            with fs.open_input_file(parquet_file_path) as f:
         | 
| 187 | 
            +
                                try:
         | 
| 188 | 
            +
                                    fr = pq.ParquetFile(f)
         | 
| 189 | 
            +
                                    df = fr.read_row_group(row_group_id).to_pandas()
         | 
| 190 | 
            +
                                    df = df.iloc[row_start_id:]
         | 
| 191 | 
            +
                                except Exception as e:
         | 
| 192 | 
            +
                                    print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
         | 
| 193 | 
            +
                                    continue
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                                for row_idx, row in df.iterrows():
         | 
| 196 | 
            +
                                    try:
         | 
| 197 | 
            +
                                        data = self.parse_row(row)
         | 
| 198 | 
            +
                                        if len(data) == 0:
         | 
| 199 | 
            +
                                            continue
         | 
| 200 | 
            +
                                        data['data_indexes'] = {
         | 
| 201 | 
            +
                                            "data_indexes": [global_row_group_idx, row_idx],
         | 
| 202 | 
            +
                                            "worker_id": worker_id,
         | 
| 203 | 
            +
                                            "dataset_name": self.dataset_name,
         | 
| 204 | 
            +
                                        }
         | 
| 205 | 
            +
                                    except Exception as e:
         | 
| 206 | 
            +
                                        print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
         | 
| 207 | 
            +
                                        continue
         | 
| 208 | 
            +
                                    yield data
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                                row_start_id = 0
         | 
| 211 | 
            +
                        global_row_group_start_id = 0
         | 
| 212 | 
            +
                        print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
         | 
    	
        data/parquet_utils.py
    ADDED
    
    | @@ -0,0 +1,90 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import xml.etree.ElementTree as ET
         | 
| 7 | 
            +
            import subprocess
         | 
| 8 | 
            +
            import logging
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import pyarrow.fs as pf
         | 
| 11 | 
            +
            import torch.distributed as dist
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
         | 
| 17 | 
            +
                num_data_dirs = len(data_dir_list)
         | 
| 18 | 
            +
                if world_size > 1:
         | 
| 19 | 
            +
                    chunk_size = (num_data_dirs + world_size - 1) // world_size
         | 
| 20 | 
            +
                    start_idx = rank * chunk_size
         | 
| 21 | 
            +
                    end_idx = min(start_idx + chunk_size, num_data_dirs)
         | 
| 22 | 
            +
                    local_data_dir_list = data_dir_list[start_idx:end_idx]
         | 
| 23 | 
            +
                    local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
         | 
| 24 | 
            +
                else:
         | 
| 25 | 
            +
                    local_data_dir_list = data_dir_list
         | 
| 26 | 
            +
                    local_num_sampled_data_paths = num_sampled_data_paths
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                local_data_paths = []
         | 
| 29 | 
            +
                for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
         | 
| 30 | 
            +
                    if data_dir.startswith("hdfs://"):
         | 
| 31 | 
            +
                        files = hdfs_ls_cmd(data_dir)
         | 
| 32 | 
            +
                        data_paths_per_dir = [
         | 
| 33 | 
            +
                            file for file in files if file.endswith(".parquet")
         | 
| 34 | 
            +
                        ]
         | 
| 35 | 
            +
                    else:
         | 
| 36 | 
            +
                        files = os.listdir(data_dir)
         | 
| 37 | 
            +
                        data_paths_per_dir = [
         | 
| 38 | 
            +
                            os.path.join(data_dir, name)
         | 
| 39 | 
            +
                            for name in files
         | 
| 40 | 
            +
                            if name.endswith(".parquet")
         | 
| 41 | 
            +
                        ]
         | 
| 42 | 
            +
                    repeat = num_data_path // len(data_paths_per_dir)
         | 
| 43 | 
            +
                    data_paths_per_dir = data_paths_per_dir * (repeat + 1)
         | 
| 44 | 
            +
                    local_data_paths.extend(data_paths_per_dir[:num_data_path])
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                if world_size > 1:
         | 
| 47 | 
            +
                    gather_list = [None] * world_size
         | 
| 48 | 
            +
                    dist.all_gather_object(gather_list, local_data_paths)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    combined_chunks = []
         | 
| 51 | 
            +
                    for chunk_list in gather_list:
         | 
| 52 | 
            +
                        if chunk_list is not None:
         | 
| 53 | 
            +
                            combined_chunks.extend(chunk_list)
         | 
| 54 | 
            +
                else:
         | 
| 55 | 
            +
                    combined_chunks = local_data_paths
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                return combined_chunks
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            # NOTE: cumtomize this function for your cluster
         | 
| 61 | 
            +
            def get_hdfs_host():
         | 
| 62 | 
            +
                return "hdfs://xxx"
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            # NOTE: cumtomize this function for your cluster
         | 
| 66 | 
            +
            def get_hdfs_block_size():
         | 
| 67 | 
            +
                return 134217728
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            # NOTE: cumtomize this function for your cluster
         | 
| 71 | 
            +
            def get_hdfs_extra_conf():
         | 
| 72 | 
            +
                return None
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def init_arrow_pf_fs(parquet_file_path):
         | 
| 76 | 
            +
                if parquet_file_path.startswith("hdfs://"):
         | 
| 77 | 
            +
                    fs = pf.HadoopFileSystem(
         | 
| 78 | 
            +
                        host=get_hdfs_host(),
         | 
| 79 | 
            +
                        port=0,
         | 
| 80 | 
            +
                        buffer_size=get_hdfs_block_size(),
         | 
| 81 | 
            +
                        extra_conf=get_hdfs_extra_conf(),
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    fs = pf.LocalFileSystem()
         | 
| 85 | 
            +
                return fs
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def hdfs_ls_cmd(dir):
         | 
| 89 | 
            +
                result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
         | 
| 90 | 
            +
                return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
         | 
    	
        data/t2i_dataset.py
    ADDED
    
    | @@ -0,0 +1,128 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import io
         | 
| 5 | 
            +
            import json
         | 
| 6 | 
            +
            import pyarrow.parquet as pq
         | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .data_utils import pil_img2rgb
         | 
| 11 | 
            +
            from .distributed_iterable_dataset import DistributedIterableDataset
         | 
| 12 | 
            +
            from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Image.MAX_IMAGE_PIXELS = 20_000_000
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class T2IIterableDataset(DistributedIterableDataset):
         | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self, dataset_name, transform, tokenizer, data_dir_list, num_used_data, 
         | 
| 20 | 
            +
                    local_rank=0, world_size=1, num_workers=8, data_status=None,
         | 
| 21 | 
            +
                ):
         | 
| 22 | 
            +
                    """
         | 
| 23 | 
            +
                    data_dir_list: list of data directories contains parquet files
         | 
| 24 | 
            +
                    num_used_data: list of number of sampled data paths for each data directory
         | 
| 25 | 
            +
                    """
         | 
| 26 | 
            +
                    super().__init__(dataset_name, local_rank, world_size, num_workers)
         | 
| 27 | 
            +
                    self.transform = transform
         | 
| 28 | 
            +
                    self.tokenizer = tokenizer
         | 
| 29 | 
            +
                    self.data_status = data_status
         | 
| 30 | 
            +
                    self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
         | 
| 31 | 
            +
                    self.set_epoch()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def get_data_paths(self, data_dir_list, num_used_data):
         | 
| 34 | 
            +
                    return get_parquet_data_paths(data_dir_list, num_used_data)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __iter__(self):
         | 
| 37 | 
            +
                    data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
         | 
| 38 | 
            +
                    if self.data_status is not None:
         | 
| 39 | 
            +
                        parquet_start_id = self.data_status[worker_id][0]
         | 
| 40 | 
            +
                        row_group_start_id = self.data_status[worker_id][1]
         | 
| 41 | 
            +
                        row_start_id = self.data_status[worker_id][2] + 1
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        parquet_start_id = 0
         | 
| 44 | 
            +
                        row_group_start_id = 0
         | 
| 45 | 
            +
                        row_start_id = 0
         | 
| 46 | 
            +
                    transform_stride = self.transform.stride
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    print(
         | 
| 49 | 
            +
                        f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
         | 
| 50 | 
            +
                        f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    while True:
         | 
| 54 | 
            +
                        data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
         | 
| 55 | 
            +
                        for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
         | 
| 56 | 
            +
                            fs = init_arrow_pf_fs(parquet_file_path)
         | 
| 57 | 
            +
                            with fs.open_input_file(parquet_file_path) as f:
         | 
| 58 | 
            +
                                fr = pq.ParquetFile(f)
         | 
| 59 | 
            +
                                row_group_ids = list(range(fr.num_row_groups))
         | 
| 60 | 
            +
                                row_group_ids_ = row_group_ids[row_group_start_id:]
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                                for row_group_id in row_group_ids_:
         | 
| 63 | 
            +
                                    df = fr.read_row_group(row_group_id).to_pandas()
         | 
| 64 | 
            +
                                    df = df.iloc[row_start_id:]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                                    for row_idx, row in df.iterrows():
         | 
| 67 | 
            +
                                        num_tokens = 0
         | 
| 68 | 
            +
                                        try:
         | 
| 69 | 
            +
                                            image_byte = row['image']
         | 
| 70 | 
            +
                                            image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
         | 
| 71 | 
            +
                                        except Exception as e:
         | 
| 72 | 
            +
                                            print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
         | 
| 73 | 
            +
                                            continue
         | 
| 74 | 
            +
                                        image_tensor = self.transform(image)
         | 
| 75 | 
            +
                                        height, width = image_tensor.shape[1:]
         | 
| 76 | 
            +
                                        num_tokens += width * height // transform_stride ** 2
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                                        try:
         | 
| 79 | 
            +
                                            caption_dict = row['captions']
         | 
| 80 | 
            +
                                            caption_dict = json.loads(caption_dict)
         | 
| 81 | 
            +
                                        except Exception as e:
         | 
| 82 | 
            +
                                            print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
         | 
| 83 | 
            +
                                            continue
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                                        caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
         | 
| 86 | 
            +
                                        if len(caps_token) == 0:
         | 
| 87 | 
            +
                                            print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
         | 
| 88 | 
            +
                                            caption_token = self.tokenizer.encode(' ')
         | 
| 89 | 
            +
                                        else:
         | 
| 90 | 
            +
                                            caption_token = random.choice(caps_token)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                                        sequence_plan, text_ids_list = [], []
         | 
| 93 | 
            +
                                        text_ids = caption_token
         | 
| 94 | 
            +
                                        num_tokens += len(caption_token)
         | 
| 95 | 
            +
                                        text_ids_list.append(text_ids)
         | 
| 96 | 
            +
                                        sequence_plan.append({
         | 
| 97 | 
            +
                                            'type': 'text',
         | 
| 98 | 
            +
                                            'enable_cfg': 1,
         | 
| 99 | 
            +
                                            'loss': 0,
         | 
| 100 | 
            +
                                            'special_token_loss': 0,
         | 
| 101 | 
            +
                                            'special_token_label': None,
         | 
| 102 | 
            +
                                        })
         | 
| 103 | 
            +
                                    
         | 
| 104 | 
            +
                                        sequence_plan.append({
         | 
| 105 | 
            +
                                            'type': 'vae_image',
         | 
| 106 | 
            +
                                            'enable_cfg': 0,
         | 
| 107 | 
            +
                                            'loss': 1,
         | 
| 108 | 
            +
                                            'special_token_loss': 0,
         | 
| 109 | 
            +
                                            'special_token_label': None,
         | 
| 110 | 
            +
                                        })
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                                        sample = dict(
         | 
| 113 | 
            +
                                            image_tensor_list=[image_tensor], 
         | 
| 114 | 
            +
                                            text_ids_list=text_ids_list,
         | 
| 115 | 
            +
                                            num_tokens=num_tokens,
         | 
| 116 | 
            +
                                            sequence_plan=sequence_plan,
         | 
| 117 | 
            +
                                            data_indexes={
         | 
| 118 | 
            +
                                                "data_indexes": [parquet_idx, row_group_id, row_idx],
         | 
| 119 | 
            +
                                                "worker_id": worker_id,
         | 
| 120 | 
            +
                                                "dataset_name": self.dataset_name,
         | 
| 121 | 
            +
                                            }
         | 
| 122 | 
            +
                                        )
         | 
| 123 | 
            +
                                        yield sample
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                                    row_start_id = 0
         | 
| 126 | 
            +
                                row_group_start_id = 0
         | 
| 127 | 
            +
                        parquet_start_id = 0
         | 
| 128 | 
            +
                        print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
         | 
    	
        data/transforms.py
    ADDED
    
    | @@ -0,0 +1,287 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from torchvision import transforms
         | 
| 11 | 
            +
            from torchvision.transforms import functional as F
         | 
| 12 | 
            +
            from torchvision.transforms import InterpolationMode
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
         | 
| 16 | 
            +
                """Resize the input image so that its longest side and shortest side are within a specified range,
         | 
| 17 | 
            +
                ensuring that both sides are divisible by a specified stride.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                Args:
         | 
| 20 | 
            +
                    max_size (int): Maximum size for the longest edge of the image.
         | 
| 21 | 
            +
                    min_size (int): Minimum size for the shortest edge of the image.
         | 
| 22 | 
            +
                    stride (int): Value by which the height and width of the image must be divisible.
         | 
| 23 | 
            +
                    max_pixels (int): Maximum pixels for the full image.
         | 
| 24 | 
            +
                    interpolation (InterpolationMode): Desired interpolation enum defined by
         | 
| 25 | 
            +
                        :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
         | 
| 26 | 
            +
                        If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
         | 
| 27 | 
            +
                        ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
         | 
| 28 | 
            +
                        The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
         | 
| 29 | 
            +
                    antialias (bool, optional): Whether to apply antialiasing (default is True).
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __init__(
         | 
| 33 | 
            +
                    self, 
         | 
| 34 | 
            +
                    max_size: int, 
         | 
| 35 | 
            +
                    min_size: int, 
         | 
| 36 | 
            +
                    stride: int, 
         | 
| 37 | 
            +
                    max_pixels: int,
         | 
| 38 | 
            +
                    interpolation=InterpolationMode.BICUBIC, 
         | 
| 39 | 
            +
                    antialias=True
         | 
| 40 | 
            +
                ):
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
                    self.max_size = max_size
         | 
| 43 | 
            +
                    self.min_size = min_size
         | 
| 44 | 
            +
                    self.stride = stride
         | 
| 45 | 
            +
                    self.max_pixels = max_pixels
         | 
| 46 | 
            +
                    self.interpolation = interpolation
         | 
| 47 | 
            +
                    self.antialias = antialias
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def _make_divisible(self, value, stride):
         | 
| 50 | 
            +
                    """Ensure the value is divisible by the stride."""
         | 
| 51 | 
            +
                    return max(stride, int(round(value / stride) * stride))
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def _apply_scale(self, width, height, scale):
         | 
| 54 | 
            +
                    new_width = round(width * scale)
         | 
| 55 | 
            +
                    new_height = round(height * scale)
         | 
| 56 | 
            +
                    new_width = self._make_divisible(new_width, self.stride)
         | 
| 57 | 
            +
                    new_height = self._make_divisible(new_height, self.stride)
         | 
| 58 | 
            +
                    return new_width, new_height
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def forward(self, img, img_num=1):
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    Args:
         | 
| 63 | 
            +
                        img (PIL Image): Image to be resized.
         | 
| 64 | 
            +
                        img_num (int): Number of images, used to change max_tokens.
         | 
| 65 | 
            +
                    Returns:
         | 
| 66 | 
            +
                        PIL Image or Tensor: Rescaled image with divisible dimensions.
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    if isinstance(img, torch.Tensor):
         | 
| 69 | 
            +
                        height, width = img.shape[-2:]
         | 
| 70 | 
            +
                    else:
         | 
| 71 | 
            +
                        width, height = img.size
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    scale = min(self.max_size / max(width, height), 1.0)
         | 
| 74 | 
            +
                    scale = max(scale, self.min_size / min(width, height))
         | 
| 75 | 
            +
                    new_width, new_height = self._apply_scale(width, height, scale)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    # Ensure the number of pixels does not exceed max_pixels
         | 
| 78 | 
            +
                    if new_width * new_height > self.max_pixels / img_num:
         | 
| 79 | 
            +
                        scale = self.max_pixels / img_num / (new_width * new_height)
         | 
| 80 | 
            +
                        new_width, new_height = self._apply_scale(new_width, new_height, scale)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # Ensure longest edge does not exceed max_size
         | 
| 83 | 
            +
                    if max(new_width, new_height) > self.max_size:
         | 
| 84 | 
            +
                        scale = self.max_size / max(new_width, new_height)
         | 
| 85 | 
            +
                        new_width, new_height = self._apply_scale(new_width, new_height, scale)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            class ImageTransform:
         | 
| 91 | 
            +
                def __init__(
         | 
| 92 | 
            +
                    self, 
         | 
| 93 | 
            +
                    max_image_size, 
         | 
| 94 | 
            +
                    min_image_size, 
         | 
| 95 | 
            +
                    image_stride, 
         | 
| 96 | 
            +
                    max_pixels=14*14*9*1024,
         | 
| 97 | 
            +
                    image_mean=[0.5, 0.5, 0.5], 
         | 
| 98 | 
            +
                    image_std=[0.5, 0.5, 0.5]
         | 
| 99 | 
            +
                ):
         | 
| 100 | 
            +
                    self.stride = image_stride
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    self.resize_transform = MaxLongEdgeMinShortEdgeResize(
         | 
| 103 | 
            +
                        max_size=max_image_size, 
         | 
| 104 | 
            +
                        min_size=min_image_size, 
         | 
| 105 | 
            +
                        stride=image_stride,
         | 
| 106 | 
            +
                        max_pixels=max_pixels,
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
                    self.to_tensor_transform = transforms.ToTensor()
         | 
| 109 | 
            +
                    self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def __call__(self, img, img_num=1):
         | 
| 112 | 
            +
                    img = self.resize_transform(img, img_num=img_num)
         | 
| 113 | 
            +
                    img = self.to_tensor_transform(img)
         | 
| 114 | 
            +
                    img = self.normalize_transform(img)
         | 
| 115 | 
            +
                    return img
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def decolorization(image):
         | 
| 119 | 
            +
                gray_image = image.convert('L')
         | 
| 120 | 
            +
                return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def downscale(image, scale_factor):
         | 
| 124 | 
            +
                new_width = int(round(image.width * scale_factor))
         | 
| 125 | 
            +
                new_height = int(round(image.height * scale_factor))
         | 
| 126 | 
            +
                new_width = max(1, new_width)
         | 
| 127 | 
            +
                new_height = max(1, new_height)
         | 
| 128 | 
            +
                return image.resize((new_width, new_height), resample=Image.BICUBIC)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def crop(image, crop_factors):
         | 
| 132 | 
            +
                target_h, target_w = crop_factors
         | 
| 133 | 
            +
                img_w, img_h = image.size
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                if target_h > img_h or target_w > img_w:
         | 
| 136 | 
            +
                    raise ValueError("Crop size exceeds image dimensions")
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                x = random.randint(0, img_w - target_w)
         | 
| 139 | 
            +
                y = random.randint(0, img_h - target_h)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            def motion_blur_opencv(image, kernel_size=15, angle=0):
         | 
| 145 | 
            +
                # 线性核
         | 
| 146 | 
            +
                kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
         | 
| 147 | 
            +
                kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                # 旋转核
         | 
| 150 | 
            +
                center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
         | 
| 151 | 
            +
                M = cv2.getRotationMatrix2D(center, angle, 1)
         | 
| 152 | 
            +
                rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                # 归一化核
         | 
| 155 | 
            +
                rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                img = np.array(image)
         | 
| 158 | 
            +
                if img.ndim == 2:
         | 
| 159 | 
            +
                    blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
         | 
| 160 | 
            +
                else:
         | 
| 161 | 
            +
                    # 对于彩色图像,各通道独立卷积
         | 
| 162 | 
            +
                    blurred = np.zeros_like(img)
         | 
| 163 | 
            +
                    for c in range(img.shape[2]):
         | 
| 164 | 
            +
                        blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                return Image.fromarray(blurred.astype(np.uint8))
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def shuffle_patch(image, num_splits, gap_size=2):
         | 
| 170 | 
            +
                """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
         | 
| 171 | 
            +
                h_splits, w_splits = num_splits
         | 
| 172 | 
            +
                img_w, img_h = image.size
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                base_patch_h = img_h // h_splits
         | 
| 175 | 
            +
                patch_heights = [base_patch_h] * (h_splits - 1)
         | 
| 176 | 
            +
                patch_heights.append(img_h - sum(patch_heights))
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                base_patch_w = img_w // w_splits
         | 
| 179 | 
            +
                patch_widths = [base_patch_w] * (w_splits - 1)
         | 
| 180 | 
            +
                patch_widths.append(img_w - sum(patch_widths))
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                patches = []
         | 
| 183 | 
            +
                current_y = 0
         | 
| 184 | 
            +
                for i in range(h_splits):
         | 
| 185 | 
            +
                    current_x = 0
         | 
| 186 | 
            +
                    patch_h = patch_heights[i]
         | 
| 187 | 
            +
                    for j in range(w_splits):
         | 
| 188 | 
            +
                        patch_w = patch_widths[j]
         | 
| 189 | 
            +
                        patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
         | 
| 190 | 
            +
                        patches.append(patch)
         | 
| 191 | 
            +
                        current_x += patch_w
         | 
| 192 | 
            +
                    current_y += patch_h
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                random.shuffle(patches)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                total_width = sum(patch_widths) + (w_splits - 1) * gap_size
         | 
| 197 | 
            +
                total_height = sum(patch_heights) + (h_splits - 1) * gap_size
         | 
| 198 | 
            +
                new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                current_y = 0  # 当前行的起始 Y 坐标
         | 
| 201 | 
            +
                patch_idx = 0  # 当前处理的块索引
         | 
| 202 | 
            +
                for i in range(h_splits):
         | 
| 203 | 
            +
                    current_x = 0  # 当前列的起始 X 坐标
         | 
| 204 | 
            +
                    patch_h = patch_heights[i]  # 当前行块的高度
         | 
| 205 | 
            +
                    for j in range(w_splits):
         | 
| 206 | 
            +
                        # 取出打乱后的块
         | 
| 207 | 
            +
                        patch = patches[patch_idx]
         | 
| 208 | 
            +
                        patch_w = patch_widths[j]  # 当前列块的宽度
         | 
| 209 | 
            +
                        # 粘贴块(左上角坐标为 (current_x, current_y))
         | 
| 210 | 
            +
                        new_image.paste(patch, (current_x, current_y))
         | 
| 211 | 
            +
                        # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
         | 
| 212 | 
            +
                        current_x += patch_w + gap_size
         | 
| 213 | 
            +
                        patch_idx += 1
         | 
| 214 | 
            +
                    # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
         | 
| 215 | 
            +
                    current_y += patch_h + gap_size
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                return new_image
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
         | 
| 221 | 
            +
                """
         | 
| 222 | 
            +
                图像分割后随机空白部分patch,用于inpainting任务
         | 
| 223 | 
            +
                
         | 
| 224 | 
            +
                参数:
         | 
| 225 | 
            +
                    image: PIL.Image 输入图像(RGB模式)
         | 
| 226 | 
            +
                    h_splits: int 行分割数(垂直方向分割块数)
         | 
| 227 | 
            +
                    w_splits: int 列分割数(水平方向分割块数)
         | 
| 228 | 
            +
                    blank_ratio: float 空白patch的比例(0~1)
         | 
| 229 | 
            +
                    blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
         | 
| 230 | 
            +
                
         | 
| 231 | 
            +
                返回:
         | 
| 232 | 
            +
                    PIL.Image 处理后拼接的图像
         | 
| 233 | 
            +
                """
         | 
| 234 | 
            +
                h_splits, w_splits = num_splits
         | 
| 235 | 
            +
                img_w, img_h = image.size
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                base_patch_h = img_h // h_splits
         | 
| 238 | 
            +
                patch_heights = [base_patch_h] * (h_splits - 1)
         | 
| 239 | 
            +
                patch_heights.append(img_h - sum(patch_heights))
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                base_patch_w = img_w // w_splits
         | 
| 242 | 
            +
                patch_widths = [base_patch_w] * (w_splits - 1)
         | 
| 243 | 
            +
                patch_widths.append(img_w - sum(patch_widths))
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                patches = []
         | 
| 246 | 
            +
                current_y = 0
         | 
| 247 | 
            +
                for i in range(h_splits):
         | 
| 248 | 
            +
                    current_x = 0
         | 
| 249 | 
            +
                    patch_h = patch_heights[i]
         | 
| 250 | 
            +
                    for j in range(w_splits):
         | 
| 251 | 
            +
                        patch_w = patch_widths[j]
         | 
| 252 | 
            +
                        patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
         | 
| 253 | 
            +
                        patches.append(patch)
         | 
| 254 | 
            +
                        current_x += patch_w
         | 
| 255 | 
            +
                    current_y += patch_h
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                total_patches = h_splits * w_splits
         | 
| 258 | 
            +
                num_blank = int(total_patches * blank_ratio)
         | 
| 259 | 
            +
                num_blank = max(0, min(num_blank, total_patches))
         | 
| 260 | 
            +
                blank_indices = random.sample(range(total_patches), num_blank)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                processed_patches = []
         | 
| 263 | 
            +
                for idx, patch in enumerate(patches):
         | 
| 264 | 
            +
                    if idx in blank_indices:
         | 
| 265 | 
            +
                        blank_patch = Image.new("RGB", patch.size, color=blank_color)
         | 
| 266 | 
            +
                        processed_patches.append(blank_patch)
         | 
| 267 | 
            +
                    else:
         | 
| 268 | 
            +
                        processed_patches.append(patch)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                # 创建结果图像(尺寸与原图一致)
         | 
| 271 | 
            +
                result_image = Image.new("RGB", (img_w, img_h))
         | 
| 272 | 
            +
                current_y = 0
         | 
| 273 | 
            +
                patch_idx = 0
         | 
| 274 | 
            +
                for i in range(h_splits):
         | 
| 275 | 
            +
                    current_x = 0
         | 
| 276 | 
            +
                    patch_h = patch_heights[i]
         | 
| 277 | 
            +
                    for j in range(w_splits):
         | 
| 278 | 
            +
                        # 取出处理后的patch
         | 
| 279 | 
            +
                        patch = processed_patches[patch_idx]
         | 
| 280 | 
            +
                        patch_w = patch_widths[j]
         | 
| 281 | 
            +
                        # 粘贴到原位置
         | 
| 282 | 
            +
                        result_image.paste(patch, (current_x, current_y))
         | 
| 283 | 
            +
                        current_x += patch_w
         | 
| 284 | 
            +
                        patch_idx += 1
         | 
| 285 | 
            +
                    current_y += patch_h
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                return result_image
         | 
    	
        data/video_utils.py
    ADDED
    
    | @@ -0,0 +1,165 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2023 OpenGVLab
         | 
| 2 | 
            +
            # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 3 | 
            +
            # SPDX-License-Identifier: MIT
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            # Original file was released under MIT, with the full license text
         | 
| 8 | 
            +
            # available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # This modified file is released under the same license.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            import io
         | 
| 14 | 
            +
            import os
         | 
| 15 | 
            +
            import random
         | 
| 16 | 
            +
            import re
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import numpy as np
         | 
| 19 | 
            +
            import decord
         | 
| 20 | 
            +
            from PIL import Image
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
         | 
| 24 | 
            +
                if sample in ['rand', 'middle']: # uniform sampling
         | 
| 25 | 
            +
                    acc_samples = min(num_frames, vlen)
         | 
| 26 | 
            +
                    # split the video into `acc_samples` intervals, and sample from each interval.
         | 
| 27 | 
            +
                    intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
         | 
| 28 | 
            +
                    ranges = []
         | 
| 29 | 
            +
                    for idx, interv in enumerate(intervals[:-1]):
         | 
| 30 | 
            +
                        ranges.append((interv, intervals[idx + 1] - 1))
         | 
| 31 | 
            +
                    if sample == 'rand':
         | 
| 32 | 
            +
                        try:
         | 
| 33 | 
            +
                            frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
         | 
| 34 | 
            +
                        except:
         | 
| 35 | 
            +
                            frame_indices = np.random.permutation(vlen)[:acc_samples]
         | 
| 36 | 
            +
                            frame_indices.sort()
         | 
| 37 | 
            +
                            frame_indices = list(frame_indices)
         | 
| 38 | 
            +
                    elif fix_start is not None:
         | 
| 39 | 
            +
                        frame_indices = [x[0] + fix_start for x in ranges]
         | 
| 40 | 
            +
                    elif sample == 'middle':
         | 
| 41 | 
            +
                        frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        raise NotImplementedError
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    if len(frame_indices) < num_frames:  # padded with last frame
         | 
| 46 | 
            +
                        padded_frame_indices = [frame_indices[-1]] * num_frames
         | 
| 47 | 
            +
                        padded_frame_indices[:len(frame_indices)] = frame_indices
         | 
| 48 | 
            +
                        frame_indices = padded_frame_indices
         | 
| 49 | 
            +
                elif 'fps' in sample:  # fps0.5, sequentially sample frames at 0.5 fps
         | 
| 50 | 
            +
                    output_fps = float(sample[3:])
         | 
| 51 | 
            +
                    duration = float(vlen) / input_fps
         | 
| 52 | 
            +
                    delta = 1 / output_fps  # gap between frames, this is also the clip length each frame represents
         | 
| 53 | 
            +
                    frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
         | 
| 54 | 
            +
                    frame_indices = np.around(frame_seconds * input_fps).astype(int)
         | 
| 55 | 
            +
                    frame_indices = [e for e in frame_indices if e < vlen]
         | 
| 56 | 
            +
                    if max_num_frames > 0 and len(frame_indices) > max_num_frames:
         | 
| 57 | 
            +
                        frame_indices = frame_indices[:max_num_frames]
         | 
| 58 | 
            +
                else:
         | 
| 59 | 
            +
                    raise ValueError
         | 
| 60 | 
            +
                return frame_indices
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
         | 
| 64 | 
            +
                video_reader = decord.VideoReader(video_path, num_threads=1)
         | 
| 65 | 
            +
                vlen = len(video_reader)
         | 
| 66 | 
            +
                fps = video_reader.get_avg_fps()
         | 
| 67 | 
            +
                duration = vlen / float(fps)
         | 
| 68 | 
            +
                if clip:
         | 
| 69 | 
            +
                    start, end = clip
         | 
| 70 | 
            +
                    duration = end - start
         | 
| 71 | 
            +
                    vlen = int(duration * fps)
         | 
| 72 | 
            +
                    start_index = int(start * fps)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                frame_indices = get_frame_indices(
         | 
| 77 | 
            +
                    t_num_frames, vlen, sample=sample, fix_start=fix_start,
         | 
| 78 | 
            +
                    input_fps=fps
         | 
| 79 | 
            +
                )
         | 
| 80 | 
            +
                if clip:
         | 
| 81 | 
            +
                    frame_indices = [f + start_index for f in frame_indices]
         | 
| 82 | 
            +
                frames = video_reader.get_batch(frame_indices).asnumpy()  # (T, H, W, C), np.uint8
         | 
| 83 | 
            +
                frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
         | 
| 84 | 
            +
                return frames
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def extract_frame_number(filename):
         | 
| 88 | 
            +
                # Extract the numeric part from the filename using regular expressions
         | 
| 89 | 
            +
                match = re.search(r'_(\d+).jpg$', filename)
         | 
| 90 | 
            +
                return int(match.group(1)) if match else -1
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def sort_frames(frame_paths):
         | 
| 94 | 
            +
                # Extract filenames from each path and sort by their numeric part
         | 
| 95 | 
            +
                return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
         | 
| 99 | 
            +
                image_list = sort_frames(list(os.listdir(video_path)))
         | 
| 100 | 
            +
                frames = []
         | 
| 101 | 
            +
                for image in image_list:
         | 
| 102 | 
            +
                    fp = os.path.join(video_path, image)
         | 
| 103 | 
            +
                    frame = Image.open(fp).convert('RGB')
         | 
| 104 | 
            +
                    frames.append(frame)
         | 
| 105 | 
            +
                vlen = len(frames)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                if vlen > t_num_frames:
         | 
| 110 | 
            +
                    frame_indices = get_frame_indices(
         | 
| 111 | 
            +
                        t_num_frames, vlen, sample=sample, fix_start=fix_start
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                    frames = [frames[i] for i in frame_indices]
         | 
| 114 | 
            +
                return frames
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            class FrameSampler:
         | 
| 118 | 
            +
                def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
         | 
| 119 | 
            +
                    self.max_num_frames = max_num_frames
         | 
| 120 | 
            +
                    self.min_num_frames = min_num_frames
         | 
| 121 | 
            +
                    self.sample = sample
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                def __call__(self, file_name):
         | 
| 124 | 
            +
                    fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
         | 
| 125 | 
            +
                    frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
         | 
| 126 | 
            +
                    return frames
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            def decode_video_byte(video_bytes):
         | 
| 130 | 
            +
                video_stream = io.BytesIO(video_bytes)
         | 
| 131 | 
            +
                vr = decord.VideoReader(video_stream)
         | 
| 132 | 
            +
                return vr
         | 
| 133 | 
            +
             | 
| 134 | 
            +
             | 
| 135 | 
            +
            def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
         | 
| 136 | 
            +
                if isinstance(mp4_p, str):
         | 
| 137 | 
            +
                    vr = decord.VideoReader(mp4_p, num_threads=1)
         | 
| 138 | 
            +
                elif isinstance(mp4_p, decord.video_reader.VideoReader):
         | 
| 139 | 
            +
                    vr = mp4_p
         | 
| 140 | 
            +
                video_fps = vr.get_avg_fps()  # 获取视频的帧率
         | 
| 141 | 
            +
                video_duration = len(vr) / video_fps
         | 
| 142 | 
            +
                if n_frames is not None:
         | 
| 143 | 
            +
                    if random_sample:
         | 
| 144 | 
            +
                        frame_indices = sorted(random.sample(range(len(vr)), n_frames))
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
         | 
| 147 | 
            +
                else:
         | 
| 148 | 
            +
                    frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
         | 
| 149 | 
            +
                frames = vr.get_batch(frame_indices).asnumpy()  # 转换为 numpy 数组
         | 
| 150 | 
            +
                frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
         | 
| 151 | 
            +
                if not return_frame_indices:
         | 
| 152 | 
            +
                    return frames, video_duration
         | 
| 153 | 
            +
                else:
         | 
| 154 | 
            +
                    return frames, video_duration, frame_indices
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
         | 
| 158 | 
            +
                if isinstance(mp4_p, str):
         | 
| 159 | 
            +
                    vr = decord.VideoReader(mp4_p, num_threads=1)
         | 
| 160 | 
            +
                elif isinstance(mp4_p, decord.video_reader.VideoReader):
         | 
| 161 | 
            +
                    vr = mp4_p
         | 
| 162 | 
            +
                # sample the frames in frame_indices
         | 
| 163 | 
            +
                frames = vr.get_batch(frame_indices).asnumpy()  # 转换为 numpy 数组
         | 
| 164 | 
            +
                frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
         | 
| 165 | 
            +
                return frames
         | 
    	
        data/vlm_dataset.py
    ADDED
    
    | @@ -0,0 +1,195 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2025 Bytedance Ltd. and/or its affiliates.
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import traceback
         | 
| 7 | 
            +
            from PIL import Image, ImageFile, PngImagePlugin
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .data_utils import pil_img2rgb
         | 
| 10 | 
            +
            from .distributed_iterable_dataset import DistributedIterableDataset
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            Image.MAX_IMAGE_PIXELS = 200000000
         | 
| 14 | 
            +
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
| 15 | 
            +
            MaximumDecompressedSize = 1024
         | 
| 16 | 
            +
            MegaByte = 2 ** 20
         | 
| 17 | 
            +
            PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class SftJSONLIterableDataset(DistributedIterableDataset):
         | 
| 21 | 
            +
                def __init__(
         | 
| 22 | 
            +
                    self, dataset_name, transform, tokenizer, frame_sampler, 
         | 
| 23 | 
            +
                    jsonl_path_list, data_dir_list, num_used_data, 
         | 
| 24 | 
            +
                    local_rank=0, world_size=1, num_workers=8, data_status=None, 
         | 
| 25 | 
            +
                    shuffle_lines=False, shuffle_seed=0,
         | 
| 26 | 
            +
                ):
         | 
| 27 | 
            +
                    """
         | 
| 28 | 
            +
                    jsonl_path_list: list of jsonl file paths
         | 
| 29 | 
            +
                    data_dir_list: list of image directories containing the images of each jsonl file
         | 
| 30 | 
            +
                    num_used_data: list of number of sampled data points for each jsonl
         | 
| 31 | 
            +
                    """
         | 
| 32 | 
            +
                    super().__init__(dataset_name, local_rank, world_size, num_workers)
         | 
| 33 | 
            +
                    self.transform = transform
         | 
| 34 | 
            +
                    self.tokenizer = tokenizer
         | 
| 35 | 
            +
                    self.frame_sampler = frame_sampler
         | 
| 36 | 
            +
                    self.data_status = data_status
         | 
| 37 | 
            +
                    self.data_paths = self.get_data_paths(
         | 
| 38 | 
            +
                        jsonl_path_list, 
         | 
| 39 | 
            +
                        data_dir_list, 
         | 
| 40 | 
            +
                        num_used_data, 
         | 
| 41 | 
            +
                        shuffle_lines, 
         | 
| 42 | 
            +
                        shuffle_seed,
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                    self.set_epoch()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def get_data_paths(
         | 
| 47 | 
            +
                    self, 
         | 
| 48 | 
            +
                    jsonl_path_list, 
         | 
| 49 | 
            +
                    data_dir_list, 
         | 
| 50 | 
            +
                    num_used_data, 
         | 
| 51 | 
            +
                    shuffle_lines, 
         | 
| 52 | 
            +
                    shuffle_seed,
         | 
| 53 | 
            +
                ):
         | 
| 54 | 
            +
                    data_paths = []
         | 
| 55 | 
            +
                    for jsonl_path, image_dir, num_data_point in zip(
         | 
| 56 | 
            +
                        jsonl_path_list, data_dir_list, num_used_data
         | 
| 57 | 
            +
                    ):
         | 
| 58 | 
            +
                        with open(jsonl_path, 'r') as f:
         | 
| 59 | 
            +
                            raw_data = f.readlines()
         | 
| 60 | 
            +
                        if shuffle_lines:
         | 
| 61 | 
            +
                            self.rng.seed(shuffle_seed)
         | 
| 62 | 
            +
                            self.rng.shuffle(raw_data)
         | 
| 63 | 
            +
                        raw_data = raw_data[:num_data_point]
         | 
| 64 | 
            +
                        data_paths.extend([(json_data, image_dir) for json_data in raw_data])
         | 
| 65 | 
            +
                    return data_paths
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def change_format(self, data, num_images):
         | 
| 68 | 
            +
                    elements = []
         | 
| 69 | 
            +
                    for conversation in data['conversations']:
         | 
| 70 | 
            +
                        if conversation['from'] == 'human':
         | 
| 71 | 
            +
                            if '<image>' not in conversation['value']:
         | 
| 72 | 
            +
                                elements.append({
         | 
| 73 | 
            +
                                    'type': 'text',
         | 
| 74 | 
            +
                                    'has_loss': 0,
         | 
| 75 | 
            +
                                    'text': conversation['value'],
         | 
| 76 | 
            +
                                })
         | 
| 77 | 
            +
                            else:
         | 
| 78 | 
            +
                                text_list = conversation['value'].split('<image>')
         | 
| 79 | 
            +
                                for idx, text in enumerate(text_list):
         | 
| 80 | 
            +
                                    if text.strip() != '':
         | 
| 81 | 
            +
                                        elements.append({
         | 
| 82 | 
            +
                                            'type': 'text',
         | 
| 83 | 
            +
                                            'has_loss': 0,
         | 
| 84 | 
            +
                                            'text': text.strip(),
         | 
| 85 | 
            +
                                        })
         | 
| 86 | 
            +
                                    if (idx != len(text_list) - 1) and (idx < num_images):
         | 
| 87 | 
            +
                                        elements.append({'type': 'image',})
         | 
| 88 | 
            +
                        elif conversation['from'] == 'gpt':
         | 
| 89 | 
            +
                            elements.append({
         | 
| 90 | 
            +
                                'type': 'text',
         | 
| 91 | 
            +
                                'has_loss': 1,
         | 
| 92 | 
            +
                                'text': conversation['value'],
         | 
| 93 | 
            +
                            })
         | 
| 94 | 
            +
                    return elements
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def __iter__(self):
         | 
| 97 | 
            +
                    data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
         | 
| 98 | 
            +
                    if self.data_status is not None:
         | 
| 99 | 
            +
                        row_start_id = self.data_status[worker_id] + 1
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        row_start_id = 0
         | 
| 102 | 
            +
                    transform_stride = self.transform.stride
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    print(
         | 
| 105 | 
            +
                        f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
         | 
| 106 | 
            +
                        f"resuming data at row#{row_start_id}"
         | 
| 107 | 
            +
                    )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    while True:
         | 
| 110 | 
            +
                        data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
         | 
| 111 | 
            +
                        for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
         | 
| 112 | 
            +
                            num_tokens = 0
         | 
| 113 | 
            +
                            image_tensor_list = []
         | 
| 114 | 
            +
                            text_ids_list = []
         | 
| 115 | 
            +
                            sequence_plan = []
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                            try:
         | 
| 118 | 
            +
                                data_item = json.loads(data)
         | 
| 119 | 
            +
                                raw_images = None
         | 
| 120 | 
            +
                                if 'image' in data_item:
         | 
| 121 | 
            +
                                    if type(data_item['image']) == list:
         | 
| 122 | 
            +
                                        raw_images = [
         | 
| 123 | 
            +
                                            pil_img2rgb(Image.open(os.path.join(image_dir, image)))
         | 
| 124 | 
            +
                                            for image in data_item['image']
         | 
| 125 | 
            +
                                        ]
         | 
| 126 | 
            +
                                    else:
         | 
| 127 | 
            +
                                        raw_images = [
         | 
| 128 | 
            +
                                            pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
         | 
| 129 | 
            +
                                        ]
         | 
| 130 | 
            +
                                elif 'video' in data_item:
         | 
| 131 | 
            +
                                    raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
         | 
| 132 | 
            +
                                    special_tokens = '<image>' * len(raw_images)
         | 
| 133 | 
            +
                                    for item in data_item['conversations']:
         | 
| 134 | 
            +
                                        if '<video>' in item['value']:
         | 
| 135 | 
            +
                                            item['value'] = item['value'].replace('<video>', special_tokens)
         | 
| 136 | 
            +
                                            break
         | 
| 137 | 
            +
                                        else:
         | 
| 138 | 
            +
                                            raise ValueError("Cannot find <video> in the conversation!")
         | 
| 139 | 
            +
                            except:
         | 
| 140 | 
            +
                                traceback.print_exc()
         | 
| 141 | 
            +
                                continue
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                            if raw_images:
         | 
| 144 | 
            +
                                for raw_image in raw_images:
         | 
| 145 | 
            +
                                    image_tensor = self.transform(raw_image, img_num=len(raw_images))
         | 
| 146 | 
            +
                                    image_tensor_list.append(image_tensor)
         | 
| 147 | 
            +
                                    height, width = image_tensor.shape[1:]
         | 
| 148 | 
            +
                                    num_tokens += width * height // transform_stride ** 2
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                            elements = self.change_format(data_item, len(image_tensor_list))
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                            for item in elements:
         | 
| 153 | 
            +
                                if item['type'] == 'text':
         | 
| 154 | 
            +
                                    text_data = item['text']
         | 
| 155 | 
            +
                                    text_ids = self.tokenizer.encode(text_data)
         | 
| 156 | 
            +
                                    if len(text_ids) > 0:
         | 
| 157 | 
            +
                                        text_ids_list.append(text_ids)
         | 
| 158 | 
            +
                                        num_tokens += len(text_ids)
         | 
| 159 | 
            +
                                        current_plan = {
         | 
| 160 | 
            +
                                            'type': 'text',
         | 
| 161 | 
            +
                                            'enable_cfg': 0,
         | 
| 162 | 
            +
                                            'loss': item['has_loss'],
         | 
| 163 | 
            +
                                            'special_token_loss': 0,
         | 
| 164 | 
            +
                                            'special_token_label': None,
         | 
| 165 | 
            +
                                        }
         | 
| 166 | 
            +
                                        sequence_plan.append(current_plan)
         | 
| 167 | 
            +
                                elif item['type'] == 'image':
         | 
| 168 | 
            +
                                    current_plan = {
         | 
| 169 | 
            +
                                        'type': 'vit_image',
         | 
| 170 | 
            +
                                        'enable_cfg': 0,
         | 
| 171 | 
            +
                                        'loss': 0,
         | 
| 172 | 
            +
                                        'special_token_loss': 0,
         | 
| 173 | 
            +
                                        'special_token_label': None,
         | 
| 174 | 
            +
                                    }
         | 
| 175 | 
            +
                                    sequence_plan.append(current_plan)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                            has_loss = [item['loss'] for item in sequence_plan]
         | 
| 178 | 
            +
                            if sum(has_loss) == 0:
         | 
| 179 | 
            +
                                print(f'No loss defined, skipped.')
         | 
| 180 | 
            +
                                continue
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                            yield dict(
         | 
| 183 | 
            +
                                image_tensor_list=image_tensor_list,
         | 
| 184 | 
            +
                                text_ids_list=text_ids_list,
         | 
| 185 | 
            +
                                sequence_plan=sequence_plan,
         | 
| 186 | 
            +
                                num_tokens=num_tokens,
         | 
| 187 | 
            +
                                data_indexes={
         | 
| 188 | 
            +
                                    "data_indexes": row_idx,
         | 
| 189 | 
            +
                                    "worker_id": worker_id,
         | 
| 190 | 
            +
                                    "dataset_name": self.dataset_name,
         | 
| 191 | 
            +
                                }
         | 
| 192 | 
            +
                            )
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                        row_start_id = 0
         | 
| 195 | 
            +
                        print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
         | 
