my-tide-env / data_provider /data_factory.py
SeungHyeok Jang
Upload model files with Git LFS
e1ccef5
raw
history blame
1.95 kB
from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \
MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, Dataset_Meteorology, TIDE_LEVEL_15MIN_MULTI, Dataset_Pred
from data_provider.uea import collate_fn
from torch.utils.data import DataLoader
data_dict = {
'TIDE': TIDE_LEVEL_15MIN_MULTI,
'ETTh1': Dataset_ETT_hour,
'ETTh2': Dataset_ETT_hour,
'ETTm1': Dataset_ETT_minute,
'ETTm2': Dataset_ETT_minute,
'custom': Dataset_Custom,
'm4': Dataset_M4,
'PSM': PSMSegLoader,
'MSL': MSLSegLoader,
'SMAP': SMAPSegLoader,
'SMD': SMDSegLoader,
'SWAT': SWATSegLoader,
'UEA': UEAloader,
'Meteorology' : Dataset_Meteorology
}
def data_provider(args, flag):
Data = data_dict[args.data]
timeenc = 0 if args.embed != 'timeF' else 1
# ★★★ 핵심 수정 사항 1 ★★★
# val, test, test_full 에서는 shuffle을 False로 설정
shuffle_flag = False if flag in ['test', 'TEST', 'val', 'test_full'] else True
# train일 때만 마지막 불완전한 배치를 버리고, 나머지는 모두 사용
drop_last = True if flag == 'train' else False
# --------------------------
batch_size = args.batch_size
freq = args.freq
# (if/elif/else 로직은 사용자 환경에 맞게 유지하되, 아래 구조를 따릅니다)
data_set = Data(
args=args,
root_path=args.root_path,
data_path=args.data_path,
flag=flag,
size=[args.seq_len, args.label_len, args.pred_len],
features=args.features,
target=args.target,
timeenc=timeenc,
freq=freq
)
print(flag, len(data_set))
data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last)
return data_set, data_loader