Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import json | |
| import sys | |
| from utils import pickle_util | |
| history_array = [] | |
| def save_model(epoch, model, optimizer, file_save_path): | |
| dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir)) | |
| if not os.path.exists(dirpath): | |
| print("mkdir:", dirpath) | |
| os.makedirs(dirpath) | |
| opti = None | |
| if optimizer is not None: | |
| opti = optimizer.state_dict() | |
| torch.save(obj={ | |
| 'epoch': epoch, | |
| 'model': model.state_dict(), | |
| 'optimizer': opti, | |
| }, f=file_save_path) | |
| history_array.append(file_save_path) | |
| def save_model_v4(epoch, model, optimizer, file_save_path, discriminator): | |
| dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir)) | |
| if not os.path.exists(dirpath): | |
| print("mkdir:", dirpath) | |
| os.makedirs(dirpath) | |
| opti = None | |
| if optimizer is not None: | |
| opti = optimizer.state_dict() | |
| torch.save(obj={ | |
| 'epoch': epoch, | |
| 'model': model.state_dict(), | |
| 'optimizer': opti, | |
| "discriminator": discriminator, | |
| }, f=file_save_path) | |
| history_array.append(file_save_path) | |
| def delete_last_saved_model(): | |
| if len(history_array) == 0: | |
| return | |
| last_path = history_array.pop() | |
| if os.path.exists(last_path): | |
| os.remove(last_path) | |
| print("delete model:", last_path) | |
| if os.path.exists(last_path + ".json"): | |
| os.remove(last_path + ".json") | |
| def load_model(resume_path, model, optimizer=None, strict=True): | |
| checkpoint = torch.load(resume_path, map_location=torch.device('cpu')) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| model.load_state_dict(checkpoint['model'], strict=strict) | |
| if optimizer is not None: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| print("checkpoint loaded!") | |
| return start_epoch | |
| def save_model_v2(model, args, model_save_name): | |
| model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name) | |
| save_model(0, model, None, model_save_path) | |
| print("save:", model_save_path) | |
| def save_project_info(args): | |
| run_info = { | |
| "cmd_str": ' '.join(sys.argv[1:]), | |
| "args": vars(args), | |
| } | |
| name = "run_info.json" | |
| folder = os.path.join(args.model_save_folder, args.project, args.name) | |
| if not os.path.exists(folder): | |
| os.makedirs(folder) | |
| json_file_path = os.path.join(folder, name) | |
| with open(json_file_path, "w") as f: | |
| json.dump(run_info, f) | |
| print("save_project_info:", json_file_path) | |
| def get_pkl_json(folder): | |
| names = [i for i in os.listdir(folder) if ".pkl.json" in i] | |
| assert len(names) == 1 | |
| json_path = os.path.join(folder, names[0]) | |
| obj = pickle_util.read_json(json_path) | |
| return obj | |
| # 并行 | |
| def is_data_parallel_checkpoint(state_dict): | |
| return any(key.startswith('module.') for key in state_dict.keys()) | |
| def map_state_dict(state_dict): | |
| if is_data_parallel_checkpoint(state_dict): | |
| # 处理 DataParallel 添加的前缀 'module.' | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] if k.startswith('module.') else k # 移除前缀 'module.' | |
| new_state_dict[name] = v | |
| return new_state_dict | |
| return state_dict | |