Spaces:
Sleeping
Sleeping
| from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \ | |
| MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, Dataset_Meteorology, TIDE_LEVEL_15MIN_MULTI, Dataset_Pred | |
| from data_provider.uea import collate_fn | |
| from torch.utils.data import DataLoader | |
| data_dict = { | |
| 'TIDE': TIDE_LEVEL_15MIN_MULTI, | |
| 'ETTh1': Dataset_ETT_hour, | |
| 'ETTh2': Dataset_ETT_hour, | |
| 'ETTm1': Dataset_ETT_minute, | |
| 'ETTm2': Dataset_ETT_minute, | |
| 'custom': Dataset_Custom, | |
| 'm4': Dataset_M4, | |
| 'PSM': PSMSegLoader, | |
| 'MSL': MSLSegLoader, | |
| 'SMAP': SMAPSegLoader, | |
| 'SMD': SMDSegLoader, | |
| 'SWAT': SWATSegLoader, | |
| 'UEA': UEAloader, | |
| 'Meteorology' : Dataset_Meteorology | |
| } | |
| def data_provider(args, flag): | |
| Data = data_dict[args.data] | |
| timeenc = 0 if args.embed != 'timeF' else 1 | |
| # ★★★ 핵심 수정 사항 1 ★★★ | |
| # val, test, test_full 에서는 shuffle을 False로 설정 | |
| shuffle_flag = False if flag in ['test', 'TEST', 'val', 'test_full'] else True | |
| # train일 때만 마지막 불완전한 배치를 버리고, 나머지는 모두 사용 | |
| drop_last = True if flag == 'train' else False | |
| # -------------------------- | |
| batch_size = args.batch_size | |
| freq = args.freq | |
| # (if/elif/else 로직은 사용자 환경에 맞게 유지하되, 아래 구조를 따릅니다) | |
| data_set = Data( | |
| args=args, | |
| root_path=args.root_path, | |
| data_path=args.data_path, | |
| flag=flag, | |
| size=[args.seq_len, args.label_len, args.pred_len], | |
| features=args.features, | |
| target=args.target, | |
| timeenc=timeenc, | |
| freq=freq | |
| ) | |
| print(flag, len(data_set)) | |
| data_loader = DataLoader( | |
| data_set, | |
| batch_size=batch_size, | |
| shuffle=shuffle_flag, | |
| num_workers=args.num_workers, | |
| drop_last=drop_last) | |
| return data_set, data_loader |