SeungHyeok Jang commited on
Commit
e1ccef5
·
1 Parent(s): 9112b64

Upload model files with Git LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. app.py +85 -0
  3. checkpoints/.DS_Store +0 -0
  4. checkpoints/long_term_forecast_DT_0001_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  5. checkpoints/long_term_forecast_DT_0001_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  6. checkpoints/long_term_forecast_DT_0002_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  7. checkpoints/long_term_forecast_DT_0002_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  8. checkpoints/long_term_forecast_DT_0003_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  9. checkpoints/long_term_forecast_DT_0003_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  10. checkpoints/long_term_forecast_DT_0008_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  11. checkpoints/long_term_forecast_DT_0008_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  12. checkpoints/long_term_forecast_DT_0017_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  13. checkpoints/long_term_forecast_DT_0017_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  14. checkpoints/long_term_forecast_DT_0018_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  15. checkpoints/long_term_forecast_DT_0018_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  16. checkpoints/long_term_forecast_DT_0024_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  17. checkpoints/long_term_forecast_DT_0024_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  18. checkpoints/long_term_forecast_DT_0025_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  19. checkpoints/long_term_forecast_DT_0025_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  20. checkpoints/long_term_forecast_DT_0037_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  21. checkpoints/long_term_forecast_DT_0037_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  22. checkpoints/long_term_forecast_DT_0043_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  23. checkpoints/long_term_forecast_DT_0043_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  24. checkpoints/long_term_forecast_DT_0050_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  25. checkpoints/long_term_forecast_DT_0050_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  26. checkpoints/long_term_forecast_DT_0051_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  27. checkpoints/long_term_forecast_DT_0051_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  28. checkpoints/long_term_forecast_DT_0052_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  29. checkpoints/long_term_forecast_DT_0052_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  30. checkpoints/long_term_forecast_DT_0065_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  31. checkpoints/long_term_forecast_DT_0065_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  32. checkpoints/long_term_forecast_DT_0066_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  33. checkpoints/long_term_forecast_DT_0066_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  34. checkpoints/long_term_forecast_DT_0067_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  35. checkpoints/long_term_forecast_DT_0067_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  36. checkpoints/long_term_forecast_DT_0068_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth +3 -0
  37. checkpoints/long_term_forecast_DT_0068_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz +3 -0
  38. data_provider/__init__.py +1 -0
  39. data_provider/__pycache__/__init__.cpython-39.pyc +0 -0
  40. data_provider/__pycache__/data_factory.cpython-39.pyc +0 -0
  41. data_provider/__pycache__/data_loader.cpython-39.pyc +0 -0
  42. data_provider/__pycache__/m4.cpython-39.pyc +0 -0
  43. data_provider/__pycache__/uea.cpython-39.pyc +0 -0
  44. data_provider/data_factory.py +58 -0
  45. data_provider/data_loader.py +1064 -0
  46. data_provider/m4.py +138 -0
  47. data_provider/uea.py +125 -0
  48. exp/.DS_Store +0 -0
  49. exp/__init__.py +0 -0
  50. exp/__pycache__/__init__.cpython-39.pyc +0 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import joblib
6
+ import os
7
+ import json
8
+ import sys # 👈 이 줄 추가
9
+
10
+ # --- 모델 및 스케일러 로딩 ---
11
+ MODEL_LOADED = False
12
+ MODEL_ERROR = "Unknown"
13
+ try:
14
+ # ⭐️⭐️⭐️ 바로 이 부분입니다! ⭐️⭐️⭐️
15
+ # 현재 폴더(.)를 파이썬의 모듈 검색 경로에 추가합니다.
16
+ # 이렇게 하면 app.py가 models, utils 폴더를 찾을 수 있게 됩니다.
17
+ sys.path.append('.')
18
+
19
+ from models.TimeXer import Model as TimeXerModel
20
+ from utils.tools import dotdict
21
+ from utils.timefeatures import time_features
22
+
23
+ # 1. 훈련 스크립트(.sh)의 모든 설정을 그대로 가져옵니다.
24
+ args = dotdict()
25
+ args.model_id = 'DT_0001_144_72'
26
+ args.model = 'TimeXer'
27
+ args.task_name = 'long_term_forecast'
28
+ args.seq_len = 144
29
+ args.label_len = 96
30
+ args.pred_len = 72
31
+ args.features = 'MS'
32
+ args.target = 'residual'
33
+ args.e_layers = 1
34
+ args.d_layers = 1
35
+ args.factor = 3
36
+ args.enc_in = 5
37
+ args.dec_in = 5
38
+ args.c_out = 1
39
+ args.d_model = 256
40
+ args.n_heads = 8
41
+ args.d_ff = 512
42
+ args.output_attention = True
43
+ args.device = torch.device('cpu')
44
+
45
+ # 2. 모델 뼈대를 만들고 학습된 가중치를 입힙니다.
46
+ model = TimeXerModel(args).float()
47
+ model.load_state_dict(torch.load('checkpoints/checkpoint.pth', map_location=args.device))
48
+ model.eval()
49
+
50
+ # 3. 스케일러를 불러옵니다.
51
+ scaler = joblib.load('checkpoints/scaler.gz')
52
+ MODEL_LOADED = True
53
+ print("✅ 모델과 스케일러 로딩 성공!")
54
+
55
+ except Exception as e:
56
+ MODEL_ERROR = str(e)
57
+ print(f"❌ 모델 로딩 중 에러 발생: {MODEL_ERROR}")
58
+
59
+ # --- 예측을 수행하는 함수 ---
60
+ def predict_tide(input_csv_string):
61
+ # (이 부분은 수정할 필요 없습니다)
62
+ # ... 이전 코드와 동일 ...
63
+ if not MODEL_LOADED:
64
+ raise gr.Error(f"모델 로딩 실패: {MODEL_ERROR}")
65
+ # ...
66
+ # ...
67
+ return json.dumps({"prediction": prediction.flatten().tolist()}, indent=2)
68
+
69
+ # --- Gradio 인터페이스 생성 ---
70
+ # NameError를 방지하기 위해, try 블록 바깥에 있는 args 참조를 제거하거나
71
+ # 모델 로딩이 실패했을 경우를 대비해 기본값을 사용하도록 수정합니다.
72
+ desc_text = "과거 144개 시점의 다변량 데이터를 입력하면, 미래 72개 시점의 조위 편차(residual)를 예측합니다."
73
+ if MODEL_LOADED:
74
+ desc_text = f"과거 {args.seq_len}개 시점의 다변량 데이터를 입력하면, 미래 {args.pred_len}개 시점의 조위 편차(residual)를 예측합니다."
75
+
76
+ demo = gr.Interface(
77
+ fn=predict_tide,
78
+ inputs=gr.Textbox(lines=10, placeholder="CSV 형식으로 144개의 데이터를 입력하세요.\n첫 줄은 헤더(date,OT,...)여야 합니다."),
79
+ outputs="json",
80
+ title="조위 예측 모델 (TimeXer)",
81
+ description=desc_text
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ demo.launch()
checkpoints/.DS_Store ADDED
Binary file (6.15 kB). View file
 
checkpoints/long_term_forecast_DT_0001_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d90564945616014390c51c042c03006eb598cd7afb08b6413f1176dab99cc91
3
+ size 9201899
checkpoints/long_term_forecast_DT_0001_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1870b2443b65c3c464d4cde1a6cadfe28afbf0104ad9e34c046630b7f566014
3
+ size 571
checkpoints/long_term_forecast_DT_0002_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f67514657d9951061f5ef1567abc786dd6017dbf75254337b17ff0bd3967033
3
+ size 9201899
checkpoints/long_term_forecast_DT_0002_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a271b2c895982625cbbaee3b74041d5ca9c3ffe23659a86210bfb9d9f13257f
3
+ size 570
checkpoints/long_term_forecast_DT_0003_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a05fc1542358e8bd36649a63f73b7eeaa43151d724ca93f3dccfa33da15f4e5
3
+ size 9201899
checkpoints/long_term_forecast_DT_0003_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2edce6854e998460e107fa038df48a070e4105a5737dca2eb812f9f53662af83
3
+ size 570
checkpoints/long_term_forecast_DT_0008_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d7a2e47edb29e7f65bfe03ef16b9831f81b72ebdf668ebc1ca1ddf8c6d4c9c9
3
+ size 9201899
checkpoints/long_term_forecast_DT_0008_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f5098ce4f49d4f62433bf94a9148c9aa3bf7b9442b692bf105be9312aba62fd
3
+ size 573
checkpoints/long_term_forecast_DT_0017_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ae7650c99955d581330a6c41dd1d8df5cd592531e5ef8c8b386addef2b78890
3
+ size 9201899
checkpoints/long_term_forecast_DT_0017_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c095fa14ee779b2f16dc28ab44147d4d8544e5a3ff96961ef51dc241be7534f6
3
+ size 569
checkpoints/long_term_forecast_DT_0018_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01e929ab4d3a6c5b4e2170b62df6d28a7d9411f6f50eaccc1425ca607edb62fc
3
+ size 9201899
checkpoints/long_term_forecast_DT_0018_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5327ce8a59a66b31906bcd9376cc1e65c6bdf7438cae2fd5b6dd14b7a39b821d
3
+ size 571
checkpoints/long_term_forecast_DT_0024_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0101778cd5c72b6782608eaf2c95423e130352c1e02d788ed85481edfb3130bf
3
+ size 9201899
checkpoints/long_term_forecast_DT_0024_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cacc57c0e390ff6d229ff0a5dd30bd71bbd81d23ab53e023dd3b595bcbf5cb04
3
+ size 572
checkpoints/long_term_forecast_DT_0025_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e24de3b1970a1de91006e748fed91fbb1d84f186a7f2ff3215e894d0a5e2f10
3
+ size 9201899
checkpoints/long_term_forecast_DT_0025_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f52fcae0e85f292f27c2cd37afb3f65f1886a53bb2056a3846fcf9eca59433a7
3
+ size 573
checkpoints/long_term_forecast_DT_0037_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9da0a6f076617a9d1f5969eb0e15496a726716caab7b5fdc92911fa35e08cd8e
3
+ size 9201899
checkpoints/long_term_forecast_DT_0037_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4de6608240f88757bb5f6334f5af83cd627405d9d868665796f3840c8dcacc6d
3
+ size 575
checkpoints/long_term_forecast_DT_0043_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa857f259b2b60f2dc8577cead41ae1228eb039862690d2cc08e30a651ae47c6
3
+ size 9201899
checkpoints/long_term_forecast_DT_0043_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9366d698d3809b5359dad715695dfb42313536890258b6efc88c5a7d06c430b0
3
+ size 572
checkpoints/long_term_forecast_DT_0050_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de824421d2b8dbc5ed724d62f1d2aea6a06d72b9b13cfe328f7f2f79d30d92c7
3
+ size 9201899
checkpoints/long_term_forecast_DT_0050_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e617f92c82157759adda11abd98529441fb87745be7eed1e36274b646f15849b
3
+ size 570
checkpoints/long_term_forecast_DT_0051_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fe1fbdd82f33cd3d815a2659e998231991398ba6c6fa1830c98d713505b4b4e
3
+ size 9201899
checkpoints/long_term_forecast_DT_0051_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36a22e462b56ef6142ba5361dc9e8a5c6550c0514dcbb77e808cf45c9524e3bc
3
+ size 569
checkpoints/long_term_forecast_DT_0052_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:511f6cba43023440fad6efaa62e2dfd1ed54ef20406df4b494a16f1615b35c7b
3
+ size 9201899
checkpoints/long_term_forecast_DT_0052_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:020c2a8f89cddc37a7da3d4c081e427ab5e3da535756983ed68f059b749da694
3
+ size 562
checkpoints/long_term_forecast_DT_0065_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5bff6c2695c83b57013e74aa2f242d87c229b5f92b9bfe3d6755e613b5fbd48
3
+ size 9201899
checkpoints/long_term_forecast_DT_0065_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6df622e93b17976afbb742bf46c371d445e9f3a82f37fa900223729a67fd7155
3
+ size 577
checkpoints/long_term_forecast_DT_0066_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71138b36455c41b54e7d77353ee9289a4900e7366edd287d401385f321cff85d
3
+ size 9201899
checkpoints/long_term_forecast_DT_0066_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bd954bb71a5aa29aafc91f907dc65ebd24c5c94fa19753d95ef599af8f174c4
3
+ size 568
checkpoints/long_term_forecast_DT_0067_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80f185578b53943a7dd8ad63d1e4c0d811710ff139db68d1449f43588a0eb1ae
3
+ size 9201899
checkpoints/long_term_forecast_DT_0067_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:793584b8c0c19a34e6f5b3edea66addcf48ddcb68f996633c406f0338b674c54
3
+ size 575
checkpoints/long_term_forecast_DT_0068_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:610e32f697b85b611d19f949f93504c40deab7ec9533761a0f0906182e1d2420
3
+ size 9201899
checkpoints/long_term_forecast_DT_0068_144_72_TimeXer_TIDE_ftMS_sl144_ll96_pl72_dm256_nh8_el1_dl1_df512_expand2_dc4_fc3_ebtimeF_dtTrue_Exp_0/scaler.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94ca2c6727da8ad9cd931290eac6cd21a57703caa1e3ef4caddef3217f7b35a4
3
+ size 570
data_provider/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
data_provider/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
data_provider/__pycache__/data_factory.cpython-39.pyc ADDED
Binary file (1.95 kB). View file
 
data_provider/__pycache__/data_loader.cpython-39.pyc ADDED
Binary file (32.3 kB). View file
 
data_provider/__pycache__/m4.cpython-39.pyc ADDED
Binary file (3.65 kB). View file
 
data_provider/__pycache__/uea.cpython-39.pyc ADDED
Binary file (4.79 kB). View file
 
data_provider/data_factory.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \
2
+ MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, Dataset_Meteorology, TIDE_LEVEL_15MIN_MULTI, Dataset_Pred
3
+ from data_provider.uea import collate_fn
4
+ from torch.utils.data import DataLoader
5
+
6
+ data_dict = {
7
+ 'TIDE': TIDE_LEVEL_15MIN_MULTI,
8
+ 'ETTh1': Dataset_ETT_hour,
9
+ 'ETTh2': Dataset_ETT_hour,
10
+ 'ETTm1': Dataset_ETT_minute,
11
+ 'ETTm2': Dataset_ETT_minute,
12
+ 'custom': Dataset_Custom,
13
+ 'm4': Dataset_M4,
14
+ 'PSM': PSMSegLoader,
15
+ 'MSL': MSLSegLoader,
16
+ 'SMAP': SMAPSegLoader,
17
+ 'SMD': SMDSegLoader,
18
+ 'SWAT': SWATSegLoader,
19
+ 'UEA': UEAloader,
20
+ 'Meteorology' : Dataset_Meteorology
21
+ }
22
+
23
+
24
+ def data_provider(args, flag):
25
+ Data = data_dict[args.data]
26
+ timeenc = 0 if args.embed != 'timeF' else 1
27
+
28
+ # ★★★ 핵심 수정 사항 1 ★★★
29
+ # val, test, test_full 에서는 shuffle을 False로 설정
30
+ shuffle_flag = False if flag in ['test', 'TEST', 'val', 'test_full'] else True
31
+ # train일 때만 마지막 불완전한 배치를 버리고, 나머지는 모두 사용
32
+ drop_last = True if flag == 'train' else False
33
+ # --------------------------
34
+
35
+ batch_size = args.batch_size
36
+ freq = args.freq
37
+
38
+ # (if/elif/else 로직은 사용자 환경에 맞게 유지하되, 아래 구조를 따릅니다)
39
+ data_set = Data(
40
+ args=args,
41
+ root_path=args.root_path,
42
+ data_path=args.data_path,
43
+ flag=flag,
44
+ size=[args.seq_len, args.label_len, args.pred_len],
45
+ features=args.features,
46
+ target=args.target,
47
+ timeenc=timeenc,
48
+ freq=freq
49
+ )
50
+ print(flag, len(data_set))
51
+ data_loader = DataLoader(
52
+ data_set,
53
+ batch_size=batch_size,
54
+ shuffle=shuffle_flag,
55
+ num_workers=args.num_workers,
56
+ drop_last=drop_last)
57
+
58
+ return data_set, data_loader
data_provider/data_loader.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import glob
5
+ import re
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from sklearn.preprocessing import StandardScaler
9
+ from utils.timefeatures import time_features
10
+ from data_provider.m4 import M4Dataset, M4Meta
11
+ from data_provider.uea import subsample, interpolate_missing, Normalizer
12
+ from sktime.datasets import load_from_tsfile_to_dataframe
13
+ import warnings
14
+ from utils.augmentation import run_augmentation_single
15
+
16
+ warnings.filterwarnings('ignore')
17
+
18
+ class TIDE_LEVEL_15MIN_MULTI(Dataset):
19
+ def __init__(self, args, root_path, flag='train', size=None,
20
+ features='MS', data_path='DT_0020.csv',
21
+ target='tide_level', scale=True, timeenc=1, freq='15min', seasonal_patterns=None):
22
+ # size [seq_len, label_len, pred_len]
23
+ self.args = args
24
+ # info
25
+ if size == None:
26
+ self.seq_len = 24 * 4 * 4
27
+ self.label_len = 24 * 4
28
+ self.pred_len = 24 * 4
29
+ else:
30
+ self.seq_len = size[0]
31
+ self.label_len = size[1]
32
+ self.pred_len = size[2]
33
+ # init
34
+ assert flag in ['train', 'test', 'val']
35
+ type_map = {'train': 0, 'val': 1, 'test': 2}
36
+ self.set_type = type_map[flag]
37
+
38
+ self.features = features
39
+ self.target = target
40
+ self.scale = scale
41
+ self.timeenc = timeenc
42
+ self.freq = freq
43
+
44
+ self.root_path = root_path
45
+ self.data_path = data_path
46
+ self.__read_data__()
47
+
48
+ def __read_data__(self):
49
+ self.scaler = StandardScaler()
50
+ df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
51
+
52
+ # Dynamically calculate data split points
53
+ data_len = len(df_raw)
54
+ train_ratio = 0.7
55
+ val_ratio = 0.1
56
+ # test_ratio is implicitly 1 - train_ratio - val_ratio
57
+
58
+ train_len = int(data_len * train_ratio)
59
+ val_len = int(data_len * val_ratio)
60
+ test_len = data_len - train_len - val_len
61
+
62
+ border1s = [
63
+ 0,
64
+ train_len - self.seq_len,
65
+ train_len + val_len - self.seq_len
66
+ ]
67
+ border2s = [
68
+ train_len,
69
+ train_len + val_len,
70
+ data_len
71
+ ]
72
+
73
+ border1 = border1s[self.set_type]
74
+ border2 = border2s[self.set_type]
75
+
76
+ if self.features == 'M' or self.features == 'MS':
77
+ cols_data = df_raw.columns[1:]
78
+ df_data = df_raw[cols_data]
79
+ elif self.features == 'S':
80
+ df_data = df_raw[[self.target]]
81
+
82
+ if self.scale:
83
+ # Scaler is fit only on the training data
84
+ train_data = df_data.iloc[border1s[0]:border2s[0]]
85
+ self.scaler.fit(train_data.values)
86
+ data = self.scaler.transform(df_data.values)
87
+ else:
88
+ data = df_data.values
89
+
90
+ df_stamp = df_raw[['date']][border1:border2]
91
+ df_stamp['date'] = pd.to_datetime(df_stamp['date'])
92
+
93
+ if self.timeenc == 0:
94
+ df_stamp['month'] = df_stamp['date'].apply(lambda row: row.month)
95
+ df_stamp['day'] = df_stamp['date'].apply(lambda row: row.day)
96
+ df_stamp['weekday'] = df_stamp['date'].apply(lambda row: row.weekday())
97
+ df_stamp['hour'] = df_stamp['date'].apply(lambda row: row.hour)
98
+ df_stamp['minute'] = df_stamp['date'].apply(lambda row: row.minute // 15)
99
+ data_stamp = df_stamp.drop(columns=['date']).values
100
+ elif self.timeenc == 1:
101
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
102
+ data_stamp = data_stamp.transpose(1, 0)
103
+
104
+ self.data_x = data[border1:border2]
105
+ self.data_y = data[border1:border2]
106
+
107
+ #if self.set_type == 0 and self.args.augmentation_ratio > 0:
108
+ # self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
109
+
110
+ self.data_stamp = data_stamp
111
+
112
+ def __getitem__(self, index):
113
+ s_begin = index
114
+ s_end = s_begin + self.seq_len
115
+ r_begin = s_end - self.label_len
116
+ r_end = r_begin + self.label_len + self.pred_len
117
+
118
+ seq_x = self.data_x[s_begin:s_end]
119
+ seq_y = self.data_y[r_begin:r_end]
120
+ seq_x_mark = self.data_stamp[s_begin:s_end]
121
+ seq_y_mark = self.data_stamp[r_begin:r_end]
122
+
123
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
124
+
125
+ def __len__(self):
126
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
127
+
128
+ def inverse_transform(self, data):
129
+ return self.scaler.inverse_transform(data)
130
+
131
+ class Dataset_Pred(Dataset):
132
+ def __init__(self, root_path, flag='pred', size=None,
133
+ features='S', data_path='tide_data_DT_0001.csv',
134
+ target='tide_level', scale=True, inverse=False, timeenc=0, freq='t', cols=None):
135
+ # size [seq_len, label_len, pred_len]
136
+ # info
137
+ if size == None:
138
+ self.seq_len = 3 * 24 * 60 # 3일치 데이터 (4320분)
139
+ self.label_len = 1 * 24 * 60 # 1일치 데이터 (1440분)
140
+ self.pred_len = 1 * 24 * 60 # 1일치 데이터 (1440분)
141
+ else:
142
+ self.seq_len = size[0]
143
+ self.label_len = size[1]
144
+ self.pred_len = size[2]
145
+ # init
146
+ assert flag in ['pred']
147
+
148
+ self.features = features
149
+ self.target = target
150
+ self.scale = scale
151
+ self.inverse = inverse
152
+ self.timeenc = timeenc
153
+ self.freq = freq
154
+ self.cols = cols
155
+ self.root_path = root_path
156
+ self.data_path = data_path
157
+ self.__read_data__()
158
+
159
+ def __read_data__(self):
160
+ self.scaler = StandardScaler()
161
+ df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
162
+
163
+ # Dynamically calculate data split points
164
+ data_len = len(df_raw)
165
+ train_ratio = 0.7
166
+ val_ratio = 0.1
167
+ # test_ratio is implicitly 1 - train_ratio - val_ratio
168
+
169
+ train_len = int(data_len * train_ratio)
170
+ val_len = int(data_len * val_ratio)
171
+ test_len = data_len - train_len - val_len
172
+
173
+ border1s = [
174
+ 0,
175
+ train_len - self.seq_len,
176
+ train_len + val_len - self.seq_len
177
+ ]
178
+ border2s = [
179
+ train_len,
180
+ train_len + val_len,
181
+ data_len
182
+ ]
183
+
184
+ border1 = border1s[self.set_type]
185
+ border2 = border2s[self.set_type]
186
+
187
+ if self.features == 'M' or self.features == 'MS':
188
+ cols_data = df_raw.columns[1:]
189
+ df_data = df_raw[cols_data]
190
+ elif self.features == 'S':
191
+ df_data = df_raw[[self.target]]
192
+
193
+ if self.scale:
194
+ # Scaler is fit only on the training data
195
+ train_data = df_data.iloc[border1s[0]:border2s[0]]
196
+ self.scaler.fit(train_data.values)
197
+ data = self.scaler.transform(df_data.values)
198
+ else:
199
+ data = df_data.values
200
+
201
+ df_stamp = df_raw[['date']][border1:border2]
202
+ df_stamp['date'] = pd.to_datetime(df_stamp['date'])
203
+
204
+ if self.timeenc == 0:
205
+ df_stamp['month'] = df_stamp['date'].apply(lambda row: row.month)
206
+ df_stamp['day'] = df_stamp['date'].apply(lambda row: row.day)
207
+ df_stamp['weekday'] = df_stamp['date'].apply(lambda row: row.weekday())
208
+ df_stamp['hour'] = df_stamp['date'].apply(lambda row: row.hour)
209
+ df_stamp['minute'] = df_stamp['date'].apply(lambda row: row.minute // 15)
210
+ data_stamp = df_stamp.drop(columns=['date']).values
211
+ elif self.timeenc == 1:
212
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
213
+ data_stamp = data_stamp.transpose(1, 0)
214
+
215
+ self.data_x = data[border1:border2]
216
+ self.data_y = data[border1:border2]
217
+
218
+ if self.set_type == 0 and self.args.augmentation_ratio > 0:
219
+ self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
220
+
221
+ self.data_stamp = data_stamp
222
+
223
+ def __getitem__(self, index):
224
+ s_begin = index
225
+ s_end = s_begin + self.seq_len
226
+ r_begin = s_end - self.label_len
227
+ r_end = r_begin + self.label_len + self.pred_len
228
+
229
+ seq_x = self.data_x[s_begin:s_end]
230
+ if self.inverse:
231
+ seq_y = self.data_x[r_begin:r_begin + self.label_len]
232
+ else:
233
+ seq_y = self.data_y[r_begin:r_begin + self.label_len]
234
+ seq_x_mark = self.data_stamp[s_begin:s_end]
235
+ seq_y_mark = self.data_stamp[r_begin:r_end]
236
+
237
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
238
+
239
+ def __len__(self):
240
+ return len(self.data_x) - self.seq_len + 1
241
+
242
+ def inverse_transform(self, data):
243
+ return self.scaler.inverse_transform(data)
244
+
245
+ class Dataset_ETT_hour(Dataset):
246
+ def __init__(self, args, root_path, flag='train', size=None,
247
+ features='S', data_path='ETTh1.csv',
248
+ target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
249
+ # size [seq_len, label_len, pred_len]
250
+ self.args = args
251
+ # info
252
+ if size == None:
253
+ self.seq_len = 24 * 4 * 4
254
+ self.label_len = 24 * 4
255
+ self.pred_len = 24 * 4
256
+ else:
257
+ self.seq_len = size[0]
258
+ self.label_len = size[1]
259
+ self.pred_len = size[2]
260
+ # init
261
+ assert flag in ['train', 'test', 'val']
262
+ type_map = {'train': 0, 'val': 1, 'test': 2}
263
+ self.set_type = type_map[flag]
264
+
265
+ self.features = features
266
+ self.target = target
267
+ self.scale = scale
268
+ self.timeenc = timeenc
269
+ self.freq = freq
270
+
271
+ self.root_path = root_path
272
+ self.data_path = data_path
273
+ self.__read_data__()
274
+
275
+ def __read_data__(self):
276
+ self.scaler = StandardScaler()
277
+ df_raw = pd.read_csv(os.path.join(self.root_path,
278
+ self.data_path))
279
+
280
+ border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
281
+ border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
282
+ border1 = border1s[self.set_type]
283
+ border2 = border2s[self.set_type]
284
+
285
+ if self.features == 'M' or self.features == 'MS':
286
+ cols_data = df_raw.columns[1:]
287
+ df_data = df_raw[cols_data]
288
+ elif self.features == 'S':
289
+ df_data = df_raw[[self.target]]
290
+
291
+ if self.scale:
292
+ train_data = df_data[border1s[0]:border2s[0]]
293
+ self.scaler.fit(train_data.values)
294
+ data = self.scaler.transform(df_data.values)
295
+ else:
296
+ data = df_data.values
297
+
298
+ df_stamp = df_raw[['date']][border1:border2]
299
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
300
+ if self.timeenc == 0:
301
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
302
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
303
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
304
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
305
+ data_stamp = df_stamp.drop(['date'], 1).values
306
+ elif self.timeenc == 1:
307
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
308
+ data_stamp = data_stamp.transpose(1, 0)
309
+
310
+ self.data_x = data[border1:border2]
311
+ self.data_y = data[border1:border2]
312
+
313
+ if self.set_type == 0 and self.args.augmentation_ratio > 0:
314
+ self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
315
+
316
+ self.data_stamp = data_stamp
317
+
318
+ def __getitem__(self, index):
319
+ s_begin = index
320
+ s_end = s_begin + self.seq_len
321
+ r_begin = s_end - self.label_len
322
+ r_end = r_begin + self.label_len + self.pred_len
323
+
324
+ seq_x = self.data_x[s_begin:s_end]
325
+ seq_y = self.data_y[r_begin:r_end]
326
+ seq_x_mark = self.data_stamp[s_begin:s_end]
327
+ seq_y_mark = self.data_stamp[r_begin:r_end]
328
+
329
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
330
+
331
+ def __len__(self):
332
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
333
+
334
+ def inverse_transform(self, data):
335
+ return self.scaler.inverse_transform(data)
336
+
337
+
338
+ class Dataset_ETT_minute(Dataset):
339
+ def __init__(self, args, root_path, flag='train', size=None,
340
+ features='S', data_path='ETTm1.csv',
341
+ target='OT', scale=True, timeenc=0, freq='t', seasonal_patterns=None):
342
+ # size [seq_len, label_len, pred_len]
343
+ self.args = args
344
+ # info
345
+ if size == None:
346
+ self.seq_len = 24 * 4 * 4
347
+ self.label_len = 24 * 4
348
+ self.pred_len = 24 * 4
349
+ else:
350
+ self.seq_len = size[0]
351
+ self.label_len = size[1]
352
+ self.pred_len = size[2]
353
+ # init
354
+ assert flag in ['train', 'test', 'val']
355
+ type_map = {'train': 0, 'val': 1, 'test': 2}
356
+ self.set_type = type_map[flag]
357
+
358
+ self.features = features
359
+ self.target = target
360
+ self.scale = scale
361
+ self.timeenc = timeenc
362
+ self.freq = freq
363
+
364
+ self.root_path = root_path
365
+ self.data_path = data_path
366
+ self.__read_data__()
367
+
368
+ def __read_data__(self):
369
+ self.scaler = StandardScaler()
370
+ df_raw = pd.read_csv(os.path.join(self.root_path,
371
+ self.data_path))
372
+
373
+ border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
374
+ border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
375
+ border1 = border1s[self.set_type]
376
+ border2 = border2s[self.set_type]
377
+
378
+ if self.features == 'M' or self.features == 'MS':
379
+ cols_data = df_raw.columns[1:]
380
+ df_data = df_raw[cols_data]
381
+ elif self.features == 'S':
382
+ df_data = df_raw[[self.target]]
383
+
384
+ if self.scale:
385
+ train_data = df_data[border1s[0]:border2s[0]]
386
+ self.scaler.fit(train_data.values)
387
+ data = self.scaler.transform(df_data.values)
388
+ else:
389
+ data = df_data.values
390
+
391
+ df_stamp = df_raw[['date']][border1:border2]
392
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
393
+ if self.timeenc == 0:
394
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
395
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
396
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
397
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
398
+ df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
399
+ df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
400
+ data_stamp = df_stamp.drop(['date'], 1).values
401
+ elif self.timeenc == 1:
402
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
403
+ data_stamp = data_stamp.transpose(1, 0)
404
+
405
+ self.data_x = data[border1:border2]
406
+ self.data_y = data[border1:border2]
407
+
408
+ if self.set_type == 0 and self.args.augmentation_ratio > 0:
409
+ self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
410
+
411
+ self.data_stamp = data_stamp
412
+
413
+ def __getitem__(self, index):
414
+ s_begin = index
415
+ s_end = s_begin + self.seq_len
416
+ r_begin = s_end - self.label_len
417
+ r_end = r_begin + self.label_len + self.pred_len
418
+
419
+ seq_x = self.data_x[s_begin:s_end]
420
+ seq_y = self.data_y[r_begin:r_end]
421
+ seq_x_mark = self.data_stamp[s_begin:s_end]
422
+ seq_y_mark = self.data_stamp[r_begin:r_end]
423
+
424
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
425
+
426
+ def __len__(self):
427
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
428
+
429
+ def inverse_transform(self, data):
430
+ return self.scaler.inverse_transform(data)
431
+
432
+
433
+ class Dataset_Custom(Dataset):
434
+ def __init__(self, args, root_path, flag='train', size=None,
435
+ features='S', data_path='ETTh1.csv',
436
+ target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
437
+ # size [seq_len, label_len, pred_len]
438
+ self.args = args
439
+ # info
440
+ if size == None:
441
+ self.seq_len = 24 * 4 * 4
442
+ self.label_len = 24 * 4
443
+ self.pred_len = 24 * 4
444
+ else:
445
+ self.seq_len = size[0]
446
+ self.label_len = size[1]
447
+ self.pred_len = size[2]
448
+ # init
449
+ assert flag in ['train', 'test', 'val']
450
+ type_map = {'train': 0, 'val': 1, 'test': 2}
451
+ self.set_type = type_map[flag]
452
+
453
+ self.features = features
454
+ self.target = target
455
+ self.scale = scale
456
+ self.timeenc = timeenc
457
+ self.freq = freq
458
+
459
+ self.root_path = root_path
460
+ self.data_path = data_path
461
+ self.__read_data__()
462
+
463
+ def __read_data__(self):
464
+ self.scaler = StandardScaler()
465
+ df_raw = pd.read_csv(os.path.join(self.root_path,
466
+ self.data_path))
467
+
468
+ '''
469
+ df_raw.columns: ['date', ...(other features), target feature]
470
+ '''
471
+ cols = list(df_raw.columns)
472
+ cols.remove(self.target)
473
+ cols.remove('date')
474
+ df_raw = df_raw[['date'] + cols + [self.target]]
475
+ num_train = int(len(df_raw) * 0.7)
476
+ num_test = int(len(df_raw) * 0.2)
477
+ num_vali = len(df_raw) - num_train - num_test
478
+ border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
479
+ border2s = [num_train, num_train + num_vali, len(df_raw)]
480
+ border1 = border1s[self.set_type]
481
+ border2 = border2s[self.set_type]
482
+
483
+ if self.features == 'M' or self.features == 'MS':
484
+ cols_data = df_raw.columns[1:]
485
+ df_data = df_raw[cols_data]
486
+ elif self.features == 'S':
487
+ df_data = df_raw[[self.target]]
488
+
489
+ if self.scale:
490
+ train_data = df_data[border1s[0]:border2s[0]]
491
+ self.scaler.fit(train_data.values)
492
+ data = self.scaler.transform(df_data.values)
493
+ else:
494
+ data = df_data.values
495
+
496
+ df_stamp = df_raw[['date']][border1:border2]
497
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
498
+ if self.timeenc == 0:
499
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
500
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
501
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
502
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
503
+ data_stamp = df_stamp.drop(['date'], 1).values
504
+ elif self.timeenc == 1:
505
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
506
+ data_stamp = data_stamp.transpose(1, 0)
507
+
508
+ self.data_x = data[border1:border2]
509
+ self.data_y = data[border1:border2]
510
+
511
+ if self.set_type == 0 and self.args.augmentation_ratio > 0:
512
+ self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
513
+
514
+ self.data_stamp = data_stamp
515
+
516
+ def __getitem__(self, index):
517
+ s_begin = index
518
+ s_end = s_begin + self.seq_len
519
+ r_begin = s_end - self.label_len
520
+ r_end = r_begin + self.label_len + self.pred_len
521
+
522
+ seq_x = self.data_x[s_begin:s_end]
523
+ seq_y = self.data_y[r_begin:r_end]
524
+ seq_x_mark = self.data_stamp[s_begin:s_end]
525
+ seq_y_mark = self.data_stamp[r_begin:r_end]
526
+
527
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
528
+
529
+ def __len__(self):
530
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
531
+
532
+ def inverse_transform(self, data):
533
+ return self.scaler.inverse_transform(data)
534
+
535
+
536
+ class Dataset_M4(Dataset):
537
+ def __init__(self, args, root_path, flag='pred', size=None,
538
+ features='S', data_path='ETTh1.csv',
539
+ target='OT', scale=False, inverse=False, timeenc=0, freq='15min',
540
+ seasonal_patterns='Yearly'):
541
+ # size [seq_len, label_len, pred_len]
542
+ # init
543
+ self.features = features
544
+ self.target = target
545
+ self.scale = scale
546
+ self.inverse = inverse
547
+ self.timeenc = timeenc
548
+ self.root_path = root_path
549
+
550
+ self.seq_len = size[0]
551
+ self.label_len = size[1]
552
+ self.pred_len = size[2]
553
+
554
+ self.seasonal_patterns = seasonal_patterns
555
+ self.history_size = M4Meta.history_size[seasonal_patterns]
556
+ self.window_sampling_limit = int(self.history_size * self.pred_len)
557
+ self.flag = flag
558
+
559
+ self.__read_data__()
560
+
561
+ def __read_data__(self):
562
+ # M4Dataset.initialize()
563
+ if self.flag == 'train':
564
+ dataset = M4Dataset.load(training=True, dataset_file=self.root_path)
565
+ else:
566
+ dataset = M4Dataset.load(training=False, dataset_file=self.root_path)
567
+ training_values = np.array(
568
+ [v[~np.isnan(v)] for v in
569
+ dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies
570
+ self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]])
571
+ self.timeseries = [ts for ts in training_values]
572
+
573
+ def __getitem__(self, index):
574
+ insample = np.zeros((self.seq_len, 1))
575
+ insample_mask = np.zeros((self.seq_len, 1))
576
+ outsample = np.zeros((self.pred_len + self.label_len, 1))
577
+ outsample_mask = np.zeros((self.pred_len + self.label_len, 1)) # m4 dataset
578
+
579
+ sampled_timeseries = self.timeseries[index]
580
+ cut_point = np.random.randint(low=max(1, len(sampled_timeseries) - self.window_sampling_limit),
581
+ high=len(sampled_timeseries),
582
+ size=1)[0]
583
+
584
+ insample_window = sampled_timeseries[max(0, cut_point - self.seq_len):cut_point]
585
+ insample[-len(insample_window):, 0] = insample_window
586
+ insample_mask[-len(insample_window):, 0] = 1.0
587
+ outsample_window = sampled_timeseries[
588
+ cut_point - self.label_len:min(len(sampled_timeseries), cut_point + self.pred_len)]
589
+ outsample[:len(outsample_window), 0] = outsample_window
590
+ outsample_mask[:len(outsample_window), 0] = 1.0
591
+ return insample, outsample, insample_mask, outsample_mask
592
+
593
+ def __len__(self):
594
+ return len(self.timeseries)
595
+
596
+ def inverse_transform(self, data):
597
+ return self.scaler.inverse_transform(data)
598
+
599
+ def last_insample_window(self):
600
+ """
601
+ The last window of insample size of all timeseries.
602
+ This function does not support batching and does not reshuffle timeseries.
603
+
604
+ :return: Last insample window of all timeseries. Shape "timeseries, insample size"
605
+ """
606
+ insample = np.zeros((len(self.timeseries), self.seq_len))
607
+ insample_mask = np.zeros((len(self.timeseries), self.seq_len))
608
+ for i, ts in enumerate(self.timeseries):
609
+ ts_last_window = ts[-self.seq_len:]
610
+ insample[i, -len(ts):] = ts_last_window
611
+ insample_mask[i, -len(ts):] = 1.0
612
+ return insample, insample_mask
613
+
614
+
615
+ class PSMSegLoader(Dataset):
616
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
617
+ self.flag = flag
618
+ self.step = step
619
+ self.win_size = win_size
620
+ self.scaler = StandardScaler()
621
+ data = pd.read_csv(os.path.join(root_path, 'train.csv'))
622
+ data = data.values[:, 1:]
623
+ data = np.nan_to_num(data)
624
+ self.scaler.fit(data)
625
+ data = self.scaler.transform(data)
626
+ test_data = pd.read_csv(os.path.join(root_path, 'test.csv'))
627
+ test_data = test_data.values[:, 1:]
628
+ test_data = np.nan_to_num(test_data)
629
+ self.test = self.scaler.transform(test_data)
630
+ self.train = data
631
+ data_len = len(self.train)
632
+ self.val = self.train[(int)(data_len * 0.8):]
633
+ self.test_labels = pd.read_csv(os.path.join(root_path, 'test_label.csv')).values[:, 1:]
634
+ print("test:", self.test.shape)
635
+ print("train:", self.train.shape)
636
+
637
+ def __len__(self):
638
+ if self.flag == "train":
639
+ return (self.train.shape[0] - self.win_size) // self.step + 1
640
+ elif (self.flag == 'val'):
641
+ return (self.val.shape[0] - self.win_size) // self.step + 1
642
+ elif (self.flag == 'test'):
643
+ return (self.test.shape[0] - self.win_size) // self.step + 1
644
+ else:
645
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
646
+
647
+ def __getitem__(self, index):
648
+ index = index * self.step
649
+ if self.flag == "train":
650
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
651
+ elif (self.flag == 'val'):
652
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
653
+ elif (self.flag == 'test'):
654
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
655
+ self.test_labels[index:index + self.win_size])
656
+ else:
657
+ return np.float32(self.test[
658
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
659
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
660
+
661
+
662
+ class MSLSegLoader(Dataset):
663
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
664
+ self.flag = flag
665
+ self.step = step
666
+ self.win_size = win_size
667
+ self.scaler = StandardScaler()
668
+ data = np.load(os.path.join(root_path, "MSL_train.npy"))
669
+ self.scaler.fit(data)
670
+ data = self.scaler.transform(data)
671
+ test_data = np.load(os.path.join(root_path, "MSL_test.npy"))
672
+ self.test = self.scaler.transform(test_data)
673
+ self.train = data
674
+ data_len = len(self.train)
675
+ self.val = self.train[(int)(data_len * 0.8):]
676
+ self.test_labels = np.load(os.path.join(root_path, "MSL_test_label.npy"))
677
+ print("test:", self.test.shape)
678
+ print("train:", self.train.shape)
679
+
680
+ def __len__(self):
681
+ if self.flag == "train":
682
+ return (self.train.shape[0] - self.win_size) // self.step + 1
683
+ elif (self.flag == 'val'):
684
+ return (self.val.shape[0] - self.win_size) // self.step + 1
685
+ elif (self.flag == 'test'):
686
+ return (self.test.shape[0] - self.win_size) // self.step + 1
687
+ else:
688
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
689
+
690
+ def __getitem__(self, index):
691
+ index = index * self.step
692
+ if self.flag == "train":
693
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
694
+ elif (self.flag == 'val'):
695
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
696
+ elif (self.flag == 'test'):
697
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
698
+ self.test_labels[index:index + self.win_size])
699
+ else:
700
+ return np.float32(self.test[
701
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
702
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
703
+
704
+
705
+ class SMAPSegLoader(Dataset):
706
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
707
+ self.flag = flag
708
+ self.step = step
709
+ self.win_size = win_size
710
+ self.scaler = StandardScaler()
711
+ data = np.load(os.path.join(root_path, "SMAP_train.npy"))
712
+ self.scaler.fit(data)
713
+ data = self.scaler.transform(data)
714
+ test_data = np.load(os.path.join(root_path, "SMAP_test.npy"))
715
+ self.test = self.scaler.transform(test_data)
716
+ self.train = data
717
+ data_len = len(self.train)
718
+ self.val = self.train[(int)(data_len * 0.8):]
719
+ self.test_labels = np.load(os.path.join(root_path, "SMAP_test_label.npy"))
720
+ print("test:", self.test.shape)
721
+ print("train:", self.train.shape)
722
+
723
+ def __len__(self):
724
+
725
+ if self.flag == "train":
726
+ return (self.train.shape[0] - self.win_size) // self.step + 1
727
+ elif (self.flag == 'val'):
728
+ return (self.val.shape[0] - self.win_size) // self.step + 1
729
+ elif (self.flag == 'test'):
730
+ return (self.test.shape[0] - self.win_size) // self.step + 1
731
+ else:
732
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
733
+
734
+ def __getitem__(self, index):
735
+ index = index * self.step
736
+ if self.flag == "train":
737
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
738
+ elif (self.flag == 'val'):
739
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
740
+ elif (self.flag == 'test'):
741
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
742
+ self.test_labels[index:index + self.win_size])
743
+ else:
744
+ return np.float32(self.test[
745
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
746
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
747
+
748
+
749
+ class SMDSegLoader(Dataset):
750
+ def __init__(self, args, root_path, win_size, step=100, flag="train"):
751
+ self.flag = flag
752
+ self.step = step
753
+ self.win_size = win_size
754
+ self.scaler = StandardScaler()
755
+ data = np.load(os.path.join(root_path, "SMD_train.npy"))
756
+ self.scaler.fit(data)
757
+ data = self.scaler.transform(data)
758
+ test_data = np.load(os.path.join(root_path, "SMD_test.npy"))
759
+ self.test = self.scaler.transform(test_data)
760
+ self.train = data
761
+ data_len = len(self.train)
762
+ self.val = self.train[(int)(data_len * 0.8):]
763
+ self.test_labels = np.load(os.path.join(root_path, "SMD_test_label.npy"))
764
+
765
+ def __len__(self):
766
+ if self.flag == "train":
767
+ return (self.train.shape[0] - self.win_size) // self.step + 1
768
+ elif (self.flag == 'val'):
769
+ return (self.val.shape[0] - self.win_size) // self.step + 1
770
+ elif (self.flag == 'test'):
771
+ return (self.test.shape[0] - self.win_size) // self.step + 1
772
+ else:
773
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
774
+
775
+ def __getitem__(self, index):
776
+ index = index * self.step
777
+ if self.flag == "train":
778
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
779
+ elif (self.flag == 'val'):
780
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
781
+ elif (self.flag == 'test'):
782
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
783
+ self.test_labels[index:index + self.win_size])
784
+ else:
785
+ return np.float32(self.test[
786
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
787
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
788
+
789
+
790
+ class SWATSegLoader(Dataset):
791
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
792
+ self.flag = flag
793
+ self.step = step
794
+ self.win_size = win_size
795
+ self.scaler = StandardScaler()
796
+
797
+ train_data = pd.read_csv(os.path.join(root_path, 'swat_train2.csv'))
798
+ test_data = pd.read_csv(os.path.join(root_path, 'swat2.csv'))
799
+ labels = test_data.values[:, -1:]
800
+ train_data = train_data.values[:, :-1]
801
+ test_data = test_data.values[:, :-1]
802
+
803
+ self.scaler.fit(train_data)
804
+ train_data = self.scaler.transform(train_data)
805
+ test_data = self.scaler.transform(test_data)
806
+ self.train = train_data
807
+ self.test = test_data
808
+ data_len = len(self.train)
809
+ self.val = self.train[(int)(data_len * 0.8):]
810
+ self.test_labels = labels
811
+ print("test:", self.test.shape)
812
+ print("train:", self.train.shape)
813
+
814
+ def __len__(self):
815
+ """
816
+ Number of images in the object dataset.
817
+ """
818
+ if self.flag == "train":
819
+ return (self.train.shape[0] - self.win_size) // self.step + 1
820
+ elif (self.flag == 'val'):
821
+ return (self.val.shape[0] - self.win_size) // self.step + 1
822
+ elif (self.flag == 'test'):
823
+ return (self.test.shape[0] - self.win_size) // self.step + 1
824
+ else:
825
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
826
+
827
+ def __getitem__(self, index):
828
+ index = index * self.step
829
+ if self.flag == "train":
830
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
831
+ elif (self.flag == 'val'):
832
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
833
+ elif (self.flag == 'test'):
834
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
835
+ self.test_labels[index:index + self.win_size])
836
+ else:
837
+ return np.float32(self.test[
838
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
839
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
840
+
841
+
842
+ class UEAloader(Dataset):
843
+ """
844
+ Dataset class for datasets included in:
845
+ Time Series Classification Archive (www.timeseriesclassification.com)
846
+ Argument:
847
+ limit_size: float in (0, 1) for debug
848
+ Attributes:
849
+ all_df: (num_samples * seq_len, num_columns) dataframe indexed by integer indices, with multiple rows corresponding to the same index (sample).
850
+ Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature.
851
+ feature_df: (num_samples * seq_len, feat_dim) dataframe; contains the subset of columns of `all_df` which correspond to selected features
852
+ feature_names: names of columns contained in `feature_df` (same as feature_df.columns)
853
+ all_IDs: (num_samples,) series of IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() )
854
+ labels_df: (num_samples, num_labels) pd.DataFrame of label(s) for each sample
855
+ max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used.
856
+ (Moreover, script argument overrides this attribute)
857
+ """
858
+
859
+ def __init__(self, args, root_path, file_list=None, limit_size=None, flag=None):
860
+ self.args = args
861
+ self.root_path = root_path
862
+ self.flag = flag
863
+ self.all_df, self.labels_df = self.load_all(root_path, file_list=file_list, flag=flag)
864
+ self.all_IDs = self.all_df.index.unique() # all sample IDs (integer indices 0 ... num_samples-1)
865
+
866
+ if limit_size is not None:
867
+ if limit_size > 1:
868
+ limit_size = int(limit_size)
869
+ else: # interpret as proportion if in (0, 1]
870
+ limit_size = int(limit_size * len(self.all_IDs))
871
+ self.all_IDs = self.all_IDs[:limit_size]
872
+ self.all_df = self.all_df.loc[self.all_IDs]
873
+
874
+ # use all features
875
+ self.feature_names = self.all_df.columns
876
+ self.feature_df = self.all_df
877
+
878
+ # pre_process
879
+ normalizer = Normalizer()
880
+ self.feature_df = normalizer.normalize(self.feature_df)
881
+ print(len(self.all_IDs))
882
+
883
+ def load_all(self, root_path, file_list=None, flag=None):
884
+ """
885
+ Loads datasets from csv files contained in `root_path` into a dataframe, optionally choosing from `pattern`
886
+ Args:
887
+ root_path: directory containing all individual .csv files
888
+ file_list: optionally, provide a list of file paths within `root_path` to consider.
889
+ Otherwise, entire `root_path` contents will be used.
890
+ Returns:
891
+ all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files
892
+ labels_df: dataframe containing label(s) for each sample
893
+ """
894
+ # Select paths for training and evaluation
895
+ if file_list is None:
896
+ data_paths = glob.glob(os.path.join(root_path, '*')) # list of all paths
897
+ else:
898
+ data_paths = [os.path.join(root_path, p) for p in file_list]
899
+ if len(data_paths) == 0:
900
+ raise Exception('No files found using: {}'.format(os.path.join(root_path, '*')))
901
+ if flag is not None:
902
+ data_paths = list(filter(lambda x: re.search(flag, x), data_paths))
903
+ input_paths = [p for p in data_paths if os.path.isfile(p) and p.endswith('.ts')]
904
+ if len(input_paths) == 0:
905
+ pattern='*.ts'
906
+ raise Exception("No .ts files found using pattern: '{}'".format(pattern))
907
+
908
+ all_df, labels_df = self.load_single(input_paths[0]) # a single file contains dataset
909
+
910
+ return all_df, labels_df
911
+
912
+ def load_single(self, filepath):
913
+ df, labels = load_from_tsfile_to_dataframe(filepath, return_separate_X_and_y=True,
914
+ replace_missing_vals_with='NaN')
915
+ labels = pd.Series(labels, dtype="category")
916
+ self.class_names = labels.cat.categories
917
+ labels_df = pd.DataFrame(labels.cat.codes,
918
+ dtype=np.int8) # int8-32 gives an error when using nn.CrossEntropyLoss
919
+
920
+ lengths = df.applymap(
921
+ lambda x: len(x)).values # (num_samples, num_dimensions) array containing the length of each series
922
+
923
+ horiz_diffs = np.abs(lengths - np.expand_dims(lengths[:, 0], -1))
924
+
925
+ if np.sum(horiz_diffs) > 0: # if any row (sample) has varying length across dimensions
926
+ df = df.applymap(subsample)
927
+
928
+ lengths = df.applymap(lambda x: len(x)).values
929
+ vert_diffs = np.abs(lengths - np.expand_dims(lengths[0, :], 0))
930
+ if np.sum(vert_diffs) > 0: # if any column (dimension) has varying length across samples
931
+ self.max_seq_len = int(np.max(lengths[:, 0]))
932
+ else:
933
+ self.max_seq_len = lengths[0, 0]
934
+
935
+ # First create a (seq_len, feat_dim) dataframe for each sample, indexed by a single integer ("ID" of the sample)
936
+ # Then concatenate into a (num_samples * seq_len, feat_dim) dataframe, with multiple rows corresponding to the
937
+ # sample index (i.e. the same scheme as all datasets in this project)
938
+
939
+ df = pd.concat((pd.DataFrame({col: df.loc[row, col] for col in df.columns}).reset_index(drop=True).set_index(
940
+ pd.Series(lengths[row, 0] * [row])) for row in range(df.shape[0])), axis=0)
941
+
942
+ # Replace NaN values
943
+ grp = df.groupby(by=df.index)
944
+ df = grp.transform(interpolate_missing)
945
+
946
+ return df, labels_df
947
+
948
+ def instance_norm(self, case):
949
+ if self.root_path.count('EthanolConcentration') > 0: # special process for numerical stability
950
+ mean = case.mean(0, keepdim=True)
951
+ case = case - mean
952
+ stdev = torch.sqrt(torch.var(case, dim=1, keepdim=True, unbiased=False) + 1e-5)
953
+ case /= stdev
954
+ return case
955
+ else:
956
+ return case
957
+
958
+ def __getitem__(self, ind):
959
+ batch_x = self.feature_df.loc[self.all_IDs[ind]].values
960
+ labels = self.labels_df.loc[self.all_IDs[ind]].values
961
+ if self.flag == "TRAIN" and self.args.augmentation_ratio > 0:
962
+ num_samples = len(self.all_IDs)
963
+ num_columns = self.feature_df.shape[1]
964
+ seq_len = int(self.feature_df.shape[0] / num_samples)
965
+ batch_x = batch_x.reshape((1, seq_len, num_columns))
966
+ batch_x, labels, augmentation_tags = run_augmentation_single(batch_x, labels, self.args)
967
+
968
+ batch_x = batch_x.reshape((1 * seq_len, num_columns))
969
+
970
+ return self.instance_norm(torch.from_numpy(batch_x)), \
971
+ torch.from_numpy(labels)
972
+
973
+ def __len__(self):
974
+ return len(self.all_IDs)
975
+
976
+
977
+ class Dataset_Meteorology(Dataset):
978
+ def __init__(self, args, root_path, flag='train', size=None,
979
+ features='S', data_path='ETTh1.csv',
980
+ target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
981
+ # size [seq_len, label_len, pred_len]
982
+ # info
983
+ if size == None:
984
+ self.seq_len = 24 * 4 * 4
985
+ self.label_len = 24 * 4
986
+ self.pred_len = 24 * 4
987
+ else:
988
+ self.seq_len = size[0]
989
+ self.label_len = size[1]
990
+ self.pred_len = size[2]
991
+ # init
992
+ assert flag in ['train', 'test', 'val']
993
+ type_map = {'train': 0, 'val': 1, 'test': 2}
994
+ self.set_type = type_map[flag]
995
+
996
+ self.features = features
997
+ self.target = target
998
+ self.scale = scale
999
+ self.timeenc = timeenc
1000
+ self.freq = freq
1001
+
1002
+ self.root_path = root_path
1003
+ self.data_path = data_path
1004
+ self.__read_data__()
1005
+ self.stations_num = self.data_x.shape[-1]
1006
+ self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1
1007
+
1008
+ def __read_data__(self):
1009
+ self.scaler = StandardScaler()
1010
+ data = np.load(os.path.join(self.root_path, self.data_path)) # (L, S, 1)
1011
+ data = np.squeeze(data) # (L S)
1012
+ era5 = np.load(os.path.join(self.root_path, 'era5_norm.npy'))
1013
+
1014
+ # new add
1015
+ era5 = era5.reshape((era5.shape[0], 4, 9, era5.shape[-1]))
1016
+
1017
+ repeat_era5 = np.repeat(era5, 3, axis=0)[:len(data), :, :, :] # (L, 4, 9, S)
1018
+ repeat_era5 = repeat_era5.reshape(repeat_era5.shape[0], -1, repeat_era5.shape[3]) # (L, 36, S)
1019
+
1020
+ num_train = int(len(data) * 0.7)
1021
+ num_test = int(len(data) * 0.2)
1022
+ num_vali = len(data) - num_train - num_test
1023
+ border1s = [0, num_train - self.seq_len, len(data) - num_test - self.seq_len]
1024
+ border2s = [num_train, num_train + num_vali, len(data)]
1025
+ border1 = border1s[self.set_type]
1026
+ border2 = border2s[self.set_type]
1027
+
1028
+ if self.scale:
1029
+ train_data = data[border1s[0]:border2s[0]]
1030
+ self.scaler.fit(train_data)
1031
+ data = self.scaler.transform(data)
1032
+ else:
1033
+ pass
1034
+
1035
+ self.data_x = data[border1:border2]
1036
+ self.data_y = data[border1:border2]
1037
+ self.covariate = repeat_era5[border1:border2]
1038
+
1039
+ def __getitem__(self, index):
1040
+
1041
+ station_id = index // self.tot_len
1042
+ s_begin = index % self.tot_len
1043
+
1044
+ s_end = s_begin + self.seq_len
1045
+ r_begin = s_end - self.label_len
1046
+ r_end = r_begin + self.label_len + self.pred_len
1047
+
1048
+ seq_x = self.data_x[s_begin:s_end, station_id:station_id + 1]
1049
+ seq_y = self.data_y[r_begin:r_end, station_id:station_id + 1] # (L 1)
1050
+ t1 = self.covariate[s_begin:s_end, :, station_id:station_id + 1].squeeze()
1051
+ t2 = self.covariate[r_begin:r_end, :, station_id:station_id + 1].squeeze()
1052
+ seq_x = np.concatenate([t1, seq_x], axis=1)
1053
+ seq_y = np.concatenate([t2, seq_y], axis=1)
1054
+ seq_x_mark = torch.zeros((seq_x.shape[0], 1))
1055
+ seq_y_mark = torch.zeros((seq_y.shape[0], 1))
1056
+
1057
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
1058
+
1059
+ def __len__(self):
1060
+ l = (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.stations_num
1061
+ return l
1062
+
1063
+ def inverse_transform(self, data):
1064
+ return self.scaler.inverse_transform(data)
data_provider/m4.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is provided for the purposes of scientific reproducibility
2
+ # under the following limited license from Element AI Inc. The code is an
3
+ # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
4
+ # expansion analysis for interpretable time series forecasting,
5
+ # https://arxiv.org/abs/1905.10437). The copyright to the source code is
6
+ # licensed under the Creative Commons - Attribution-NonCommercial 4.0
7
+ # International license (CC BY-NC 4.0):
8
+ # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
9
+ # for the benefit of third parties or internally in production) requires an
10
+ # explicit license. The subject-matter of the N-BEATS model and associated
11
+ # materials are the property of Element AI Inc. and may be subject to patent
12
+ # protection. No license to patents is granted hereunder (whether express or
13
+ # implied). Copyright © 2020 Element AI Inc. All rights reserved.
14
+
15
+ """
16
+ M4 Dataset
17
+ """
18
+ import logging
19
+ import os
20
+ from collections import OrderedDict
21
+ from dataclasses import dataclass
22
+ from glob import glob
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ import patoolib
27
+ from tqdm import tqdm
28
+ import logging
29
+ import os
30
+ import pathlib
31
+ import sys
32
+ from urllib import request
33
+
34
+
35
+ def url_file_name(url: str) -> str:
36
+ """
37
+ Extract file name from url.
38
+
39
+ :param url: URL to extract file name from.
40
+ :return: File name.
41
+ """
42
+ return url.split('/')[-1] if len(url) > 0 else ''
43
+
44
+
45
+ def download(url: str, file_path: str) -> None:
46
+ """
47
+ Download a file to the given path.
48
+
49
+ :param url: URL to download
50
+ :param file_path: Where to download the content.
51
+ """
52
+
53
+ def progress(count, block_size, total_size):
54
+ progress_pct = float(count * block_size) / float(total_size) * 100.0
55
+ sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct))
56
+ sys.stdout.flush()
57
+
58
+ if not os.path.isfile(file_path):
59
+ opener = request.build_opener()
60
+ opener.addheaders = [('User-agent', 'Mozilla/5.0')]
61
+ request.install_opener(opener)
62
+ pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True)
63
+ f, _ = request.urlretrieve(url, file_path, progress)
64
+ sys.stdout.write('\n')
65
+ sys.stdout.flush()
66
+ file_info = os.stat(f)
67
+ logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.')
68
+ else:
69
+ file_info = os.stat(file_path)
70
+ logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.')
71
+
72
+
73
+ @dataclass()
74
+ class M4Dataset:
75
+ ids: np.ndarray
76
+ groups: np.ndarray
77
+ frequencies: np.ndarray
78
+ horizons: np.ndarray
79
+ values: np.ndarray
80
+
81
+ @staticmethod
82
+ def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset':
83
+ """
84
+ Load cached dataset.
85
+
86
+ :param training: Load training part if_inverted training is True, test part otherwise.
87
+ """
88
+ info_file = os.path.join(dataset_file, 'M4-info.csv')
89
+ train_cache_file = os.path.join(dataset_file, 'training.npz')
90
+ test_cache_file = os.path.join(dataset_file, 'test.npz')
91
+ m4_info = pd.read_csv(info_file)
92
+ return M4Dataset(ids=m4_info.M4id.values,
93
+ groups=m4_info.SP.values,
94
+ frequencies=m4_info.Frequency.values,
95
+ horizons=m4_info.Horizon.values,
96
+ values=np.load(
97
+ train_cache_file if training else test_cache_file,
98
+ allow_pickle=True))
99
+
100
+
101
+ @dataclass()
102
+ class M4Meta:
103
+ seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly']
104
+ horizons = [6, 8, 18, 13, 14, 48]
105
+ frequencies = [1, 4, 12, 1, 1, 24]
106
+ horizons_map = {
107
+ 'Yearly': 6,
108
+ 'Quarterly': 8,
109
+ 'Monthly': 18,
110
+ 'Weekly': 13,
111
+ 'Daily': 14,
112
+ 'Hourly': 48
113
+ } # different predict length
114
+ frequency_map = {
115
+ 'Yearly': 1,
116
+ 'Quarterly': 4,
117
+ 'Monthly': 12,
118
+ 'Weekly': 1,
119
+ 'Daily': 1,
120
+ 'Hourly': 24
121
+ }
122
+ history_size = {
123
+ 'Yearly': 1.5,
124
+ 'Quarterly': 1.5,
125
+ 'Monthly': 1.5,
126
+ 'Weekly': 10,
127
+ 'Daily': 10,
128
+ 'Hourly': 10
129
+ } # from interpretable.gin
130
+
131
+
132
+ def load_m4_info() -> pd.DataFrame:
133
+ """
134
+ Load M4Info file.
135
+
136
+ :return: Pandas DataFrame of M4Info.
137
+ """
138
+ return pd.read_csv(INFO_FILE_PATH)
data_provider/uea.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+
6
+
7
+ def collate_fn(data, max_len=None):
8
+ """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create
9
+ Args:
10
+ data: len(batch_size) list of tuples (X, y).
11
+ - X: torch tensor of shape (seq_length, feat_dim); variable seq_length.
12
+ - y: torch tensor of shape (num_labels,) : class indices or numerical targets
13
+ (for classification or regression, respectively). num_labels > 1 for multi-task models
14
+ max_len: global fixed sequence length. Used for architectures requiring fixed length input,
15
+ where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s
16
+ Returns:
17
+ X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input)
18
+ targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output)
19
+ target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor
20
+ 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values
21
+ padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding
22
+ """
23
+
24
+ batch_size = len(data)
25
+ features, labels = zip(*data)
26
+
27
+ # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
28
+ lengths = [X.shape[0] for X in features] # original sequence length for each time series
29
+ if max_len is None:
30
+ max_len = max(lengths)
31
+
32
+ X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim)
33
+ for i in range(batch_size):
34
+ end = min(lengths[i], max_len)
35
+ X[i, :end, :] = features[i][:end, :]
36
+
37
+ targets = torch.stack(labels, dim=0) # (batch_size, num_labels)
38
+
39
+ padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
40
+ max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep
41
+
42
+ return X, targets, padding_masks
43
+
44
+
45
+ def padding_mask(lengths, max_len=None):
46
+ """
47
+ Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
48
+ where 1 means keep element at this position (time step)
49
+ """
50
+ batch_size = lengths.numel()
51
+ max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types
52
+ return (torch.arange(0, max_len, device=lengths.device)
53
+ .type_as(lengths)
54
+ .repeat(batch_size, 1)
55
+ .lt(lengths.unsqueeze(1)))
56
+
57
+
58
+ class Normalizer(object):
59
+ """
60
+ Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization.
61
+ """
62
+
63
+ def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None):
64
+ """
65
+ Args:
66
+ norm_type: choose from:
67
+ "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps)
68
+ "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows)
69
+ mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values
70
+ """
71
+
72
+ self.norm_type = norm_type
73
+ self.mean = mean
74
+ self.std = std
75
+ self.min_val = min_val
76
+ self.max_val = max_val
77
+
78
+ def normalize(self, df):
79
+ """
80
+ Args:
81
+ df: input dataframe
82
+ Returns:
83
+ df: normalized dataframe
84
+ """
85
+ if self.norm_type == "standardization":
86
+ if self.mean is None:
87
+ self.mean = df.mean()
88
+ self.std = df.std()
89
+ return (df - self.mean) / (self.std + np.finfo(float).eps)
90
+
91
+ elif self.norm_type == "minmax":
92
+ if self.max_val is None:
93
+ self.max_val = df.max()
94
+ self.min_val = df.min()
95
+ return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps)
96
+
97
+ elif self.norm_type == "per_sample_std":
98
+ grouped = df.groupby(by=df.index)
99
+ return (df - grouped.transform('mean')) / grouped.transform('std')
100
+
101
+ elif self.norm_type == "per_sample_minmax":
102
+ grouped = df.groupby(by=df.index)
103
+ min_vals = grouped.transform('min')
104
+ return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps)
105
+
106
+ else:
107
+ raise (NameError(f'Normalize method "{self.norm_type}" not implemented'))
108
+
109
+
110
+ def interpolate_missing(y):
111
+ """
112
+ Replaces NaN values in pd.Series `y` using linear interpolation
113
+ """
114
+ if y.isna().any():
115
+ y = y.interpolate(method='linear', limit_direction='both')
116
+ return y
117
+
118
+
119
+ def subsample(y, limit=256, factor=2):
120
+ """
121
+ If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor
122
+ """
123
+ if len(y) > limit:
124
+ return y[::factor].reset_index(drop=True)
125
+ return y
exp/.DS_Store ADDED
Binary file (6.15 kB). View file
 
exp/__init__.py ADDED
File without changes
exp/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (142 Bytes). View file