diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d50c243c4f56d9b456e16b1d5ac6f631c903ddd3 --- /dev/null +++ b/app.py @@ -0,0 +1,205 @@ +# vae: +# class_path: src.models.vae.LatentVAE +# init_args: +# precompute: true +# weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ +# denoiser: +# class_path: src.models.denoiser.decoupled_improved_dit.DDT +# init_args: +# in_channels: 4 +# patch_size: 2 +# num_groups: 16 +# hidden_size: &hidden_dim 1152 +# num_blocks: 28 +# num_encoder_blocks: 22 +# num_classes: 1000 +# conditioner: +# class_path: src.models.conditioner.LabelConditioner +# init_args: +# null_class: 1000 +# diffusion_sampler: +# class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler +# init_args: +# num_steps: 250 +# guidance: 3.0 +# state_refresh_rate: 1 +# guidance_interval_min: 0.3 +# guidance_interval_max: 1.0 +# timeshift: 1.0 +# last_step: 0.04 +# scheduler: *scheduler +# w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler +# guidance_fn: src.diffusion.base.guidance.simple_guidance_fn +# step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn +import random +import os +import torch +import argparse +from omegaconf import OmegaConf +from src.models.autoencoder.base import fp2uint8 +from src.diffusion.base.guidance import simple_guidance_fn +from src.diffusion.flow_matching.adam_sampling import AdamLMSampler +from src.diffusion.flow_matching.scheduling import LinearScheduler +from PIL import Image +import gradio as gr +import tempfile +from huggingface_hub import snapshot_download + + +def instantiate_class(config): + kwargs = config.get("init_args", {}) + class_module, class_name = config["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(**kwargs) + +def load_model(weight_dict, denoiser): + prefix = "ema_denoiser." + for k, v in denoiser.state_dict().items(): + try: + v.copy_(weight_dict["state_dict"][prefix + k]) + except: + print(f"Failed to copy {prefix + k} to denoiser weight") + return denoiser + + +class Pipeline: + def __init__(self, vae, denoiser, conditioner, resolution): + self.vae = vae.cuda() + self.denoiser = denoiser.cuda() + self.conditioner = conditioner.cuda() + self.conditioner.compile() + self.resolution = resolution + self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_") + # self.denoiser.compile() + + def __del__(self): + self.tmp_dir.cleanup() + + @torch.no_grad() + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def __call__(self, y, num_images, seed, image_height, image_width, num_steps, guidance, timeshift, order): + diffusion_sampler = AdamLMSampler( + order=order, + scheduler=LinearScheduler(), + guidance_fn=simple_guidance_fn, + num_steps=num_steps, + guidance=guidance, + timeshift=timeshift + ) + generator = torch.Generator(device="cpu").manual_seed(seed) + image_height = image_height // 32 * 32 + image_width = image_width // 32 * 32 + self.denoiser.decoder_patch_scaling_h = image_height / 512 + self.denoiser.decoder_patch_scaling_w = image_width / 512 + xT = torch.randn((num_images, 3, image_height, image_width), device="cpu", dtype=torch.float32, + generator=generator) + xT = xT.to("cuda") + with torch.no_grad(): + condition, uncondition = conditioner([y,]*num_images) + + + # Sample images: + samples, trajs = diffusion_sampler(denoiser, xT, condition, uncondition, return_x_trajs=True) + + def decode_images(samples): + samples = vae.decode(samples) + samples = fp2uint8(samples) + samples = samples.permute(0, 2, 3, 1).cpu().numpy() + images = [] + for i in range(len(samples)): + image = Image.fromarray(samples[i]) + images.append(image) + return images + + def decode_trajs(trajs): + cat_trajs = torch.stack(trajs, dim=0).permute(1, 0, 2, 3, 4) + animations = [] + for i in range(cat_trajs.shape[0]): + frames = decode_images( + cat_trajs[i] + ) + # 生成唯一文件名(结合seed和样本索引,避免冲突) + gif_filename = f"{random.randint(0, 100000)}.gif" + gif_path = os.path.join(self.tmp_dir.name, gif_filename) + frames[0].save( + gif_path, + format="GIF", + append_images=frames[1:], + save_all=True, + duration=200, + loop=0 + ) + animations.append(gif_path) + return animations + + images = decode_images(samples) + animations = decode_trajs(trajs) + + return images, animations + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs_t2i/inference_heavydecoder.yaml") + parser.add_argument("--resolution", type=int, default=512) + parser.add_argument("--model_id", type=str, default="MCG-NJU/PixNerd-XXL-P16-T2I") + parser.add_argument("--ckpt_path", type=str, default="models") + + args = parser.parse_args() + if not os.path.exists(args.ckpt_path): + snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path) + ckpt_path = os.path.join(args.ckpt_path, "model.ckpt") + else: + ckpt_path = args.ckpt_path + + config = OmegaConf.load(args.config) + vae_config = config.model.vae + denoiser_config = config.model.denoiser + conditioner_config = config.model.conditioner + + vae = instantiate_class(vae_config) + denoiser = instantiate_class(denoiser_config) + conditioner = instantiate_class(conditioner_config) + + + ckpt = torch.load(ckpt_path, map_location="cpu") + denoiser = load_model(ckpt, denoiser) + denoiser = denoiser.cuda() + vae = vae.cuda() + denoiser.eval() + + + pipeline = Pipeline(vae, denoiser, conditioner, args.resolution) + + with gr.Blocks() as demo: + gr.Markdown(f"config:{args.config}\n\n ckpt_path:{args.ckpt_path}") + with gr.Row(): + with gr.Column(scale=1): + num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=25) + guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0) + image_height = gr.Slider(minimum=128, maximum=1024, step=32, label="image height", value=512) + image_width = gr.Slider(minimum=128, maximum=1024, step=32, label="image width", value=512) + num_images = gr.Slider(minimum=1, maximum=4, step=1, label="num images", value=4) + label = gr.Textbox(label="positive prompt", value="a photo of a cat") + seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) + timeshift = gr.Slider(minimum=0.1, maximum=5.0, step=0.1, label="timeshift", value=3.0) + order = gr.Slider(minimum=1, maximum=4, step=1, label="order", value=2) + with gr.Column(scale=2): + btn = gr.Button("Generate") + output_sample = gr.Gallery(label="Images", columns=2, rows=2) + with gr.Column(scale=2): + output_trajs = gr.Gallery(label="Trajs of Diffusion", columns=2, rows=2) + + btn.click(fn=pipeline, + inputs=[ + label, + num_images, + seed, + image_height, + image_width, + num_steps, + guidance, + timeshift, + order + ], outputs=[output_sample, output_trajs]) + demo.launch(server_name="0.0.0.0", server_port=7861) \ No newline at end of file diff --git a/configs_t2i/inference_heavydecoder.yaml b/configs_t2i/inference_heavydecoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f553cfab304d125c91f32e282cd2d7cfa539c843 --- /dev/null +++ b/configs_t2i/inference_heavydecoder.yaml @@ -0,0 +1,133 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp sftpix512ema +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_pix_t2i_pixnerd_workdirs + accelerator: auto + strategy: ddp + devices: auto + num_nodes: 1 + precision: bf16-true + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_pix_t2i_pixnerd_workdirs + name: *exp + num_sanity_val_steps: 2 + max_steps: 60000 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm + val_check_interval: 10000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 20000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.autoencoder.pixel.PixelAE + denoiser: + class_path: src.models.transformer.pixnerd_heavydecoder.PixNerDiT + init_args: + in_channels: 3 + patch_size: 16 + num_groups: 24 + hidden_size: &hidden_dim 1536 + txt_embed_dim: &txt_embed_dim 2048 + txt_max_length: 128 + num_text_blocks: 4 + decoder_hidden_size: 64 + num_encoder_blocks: 16 + num_decoder_blocks: 2 + conditioner: + class_path: src.models.conditioner.qwen3_text_encoder.Qwen3TextEncoder + init_args: + weight_path: Qwen/Qwen3-1.7B + embed_dim: *txt_embed_dim + max_length: 128 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + feat_loss_weight: 0.5 + timeshift: 4.0 + encoder: + class_path: src.models.encoder.DINOv2 + init_args: + weight_path: /mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main/dinov2_vitb14 + align_layer: 6 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.adam_sampling.AdamLMSampler + init_args: + num_steps: 25 + guidance: 4.0 + timeshift: 3.0 + order: 2 + scheduler: *scheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-5 + betas: + - 0.9 + - 0.999 + weight_decay: 0.00 +data: + train_dataset: + class_path: src.data.dataset.image_txt.ImageText + init_args: + root: datasets/BLIP3o-60k + resolution: 512 + eval_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 512 + - 512 + pred_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata_rephrased.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 512 + - 512 +# class_path: src.data.dataset.dpg.DPGDataset +# init_args: +# prompt_path: ./evaluations/ELLA/dpg_bench/prompts +# num_samples_per_instance: 4 +# latent_shape: +# - 3 +# - 512 +# - 512 + train_batch_size: 24 + train_num_workers: 4 + pred_batch_size: 8 + pred_num_workers: 1 \ No newline at end of file diff --git a/configs_t2i/pretraining_res256.yaml b/configs_t2i/pretraining_res256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04e467ba1e88fd815bb7e101617b3c8214fb9233 --- /dev/null +++ b/configs_t2i/pretraining_res256.yaml @@ -0,0 +1,120 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp pretraining_pix256 +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_pix_t2i_pixnerd_workdirs + accelerator: auto + strategy: ddp + devices: auto + num_nodes: 1 + precision: bf16-true + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_pix_t2i_pixnerd_workdirs + name: *exp + num_sanity_val_steps: 2 + max_steps: 200000 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm + val_check_interval: 10000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 20000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.autoencoder.pixel.PixelAE + denoiser: + class_path: src.models.transformer.pixnerd.PixNerDiT + init_args: + in_channels: 3 + patch_size: 16 + num_groups: 24 + hidden_size: &hidden_dim 1536 + txt_embed_dim: &txt_embed_dim 2048 + txt_max_length: 128 + num_text_blocks: 4 + decoder_hidden_size: 64 + num_encoder_blocks: 16 + num_decoder_blocks: 2 + conditioner: + class_path: src.models.conditioner.qwen3_text_encoder.Qwen3TextEncoder + init_args: + weight_path: Qwen/Qwen3-1.7B + embed_dim: *txt_embed_dim + max_length: 128 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + feat_loss_weight: 0.5 + timeshift: 2.0 + encoder: + class_path: src.models.encoder.DINOv2 + init_args: + weight_path: /mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main/dinov2_vitb14 + align_layer: 6 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 100 + guidance: 4.0 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 2e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.00 +data: + eval_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 256 + - 256 + pred_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata_rephrased.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 256 + - 256 + train_batch_size: 96 + train_num_workers: 4 + pred_batch_size: 8 + pred_num_workers: 1 \ No newline at end of file diff --git a/configs_t2i/pretraining_res512.yaml b/configs_t2i/pretraining_res512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ae8b94bf4ee67c62aadd75c0e79cbd4815cf70e --- /dev/null +++ b/configs_t2i/pretraining_res512.yaml @@ -0,0 +1,120 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp pretraining_pix512 +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_pix_t2i_pixnerd_workdirs + accelerator: auto + strategy: ddp + devices: auto + num_nodes: 1 + precision: bf16-true + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_pix_t2i_pixnerd_workdirs + name: *exp + num_sanity_val_steps: 2 + max_steps: 400000 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm + val_check_interval: 10000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 20000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.autoencoder.pixel.PixelAE + denoiser: + class_path: src.models.transformer.pixnerd.PixNerDiT + init_args: + in_channels: 3 + patch_size: 16 + num_groups: 24 + hidden_size: &hidden_dim 1536 + txt_embed_dim: &txt_embed_dim 2048 + txt_max_length: 128 + num_text_blocks: 4 + decoder_hidden_size: 64 + num_encoder_blocks: 16 + num_decoder_blocks: 2 + conditioner: + class_path: src.models.conditioner.qwen3_text_encoder.Qwen3TextEncoder + init_args: + weight_path: Qwen/Qwen3-1.7B + embed_dim: *txt_embed_dim + max_length: 128 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + feat_loss_weight: 0.5 + timeshift: 4.0 + encoder: + class_path: src.models.encoder.DINOv2 + init_args: + weight_path: /mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main/dinov2_vitb14 + align_layer: 6 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 100 + guidance: 4.0 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.00 +data: + eval_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 512 + - 512 + pred_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata_rephrased.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 512 + - 512 + train_batch_size: 24 + train_num_workers: 4 + pred_batch_size: 8 + pred_num_workers: 1 \ No newline at end of file diff --git a/configs_t2i/sft_res512.yaml b/configs_t2i/sft_res512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d11f2b2451ce3874f7fa53ec20fa07d275cdec9d --- /dev/null +++ b/configs_t2i/sft_res512.yaml @@ -0,0 +1,133 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp sftpix512ema +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_pix_t2i_pixnerd_workdirs + accelerator: auto + strategy: ddp + devices: auto + num_nodes: 1 + precision: bf16-true + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_pix_t2i_pixnerd_workdirs + name: *exp + num_sanity_val_steps: 2 + max_steps: 60000 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm + val_check_interval: 10000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 20000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.autoencoder.pixel.PixelAE + denoiser: + class_path: src.models.transformer.pixnerd.PixNerDiT + init_args: + in_channels: 3 + patch_size: 16 + num_groups: 24 + hidden_size: &hidden_dim 1536 + txt_embed_dim: &txt_embed_dim 2048 + txt_max_length: 128 + num_text_blocks: 4 + decoder_hidden_size: 64 + num_encoder_blocks: 16 + num_decoder_blocks: 2 + conditioner: + class_path: src.models.conditioner.qwen3_text_encoder.Qwen3TextEncoder + init_args: + weight_path: Qwen/Qwen3-1.7B + embed_dim: *txt_embed_dim + max_length: 128 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + feat_loss_weight: 0.5 + timeshift: 4.0 + encoder: + class_path: src.models.encoder.DINOv2 + init_args: + weight_path: /mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main/dinov2_vitb14 + align_layer: 6 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.adam_sampling.AdamLMSampler + init_args: + num_steps: 25 + guidance: 4.0 + timeshift: 3.0 + order: 2 + scheduler: *scheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-5 + betas: + - 0.9 + - 0.999 + weight_decay: 0.00 +data: +# train_dataset: +# class_path: src.data.dataset.image_txt.ImageText +# init_args: +# root: /mnt/bn/wangshuai7/datasets/BLIP3o-60k +# resolution: 512 + eval_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 512 + - 512 + pred_dataset: + class_path: src.data.dataset.geneval.GenEvalDataset + init_args: + meta_json_path: ./evaluations/geneval/evaluation_metadata_rephrased.jsonl + num_samples_per_instance: 4 + latent_shape: + - 3 + - 512 + - 512 +# class_path: src.data.dataset.dpg.DPGDataset +# init_args: +# prompt_path: ./evaluations/dpg_bench/prompts +# num_samples_per_instance: 4 +# latent_shape: +# - 3 +# - 512 +# - 512 + train_batch_size: 24 + train_num_workers: 4 + pred_batch_size: 8 + pred_num_workers: 1 \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..da82e72df5c0da373e7a6432f4c721531b2a81af --- /dev/null +++ b/main.py @@ -0,0 +1,84 @@ +import os +import torch +import time +from typing import Any, Union + +from src.utils.patch_bugs import * +from lightning import Trainer, LightningModule + +from src.lightning_data import DataModule +from src.lightning_model import LightningModel +from lightning.pytorch.cli import LightningCLI, LightningArgumentParser, SaveConfigCallback + +import logging +logger = logging.getLogger("lightning.pytorch") +# log_path = os.path.join( f"log.txt") +# logger.addHandler(logging.FileHandler(log_path)) + +class ReWriteRootSaveConfigCallback(SaveConfigCallback): + def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + stamp = time.strftime('%y%m%d%H%M') + file_path = os.path.join(trainer.default_root_dir, f"config-{stage}-{stamp}.yaml") + self.parser.save( + self.config, file_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) + + +class ReWriteRootDirCli(LightningCLI): + def before_instantiate_classes(self) -> None: + super().before_instantiate_classes() + config_trainer = self._get(self.config, "trainer", default={}) + + # predict path & logger check + if self.subcommand == "predict": + config_trainer.logger = None + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + class TagsClass: + def __init__(self, exp:str): + ... + parser.add_class_arguments(TagsClass, nested_key="tags") + + def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_default_arguments_to_parser(parser) + parser.add_argument("--torch_hub_dir", type=str, default=None, help=("torch hub dir"),) + parser.add_argument("--huggingface_cache_dir", type=str, default=None, help=("huggingface hub dir"),) + + def instantiate_trainer(self, **kwargs: Any) -> Trainer: + config_trainer = self._get(self.config_init, "trainer", default={}) + default_root_dir = config_trainer.get("default_root_dir", None) + + if default_root_dir is None: + default_root_dir = os.path.join(os.getcwd(), "workdirs") + + dirname = "" + for v, k in self._get(self.config, "tags", default={}).items(): + dirname += f"{v}_{k}" + default_root_dir = os.path.join(default_root_dir, dirname) + is_resume = self._get(self.config_init, "ckpt_path", default=None) + if os.path.exists(default_root_dir) and "debug" not in default_root_dir: + if os.listdir(default_root_dir) and self.subcommand != "predict" and not is_resume: + raise FileExistsError(f"{default_root_dir} already exists") + + config_trainer.default_root_dir = default_root_dir + trainer = super().instantiate_trainer(**kwargs) + if trainer.is_global_zero: + os.makedirs(default_root_dir, exist_ok=True) + return trainer + + def instantiate_classes(self) -> None: + torch_hub_dir = self._get(self.config, "torch_hub_dir") + huggingface_cache_dir = self._get(self.config, "huggingface_cache_dir") + if huggingface_cache_dir is not None: + os.environ["HUGGINGFACE_HUB_CACHE"] = huggingface_cache_dir + if torch_hub_dir is not None: + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + super().instantiate_classes() + +if __name__ == "__main__": + + cli = ReWriteRootDirCli(LightningModel, DataModule, + auto_configure_optimizers=False, + save_config_callback=ReWriteRootSaveConfigCallback, + save_config_kwargs={"overwrite": True}) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..973222f61b1ed2e9977c3d464046403d048f1c27 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +lightning==2.5.0.post0 +omegaconf==2.3.0 +torch==2.5.0 +diffusers==0.30.0 +jsonargparse[signatures]>=4.27.7 +torchvision +timm +accelerate +gradio + + diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/callbacks/grad.py b/src/callbacks/grad.py new file mode 100644 index 0000000000000000000000000000000000000000..f9155b68152ac8b4ce2f30e31501ec892175b167 --- /dev/null +++ b/src/callbacks/grad.py @@ -0,0 +1,22 @@ +import torch +import lightning.pytorch as pl +from lightning.pytorch.utilities import grad_norm +from torch.optim import Optimizer + +class GradientMonitor(pl.Callback): + """Logs the gradient norm""" + + def __init__(self, norm_type: int = 2): + norm_type = float(norm_type) + if norm_type <= 0: + raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") + self.norm_type = norm_type + + def on_before_optimizer_step( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + optimizer: Optimizer + ) -> None: + norms = grad_norm(pl_module, norm_type=self.norm_type) + max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max() + pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) \ No newline at end of file diff --git a/src/callbacks/model_checkpoint.py b/src/callbacks/model_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ace6946e90c616913b498e5ede1663ad883bf1e7 --- /dev/null +++ b/src/callbacks/model_checkpoint.py @@ -0,0 +1,24 @@ +import os.path +from typing import Optional, Dict, Any + +import lightning.pytorch as pl +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint + + +class CheckpointHook(ModelCheckpoint): + """Save checkpoint with only the incremental part of the model""" + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + self.dirpath = trainer.default_root_dir + self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt") + pl_module.strict_loading = False + + def on_save_checkpoint( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any] + ) -> None: + del checkpoint["callbacks"] + + # def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + # if not "debug" in self.exception_ckpt_path: + # trainer.save_checkpoint(self.exception_ckpt_path) \ No newline at end of file diff --git a/src/callbacks/save_images.py b/src/callbacks/save_images.py new file mode 100644 index 0000000000000000000000000000000000000000..69f75786e010c8cc7cd8b8d9bbfd63fb0b45e448 --- /dev/null +++ b/src/callbacks/save_images.py @@ -0,0 +1,102 @@ +import lightning.pytorch as pl +from lightning.pytorch import Callback + + +import os.path +import numpy +from typing import Sequence, Any, Dict +from concurrent.futures import ThreadPoolExecutor + +from lightning.pytorch.utilities.types import STEP_OUTPUT +from lightning_utilities.core.rank_zero import rank_zero_info + + + +class SaveImagesHook(Callback): + def __init__(self, save_dir="val", save_compressed=False): + self.save_dir = save_dir + self.save_compressed = save_compressed + + def save_start(self, target_dir): + self.samples = [] + self.target_dir = target_dir + self.executor_pool = ThreadPoolExecutor(max_workers=8) + if not os.path.exists(self.target_dir): + os.makedirs(self.target_dir, exist_ok=True) + else: + if os.listdir(target_dir) and "debug" not in str(target_dir): + raise FileExistsError(f'{self.target_dir} already exists and not empty!') + rank_zero_info(f"Save images to {self.target_dir}") + self._saved_num = 0 + + def save_image(self, trainer, pl_module, images, metadatas,): + images = images.permute(0, 2, 3, 1).cpu().numpy() + for sample, metadata in zip(images, metadatas): + save_fn = metadata.pop("save_fn", None) + self.executor_pool.submit(save_fn, sample, metadata, self.target_dir) + + def process_batch( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: STEP_OUTPUT, + batch: Any, + ) -> None: + xT, y, metadata = batch + b, c, h, w = samples.shape + if not self.save_compressed or self._saved_num < 10: + self._saved_num += b + self.save_image(trainer, pl_module, samples, metadata) + + all_samples = pl_module.all_gather(samples).view(-1, c, h, w) + if trainer.is_global_zero: + all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy() + self.samples.append(all_samples) + + def save_end(self): + if self.save_compressed and len(self.samples) > 0: + samples = numpy.concatenate(self.samples) + numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples) + self.executor_pool.shutdown(wait=True) + self.target_dir = None + self.executor_pool = None + self.samples = [] + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") + self.save_start(target_dir) + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, outputs, batch) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict") + self.save_start(target_dir) + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, samples, batch) + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() + + def state_dict(self) -> Dict[str, Any]: + return dict() \ No newline at end of file diff --git a/src/callbacks/simple_ema.py b/src/callbacks/simple_ema.py new file mode 100644 index 0000000000000000000000000000000000000000..12a9a98f751455a9d6cb17b845be102f5da0f63e --- /dev/null +++ b/src/callbacks/simple_ema.py @@ -0,0 +1,60 @@ +from typing import Any, Dict + +import torch +import torch.nn as nn +import threading +import lightning.pytorch as pl +from lightning.pytorch import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from src.utils.copy import swap_tensors + +class SimpleEMA(Callback): + def __init__(self, + decay: float = 0.9999, + every_n_steps: int = 1, + ): + super().__init__() + self.decay = decay + self.every_n_steps = every_n_steps + self._stream = torch.cuda.Stream() + self.previous_step = 0 + + def setup_models(self, net: nn.Module, ema_net: nn.Module): + self.net_params = list(net.parameters()) + self.ema_params = list(ema_net.parameters()) + + def ema_step(self): + @torch.no_grad() + def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) + + if self._stream is not None: + self._stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._stream): + ema_update(self.ema_params, self.net_params, self.decay) + assert self.ema_params[0].dtype == torch.float32 + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + if trainer.global_step == self.previous_step: + return + self.previous_step = trainer.global_step + if trainer.global_step % self.every_n_steps == 0: + self.ema_step() + + + def state_dict(self) -> Dict[str, Any]: + return { + "decay": self.decay, + "every_n_steps": self.every_n_steps, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1 @@ + diff --git a/src/data/dataset/__init__.py b/src/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/dataset/dpg.py b/src/data/dataset/dpg.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3ac07ad689b0c2f7c21f580becd86d527ab346 --- /dev/null +++ b/src/data/dataset/dpg.py @@ -0,0 +1,42 @@ +import torch +import json +import copy +import os +from torch.utils.data import Dataset +from PIL import Image + +def dpg_save_fn(image, metadata, root_path): + image_path = os.path.join(root_path, str(metadata['filename'])+"_"+str(metadata['seed'])+".png") + Image.fromarray(image).save(image_path) + +class DPGDataset(Dataset): + def __init__(self, prompt_path, num_samples_per_instance, latent_shape): + self.latent_shape = latent_shape + self.prompt_path = prompt_path + prompt_files = os.listdir(self.prompt_path) + self.prompts = [] + self.filenames = [] + for prompt_file in prompt_files: + with open(os.path.join(self.prompt_path, prompt_file)) as fp: + self.prompts.append(fp.readline().strip()) + self.filenames.append(prompt_file.replace('.txt', '')) + self.num_instances = len(self.prompts) + self.num_samples_per_instance = num_samples_per_instance + self.num_samples = self.num_instances * self.num_samples_per_instance + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + instance_idx = idx // self.num_samples_per_instance + sample_idx = idx % self.num_samples_per_instance + generator = torch.Generator().manual_seed(sample_idx) + metadata = dict( + prompt=self.prompts[instance_idx], + filename=self.filenames[instance_idx], + seed=sample_idx, + save_fn=dpg_save_fn, + ) + condition = metadata["prompt"] + latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) + return latent, condition, metadata \ No newline at end of file diff --git a/src/data/dataset/geneval.py b/src/data/dataset/geneval.py new file mode 100644 index 0000000000000000000000000000000000000000..d21ac76b84e440c67d4b91e28cb8a3b5ded633e5 --- /dev/null +++ b/src/data/dataset/geneval.py @@ -0,0 +1,46 @@ +import torch +import json +import copy +from torch.utils.data import Dataset +import os +from PIL import Image + +def geneval_save_fn(image, metadata, root_path): + path = os.path.join(root_path, metadata['filename']) + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + # save image + image_path = os.path.join(path, "samples", f"{metadata['seed']}.png") + if not os.path.exists(os.path.dirname(image_path)): + os.makedirs(os.path.dirname(image_path), exist_ok=True) + Image.fromarray(image).save(image_path) + # metadata_path + metadata_path = os.path.join(path, "metadata.jsonl") + with open(metadata_path, "w") as fp: + json.dump(metadata, fp) + +class GenEvalDataset(Dataset): + def __init__(self, meta_json_path, num_samples_per_instance, latent_shape): + self.latent_shape = latent_shape + self.meta_json_path = meta_json_path + with open(meta_json_path) as fp: + self.metadatas = [json.loads(line) for line in fp] + self.num_instances = len(self.metadatas) + self.num_samples_per_instance = num_samples_per_instance + self.num_samples = self.num_instances * self.num_samples_per_instance + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + instance_idx = idx // self.num_samples_per_instance + sample_idx = idx % self.num_samples_per_instance + metadata = copy.deepcopy(self.metadatas[instance_idx]) + generator = torch.Generator().manual_seed(sample_idx) + condition = metadata["prompt"] + latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) + filename = f"{idx}" + metadata["seed"] = sample_idx + metadata["filename"] = filename + metadata["save_fn"] = geneval_save_fn + return latent, condition, metadata \ No newline at end of file diff --git a/src/data/dataset/image_txt.py b/src/data/dataset/image_txt.py new file mode 100644 index 0000000000000000000000000000000000000000..b785106e61acb618369f74258360662c34ff8cde --- /dev/null +++ b/src/data/dataset/image_txt.py @@ -0,0 +1,54 @@ +import torch +import os + +from torch.utils.data import Dataset +from torchvision.transforms import CenterCrop, Normalize, Resize +from torchvision.transforms.functional import to_tensor +from PIL import Image + +EXTs = ['.png', '.jpg', '.jpeg', ".JPEG"] + + +def is_image_file(filename): + return any(filename.endswith(ext) for ext in EXTs) + +class ImageText(Dataset): + def __init__(self, root, resolution): + super().__init__() + self.image_paths = [] + self.texts = [] + for dir, subdirs, files in os.walk(root): + for file in files: + if is_image_file(file): + image_path = os.path.join(dir, file) + image_base_path = image_path.split(".")[:-1] + text_path = ".".join(image_base_path) + ".txt" + if os.path.exists(text_path): + with open(text_path, 'r') as f: + text = f.read() + self.texts.append(text) + self.image_paths.append(image_path) + + self.resize = Resize(resolution) + self.center_crop = CenterCrop(resolution) + self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + + def __getitem__(self, idx: int): + image_path = self.image_paths[idx] + text = self.texts[idx] + pil_image = Image.open(image_path).convert('RGB') + pil_image = self.resize(pil_image) + pil_image = self.center_crop(pil_image) + raw_image = to_tensor(pil_image) + normalized_image = self.normalize(raw_image) + metadata = { + "image_path": image_path, + "prompt": text, + "raw_image": raw_image, + } + return normalized_image, text, metadata + + def __len__(self): + return len(self.image_paths) + + diff --git a/src/data/dataset/imagenet.py b/src/data/dataset/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..774715a19ad8656b2fb17eff90b228bacb2cc69b --- /dev/null +++ b/src/data/dataset/imagenet.py @@ -0,0 +1,77 @@ +import torch +import torchvision.transforms +from PIL import Image +from torchvision.datasets import ImageFolder +from torchvision.transforms.functional import to_tensor +from torchvision.transforms import Normalize +from functools import partial + +def center_crop_fn(image, height, width): + crop_x = (image.width - width) // 2 + crop_y = (image.height - height) // 2 + return image.crop((crop_x, crop_y, crop_x + width, crop_y + height)) + + +class LocalCachedDataset(ImageFolder): + def __init__(self, root, resolution=256, cache_root=None): + super().__init__(root) + self.cache_root = cache_root + self.transform = partial(center_crop_fn, height=resolution, width=resolution) + + def load_latent(self, latent_path): + pk_data = torch.load(latent_path) + mean = pk_data['mean'].to(torch.float32) + logvar = pk_data['logvar'].to(torch.float32) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + latent = mean + torch.randn_like(mean) * std + return latent + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + latent_path = image_path.replace(self.root, self.cache_root) + ".pt" + + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + if self.cache_root is not None: + latent = self.load_latent(latent_path) + else: + latent = raw_image + + metadata = { + "raw_image": raw_image, + "class": target, + } + return latent, target, metadata + + +class PixImageNet(ImageFolder): + def __init__(self, root, resolution=256, random_crop=False, ): + super().__init__(root) + if random_crop: + self.transform = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(resolution), + torchvision.transforms.RandomCrop(resolution), + torchvision.transforms.RandomHorizontalFlip(), + ] + ) + else: + self.transform = partial(center_crop_fn, height=resolution, width=resolution) + self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + + normalized_image = self.normalize(raw_image) + + metadata = { + "raw_image": raw_image, + "class": target, + } + return normalized_image, target, metadata \ No newline at end of file diff --git a/src/data/dataset/randn.py b/src/data/dataset/randn.py new file mode 100644 index 0000000000000000000000000000000000000000..bf327b1a0c3b61dc1a6e63f5c049f3dcbeaf69d8 --- /dev/null +++ b/src/data/dataset/randn.py @@ -0,0 +1,91 @@ +import os.path +import random +import re +import unicodedata +import torch +from torch.utils.data import Dataset +from PIL import Image + +from typing import List, Union + +def clean_filename(s): + # 去除首尾空格和点号 + s = s.strip().strip('.') + # 转换 Unicode 字符为 ASCII 形式 + s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII') + illegal_chars = r'[/]' + reserved_names = set() + # 替换非法字符为下划线 + s = re.sub(illegal_chars, '_', s) + # 合并连续的下划线 + s = re.sub(r'_{2,}', '_', s) + # 转换为小写 + s = s.lower() + # 检查是否为保留文件名 + if s.upper() in reserved_names: + s = s + '_' + # 限制文件名长度 + max_length = 200 + s = s[:max_length] + if not s: + return 'untitled' + return s + +def save_fn(image, metadata, root_path): + image_path = os.path.join(root_path, str(metadata['filename'])+".png") + Image.fromarray(image).save(image_path) + +class RandomNDataset(Dataset): + def __init__(self, latent_shape=(4, 64, 64), conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1): + if isinstance(conditions, int): + conditions = list(range(conditions)) # class labels + elif isinstance(conditions, str): + if os.path.exists(conditions): + conditions = open(conditions, "r").read().splitlines() + else: + raise FileNotFoundError(conditions) + elif isinstance(conditions, list): + conditions = conditions + self.conditions = conditions + self.num_conditons = len(conditions) + self.seeds = seeds + + if num_samples_per_instance > 0: + max_num_instances = num_samples_per_instance*self.num_conditons + else: + max_num_instances = max_num_instances + + if seeds is not None: + self.max_num_instances = len(seeds)*self.num_conditons + self.num_seeds = len(seeds) + else: + self.num_seeds = (max_num_instances + self.num_conditons - 1) // self.num_conditons + self.max_num_instances = self.num_seeds*self.num_conditons + self.latent_shape = latent_shape + + def __getitem__(self, idx): + condition = self.conditions[idx//self.num_seeds] + + seed = random.randint(0, 1<<31) #idx % self.num_seeds + if self.seeds is not None: + seed = self.seeds[idx % self.num_seeds] + + filename = f"{clean_filename(str(condition))}_{seed}" + generator = torch.Generator().manual_seed(seed) + latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) + + metadata = dict( + filename=filename, + seed=seed, + condition=condition, + save_fn=save_fn, + ) + return latent, condition, metadata + def __len__(self): + return self.max_num_instances + +class ClassLabelRandomNDataset(RandomNDataset): + def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, conditions:Union[int, List, str]=None, seeds=None, max_num_instances=50000, num_samples_per_instance=-1): + if conditions is None: + conditions = list(range(num_classes)) + super().__init__(latent_shape, conditions, seeds, max_num_instances, num_samples_per_instance) diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/diffusion/base/guidance.py b/src/diffusion/base/guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..9eea893da7a7477d735db61fbd8425d7ba695f3f --- /dev/null +++ b/src/diffusion/base/guidance.py @@ -0,0 +1,13 @@ +import torch + +def simple_guidance_fn(out, cfg): + uncondition, condtion = out.chunk(2, dim=0) + out = uncondition + cfg * (condtion - uncondition) + return out + +def c3_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condtion = out.chunk(2, dim=0) + out = condtion + out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3]) + return out \ No newline at end of file diff --git a/src/diffusion/base/sampling.py b/src/diffusion/base/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..b52bff95ef52fa62348437eaba87d68338d1623b --- /dev/null +++ b/src/diffusion/base/sampling.py @@ -0,0 +1,39 @@ +from typing import Union, List + +import torch +import torch.nn as nn +from typing import Callable +from src.diffusion.base.scheduling import BaseScheduler + +class BaseSampler(nn.Module): + def __init__(self, + scheduler: BaseScheduler = None, + guidance_fn: Callable = None, + num_steps: int = 250, + guidance: Union[float, List[float]] = 1.0, + *args, + **kwargs + ): + super(BaseSampler, self).__init__() + self.num_steps = num_steps + self.guidance = guidance + self.guidance_fn = guidance_fn + self.scheduler = scheduler + + + def _impl_sampling(self, net, noise, condition, uncondition): + raise NotImplementedError + + @torch.autocast("cuda", dtype=torch.bfloat16) + def forward(self, net, noise, condition, uncondition, return_x_trajs=False, return_v_trajs=False): + x_trajs, v_trajs = self._impl_sampling(net, noise, condition, uncondition) + if return_x_trajs and return_v_trajs: + return x_trajs[-1], x_trajs, v_trajs + elif return_x_trajs: + return x_trajs[-1], x_trajs + elif return_v_trajs: + return x_trajs[-1], v_trajs + else: + return x_trajs[-1] + + diff --git a/src/diffusion/base/scheduling.py b/src/diffusion/base/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..05c7fb18156e2e8aa28121e9ac855ba6ccf698f6 --- /dev/null +++ b/src/diffusion/base/scheduling.py @@ -0,0 +1,32 @@ +import torch +from torch import Tensor + +class BaseScheduler: + def alpha(self, t) -> Tensor: + ... + def sigma(self, t) -> Tensor: + ... + + def dalpha(self, t) -> Tensor: + ... + def dsigma(self, t) -> Tensor: + ... + + def dalpha_over_alpha(self, t) -> Tensor: + return self.dalpha(t) / self.alpha(t) + + def dsigma_mul_sigma(self, t) -> Tensor: + return self.dsigma(t)*self.sigma(t) + + def drift_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dalpha/(alpha + 1e-6) + + def diffuse_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2 + + def w(self, t): + return self.sigma(t) diff --git a/src/diffusion/base/training.py b/src/diffusion/base/training.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa20a6cc3fe3b5239dea656bc6bcf64323997a9 --- /dev/null +++ b/src/diffusion/base/training.py @@ -0,0 +1,29 @@ +import time + +import torch +import torch.nn as nn + + +class BaseTrainer(nn.Module): + def __init__(self, + null_condition_p=0.1, + ): + super(BaseTrainer, self).__init__() + self.null_condition_p = null_condition_p + + def preproprocess(self, x, condition, uncondition, metadata): + bsz = x.shape[0] + if self.null_condition_p > 0: + mask = torch.rand((bsz), device=condition.device) < self.null_condition_p + mask = mask.view(-1, *([1] * (len(condition.shape) - 1))).to(condition.dtype) + condition = condition*(1-mask) + uncondition*mask + return x, condition, metadata + + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + raise NotImplementedError + + @torch.autocast(device_type='cuda', dtype=torch.bfloat16) + def __call__(self, net, ema_net, solver, x, condition, uncondition, metadata=None): + x, condition, metadata = self.preproprocess(x, condition, uncondition, metadata) + return self._impl_trainstep(net, ema_net, solver, x, condition, metadata) + diff --git a/src/diffusion/ddpm/ddim_sampling.py b/src/diffusion/ddpm/ddim_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..523f4d87ccc131591ad5f263579695a3db8bda5e --- /dev/null +++ b/src/diffusion/ddpm/ddim_sampling.py @@ -0,0 +1,45 @@ +import torch +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + +import logging +logger = logging.getLogger(__name__) + +class DDIMSampler(BaseSampler): + def __init__( + self, + train_num_steps=1000, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.train_num_steps = train_num_steps + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device) + steps = torch.flip(steps, dims=[0]) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + x_trajs = [noise, ] + v_trajs = [] + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + t_cur = t_cur.repeat(batch_size) + t_next = t_next.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha = self.scheduler.alpha(t_cur) + sigma_next = self.scheduler.sigma(t_next) + alpha_next = self.scheduler.alpha(t_next) + cfg_x = torch.cat([x, x], dim=0) + t = t_cur.repeat(2) + out = net(cfg_x, t, cfg_condition) + out = self.guidance_fn(out, self.guidance) + x0 = (x - sigma * out) / alpha + x = alpha_next * x0 + sigma_next * out + x_trajs.append(x) + v_trajs.append(out) + v_trajs.append(torch.zeros_like(x)) + return x_trajs, v_trajs \ No newline at end of file diff --git a/src/diffusion/ddpm/scheduling.py b/src/diffusion/ddpm/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..aff1523b768b9ea83fcb5984e2190a100d5d0922 --- /dev/null +++ b/src/diffusion/ddpm/scheduling.py @@ -0,0 +1,102 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class DDPMScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.0001, + beta_max=0.02, + num_steps=1000, + ): + super().__init__() + self.beta_min = beta_min + self.beta_max = beta_max + self.num_steps = num_steps + + self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda") + self.alphas_table = torch.cumprod(1-self.betas_table, dim=0) + self.sigmas_table = 1-self.alphas_table + + + def beta(self, t) -> Tensor: + t = t.to(torch.long) + return self.betas_table[t].view(-1, 1, 1, 1) + + def alpha(self, t) -> Tensor: + t = t.to(torch.long) + return self.alphas_table[t].view(-1, 1, 1, 1)**0.5 + + def sigma(self, t) -> Tensor: + t = t.to(torch.long) + return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5 + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + raise NotImplementedError("wrong usage") + + +class VPScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.1, + beta_max=20, + ): + super().__init__() + self.beta_min = beta_min + self.beta_d = beta_max - beta_min + def beta(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1) + + def sigma(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t + return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1) + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def alpha(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t + return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1) + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + return self.diffuse_coefficient(t) + + + diff --git a/src/diffusion/ddpm/training.py b/src/diffusion/ddpm/training.py new file mode 100644 index 0000000000000000000000000000000000000000..2b831f68fdab5025c250e92163d448aa452164f8 --- /dev/null +++ b/src/diffusion/ddpm/training.py @@ -0,0 +1,83 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class VPTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t*self.train_max_t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - noise)**2 + + out = dict( + loss=loss.mean(), + ) + return out + + +class DDPMTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn: Callable = constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + + def _impl_trainstep(self, net, ema_net, x, y, metadata=None): + batch_size = x.shape[0] + t = torch.randint(0, self.train_max_t, (batch_size,)) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight * (out - noise) ** 2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/ddpm/vp_sampling.py b/src/diffusion/ddpm/vp_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..6d937587ade86a65b741e676f1837ceb81aeb3ab --- /dev/null +++ b/src/diffusion/ddpm/vp_sampling.py @@ -0,0 +1,64 @@ +import torch + +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * +from typing import Callable + +def ode_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt + +def sde_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x) + +import logging +logger = logging.getLogger(__name__) + +class VPEulerSampler(BaseSampler): + def __init__( + self, + train_max_t=1000, + guidance_fn: Callable = None, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.guidance_fn = guidance_fn + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.train_max_t = train_max_t + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device) + steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + x_trajs = [noise, ] + eps_trajs = [] + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + beta = self.scheduler.beta(t_cur) + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition) + eps = self.guidance_fn(out, self.guidance) + if i < self.num_steps -1 : + x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0]) + x = self.step_fn(x, eps, beta, sigma, dt) + else: + x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step) + x_trajs.append(x) + eps_trajs.append(eps) + eps_trajs.append(torch.zeros_like(x)) + return x_trajs, eps_trajs \ No newline at end of file diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..19a3c3847195d84ae7926a953780b678957d81d5 --- /dev/null +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -0,0 +1,122 @@ +import math +from src.diffusion.base.sampling import * +from src.diffusion.base.scheduling import * +from src.diffusion.pre_integral import * + +from typing import Callable, List, Tuple + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def t2snr(t): + if isinstance(t, torch.Tensor): + return (t.clip(min=1e-8)/(1-t + 1e-8)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2snr(t) for t in t] + t = max(t, 1e-8) + return (t/(1-t + 1e-8)) + +def t2logsnr(t): + if isinstance(t, torch.Tensor): + return torch.log(t.clip(min=1e-3)/(1-t + 1e-3)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2logsnr(t) for t in t] + t = max(t, 1e-3) + return math.log(t/(1-t + 1e-3)) + +def t2isnr(t): + return 1/t2snr(t) + +def nop(t): + return t + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +import logging +logger = logging.getLogger(__name__) + +class AdamLMSampler(BaseSampler): + def __init__( + self, + order: int = 2, + timeshift: float = 1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + lms_transform_fn: Callable = nop, + last_step=None, + step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + + assert self.scheduler is not None + assert self.step_fn in [ode_step_fn, ] + self.order = order + self.lms_transform_fn = lms_transform_fn + self.last_step = last_step + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None: + self.last_step = 1.0/self.num_steps + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, timeshift) + self.timedeltas = self.timesteps[1:] - self.timesteps[:-1] + self._reparameterize_coeffs() + + def _reparameterize_coeffs(self): + solver_coeffs = [[] for _ in range(self.num_steps)] + for i in range(0, self.num_steps): + pre_vs = [1.0, ]*(i+1) + pre_ts = self.lms_transform_fn(self.timesteps[:i+1]) + int_t_start = self.lms_transform_fn(self.timesteps[i]) + int_t_end = self.lms_transform_fn(self.timesteps[i+1]) + + order_annealing = self.order #self.num_steps - i + order = min(self.order, i + 1, order_annealing) + + _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end) + solver_coeffs[i] = coeffs + self.solver_coeffs = solver_coeffs + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + pred_trajectory = [] + x_trajectory = [noise, ] + v_trajectory = [] + t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype) + timedeltas = self.timedeltas + solver_coeffs = self.solver_coeffs + for i in range(self.num_steps): + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + guidance = self.guidance + out = self.guidance_fn(out, guidance) + else: + out = self.guidance_fn(out, 1.0) + pred_trajectory.append(out) + out = torch.zeros_like(out) + order = len(self.solver_coeffs[i]) + for j in range(order): + out += solver_coeffs[i][j] * pred_trajectory[-order:][j] + v = out + dt = timedeltas[i] + x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0) + x = self.step_fn(x, v, dt, s=0, w=0) + t_cur += dt + x_trajectory.append(x) + v_trajectory.append(v) + v_trajectory.append(torch.zeros_like(noise)) + return x_trajectory, v_trajectory \ No newline at end of file diff --git a/src/diffusion/flow_matching/sampling.py b/src/diffusion/flow_matching/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef2a723cc4721a00da19dbfdf5fe0924fbe93da --- /dev/null +++ b/src/diffusion/flow_matching/sampling.py @@ -0,0 +1,197 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def sde_mean_step_fn(x, v, dt, s, w): + return x + v * dt + s * w * dt + +def sde_step_fn(x, v, dt, s, w): + return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) + +def sde_preserve_step_fn(x, v, dt, s, w): + return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device, noise.dtype) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + x_trajs = [noise,] + v_trajs = [] + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + guidance = self.guidance + out = self.guidance_fn(out, guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + x_trajs.append(x) + v_trajs.append(v) + v_trajs.append(torch.zeros_like(x)) + return x_trajs, v_trajs + + +class HeunSampler(BaseSampler): + def __init__( + self, + scheduler: BaseScheduler = None, + w_scheduler: BaseScheduler = None, + exact_henu=False, + timeshift=1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.scheduler = scheduler + self.exact_henu = exact_henu + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Henu sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + v_hat, s_hat = 0.0, 0.0 + x_trajs = [noise, ] + v_trajs = [] + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + t_hat = t_next + t_hat = t_hat.repeat(batch_size) + sigma_hat = self.scheduler.sigma(t_hat) + alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat) + dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat) + + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + if i == 0 or self.exact_henu: + cfg_x = torch.cat([x, x], dim=0) + cfg_t_cur = t_cur.repeat(2) + out = net(cfg_x, cfg_t_cur, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v = out + s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma) + else: + v = v_hat + s = s_hat + x_hat = self.step_fn(x, v, dt, s=s, w=w) + # henu correct + if i < self.num_steps -1: + cfg_x_hat = torch.cat([x_hat, x_hat], dim=0) + cfg_t_hat = t_hat.repeat(2) + out = net(cfg_x_hat, cfg_t_hat, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v_hat = out + s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat) + v = (v + v_hat) / 2 + s = (s + s_hat) / 2 + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + x_trajs.append(x) + v_trajs.append(v) + v_trajs.append(torch.zeros_like(x)) + return x_trajs, v_trajs \ No newline at end of file diff --git a/src/diffusion/flow_matching/scheduling.py b/src/diffusion/flow_matching/scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..a82cd3a2fcb5e3080710fa0208c5aafff54cd068 --- /dev/null +++ b/src/diffusion/flow_matching/scheduling.py @@ -0,0 +1,39 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class LinearScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return (t).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return (1-t).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return torch.full_like(t, 1.0).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.full_like(t, -1.0).view(-1, 1, 1, 1) + +# SoTA for ImageNet! +class GVPScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def w(self, t): + return torch.sin(t)**2 + +class ConstScheduler(BaseScheduler): + def w(self, t): + return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) + +from src.diffusion.ddpm.scheduling import VPScheduler +class VPBetaScheduler(VPScheduler): + def w(self, t): + return self.beta(t).view(-1, 1, 1, 1) + + + diff --git a/src/diffusion/flow_matching/training.py b/src/diffusion/flow_matching/training.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbb4fbf7e787d58955825b043efa589ffe98af8 --- /dev/null +++ b/src/diffusion/flow_matching/training.py @@ -0,0 +1,61 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +def time_shift_fn(t, timeshift=1.0): + return t/(t+(1-t)*timeshift) + +class FlowMatchingTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + timeshift=1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.timeshift = timeshift + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + t = time_shift_fn(t, self.timeshift) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + loss = weight*(out - v_t)**2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_disperse.py b/src/diffusion/flow_matching/training_disperse.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8b5d07026200db2403ca9ed0e1161e71cc2eba --- /dev/null +++ b/src/diffusion/flow_matching/training_disperse.py @@ -0,0 +1,101 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +def time_shift_fn(t, timeshift=1.0): + return t/(t+(1-t)*timeshift) + + +class DisperseTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + timeshift=1.0, + align_layer=8, + temperature=1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.timeshift = timeshift + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.temperature = temperature + + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32) + t = time_shift_fn(base_t, self.timeshift).to(x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + feature = output + if isinstance(feature, tuple): + feature = feature[0] # mmdit + src_feature.append(feature) + + if getattr(net, "encoder", None) is not None: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out = net(x_t, t, y) + handle.remove() + disperse_distance = 0.0 + for sf in src_feature: + sf = torch.mean(sf, dim=1, keepdim=False) + distance = (sf[None, :, :] - sf[:, None, :])**2 + distance = distance.sum(dim=-1) + sf_disperse_loss = torch.exp(-distance/self.temperature) + mask = 1-torch.eye(batch_size, device=distance.device, dtype=distance.dtype) + disperse_distance += (sf_disperse_loss*mask).sum()/mask.numel() + 1e-6 + disperse_loss = disperse_distance.log() + + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=disperse_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(), + ) + return out + diff --git a/src/diffusion/flow_matching/training_imadv_ode.py b/src/diffusion/flow_matching/training_imadv_ode.py new file mode 100644 index 0000000000000000000000000000000000000000..6c74a6329cda7a249772178c79dab777c7dff07b --- /dev/null +++ b/src/diffusion/flow_matching/training_imadv_ode.py @@ -0,0 +1,218 @@ +import random + +import torch +import copy +import timm +import torchvision.transforms.v2.functional +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.diffusion.base.sampling import BaseSampler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +from PIL import Image +import numpy as np + +def time_shift_fn(t, timeshift=1.0): + return t/(t+(1-t)*timeshift) + +def random_crop(images, resize, crop_size): + images = torchvision.transforms.v2.functional.resize(images, size=resize, antialias=True) + h, w = crop_size + h0 = random.randint(0, images.shape[2]-h) + w0 = random.randint(0, images.shape[3]-w) + return images[:, :, h0:h0+h, w0:w0+w] + +# class EulerSolver: +# def __init__( +# self, +# num_steps: int, +# *args, +# **kwargs +# ): +# super().__init__(*args, **kwargs) +# self.num_steps = num_steps +# self.timesteps = torch.linspace(0.0, 1, self.num_steps+1, dtype=torch.float32) +# +# def __call__(self, net, noise, timeshift, condition): +# steps = time_shift_fn(self.timesteps[:, None], timeshift[None, :]).to(noise.device, noise.dtype) +# x = noise +# trajs = [x, ] +# for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): +# dt = t_next - t_cur +# v = net(x, t_cur, condition) +# x = x + v*dt[:, None, None, None] +# x = x.to(noise.dtype) +# trajs.append(x) +# return trajs +# +# class NeuralSolver(nn.Module): +# def __init__( +# self, +# num_steps: int, +# *args, +# **kwargs +# ): +# super().__init__(*args, **kwargs) +# self.num_steps = num_steps +# self.timedeltas = torch.nn.Parameter(torch.ones((num_steps))/num_steps, requires_grad=True) +# self.coeffs = torch.nn.Parameter(torch.zeros((num_steps, num_steps)), requires_grad=True) +# # self.golden_noise = torch.nn.Parameter(torch.randn((1, 3, 1024, 1024))*0.01, requires_grad=True) +# +# def forward(self, net, noise, timeshift, condition): +# batch_size, c, height, width = noise.shape +# # golden_noise = torch.nn.functional.interpolate(self.golden_noise, size=(height, width), mode='bicubic', align_corners=False) +# x = noise # + golden_noise.repeat(batch_size, 1, 1, 1) +# x_trajs = [x, ] +# v_trajs = [] +# dts = self.timedeltas.softmax(dim=0) +# print(dts) +# coeffs = self.coeffs +# t_cur = torch.zeros((batch_size,), dtype=noise.dtype, device=noise.device) +# for i, dt in enumerate(dts): +# pred_v = net(x, t_cur, condition) +# v = torch.zeros_like(pred_v) +# v_trajs.append(pred_v) +# acc_coeffs = 0.0 +# for j in range(i): +# acc_coeffs = acc_coeffs + coeffs[i, j] +# v = v + coeffs[i, j]*v_trajs[j] +# v = v + (1-acc_coeffs)*v_trajs[i] +# x = x + v*dt +# x = x.to(noise.dtype) +# x_trajs.append(x) +# t_cur = t_cur + dt +# return x_trajs + +import re +import unicodedata +def clean_filename(s): + # 去除首尾空格和点号 + s = s.strip().strip('.') + # 转换 Unicode 字符为 ASCII 形式 + s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII') + illegal_chars = r'[/]' + reserved_names = set() + # 替换非法字符为下划线 + s = re.sub(illegal_chars, '_', s) + # 合并连续的下划线 + s = re.sub(r'_{2,}', '_', s) + # 转换为小写 + s = s.lower() + # 检查是否为保留文件名 + if s.upper() in reserved_names: + s = s + '_' + # 限制文件名长度 + max_length = 200 + s = s[:max_length] + if not s: + return 'untitled' + return s + +class AdvODETrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + adv_loss_weight: float=0.5, + gan_loss_weight: float=0.5, + im_encoder:nn.Module=None, + adv_head:nn.Module=None, + random_crop_size=448, + max_image_size=512, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_loss_weight = adv_loss_weight + self.gan_loss_weight = gan_loss_weight + self.im_encoder = im_encoder + self.adv_head = adv_head + + self.real_buffer = [] + self.fake_buffer = [] + self.random_crop_size = random_crop_size + self.max_image_size = max_image_size + + no_grad(self.im_encoder) + + + def preproprocess(self, x, condition, uncondition, metadata): + self.uncondition = uncondition + return super().preproprocess(x, condition, uncondition, metadata) + + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + batch_size, c, height, width = x.shape + noise = torch.randn_like(x) + _, trajs = solver(net, noise, y, self.uncondition, return_x_trajs=True, return_v_trajs=False) + + fake_x0 = (trajs[-1]+1)/2 + fake_x0 = fake_x0.clamp(0, 1) + prompt = metadata["prompt"] + # filename = clean_filename(prompt[0]) + # Image.fromarray((fake_x0[0].permute(1, 2, 0).detach().cpu().float() * 255).to(torch.uint8).numpy()).save(f'{filename}.png') + real_x0 = metadata["raw_image"] + fake_x0 = random_crop(fake_x0, resize=self.max_image_size, crop_size=(self.random_crop_size, self.random_crop_size)) + real_x0 = random_crop(real_x0, resize=self.max_image_size, crop_size=(self.random_crop_size, self.random_crop_size)) + + fake_im_features = self.im_encoder(fake_x0, resize=False) + fake_im_features_detach = fake_im_features.detach() + + with torch.no_grad(): + real_im_features = self.im_encoder(real_x0, resize=False) + self.real_buffer.append(real_im_features) + self.fake_buffer.append(fake_im_features_detach) + while len(self.real_buffer) > 10: + self.real_buffer.pop(0) + while len(self.fake_buffer) > 10: + self.fake_buffer.pop(0) + + real_features_gan = torch.cat(self.real_buffer, dim=0) + fake_features_gan = torch.cat(self.fake_buffer, dim=0) + real_score_gan = self.adv_head(real_features_gan) + fake_score_gan = self.adv_head(fake_features_gan) + + fake_score_adv = self.adv_head(fake_im_features) + fake_score_detach_adv = self.adv_head(fake_im_features_detach) + + + loss_gan = -torch.log(1 - fake_score_gan).mean() - torch.log(real_score_gan).mean() + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score_adv) + loss_adv_hack = torch.log(fake_score_detach_adv) + + + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + loss=self.adv_loss_weight*(loss_adv.mean() + loss_adv_hack.mean())+self.gan_loss_weight*loss_gan.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.adv_head.state_dict( + destination=destination, + prefix=prefix + "adv_head.", + keep_vars=keep_vars) + diff --git a/src/diffusion/flow_matching/training_mmadv_ode.py b/src/diffusion/flow_matching/training_mmadv_ode.py new file mode 100644 index 0000000000000000000000000000000000000000..a0001565323c6d9676f88402b26ab6e455b3db3e --- /dev/null +++ b/src/diffusion/flow_matching/training_mmadv_ode.py @@ -0,0 +1,261 @@ +import random + +import torch +import copy +import timm +import torchvision.transforms.v2.functional +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.diffusion.base.sampling import BaseSampler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +from PIL import Image +import numpy as np + +def time_shift_fn(t, timeshift=1.0): + return t/(t+(1-t)*timeshift) + +def random_crop(images, resize, crop_size): + images = torchvision.transforms.v2.functional.resize(images, size=resize, antialias=True) + h, w = crop_size + h0 = random.randint(0, images.shape[2]-h) + w0 = random.randint(0, images.shape[3]-w) + return images[:, :, h0:h0+h, w0:w0+w] + +# class EulerSolver: +# def __init__( +# self, +# num_steps: int, +# *args, +# **kwargs +# ): +# super().__init__(*args, **kwargs) +# self.num_steps = num_steps +# self.timesteps = torch.linspace(0.0, 1, self.num_steps+1, dtype=torch.float32) +# +# def __call__(self, net, noise, timeshift, condition): +# steps = time_shift_fn(self.timesteps[:, None], timeshift[None, :]).to(noise.device, noise.dtype) +# x = noise +# trajs = [x, ] +# for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): +# dt = t_next - t_cur +# v = net(x, t_cur, condition) +# x = x + v*dt[:, None, None, None] +# x = x.to(noise.dtype) +# trajs.append(x) +# return trajs +# +# class NeuralSolver(nn.Module): +# def __init__( +# self, +# num_steps: int, +# *args, +# **kwargs +# ): +# super().__init__(*args, **kwargs) +# self.num_steps = num_steps +# self.timedeltas = torch.nn.Parameter(torch.ones((num_steps))/num_steps, requires_grad=True) +# self.coeffs = torch.nn.Parameter(torch.zeros((num_steps, num_steps)), requires_grad=True) +# # self.golden_noise = torch.nn.Parameter(torch.randn((1, 3, 1024, 1024))*0.01, requires_grad=True) +# +# def forward(self, net, noise, timeshift, condition): +# batch_size, c, height, width = noise.shape +# # golden_noise = torch.nn.functional.interpolate(self.golden_noise, size=(height, width), mode='bicubic', align_corners=False) +# x = noise # + golden_noise.repeat(batch_size, 1, 1, 1) +# x_trajs = [x, ] +# v_trajs = [] +# dts = self.timedeltas.softmax(dim=0) +# print(dts) +# coeffs = self.coeffs +# t_cur = torch.zeros((batch_size,), dtype=noise.dtype, device=noise.device) +# for i, dt in enumerate(dts): +# pred_v = net(x, t_cur, condition) +# v = torch.zeros_like(pred_v) +# v_trajs.append(pred_v) +# acc_coeffs = 0.0 +# for j in range(i): +# acc_coeffs = acc_coeffs + coeffs[i, j] +# v = v + coeffs[i, j]*v_trajs[j] +# v = v + (1-acc_coeffs)*v_trajs[i] +# x = x + v*dt +# x = x.to(noise.dtype) +# x_trajs.append(x) +# t_cur = t_cur + dt +# return x_trajs + +import re +import os +import unicodedata +def clean_filename(s): + # 去除首尾空格和点号 + s = s.strip().strip('.') + # 转换 Unicode 字符为 ASCII 形式 + s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII') + illegal_chars = r'[/]' + reserved_names = set() + # 替换非法字符为下划线 + s = re.sub(illegal_chars, '_', s) + # 合并连续的下划线 + s = re.sub(r'_{2,}', '_', s) + # 转换为小写 + s = s.lower() + # 检查是否为保留文件名 + if s.upper() in reserved_names: + s = s + '_' + # 限制文件名长度 + max_length = 200 + s = s[:max_length] + if not s: + return 'untitled' + return s + +def prompt_augment(prompts, random_prompts, replace_prob=0.5, front_append_prob=0.5, back_append_prob=0.5, delete_prob=0.5,): + random_prompts = random.choices(random_prompts, k=len(prompts)) + new_prompts = [] + for prompt, random_prompt in zip(prompts, random_prompts): + if random.random() < replace_prob: + new_prompt = random_prompt + else: + new_prompt = prompt + if random.random() < front_append_prob: + new_prompt = random_prompt + ", " + new_prompt + if random.random() < back_append_prob: + new_prompt = new_prompt + ", " + random_prompt + if random.random() < delete_prob: + new_length = random.randint(1, len(new_prompt.split(","))) + new_prompt = ", ".join(new_prompt.split(",")[:new_length]) + new_prompts.append(new_prompt) + return new_prompts + +class AdvODETrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + adv_loss_weight: float=0.5, + gan_loss_weight: float=0.5, + im_encoder:nn.Module=None, + mm_encoder:nn.Module=None, + adv_head:nn.Module=None, + random_crop_size=448, + max_image_size=512, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_loss_weight = adv_loss_weight + self.gan_loss_weight = gan_loss_weight + self.im_encoder = im_encoder + self.mm_encoder = mm_encoder + self.adv_head = adv_head + + self.real_buffer = [] + self.fake_buffer = [] + self.random_crop_size = random_crop_size + self.max_image_size = max_image_size + + no_grad(self.im_encoder) + no_grad(self.mm_encoder) + self.random_prompts = ["hahahaha", ] + self.saved_filenames = [] + + def preproprocess(self, x, condition, uncondition, metadata): + self.uncondition = uncondition + return super().preproprocess(x, condition, uncondition, metadata) + + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + batch_size, c, height, width = x.shape + noise = torch.randn_like(x) + _, trajs = solver(net, noise, y, self.uncondition, return_x_trajs=True, return_v_trajs=False) + with torch.no_grad(): + _, ref_trajs = solver(ema_net, noise, y, self.uncondition, return_x_trajs=True, return_v_trajs=False) + + fake_x0 = (trajs[-1]+1)/2 + fake_x0 = fake_x0.clamp(0, 1) + prompt = metadata["prompt"] + self.random_prompts.extend(prompt) + self.random_prompts = self.random_prompts[-50:] + filename = clean_filename(prompt[0])+".png" + Image.fromarray((fake_x0[0].permute(1, 2, 0).detach().cpu().float() * 255).to(torch.uint8).numpy()).save(f'{filename}') + self.saved_filenames.append(filename) + if len(self.saved_filenames) > 100: + os.remove(self.saved_filenames[0]) + self.saved_filenames.pop(0) + + real_x0 = metadata["raw_image"] + fake_x0 = random_crop(fake_x0, resize=self.max_image_size, crop_size=(self.random_crop_size, self.random_crop_size)) + real_x0 = random_crop(real_x0, resize=self.max_image_size, crop_size=(self.random_crop_size, self.random_crop_size)) + + fake_im_features = self.im_encoder(fake_x0, resize=False) + fake_mm_features = self.mm_encoder(fake_x0, prompt, resize=True) + fake_im_features_detach = fake_im_features.detach() + fake_mm_features_detach = fake_mm_features.detach() + + with torch.no_grad(): + real_im_features = self.im_encoder(real_x0, resize=False) + real_mm_features = self.mm_encoder(real_x0, prompt, resize=True) + not_match_prompt = prompt_augment(prompt, self.random_prompts)#random.choices(self.random_prompts, k=batch_size) + real_not_match_mm_features = self.mm_encoder(real_x0, not_match_prompt, resize=True) + self.real_buffer.append((real_im_features, real_mm_features)) + self.fake_buffer.append((fake_im_features_detach, fake_mm_features_detach)) + self.fake_buffer.append((real_im_features, real_not_match_mm_features)) + while len(self.real_buffer) > 10: + self.real_buffer.pop(0) + while len(self.fake_buffer) > 10: + self.fake_buffer.pop(0) + + real_features_gan = torch.cat([x[0] for x in self.real_buffer], dim=0) + real_conditions_gan = torch.cat([x[1] for x in self.real_buffer], dim=0) + fake_features_gan = torch.cat([x[0] for x in self.fake_buffer], dim=0) + fake_conditions_gan = torch.cat([x[1] for x in self.fake_buffer], dim=0) + real_score_gan = self.adv_head(real_features_gan, real_conditions_gan) + fake_score_gan = self.adv_head(fake_features_gan, fake_conditions_gan) + + fake_score_adv = self.adv_head(fake_im_features, fake_mm_features) + fake_score_detach_adv = self.adv_head(fake_im_features_detach, fake_mm_features_detach) + + + loss_gan = -torch.log(1 - fake_score_gan).mean() - torch.log(real_score_gan).mean() + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score_adv) + loss_adv_hack = torch.log(fake_score_detach_adv) + + trajs_loss = 0.0 + for x_t, ref_x_t in zip(trajs, ref_trajs): + trajs_loss = trajs_loss + torch.abs(x_t - ref_x_t).mean() + trajs_loss = trajs_loss / len(trajs) + + out = dict( + trajs_loss=trajs_loss.mean(), + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + loss=trajs_loss.mean() + self.adv_loss_weight*(loss_adv.mean() + loss_adv_hack.mean())+self.gan_loss_weight*loss_gan.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.adv_head.state_dict( + destination=destination, + prefix=prefix + "adv_head.", + keep_vars=keep_vars) diff --git a/src/diffusion/flow_matching/training_repa.py b/src/diffusion/flow_matching/training_repa.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc72aa39f7100e890389a06aaf07e58d68482ca --- /dev/null +++ b/src/diffusion/flow_matching/training_repa.py @@ -0,0 +1,122 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +def time_shift_fn(t, timeshift=1.0): + return t/(t+(1-t)*timeshift) + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + timeshift=1.0, + encoder:nn.Module=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.timeshift = timeshift + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = encoder + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + raw_images = metadata["raw_image"] + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32) + t = time_shift_fn(base_t, self.timeshift).to(x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + feature = output + if isinstance(feature, tuple): + feature = feature[0] # mmdit + src_feature.append(feature) + if getattr(net, "encoder", None) is not None: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + src_feature = src_feature[:, :dst_feature.shape[1]] + + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/pre_integral.py b/src/diffusion/pre_integral.py new file mode 100644 index 0000000000000000000000000000000000000000..848533a8e1aa99b4f2249560d4e2cec550f7852c --- /dev/null +++ b/src/diffusion/pre_integral.py @@ -0,0 +1,143 @@ +import torch + +# lagrange interpolation +def lagrange_preint_o1(t1, v1, int_t_start, int_t_end): + ''' + lagrange interpolation of order 1 + Args: + t1: timestepx + v1: value field at t1 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = (int_t_end-int_t_start) + return int1*v1, (int1/int1, ) + +def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end): + ''' + lagrange interpolation of order 2 + Args: + t1: timestepx + t2: timestepy + v1: value field at t1 + v2: value field at t2 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2) + int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2) + int_sum = int1+int2 + return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum) + +def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end): + ''' + lagrange interpolation of order 3 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3) + int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end + int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3) + int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end + int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2) + int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end + int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int_sum = int1+int2+int3 + return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum) + +def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end): + ''' + lagrange interpolation of order 4 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + t4: timestepw + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + v4: value field at t4 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3)*(t1-t4) + int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end + int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3)*(t2-t4) + int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end + int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2)*(t3-t4) + int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end + int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int4_denom = (t4-t1)*(t4-t2)*(t4-t3) + int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end + int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start + int4 = (int4_end - int4_start)/int4_denom + int_sum = int1+int2+int3+int4 + return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum) + + +def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end): + ''' + lagrange interpolation + Args: + order: order of interpolation + pre_vs: value field at pre_ts + pre_ts: timesteps + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + order = min(order, len(pre_vs), len(pre_ts)) + if order == 1: + return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end) + elif order == 2: + return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 3: + return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 4: + return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + else: + raise ValueError('Invalid order') + + +def polynomial_integral(coeffs, int_t_start, int_t_end): + ''' + polynomial integral + Args: + coeffs: coefficients of the polynomial + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + orders = len(coeffs) + int_val = 0 + for o in range(orders): + int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1)) + return int_val + diff --git a/src/lightning_data.py b/src/lightning_data.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc70a009f325f418a8b4962cca62ca20e08d7df --- /dev/null +++ b/src/lightning_data.py @@ -0,0 +1,132 @@ +from typing import Any +import torch +import time +import copy +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS +from torch.utils.data import DataLoader, Dataset, IterableDataset +from src.data.dataset.randn import RandomNDataset + +def mirco_batch_collate_fn(batch): + batch = copy.deepcopy(batch) + new_batch = [] + for micro_batch in batch: + new_batch.extend(micro_batch) + x, y, metadata = list(zip(*new_batch)) + stacked_metadata = {} + for key in metadata[0].keys(): + try: + if isinstance(metadata[0][key], torch.Tensor): + stacked_metadata[key] = torch.stack([m[key] for m in metadata], dim=0) + else: + stacked_metadata[key] = [m[key] for m in metadata] + except: + pass + x = torch.stack(x, dim=0) + return x, y, stacked_metadata + +def collate_fn(batch): + batch = copy.deepcopy(batch) + x, y, metadata = list(zip(*batch)) + stacked_metadata = {} + for key in metadata[0].keys(): + try: + if isinstance(metadata[0][key], torch.Tensor): + stacked_metadata[key] = torch.stack([m[key] for m in metadata], dim=0) + else: + stacked_metadata[key] = [m[key] for m in metadata] + except: + pass + x = torch.stack(x, dim=0) + return x, y, stacked_metadata + +def eval_collate_fn(batch): + batch = copy.deepcopy(batch) + x, y, metadata = list(zip(*batch)) + x = torch.stack(x, dim=0) + return x, y, metadata + +class DataModule(pl.LightningDataModule): + def __init__(self, + train_dataset:Dataset=None, + eval_dataset:Dataset=None, + pred_dataset:Dataset=None, + train_batch_size=64, + train_num_workers=16, + train_prefetch_factor=8, + eval_batch_size=32, + eval_num_workers=4, + pred_batch_size=32, + pred_num_workers=4, + ): + super().__init__() + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.pred_dataset = pred_dataset + # stupid data_convert override, just to make nebular happy + self.train_batch_size = train_batch_size + self.train_num_workers = train_num_workers + self.train_prefetch_factor = train_prefetch_factor + + + self.eval_batch_size = eval_batch_size + self.pred_batch_size = pred_batch_size + + self.pred_num_workers = pred_num_workers + self.eval_num_workers = eval_num_workers + + self._train_dataloader = None + + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + return batch + + def train_dataloader(self) -> TRAIN_DATALOADERS: + micro_batch_size = getattr(self.train_dataset, "micro_batch_size", None) + if micro_batch_size is not None: + assert self.train_batch_size % micro_batch_size == 0 + dataloader_batch_size = self.train_batch_size // micro_batch_size + train_collate_fn = mirco_batch_collate_fn + else: + dataloader_batch_size = self.train_batch_size + train_collate_fn = collate_fn + + # build dataloader sampler + if not isinstance(self.train_dataset, IterableDataset): + sampler = torch.utils.data.distributed.DistributedSampler(self.train_dataset) + else: + sampler = None + + self._train_dataloader = DataLoader( + self.train_dataset, + dataloader_batch_size, + timeout=6000, + num_workers=self.train_num_workers, + prefetch_factor=self.train_prefetch_factor, + collate_fn=train_collate_fn, + sampler=sampler, + ) + return self._train_dataloader + + def val_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.eval_dataset, self.eval_batch_size, + num_workers=self.eval_num_workers, + prefetch_factor=2, + sampler=sampler, + collate_fn=eval_collate_fn + ) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, + num_workers=self.pred_num_workers, + prefetch_factor=4, + sampler=sampler, + collate_fn=eval_collate_fn + ) diff --git a/src/lightning_model.py b/src/lightning_model.py new file mode 100644 index 0000000000000000000000000000000000000000..56877dc3e721ce6d864b772c8a0979c371c0f93b --- /dev/null +++ b/src/lightning_model.py @@ -0,0 +1,150 @@ +from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict +import os.path +import copy +import torch +import torch.nn as nn +import lightning.pytorch as pl +from lightning.pytorch.core.optimizer import LightningOptimizer +from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT +from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from lightning.pytorch.callbacks import Callback + + +from src.models.autoencoder.base import BaseAE, fp2uint8 +from src.models.conditioner.base import BaseConditioner +from src.utils.model_loader import ModelLoader +from src.callbacks.simple_ema import SimpleEMA +from src.diffusion.base.sampling import BaseSampler +from src.diffusion.base.training import BaseTrainer +from src.utils.no_grad import no_grad, filter_nograd_tensors +from src.utils.copy import copy_params + +EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA] +OptimizerCallable = Callable[[Iterable], Optimizer] +LRSchedulerCallable = Callable[[Optimizer], LRScheduler] + +class LightningModel(pl.LightningModule): + def __init__(self, + vae: BaseAE, + conditioner: BaseConditioner, + denoiser: nn.Module, + diffusion_trainer: BaseTrainer, + diffusion_sampler: BaseSampler, + ema_tracker: SimpleEMA=None, + optimizer: OptimizerCallable = None, + lr_scheduler: LRSchedulerCallable = None, + eval_original_model: bool = False, + ): + super().__init__() + self.vae = vae + self.conditioner = conditioner + self.denoiser = denoiser + self.ema_denoiser = copy.deepcopy(self.denoiser) + self.diffusion_sampler = diffusion_sampler + self.diffusion_trainer = diffusion_trainer + self.ema_tracker = ema_tracker + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + self.eval_original_model = eval_original_model + + self._strict_loading = False + + def configure_model(self) -> None: + self.trainer.strategy.barrier() + copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser) + + # disable grad for conditioner and vae + no_grad(self.conditioner) + no_grad(self.vae) + # no_grad(self.diffusion_sampler) + no_grad(self.ema_denoiser) + + # torch.compile + self.denoiser.compile() + self.ema_denoiser.compile() + + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: + return [self.ema_tracker] + + def configure_optimizers(self) -> OptimizerLRScheduler: + params_denoiser = filter_nograd_tensors(self.denoiser.parameters()) + params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters()) + params_sampler = filter_nograd_tensors(self.diffusion_sampler.parameters()) + param_groups = [ + {"params": params_denoiser, }, + {"params": params_trainer,}, + {"params": params_sampler, "lr": 1e-3}, + ] + # optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser]) + optimizer: torch.optim.Optimizer = self.optimizer(param_groups) + if self.lr_scheduler is None: + return dict( + optimizer=optimizer + ) + else: + lr_scheduler = self.lr_scheduler(optimizer) + return dict( + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + + def on_validation_start(self) -> None: + self.ema_denoiser.to(torch.float32) + + def on_predict_start(self) -> None: + self.ema_denoiser.to(torch.float32) + + # sanity check before training start + def on_train_start(self) -> None: + self.ema_denoiser.to(torch.float32) + self.ema_tracker.setup_models(net=self.denoiser, ema_net=self.ema_denoiser) + + + def training_step(self, batch, batch_idx): + x, y, metadata = batch + with torch.no_grad(): + x = self.vae.encode(x) + condition, uncondition = self.conditioner(y, metadata) + loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, self.diffusion_sampler, x, condition, uncondition, metadata) + self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False) + return loss["loss"] + + def predict_step(self, batch, batch_idx): + xT, y, metadata = batch + with torch.no_grad(): + condition, uncondition = self.conditioner(y) + + # sample images + if self.eval_original_model: + samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition) + else: + samples = self.diffusion_sampler(self.ema_denoiser, xT, condition, uncondition) + + samples = self.vae.decode(samples) + # fp32 -1,1 -> uint8 0,255 + samples = fp2uint8(samples) + return samples + + def validation_step(self, batch, batch_idx): + samples = self.predict_step(batch, batch_idx) + return samples + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = {} + self._save_to_state_dict(destination, prefix, keep_vars) + self.denoiser.state_dict( + destination=destination, + prefix=prefix+"denoiser.", + keep_vars=keep_vars) + self.ema_denoiser.state_dict( + destination=destination, + prefix=prefix+"ema_denoiser.", + keep_vars=keep_vars) + self.diffusion_trainer.state_dict( + destination=destination, + prefix=prefix+"diffusion_trainer.", + keep_vars=keep_vars) + return destination \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/autoencoder/base.py b/src/models/autoencoder/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2801bd3950734cd0aac13b48c822cb13763edeb4 --- /dev/null +++ b/src/models/autoencoder/base.py @@ -0,0 +1,33 @@ +import torch +import logging + + +class BaseAE(torch.nn.Module): + def __init__(self, scale=1.0, shift=0.0): + super().__init__() + self.scale = scale + self.shift = shift + + @torch.autocast("cuda", dtype=torch.bfloat16) + def encode(self, x): + return self._impl_encode(x).to(torch.bfloat16) + + @torch.autocast("cuda", dtype=torch.bfloat16) + def decode(self, x): + return self._impl_decode(x).to(torch.bfloat16) + + def _impl_encode(self, x): + raise NotImplementedError + + def _impl_decode(self, x): + raise NotImplementedError + +def uint82fp(x): + x = x.to(torch.float32) + x = (x - 127.5) / 127.5 + return x + +def fp2uint8(x): + x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8) + return x + diff --git a/src/models/autoencoder/latent.py b/src/models/autoencoder/latent.py new file mode 100644 index 0000000000000000000000000000000000000000..901bf770c78c62139d78858cc9c7969a588d0c9b --- /dev/null +++ b/src/models/autoencoder/latent.py @@ -0,0 +1,24 @@ +import torch +from src.models.autoencoder.base import BaseAE + +class LatentAE(BaseAE): + def __init__(self, precompute=False, weight_path:str=None): + super().__init__() + self.precompute = precompute + self.model = None + self.weight_path = weight_path + + from diffusers.models import AutoencoderKL + setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path)) + self.scaling_factor = self.model.config.scaling_factor + + def _impl_encode(self, x): + assert self.model is not None + if self.precompute: + return x.mul_(self.scaling_factor) + encodedx = self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor) + return encodedx + + def _impl_decode(self, x): + assert self.model is not None + return self.model.decode(x.div_(self.scaling_factor)).sample \ No newline at end of file diff --git a/src/models/autoencoder/pixel.py b/src/models/autoencoder/pixel.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbf1c08c93c27beecef6973c6243a22f11f2390 --- /dev/null +++ b/src/models/autoencoder/pixel.py @@ -0,0 +1,12 @@ +import torch +from src.models.autoencoder.base import BaseAE + +class PixelAE(BaseAE): + def __init__(self, scale=1.0, shift=0.0): + super().__init__(scale, shift) + + def _impl_encode(self, x): + return x/self.scale+self.shift + + def _impl_decode(self, x): + return (x-self.shift)*self.scale \ No newline at end of file diff --git a/src/models/conditioner/base.py b/src/models/conditioner/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1e75b8eb152cccf69b3eb78ac3409956339fa4b2 --- /dev/null +++ b/src/models/conditioner/base.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from typing import List + +class BaseConditioner(nn.Module): + def __init__(self): + super(BaseConditioner, self).__init__() + + def _impl_condition(self, y, metadata)->torch.Tensor: + raise NotImplementedError() + + def _impl_uncondition(self, y, metadata)->torch.Tensor: + raise NotImplementedError() + + @torch.no_grad() + @torch.autocast("cuda", dtype=torch.bfloat16) + def __call__(self, y, metadata:dict={}): + condition = self._impl_condition(y, metadata) + uncondition = self._impl_uncondition(y, metadata) + if condition.dtype in [torch.float64, torch.float32, torch.float16]: + condition = condition.to(torch.bfloat16) + if uncondition.dtype in [torch.float64,torch.float32, torch.float16]: + uncondition = uncondition.to(torch.bfloat16) + return condition, uncondition + + +class ComposeConditioner(BaseConditioner): + def __init__(self, conditioners:List[BaseConditioner]): + super().__init__() + self.conditioners = conditioners + + def _impl_condition(self, y, metadata): + condition = [] + for conditioner in self.conditioners: + condition.append(conditioner._impl_condition(y, metadata)) + condition = torch.cat(condition, dim=1) + return condition + + def _impl_uncondition(self, y, metadata): + uncondition = [] + for conditioner in self.conditioners: + uncondition.append(conditioner._impl_uncondition(y, metadata)) + uncondition = torch.cat(uncondition, dim=1) + return uncondition \ No newline at end of file diff --git a/src/models/conditioner/class_label.py b/src/models/conditioner/class_label.py new file mode 100644 index 0000000000000000000000000000000000000000..dcfbf17690f291e93d26e5a8e8526f9085583e91 --- /dev/null +++ b/src/models/conditioner/class_label.py @@ -0,0 +1,13 @@ +import torch +from src.models.conditioner.base import BaseConditioner + +class LabelConditioner(BaseConditioner): + def __init__(self, num_classes): + super().__init__() + self.null_condition = num_classes + + def _impl_condition(self, y, metadata): + return torch.tensor(y).long().cuda() + + def _impl_uncondition(self, y, metadata): + return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda() \ No newline at end of file diff --git a/src/models/conditioner/place_holder.py b/src/models/conditioner/place_holder.py new file mode 100644 index 0000000000000000000000000000000000000000..847e22a8681a00284811dbacb8d0f5007e26d6be --- /dev/null +++ b/src/models/conditioner/place_holder.py @@ -0,0 +1,14 @@ +import torch +from src.models.conditioner.base import BaseConditioner + +class PlaceHolderConditioner(BaseConditioner): + def __init__(self, null_class=1000): + super().__init__() + self.null_condition = null_class + + def _impl_condition(self, y, metadata): + y = torch.randint(0, self.null_condition, (len(y),)).cuda() + return y + + def _impl_uncondition(self, y, metadata): + return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda() \ No newline at end of file diff --git a/src/models/conditioner/qwen3_text_encoder.py b/src/models/conditioner/qwen3_text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..22c1dc00ce276bde4e3686b0f7f535089ff9b5d1 --- /dev/null +++ b/src/models/conditioner/qwen3_text_encoder.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +from src.models.conditioner.base import BaseConditioner + +from transformers import Qwen3Model, Qwen2Tokenizer + + +class Qwen3TextEncoder(BaseConditioner): + def __init__(self, weight_path: str, embed_dim:int=None, max_length=128): + super().__init__() + self.tokenizer = Qwen2Tokenizer.from_pretrained(weight_path, max_length=max_length, padding_side="right") + # self.model = Qwen3Model.from_pretrained(weight_path, attn_implementation="flex_attention").to(torch.bfloat16) + self.model = Qwen3Model.from_pretrained(weight_path).to(torch.bfloat16) + self.model.compile() + self.uncondition_embedding = None + self.embed_dim = embed_dim + self.max_length = max_length + # torch._dynamo.config.optimize_ddp = False + + def _impl_condition(self, y, metadata:dict={}): + tokenized = self.tokenizer(y, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") + input_ids = tokenized.input_ids.cuda() + attention_mask = tokenized.attention_mask.cuda() + metadata["valid_length_y"] = torch.sum(attention_mask, dim=-1) + y = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] + if y.shape[2] < self.embed_dim: + y = torch.cat([y, torch.zeros(y.shape[0], y.shape[1], self.embed_dim - y.shape[2]).to(y.device, y.dtype)], dim=-1) + if y.shape[2] > self.embed_dim: + y = y[:, :, :self.embed_dim] + return y + + def _impl_uncondition(self, y, metadata:dict=None): + if self.uncondition_embedding is not None: + return self.uncondition_embedding.repeat(len(y), 1, 1) + self.uncondition_embedding = self._impl_condition(["",]) + return self.uncondition_embedding.repeat(len(y), 1, 1) \ No newline at end of file diff --git a/src/models/encoder.py b/src/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2184509b1a4826df8a180123372b380caeb915ae --- /dev/null +++ b/src/models/encoder.py @@ -0,0 +1,119 @@ +import copy +import torch +import torch.nn as nn +import timm +from torchvision.transforms import Normalize +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +import os + +class IndentityMapping(nn.Module): + def __init__(self): + super().__init__() + def forward(self, x, resize=True): + b, c, h, w = x.shape + x = x.reshape(b, c, h*w).transpose(1, 2) + return x + +class DINOv2(nn.Module): + def __init__(self, weight_path:str, base_patch_size=16): + super(DINOv2, self).__init__() + directory = os.path.dirname(weight_path) + weight_path = os.path.basename(weight_path) + self.encoder = torch.hub.load( + directory, + weight_path, + source="local", + skip_validation=True + ) + self.encoder = self.encoder.to(torch.bfloat16) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + self.base_patch_size = base_patch_size + self.encoder.compile() + + @torch.autocast(device_type='cuda', dtype=torch.bfloat16) + def forward(self, x, resize=True): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + if resize: + x = torch.nn.functional.interpolate(x, (int(14*h/self.base_patch_size), int(14*w/self.base_patch_size)), mode='bicubic') + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + feature = feature.to(torch.bfloat16) + return feature + +from transformers import CLIPModel, CLIPTokenizer +class CLIP(nn.Module): + def __init__(self, weight_path:str): + super(CLIP, self).__init__() + self.model = CLIPModel.from_pretrained(weight_path).to(torch.bfloat16) + self.tokenizer = CLIPTokenizer.from_pretrained(weight_path) + self.height = self.model.config.vision_config.image_size + self.width = self.model.config.vision_config.image_size + + self.model.vision_model.compile() + self.model.text_model.compile() + def forward(self, x, text, resize=True): + tokens = self.tokenizer(text, truncation=True, return_tensors='pt', padding="max_length", max_length=self.tokenizer.model_max_length).input_ids.cuda() + text_output = self.model.text_model(input_ids=tokens).last_hidden_state + text_output = self.model.text_projection(text_output) + text_output = torch.nn.functional.normalize(text_output, dim=-1, p=2) + if resize: + x = torch.nn.functional.interpolate(x, (self.height, self.width), mode='bicubic') + x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x) + vision_output = self.model.vision_model(x).last_hidden_state[:, 1:] + vision_output = self.model.visual_projection(vision_output) + vision_output = torch.nn.functional.normalize(vision_output, dim=-1, p=2) + output = torch.bmm(vision_output, text_output.transpose(1, 2)) + return output + +from transformers import SiglipModel, GemmaTokenizer, SiglipTokenizer +class SigLIP(nn.Module): + def __init__(self, weight_path:str): + super(SigLIP, self).__init__() + if "siglip2" in weight_path: + self.tokenizer = GemmaTokenizer.from_pretrained(weight_path) + else: + self.tokenizer = SiglipTokenizer.from_pretrained(weight_path) + self.model = SiglipModel.from_pretrained(weight_path).to(torch.bfloat16) + + self.mean = 0.5 + self.std = 0.5 + + self.model.vision_model.compile() + self.model.text_model.compile() + def forward(self, x, text, resize=True): + tokens = self.tokenizer(text, truncation=True, return_tensors='pt', padding="max_length", max_length=64).input_ids.cuda() + text_output = self.model.text_model(input_ids=tokens).last_hidden_state + text_output = torch.nn.functional.normalize(text_output, dim=-1, p=2) + if resize: + x = torch.nn.functional.interpolate(x, (self.height, self.width), mode='bicubic') + x = (x - self.mean)/self.std + vision_output = self.model.vision_model(x).last_hidden_state + vision_output = torch.nn.functional.normalize(vision_output, dim=-1, p=2) + output = torch.bmm(vision_output, text_output.transpose(1, 2)) + return output + +from transformers import SiglipVisionModel +class SigLIPVision(nn.Module): + def __init__(self, weight_path:str, base_patch_size=16): + super(SigLIPVision, self).__init__() + self.model = SiglipVisionModel.from_pretrained(weight_path).to(torch.bfloat16) + self.height = self.model.config.image_size + self.width = self.model.config.image_size + self.patch_size = self.model.config.patch_size + self.base_patch_size = base_patch_size + self.model.compile() + self.mean = 0.5 + self.std = 0.5 + def forward(self, x, resize=True): + if resize: + h, w = x.shape[-2:] + new_h = int(self.patch_size * h / self.base_patch_size) + new_w = int(self.patch_size * w / self.base_patch_size) + x = torch.nn.functional.interpolate(x, (new_h, new_w), mode='bicubic') + x = (x - self.mean)/self.std + vision_output = self.model.vision_model(x).last_hidden_state + return vision_output \ No newline at end of file diff --git a/src/models/layers/adv_head.py b/src/models/layers/adv_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1f54138e57962f572d9a7c28e9153b87975ff6 --- /dev/null +++ b/src/models/layers/adv_head.py @@ -0,0 +1,236 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class ConvHead(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 + ) + + def forward(self, feature, text_embedding=None): + # assume sqrt image size + B, L, C = feature.shape + H = W = int(math.sqrt(L)) + feature = feature.permute(0, 2, 1) + feature = feature.view(B, C, H, W) + out = self.head(feature).sigmoid().clamp(0.01, 0.99) + return out + +class ConvLinearMMHead(nn.Module): + def __init__(self, im_channels, mm_channels, hidden_size): + super().__init__() + self.conv_head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=im_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + ) + self.linear_head = nn.Sequential( + nn.Linear(mm_channels, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + nn.SiLU(), + ) + self.out = nn.Linear(hidden_size*2, 1) + + def forward(self, im_feature, mm_feature=None): + # assume sqrt image size + B, L, C = im_feature.shape + H = W = int(math.sqrt(L)) + im_feature = im_feature.permute(0, 2, 1) + im_feature = im_feature.view(B, C, H, W) + im_out = self.conv_head(im_feature).view(B, -1) + mm_out = self.linear_head(mm_feature).view(B, -1) + out = self.out(torch.cat([im_out, mm_out], dim=-1)).sigmoid().clamp(0.01, 0.99) + return out + +class ConvMMHead(nn.Module): + def __init__(self, im_channels, mm_channels, hidden_size): + super().__init__() + self.conv1_head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=im_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + ) + self.conv2_head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=mm_channels, out_channels=hidden_size, stride=2, padding=1), + # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), + # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), + # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + ) + self.out = nn.Linear(hidden_size*2, 1) + + def forward(self, im_feature, mm_feature=None): + # assume sqrt image size + B, L, C = im_feature.shape + H = W = int(math.sqrt(L)) + im_feature = im_feature.permute(0, 2, 1) + im_feature = im_feature.view(B, C, H, W) + + B, Lmm, Cmm = mm_feature.shape + Hmm = Wmm = int(math.sqrt(Lmm)) + mm_feature = mm_feature.permute(0, 2, 1) + mm_feature = mm_feature.view(B, Cmm, Hmm, Wmm) + + im_out = self.conv1_head(im_feature).view(B, -1) + mm_out = self.conv2_head(mm_feature).view(B, -1) + out = self.out(torch.cat([im_out, mm_out], dim=-1)).sigmoid().clamp(0.01, 0.99) + return out + +# class ConvTextHead(nn.Module): +# def __init__(self, in_channels, text_channels, hidden_size): +# super().__init__() +# self.head = nn.Sequential( +# nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 +# nn.GroupNorm(num_groups=32, num_channels=hidden_size), +# nn.SiLU(), +# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 +# nn.GroupNorm(num_groups=32, num_channels=hidden_size), +# nn.SiLU(), +# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 +# nn.GroupNorm(num_groups=32, num_channels=hidden_size), +# nn.SiLU(), +# nn.AdaptiveAvgPool2d(1), +# nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=hidden_size, stride=1, padding=0), # 1x1 -> 1x1 +# ) +# self.text_head = nn.Sequential( +# nn.Linear(text_channels, hidden_size), +# nn.SiLU(), +# nn.Linear(hidden_size, hidden_size), +# ) +# +# def forward(self, feature, text_embedding=None): +# # assume sqrt image size +# B, L, C = feature.shape +# H = W = int(math.sqrt(L)) +# feature = feature.permute(0, 2, 1) +# feature = feature.view(B, C, H, W) +# feature = self.head(feature).view(B, -1) +# text_embedding = torch.mean(text_embedding, dim=1, keepdim=False) +# text_embedding = self.text_head(text_embedding) +# logits = torch.sum(feature * text_embedding, dim=1, keepdim=False) +# score = logits.sigmoid().clamp(0.01, 0.99) +# return score +# +# class LinearHead(nn.Module): +# def __init__(self, in_channels, hidden_size): +# super().__init__() +# self.head = nn.Sequential( +# nn.Linear(in_channels, hidden_size), +# nn.SiLU(), +# nn.Linear(hidden_size, hidden_size), +# nn.SiLU(), +# nn.Linear(hidden_size, 1), +# ) +# def forward(self, feature, text_embedding=None): +# out = self.head(feature).sigmoid().clamp(0.01, 0.99) +# return out + + +# class ConvMultiModalHead(nn.Module): +# def __init__(self, in_channels, mm_channels, hidden_size): +# super().__init__() +# self.image_head = nn.Sequential( +# nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 +# nn.GroupNorm(num_groups=32, num_channels=hidden_size), +# nn.SiLU(), +# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 +# nn.GroupNorm(num_groups=32, num_channels=hidden_size), +# nn.SiLU(), +# nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 +# nn.GroupNorm(num_groups=32, num_channels=hidden_size), +# nn.SiLU(), +# nn.AdaptiveAvgPool2d(1), +# nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 +# ) +# self.mm_head = nn.Sequential( +# nn.Linear(mm_channels, hidden_size), +# nn.SiLU(), +# nn.Linear(hidden_size, hidden_size), +# ) +# +# def forward(self, feature, text_embedding=None): +# # assume sqrt image size +# B, L, C = feature.shape +# H = W = int(math.sqrt(L)) +# feature = feature.permute(0, 2, 1) +# feature = feature.view(B, C, H, W) +# feature = self.head(feature).view(B, -1) +# text_embedding = torch.mean(text_embedding, dim=1, keepdim=False) +# text_embedding = self.text_head(text_embedding) +# logits = torch.sum(feature * text_embedding, dim=1, keepdim=False) +# score = logits.sigmoid().clamp(0.01, 0.99) +# return score + +# class TransformerTextHead(nn.Module): +# def __init__(self, in_channels, text_channels, hidden_size): +# super().__init__() +# +# self.transformer = nn.Sequential( +# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), +# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), +# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), +# nn.TransformerEncoderLayer(d_model=hidden_size, nhead=8, dim_feedforward=hidden_size, batch_first=True), +# ) +# self.text_head = nn.Sequential( +# nn.Linear(text_channels, hidden_size), +# nn.SiLU(), +# nn.Linear(hidden_size, hidden_size), +# ) +# self.feature_head = nn.Sequential( +# nn.Linear(in_channels, hidden_size), +# nn.SiLU(), +# nn.Linear(hidden_size, hidden_size), +# ) +# self.cls_head = nn.Sequential( +# nn.Linear(hidden_size, hidden_size), +# nn.SiLU(), +# nn.Linear(hidden_size, 1), +# ) +# +# def forward(self, feature, text_embedding=None): +# # assume sqrt image size +# feature = self.feature_head(feature) +# text_embedding = self.text_head(text_embedding) +# tokens = torch.cat([feature, text_embedding], dim=1) +# tokens = self.transformer(tokens) +# cls_token = tokens +# logits = self.cls_head(cls_token) +# logits = torch.mean(logits, dim=1, keepdim=False) +# score = logits.sigmoid().clamp(0.01, 0.99) +# return score diff --git a/src/models/layers/attention_op.py b/src/models/layers/attention_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ebf05b7a0a213b1124c13676001a8471381626 --- /dev/null +++ b/src/models/layers/attention_op.py @@ -0,0 +1,7 @@ +import torch +import torch.nn as nn + +from torch.nn.functional import scaled_dot_product_attention as attention + + + diff --git a/src/models/layers/final_layer.py b/src/models/layers/final_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..bb22f7140b7f36d9c1665e4b887f89a1310ed7ad --- /dev/null +++ b/src/models/layers/final_layer.py @@ -0,0 +1,19 @@ +import torch.nn as nn + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x \ No newline at end of file diff --git a/src/models/layers/msdcn.py b/src/models/layers/msdcn.py new file mode 100644 index 0000000000000000000000000000000000000000..d92c457a68cb08270daa5726bd9b4d1c59b8ac78 --- /dev/null +++ b/src/models/layers/msdcn.py @@ -0,0 +1,302 @@ +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from typing import Any +from torch.autograd import Function +from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd + +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def forward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_channels_per_group + C: tl.constexpr, # num_groups + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] + weights_ptr, # weights [B, H, W, G, K] + out_ptr, # out [B, H, W, G, C] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + + for block_base in tl.static_range(0, C, BLOCK_SIZE): + buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + for k in tl.static_range(K): + deformable_offset = (common_offset * K + k) * 2 + + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) + + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + + + + tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input = tl_block_input * tl_weight + + # load top right + tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input = tr_block_input * tr_weight + # load bottom left + bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input = bl_block_input * bl_weight + # load bottom right + br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input = br_block_input * br_weight + + # sampled + sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input + + weighted_sampled_input = sampled_input * weight + buffer = buffer + weighted_sampled_input + # store to out_ptr + tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def backward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_groups + C: tl.constexpr, # num_channels_per_group + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, K, 2] + weights_ptr, # weights [B, H, W, G, K] + grad_ptr, # out [B, H, W, G, C] + grad_input_ptr, # input features [B, H, W, G, C] + grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] + grad_weights_ptr, # weights [B, H, W, G, K] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + for k in tl.static_range(K): + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) + dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) + dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) + deformable_offset = (common_offset * K + k)*2 + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + for block_base in tl.static_range(0, C, BLOCK_SIZE): + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) + dods = weight*grad + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset + tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) + tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) + dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) + dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) + dodw = dodw + tl_block_input_dot_grad * tl_weight + + dodtl = dods * tl_weight + tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) + + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset + tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) + tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) + dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) + dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) + dodw = dodw + tr_block_input_dot_grad*tr_weight + + dodtr = dods * tr_weight + tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) + + + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset + bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) + bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) + dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) + dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) + dodw = dodw + bl_block_input_dot_grad*bl_weight + + dodbl = dods * bl_weight + tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) + + + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset + br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) + br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask + + dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) + dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) + dodw = dodw + br_block_input_dot_grad*br_weight + + dodbr = dods * br_weight + tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) + dodx = dodx * weight + dody = dody * weight + tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) + tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) + tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) + + +class DCNFunction(Function): + @staticmethod + @custom_fwd + def forward(ctx: Any, inputs, deformables, weights) -> Any: + B, H, W, G, C = inputs.shape + _, _, _, _, K, _ = deformables.shape + out = torch.zeros_like(inputs) + grid = lambda META: (B * H * W * G,) + forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) + ctx.save_for_backward(inputs, deformables, weights) + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_output = grad_outputs[0].contiguous() + inputs, deformables, weights = ctx.saved_tensors + B, H, W, G, C = inputs.shape + _, _, _, _, K, _ = deformables.shape + grad_inputs = torch.zeros_like(inputs) + grad_deformables = torch.zeros_like(deformables) + grad_weights = torch.zeros_like(weights) + grid = lambda META: (B * H * W * G,) + backward_kernel[grid]( + B, H, W, G, C, K, + inputs, + deformables, + weights, + grad_output, + grad_inputs, + grad_deformables, + grad_weights, + ) + return (grad_inputs, grad_deformables, grad_weights) + + +class MultiScaleDCN(nn.Module): + def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True): + super().__init__() + self.in_channels = in_channels + self.groups = groups + self.channels = channels + self.kernels = kernels + self.v = nn.Linear(in_channels, groups * channels, bias=True) + self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True) + self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False) + self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True) + self.out = nn.Linear(groups * channels, in_channels) + self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False) + self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True) + self.max_scale = 6 + self._init_weights() + def _init_weights(self): + zeros_(self.qk_deformables.weight.data) + zeros_(self.qk_scales.weight.data) + zeros_(self.qk_deformables.bias.data) + zeros_(self.qk_weights.weight.data) + zeros_(self.v.bias.data) + zeros_(self.out.bias.data) + num_prior = int(self.kernels ** 0.5) + dx = torch.linspace(-1, 1, num_prior, device="cuda") + dy = torch.linspace(-1, 1, num_prior, device="cuda") + dxy = torch.meshgrid([dx, dy], indexing="xy") + dxy = torch.stack(dxy, dim=-1) + dxy = dxy.view(-1, 2) + self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy + for i in range(self.groups): + scale = (i+1)/self.groups - 0.0001 + inv_scale = math.log((scale)/(1-scale)) + self.deformables_scale.data[..., i, :, :] = inv_scale + def forward(self, x): + B, H, W, _ = x.shape + v = self.v(x).view(B, H, W, self.groups, self.channels) + deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2) + scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale + deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale + weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels) + out = DCNFunction.apply(v, deformables, weights) + out = out.view(B, H, W, -1) + out = self.out(out) + return out \ No newline at end of file diff --git a/src/models/layers/patch_embed.py b/src/models/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..a7dba7e2f947e1f58682c47bc3fb1e904f1936fe --- /dev/null +++ b/src/models/layers/patch_embed.py @@ -0,0 +1,22 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x \ No newline at end of file diff --git a/src/models/layers/rmsnorm.py b/src/models/layers/rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..6be8f020e70beb329fda68c70f06cfc1d4a99fd7 --- /dev/null +++ b/src/models/layers/rmsnorm.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn + + +class _RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + +RMSNorm = _RMSNorm \ No newline at end of file diff --git a/src/models/layers/rope.py b/src/models/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..10105d82eebcc16af34972eeb2a881c710d5850f --- /dev/null +++ b/src/models/layers/rope.py @@ -0,0 +1,69 @@ +from typing import Tuple +import torch + + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + +def precompute_freqs_cis_ex2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=1.0): + if isinstance(scale, float): + scale = (scale, scale) + x_pos = torch.linspace(0, height*scale[0], width) + y_pos = torch.linspace(0, width*scale[1], height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, None, :, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +def apply_rotary_emb_crossattention( + xq: torch.Tensor, + xk: torch.Tensor, + yk: torch.Tensor, + freqs_cis1: torch.Tensor, + freqs_cis2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + freqs_cis1 = freqs_cis1[None, None, :, :] + freqs_cis2 = freqs_cis2[None, None, :, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + yk_ = torch.view_as_complex(yk.float().reshape(*yk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis1).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis1).flatten(3) + yk_out = torch.view_as_real(yk_ * freqs_cis2).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk), yk_out.type_as(yk) \ No newline at end of file diff --git a/src/models/layers/swiglu.py b/src/models/layers/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..90924737222e74d933e8928b7d8ad4a7ec817857 --- /dev/null +++ b/src/models/layers/swiglu.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + +class _SwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w12 = nn.Linear(dim, hidden_dim*2, bias=False) + self.w3 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x1, x2 = self.w12(x).chunk(2, dim=-1) + return self.w3(torch.nn.functional.silu(x1)*x2) + + +# try: +# from xformers.ops import SwiGLU as aa +# SwiGLU = SwiGLU +# print("use xformers swiglu") +# except: +# print("use slow swiglu") + +SwiGLU = _SwiGLU \ No newline at end of file diff --git a/src/models/layers/time_embed.py b/src/models/layers/time_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..ac983ebf6e37328f984223a9eda78479aa20632e --- /dev/null +++ b/src/models/layers/time_embed.py @@ -0,0 +1,30 @@ +import math +import torch +import torch.nn as nn + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding.to(t.dtype) + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb \ No newline at end of file diff --git a/src/models/transformer/pixnerd_c2i.py b/src/models/transformer/pixnerd_c2i.py new file mode 100644 index 0000000000000000000000000000000000000000..fd986c8ad1d7cdc3767419b9b8414e6e7a09eba3 --- /dev/null +++ b/src/models/transformer/pixnerd_c2i.py @@ -0,0 +1,384 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from functools import lru_cache +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class NerfEmbedder(nn.Module): + def __init__(self, in_channels, hidden_size_input, max_freqs): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + self.embedder = nn.Sequential( + nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True), + ) + + @lru_cache + def fetch_pos(self, patch_size, device, dtype): + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + freqs = torch.linspace(0, self.max_freqs, self.max_freqs, dtype=dtype, device=device) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + coeffs = (1 + freqs_x * freqs_y) ** -1 + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + return dct + + + def forward(self, inputs): + B, P2, C = inputs.shape + patch_size = int(P2 ** 0.5) + device = inputs.device + dtype = inputs.dtype + dct = self.fetch_pos(patch_size, device, dtype) + dct = dct.repeat(B, 1, 1) + inputs = torch.cat([inputs, dct], dim=-1) + inputs = self.embedder(inputs) + return inputs + + +class NerfBlock(nn.Module): + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio=4): + super().__init__() + self.param_generator1 = nn.Sequential( + nn.Linear(hidden_size_s, 2*hidden_size_x**2*mlp_ratio, bias=True), + ) + self.norm = RMSNorm(hidden_size_x, eps=1e-6) + self.mlp_ratio = mlp_ratio + def forward(self, x, s): + batch_size, num_x, hidden_size_x = x.shape + mlp_params1 = self.param_generator1(s) + fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1) + fc1_param1 = fc1_param1.view(batch_size, hidden_size_x, hidden_size_x*self.mlp_ratio) + fc2_param1 = fc2_param1.view(batch_size, hidden_size_x*self.mlp_ratio, hidden_size_x) + + # normalize fc1 + normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2) + # normalize fc2 + normalized_fc2_param1 = torch.nn.functional.normalize(fc2_param1, dim=-2) + # mlp 1 + res_x = x + x = self.norm(x) + x = torch.bmm(x, normalized_fc1_param1) + x = torch.nn.functional.silu(x) + x = torch.bmm(x, normalized_fc2_param1) + x = x + res_x + return x + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm = RMSNorm(hidden_size, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + def forward(self, x): + x = self.norm(x) + x = self.linear(x) + return x + +class PixNerDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + hidden_size_x=64, + nerf_mlpratio=4, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = NerfEmbedder(in_channels, hidden_size_x, max_freqs=8) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = NerfFinalLayer(hidden_size_x, self.out_channels) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_cond_blocks) + ]) + self.blocks.extend([ + NerfBlock(self.hidden_size, hidden_size_x, nerf_mlpratio) for _ in range(self.num_cond_blocks, self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # zero init final layer + nn.init.zeros_(self.final_layer.linear.weight) + nn.init.zeros_(self.final_layer.linear.bias) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + s = nn.functional.silu(t + s) + batch_size, length, _ = s.shape + x = x.reshape(batch_size*length, self.in_channels, self.patch_size**2) + x = x.transpose(1, 2) + s = s.view(batch_size*length, self.hidden_size) + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s) + x = self.final_layer(x) + x = x.transpose(1, 2) + x = x.reshape(batch_size, length, -1) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x \ No newline at end of file diff --git a/src/models/transformer/pixnerd_t2i.py b/src/models/transformer/pixnerd_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..153091692bf2be2cef96dfc1015688a090e89ad8 --- /dev/null +++ b/src/models/transformer/pixnerd_t2i.py @@ -0,0 +1,310 @@ +import torch +import torch.nn as nn + +from functools import lru_cache +from src.models.layers.attention_op import attention +from src.models.layers.rope import apply_rotary_emb, precompute_freqs_cis_ex2d as precompute_freqs_cis_2d +from src.models.layers.time_embed import TimestepEmbedder as TimestepEmbedder +from src.models.layers.patch_embed import Embed as Embed +from src.models.layers.swiglu import SwiGLU as FeedForward +from src.models.layers.rmsnorm import RMSNorm as Norm + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.qkv_x = nn.Linear(dim, dim*3, bias=qkv_bias) + self.kv_y = nn.Linear(dim, dim*2, bias=qkv_bias) + + self.q_norm = Norm(self.head_dim) + self.k_norm = Norm(self.head_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, y, pos) -> torch.Tensor: + B, N, C = x.shape + qkv_x = self.qkv_x(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, kx, vx = qkv_x[0], qkv_x[1], qkv_x[2] + q = self.q_norm(q.contiguous()) + kx = self.k_norm(kx.contiguous()) + q, kx = apply_rotary_emb(q, kx, freqs_cis=pos) + kv_y = self.kv_y(y).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + ky, vy = kv_y[0], kv_y[1] + ky = self.k_norm(ky.contiguous()) + + k = torch.cat([kx, ky], dim=2) + v = torch.cat([vx, vy], dim=2) + + q = q.view(B, self.num_heads, -1, C // self.num_heads) # B, H, N, Hc + k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous() # B, H, N, Hc + v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous() + + x = attention(q, k, v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4, ): + super().__init__() + self.norm1 = Norm(hidden_size, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = Norm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, y, c, pos): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), y, pos) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class NerfEmbedder(nn.Module): + def __init__(self, in_channels, hidden_size_input, max_freqs): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + self.embedder = nn.Sequential( + nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True), + ) + + @lru_cache + def fetch_pos(self, patch_size, device, dtype): + pos = precompute_freqs_cis_2d(self.max_freqs ** 2 * 2, patch_size, patch_size) + pos = pos[None, :, :].to(device=device, dtype=dtype) + return pos + + + def forward(self, inputs): + B, P2, C = inputs.shape + patch_size = int(P2 ** 0.5) + device = inputs.device + dtype = inputs.dtype + dct = self.fetch_pos(patch_size, device, dtype) + dct = dct.repeat(B, 1, 1) + inputs = torch.cat([inputs, dct], dim=-1) + inputs = self.embedder(inputs) + return inputs + +class NerfBlock(nn.Module): + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio=4): + super().__init__() + self.param_generator1 = nn.Sequential( + nn.Linear(hidden_size_s, 2*hidden_size_x**2*mlp_ratio, bias=True), + ) + self.norm = Norm(hidden_size_x, eps=1e-6) + self.mlp_ratio = mlp_ratio + def forward(self, x, s): + batch_size, num_x, hidden_size_x = x.shape + mlp_params1 = self.param_generator1(s) + fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1) + fc1_param1 = fc1_param1.view(batch_size, hidden_size_x, hidden_size_x*self.mlp_ratio) + fc2_param1 = fc2_param1.view(batch_size, hidden_size_x*self.mlp_ratio, hidden_size_x) + + # normalize fc1 + normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2) + # mlp 1 + res_x = x + x = self.norm(x) + x = torch.bmm(x, normalized_fc1_param1) + x = torch.nn.functional.silu(x) + x = torch.bmm(x, fc2_param1) + x = x + res_x + return x + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + def forward(self, x): + x = self.linear(x) + return x + +class TextRefineAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) + self.q_norm = Norm(self.head_dim) + self.k_norm = Norm(self.head_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv_x = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv_x[0], qkv_x[1], qkv_x[2] + q = self.q_norm(q) + k = self.k_norm(k) + q = q.view(B, self.num_heads, -1, C // self.num_heads) # B, H, N, Hc + k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous() # B, H, N, Hc + v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous() + x = attention(q, k, v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class TextRefineBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4, ): + super().__init__() + self.norm1 = Norm(hidden_size, eps=1e-6) + self.attn = TextRefineAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = Norm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class PixNerDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + decoder_hidden_size=64, + num_encoder_blocks=18, + num_decoder_blocks=4, + num_text_blocks=4, + patch_size=2, + txt_embed_dim=1024, + txt_max_length=100, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.decoder_hidden_size = decoder_hidden_size + self.num_encoder_blocks = num_encoder_blocks + self.num_decoder_blocks = num_decoder_blocks + self.num_blocks = self.num_encoder_blocks + self.num_decoder_blocks + self.num_text_blocks = num_text_blocks + self.patch_size = patch_size + self.txt_embed_dim = txt_embed_dim + self.txt_max_length = txt_max_length + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.x_embedder = NerfEmbedder(in_channels, decoder_hidden_size, max_freqs=8) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = Embed(txt_embed_dim, hidden_size, bias=True, norm_layer=Norm) + self.y_pos_embedding = torch.nn.Parameter( + torch.randn(1, txt_max_length, hidden_size), + requires_grad=True + ) + self.final_layer = NerfFinalLayer(decoder_hidden_size, in_channels) + encoder_blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_encoder_blocks) + ]) + decoder_blocks = nn.ModuleList([ + NerfBlock(self.hidden_size, self.decoder_hidden_size, mlp_ratio=2) for _ in range(self.num_decoder_blocks) + ]) + self.blocks = nn.ModuleList(encoder_blocks + decoder_blocks) + self.text_refine_blocks = nn.ModuleList([ + TextRefineBlock(self.hidden_size, self.num_groups) for _ in range(self.num_text_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y): + B, _, H, W = x.shape + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + xpos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) + ypos = self.y_pos_embedding + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, -1, self.hidden_size) + ypos.to(y.dtype) + + condition = nn.functional.silu(t) + for i, block in enumerate(self.text_refine_blocks): + y = block(y, condition) + + s = self.s_embedder(x) + for i in range(self.num_encoder_blocks): + s = self.blocks[i](s, y, condition, xpos) + + s = torch.nn.functional.silu(t + s) + batch_size, length, _ = s.shape + x = x.reshape(batch_size * length, self.in_channels, self.patch_size ** 2 ) + x = x.transpose(1, 2) + s = s.view(batch_size * length, self.hidden_size) + x = self.x_embedder(x) + + for i in range(self.num_decoder_blocks): + def checkpoint_forward(x, s, block=self.blocks[i + self.num_encoder_blocks]): + return block(x, s) + x = checkpoint_forward(x, s) + x = self.final_layer(x) + x = x.transpose(1, 2) + x = x.reshape(batch_size, length, -1) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), + (H, W), + kernel_size=self.patch_size, + stride=self.patch_size) + return x \ No newline at end of file diff --git a/src/models/transformer/pixnerd_t2i_heavydecoder.py b/src/models/transformer/pixnerd_t2i_heavydecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8ff6fbe9d4fde18173791dbb98d9f7d01417c1 --- /dev/null +++ b/src/models/transformer/pixnerd_t2i_heavydecoder.py @@ -0,0 +1,317 @@ +import torch +import torch.nn as nn + +from functools import lru_cache +from src.models.layers.attention_op import attention +from src.models.layers.rope import apply_rotary_emb, precompute_freqs_cis_ex2d as precompute_freqs_cis_2d +from src.models.layers.time_embed import TimestepEmbedder as TimestepEmbedder +from src.models.layers.patch_embed import Embed as Embed +from src.models.layers.swiglu import SwiGLU as FeedForward +from src.models.layers.rmsnorm import RMSNorm as Norm + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.qkv_x = nn.Linear(dim, dim*3, bias=qkv_bias) + self.kv_y = nn.Linear(dim, dim*2, bias=qkv_bias) + + self.q_norm = Norm(self.head_dim) + self.k_norm = Norm(self.head_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, y, pos) -> torch.Tensor: + B, N, C = x.shape + qkv_x = self.qkv_x(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, kx, vx = qkv_x[0], qkv_x[1], qkv_x[2] + q = self.q_norm(q.contiguous()) + kx = self.k_norm(kx.contiguous()) + q, kx = apply_rotary_emb(q, kx, freqs_cis=pos) + kv_y = self.kv_y(y).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + ky, vy = kv_y[0], kv_y[1] + ky = self.k_norm(ky.contiguous()) + + k = torch.cat([kx, ky], dim=2) + v = torch.cat([vx, vy], dim=2) + + q = q.view(B, self.num_heads, -1, C // self.num_heads) # B, H, N, Hc + k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous() # B, H, N, Hc + v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous() + + x = attention(q, k, v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4, ): + super().__init__() + self.norm1 = Norm(hidden_size, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = Norm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, y, c, pos): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), y, pos) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class NerfEmbedder(nn.Module): + def __init__(self, in_channels, hidden_size_input, max_freqs): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + self.embedder = nn.Sequential( + nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True), + ) + + @lru_cache + def fetch_pos(self, patch_size_h, patch_size_w, device, dtype): + pos = precompute_freqs_cis_2d(self.max_freqs ** 2 * 2, patch_size_h, patch_size_w, scale=(16/patch_size_h, 16/patch_size_w)) + pos = pos[None, :, :].to(device=device, dtype=dtype) + return pos + + + def forward(self, inputs, patch_size_h, patch_size_w): + B, _, C = inputs.shape + device = inputs.device + dtype = inputs.dtype + dct = self.fetch_pos(patch_size_h, patch_size_w, device, dtype) + dct = dct.repeat(B, 1, 1) + inputs = torch.cat([inputs, dct], dim=-1) + inputs = self.embedder(inputs) + return inputs + +class NerfBlock(nn.Module): + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio=4): + super().__init__() + self.param_generator1 = nn.Sequential( + nn.Linear(hidden_size_s, 2*hidden_size_x**2*mlp_ratio, bias=True), + ) + self.norm = Norm(hidden_size_x, eps=1e-6) + self.mlp_ratio = mlp_ratio + def forward(self, x, s): + batch_size, num_x, hidden_size_x = x.shape + mlp_params1 = self.param_generator1(s) + fc1_param1, fc2_param1 = mlp_params1.chunk(2, dim=-1) + fc1_param1 = fc1_param1.view(batch_size, hidden_size_x, hidden_size_x*self.mlp_ratio) + fc2_param1 = fc2_param1.view(batch_size, hidden_size_x*self.mlp_ratio, hidden_size_x) + + # normalize fc1 + normalized_fc1_param1 = torch.nn.functional.normalize(fc1_param1, dim=-2) + # mlp 1 + res_x = x + x = self.norm(x) + x = torch.bmm(x, normalized_fc1_param1) + x = torch.nn.functional.silu(x) + x = torch.bmm(x, fc2_param1) + x = x + res_x + return x + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + def forward(self, x): + x = self.linear(x) + return x + +class TextRefineAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias) + self.q_norm = Norm(self.head_dim) + self.k_norm = Norm(self.head_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv_x = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv_x[0], qkv_x[1], qkv_x[2] + q = self.q_norm(q) + k = self.k_norm(k) + q = q.view(B, self.num_heads, -1, C // self.num_heads) # B, H, N, Hc + k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous() # B, H, N, Hc + v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous() + x = attention(q, k, v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class TextRefineBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4, ): + super().__init__() + self.norm1 = Norm(hidden_size, eps=1e-6) + self.attn = TextRefineAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = Norm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class PixNerDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + decoder_hidden_size=64, + num_encoder_blocks=18, + num_decoder_blocks=4, + num_text_blocks=4, + patch_size=2, + txt_embed_dim=1024, + txt_max_length=100, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.decoder_hidden_size = decoder_hidden_size + self.num_encoder_blocks = num_encoder_blocks + self.num_decoder_blocks = num_decoder_blocks + self.num_blocks = self.num_encoder_blocks + self.num_decoder_blocks + self.num_text_blocks = num_text_blocks + self.decoder_patch_scaling_h = 1.0 + self.decoder_patch_scaling_w = 1.0 + self.patch_size = patch_size + self.txt_embed_dim = txt_embed_dim + self.txt_max_length = txt_max_length + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.x_embedder = NerfEmbedder(in_channels, decoder_hidden_size, max_freqs=8) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = Embed(txt_embed_dim, hidden_size, bias=True, norm_layer=Norm) + self.y_pos_embedding = torch.nn.Parameter( + torch.randn(1, txt_max_length, hidden_size), + requires_grad=True + ) + self.final_layer = NerfFinalLayer(decoder_hidden_size, in_channels) + encoder_blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_encoder_blocks) + ]) + decoder_blocks = nn.ModuleList([ + NerfBlock(self.hidden_size, self.decoder_hidden_size, mlp_ratio=2) for _ in range(self.num_decoder_blocks) + ]) + self.blocks = nn.ModuleList(encoder_blocks + decoder_blocks) + self.text_refine_blocks = nn.ModuleList([ + TextRefineBlock(self.hidden_size, self.num_groups) for _ in range(self.num_text_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y): + B, _, H, W = x.shape + encoder_h, encoder_w = int(H/self.decoder_patch_scaling_h), int(W/self.decoder_patch_scaling_w) + decoder_patch_size_h = int(self.patch_size * self.decoder_patch_scaling_h) + decoder_patch_size_w = int(self.patch_size * self.decoder_patch_scaling_w) + x_for_encoder = torch.nn.functional.interpolate(x, (encoder_h, encoder_w)) + + x_for_encoder = torch.nn.functional.unfold(x_for_encoder, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x_for_decoder = torch.nn.functional.unfold(x, kernel_size=(decoder_patch_size_h, decoder_patch_size_w), stride=(decoder_patch_size_h, decoder_patch_size_w)).transpose(1, 2) + xpos = self.fetch_pos(encoder_h // self.patch_size, encoder_w // self.patch_size, x.device) + ypos = self.y_pos_embedding + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, -1, self.hidden_size) + ypos.to(y.dtype) + + condition = nn.functional.silu(t) + for i, block in enumerate(self.text_refine_blocks): + y = block(y, condition) + + + s = self.s_embedder(x_for_encoder) + for i in range(self.num_encoder_blocks): + s = self.blocks[i](s, y, condition, xpos) + + s = torch.nn.functional.silu(t + s) + batch_size, length, _ = s.shape + x = x_for_decoder.reshape(batch_size * length, self.in_channels, decoder_patch_size_h * decoder_patch_size_w) + x = x.transpose(1, 2) + s = s.view(batch_size * length, self.hidden_size) + x = self.x_embedder(x, decoder_patch_size_h, decoder_patch_size_w) + + for i in range(self.num_decoder_blocks): + def checkpoint_forward(x, s, block=self.blocks[i + self.num_encoder_blocks]): + return block(x, s) + x = checkpoint_forward(x, s) + x = self.final_layer(x) + x = x.transpose(1, 2) + x = x.reshape(batch_size, length, -1) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), + (H, W), + kernel_size=(decoder_patch_size_h, decoder_patch_size_w), + stride=(decoder_patch_size_h, decoder_patch_size_w)) + return x \ No newline at end of file diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/plugins/bd_env.py b/src/plugins/bd_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c1900e9c34422e58cf14dedf285a4e162cee62db --- /dev/null +++ b/src/plugins/bd_env.py @@ -0,0 +1,70 @@ +import torch +import os +import socket +from typing_extensions import override +from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.plugins.environments.lightning import LightningEnvironment + + +class BDEnvironment(LightningEnvironment): + pass + # def __init__(self) -> None: + # super().__init__() + # self._global_rank: int = 0 + # self._world_size: int = 1 + # + # @property + # @override + # def creates_processes_externally(self) -> bool: + # """Returns whether the cluster creates the processes or not. + # + # If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the + # process launcher/job scheduler and Lightning will not launch new processes. + # + # """ + # return "LOCAL_RANK" in os.environ + # + # @staticmethod + # @override + # def detect() -> bool: + # assert "ARNOLD_WORKER_0_HOST" in os.environ.keys() + # assert "ARNOLD_WORKER_0_PORT" in os.environ.keys() + # return True + # + # @override + # def world_size(self) -> int: + # return self._world_size + # + # @override + # def set_world_size(self, size: int) -> None: + # self._world_size = size + # + # @override + # def global_rank(self) -> int: + # return self._global_rank + # + # @override + # def set_global_rank(self, rank: int) -> None: + # self._global_rank = rank + # rank_zero_only.rank = rank + # + # @override + # def local_rank(self) -> int: + # return int(os.environ.get("LOCAL_RANK", 0)) + # + # @override + # def node_rank(self) -> int: + # return int(os.environ.get("ARNOLD_ID")) + # + # @override + # def teardown(self) -> None: + # if "WORLD_SIZE" in os.environ: + # del os.environ["WORLD_SIZE"] + # + # @property + # def main_address(self) -> str: + # return os.environ.get("ARNOLD_WORKER_0_HOST") + # + # @property + # def main_port(self) -> int: + # return int(os.environ.get("ARNOLD_WORKER_0_PORT")) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/utils/copy.py b/src/utils/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..62cd89da7ffd0f3b65fd0206b9c646f8df5e64c4 --- /dev/null +++ b/src/utils/copy.py @@ -0,0 +1,13 @@ +import torch + +@torch.no_grad() +def copy_params(src_model, dst_model): + for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()): + dst_param.data.copy_(src_param.data) + +@torch.no_grad() +def swap_tensors(tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) \ No newline at end of file diff --git a/src/utils/model_loader.py b/src/utils/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..68c516cdc659b79775db5a234a9888d5bc5411b1 --- /dev/null +++ b/src/utils/model_loader.py @@ -0,0 +1,27 @@ +from typing import Dict, Any, Optional + +import torch +import torch.nn as nn + + +import logging +logger = logging.getLogger(__name__) + +class ModelLoader: + def __init__(self,): + super().__init__() + + def load(self, denoiser): + if denoiser.weight_path: + weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu')) + + if denoiser.load_ema: + prefix = "ema_denoiser." + else: + prefix = "denoiser." + for k, v in denoiser.state_dict().items(): + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + logger.warning(f"Failed to copy {prefix+k} to denoiser weight") + return denoiser \ No newline at end of file diff --git a/src/utils/no_grad.py b/src/utils/no_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..7688c8ea47c51feeb2a574ae7cbb504743424f3c --- /dev/null +++ b/src/utils/no_grad.py @@ -0,0 +1,17 @@ +import torch + +@torch.no_grad() +def no_grad(net): + assert net is not None, "net is None" + for param in net.parameters(): + param.requires_grad = False + net.eval() + return net + +@torch.no_grad() +def filter_nograd_tensors(params_list): + filtered_params_list = [] + for param in params_list: + if param.requires_grad: + filtered_params_list.append(param) + return filtered_params_list \ No newline at end of file diff --git a/src/utils/patch_bugs.py b/src/utils/patch_bugs.py new file mode 100644 index 0000000000000000000000000000000000000000..815c1d12158aff6604c52dc5ec548cd3629c5a2c --- /dev/null +++ b/src/utils/patch_bugs.py @@ -0,0 +1,18 @@ +import torch +import lightning.pytorch.loggers.wandb as wandb + +setattr(wandb, '_WANDB_AVAILABLE', True) +torch.set_float32_matmul_precision('medium') + +import logging +logger = logging.getLogger("wandb") +logger.setLevel(logging.WARNING) + +import os +os.environ["NCCL_DEBUG"] = "WARN" +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=UserWarning) +