Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, Any, Optional | |
| import torch | |
| import torch.nn as nn | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ModelLoader: | |
| def __init__(self,): | |
| super().__init__() | |
| def load(self, denoiser): | |
| if denoiser.weight_path: | |
| weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu')) | |
| if denoiser.load_ema: | |
| prefix = "ema_denoiser." | |
| else: | |
| prefix = "denoiser." | |
| for k, v in denoiser.state_dict().items(): | |
| try: | |
| v.copy_(weight["state_dict"][prefix+k]) | |
| except: | |
| logger.warning(f"Failed to copy {prefix+k} to denoiser weight") | |
| return denoiser |