Spaces:
Paused
Paused
| # Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import random | |
| import json | |
| import numpy as np | |
| import torch | |
| from .data_utils import ( | |
| get_flattened_position_ids_interpolate, | |
| get_flattened_position_ids_extrapolate, | |
| len2weight, | |
| patchify, | |
| prepare_attention_mask_per_sample, | |
| ) | |
| from .dataset_info import DATASET_INFO, DATASET_REGISTRY | |
| from .transforms import ImageTransform | |
| from .video_utils import FrameSampler | |
| class DataConfig: | |
| def __init__( | |
| self, | |
| grouped_datasets, | |
| text_cond_dropout_prob=0.1, | |
| vit_cond_dropout_prob=0.4, | |
| vae_cond_dropout_prob=0.1, | |
| vae_image_downsample=16, | |
| max_latent_size=32, | |
| vit_patch_size=14, | |
| max_num_patch_per_side=70, | |
| ): | |
| self.grouped_datasets = grouped_datasets | |
| self.text_cond_dropout_prob = text_cond_dropout_prob | |
| self.vit_cond_dropout_prob = vit_cond_dropout_prob | |
| self.vit_patch_size = vit_patch_size | |
| self.max_num_patch_per_side = max_num_patch_per_side | |
| self.vae_cond_dropout_prob = vae_cond_dropout_prob | |
| self.vae_image_downsample = vae_image_downsample | |
| self.max_latent_size = max_latent_size | |
| class PackedDataset(torch.utils.data.IterableDataset): | |
| def __init__( | |
| self, | |
| data_config, | |
| tokenizer, | |
| special_tokens, | |
| local_rank, | |
| world_size, | |
| num_workers, | |
| expected_num_tokens=32768, | |
| max_num_tokens_per_sample=16384, | |
| max_num_tokens=36864, | |
| prefer_buffer_before=16384, | |
| max_buffer_size=50, | |
| interpolate_pos=False, | |
| use_flex=False, | |
| data_status=None, | |
| ): | |
| super().__init__() | |
| self.expected_num_tokens = expected_num_tokens | |
| self.max_num_tokens_per_sample = max_num_tokens_per_sample | |
| self.prefer_buffer_before = prefer_buffer_before | |
| self.max_num_tokens = max_num_tokens | |
| self.max_buffer_size = max_buffer_size | |
| self.tokenizer = tokenizer | |
| self.local_rank = local_rank | |
| self.world_size = world_size | |
| self.num_workers = num_workers | |
| self.use_flex = use_flex | |
| for k, v in special_tokens.items(): | |
| setattr(self, k, v) | |
| grouped_datasets, is_mandatory, grouped_weights = self.build_datasets( | |
| data_config.grouped_datasets, data_status | |
| ) | |
| self.grouped_datasets = grouped_datasets | |
| self.dataset_iters = [iter(dataset) for dataset in grouped_datasets] | |
| self.is_mandatory = is_mandatory | |
| self.grouped_weights = grouped_weights | |
| self.data_config = data_config | |
| self.interpolate_pos = interpolate_pos | |
| if self.interpolate_pos: | |
| self.get_flattened_position_ids = get_flattened_position_ids_interpolate | |
| else: | |
| self.get_flattened_position_ids = get_flattened_position_ids_extrapolate | |
| def build_datasets(self, datasets_metainfo, data_status): | |
| datasets = [] | |
| is_mandatory = [] | |
| grouped_weights = [] | |
| for grouped_dataset_name, dataset_args in datasets_metainfo.items(): | |
| is_mandatory.append(dataset_args.pop('is_mandatory', False)) | |
| grouped_weights.append(dataset_args.pop('weight', 0.0)) | |
| if 'frame_sampler_args' in dataset_args.keys(): | |
| frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args')) | |
| dataset_args['frame_sampler'] = frame_sampler | |
| if 'image_transform_args' in dataset_args.keys(): | |
| transform = ImageTransform(**dataset_args.pop('image_transform_args')) | |
| dataset_args['transform'] = transform | |
| if 'vit_image_transform_args' in dataset_args.keys(): | |
| vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args')) | |
| dataset_args['vit_transform'] = vit_transform | |
| assert 'dataset_names' in dataset_args.keys() | |
| dataset_names = dataset_args.pop('dataset_names') | |
| dataset_args['data_dir_list'] = [] | |
| for item in dataset_names: | |
| if self.local_rank == 0: | |
| print(f'Preparing Dataset {grouped_dataset_name}/{item}') | |
| meta_info = DATASET_INFO[grouped_dataset_name][item] | |
| dataset_args['data_dir_list'].append(meta_info['data_dir']) | |
| if "parquet_info_path" in meta_info.keys(): | |
| if 'parquet_info' not in dataset_args.keys(): | |
| dataset_args['parquet_info'] = {} | |
| with open(meta_info['parquet_info_path'], 'r') as f: | |
| parquet_info = json.load(f) | |
| dataset_args['parquet_info'].update(parquet_info) | |
| if 'json_dir' in meta_info.keys(): | |
| # parquet/tar with json | |
| if 'json_dir_list' not in dataset_args.keys(): | |
| dataset_args['json_dir_list'] = [meta_info['json_dir']] | |
| else: | |
| dataset_args['json_dir_list'].append(meta_info['json_dir']) | |
| if 'jsonl_path' in meta_info.keys(): | |
| # jsonl with jpeg | |
| if 'jsonl_path_list' not in dataset_args.keys(): | |
| dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']] | |
| else: | |
| dataset_args['jsonl_path_list'].append(meta_info['jsonl_path']) | |
| resume_data_status = dataset_args.pop('resume_data_status', True) | |
| if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status: | |
| data_status_per_group = data_status[grouped_dataset_name] | |
| else: | |
| data_status_per_group = None | |
| dataset = DATASET_REGISTRY[grouped_dataset_name]( | |
| dataset_name=grouped_dataset_name, | |
| tokenizer=self.tokenizer, | |
| local_rank=self.local_rank, | |
| world_size=self.world_size, | |
| num_workers=self.num_workers, | |
| data_status=data_status_per_group, | |
| **dataset_args | |
| ) | |
| datasets.append(dataset) | |
| return datasets, is_mandatory, grouped_weights | |
| def set_epoch(self, seed): | |
| for dataset in self.grouped_datasets: | |
| dataset.set_epoch(seed) | |
| def set_sequence_status(self): | |
| sequence_status = dict( | |
| curr = 0, | |
| sample_lens = list(), | |
| packed_position_ids = list(), | |
| nested_attention_masks = list(), | |
| split_lens = list(), | |
| attn_modes = list(), | |
| packed_text_ids = list(), | |
| packed_text_indexes = list(), | |
| packed_label_ids = list(), | |
| ce_loss_indexes = list(), | |
| ce_loss_weights = list(), | |
| vae_image_tensors = list(), | |
| packed_latent_position_ids = list(), | |
| vae_latent_shapes = list(), | |
| packed_vae_token_indexes = list(), | |
| packed_timesteps = list(), | |
| mse_loss_indexes = list(), | |
| packed_vit_tokens = list(), | |
| vit_token_seqlens = list(), | |
| packed_vit_position_ids = list(), | |
| packed_vit_token_indexes = list(), | |
| ) | |
| return sequence_status | |
| def to_tensor(self, sequence_status): | |
| data = dict( | |
| sequence_length=sum(sequence_status['sample_lens']), | |
| sample_lens=sequence_status['sample_lens'], | |
| packed_text_ids=torch.tensor(sequence_status['packed_text_ids']), | |
| packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']), | |
| packed_position_ids=torch.tensor(sequence_status['packed_position_ids']), | |
| ) | |
| if not self.use_flex: | |
| data['nested_attention_masks'] = sequence_status['nested_attention_masks'] | |
| else: | |
| sequence_len = data['sequence_length'] | |
| pad_len = self.max_num_tokens - sequence_len | |
| data['split_lens'] = sequence_status['split_lens'] + [pad_len] | |
| data['attn_modes'] = sequence_status['attn_modes'] + ['causal'] | |
| data['sample_lens'] += [pad_len] | |
| # if the model has a convnet vae (e.g., as visual tokenizer) | |
| if len(sequence_status['vae_image_tensors']) > 0: | |
| image_tensors = sequence_status.pop('vae_image_tensors') | |
| image_sizes = [item.shape for item in image_tensors] | |
| max_image_size = [max(item) for item in list(zip(*image_sizes))] | |
| padded_images = torch.zeros(size=(len(image_tensors), *max_image_size)) | |
| for i, image_tensor in enumerate(image_tensors): | |
| padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor | |
| data['padded_images'] = padded_images | |
| data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes'] | |
| data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0) | |
| data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes']) | |
| # if the model has a vit (e.g., as visual tokenizer) | |
| if len(sequence_status['packed_vit_tokens']) > 0: | |
| data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0) | |
| data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0) | |
| data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes']) | |
| data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens']) | |
| # if the model is required to perform visual generation | |
| if len(sequence_status['packed_timesteps']) > 0: | |
| data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps']) | |
| data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes']) | |
| # if the model is required to perform text generation | |
| if len(sequence_status['packed_label_ids']) > 0: | |
| data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids']) | |
| data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes']) | |
| data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights']) | |
| return data | |
| def __iter__(self): | |
| total_weights = sum(self.grouped_weights) | |
| assert total_weights > 0.0 | |
| group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights | |
| for i in range(len(self.grouped_weights))] | |
| sequence_status = self.set_sequence_status() | |
| batch_data_indexes = [] | |
| buffer = [] | |
| while True: | |
| # Ensure at least one sample from each group | |
| if sequence_status['curr'] == 0: | |
| for group_index, group_iter in enumerate(self.dataset_iters): | |
| if self.is_mandatory[group_index]: | |
| while True: | |
| sample = next(group_iter) | |
| # if a sample is too long, skip it | |
| num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan']) | |
| if num_tokens < self.max_num_tokens_per_sample: | |
| sequence_status = self.pack_sequence(sample, sequence_status) | |
| batch_data_indexes.append(sample['data_indexes']) | |
| break | |
| else: | |
| print(f"skip a sample with length {num_tokens}") | |
| continue | |
| if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0: | |
| sample = buffer.pop(0) | |
| sample_from_buffer = True | |
| else: | |
| # sample normally across all groups | |
| n = random.random() | |
| group_index = 0 | |
| for i, cumprob in enumerate(group_cumprobs): | |
| if n < cumprob: | |
| group_index = i | |
| break | |
| sample = next(self.dataset_iters[group_index]) | |
| sample_from_buffer = False | |
| # if a sample is too long, skip it | |
| num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan']) | |
| if num_tokens > self.max_num_tokens_per_sample: | |
| print(f"skip a sample with length {num_tokens}") | |
| continue | |
| if sequence_status['curr'] + num_tokens > self.max_num_tokens: | |
| if len(buffer) < self.max_buffer_size and not sample_from_buffer: | |
| buffer.append(sample) | |
| else: | |
| print(f"Yielding data with length {sum(sequence_status['sample_lens'])}") | |
| data = self.to_tensor(sequence_status) | |
| data['batch_data_indexes'] = batch_data_indexes | |
| yield data | |
| sequence_status = self.set_sequence_status() | |
| batch_data_indexes = [] | |
| continue | |
| sequence_status = self.pack_sequence(sample, sequence_status) | |
| batch_data_indexes.append(sample['data_indexes']) | |
| if sequence_status['curr'] >= self.expected_num_tokens: | |
| data = self.to_tensor(sequence_status) | |
| data['batch_data_indexes'] = batch_data_indexes | |
| yield data | |
| sequence_status = self.set_sequence_status() | |
| batch_data_indexes = [] | |
| def pack_sequence(self, sample, sequence_status): | |
| image_tensor_list = sample['image_tensor_list'] | |
| text_ids_list = sample['text_ids_list'] | |
| sequence_plan = sample['sequence_plan'] | |
| split_lens, attn_modes = list(), list() | |
| curr = sequence_status['curr'] | |
| curr_rope_id = 0 | |
| sample_lens = 0 | |
| for item in sequence_plan: | |
| split_start = item.get('split_start', True) | |
| if split_start: | |
| curr_split_len = 0 | |
| if item['type'] == 'text': | |
| text_ids = text_ids_list.pop(0) | |
| if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob: | |
| continue | |
| shifted_text_ids = [self.bos_token_id] + text_ids | |
| sequence_status['packed_text_ids'].extend(shifted_text_ids) | |
| sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids))) | |
| if item['loss'] == 1: | |
| sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids))) | |
| sequence_status['ce_loss_weights'].extend( | |
| [len2weight(len(shifted_text_ids))] * len(shifted_text_ids) | |
| ) | |
| sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id]) | |
| curr += len(shifted_text_ids) | |
| curr_split_len += len(shifted_text_ids) | |
| # add a <|im_end|> token | |
| sequence_status['packed_text_ids'].append(self.eos_token_id) | |
| sequence_status['packed_text_indexes'].append(curr) | |
| if item['special_token_loss'] == 1: # <|im_end|> may have loss | |
| sequence_status['ce_loss_indexes'].append(curr) | |
| sequence_status['ce_loss_weights'].append(1.0) | |
| sequence_status['packed_label_ids'].append(item['special_token_label']) | |
| curr += 1 | |
| curr_split_len += 1 | |
| # update sequence status | |
| attn_modes.append("causal") | |
| sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len)) | |
| curr_rope_id += curr_split_len | |
| elif item['type'] == 'vit_image': | |
| image_tensor = image_tensor_list.pop(0) | |
| if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob: | |
| curr_rope_id += 1 | |
| continue | |
| # add a <|startofimage|> token | |
| sequence_status['packed_text_ids'].append(self.start_of_image) | |
| sequence_status['packed_text_indexes'].append(curr) | |
| curr += 1 | |
| curr_split_len += 1 | |
| # preprocess image | |
| vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size) | |
| num_img_tokens = vit_tokens.shape[0] | |
| sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens)) | |
| curr += num_img_tokens | |
| curr_split_len += num_img_tokens | |
| sequence_status['packed_vit_tokens'].append(vit_tokens) | |
| sequence_status['vit_token_seqlens'].append(num_img_tokens) | |
| sequence_status['packed_vit_position_ids'].append( | |
| self.get_flattened_position_ids( | |
| image_tensor.size(1), image_tensor.size(2), | |
| self.data_config.vit_patch_size, | |
| max_num_patches_per_side=self.data_config.max_num_patch_per_side | |
| ) | |
| ) | |
| # add a <|endofimage|> token | |
| sequence_status['packed_text_ids'].append(self.end_of_image) | |
| sequence_status['packed_text_indexes'].append(curr) | |
| if item['special_token_loss'] == 1: # <|endofimage|> may have loss | |
| sequence_status['ce_loss_indexes'].append(curr) | |
| sequence_status['ce_loss_weights'].append(1.0) | |
| sequence_status['packed_label_ids'].append(item['special_token_label']) | |
| curr += 1 | |
| curr_split_len += 1 | |
| # update sequence status | |
| attn_modes.append("full") | |
| sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len) | |
| curr_rope_id += 1 | |
| elif item['type'] == 'vae_image': | |
| image_tensor = image_tensor_list.pop(0) | |
| if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob: | |
| # FIXME fix vae dropout in video2video setting. | |
| curr_rope_id += 1 | |
| continue | |
| # add a <|startofimage|> token | |
| sequence_status['packed_text_ids'].append(self.start_of_image) | |
| sequence_status['packed_text_indexes'].append(curr) | |
| curr += 1 | |
| curr_split_len += 1 | |
| # preprocess image | |
| sequence_status['vae_image_tensors'].append(image_tensor) | |
| sequence_status['packed_latent_position_ids'].append( | |
| self.get_flattened_position_ids( | |
| image_tensor.size(1), image_tensor.size(2), | |
| self.data_config.vae_image_downsample, | |
| max_num_patches_per_side=self.data_config.max_latent_size | |
| ) | |
| ) | |
| H, W = image_tensor.shape[1:] | |
| h = H // self.data_config.vae_image_downsample | |
| w = W // self.data_config.vae_image_downsample | |
| sequence_status['vae_latent_shapes'].append((h, w)) | |
| num_img_tokens = w * h | |
| sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens)) | |
| if item['loss'] == 1: | |
| sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens)) | |
| if split_start: | |
| timestep = np.random.randn() | |
| else: | |
| timestep = float('-inf') | |
| sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens) | |
| curr += num_img_tokens | |
| curr_split_len += num_img_tokens | |
| # add a <|endofimage|> token | |
| sequence_status['packed_text_ids'].append(self.end_of_image) | |
| sequence_status['packed_text_indexes'].append(curr) | |
| # <|endofimage|> may have loss | |
| if item['special_token_loss'] == 1: | |
| sequence_status['ce_loss_indexes'].append(curr) | |
| sequence_status['ce_loss_weights'].append(1.0) | |
| sequence_status['packed_label_ids'].append(item['special_token_label']) | |
| curr += 1 | |
| curr_split_len += 1 | |
| # update sequence status | |
| if split_start: | |
| if item['loss'] == 1 and 'frame_delta' not in item.keys(): | |
| attn_modes.append("noise") | |
| else: | |
| attn_modes.append("full") | |
| sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2)) | |
| if 'frame_delta' in item.keys(): | |
| curr_rope_id += item['frame_delta'] | |
| elif item['loss'] == 0: | |
| curr_rope_id += 1 | |
| if item.get('split_end', True): | |
| split_lens.append(curr_split_len) | |
| sample_lens += curr_split_len | |
| sequence_status['curr'] = curr | |
| sequence_status['sample_lens'].append(sample_lens) | |
| # prepare attention mask | |
| if not self.use_flex: | |
| sequence_status['nested_attention_masks'].append( | |
| prepare_attention_mask_per_sample(split_lens, attn_modes) | |
| ) | |
| else: | |
| sequence_status['split_lens'].extend(split_lens) | |
| sequence_status['attn_modes'].extend(attn_modes) | |
| return sequence_status | |
| class SimpleCustomBatch: | |
| def __init__(self, batch): | |
| data = batch[0] | |
| self.batch_data_indexes = data['batch_data_indexes'] | |
| self.sequence_length = data["sequence_length"] | |
| self.sample_lens = data["sample_lens"] | |
| self.packed_text_ids = data["packed_text_ids"] | |
| self.packed_text_indexes = data["packed_text_indexes"] | |
| self.packed_position_ids = data["packed_position_ids"] | |
| self.use_flex = "nested_attention_masks" not in data.keys() | |
| if self.use_flex: | |
| self.split_lens = data["split_lens"] | |
| self.attn_modes = data["attn_modes"] | |
| else: | |
| self.nested_attention_masks = data["nested_attention_masks"] | |
| if "padded_images" in data.keys(): | |
| self.padded_images = data["padded_images"] | |
| self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"] | |
| self.packed_latent_position_ids = data["packed_latent_position_ids"] | |
| self.packed_vae_token_indexes = data["packed_vae_token_indexes"] | |
| if "packed_vit_tokens" in data.keys(): | |
| self.packed_vit_tokens = data["packed_vit_tokens"] | |
| self.packed_vit_position_ids = data["packed_vit_position_ids"] | |
| self.packed_vit_token_indexes = data["packed_vit_token_indexes"] | |
| self.vit_token_seqlens = data["vit_token_seqlens"] | |
| if "packed_timesteps" in data.keys(): | |
| self.packed_timesteps = data["packed_timesteps"] | |
| self.mse_loss_indexes = data["mse_loss_indexes"] | |
| if "packed_label_ids" in data.keys(): | |
| self.packed_label_ids = data["packed_label_ids"] | |
| self.ce_loss_indexes = data["ce_loss_indexes"] | |
| self.ce_loss_weights = data["ce_loss_weights"] | |
| def pin_memory(self): | |
| self.packed_text_ids = self.packed_text_ids.pin_memory() | |
| self.packed_text_indexes = self.packed_text_indexes.pin_memory() | |
| self.packed_position_ids = self.packed_position_ids.pin_memory() | |
| if not self.use_flex: | |
| self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks] | |
| if hasattr(self, 'padded_images'): | |
| self.padded_images = self.padded_images.pin_memory() | |
| self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory() | |
| self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory() | |
| if hasattr(self, 'packed_timesteps'): | |
| self.packed_timesteps = self.packed_timesteps.pin_memory() | |
| self.mse_loss_indexes = self.mse_loss_indexes.pin_memory() | |
| if hasattr(self, 'packed_vit_tokens'): | |
| self.packed_vit_tokens = self.packed_vit_tokens.pin_memory() | |
| self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory() | |
| self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory() | |
| self.vit_token_seqlens = self.vit_token_seqlens.pin_memory() | |
| if hasattr(self, 'packed_label_ids'): | |
| self.packed_label_ids = self.packed_label_ids.pin_memory() | |
| self.ce_loss_indexes = self.ce_loss_indexes.pin_memory() | |
| self.ce_loss_weights = self.ce_loss_weights.pin_memory() | |
| return self | |
| def cuda(self, device): | |
| self.packed_text_ids = self.packed_text_ids.to(device) | |
| self.packed_text_indexes = self.packed_text_indexes.to(device) | |
| self.packed_position_ids = self.packed_position_ids.to(device) | |
| if not self.use_flex: | |
| self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks] | |
| if hasattr(self, 'padded_images'): | |
| self.padded_images = self.padded_images.to(device) | |
| self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device) | |
| self.packed_latent_position_ids = self.packed_latent_position_ids.to(device) | |
| if hasattr(self, 'packed_timesteps'): | |
| self.packed_timesteps = self.packed_timesteps.to(device) | |
| self.mse_loss_indexes = self.mse_loss_indexes.to(device) | |
| if hasattr(self, 'packed_vit_tokens'): | |
| self.packed_vit_tokens = self.packed_vit_tokens.to(device) | |
| self.packed_vit_position_ids = self.packed_vit_position_ids.to(device) | |
| self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device) | |
| self.vit_token_seqlens = self.vit_token_seqlens.to(device) | |
| if hasattr(self, 'packed_label_ids'): | |
| self.packed_label_ids = self.packed_label_ids.to(device) | |
| self.ce_loss_indexes = self.ce_loss_indexes.to(device) | |
| self.ce_loss_weights = self.ce_loss_weights.to(device) | |
| return self | |
| def to_dict(self): | |
| data = dict( | |
| sequence_length = self.sequence_length, | |
| sample_lens = self.sample_lens, | |
| packed_text_ids = self.packed_text_ids, | |
| packed_text_indexes = self.packed_text_indexes, | |
| packed_position_ids = self.packed_position_ids, | |
| batch_data_indexes = self.batch_data_indexes, | |
| ) | |
| if not self.use_flex: | |
| data['nested_attention_masks'] = self.nested_attention_masks | |
| else: | |
| data['split_lens'] = self.split_lens | |
| data['attn_modes'] = self.attn_modes | |
| if hasattr(self, 'padded_images'): | |
| data['padded_images'] = self.padded_images | |
| data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes | |
| data['packed_latent_position_ids'] = self.packed_latent_position_ids | |
| data['packed_vae_token_indexes'] = self.packed_vae_token_indexes | |
| if hasattr(self, 'packed_vit_tokens'): | |
| data['packed_vit_tokens'] = self.packed_vit_tokens | |
| data['packed_vit_position_ids'] = self.packed_vit_position_ids | |
| data['packed_vit_token_indexes'] = self.packed_vit_token_indexes | |
| data['vit_token_seqlens'] = self.vit_token_seqlens | |
| if hasattr(self, 'packed_timesteps'): | |
| data['packed_timesteps'] = self.packed_timesteps | |
| data['mse_loss_indexes'] = self.mse_loss_indexes | |
| if hasattr(self, 'packed_label_ids'): | |
| data['packed_label_ids'] = self.packed_label_ids | |
| data['ce_loss_indexes'] = self.ce_loss_indexes | |
| data['ce_loss_weights'] = self.ce_loss_weights | |
| return data | |
| def collate_wrapper(): | |
| def collate_fn(batch): | |
| return SimpleCustomBatch(batch) | |
| return collate_fn | |