Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \ | |
| Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \ | |
| Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, TimeXer | |
| class Exp_Basic(object): | |
| def __init__(self, args): | |
| self.args = args | |
| self.model_dict = { | |
| 'TimesNet': TimesNet, | |
| 'Autoformer': Autoformer, | |
| 'Transformer': Transformer, | |
| 'Nonstationary_Transformer': Nonstationary_Transformer, | |
| 'DLinear': DLinear, | |
| 'FEDformer': FEDformer, | |
| 'Informer': Informer, | |
| 'LightTS': LightTS, | |
| 'Reformer': Reformer, | |
| 'ETSformer': ETSformer, | |
| 'PatchTST': PatchTST, | |
| 'Pyraformer': Pyraformer, | |
| 'MICN': MICN, | |
| 'Crossformer': Crossformer, | |
| 'FiLM': FiLM, | |
| 'iTransformer': iTransformer, | |
| 'Koopa': Koopa, | |
| 'TiDE': TiDE, | |
| 'FreTS': FreTS, | |
| 'MambaSimple': MambaSimple, | |
| 'TimeMixer': TimeMixer, | |
| 'TSMixer': TSMixer, | |
| 'SegRNN': SegRNN, | |
| 'TemporalFusionTransformer': TemporalFusionTransformer, | |
| "SCINet": SCINet, | |
| 'TimeXer': TimeXer | |
| } | |
| if args.model == 'Mamba': | |
| print('Please make sure you have successfully installed mamba_ssm') | |
| from models import Mamba | |
| self.model_dict['Mamba'] = Mamba | |
| self.device = self._acquire_device() | |
| self.model = self._build_model().to(self.device) | |
| def _build_model(self): | |
| raise NotImplementedError | |
| return None | |
| def _acquire_device(self): | |
| if self.args.use_gpu: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str( | |
| self.args.gpu) if not self.args.use_multi_gpu else self.args.devices | |
| device = torch.device('cuda:{}'.format(self.args.gpu)) | |
| print('Use GPU: cuda:{}'.format(self.args.gpu)) | |
| else: | |
| device = torch.device('cpu') | |
| print('Use CPU') | |
| return device | |
| def _get_data(self): | |
| pass | |
| def vali(self): | |
| pass | |
| def train(self): | |
| pass | |
| def test(self): | |
| pass | |