| import numpy as np | |
| import os | |
| import PIL | |
| import pickle | |
| import torch | |
| import argparse | |
| import json | |
| from PIL import Image | |
| import torch.nn as nn | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| from src.data_loader import DataCollatorForSupervisedDataset, get_dataset | |
| from src.data_processing import tensor_to_pil | |
| from src.model_processing import get_model | |
| from PIL import Image | |
| from accelerate import Accelerator | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from concurrent.futures import ThreadPoolExecutor | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_name", type=str, default="chameleon") | |
| parser.add_argument("--model_path", type=str, default=None) | |
| parser.add_argument("--dataset_name", type=str, default="task3-movie-posters") | |
| parser.add_argument("--split_name", type=str, default="test") | |
| parser.add_argument("--batch_size", default=8, type=int) | |
| parser.add_argument("--output_dir", type=str, default=None) | |
| parser.add_argument("--begin_id", default=0, type=int) | |
| parser.add_argument("--n_take", default=-1, type=int) | |
| args = parser.parse_args() | |
| batch_size = args.batch_size | |
| output_dir = args.output_dir | |
| accelerator = Accelerator() | |
| if accelerator.is_main_process and output_dir is not None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(f"{output_dir}/original_images", exist_ok=True) | |
| os.makedirs(f"{output_dir}/reconstructed_images", exist_ok=True) | |
| os.makedirs(f"{output_dir}/results", exist_ok=True) | |
| model, data_params = get_model(args.model_path, args.model_name) | |
| dataset = get_dataset(args.dataset_name, args.split_name, None if args.n_take <= 0 else args.n_take) | |
| data_collator = DataCollatorForSupervisedDataset(args.dataset_name, **data_params) | |
| dataloader = DataLoader( | |
| dataset, batch_size=batch_size, num_workers=0, collate_fn=data_collator | |
| ) | |
| model, dataloader = accelerator.prepare(model, dataloader) | |
| print("Model prepared...") | |
| def save_results( | |
| pixel_values, reconstructed_image, idx, output_dir, data_params | |
| ): | |
| if reconstructed_image is None: | |
| return | |
| ori_img = tensor_to_pil(pixel_values, **data_params) | |
| rec_img = tensor_to_pil(reconstructed_image, **data_params) | |
| ori_img.save(f"{output_dir}/original_images/{idx:08d}.png") | |
| rec_img.save(f"{output_dir}/reconstructed_images/{idx:08d}.png") | |
| result = { | |
| "ori_img": ori_img, | |
| "rec_img": rec_img, | |
| } | |
| with open(f"{output_dir}/results/{idx:08d}.pickle", "wb") as fw: | |
| pickle.dump(result, fw) | |
| executor = ThreadPoolExecutor(max_workers=16) | |
| with torch.no_grad(): | |
| print("Begin data loading...") | |
| for batch in tqdm(dataloader): | |
| pixel_values = batch["image"] | |
| reconstructed_images = model(pixel_values) | |
| if isinstance(reconstructed_images, tuple): | |
| reconstructed_images = reconstructed_images[0] | |
| if output_dir is not None: | |
| idx_list = batch["idx"] | |
| original_images = pixel_values.detach().cpu() | |
| if not isinstance(reconstructed_images, list): | |
| reconstructed_images = reconstructed_images.detach().cpu() | |
| for i in range(pixel_values.shape[0]): | |
| executor.submit( | |
| save_results, | |
| original_images[i], | |
| reconstructed_images[i], | |
| idx_list[i], | |
| output_dir, | |
| data_params, | |
| ) | |
| executor.shutdown(wait=True) | |