Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: [email protected] | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| from lib.common.seg3d_lossless import Seg3dLossless | |
| from lib.common.train_util import * | |
| torch.backends.cudnn.benchmark = True | |
| class IFGeo(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super(IFGeo, self).__init__() | |
| self.cfg = cfg | |
| self.batch_size = self.cfg.batch_size | |
| self.lr_G = self.cfg.lr_G | |
| self.use_sdf = cfg.sdf | |
| self.mcube_res = cfg.mcube_res | |
| self.clean_mesh_flag = cfg.clean_mesh | |
| self.overfit = cfg.overfit | |
| if cfg.dataset.prior_type == "SMPL": | |
| from lib.net.IFGeoNet import IFGeoNet | |
| self.netG = IFGeoNet(cfg) | |
| else: | |
| from lib.net.IFGeoNet_nobody import IFGeoNet | |
| self.netG = IFGeoNet(cfg) | |
| self.resolutions = ( | |
| np.logspace( | |
| start=5, | |
| stop=np.log2(self.mcube_res), | |
| base=2, | |
| num=int(np.log2(self.mcube_res) - 4), | |
| endpoint=True, | |
| ) + 1.0 | |
| ) | |
| self.resolutions = self.resolutions.astype(np.int16).tolist() | |
| self.reconEngine = Seg3dLossless( | |
| query_func=query_func_IF, | |
| b_min=[[-1.0, 1.0, -1.0]], | |
| b_max=[[1.0, -1.0, 1.0]], | |
| resolutions=self.resolutions, | |
| align_corners=True, | |
| balance_value=0.50, | |
| visualize=False, | |
| debug=False, | |
| use_cuda_impl=False, | |
| faster=True, | |
| ) | |
| self.export_dir = None | |
| self.result_eval = {} | |
| # Training related | |
| def configure_optimizers(self): | |
| # set optimizer | |
| weight_decay = self.cfg.weight_decay | |
| momentum = self.cfg.momentum | |
| optim_params_G = [{"params": self.netG.parameters(), "lr": self.lr_G}] | |
| if self.cfg.optim == "Adadelta": | |
| optimizer_G = torch.optim.Adadelta( | |
| optim_params_G, lr=self.lr_G, weight_decay=weight_decay | |
| ) | |
| elif self.cfg.optim == "Adam": | |
| optimizer_G = torch.optim.Adam(optim_params_G, lr=self.lr_G, weight_decay=weight_decay) | |
| elif self.cfg.optim == "RMSprop": | |
| optimizer_G = torch.optim.RMSprop( | |
| optim_params_G, | |
| lr=self.lr_G, | |
| weight_decay=weight_decay, | |
| momentum=momentum, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| # set scheduler | |
| scheduler_G = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma | |
| ) | |
| return [optimizer_G], [scheduler_G] | |
| def training_step(self, batch, batch_idx): | |
| self.netG.train() | |
| preds_G = self.netG(batch) | |
| error_G = self.netG.compute_loss(preds_G, batch["labels_geo"]) | |
| # metrics processing | |
| metrics_log = { | |
| "loss": error_G, | |
| } | |
| self.log_dict( | |
| metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True | |
| ) | |
| return metrics_log | |
| def training_epoch_end(self, outputs): | |
| # metrics processing | |
| metrics_log = { | |
| "train/avgloss": batch_mean(outputs, "loss"), | |
| } | |
| self.log_dict( | |
| metrics_log, | |
| prog_bar=False, | |
| logger=True, | |
| on_step=False, | |
| on_epoch=True, | |
| rank_zero_only=True | |
| ) | |
| def validation_step(self, batch, batch_idx): | |
| self.netG.eval() | |
| self.netG.training = False | |
| preds_G = self.netG(batch) | |
| error_G = self.netG.compute_loss(preds_G, batch["labels_geo"]) | |
| metrics_log = { | |
| "val/loss": error_G, | |
| } | |
| self.log_dict( | |
| metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True | |
| ) | |
| return metrics_log | |
| def validation_epoch_end(self, outputs): | |
| # metrics processing | |
| metrics_log = { | |
| "val/avgloss": batch_mean(outputs, "val/loss"), | |
| } | |
| self.log_dict( | |
| metrics_log, | |
| prog_bar=False, | |
| logger=True, | |
| on_step=False, | |
| on_epoch=True, | |
| rank_zero_only=True | |
| ) | |