Spaces:
Sleeping
Sleeping
SeungHyeok Jang
commited on
Commit
·
30e7412
1
Parent(s):
23760df
Add automation pipeline with GitHub Actions integration
Browse files- .DS_Store +0 -0
- .github/workflows/tide_scheduler.yml +122 -0
- app.py +3 -0
- automation/__init__.py +0 -0
- automation/data_collector.py +176 -0
- automation/data_processor.py +279 -0
- automation/internal_api.py +174 -0
- automation/prediction_updater.py +0 -0
- config.py +9 -0
- models/.DS_Store +0 -0
- models/Autoformer.py +0 -157
- models/Crossformer.py +0 -145
- models/DLinear.py +0 -110
- models/ETSformer.py +0 -110
- models/FEDformer.py +0 -176
- models/FiLM.py +0 -268
- models/FreTS.py +0 -118
- models/Informer.py +0 -147
- models/Koopa.py +0 -337
- models/LightTS.py +0 -165
- models/MICN.py +0 -221
- models/Mamba.py +0 -50
- models/MambaSimple.py +0 -162
- models/Nonstationary_Transformer.py +0 -218
- models/PatchTST.py +0 -227
- models/Pyraformer.py +0 -101
- models/Reformer.py +0 -132
- models/SCINet.py +0 -188
- models/SegRNN.py +0 -119
- models/TSMixer.py +0 -54
- models/TemporalFusionTransformer.py +0 -309
- models/TiDE.py +0 -145
- models/Transformer.py +0 -124
- models/iTransformer.py +0 -132
- requirements.txt +5 -1
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
.github/workflows/tide_scheduler.yml
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Tide Data Automation
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
schedule:
|
| 5 |
+
# 5분마다 데이터 수집 및 처리
|
| 6 |
+
- cron: '*/5 * * * *'
|
| 7 |
+
|
| 8 |
+
workflow_dispatch: # 수동 실행 옵션
|
| 9 |
+
inputs:
|
| 10 |
+
task_type:
|
| 11 |
+
description: 'Task to run'
|
| 12 |
+
required: true
|
| 13 |
+
default: 'all'
|
| 14 |
+
type: choice
|
| 15 |
+
options:
|
| 16 |
+
- all
|
| 17 |
+
- collect_data
|
| 18 |
+
- update_predictions
|
| 19 |
+
|
| 20 |
+
env:
|
| 21 |
+
HF_SPACE_URL: https://alwaysgood-my-tide-env.hf.space
|
| 22 |
+
|
| 23 |
+
jobs:
|
| 24 |
+
collect_and_process:
|
| 25 |
+
runs-on: ubuntu-latest
|
| 26 |
+
timeout-minutes: 5
|
| 27 |
+
|
| 28 |
+
steps:
|
| 29 |
+
- name: Check current time
|
| 30 |
+
id: time_check
|
| 31 |
+
run: |
|
| 32 |
+
MINUTE=$(date +%M)
|
| 33 |
+
echo "current_minute=$MINUTE" >> $GITHUB_OUTPUT
|
| 34 |
+
|
| 35 |
+
# 10분 단위 체크
|
| 36 |
+
if [ $((MINUTE % 10)) -eq 0 ]; then
|
| 37 |
+
echo "should_predict=true" >> $GITHUB_OUTPUT
|
| 38 |
+
else
|
| 39 |
+
echo "should_predict=false" >> $GITHUB_OUTPUT
|
| 40 |
+
fi
|
| 41 |
+
|
| 42 |
+
- name: Collect and Process Data (5분 간격)
|
| 43 |
+
run: |
|
| 44 |
+
response=$(curl -X POST "${{ env.HF_SPACE_URL }}/api/internal/collect_data" \
|
| 45 |
+
-H "Authorization: Bearer ${{ secrets.INTERNAL_API_KEY }}" \
|
| 46 |
+
-H "Content-Type: application/json" \
|
| 47 |
+
-d '{
|
| 48 |
+
"task": "collect_and_process",
|
| 49 |
+
"timestamp": "'$(date -u +%Y-%m-%dT%H:%M:%S)'"
|
| 50 |
+
}' \
|
| 51 |
+
-w "\n%{http_code}" \
|
| 52 |
+
-s)
|
| 53 |
+
|
| 54 |
+
http_code=$(echo "$response" | tail -n1)
|
| 55 |
+
body=$(echo "$response" | head -n-1)
|
| 56 |
+
|
| 57 |
+
if [ "$http_code" != "200" ]; then
|
| 58 |
+
echo "❌ Data collection failed with status $http_code"
|
| 59 |
+
echo "Response: $body"
|
| 60 |
+
exit 1
|
| 61 |
+
fi
|
| 62 |
+
|
| 63 |
+
echo "✅ Data collection successful"
|
| 64 |
+
echo "$body" | jq '.'
|
| 65 |
+
|
| 66 |
+
- name: Update Predictions (10분 간격)
|
| 67 |
+
if: steps.time_check.outputs.should_predict == 'true'
|
| 68 |
+
run: |
|
| 69 |
+
response=$(curl -X POST "${{ env.HF_SPACE_URL }}/api/internal/update_predictions" \
|
| 70 |
+
-H "Authorization: Bearer ${{ secrets.INTERNAL_API_KEY }}" \
|
| 71 |
+
-H "Content-Type: application/json" \
|
| 72 |
+
-d '{
|
| 73 |
+
"task": "update_predictions",
|
| 74 |
+
"timestamp": "'$(date -u +%Y-%m-%dT%H:%M:%S)'"
|
| 75 |
+
}' \
|
| 76 |
+
-w "\n%{http_code}" \
|
| 77 |
+
-s)
|
| 78 |
+
|
| 79 |
+
http_code=$(echo "$response" | tail -n1)
|
| 80 |
+
body=$(echo "$response" | head -n-1)
|
| 81 |
+
|
| 82 |
+
if [ "$http_code" != "200" ]; then
|
| 83 |
+
echo "❌ Prediction update failed with status $http_code"
|
| 84 |
+
echo "Response: $body"
|
| 85 |
+
exit 1
|
| 86 |
+
fi
|
| 87 |
+
|
| 88 |
+
echo "✅ Predictions updated successfully"
|
| 89 |
+
echo "$body" | jq '.'
|
| 90 |
+
|
| 91 |
+
- name: Send notification on failure
|
| 92 |
+
if: failure()
|
| 93 |
+
run: |
|
| 94 |
+
# Slack, Discord, 또는 이메일 알림 전송
|
| 95 |
+
echo "Pipeline failed at $(date)"
|
| 96 |
+
# curl -X POST slack_webhook_url ...
|
| 97 |
+
|
| 98 |
+
health_check:
|
| 99 |
+
runs-on: ubuntu-latest
|
| 100 |
+
needs: collect_and_process
|
| 101 |
+
|
| 102 |
+
steps:
|
| 103 |
+
- name: Verify System Health
|
| 104 |
+
run: |
|
| 105 |
+
response=$(curl -s "${{ env.HF_SPACE_URL }}/api/health")
|
| 106 |
+
echo "$response" | jq '.'
|
| 107 |
+
|
| 108 |
+
status=$(echo "$response" | jq -r '.status')
|
| 109 |
+
if [ "$status" != "healthy" ]; then
|
| 110 |
+
echo "⚠️ System is not healthy: $status"
|
| 111 |
+
fi
|
| 112 |
+
|
| 113 |
+
- name: Check Data Freshness
|
| 114 |
+
run: |
|
| 115 |
+
response=$(curl -s "${{ env.HF_SPACE_URL }}/api/internal/data_freshness")
|
| 116 |
+
echo "$response" | jq '.'
|
| 117 |
+
|
| 118 |
+
# 데이터가 15분 이상 오래된 경우 경고
|
| 119 |
+
minutes_old=$(echo "$response" | jq -r '.oldest_data_minutes')
|
| 120 |
+
if [ "$minutes_old" -gt 15 ]; then
|
| 121 |
+
echo "⚠️ Data is stale: ${minutes_old} minutes old"
|
| 122 |
+
fi
|
app.py
CHANGED
|
@@ -32,6 +32,9 @@ app = FastAPI(
|
|
| 32 |
version="1.0.0",
|
| 33 |
)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
# 2. FastAPI API 엔드포인트 정의
|
| 36 |
@app.get("/api/health", tags=["Status"])
|
| 37 |
def health():
|
|
|
|
| 32 |
version="1.0.0",
|
| 33 |
)
|
| 34 |
|
| 35 |
+
from automation.internal_api import register_internal_routes
|
| 36 |
+
register_internal_routes(app)
|
| 37 |
+
|
| 38 |
# 2. FastAPI API 엔드포인트 정의
|
| 39 |
@app.get("/api/health", tags=["Status"])
|
| 40 |
def health():
|
automation/__init__.py
ADDED
|
File without changes
|
automation/data_collector.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
데이터 수집 모듈
|
| 3 |
+
공공 API에서 실시간 조위 데이터 수집
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import aiohttp
|
| 7 |
+
import asyncio
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import Dict, List, Optional
|
| 12 |
+
import logging
|
| 13 |
+
from config import STATIONS, KHOA_API_KEY
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class DataCollector:
|
| 18 |
+
"""실시간 조위 데이터 수집기"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.api_base_url = "http://www.khoa.go.kr/api/oceangrid/tideObsRecent/search.do"
|
| 22 |
+
self.api_key = KHOA_API_KEY # 환경변수로 관리
|
| 23 |
+
self.stations = STATIONS
|
| 24 |
+
|
| 25 |
+
async def collect_station_data(self, station_id: str) -> Dict:
|
| 26 |
+
"""단일 관측소 데이터 수집"""
|
| 27 |
+
|
| 28 |
+
params = {
|
| 29 |
+
"ServiceKey": self.api_key,
|
| 30 |
+
"ObsCode": station_id,
|
| 31 |
+
"ResultType": "json",
|
| 32 |
+
"DataType": "tideObs", # 실시간 관측 데이터
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
async with aiohttp.ClientSession() as session:
|
| 37 |
+
async with session.get(self.api_base_url, params=params, timeout=30) as response:
|
| 38 |
+
if response.status == 200:
|
| 39 |
+
data = await response.json()
|
| 40 |
+
|
| 41 |
+
# 데이터 파싱
|
| 42 |
+
if data.get("result", {}).get("data"):
|
| 43 |
+
observations = data["result"]["data"]
|
| 44 |
+
|
| 45 |
+
# 최신 데이터만 추출
|
| 46 |
+
latest_obs = observations[0] if observations else None
|
| 47 |
+
|
| 48 |
+
if latest_obs:
|
| 49 |
+
return {
|
| 50 |
+
"station_id": station_id,
|
| 51 |
+
"observed_at": latest_obs.get("record_time"),
|
| 52 |
+
"tide_level": float(latest_obs.get("tide_level", 0)),
|
| 53 |
+
"air_temp": float(latest_obs.get("air_temp", 0)),
|
| 54 |
+
"water_temp": float(latest_obs.get("water_temp", 0)),
|
| 55 |
+
"air_pres": float(latest_obs.get("air_press", 1013)),
|
| 56 |
+
"wind_dir": float(latest_obs.get("wind_dir", 0)),
|
| 57 |
+
"wind_speed": float(latest_obs.get("wind_speed", 0)),
|
| 58 |
+
"status": "success"
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
logger.warning(f"Failed to collect data for {station_id}: Status {response.status}")
|
| 62 |
+
return {"station_id": station_id, "status": "failed", "error": f"HTTP {response.status}"}
|
| 63 |
+
|
| 64 |
+
except asyncio.TimeoutError:
|
| 65 |
+
logger.error(f"Timeout collecting data for {station_id}")
|
| 66 |
+
return {"station_id": station_id, "status": "timeout"}
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Error collecting data for {station_id}: {str(e)}")
|
| 69 |
+
return {"station_id": station_id, "status": "error", "error": str(e)}
|
| 70 |
+
|
| 71 |
+
async def collect_all_stations(self) -> List[Dict]:
|
| 72 |
+
"""모든 관측소 데이터 병렬 수집"""
|
| 73 |
+
|
| 74 |
+
tasks = [
|
| 75 |
+
self.collect_station_data(station_id)
|
| 76 |
+
for station_id in self.stations
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
results = await asyncio.gather(*tasks)
|
| 80 |
+
|
| 81 |
+
# 성공한 데이터만 필터링
|
| 82 |
+
valid_results = [r for r in results if r.get("status") == "success"]
|
| 83 |
+
|
| 84 |
+
logger.info(f"Collected data from {len(valid_results)}/{len(self.stations)} stations")
|
| 85 |
+
|
| 86 |
+
return valid_results
|
| 87 |
+
|
| 88 |
+
async def collect_with_retry(self, max_retries: int = 3) -> List[Dict]:
|
| 89 |
+
"""재시도 로직을 포함한 데이터 수집"""
|
| 90 |
+
|
| 91 |
+
all_data = []
|
| 92 |
+
failed_stations = list(self.stations)
|
| 93 |
+
|
| 94 |
+
for attempt in range(max_retries):
|
| 95 |
+
if not failed_stations:
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
logger.info(f"Collection attempt {attempt + 1}/{max_retries} for {len(failed_stations)} stations")
|
| 99 |
+
|
| 100 |
+
tasks = [
|
| 101 |
+
self.collect_station_data(station_id)
|
| 102 |
+
for station_id in failed_stations
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
results = await asyncio.gather(*tasks)
|
| 106 |
+
|
| 107 |
+
# 성공/실패 분류
|
| 108 |
+
newly_succeeded = []
|
| 109 |
+
still_failed = []
|
| 110 |
+
|
| 111 |
+
for result in results:
|
| 112 |
+
if result.get("status") == "success":
|
| 113 |
+
newly_succeeded.append(result)
|
| 114 |
+
else:
|
| 115 |
+
still_failed.append(result["station_id"])
|
| 116 |
+
|
| 117 |
+
all_data.extend(newly_succeeded)
|
| 118 |
+
failed_stations = still_failed
|
| 119 |
+
|
| 120 |
+
if failed_stations and attempt < max_retries - 1:
|
| 121 |
+
# 재시도 전 대기
|
| 122 |
+
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
| 123 |
+
|
| 124 |
+
if failed_stations:
|
| 125 |
+
logger.warning(f"Failed to collect data from stations: {failed_stations}")
|
| 126 |
+
|
| 127 |
+
return all_data
|
| 128 |
+
|
| 129 |
+
def validate_data(self, data: Dict) -> bool:
|
| 130 |
+
"""데이터 유효성 검증"""
|
| 131 |
+
|
| 132 |
+
# 필수 필드 확인
|
| 133 |
+
required_fields = ["station_id", "observed_at", "tide_level"]
|
| 134 |
+
if not all(field in data for field in required_fields):
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
# 범위 검증
|
| 138 |
+
tide_level = data.get("tide_level", 0)
|
| 139 |
+
if not -100 <= tide_level <= 1000: # cm 단위
|
| 140 |
+
logger.warning(f"Invalid tide level: {tide_level} for station {data['station_id']}")
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
# 시간 검증 (너무 오래된 데이터 제외)
|
| 144 |
+
try:
|
| 145 |
+
obs_time = datetime.fromisoformat(data["observed_at"])
|
| 146 |
+
if (datetime.now() - obs_time).total_seconds() > 3600: # 1시간 이상 오래된 데이터
|
| 147 |
+
logger.warning(f"Stale data for station {data['station_id']}: {data['observed_at']}")
|
| 148 |
+
return False
|
| 149 |
+
except:
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
return True
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class MockDataCollector(DataCollector):
|
| 156 |
+
"""테스트용 모의 데이터 수집기"""
|
| 157 |
+
|
| 158 |
+
async def collect_station_data(self, station_id: str) -> Dict:
|
| 159 |
+
"""모의 데이터 생성"""
|
| 160 |
+
|
| 161 |
+
# 실제 API 대신 랜덤 데이터 생성
|
| 162 |
+
import random
|
| 163 |
+
|
| 164 |
+
base_tide = 300 + 200 * np.sin(datetime.now().timestamp() / 3600 * np.pi / 6)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"station_id": station_id,
|
| 168 |
+
"observed_at": datetime.now().isoformat(),
|
| 169 |
+
"tide_level": base_tide + random.uniform(-10, 10),
|
| 170 |
+
"air_temp": 20 + random.uniform(-5, 5),
|
| 171 |
+
"water_temp": 18 + random.uniform(-3, 3),
|
| 172 |
+
"air_pres": 1013 + random.uniform(-10, 10),
|
| 173 |
+
"wind_dir": random.uniform(0, 360),
|
| 174 |
+
"wind_speed": random.uniform(0, 20),
|
| 175 |
+
"status": "success"
|
| 176 |
+
}
|
automation/data_processor.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
데이터 처리 모듈
|
| 3 |
+
수집된 데이터의 결측치 처리, 리샘플링, 저장
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
from scipy import interpolate
|
| 10 |
+
from typing import List, Dict, Optional
|
| 11 |
+
import logging
|
| 12 |
+
from supabase_utils import get_supabase_client
|
| 13 |
+
from config import DATA_COLLECTION_CONFIG
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class DataProcessor:
|
| 18 |
+
"""데이터 처리 및 저장 클래스"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.client = get_supabase_client()
|
| 22 |
+
self.resample_interval = DATA_COLLECTION_CONFIG["resample_interval"]
|
| 23 |
+
self.missing_threshold = DATA_COLLECTION_CONFIG["missing_threshold_minutes"]
|
| 24 |
+
|
| 25 |
+
async def process_data(self, raw_data: List[Dict]) -> pd.DataFrame:
|
| 26 |
+
"""원시 데이터 처리"""
|
| 27 |
+
|
| 28 |
+
if not raw_data:
|
| 29 |
+
logger.warning("No data to process")
|
| 30 |
+
return pd.DataFrame()
|
| 31 |
+
|
| 32 |
+
# DataFrame으로 변환
|
| 33 |
+
df = pd.DataFrame(raw_data)
|
| 34 |
+
df['observed_at'] = pd.to_datetime(df['observed_at'])
|
| 35 |
+
df.set_index('observed_at', inplace=True)
|
| 36 |
+
|
| 37 |
+
# 관측소별로 처리
|
| 38 |
+
processed_frames = []
|
| 39 |
+
|
| 40 |
+
for station_id in df['station_id'].unique():
|
| 41 |
+
station_data = df[df['station_id'] == station_id].copy()
|
| 42 |
+
|
| 43 |
+
# 1. 결측치 처리
|
| 44 |
+
station_data = self.handle_missing_values(station_data)
|
| 45 |
+
|
| 46 |
+
# 2. 이상치 제거
|
| 47 |
+
station_data = self.remove_outliers(station_data)
|
| 48 |
+
|
| 49 |
+
# 3. 5분 리샘플링
|
| 50 |
+
station_data = self.resample_data(station_data)
|
| 51 |
+
|
| 52 |
+
# 4. Residual 계산
|
| 53 |
+
station_data = self.calculate_residual(station_data, station_id)
|
| 54 |
+
|
| 55 |
+
processed_frames.append(station_data)
|
| 56 |
+
|
| 57 |
+
# 모든 관측소 데이터 결합
|
| 58 |
+
if processed_frames:
|
| 59 |
+
result = pd.concat(processed_frames, ignore_index=True)
|
| 60 |
+
return result
|
| 61 |
+
|
| 62 |
+
return pd.DataFrame()
|
| 63 |
+
|
| 64 |
+
def handle_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 65 |
+
"""결측치 처리 - 10분 미만은 코사인 보간"""
|
| 66 |
+
|
| 67 |
+
for column in ['tide_level', 'air_temp', 'air_pres', 'wind_speed', 'wind_dir']:
|
| 68 |
+
if column not in df.columns:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
# 결측치 찾기
|
| 72 |
+
missing_mask = df[column].isna()
|
| 73 |
+
|
| 74 |
+
if not missing_mask.any():
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
# 연속된 결측치 그룹 찾기
|
| 78 |
+
missing_groups = self.find_missing_groups(missing_mask)
|
| 79 |
+
|
| 80 |
+
for start_idx, end_idx in missing_groups:
|
| 81 |
+
duration_minutes = (df.index[end_idx] - df.index[start_idx]).total_seconds() / 60
|
| 82 |
+
|
| 83 |
+
if duration_minutes < self.missing_threshold:
|
| 84 |
+
# 10분 미만: 코사인 보간 (조위 특성 반영)
|
| 85 |
+
if column == 'tide_level':
|
| 86 |
+
df[column] = self.cosine_interpolation(df[column], start_idx, end_idx)
|
| 87 |
+
else:
|
| 88 |
+
# 다른 기상 데이터는 선형 보간
|
| 89 |
+
df[column].interpolate(method='linear', inplace=True)
|
| 90 |
+
else:
|
| 91 |
+
# 10분 이상: 예측값 사용 (별도 처리 필요)
|
| 92 |
+
logger.warning(f"Long missing period ({duration_minutes:.1f} min) for {column}")
|
| 93 |
+
# 일단 forward fill 사용
|
| 94 |
+
df[column].fillna(method='ffill', inplace=True)
|
| 95 |
+
|
| 96 |
+
return df
|
| 97 |
+
|
| 98 |
+
def cosine_interpolation(self, series: pd.Series, start_idx: int, end_idx: int) -> pd.Series:
|
| 99 |
+
"""코사인 보간 (조위의 주기적 특성 반영)"""
|
| 100 |
+
|
| 101 |
+
# 유효한 데이터 포인트
|
| 102 |
+
valid_mask = ~series.isna()
|
| 103 |
+
valid_indices = np.where(valid_mask)[0]
|
| 104 |
+
|
| 105 |
+
if len(valid_indices) < 2:
|
| 106 |
+
return series
|
| 107 |
+
|
| 108 |
+
# PCHIP 보간 (Piecewise Cubic Hermite Interpolating Polynomial)
|
| 109 |
+
# 부드러운 곡선을 생성하며 오버슈팅을 방지
|
| 110 |
+
interp_func = interpolate.PchipInterpolator(
|
| 111 |
+
valid_indices,
|
| 112 |
+
series.iloc[valid_indices].values,
|
| 113 |
+
extrapolate=False
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 결측 구간 보간
|
| 117 |
+
missing_indices = np.arange(start_idx, end_idx + 1)
|
| 118 |
+
interpolated_values = interp_func(missing_indices)
|
| 119 |
+
|
| 120 |
+
# 결과 적용
|
| 121 |
+
result = series.copy()
|
| 122 |
+
result.iloc[missing_indices] = interpolated_values
|
| 123 |
+
|
| 124 |
+
return result
|
| 125 |
+
|
| 126 |
+
def find_missing_groups(self, missing_mask: pd.Series) -> List[tuple]:
|
| 127 |
+
"""연속된 결측치 그룹 찾기"""
|
| 128 |
+
|
| 129 |
+
groups = []
|
| 130 |
+
in_group = False
|
| 131 |
+
start_idx = 0
|
| 132 |
+
|
| 133 |
+
for i, is_missing in enumerate(missing_mask):
|
| 134 |
+
if is_missing and not in_group:
|
| 135 |
+
start_idx = i
|
| 136 |
+
in_group = True
|
| 137 |
+
elif not is_missing and in_group:
|
| 138 |
+
groups.append((start_idx, i - 1))
|
| 139 |
+
in_group = False
|
| 140 |
+
|
| 141 |
+
if in_group:
|
| 142 |
+
groups.append((start_idx, len(missing_mask) - 1))
|
| 143 |
+
|
| 144 |
+
return groups
|
| 145 |
+
|
| 146 |
+
def remove_outliers(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 147 |
+
"""이상치 제거"""
|
| 148 |
+
|
| 149 |
+
# Z-score 방법
|
| 150 |
+
for column in ['tide_level', 'air_temp', 'air_pres']:
|
| 151 |
+
if column not in df.columns:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
z_scores = np.abs((df[column] - df[column].mean()) / df[column].std())
|
| 155 |
+
|
| 156 |
+
# Z-score > 4인 경우 이상치로 판단
|
| 157 |
+
outlier_mask = z_scores > 4
|
| 158 |
+
|
| 159 |
+
if outlier_mask.any():
|
| 160 |
+
logger.info(f"Removing {outlier_mask.sum()} outliers from {column}")
|
| 161 |
+
df.loc[outlier_mask, column] = np.nan
|
| 162 |
+
# 이상치는 선형 보간으로 대체
|
| 163 |
+
df[column].interpolate(method='linear', inplace=True)
|
| 164 |
+
|
| 165 |
+
# 물리적 범위 체크
|
| 166 |
+
if 'tide_level' in df.columns:
|
| 167 |
+
df.loc[df['tide_level'] < -100, 'tide_level'] = np.nan
|
| 168 |
+
df.loc[df['tide_level'] > 1000, 'tide_level'] = np.nan
|
| 169 |
+
df['tide_level'].interpolate(method='linear', inplace=True)
|
| 170 |
+
|
| 171 |
+
return df
|
| 172 |
+
|
| 173 |
+
def resample_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 174 |
+
"""5분 간격으로 리샘플링"""
|
| 175 |
+
|
| 176 |
+
# 시계열 인덱스 확인
|
| 177 |
+
if not isinstance(df.index, pd.DatetimeIndex):
|
| 178 |
+
df.index = pd.to_datetime(df.index)
|
| 179 |
+
|
| 180 |
+
# 5분 평균으로 리샘플링
|
| 181 |
+
numeric_columns = df.select_dtypes(include=[np.number]).columns
|
| 182 |
+
resampled = df[numeric_columns].resample(self.resample_interval).mean()
|
| 183 |
+
|
| 184 |
+
# 카테고리 데이터는 최빈값
|
| 185 |
+
categorical_columns = df.select_dtypes(exclude=[np.number]).columns
|
| 186 |
+
for col in categorical_columns:
|
| 187 |
+
if col in df.columns:
|
| 188 |
+
resampled[col] = df[col].resample(self.resample_interval).agg(lambda x: x.mode()[0] if not x.empty else None)
|
| 189 |
+
|
| 190 |
+
# station_id 복원
|
| 191 |
+
if 'station_id' not in resampled.columns and 'station_id' in df.columns:
|
| 192 |
+
resampled['station_id'] = df['station_id'].iloc[0]
|
| 193 |
+
|
| 194 |
+
return resampled.reset_index()
|
| 195 |
+
|
| 196 |
+
def calculate_residual(self, df: pd.DataFrame, station_id: str) -> pd.DataFrame:
|
| 197 |
+
"""잔차 계산 (관측값 - 조화예측값)"""
|
| 198 |
+
|
| 199 |
+
# 조화 예측값 가져오기 (별도 구현 필요)
|
| 200 |
+
# 여기서는 간단한 예시
|
| 201 |
+
df['astronomical_tide'] = self.get_astronomical_tide(station_id, df.index)
|
| 202 |
+
|
| 203 |
+
if 'tide_level' in df.columns:
|
| 204 |
+
df['residual'] = df['tide_level'] - df['astronomical_tide']
|
| 205 |
+
else:
|
| 206 |
+
df['residual'] = 0
|
| 207 |
+
|
| 208 |
+
return df
|
| 209 |
+
|
| 210 |
+
def get_astronomical_tide(self, station_id: str, timestamps: pd.DatetimeIndex) -> np.ndarray:
|
| 211 |
+
"""조화 예측값 계산 (간단한 예시)"""
|
| 212 |
+
|
| 213 |
+
# 실제로는 조화 상수를 사용한 계산 필요
|
| 214 |
+
# 여기서는 간단한 사인 함수로 대체
|
| 215 |
+
hours = timestamps.hour + timestamps.minute / 60
|
| 216 |
+
|
| 217 |
+
# 주요 조석 성분 (M2: 12.42시간 주기)
|
| 218 |
+
M2_period = 12.42
|
| 219 |
+
tide = 200 * np.sin(2 * np.pi * hours / M2_period)
|
| 220 |
+
|
| 221 |
+
# 평균 해수면 높이 추가
|
| 222 |
+
tide += 300
|
| 223 |
+
|
| 224 |
+
return tide
|
| 225 |
+
|
| 226 |
+
async def save_to_database(self, df: pd.DataFrame) -> int:
|
| 227 |
+
"""처리된 데이터를 Supabase에 저장"""
|
| 228 |
+
|
| 229 |
+
if df.empty:
|
| 230 |
+
return 0
|
| 231 |
+
|
| 232 |
+
# 저장할 데이터 준비
|
| 233 |
+
records = []
|
| 234 |
+
for _, row in df.iterrows():
|
| 235 |
+
record = {
|
| 236 |
+
"station_id": row.get("station_id"),
|
| 237 |
+
"observed_at": row.get("observed_at").isoformat() if pd.notna(row.get("observed_at")) else None,
|
| 238 |
+
"tide_level": float(row.get("tide_level", 0)),
|
| 239 |
+
"astronomical_tide": float(row.get("astronomical_tide", 0)),
|
| 240 |
+
"residual": float(row.get("residual", 0)),
|
| 241 |
+
"air_temp": float(row.get("air_temp", 0)),
|
| 242 |
+
"air_pres": float(row.get("air_pres", 1013)),
|
| 243 |
+
"wind_dir": float(row.get("wind_dir", 0)),
|
| 244 |
+
"wind_speed": float(row.get("wind_speed", 0)),
|
| 245 |
+
"interpolated": False, # 보간 여부
|
| 246 |
+
"created_at": datetime.now().isoformat()
|
| 247 |
+
}
|
| 248 |
+
records.append(record)
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
# Supabase에 저장
|
| 252 |
+
response = self.client.table("tide_observations_processed").insert(records).execute()
|
| 253 |
+
|
| 254 |
+
logger.info(f"Saved {len(records)} records to database")
|
| 255 |
+
return len(records)
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Failed to save data: {str(e)}")
|
| 259 |
+
return 0
|
| 260 |
+
|
| 261 |
+
async def cleanup_old_data(self, days_to_keep: int = 7) -> int:
|
| 262 |
+
"""오래된 원시 데이터 정리"""
|
| 263 |
+
|
| 264 |
+
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
# 오래된 데이터 삭제
|
| 268 |
+
response = self.client.table("tide_observations_raw").delete().lt(
|
| 269 |
+
"created_at", cutoff_date.isoformat()
|
| 270 |
+
).execute()
|
| 271 |
+
|
| 272 |
+
deleted_count = len(response.data) if response.data else 0
|
| 273 |
+
logger.info(f"Cleaned up {deleted_count} old records")
|
| 274 |
+
|
| 275 |
+
return deleted_count
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.error(f"Cleanup failed: {str(e)}")
|
| 279 |
+
return 0
|
automation/internal_api.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Internal API endpoints for automation
|
| 3 |
+
GitHub Actions에서 호출하는 내부 API
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import HTTPException, Header, Request
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
import os
|
| 9 |
+
import asyncio
|
| 10 |
+
from typing import Optional
|
| 11 |
+
import logging
|
| 12 |
+
from config import INTERNAL_API_KEY
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# 환경변수에서 내부 API 키 가져오기
|
| 17 |
+
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "")
|
| 18 |
+
|
| 19 |
+
def verify_internal_api_key(authorization: str = Header(None)):
|
| 20 |
+
"""내부 API 키 검증"""
|
| 21 |
+
if not authorization or authorization != f"Bearer {INTERNAL_API_KEY}":
|
| 22 |
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
def register_internal_routes(app: FastAPI):
|
| 26 |
+
"""FastAPI 앱에 내부 API 라우트 등록"""
|
| 27 |
+
|
| 28 |
+
@app.post("/api/internal/collect_data", tags=["Internal"])
|
| 29 |
+
async def collect_data_endpoint(
|
| 30 |
+
request: Request,
|
| 31 |
+
authorization: str = Header(None)
|
| 32 |
+
):
|
| 33 |
+
"""데이터 수집 및 처리 엔드포인트"""
|
| 34 |
+
verify_internal_api_key(authorization)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
# 데이터 수집 작업 시작
|
| 38 |
+
from automation.data_collector import DataCollector
|
| 39 |
+
from automation.data_processor import DataProcessor
|
| 40 |
+
|
| 41 |
+
collector = DataCollector()
|
| 42 |
+
processor = DataProcessor()
|
| 43 |
+
|
| 44 |
+
# 1. 모든 관측소 데이터 수집
|
| 45 |
+
collected_data = await collector.collect_all_stations()
|
| 46 |
+
|
| 47 |
+
# 2. 데이터 처리 (결측치 처리, 리샘플링)
|
| 48 |
+
processed_data = await processor.process_data(collected_data)
|
| 49 |
+
|
| 50 |
+
# 3. Supabase 저장
|
| 51 |
+
saved_count = await processor.save_to_database(processed_data)
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"success": True,
|
| 55 |
+
"timestamp": datetime.now().isoformat(),
|
| 56 |
+
"stations_collected": len(collected_data),
|
| 57 |
+
"records_saved": saved_count,
|
| 58 |
+
"message": f"Successfully collected and processed data for {len(collected_data)} stations"
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"Data collection failed: {str(e)}")
|
| 63 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 64 |
+
|
| 65 |
+
@app.post("/api/internal/update_predictions", tags=["Internal"])
|
| 66 |
+
async def update_predictions_endpoint(
|
| 67 |
+
request: Request,
|
| 68 |
+
authorization: str = Header(None)
|
| 69 |
+
):
|
| 70 |
+
"""예측 업데이트 엔드포인트"""
|
| 71 |
+
verify_internal_api_key(authorization)
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
from automation.prediction_updater import PredictionUpdater
|
| 75 |
+
|
| 76 |
+
updater = PredictionUpdater()
|
| 77 |
+
|
| 78 |
+
# 1. 모든 관측소에 대한 예측 업데이트
|
| 79 |
+
results = await updater.update_all_predictions()
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
"success": True,
|
| 83 |
+
"timestamp": datetime.now().isoformat(),
|
| 84 |
+
"predictions_updated": results["updated_count"],
|
| 85 |
+
"stations": results["stations"],
|
| 86 |
+
"prediction_horizon": "72 hours",
|
| 87 |
+
"message": f"Successfully updated predictions for {results['updated_count']} stations"
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Prediction update failed: {str(e)}")
|
| 92 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 93 |
+
|
| 94 |
+
@app.get("/api/internal/data_freshness", tags=["Internal"])
|
| 95 |
+
async def check_data_freshness(
|
| 96 |
+
authorization: str = Header(None)
|
| 97 |
+
):
|
| 98 |
+
"""데이터 신선도 체크"""
|
| 99 |
+
verify_internal_api_key(authorization)
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
from supabase_utils import get_supabase_client
|
| 103 |
+
|
| 104 |
+
client = get_supabase_client()
|
| 105 |
+
|
| 106 |
+
# 각 관측소의 최신 데이터 시간 확인
|
| 107 |
+
freshness_report = {}
|
| 108 |
+
|
| 109 |
+
for station_id in ["DT_0001", "DT_0002", "DT_0003", "DT_0004", "DT_0005"]:
|
| 110 |
+
response = client.table("tide_observations_processed").select("observed_at").eq(
|
| 111 |
+
"station_id", station_id
|
| 112 |
+
).order("observed_at", desc=True).limit(1).execute()
|
| 113 |
+
|
| 114 |
+
if response.data:
|
| 115 |
+
last_update = datetime.fromisoformat(response.data[0]["observed_at"])
|
| 116 |
+
minutes_old = (datetime.now() - last_update).total_seconds() / 60
|
| 117 |
+
freshness_report[station_id] = {
|
| 118 |
+
"last_update": last_update.isoformat(),
|
| 119 |
+
"minutes_old": round(minutes_old, 2),
|
| 120 |
+
"status": "fresh" if minutes_old < 10 else "stale"
|
| 121 |
+
}
|
| 122 |
+
else:
|
| 123 |
+
freshness_report[station_id] = {
|
| 124 |
+
"last_update": None,
|
| 125 |
+
"minutes_old": None,
|
| 126 |
+
"status": "no_data"
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# 가장 오래된 데이터 찾기
|
| 130 |
+
oldest_minutes = max(
|
| 131 |
+
[v["minutes_old"] for v in freshness_report.values() if v["minutes_old"]],
|
| 132 |
+
default=0
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"timestamp": datetime.now().isoformat(),
|
| 137 |
+
"oldest_data_minutes": round(oldest_minutes, 2),
|
| 138 |
+
"stations": freshness_report,
|
| 139 |
+
"overall_status": "healthy" if oldest_minutes < 15 else "warning"
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logger.error(f"Freshness check failed: {str(e)}")
|
| 144 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 145 |
+
|
| 146 |
+
@app.post("/api/internal/manual_trigger", tags=["Internal"])
|
| 147 |
+
async def manual_trigger(
|
| 148 |
+
task: str,
|
| 149 |
+
authorization: str = Header(None)
|
| 150 |
+
):
|
| 151 |
+
"""수동 작업 트리거"""
|
| 152 |
+
verify_internal_api_key(authorization)
|
| 153 |
+
|
| 154 |
+
if task == "collect_now":
|
| 155 |
+
# 즉시 데이터 수집 실행
|
| 156 |
+
result = await collect_data_endpoint(None, authorization)
|
| 157 |
+
return result
|
| 158 |
+
elif task == "predict_now":
|
| 159 |
+
# 즉시 예측 업데이트
|
| 160 |
+
result = await update_predictions_endpoint(None, authorization)
|
| 161 |
+
return result
|
| 162 |
+
elif task == "cleanup":
|
| 163 |
+
# 오래된 데이터 정리
|
| 164 |
+
from automation.data_cleaner import cleanup_old_data
|
| 165 |
+
deleted_count = await cleanup_old_data(days_to_keep=7)
|
| 166 |
+
return {
|
| 167 |
+
"success": True,
|
| 168 |
+
"deleted_records": deleted_count,
|
| 169 |
+
"message": f"Cleaned up {deleted_count} old records"
|
| 170 |
+
}
|
| 171 |
+
else:
|
| 172 |
+
raise HTTPException(status_code=400, detail=f"Unknown task: {task}")
|
| 173 |
+
|
| 174 |
+
logger.info("Internal API routes registered successfully")
|
automation/prediction_updater.py
ADDED
|
File without changes
|
config.py
CHANGED
|
@@ -4,6 +4,8 @@ import os
|
|
| 4 |
SUPABASE_URL = os.environ.get("SUPABASE_URL")
|
| 5 |
SUPABASE_KEY = os.environ.get("SUPABASE_KEY")
|
| 6 |
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
|
|
|
|
|
|
|
| 7 |
|
| 8 |
STATIONS = [
|
| 9 |
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
|
|
@@ -17,4 +19,11 @@ STATION_NAMES = {
|
|
| 17 |
"DT_0037": "어청도", "DT_0043": "영흥도", "DT_0050": "태안", "DT_0051": "서천마량",
|
| 18 |
"DT_0052": "인천송도", "DT_0065": "덕적도", "DT_0066": "향화도", "DT_0067": "안흥",
|
| 19 |
"DT_0068": "위도"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
}
|
|
|
|
| 4 |
SUPABASE_URL = os.environ.get("SUPABASE_URL")
|
| 5 |
SUPABASE_KEY = os.environ.get("SUPABASE_KEY")
|
| 6 |
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
|
| 7 |
+
INTERNAL_API_KEY = os.environ.get("INTERNAL_API_KEY")
|
| 8 |
+
KHOA_API_KEY = os.environ.get("KHOA_API_KEY")
|
| 9 |
|
| 10 |
STATIONS = [
|
| 11 |
"DT_0001", "DT_0065", "DT_0008", "DT_0067", "DT_0043", "DT_0002",
|
|
|
|
| 19 |
"DT_0037": "어청도", "DT_0043": "영흥도", "DT_0050": "태안", "DT_0051": "서천마량",
|
| 20 |
"DT_0052": "인천송도", "DT_0065": "덕적도", "DT_0066": "향화도", "DT_0067": "안흥",
|
| 21 |
"DT_0068": "위도"
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
DATA_COLLECTION_CONFIG = {
|
| 25 |
+
"raw_data_retention_days": 3, # 원시 데이터 보관 기간
|
| 26 |
+
"processed_data_retention_days": 365, # 처리된 데이터 보관 기간
|
| 27 |
+
"resample_interval": "5T", # 5분 리샘플링
|
| 28 |
+
"missing_threshold_minutes": 10, # 결측치 임계값
|
| 29 |
}
|
models/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/Autoformer.py
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Embed import DataEmbedding, DataEmbedding_wo_pos
|
| 5 |
-
from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer
|
| 6 |
-
from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
|
| 7 |
-
import math
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Model(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
Autoformer is the first method to achieve the series-wise connection,
|
| 14 |
-
with inherent O(LlogL) complexity
|
| 15 |
-
Paper link: https://openreview.net/pdf?id=I55UqU-M11y
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
def __init__(self, configs):
|
| 19 |
-
super(Model, self).__init__()
|
| 20 |
-
self.task_name = configs.task_name
|
| 21 |
-
self.seq_len = configs.seq_len
|
| 22 |
-
self.label_len = configs.label_len
|
| 23 |
-
self.pred_len = configs.pred_len
|
| 24 |
-
|
| 25 |
-
# Decomp
|
| 26 |
-
kernel_size = configs.moving_avg
|
| 27 |
-
self.decomp = series_decomp(kernel_size)
|
| 28 |
-
|
| 29 |
-
# Embedding
|
| 30 |
-
self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 31 |
-
configs.dropout)
|
| 32 |
-
# Encoder
|
| 33 |
-
self.encoder = Encoder(
|
| 34 |
-
[
|
| 35 |
-
EncoderLayer(
|
| 36 |
-
AutoCorrelationLayer(
|
| 37 |
-
AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
|
| 38 |
-
output_attention=False),
|
| 39 |
-
configs.d_model, configs.n_heads),
|
| 40 |
-
configs.d_model,
|
| 41 |
-
configs.d_ff,
|
| 42 |
-
moving_avg=configs.moving_avg,
|
| 43 |
-
dropout=configs.dropout,
|
| 44 |
-
activation=configs.activation
|
| 45 |
-
) for l in range(configs.e_layers)
|
| 46 |
-
],
|
| 47 |
-
norm_layer=my_Layernorm(configs.d_model)
|
| 48 |
-
)
|
| 49 |
-
# Decoder
|
| 50 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 51 |
-
self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
| 52 |
-
configs.dropout)
|
| 53 |
-
self.decoder = Decoder(
|
| 54 |
-
[
|
| 55 |
-
DecoderLayer(
|
| 56 |
-
AutoCorrelationLayer(
|
| 57 |
-
AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout,
|
| 58 |
-
output_attention=False),
|
| 59 |
-
configs.d_model, configs.n_heads),
|
| 60 |
-
AutoCorrelationLayer(
|
| 61 |
-
AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
|
| 62 |
-
output_attention=False),
|
| 63 |
-
configs.d_model, configs.n_heads),
|
| 64 |
-
configs.d_model,
|
| 65 |
-
configs.c_out,
|
| 66 |
-
configs.d_ff,
|
| 67 |
-
moving_avg=configs.moving_avg,
|
| 68 |
-
dropout=configs.dropout,
|
| 69 |
-
activation=configs.activation,
|
| 70 |
-
)
|
| 71 |
-
for l in range(configs.d_layers)
|
| 72 |
-
],
|
| 73 |
-
norm_layer=my_Layernorm(configs.d_model),
|
| 74 |
-
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 75 |
-
)
|
| 76 |
-
if self.task_name == 'imputation':
|
| 77 |
-
self.projection = nn.Linear(
|
| 78 |
-
configs.d_model, configs.c_out, bias=True)
|
| 79 |
-
if self.task_name == 'anomaly_detection':
|
| 80 |
-
self.projection = nn.Linear(
|
| 81 |
-
configs.d_model, configs.c_out, bias=True)
|
| 82 |
-
if self.task_name == 'classification':
|
| 83 |
-
self.act = F.gelu
|
| 84 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 85 |
-
self.projection = nn.Linear(
|
| 86 |
-
configs.d_model * configs.seq_len, configs.num_class)
|
| 87 |
-
|
| 88 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 89 |
-
# decomp init
|
| 90 |
-
mean = torch.mean(x_enc, dim=1).unsqueeze(
|
| 91 |
-
1).repeat(1, self.pred_len, 1)
|
| 92 |
-
zeros = torch.zeros([x_dec.shape[0], self.pred_len,
|
| 93 |
-
x_dec.shape[2]], device=x_enc.device)
|
| 94 |
-
seasonal_init, trend_init = self.decomp(x_enc)
|
| 95 |
-
# decoder input
|
| 96 |
-
trend_init = torch.cat(
|
| 97 |
-
[trend_init[:, -self.label_len:, :], mean], dim=1)
|
| 98 |
-
seasonal_init = torch.cat(
|
| 99 |
-
[seasonal_init[:, -self.label_len:, :], zeros], dim=1)
|
| 100 |
-
# enc
|
| 101 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 102 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 103 |
-
# dec
|
| 104 |
-
dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
|
| 105 |
-
seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None,
|
| 106 |
-
trend=trend_init)
|
| 107 |
-
# final
|
| 108 |
-
dec_out = trend_part + seasonal_part
|
| 109 |
-
return dec_out
|
| 110 |
-
|
| 111 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 112 |
-
# enc
|
| 113 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 114 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 115 |
-
# final
|
| 116 |
-
dec_out = self.projection(enc_out)
|
| 117 |
-
return dec_out
|
| 118 |
-
|
| 119 |
-
def anomaly_detection(self, x_enc):
|
| 120 |
-
# enc
|
| 121 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 122 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 123 |
-
# final
|
| 124 |
-
dec_out = self.projection(enc_out)
|
| 125 |
-
return dec_out
|
| 126 |
-
|
| 127 |
-
def classification(self, x_enc, x_mark_enc):
|
| 128 |
-
# enc
|
| 129 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 130 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 131 |
-
|
| 132 |
-
# Output
|
| 133 |
-
# the output transformer encoder/decoder embeddings don't include non-linearity
|
| 134 |
-
output = self.act(enc_out)
|
| 135 |
-
output = self.dropout(output)
|
| 136 |
-
# zero-out padding embeddings
|
| 137 |
-
output = output * x_mark_enc.unsqueeze(-1)
|
| 138 |
-
# (batch_size, seq_length * d_model)
|
| 139 |
-
output = output.reshape(output.shape[0], -1)
|
| 140 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 141 |
-
return output
|
| 142 |
-
|
| 143 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 144 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 145 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 146 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 147 |
-
if self.task_name == 'imputation':
|
| 148 |
-
dec_out = self.imputation(
|
| 149 |
-
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 150 |
-
return dec_out # [B, L, D]
|
| 151 |
-
if self.task_name == 'anomaly_detection':
|
| 152 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 153 |
-
return dec_out # [B, L, D]
|
| 154 |
-
if self.task_name == 'classification':
|
| 155 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 156 |
-
return dec_out # [B, N]
|
| 157 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Crossformer.py
DELETED
|
@@ -1,145 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from einops import rearrange, repeat
|
| 5 |
-
from layers.Crossformer_EncDec import scale_block, Encoder, Decoder, DecoderLayer
|
| 6 |
-
from layers.Embed import PatchEmbedding
|
| 7 |
-
from layers.SelfAttention_Family import AttentionLayer, FullAttention, TwoStageAttentionLayer
|
| 8 |
-
from models.PatchTST import FlattenHead
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from math import ceil
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class Model(nn.Module):
|
| 15 |
-
"""
|
| 16 |
-
Paper link: https://openreview.net/pdf?id=vSVLM2j9eie
|
| 17 |
-
"""
|
| 18 |
-
def __init__(self, configs):
|
| 19 |
-
super(Model, self).__init__()
|
| 20 |
-
self.enc_in = configs.enc_in
|
| 21 |
-
self.seq_len = configs.seq_len
|
| 22 |
-
self.pred_len = configs.pred_len
|
| 23 |
-
self.seg_len = 12
|
| 24 |
-
self.win_size = 2
|
| 25 |
-
self.task_name = configs.task_name
|
| 26 |
-
|
| 27 |
-
# The padding operation to handle invisible sgemnet length
|
| 28 |
-
self.pad_in_len = ceil(1.0 * configs.seq_len / self.seg_len) * self.seg_len
|
| 29 |
-
self.pad_out_len = ceil(1.0 * configs.pred_len / self.seg_len) * self.seg_len
|
| 30 |
-
self.in_seg_num = self.pad_in_len // self.seg_len
|
| 31 |
-
self.out_seg_num = ceil(self.in_seg_num / (self.win_size ** (configs.e_layers - 1)))
|
| 32 |
-
self.head_nf = configs.d_model * self.out_seg_num
|
| 33 |
-
|
| 34 |
-
# Embedding
|
| 35 |
-
self.enc_value_embedding = PatchEmbedding(configs.d_model, self.seg_len, self.seg_len, self.pad_in_len - configs.seq_len, 0)
|
| 36 |
-
self.enc_pos_embedding = nn.Parameter(
|
| 37 |
-
torch.randn(1, configs.enc_in, self.in_seg_num, configs.d_model))
|
| 38 |
-
self.pre_norm = nn.LayerNorm(configs.d_model)
|
| 39 |
-
|
| 40 |
-
# Encoder
|
| 41 |
-
self.encoder = Encoder(
|
| 42 |
-
[
|
| 43 |
-
scale_block(configs, 1 if l is 0 else self.win_size, configs.d_model, configs.n_heads, configs.d_ff,
|
| 44 |
-
1, configs.dropout,
|
| 45 |
-
self.in_seg_num if l is 0 else ceil(self.in_seg_num / self.win_size ** l), configs.factor
|
| 46 |
-
) for l in range(configs.e_layers)
|
| 47 |
-
]
|
| 48 |
-
)
|
| 49 |
-
# Decoder
|
| 50 |
-
self.dec_pos_embedding = nn.Parameter(
|
| 51 |
-
torch.randn(1, configs.enc_in, (self.pad_out_len // self.seg_len), configs.d_model))
|
| 52 |
-
|
| 53 |
-
self.decoder = Decoder(
|
| 54 |
-
[
|
| 55 |
-
DecoderLayer(
|
| 56 |
-
TwoStageAttentionLayer(configs, (self.pad_out_len // self.seg_len), configs.factor, configs.d_model, configs.n_heads,
|
| 57 |
-
configs.d_ff, configs.dropout),
|
| 58 |
-
AttentionLayer(
|
| 59 |
-
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 60 |
-
output_attention=False),
|
| 61 |
-
configs.d_model, configs.n_heads),
|
| 62 |
-
self.seg_len,
|
| 63 |
-
configs.d_model,
|
| 64 |
-
configs.d_ff,
|
| 65 |
-
dropout=configs.dropout,
|
| 66 |
-
# activation=configs.activation,
|
| 67 |
-
)
|
| 68 |
-
for l in range(configs.e_layers + 1)
|
| 69 |
-
],
|
| 70 |
-
)
|
| 71 |
-
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
| 72 |
-
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len,
|
| 73 |
-
head_dropout=configs.dropout)
|
| 74 |
-
elif self.task_name == 'classification':
|
| 75 |
-
self.flatten = nn.Flatten(start_dim=-2)
|
| 76 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 77 |
-
self.projection = nn.Linear(
|
| 78 |
-
self.head_nf * configs.enc_in, configs.num_class)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 83 |
-
# embedding
|
| 84 |
-
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
| 85 |
-
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d = n_vars)
|
| 86 |
-
x_enc += self.enc_pos_embedding
|
| 87 |
-
x_enc = self.pre_norm(x_enc)
|
| 88 |
-
enc_out, attns = self.encoder(x_enc)
|
| 89 |
-
|
| 90 |
-
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat=x_enc.shape[0])
|
| 91 |
-
dec_out = self.decoder(dec_in, enc_out)
|
| 92 |
-
return dec_out
|
| 93 |
-
|
| 94 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 95 |
-
# embedding
|
| 96 |
-
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
| 97 |
-
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
|
| 98 |
-
x_enc += self.enc_pos_embedding
|
| 99 |
-
x_enc = self.pre_norm(x_enc)
|
| 100 |
-
enc_out, attns = self.encoder(x_enc)
|
| 101 |
-
|
| 102 |
-
dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1)
|
| 103 |
-
|
| 104 |
-
return dec_out
|
| 105 |
-
|
| 106 |
-
def anomaly_detection(self, x_enc):
|
| 107 |
-
# embedding
|
| 108 |
-
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
| 109 |
-
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
|
| 110 |
-
x_enc += self.enc_pos_embedding
|
| 111 |
-
x_enc = self.pre_norm(x_enc)
|
| 112 |
-
enc_out, attns = self.encoder(x_enc)
|
| 113 |
-
|
| 114 |
-
dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1)
|
| 115 |
-
return dec_out
|
| 116 |
-
|
| 117 |
-
def classification(self, x_enc, x_mark_enc):
|
| 118 |
-
# embedding
|
| 119 |
-
x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1))
|
| 120 |
-
|
| 121 |
-
x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars)
|
| 122 |
-
x_enc += self.enc_pos_embedding
|
| 123 |
-
x_enc = self.pre_norm(x_enc)
|
| 124 |
-
enc_out, attns = self.encoder(x_enc)
|
| 125 |
-
# Output from Non-stationary Transformer
|
| 126 |
-
output = self.flatten(enc_out[-1].permute(0, 1, 3, 2))
|
| 127 |
-
output = self.dropout(output)
|
| 128 |
-
output = output.reshape(output.shape[0], -1)
|
| 129 |
-
output = self.projection(output)
|
| 130 |
-
return output
|
| 131 |
-
|
| 132 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 133 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 134 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 135 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 136 |
-
if self.task_name == 'imputation':
|
| 137 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 138 |
-
return dec_out # [B, L, D]
|
| 139 |
-
if self.task_name == 'anomaly_detection':
|
| 140 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 141 |
-
return dec_out # [B, L, D]
|
| 142 |
-
if self.task_name == 'classification':
|
| 143 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 144 |
-
return dec_out # [B, N]
|
| 145 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/DLinear.py
DELETED
|
@@ -1,110 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Autoformer_EncDec import series_decomp
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class Model(nn.Module):
|
| 8 |
-
"""
|
| 9 |
-
Paper link: https://arxiv.org/pdf/2205.13504.pdf
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def __init__(self, configs, individual=False):
|
| 13 |
-
"""
|
| 14 |
-
individual: Bool, whether shared model among different variates.
|
| 15 |
-
"""
|
| 16 |
-
super(Model, self).__init__()
|
| 17 |
-
self.task_name = configs.task_name
|
| 18 |
-
self.seq_len = configs.seq_len
|
| 19 |
-
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
| 20 |
-
self.pred_len = configs.seq_len
|
| 21 |
-
else:
|
| 22 |
-
self.pred_len = configs.pred_len
|
| 23 |
-
# Series decomposition block from Autoformer
|
| 24 |
-
self.decompsition = series_decomp(configs.moving_avg)
|
| 25 |
-
self.individual = individual
|
| 26 |
-
self.channels = configs.enc_in
|
| 27 |
-
|
| 28 |
-
if self.individual:
|
| 29 |
-
self.Linear_Seasonal = nn.ModuleList()
|
| 30 |
-
self.Linear_Trend = nn.ModuleList()
|
| 31 |
-
|
| 32 |
-
for i in range(self.channels):
|
| 33 |
-
self.Linear_Seasonal.append(
|
| 34 |
-
nn.Linear(self.seq_len, self.pred_len))
|
| 35 |
-
self.Linear_Trend.append(
|
| 36 |
-
nn.Linear(self.seq_len, self.pred_len))
|
| 37 |
-
|
| 38 |
-
self.Linear_Seasonal[i].weight = nn.Parameter(
|
| 39 |
-
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
| 40 |
-
self.Linear_Trend[i].weight = nn.Parameter(
|
| 41 |
-
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
| 42 |
-
else:
|
| 43 |
-
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
|
| 44 |
-
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
|
| 45 |
-
|
| 46 |
-
self.Linear_Seasonal.weight = nn.Parameter(
|
| 47 |
-
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
| 48 |
-
self.Linear_Trend.weight = nn.Parameter(
|
| 49 |
-
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
|
| 50 |
-
|
| 51 |
-
if self.task_name == 'classification':
|
| 52 |
-
self.projection = nn.Linear(
|
| 53 |
-
configs.enc_in * configs.seq_len, configs.num_class)
|
| 54 |
-
|
| 55 |
-
def encoder(self, x):
|
| 56 |
-
seasonal_init, trend_init = self.decompsition(x)
|
| 57 |
-
seasonal_init, trend_init = seasonal_init.permute(
|
| 58 |
-
0, 2, 1), trend_init.permute(0, 2, 1)
|
| 59 |
-
if self.individual:
|
| 60 |
-
seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
|
| 61 |
-
dtype=seasonal_init.dtype).to(seasonal_init.device)
|
| 62 |
-
trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
|
| 63 |
-
dtype=trend_init.dtype).to(trend_init.device)
|
| 64 |
-
for i in range(self.channels):
|
| 65 |
-
seasonal_output[:, i, :] = self.Linear_Seasonal[i](
|
| 66 |
-
seasonal_init[:, i, :])
|
| 67 |
-
trend_output[:, i, :] = self.Linear_Trend[i](
|
| 68 |
-
trend_init[:, i, :])
|
| 69 |
-
else:
|
| 70 |
-
seasonal_output = self.Linear_Seasonal(seasonal_init)
|
| 71 |
-
trend_output = self.Linear_Trend(trend_init)
|
| 72 |
-
x = seasonal_output + trend_output
|
| 73 |
-
return x.permute(0, 2, 1)
|
| 74 |
-
|
| 75 |
-
def forecast(self, x_enc):
|
| 76 |
-
# Encoder
|
| 77 |
-
return self.encoder(x_enc)
|
| 78 |
-
|
| 79 |
-
def imputation(self, x_enc):
|
| 80 |
-
# Encoder
|
| 81 |
-
return self.encoder(x_enc)
|
| 82 |
-
|
| 83 |
-
def anomaly_detection(self, x_enc):
|
| 84 |
-
# Encoder
|
| 85 |
-
return self.encoder(x_enc)
|
| 86 |
-
|
| 87 |
-
def classification(self, x_enc):
|
| 88 |
-
# Encoder
|
| 89 |
-
enc_out = self.encoder(x_enc)
|
| 90 |
-
# Output
|
| 91 |
-
# (batch_size, seq_length * d_model)
|
| 92 |
-
output = enc_out.reshape(enc_out.shape[0], -1)
|
| 93 |
-
# (batch_size, num_classes)
|
| 94 |
-
output = self.projection(output)
|
| 95 |
-
return output
|
| 96 |
-
|
| 97 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 98 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 99 |
-
dec_out = self.forecast(x_enc)
|
| 100 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 101 |
-
if self.task_name == 'imputation':
|
| 102 |
-
dec_out = self.imputation(x_enc)
|
| 103 |
-
return dec_out # [B, L, D]
|
| 104 |
-
if self.task_name == 'anomaly_detection':
|
| 105 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 106 |
-
return dec_out # [B, L, D]
|
| 107 |
-
if self.task_name == 'classification':
|
| 108 |
-
dec_out = self.classification(x_enc)
|
| 109 |
-
return dec_out # [B, N]
|
| 110 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/ETSformer.py
DELETED
|
@@ -1,110 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from layers.Embed import DataEmbedding
|
| 4 |
-
from layers.ETSformer_EncDec import EncoderLayer, Encoder, DecoderLayer, Decoder, Transform
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class Model(nn.Module):
|
| 8 |
-
"""
|
| 9 |
-
Paper link: https://arxiv.org/abs/2202.01381
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def __init__(self, configs):
|
| 13 |
-
super(Model, self).__init__()
|
| 14 |
-
self.task_name = configs.task_name
|
| 15 |
-
self.seq_len = configs.seq_len
|
| 16 |
-
self.label_len = configs.label_len
|
| 17 |
-
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
| 18 |
-
self.pred_len = configs.seq_len
|
| 19 |
-
else:
|
| 20 |
-
self.pred_len = configs.pred_len
|
| 21 |
-
|
| 22 |
-
assert configs.e_layers == configs.d_layers, "Encoder and decoder layers must be equal"
|
| 23 |
-
|
| 24 |
-
# Embedding
|
| 25 |
-
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 26 |
-
configs.dropout)
|
| 27 |
-
|
| 28 |
-
# Encoder
|
| 29 |
-
self.encoder = Encoder(
|
| 30 |
-
[
|
| 31 |
-
EncoderLayer(
|
| 32 |
-
configs.d_model, configs.n_heads, configs.enc_in, configs.seq_len, self.pred_len, configs.top_k,
|
| 33 |
-
dim_feedforward=configs.d_ff,
|
| 34 |
-
dropout=configs.dropout,
|
| 35 |
-
activation=configs.activation,
|
| 36 |
-
) for _ in range(configs.e_layers)
|
| 37 |
-
]
|
| 38 |
-
)
|
| 39 |
-
# Decoder
|
| 40 |
-
self.decoder = Decoder(
|
| 41 |
-
[
|
| 42 |
-
DecoderLayer(
|
| 43 |
-
configs.d_model, configs.n_heads, configs.c_out, self.pred_len,
|
| 44 |
-
dropout=configs.dropout,
|
| 45 |
-
) for _ in range(configs.d_layers)
|
| 46 |
-
],
|
| 47 |
-
)
|
| 48 |
-
self.transform = Transform(sigma=0.2)
|
| 49 |
-
|
| 50 |
-
if self.task_name == 'classification':
|
| 51 |
-
self.act = torch.nn.functional.gelu
|
| 52 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 53 |
-
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
| 54 |
-
|
| 55 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 56 |
-
with torch.no_grad():
|
| 57 |
-
if self.training:
|
| 58 |
-
x_enc = self.transform.transform(x_enc)
|
| 59 |
-
res = self.enc_embedding(x_enc, x_mark_enc)
|
| 60 |
-
level, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
| 61 |
-
|
| 62 |
-
growth, season = self.decoder(growths, seasons)
|
| 63 |
-
preds = level[:, -1:] + growth + season
|
| 64 |
-
return preds
|
| 65 |
-
|
| 66 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 67 |
-
res = self.enc_embedding(x_enc, x_mark_enc)
|
| 68 |
-
level, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
| 69 |
-
growth, season = self.decoder(growths, seasons)
|
| 70 |
-
preds = level[:, -1:] + growth + season
|
| 71 |
-
return preds
|
| 72 |
-
|
| 73 |
-
def anomaly_detection(self, x_enc):
|
| 74 |
-
res = self.enc_embedding(x_enc, None)
|
| 75 |
-
level, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
| 76 |
-
growth, season = self.decoder(growths, seasons)
|
| 77 |
-
preds = level[:, -1:] + growth + season
|
| 78 |
-
return preds
|
| 79 |
-
|
| 80 |
-
def classification(self, x_enc, x_mark_enc):
|
| 81 |
-
res = self.enc_embedding(x_enc, None)
|
| 82 |
-
_, growths, seasons = self.encoder(res, x_enc, attn_mask=None)
|
| 83 |
-
|
| 84 |
-
growths = torch.sum(torch.stack(growths, 0), 0)[:, :self.seq_len, :]
|
| 85 |
-
seasons = torch.sum(torch.stack(seasons, 0), 0)[:, :self.seq_len, :]
|
| 86 |
-
|
| 87 |
-
enc_out = growths + seasons
|
| 88 |
-
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
| 89 |
-
output = self.dropout(output)
|
| 90 |
-
|
| 91 |
-
# Output
|
| 92 |
-
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
| 93 |
-
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
| 94 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 95 |
-
return output
|
| 96 |
-
|
| 97 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 98 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 99 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 100 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 101 |
-
if self.task_name == 'imputation':
|
| 102 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 103 |
-
return dec_out # [B, L, D]
|
| 104 |
-
if self.task_name == 'anomaly_detection':
|
| 105 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 106 |
-
return dec_out # [B, L, D]
|
| 107 |
-
if self.task_name == 'classification':
|
| 108 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 109 |
-
return dec_out # [B, N]
|
| 110 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FEDformer.py
DELETED
|
@@ -1,176 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Embed import DataEmbedding
|
| 5 |
-
from layers.AutoCorrelation import AutoCorrelationLayer
|
| 6 |
-
from layers.FourierCorrelation import FourierBlock, FourierCrossAttention
|
| 7 |
-
from layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform
|
| 8 |
-
from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Model(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity
|
| 14 |
-
Paper link: https://proceedings.mlr.press/v162/zhou22g.html
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, configs, version='fourier', mode_select='random', modes=32):
|
| 18 |
-
"""
|
| 19 |
-
version: str, for FEDformer, there are two versions to choose, options: [Fourier, Wavelets].
|
| 20 |
-
mode_select: str, for FEDformer, there are two mode selection method, options: [random, low].
|
| 21 |
-
modes: int, modes to be selected.
|
| 22 |
-
"""
|
| 23 |
-
super(Model, self).__init__()
|
| 24 |
-
self.task_name = configs.task_name
|
| 25 |
-
self.seq_len = configs.seq_len
|
| 26 |
-
self.label_len = configs.label_len
|
| 27 |
-
self.pred_len = configs.pred_len
|
| 28 |
-
|
| 29 |
-
self.version = version
|
| 30 |
-
self.mode_select = mode_select
|
| 31 |
-
self.modes = modes
|
| 32 |
-
|
| 33 |
-
# Decomp
|
| 34 |
-
self.decomp = series_decomp(configs.moving_avg)
|
| 35 |
-
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 36 |
-
configs.dropout)
|
| 37 |
-
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
| 38 |
-
configs.dropout)
|
| 39 |
-
|
| 40 |
-
if self.version == 'Wavelets':
|
| 41 |
-
encoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=1, base='legendre')
|
| 42 |
-
decoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=1, base='legendre')
|
| 43 |
-
decoder_cross_att = MultiWaveletCross(in_channels=configs.d_model,
|
| 44 |
-
out_channels=configs.d_model,
|
| 45 |
-
seq_len_q=self.seq_len // 2 + self.pred_len,
|
| 46 |
-
seq_len_kv=self.seq_len,
|
| 47 |
-
modes=self.modes,
|
| 48 |
-
ich=configs.d_model,
|
| 49 |
-
base='legendre',
|
| 50 |
-
activation='tanh')
|
| 51 |
-
else:
|
| 52 |
-
encoder_self_att = FourierBlock(in_channels=configs.d_model,
|
| 53 |
-
out_channels=configs.d_model,
|
| 54 |
-
seq_len=self.seq_len,
|
| 55 |
-
modes=self.modes,
|
| 56 |
-
mode_select_method=self.mode_select)
|
| 57 |
-
decoder_self_att = FourierBlock(in_channels=configs.d_model,
|
| 58 |
-
out_channels=configs.d_model,
|
| 59 |
-
seq_len=self.seq_len // 2 + self.pred_len,
|
| 60 |
-
modes=self.modes,
|
| 61 |
-
mode_select_method=self.mode_select)
|
| 62 |
-
decoder_cross_att = FourierCrossAttention(in_channels=configs.d_model,
|
| 63 |
-
out_channels=configs.d_model,
|
| 64 |
-
seq_len_q=self.seq_len // 2 + self.pred_len,
|
| 65 |
-
seq_len_kv=self.seq_len,
|
| 66 |
-
modes=self.modes,
|
| 67 |
-
mode_select_method=self.mode_select,
|
| 68 |
-
num_heads=configs.n_heads)
|
| 69 |
-
# Encoder
|
| 70 |
-
self.encoder = Encoder(
|
| 71 |
-
[
|
| 72 |
-
EncoderLayer(
|
| 73 |
-
AutoCorrelationLayer(
|
| 74 |
-
encoder_self_att, # instead of multi-head attention in transformer
|
| 75 |
-
configs.d_model, configs.n_heads),
|
| 76 |
-
configs.d_model,
|
| 77 |
-
configs.d_ff,
|
| 78 |
-
moving_avg=configs.moving_avg,
|
| 79 |
-
dropout=configs.dropout,
|
| 80 |
-
activation=configs.activation
|
| 81 |
-
) for l in range(configs.e_layers)
|
| 82 |
-
],
|
| 83 |
-
norm_layer=my_Layernorm(configs.d_model)
|
| 84 |
-
)
|
| 85 |
-
# Decoder
|
| 86 |
-
self.decoder = Decoder(
|
| 87 |
-
[
|
| 88 |
-
DecoderLayer(
|
| 89 |
-
AutoCorrelationLayer(
|
| 90 |
-
decoder_self_att,
|
| 91 |
-
configs.d_model, configs.n_heads),
|
| 92 |
-
AutoCorrelationLayer(
|
| 93 |
-
decoder_cross_att,
|
| 94 |
-
configs.d_model, configs.n_heads),
|
| 95 |
-
configs.d_model,
|
| 96 |
-
configs.c_out,
|
| 97 |
-
configs.d_ff,
|
| 98 |
-
moving_avg=configs.moving_avg,
|
| 99 |
-
dropout=configs.dropout,
|
| 100 |
-
activation=configs.activation,
|
| 101 |
-
)
|
| 102 |
-
for l in range(configs.d_layers)
|
| 103 |
-
],
|
| 104 |
-
norm_layer=my_Layernorm(configs.d_model),
|
| 105 |
-
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
if self.task_name == 'imputation':
|
| 109 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 110 |
-
if self.task_name == 'anomaly_detection':
|
| 111 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 112 |
-
if self.task_name == 'classification':
|
| 113 |
-
self.act = F.gelu
|
| 114 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 115 |
-
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
| 116 |
-
|
| 117 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 118 |
-
# decomp init
|
| 119 |
-
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
|
| 120 |
-
seasonal_init, trend_init = self.decomp(x_enc) # x - moving_avg, moving_avg
|
| 121 |
-
# decoder input
|
| 122 |
-
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
|
| 123 |
-
seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len))
|
| 124 |
-
# enc
|
| 125 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 126 |
-
dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
|
| 127 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 128 |
-
# dec
|
| 129 |
-
seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init)
|
| 130 |
-
# final
|
| 131 |
-
dec_out = trend_part + seasonal_part
|
| 132 |
-
return dec_out
|
| 133 |
-
|
| 134 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 135 |
-
# enc
|
| 136 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 137 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 138 |
-
# final
|
| 139 |
-
dec_out = self.projection(enc_out)
|
| 140 |
-
return dec_out
|
| 141 |
-
|
| 142 |
-
def anomaly_detection(self, x_enc):
|
| 143 |
-
# enc
|
| 144 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 145 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 146 |
-
# final
|
| 147 |
-
dec_out = self.projection(enc_out)
|
| 148 |
-
return dec_out
|
| 149 |
-
|
| 150 |
-
def classification(self, x_enc, x_mark_enc):
|
| 151 |
-
# enc
|
| 152 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 153 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 154 |
-
|
| 155 |
-
# Output
|
| 156 |
-
output = self.act(enc_out)
|
| 157 |
-
output = self.dropout(output)
|
| 158 |
-
output = output * x_mark_enc.unsqueeze(-1)
|
| 159 |
-
output = output.reshape(output.shape[0], -1)
|
| 160 |
-
output = self.projection(output)
|
| 161 |
-
return output
|
| 162 |
-
|
| 163 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 164 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 165 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 166 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 167 |
-
if self.task_name == 'imputation':
|
| 168 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 169 |
-
return dec_out # [B, L, D]
|
| 170 |
-
if self.task_name == 'anomaly_detection':
|
| 171 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 172 |
-
return dec_out # [B, L, D]
|
| 173 |
-
if self.task_name == 'classification':
|
| 174 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 175 |
-
return dec_out # [B, N]
|
| 176 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FiLM.py
DELETED
|
@@ -1,268 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import numpy as np
|
| 5 |
-
from scipy import signal
|
| 6 |
-
from scipy import special as ss
|
| 7 |
-
|
| 8 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def transition(N):
|
| 12 |
-
Q = np.arange(N, dtype=np.float64)
|
| 13 |
-
R = (2 * Q + 1)[:, None] # / theta
|
| 14 |
-
j, i = np.meshgrid(Q, Q)
|
| 15 |
-
A = np.where(i < j, -1, (-1.) ** (i - j + 1)) * R
|
| 16 |
-
B = (-1.) ** Q[:, None] * R
|
| 17 |
-
return A, B
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class HiPPO_LegT(nn.Module):
|
| 21 |
-
def __init__(self, N, dt=1.0, discretization='bilinear'):
|
| 22 |
-
"""
|
| 23 |
-
N: the order of the HiPPO projection
|
| 24 |
-
dt: discretization step size - should be roughly inverse to the length of the sequence
|
| 25 |
-
"""
|
| 26 |
-
super(HiPPO_LegT, self).__init__()
|
| 27 |
-
self.N = N
|
| 28 |
-
A, B = transition(N)
|
| 29 |
-
C = np.ones((1, N))
|
| 30 |
-
D = np.zeros((1,))
|
| 31 |
-
A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization)
|
| 32 |
-
|
| 33 |
-
B = B.squeeze(-1)
|
| 34 |
-
|
| 35 |
-
self.register_buffer('A', torch.Tensor(A).to(device))
|
| 36 |
-
self.register_buffer('B', torch.Tensor(B).to(device))
|
| 37 |
-
vals = np.arange(0.0, 1.0, dt)
|
| 38 |
-
self.register_buffer('eval_matrix', torch.Tensor(
|
| 39 |
-
ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * vals).T).to(device))
|
| 40 |
-
|
| 41 |
-
def forward(self, inputs):
|
| 42 |
-
"""
|
| 43 |
-
inputs : (length, ...)
|
| 44 |
-
output : (length, ..., N) where N is the order of the HiPPO projection
|
| 45 |
-
"""
|
| 46 |
-
c = torch.zeros(inputs.shape[:-1] + tuple([self.N])).to(device)
|
| 47 |
-
cs = []
|
| 48 |
-
for f in inputs.permute([-1, 0, 1]):
|
| 49 |
-
f = f.unsqueeze(-1)
|
| 50 |
-
new = f @ self.B.unsqueeze(0)
|
| 51 |
-
c = F.linear(c, self.A) + new
|
| 52 |
-
cs.append(c)
|
| 53 |
-
return torch.stack(cs, dim=0)
|
| 54 |
-
|
| 55 |
-
def reconstruct(self, c):
|
| 56 |
-
return (self.eval_matrix @ c.unsqueeze(-1)).squeeze(-1)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class SpectralConv1d(nn.Module):
|
| 60 |
-
def __init__(self, in_channels, out_channels, seq_len, ratio=0.5):
|
| 61 |
-
"""
|
| 62 |
-
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
|
| 63 |
-
"""
|
| 64 |
-
super(SpectralConv1d, self).__init__()
|
| 65 |
-
self.in_channels = in_channels
|
| 66 |
-
self.out_channels = out_channels
|
| 67 |
-
self.ratio = ratio
|
| 68 |
-
self.modes = min(32, seq_len // 2)
|
| 69 |
-
self.index = list(range(0, self.modes))
|
| 70 |
-
|
| 71 |
-
self.scale = (1 / (in_channels * out_channels))
|
| 72 |
-
self.weights_real = nn.Parameter(
|
| 73 |
-
self.scale * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.float))
|
| 74 |
-
self.weights_imag = nn.Parameter(
|
| 75 |
-
self.scale * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.float))
|
| 76 |
-
|
| 77 |
-
def compl_mul1d(self, order, x, weights_real, weights_imag):
|
| 78 |
-
return torch.complex(torch.einsum(order, x.real, weights_real) - torch.einsum(order, x.imag, weights_imag),
|
| 79 |
-
torch.einsum(order, x.real, weights_imag) + torch.einsum(order, x.imag, weights_real))
|
| 80 |
-
|
| 81 |
-
def forward(self, x):
|
| 82 |
-
B, H, E, N = x.shape
|
| 83 |
-
x_ft = torch.fft.rfft(x)
|
| 84 |
-
out_ft = torch.zeros(B, H, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat)
|
| 85 |
-
a = x_ft[:, :, :, :self.modes]
|
| 86 |
-
out_ft[:, :, :, :self.modes] = self.compl_mul1d("bjix,iox->bjox", a, self.weights_real, self.weights_imag)
|
| 87 |
-
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
| 88 |
-
return x
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class Model(nn.Module):
|
| 92 |
-
"""
|
| 93 |
-
Paper link: https://arxiv.org/abs/2205.08897
|
| 94 |
-
"""
|
| 95 |
-
def __init__(self, configs):
|
| 96 |
-
super(Model, self).__init__()
|
| 97 |
-
self.task_name = configs.task_name
|
| 98 |
-
self.configs = configs
|
| 99 |
-
self.seq_len = configs.seq_len
|
| 100 |
-
self.label_len = configs.label_len
|
| 101 |
-
self.pred_len = configs.seq_len if configs.pred_len == 0 else configs.pred_len
|
| 102 |
-
|
| 103 |
-
self.seq_len_all = self.seq_len + self.label_len
|
| 104 |
-
|
| 105 |
-
self.layers = configs.e_layers
|
| 106 |
-
self.enc_in = configs.enc_in
|
| 107 |
-
self.e_layers = configs.e_layers
|
| 108 |
-
# b, s, f means b, f
|
| 109 |
-
self.affine_weight = nn.Parameter(torch.ones(1, 1, configs.enc_in))
|
| 110 |
-
self.affine_bias = nn.Parameter(torch.zeros(1, 1, configs.enc_in))
|
| 111 |
-
|
| 112 |
-
self.multiscale = [1, 2, 4]
|
| 113 |
-
self.window_size = [256]
|
| 114 |
-
configs.ratio = 0.5
|
| 115 |
-
self.legts = nn.ModuleList(
|
| 116 |
-
[HiPPO_LegT(N=n, dt=1. / self.pred_len / i) for n in self.window_size for i in self.multiscale])
|
| 117 |
-
self.spec_conv_1 = nn.ModuleList([SpectralConv1d(in_channels=n, out_channels=n,
|
| 118 |
-
seq_len=min(self.pred_len, self.seq_len),
|
| 119 |
-
ratio=configs.ratio) for n in
|
| 120 |
-
self.window_size for _ in range(len(self.multiscale))])
|
| 121 |
-
self.mlp = nn.Linear(len(self.multiscale) * len(self.window_size), 1)
|
| 122 |
-
|
| 123 |
-
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
| 124 |
-
self.projection = nn.Linear(
|
| 125 |
-
configs.d_model, configs.c_out, bias=True)
|
| 126 |
-
if self.task_name == 'classification':
|
| 127 |
-
self.act = F.gelu
|
| 128 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 129 |
-
self.projection = nn.Linear(
|
| 130 |
-
configs.enc_in * configs.seq_len, configs.num_class)
|
| 131 |
-
|
| 132 |
-
def forecast(self, x_enc, x_mark_enc, x_dec_true, x_mark_dec):
|
| 133 |
-
# Normalization from Non-stationary Transformer
|
| 134 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 135 |
-
x_enc = x_enc - means
|
| 136 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 137 |
-
x_enc /= stdev
|
| 138 |
-
|
| 139 |
-
x_enc = x_enc * self.affine_weight + self.affine_bias
|
| 140 |
-
x_decs = []
|
| 141 |
-
jump_dist = 0
|
| 142 |
-
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
| 143 |
-
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
| 144 |
-
x_in = x_enc[:, -x_in_len:]
|
| 145 |
-
legt = self.legts[i]
|
| 146 |
-
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
| 147 |
-
out1 = self.spec_conv_1[i](x_in_c)
|
| 148 |
-
if self.seq_len >= self.pred_len:
|
| 149 |
-
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
| 150 |
-
else:
|
| 151 |
-
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
| 152 |
-
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
| 153 |
-
x_decs.append(x_dec)
|
| 154 |
-
x_dec = torch.stack(x_decs, dim=-1)
|
| 155 |
-
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
| 156 |
-
|
| 157 |
-
# De-Normalization from Non-stationary Transformer
|
| 158 |
-
x_dec = x_dec - self.affine_bias
|
| 159 |
-
x_dec = x_dec / (self.affine_weight + 1e-10)
|
| 160 |
-
x_dec = x_dec * stdev
|
| 161 |
-
x_dec = x_dec + means
|
| 162 |
-
return x_dec
|
| 163 |
-
|
| 164 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 165 |
-
# Normalization from Non-stationary Transformer
|
| 166 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 167 |
-
x_enc = x_enc - means
|
| 168 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 169 |
-
x_enc /= stdev
|
| 170 |
-
|
| 171 |
-
x_enc = x_enc * self.affine_weight + self.affine_bias
|
| 172 |
-
x_decs = []
|
| 173 |
-
jump_dist = 0
|
| 174 |
-
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
| 175 |
-
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
| 176 |
-
x_in = x_enc[:, -x_in_len:]
|
| 177 |
-
legt = self.legts[i]
|
| 178 |
-
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
| 179 |
-
out1 = self.spec_conv_1[i](x_in_c)
|
| 180 |
-
if self.seq_len >= self.pred_len:
|
| 181 |
-
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
| 182 |
-
else:
|
| 183 |
-
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
| 184 |
-
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
| 185 |
-
x_decs.append(x_dec)
|
| 186 |
-
x_dec = torch.stack(x_decs, dim=-1)
|
| 187 |
-
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
| 188 |
-
|
| 189 |
-
# De-Normalization from Non-stationary Transformer
|
| 190 |
-
x_dec = x_dec - self.affine_bias
|
| 191 |
-
x_dec = x_dec / (self.affine_weight + 1e-10)
|
| 192 |
-
x_dec = x_dec * stdev
|
| 193 |
-
x_dec = x_dec + means
|
| 194 |
-
return x_dec
|
| 195 |
-
|
| 196 |
-
def anomaly_detection(self, x_enc):
|
| 197 |
-
# Normalization from Non-stationary Transformer
|
| 198 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 199 |
-
x_enc = x_enc - means
|
| 200 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 201 |
-
x_enc /= stdev
|
| 202 |
-
|
| 203 |
-
x_enc = x_enc * self.affine_weight + self.affine_bias
|
| 204 |
-
x_decs = []
|
| 205 |
-
jump_dist = 0
|
| 206 |
-
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
| 207 |
-
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
| 208 |
-
x_in = x_enc[:, -x_in_len:]
|
| 209 |
-
legt = self.legts[i]
|
| 210 |
-
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
| 211 |
-
out1 = self.spec_conv_1[i](x_in_c)
|
| 212 |
-
if self.seq_len >= self.pred_len:
|
| 213 |
-
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
| 214 |
-
else:
|
| 215 |
-
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
| 216 |
-
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
| 217 |
-
x_decs.append(x_dec)
|
| 218 |
-
x_dec = torch.stack(x_decs, dim=-1)
|
| 219 |
-
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
| 220 |
-
|
| 221 |
-
# De-Normalization from Non-stationary Transformer
|
| 222 |
-
x_dec = x_dec - self.affine_bias
|
| 223 |
-
x_dec = x_dec / (self.affine_weight + 1e-10)
|
| 224 |
-
x_dec = x_dec * stdev
|
| 225 |
-
x_dec = x_dec + means
|
| 226 |
-
return x_dec
|
| 227 |
-
|
| 228 |
-
def classification(self, x_enc, x_mark_enc):
|
| 229 |
-
x_enc = x_enc * self.affine_weight + self.affine_bias
|
| 230 |
-
x_decs = []
|
| 231 |
-
jump_dist = 0
|
| 232 |
-
for i in range(0, len(self.multiscale) * len(self.window_size)):
|
| 233 |
-
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
|
| 234 |
-
x_in = x_enc[:, -x_in_len:]
|
| 235 |
-
legt = self.legts[i]
|
| 236 |
-
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
|
| 237 |
-
out1 = self.spec_conv_1[i](x_in_c)
|
| 238 |
-
if self.seq_len >= self.pred_len:
|
| 239 |
-
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
|
| 240 |
-
else:
|
| 241 |
-
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
|
| 242 |
-
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
|
| 243 |
-
x_decs.append(x_dec)
|
| 244 |
-
x_dec = torch.stack(x_decs, dim=-1)
|
| 245 |
-
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
|
| 246 |
-
|
| 247 |
-
# Output from Non-stationary Transformer
|
| 248 |
-
output = self.act(x_dec)
|
| 249 |
-
output = self.dropout(output)
|
| 250 |
-
output = output * x_mark_enc.unsqueeze(-1)
|
| 251 |
-
output = output.reshape(output.shape[0], -1)
|
| 252 |
-
output = self.projection(output)
|
| 253 |
-
return output
|
| 254 |
-
|
| 255 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 256 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 257 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 258 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 259 |
-
if self.task_name == 'imputation':
|
| 260 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 261 |
-
return dec_out # [B, L, D]
|
| 262 |
-
if self.task_name == 'anomaly_detection':
|
| 263 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 264 |
-
return dec_out # [B, L, D]
|
| 265 |
-
if self.task_name == 'classification':
|
| 266 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 267 |
-
return dec_out # [B, N]
|
| 268 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FreTS.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import numpy as np
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class Model(nn.Module):
|
| 8 |
-
"""
|
| 9 |
-
Paper link: https://arxiv.org/pdf/2311.06184.pdf
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def __init__(self, configs):
|
| 13 |
-
super(Model, self).__init__()
|
| 14 |
-
self.task_name = configs.task_name
|
| 15 |
-
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
| 16 |
-
self.pred_len = configs.seq_len
|
| 17 |
-
else:
|
| 18 |
-
self.pred_len = configs.pred_len
|
| 19 |
-
self.embed_size = 128 # embed_size
|
| 20 |
-
self.hidden_size = 256 # hidden_size
|
| 21 |
-
self.pred_len = configs.pred_len
|
| 22 |
-
self.feature_size = configs.enc_in # channels
|
| 23 |
-
self.seq_len = configs.seq_len
|
| 24 |
-
self.channel_independence = configs.channel_independence
|
| 25 |
-
self.sparsity_threshold = 0.01
|
| 26 |
-
self.scale = 0.02
|
| 27 |
-
self.embeddings = nn.Parameter(torch.randn(1, self.embed_size))
|
| 28 |
-
self.r1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
| 29 |
-
self.i1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
| 30 |
-
self.rb1 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
| 31 |
-
self.ib1 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
| 32 |
-
self.r2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
| 33 |
-
self.i2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size))
|
| 34 |
-
self.rb2 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
| 35 |
-
self.ib2 = nn.Parameter(self.scale * torch.randn(self.embed_size))
|
| 36 |
-
|
| 37 |
-
self.fc = nn.Sequential(
|
| 38 |
-
nn.Linear(self.seq_len * self.embed_size, self.hidden_size),
|
| 39 |
-
nn.LeakyReLU(),
|
| 40 |
-
nn.Linear(self.hidden_size, self.pred_len)
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# dimension extension
|
| 44 |
-
def tokenEmb(self, x):
|
| 45 |
-
# x: [Batch, Input length, Channel]
|
| 46 |
-
x = x.permute(0, 2, 1)
|
| 47 |
-
x = x.unsqueeze(3)
|
| 48 |
-
# N*T*1 x 1*D = N*T*D
|
| 49 |
-
y = self.embeddings
|
| 50 |
-
return x * y
|
| 51 |
-
|
| 52 |
-
# frequency temporal learner
|
| 53 |
-
def MLP_temporal(self, x, B, N, L):
|
| 54 |
-
# [B, N, T, D]
|
| 55 |
-
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension
|
| 56 |
-
y = self.FreMLP(B, N, L, x, self.r2, self.i2, self.rb2, self.ib2)
|
| 57 |
-
x = torch.fft.irfft(y, n=self.seq_len, dim=2, norm="ortho")
|
| 58 |
-
return x
|
| 59 |
-
|
| 60 |
-
# frequency channel learner
|
| 61 |
-
def MLP_channel(self, x, B, N, L):
|
| 62 |
-
# [B, N, T, D]
|
| 63 |
-
x = x.permute(0, 2, 1, 3)
|
| 64 |
-
# [B, T, N, D]
|
| 65 |
-
x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on N dimension
|
| 66 |
-
y = self.FreMLP(B, L, N, x, self.r1, self.i1, self.rb1, self.ib1)
|
| 67 |
-
x = torch.fft.irfft(y, n=self.feature_size, dim=2, norm="ortho")
|
| 68 |
-
x = x.permute(0, 2, 1, 3)
|
| 69 |
-
# [B, N, T, D]
|
| 70 |
-
return x
|
| 71 |
-
|
| 72 |
-
# frequency-domain MLPs
|
| 73 |
-
# dimension: FFT along the dimension, r: the real part of weights, i: the imaginary part of weights
|
| 74 |
-
# rb: the real part of bias, ib: the imaginary part of bias
|
| 75 |
-
def FreMLP(self, B, nd, dimension, x, r, i, rb, ib):
|
| 76 |
-
o1_real = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size],
|
| 77 |
-
device=x.device)
|
| 78 |
-
o1_imag = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size],
|
| 79 |
-
device=x.device)
|
| 80 |
-
|
| 81 |
-
o1_real = F.relu(
|
| 82 |
-
torch.einsum('bijd,dd->bijd', x.real, r) - \
|
| 83 |
-
torch.einsum('bijd,dd->bijd', x.imag, i) + \
|
| 84 |
-
rb
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
o1_imag = F.relu(
|
| 88 |
-
torch.einsum('bijd,dd->bijd', x.imag, r) + \
|
| 89 |
-
torch.einsum('bijd,dd->bijd', x.real, i) + \
|
| 90 |
-
ib
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
y = torch.stack([o1_real, o1_imag], dim=-1)
|
| 94 |
-
y = F.softshrink(y, lambd=self.sparsity_threshold)
|
| 95 |
-
y = torch.view_as_complex(y)
|
| 96 |
-
return y
|
| 97 |
-
|
| 98 |
-
def forecast(self, x_enc):
|
| 99 |
-
# x: [Batch, Input length, Channel]
|
| 100 |
-
B, T, N = x_enc.shape
|
| 101 |
-
# embedding x: [B, N, T, D]
|
| 102 |
-
x = self.tokenEmb(x_enc)
|
| 103 |
-
bias = x
|
| 104 |
-
# [B, N, T, D]
|
| 105 |
-
if self.channel_independence == '0':
|
| 106 |
-
x = self.MLP_channel(x, B, N, T)
|
| 107 |
-
# [B, N, T, D]
|
| 108 |
-
x = self.MLP_temporal(x, B, N, T)
|
| 109 |
-
x = x + bias
|
| 110 |
-
x = self.fc(x.reshape(B, N, -1)).permute(0, 2, 1)
|
| 111 |
-
return x
|
| 112 |
-
|
| 113 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 114 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 115 |
-
dec_out = self.forecast(x_enc)
|
| 116 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 117 |
-
else:
|
| 118 |
-
raise ValueError('Only forecast tasks implemented yet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Informer.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
|
| 5 |
-
from layers.SelfAttention_Family import ProbAttention, AttentionLayer
|
| 6 |
-
from layers.Embed import DataEmbedding
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class Model(nn.Module):
|
| 10 |
-
"""
|
| 11 |
-
Informer with Propspare attention in O(LlogL) complexity
|
| 12 |
-
Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
def __init__(self, configs):
|
| 16 |
-
super(Model, self).__init__()
|
| 17 |
-
self.task_name = configs.task_name
|
| 18 |
-
self.pred_len = configs.pred_len
|
| 19 |
-
self.label_len = configs.label_len
|
| 20 |
-
|
| 21 |
-
# Embedding
|
| 22 |
-
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 23 |
-
configs.dropout)
|
| 24 |
-
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
| 25 |
-
configs.dropout)
|
| 26 |
-
|
| 27 |
-
# Encoder
|
| 28 |
-
self.encoder = Encoder(
|
| 29 |
-
[
|
| 30 |
-
EncoderLayer(
|
| 31 |
-
AttentionLayer(
|
| 32 |
-
ProbAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 33 |
-
output_attention=False),
|
| 34 |
-
configs.d_model, configs.n_heads),
|
| 35 |
-
configs.d_model,
|
| 36 |
-
configs.d_ff,
|
| 37 |
-
dropout=configs.dropout,
|
| 38 |
-
activation=configs.activation
|
| 39 |
-
) for l in range(configs.e_layers)
|
| 40 |
-
],
|
| 41 |
-
[
|
| 42 |
-
ConvLayer(
|
| 43 |
-
configs.d_model
|
| 44 |
-
) for l in range(configs.e_layers - 1)
|
| 45 |
-
] if configs.distil and ('forecast' in configs.task_name) else None,
|
| 46 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
| 47 |
-
)
|
| 48 |
-
# Decoder
|
| 49 |
-
self.decoder = Decoder(
|
| 50 |
-
[
|
| 51 |
-
DecoderLayer(
|
| 52 |
-
AttentionLayer(
|
| 53 |
-
ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),
|
| 54 |
-
configs.d_model, configs.n_heads),
|
| 55 |
-
AttentionLayer(
|
| 56 |
-
ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
|
| 57 |
-
configs.d_model, configs.n_heads),
|
| 58 |
-
configs.d_model,
|
| 59 |
-
configs.d_ff,
|
| 60 |
-
dropout=configs.dropout,
|
| 61 |
-
activation=configs.activation,
|
| 62 |
-
)
|
| 63 |
-
for l in range(configs.d_layers)
|
| 64 |
-
],
|
| 65 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model),
|
| 66 |
-
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 67 |
-
)
|
| 68 |
-
if self.task_name == 'imputation':
|
| 69 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 70 |
-
if self.task_name == 'anomaly_detection':
|
| 71 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 72 |
-
if self.task_name == 'classification':
|
| 73 |
-
self.act = F.gelu
|
| 74 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 75 |
-
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
| 76 |
-
|
| 77 |
-
def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 78 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 79 |
-
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
| 80 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 81 |
-
|
| 82 |
-
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
|
| 83 |
-
|
| 84 |
-
return dec_out # [B, L, D]
|
| 85 |
-
|
| 86 |
-
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 87 |
-
# Normalization
|
| 88 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
| 89 |
-
x_enc = x_enc - mean_enc
|
| 90 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
| 91 |
-
x_enc = x_enc / std_enc
|
| 92 |
-
|
| 93 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 94 |
-
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
| 95 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 96 |
-
|
| 97 |
-
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
|
| 98 |
-
|
| 99 |
-
dec_out = dec_out * std_enc + mean_enc
|
| 100 |
-
return dec_out # [B, L, D]
|
| 101 |
-
|
| 102 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 103 |
-
# enc
|
| 104 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 105 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 106 |
-
# final
|
| 107 |
-
dec_out = self.projection(enc_out)
|
| 108 |
-
return dec_out
|
| 109 |
-
|
| 110 |
-
def anomaly_detection(self, x_enc):
|
| 111 |
-
# enc
|
| 112 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 113 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 114 |
-
# final
|
| 115 |
-
dec_out = self.projection(enc_out)
|
| 116 |
-
return dec_out
|
| 117 |
-
|
| 118 |
-
def classification(self, x_enc, x_mark_enc):
|
| 119 |
-
# enc
|
| 120 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 121 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 122 |
-
|
| 123 |
-
# Output
|
| 124 |
-
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
| 125 |
-
output = self.dropout(output)
|
| 126 |
-
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
| 127 |
-
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
| 128 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 129 |
-
return output
|
| 130 |
-
|
| 131 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 132 |
-
if self.task_name == 'long_term_forecast':
|
| 133 |
-
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 134 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 135 |
-
if self.task_name == 'short_term_forecast':
|
| 136 |
-
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 137 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 138 |
-
if self.task_name == 'imputation':
|
| 139 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 140 |
-
return dec_out # [B, L, D]
|
| 141 |
-
if self.task_name == 'anomaly_detection':
|
| 142 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 143 |
-
return dec_out # [B, L, D]
|
| 144 |
-
if self.task_name == 'classification':
|
| 145 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 146 |
-
return dec_out # [B, N]
|
| 147 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Koopa.py
DELETED
|
@@ -1,337 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from data_provider.data_factory import data_provider
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class FourierFilter(nn.Module):
|
| 9 |
-
"""
|
| 10 |
-
Fourier Filter: to time-variant and time-invariant term
|
| 11 |
-
"""
|
| 12 |
-
def __init__(self, mask_spectrum):
|
| 13 |
-
super(FourierFilter, self).__init__()
|
| 14 |
-
self.mask_spectrum = mask_spectrum
|
| 15 |
-
|
| 16 |
-
def forward(self, x):
|
| 17 |
-
xf = torch.fft.rfft(x, dim=1)
|
| 18 |
-
mask = torch.ones_like(xf)
|
| 19 |
-
mask[:, self.mask_spectrum, :] = 0
|
| 20 |
-
x_var = torch.fft.irfft(xf*mask, dim=1)
|
| 21 |
-
x_inv = x - x_var
|
| 22 |
-
|
| 23 |
-
return x_var, x_inv
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class MLP(nn.Module):
|
| 27 |
-
'''
|
| 28 |
-
Multilayer perceptron to encode/decode high dimension representation of sequential data
|
| 29 |
-
'''
|
| 30 |
-
def __init__(self,
|
| 31 |
-
f_in,
|
| 32 |
-
f_out,
|
| 33 |
-
hidden_dim=128,
|
| 34 |
-
hidden_layers=2,
|
| 35 |
-
dropout=0.05,
|
| 36 |
-
activation='tanh'):
|
| 37 |
-
super(MLP, self).__init__()
|
| 38 |
-
self.f_in = f_in
|
| 39 |
-
self.f_out = f_out
|
| 40 |
-
self.hidden_dim = hidden_dim
|
| 41 |
-
self.hidden_layers = hidden_layers
|
| 42 |
-
self.dropout = dropout
|
| 43 |
-
if activation == 'relu':
|
| 44 |
-
self.activation = nn.ReLU()
|
| 45 |
-
elif activation == 'tanh':
|
| 46 |
-
self.activation = nn.Tanh()
|
| 47 |
-
else:
|
| 48 |
-
raise NotImplementedError
|
| 49 |
-
|
| 50 |
-
layers = [nn.Linear(self.f_in, self.hidden_dim),
|
| 51 |
-
self.activation, nn.Dropout(self.dropout)]
|
| 52 |
-
for i in range(self.hidden_layers-2):
|
| 53 |
-
layers += [nn.Linear(self.hidden_dim, self.hidden_dim),
|
| 54 |
-
self.activation, nn.Dropout(dropout)]
|
| 55 |
-
|
| 56 |
-
layers += [nn.Linear(hidden_dim, f_out)]
|
| 57 |
-
self.layers = nn.Sequential(*layers)
|
| 58 |
-
|
| 59 |
-
def forward(self, x):
|
| 60 |
-
# x: B x S x f_in
|
| 61 |
-
# y: B x S x f_out
|
| 62 |
-
y = self.layers(x)
|
| 63 |
-
return y
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
class KPLayer(nn.Module):
|
| 67 |
-
"""
|
| 68 |
-
A demonstration of finding one step transition of linear system by DMD iteratively
|
| 69 |
-
"""
|
| 70 |
-
def __init__(self):
|
| 71 |
-
super(KPLayer, self).__init__()
|
| 72 |
-
|
| 73 |
-
self.K = None # B E E
|
| 74 |
-
|
| 75 |
-
def one_step_forward(self, z, return_rec=False, return_K=False):
|
| 76 |
-
B, input_len, E = z.shape
|
| 77 |
-
assert input_len > 1, 'snapshots number should be larger than 1'
|
| 78 |
-
x, y = z[:, :-1], z[:, 1:]
|
| 79 |
-
|
| 80 |
-
# solve linear system
|
| 81 |
-
self.K = torch.linalg.lstsq(x, y).solution # B E E
|
| 82 |
-
if torch.isnan(self.K).any():
|
| 83 |
-
print('Encounter K with nan, replace K by identity matrix')
|
| 84 |
-
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
|
| 85 |
-
|
| 86 |
-
z_pred = torch.bmm(z[:, -1:], self.K)
|
| 87 |
-
if return_rec:
|
| 88 |
-
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1)
|
| 89 |
-
return z_rec, z_pred
|
| 90 |
-
|
| 91 |
-
return z_pred
|
| 92 |
-
|
| 93 |
-
def forward(self, z, pred_len=1):
|
| 94 |
-
assert pred_len >= 1, 'prediction length should not be less than 1'
|
| 95 |
-
z_rec, z_pred= self.one_step_forward(z, return_rec=True)
|
| 96 |
-
z_preds = [z_pred]
|
| 97 |
-
for i in range(1, pred_len):
|
| 98 |
-
z_pred = torch.bmm(z_pred, self.K)
|
| 99 |
-
z_preds.append(z_pred)
|
| 100 |
-
z_preds = torch.cat(z_preds, dim=1)
|
| 101 |
-
return z_rec, z_preds
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class KPLayerApprox(nn.Module):
|
| 105 |
-
"""
|
| 106 |
-
Find koopman transition of linear system by DMD with multistep K approximation
|
| 107 |
-
"""
|
| 108 |
-
def __init__(self):
|
| 109 |
-
super(KPLayerApprox, self).__init__()
|
| 110 |
-
|
| 111 |
-
self.K = None # B E E
|
| 112 |
-
self.K_step = None # B E E
|
| 113 |
-
|
| 114 |
-
def forward(self, z, pred_len=1):
|
| 115 |
-
# z: B L E, koopman invariance space representation
|
| 116 |
-
# z_rec: B L E, reconstructed representation
|
| 117 |
-
# z_pred: B S E, forecasting representation
|
| 118 |
-
B, input_len, E = z.shape
|
| 119 |
-
assert input_len > 1, 'snapshots number should be larger than 1'
|
| 120 |
-
x, y = z[:, :-1], z[:, 1:]
|
| 121 |
-
|
| 122 |
-
# solve linear system
|
| 123 |
-
self.K = torch.linalg.lstsq(x, y).solution # B E E
|
| 124 |
-
|
| 125 |
-
if torch.isnan(self.K).any():
|
| 126 |
-
print('Encounter K with nan, replace K by identity matrix')
|
| 127 |
-
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1)
|
| 128 |
-
|
| 129 |
-
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E
|
| 130 |
-
|
| 131 |
-
if pred_len <= input_len:
|
| 132 |
-
self.K_step = torch.linalg.matrix_power(self.K, pred_len)
|
| 133 |
-
if torch.isnan(self.K_step).any():
|
| 134 |
-
print('Encounter multistep K with nan, replace it by identity matrix')
|
| 135 |
-
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
|
| 136 |
-
z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step)
|
| 137 |
-
else:
|
| 138 |
-
self.K_step = torch.linalg.matrix_power(self.K, input_len)
|
| 139 |
-
if torch.isnan(self.K_step).any():
|
| 140 |
-
print('Encounter multistep K with nan, replace it by identity matrix')
|
| 141 |
-
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
|
| 142 |
-
temp_z_pred, all_pred = z, []
|
| 143 |
-
for _ in range(math.ceil(pred_len / input_len)):
|
| 144 |
-
temp_z_pred = torch.bmm(temp_z_pred, self.K_step)
|
| 145 |
-
all_pred.append(temp_z_pred)
|
| 146 |
-
z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :]
|
| 147 |
-
|
| 148 |
-
return z_rec, z_pred
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
class TimeVarKP(nn.Module):
|
| 152 |
-
"""
|
| 153 |
-
Koopman Predictor with DMD (analysitical solution of Koopman operator)
|
| 154 |
-
Utilize local variations within individual sliding window to predict the future of time-variant term
|
| 155 |
-
"""
|
| 156 |
-
def __init__(self,
|
| 157 |
-
enc_in=8,
|
| 158 |
-
input_len=96,
|
| 159 |
-
pred_len=96,
|
| 160 |
-
seg_len=24,
|
| 161 |
-
dynamic_dim=128,
|
| 162 |
-
encoder=None,
|
| 163 |
-
decoder=None,
|
| 164 |
-
multistep=False,
|
| 165 |
-
):
|
| 166 |
-
super(TimeVarKP, self).__init__()
|
| 167 |
-
self.input_len = input_len
|
| 168 |
-
self.pred_len = pred_len
|
| 169 |
-
self.enc_in = enc_in
|
| 170 |
-
self.seg_len = seg_len
|
| 171 |
-
self.dynamic_dim = dynamic_dim
|
| 172 |
-
self.multistep = multistep
|
| 173 |
-
self.encoder, self.decoder = encoder, decoder
|
| 174 |
-
self.freq = math.ceil(self.input_len / self.seg_len) # segment number of input
|
| 175 |
-
self.step = math.ceil(self.pred_len / self.seg_len) # segment number of output
|
| 176 |
-
self.padding_len = self.seg_len * self.freq - self.input_len
|
| 177 |
-
# Approximate mulitstep K by KPLayerApprox when pred_len is large
|
| 178 |
-
self.dynamics = KPLayerApprox() if self.multistep else KPLayer()
|
| 179 |
-
|
| 180 |
-
def forward(self, x):
|
| 181 |
-
# x: B L C
|
| 182 |
-
B, L, C = x.shape
|
| 183 |
-
|
| 184 |
-
res = torch.cat((x[:, L-self.padding_len:, :], x) ,dim=1)
|
| 185 |
-
|
| 186 |
-
res = res.chunk(self.freq, dim=1) # F x B P C, P means seg_len
|
| 187 |
-
res = torch.stack(res, dim=1).reshape(B, self.freq, -1) # B F PC
|
| 188 |
-
|
| 189 |
-
res = self.encoder(res) # B F H
|
| 190 |
-
x_rec, x_pred = self.dynamics(res, self.step) # B F H, B S H
|
| 191 |
-
|
| 192 |
-
x_rec = self.decoder(x_rec) # B F PC
|
| 193 |
-
x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in)
|
| 194 |
-
x_rec = x_rec.reshape(B, -1, self.enc_in)[:, :self.input_len, :] # B L C
|
| 195 |
-
|
| 196 |
-
x_pred = self.decoder(x_pred) # B S PC
|
| 197 |
-
x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in)
|
| 198 |
-
x_pred = x_pred.reshape(B, -1, self.enc_in)[:, :self.pred_len, :] # B S C
|
| 199 |
-
|
| 200 |
-
return x_rec, x_pred
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
class TimeInvKP(nn.Module):
|
| 204 |
-
"""
|
| 205 |
-
Koopman Predictor with learnable Koopman operator
|
| 206 |
-
Utilize lookback and forecast window snapshots to predict the future of time-invariant term
|
| 207 |
-
"""
|
| 208 |
-
def __init__(self,
|
| 209 |
-
input_len=96,
|
| 210 |
-
pred_len=96,
|
| 211 |
-
dynamic_dim=128,
|
| 212 |
-
encoder=None,
|
| 213 |
-
decoder=None):
|
| 214 |
-
super(TimeInvKP, self).__init__()
|
| 215 |
-
self.dynamic_dim = dynamic_dim
|
| 216 |
-
self.input_len = input_len
|
| 217 |
-
self.pred_len = pred_len
|
| 218 |
-
self.encoder = encoder
|
| 219 |
-
self.decoder = decoder
|
| 220 |
-
|
| 221 |
-
K_init = torch.randn(self.dynamic_dim, self.dynamic_dim)
|
| 222 |
-
U, _, V = torch.svd(K_init) # stable initialization
|
| 223 |
-
self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False)
|
| 224 |
-
self.K.weight.data = torch.mm(U, V.t())
|
| 225 |
-
|
| 226 |
-
def forward(self, x):
|
| 227 |
-
# x: B L C
|
| 228 |
-
res = x.transpose(1, 2) # B C L
|
| 229 |
-
res = self.encoder(res) # B C H
|
| 230 |
-
res = self.K(res) # B C H
|
| 231 |
-
res = self.decoder(res) # B C S
|
| 232 |
-
res = res.transpose(1, 2) # B S C
|
| 233 |
-
|
| 234 |
-
return res
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
class Model(nn.Module):
|
| 238 |
-
'''
|
| 239 |
-
Paper link: https://arxiv.org/pdf/2305.18803.pdf
|
| 240 |
-
'''
|
| 241 |
-
def __init__(self, configs, dynamic_dim=128, hidden_dim=64, hidden_layers=2, num_blocks=3, multistep=False):
|
| 242 |
-
"""
|
| 243 |
-
mask_spectrum: list, shared frequency spectrums
|
| 244 |
-
seg_len: int, segment length of time series
|
| 245 |
-
dynamic_dim: int, latent dimension of koopman embedding
|
| 246 |
-
hidden_dim: int, hidden dimension of en/decoder
|
| 247 |
-
hidden_layers: int, number of hidden layers of en/decoder
|
| 248 |
-
num_blocks: int, number of Koopa blocks
|
| 249 |
-
multistep: bool, whether to use approximation for multistep K
|
| 250 |
-
alpha: float, spectrum filter ratio
|
| 251 |
-
"""
|
| 252 |
-
super(Model, self).__init__()
|
| 253 |
-
self.task_name = configs.task_name
|
| 254 |
-
self.enc_in = configs.enc_in
|
| 255 |
-
self.input_len = configs.seq_len
|
| 256 |
-
self.pred_len = configs.pred_len
|
| 257 |
-
|
| 258 |
-
self.seg_len = self.pred_len
|
| 259 |
-
self.num_blocks = num_blocks
|
| 260 |
-
self.dynamic_dim = dynamic_dim
|
| 261 |
-
self.hidden_dim = hidden_dim
|
| 262 |
-
self.hidden_layers = hidden_layers
|
| 263 |
-
self.multistep = multistep
|
| 264 |
-
self.alpha = 0.2
|
| 265 |
-
self.mask_spectrum = self._get_mask_spectrum(configs)
|
| 266 |
-
|
| 267 |
-
self.disentanglement = FourierFilter(self.mask_spectrum)
|
| 268 |
-
|
| 269 |
-
# shared encoder/decoder to make koopman embedding consistent
|
| 270 |
-
self.time_inv_encoder = MLP(f_in=self.input_len, f_out=self.dynamic_dim, activation='relu',
|
| 271 |
-
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
| 272 |
-
self.time_inv_decoder = MLP(f_in=self.dynamic_dim, f_out=self.pred_len, activation='relu',
|
| 273 |
-
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
| 274 |
-
self.time_inv_kps = self.time_var_kps = nn.ModuleList([
|
| 275 |
-
TimeInvKP(input_len=self.input_len,
|
| 276 |
-
pred_len=self.pred_len,
|
| 277 |
-
dynamic_dim=self.dynamic_dim,
|
| 278 |
-
encoder=self.time_inv_encoder,
|
| 279 |
-
decoder=self.time_inv_decoder)
|
| 280 |
-
for _ in range(self.num_blocks)])
|
| 281 |
-
|
| 282 |
-
# shared encoder/decoder to make koopman embedding consistent
|
| 283 |
-
self.time_var_encoder = MLP(f_in=self.seg_len*self.enc_in, f_out=self.dynamic_dim, activation='tanh',
|
| 284 |
-
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
| 285 |
-
self.time_var_decoder = MLP(f_in=self.dynamic_dim, f_out=self.seg_len*self.enc_in, activation='tanh',
|
| 286 |
-
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers)
|
| 287 |
-
self.time_var_kps = nn.ModuleList([
|
| 288 |
-
TimeVarKP(enc_in=configs.enc_in,
|
| 289 |
-
input_len=self.input_len,
|
| 290 |
-
pred_len=self.pred_len,
|
| 291 |
-
seg_len=self.seg_len,
|
| 292 |
-
dynamic_dim=self.dynamic_dim,
|
| 293 |
-
encoder=self.time_var_encoder,
|
| 294 |
-
decoder=self.time_var_decoder,
|
| 295 |
-
multistep=self.multistep)
|
| 296 |
-
for _ in range(self.num_blocks)])
|
| 297 |
-
|
| 298 |
-
def _get_mask_spectrum(self, configs):
|
| 299 |
-
"""
|
| 300 |
-
get shared frequency spectrums
|
| 301 |
-
"""
|
| 302 |
-
train_data, train_loader = data_provider(configs, 'train')
|
| 303 |
-
amps = 0.0
|
| 304 |
-
for data in train_loader:
|
| 305 |
-
lookback_window = data[0]
|
| 306 |
-
amps += abs(torch.fft.rfft(lookback_window, dim=1)).mean(dim=0).mean(dim=1)
|
| 307 |
-
mask_spectrum = amps.topk(int(amps.shape[0]*self.alpha)).indices
|
| 308 |
-
return mask_spectrum # as the spectrums of time-invariant component
|
| 309 |
-
|
| 310 |
-
def forecast(self, x_enc):
|
| 311 |
-
# Series Stationarization adopted from NSformer
|
| 312 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
| 313 |
-
x_enc = x_enc - mean_enc
|
| 314 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 315 |
-
x_enc = x_enc / std_enc
|
| 316 |
-
|
| 317 |
-
# Koopman Forecasting
|
| 318 |
-
residual, forecast = x_enc, None
|
| 319 |
-
for i in range(self.num_blocks):
|
| 320 |
-
time_var_input, time_inv_input = self.disentanglement(residual)
|
| 321 |
-
time_inv_output = self.time_inv_kps[i](time_inv_input)
|
| 322 |
-
time_var_backcast, time_var_output = self.time_var_kps[i](time_var_input)
|
| 323 |
-
residual = residual - time_var_backcast
|
| 324 |
-
if forecast is None:
|
| 325 |
-
forecast = (time_inv_output + time_var_output)
|
| 326 |
-
else:
|
| 327 |
-
forecast += (time_inv_output + time_var_output)
|
| 328 |
-
|
| 329 |
-
# Series Stationarization adopted from NSformer
|
| 330 |
-
res = forecast * std_enc + mean_enc
|
| 331 |
-
|
| 332 |
-
return res
|
| 333 |
-
|
| 334 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 335 |
-
if self.task_name == 'long_term_forecast':
|
| 336 |
-
dec_out = self.forecast(x_enc)
|
| 337 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/LightTS.py
DELETED
|
@@ -1,165 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class IEBlock(nn.Module):
|
| 7 |
-
def __init__(self, input_dim, hid_dim, output_dim, num_node):
|
| 8 |
-
super(IEBlock, self).__init__()
|
| 9 |
-
|
| 10 |
-
self.input_dim = input_dim
|
| 11 |
-
self.hid_dim = hid_dim
|
| 12 |
-
self.output_dim = output_dim
|
| 13 |
-
self.num_node = num_node
|
| 14 |
-
|
| 15 |
-
self._build()
|
| 16 |
-
|
| 17 |
-
def _build(self):
|
| 18 |
-
self.spatial_proj = nn.Sequential(
|
| 19 |
-
nn.Linear(self.input_dim, self.hid_dim),
|
| 20 |
-
nn.LeakyReLU(),
|
| 21 |
-
nn.Linear(self.hid_dim, self.hid_dim // 4)
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
self.channel_proj = nn.Linear(self.num_node, self.num_node)
|
| 25 |
-
torch.nn.init.eye_(self.channel_proj.weight)
|
| 26 |
-
|
| 27 |
-
self.output_proj = nn.Linear(self.hid_dim // 4, self.output_dim)
|
| 28 |
-
|
| 29 |
-
def forward(self, x):
|
| 30 |
-
x = self.spatial_proj(x.permute(0, 2, 1))
|
| 31 |
-
x = x.permute(0, 2, 1) + self.channel_proj(x.permute(0, 2, 1))
|
| 32 |
-
x = self.output_proj(x.permute(0, 2, 1))
|
| 33 |
-
|
| 34 |
-
x = x.permute(0, 2, 1)
|
| 35 |
-
|
| 36 |
-
return x
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class Model(nn.Module):
|
| 40 |
-
"""
|
| 41 |
-
Paper link: https://arxiv.org/abs/2207.01186
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
def __init__(self, configs, chunk_size=24):
|
| 45 |
-
"""
|
| 46 |
-
chunk_size: int, reshape T into [num_chunks, chunk_size]
|
| 47 |
-
"""
|
| 48 |
-
super(Model, self).__init__()
|
| 49 |
-
self.task_name = configs.task_name
|
| 50 |
-
self.seq_len = configs.seq_len
|
| 51 |
-
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
| 52 |
-
self.pred_len = configs.seq_len
|
| 53 |
-
else:
|
| 54 |
-
self.pred_len = configs.pred_len
|
| 55 |
-
|
| 56 |
-
if configs.task_name == 'long_term_forecast' or configs.task_name == 'short_term_forecast':
|
| 57 |
-
self.chunk_size = min(configs.pred_len, configs.seq_len, chunk_size)
|
| 58 |
-
else:
|
| 59 |
-
self.chunk_size = min(configs.seq_len, chunk_size)
|
| 60 |
-
# assert (self.seq_len % self.chunk_size == 0)
|
| 61 |
-
if self.seq_len % self.chunk_size != 0:
|
| 62 |
-
self.seq_len += (self.chunk_size - self.seq_len % self.chunk_size) # padding in order to ensure complete division
|
| 63 |
-
self.num_chunks = self.seq_len // self.chunk_size
|
| 64 |
-
|
| 65 |
-
self.d_model = configs.d_model
|
| 66 |
-
self.enc_in = configs.enc_in
|
| 67 |
-
self.dropout = configs.dropout
|
| 68 |
-
if self.task_name == 'classification':
|
| 69 |
-
self.act = F.gelu
|
| 70 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 71 |
-
self.projection = nn.Linear(configs.enc_in * configs.seq_len, configs.num_class)
|
| 72 |
-
self._build()
|
| 73 |
-
|
| 74 |
-
def _build(self):
|
| 75 |
-
self.layer_1 = IEBlock(
|
| 76 |
-
input_dim=self.chunk_size,
|
| 77 |
-
hid_dim=self.d_model // 4,
|
| 78 |
-
output_dim=self.d_model // 4,
|
| 79 |
-
num_node=self.num_chunks
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
self.chunk_proj_1 = nn.Linear(self.num_chunks, 1)
|
| 83 |
-
|
| 84 |
-
self.layer_2 = IEBlock(
|
| 85 |
-
input_dim=self.chunk_size,
|
| 86 |
-
hid_dim=self.d_model // 4,
|
| 87 |
-
output_dim=self.d_model // 4,
|
| 88 |
-
num_node=self.num_chunks
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
self.chunk_proj_2 = nn.Linear(self.num_chunks, 1)
|
| 92 |
-
|
| 93 |
-
self.layer_3 = IEBlock(
|
| 94 |
-
input_dim=self.d_model // 2,
|
| 95 |
-
hid_dim=self.d_model // 2,
|
| 96 |
-
output_dim=self.pred_len,
|
| 97 |
-
num_node=self.enc_in
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
self.ar = nn.Linear(self.seq_len, self.pred_len)
|
| 101 |
-
|
| 102 |
-
def encoder(self, x):
|
| 103 |
-
B, T, N = x.size()
|
| 104 |
-
|
| 105 |
-
highway = self.ar(x.permute(0, 2, 1))
|
| 106 |
-
highway = highway.permute(0, 2, 1)
|
| 107 |
-
|
| 108 |
-
# continuous sampling
|
| 109 |
-
x1 = x.reshape(B, self.num_chunks, self.chunk_size, N)
|
| 110 |
-
x1 = x1.permute(0, 3, 2, 1)
|
| 111 |
-
x1 = x1.reshape(-1, self.chunk_size, self.num_chunks)
|
| 112 |
-
x1 = self.layer_1(x1)
|
| 113 |
-
x1 = self.chunk_proj_1(x1).squeeze(dim=-1)
|
| 114 |
-
|
| 115 |
-
# interval sampling
|
| 116 |
-
x2 = x.reshape(B, self.chunk_size, self.num_chunks, N)
|
| 117 |
-
x2 = x2.permute(0, 3, 1, 2)
|
| 118 |
-
x2 = x2.reshape(-1, self.chunk_size, self.num_chunks)
|
| 119 |
-
x2 = self.layer_2(x2)
|
| 120 |
-
x2 = self.chunk_proj_2(x2).squeeze(dim=-1)
|
| 121 |
-
|
| 122 |
-
x3 = torch.cat([x1, x2], dim=-1)
|
| 123 |
-
|
| 124 |
-
x3 = x3.reshape(B, N, -1)
|
| 125 |
-
x3 = x3.permute(0, 2, 1)
|
| 126 |
-
|
| 127 |
-
out = self.layer_3(x3)
|
| 128 |
-
|
| 129 |
-
out = out + highway
|
| 130 |
-
return out
|
| 131 |
-
|
| 132 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 133 |
-
return self.encoder(x_enc)
|
| 134 |
-
|
| 135 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 136 |
-
return self.encoder(x_enc)
|
| 137 |
-
|
| 138 |
-
def anomaly_detection(self, x_enc):
|
| 139 |
-
return self.encoder(x_enc)
|
| 140 |
-
|
| 141 |
-
def classification(self, x_enc, x_mark_enc):
|
| 142 |
-
# padding
|
| 143 |
-
x_enc = torch.cat([x_enc, torch.zeros((x_enc.shape[0], self.seq_len-x_enc.shape[1], x_enc.shape[2])).to(x_enc.device)], dim=1)
|
| 144 |
-
|
| 145 |
-
enc_out = self.encoder(x_enc)
|
| 146 |
-
|
| 147 |
-
# Output
|
| 148 |
-
output = enc_out.reshape(enc_out.shape[0], -1) # (batch_size, seq_length * d_model)
|
| 149 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 150 |
-
return output
|
| 151 |
-
|
| 152 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 153 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 154 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 155 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 156 |
-
if self.task_name == 'imputation':
|
| 157 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 158 |
-
return dec_out # [B, L, D]
|
| 159 |
-
if self.task_name == 'anomaly_detection':
|
| 160 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 161 |
-
return dec_out # [B, L, D]
|
| 162 |
-
if self.task_name == 'classification':
|
| 163 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 164 |
-
return dec_out # [B, N]
|
| 165 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/MICN.py
DELETED
|
@@ -1,221 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from layers.Embed import DataEmbedding
|
| 4 |
-
from layers.Autoformer_EncDec import series_decomp, series_decomp_multi
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class MIC(nn.Module):
|
| 9 |
-
"""
|
| 10 |
-
MIC layer to extract local and global features
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
def __init__(self, feature_size=512, n_heads=8, dropout=0.05, decomp_kernel=[32], conv_kernel=[24],
|
| 14 |
-
isometric_kernel=[18, 6], device='cuda'):
|
| 15 |
-
super(MIC, self).__init__()
|
| 16 |
-
self.conv_kernel = conv_kernel
|
| 17 |
-
self.device = device
|
| 18 |
-
|
| 19 |
-
# isometric convolution
|
| 20 |
-
self.isometric_conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size,
|
| 21 |
-
kernel_size=i, padding=0, stride=1)
|
| 22 |
-
for i in isometric_kernel])
|
| 23 |
-
|
| 24 |
-
# downsampling convolution: padding=i//2, stride=i
|
| 25 |
-
self.conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size,
|
| 26 |
-
kernel_size=i, padding=i // 2, stride=i)
|
| 27 |
-
for i in conv_kernel])
|
| 28 |
-
|
| 29 |
-
# upsampling convolution
|
| 30 |
-
self.conv_trans = nn.ModuleList([nn.ConvTranspose1d(in_channels=feature_size, out_channels=feature_size,
|
| 31 |
-
kernel_size=i, padding=0, stride=i)
|
| 32 |
-
for i in conv_kernel])
|
| 33 |
-
|
| 34 |
-
self.decomp = nn.ModuleList([series_decomp(k) for k in decomp_kernel])
|
| 35 |
-
self.merge = torch.nn.Conv2d(in_channels=feature_size, out_channels=feature_size,
|
| 36 |
-
kernel_size=(len(self.conv_kernel), 1))
|
| 37 |
-
|
| 38 |
-
# feedforward network
|
| 39 |
-
self.conv1 = nn.Conv1d(in_channels=feature_size, out_channels=feature_size * 4, kernel_size=1)
|
| 40 |
-
self.conv2 = nn.Conv1d(in_channels=feature_size * 4, out_channels=feature_size, kernel_size=1)
|
| 41 |
-
self.norm1 = nn.LayerNorm(feature_size)
|
| 42 |
-
self.norm2 = nn.LayerNorm(feature_size)
|
| 43 |
-
|
| 44 |
-
self.norm = torch.nn.LayerNorm(feature_size)
|
| 45 |
-
self.act = torch.nn.Tanh()
|
| 46 |
-
self.drop = torch.nn.Dropout(0.05)
|
| 47 |
-
|
| 48 |
-
def conv_trans_conv(self, input, conv1d, conv1d_trans, isometric):
|
| 49 |
-
batch, seq_len, channel = input.shape
|
| 50 |
-
x = input.permute(0, 2, 1)
|
| 51 |
-
|
| 52 |
-
# downsampling convolution
|
| 53 |
-
x1 = self.drop(self.act(conv1d(x)))
|
| 54 |
-
x = x1
|
| 55 |
-
|
| 56 |
-
# isometric convolution
|
| 57 |
-
zeros = torch.zeros((x.shape[0], x.shape[1], x.shape[2] - 1), device=self.device)
|
| 58 |
-
x = torch.cat((zeros, x), dim=-1)
|
| 59 |
-
x = self.drop(self.act(isometric(x)))
|
| 60 |
-
x = self.norm((x + x1).permute(0, 2, 1)).permute(0, 2, 1)
|
| 61 |
-
|
| 62 |
-
# upsampling convolution
|
| 63 |
-
x = self.drop(self.act(conv1d_trans(x)))
|
| 64 |
-
x = x[:, :, :seq_len] # truncate
|
| 65 |
-
|
| 66 |
-
x = self.norm(x.permute(0, 2, 1) + input)
|
| 67 |
-
return x
|
| 68 |
-
|
| 69 |
-
def forward(self, src):
|
| 70 |
-
# multi-scale
|
| 71 |
-
multi = []
|
| 72 |
-
for i in range(len(self.conv_kernel)):
|
| 73 |
-
src_out, trend1 = self.decomp[i](src)
|
| 74 |
-
src_out = self.conv_trans_conv(src_out, self.conv[i], self.conv_trans[i], self.isometric_conv[i])
|
| 75 |
-
multi.append(src_out)
|
| 76 |
-
|
| 77 |
-
# merge
|
| 78 |
-
mg = torch.tensor([], device=self.device)
|
| 79 |
-
for i in range(len(self.conv_kernel)):
|
| 80 |
-
mg = torch.cat((mg, multi[i].unsqueeze(1)), dim=1)
|
| 81 |
-
mg = self.merge(mg.permute(0, 3, 1, 2)).squeeze(-2).permute(0, 2, 1)
|
| 82 |
-
|
| 83 |
-
y = self.norm1(mg)
|
| 84 |
-
y = self.conv2(self.conv1(y.transpose(-1, 1))).transpose(-1, 1)
|
| 85 |
-
|
| 86 |
-
return self.norm2(mg + y)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class SeasonalPrediction(nn.Module):
|
| 90 |
-
def __init__(self, embedding_size=512, n_heads=8, dropout=0.05, d_layers=1, decomp_kernel=[32], c_out=1,
|
| 91 |
-
conv_kernel=[2, 4], isometric_kernel=[18, 6], device='cuda'):
|
| 92 |
-
super(SeasonalPrediction, self).__init__()
|
| 93 |
-
|
| 94 |
-
self.mic = nn.ModuleList([MIC(feature_size=embedding_size, n_heads=n_heads,
|
| 95 |
-
decomp_kernel=decomp_kernel, conv_kernel=conv_kernel,
|
| 96 |
-
isometric_kernel=isometric_kernel, device=device)
|
| 97 |
-
for i in range(d_layers)])
|
| 98 |
-
|
| 99 |
-
self.projection = nn.Linear(embedding_size, c_out)
|
| 100 |
-
|
| 101 |
-
def forward(self, dec):
|
| 102 |
-
for mic_layer in self.mic:
|
| 103 |
-
dec = mic_layer(dec)
|
| 104 |
-
return self.projection(dec)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
class Model(nn.Module):
|
| 108 |
-
"""
|
| 109 |
-
Paper link: https://openreview.net/pdf?id=zt53IDUR1U
|
| 110 |
-
"""
|
| 111 |
-
def __init__(self, configs, conv_kernel=[12, 16]):
|
| 112 |
-
"""
|
| 113 |
-
conv_kernel: downsampling and upsampling convolution kernel_size
|
| 114 |
-
"""
|
| 115 |
-
super(Model, self).__init__()
|
| 116 |
-
|
| 117 |
-
decomp_kernel = [] # kernel of decomposition operation
|
| 118 |
-
isometric_kernel = [] # kernel of isometric convolution
|
| 119 |
-
for ii in conv_kernel:
|
| 120 |
-
if ii % 2 == 0: # the kernel of decomposition operation must be odd
|
| 121 |
-
decomp_kernel.append(ii + 1)
|
| 122 |
-
isometric_kernel.append((configs.seq_len + configs.pred_len + ii) // ii)
|
| 123 |
-
else:
|
| 124 |
-
decomp_kernel.append(ii)
|
| 125 |
-
isometric_kernel.append((configs.seq_len + configs.pred_len + ii - 1) // ii)
|
| 126 |
-
|
| 127 |
-
self.task_name = configs.task_name
|
| 128 |
-
self.pred_len = configs.pred_len
|
| 129 |
-
self.seq_len = configs.seq_len
|
| 130 |
-
|
| 131 |
-
# Multiple Series decomposition block from FEDformer
|
| 132 |
-
self.decomp_multi = series_decomp_multi(decomp_kernel)
|
| 133 |
-
|
| 134 |
-
# embedding
|
| 135 |
-
self.dec_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 136 |
-
configs.dropout)
|
| 137 |
-
|
| 138 |
-
self.conv_trans = SeasonalPrediction(embedding_size=configs.d_model, n_heads=configs.n_heads,
|
| 139 |
-
dropout=configs.dropout,
|
| 140 |
-
d_layers=configs.d_layers, decomp_kernel=decomp_kernel,
|
| 141 |
-
c_out=configs.c_out, conv_kernel=conv_kernel,
|
| 142 |
-
isometric_kernel=isometric_kernel, device=torch.device('cuda:0'))
|
| 143 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 144 |
-
# refer to DLinear
|
| 145 |
-
self.regression = nn.Linear(configs.seq_len, configs.pred_len)
|
| 146 |
-
self.regression.weight = nn.Parameter(
|
| 147 |
-
(1 / configs.pred_len) * torch.ones([configs.pred_len, configs.seq_len]),
|
| 148 |
-
requires_grad=True)
|
| 149 |
-
if self.task_name == 'imputation':
|
| 150 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 151 |
-
if self.task_name == 'anomaly_detection':
|
| 152 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 153 |
-
if self.task_name == 'classification':
|
| 154 |
-
self.act = F.gelu
|
| 155 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 156 |
-
self.projection = nn.Linear(configs.c_out * configs.seq_len, configs.num_class)
|
| 157 |
-
|
| 158 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 159 |
-
# Multi-scale Hybrid Decomposition
|
| 160 |
-
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
| 161 |
-
trend = self.regression(trend.permute(0, 2, 1)).permute(0, 2, 1)
|
| 162 |
-
|
| 163 |
-
# embedding
|
| 164 |
-
zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device)
|
| 165 |
-
seasonal_init_dec = torch.cat([seasonal_init_enc[:, -self.seq_len:, :], zeros], dim=1)
|
| 166 |
-
dec_out = self.dec_embedding(seasonal_init_dec, x_mark_dec)
|
| 167 |
-
dec_out = self.conv_trans(dec_out)
|
| 168 |
-
dec_out = dec_out[:, -self.pred_len:, :] + trend[:, -self.pred_len:, :]
|
| 169 |
-
return dec_out
|
| 170 |
-
|
| 171 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 172 |
-
# Multi-scale Hybrid Decomposition
|
| 173 |
-
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
| 174 |
-
|
| 175 |
-
# embedding
|
| 176 |
-
dec_out = self.dec_embedding(seasonal_init_enc, x_mark_dec)
|
| 177 |
-
dec_out = self.conv_trans(dec_out)
|
| 178 |
-
dec_out = dec_out + trend
|
| 179 |
-
return dec_out
|
| 180 |
-
|
| 181 |
-
def anomaly_detection(self, x_enc):
|
| 182 |
-
# Multi-scale Hybrid Decomposition
|
| 183 |
-
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
| 184 |
-
|
| 185 |
-
# embedding
|
| 186 |
-
dec_out = self.dec_embedding(seasonal_init_enc, None)
|
| 187 |
-
dec_out = self.conv_trans(dec_out)
|
| 188 |
-
dec_out = dec_out + trend
|
| 189 |
-
return dec_out
|
| 190 |
-
|
| 191 |
-
def classification(self, x_enc, x_mark_enc):
|
| 192 |
-
# Multi-scale Hybrid Decomposition
|
| 193 |
-
seasonal_init_enc, trend = self.decomp_multi(x_enc)
|
| 194 |
-
# embedding
|
| 195 |
-
dec_out = self.dec_embedding(seasonal_init_enc, None)
|
| 196 |
-
dec_out = self.conv_trans(dec_out)
|
| 197 |
-
dec_out = dec_out + trend
|
| 198 |
-
|
| 199 |
-
# Output from Non-stationary Transformer
|
| 200 |
-
output = self.act(dec_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
| 201 |
-
output = self.dropout(output)
|
| 202 |
-
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
| 203 |
-
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
| 204 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 205 |
-
return output
|
| 206 |
-
|
| 207 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 208 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 209 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 210 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 211 |
-
if self.task_name == 'imputation':
|
| 212 |
-
dec_out = self.imputation(
|
| 213 |
-
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 214 |
-
return dec_out # [B, L, D]
|
| 215 |
-
if self.task_name == 'anomaly_detection':
|
| 216 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 217 |
-
return dec_out # [B, L, D]
|
| 218 |
-
if self.task_name == 'classification':
|
| 219 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 220 |
-
return dec_out # [B, N]
|
| 221 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Mamba.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
|
| 7 |
-
from mamba_ssm import Mamba
|
| 8 |
-
|
| 9 |
-
from layers.Embed import DataEmbedding
|
| 10 |
-
|
| 11 |
-
class Model(nn.Module):
|
| 12 |
-
|
| 13 |
-
def __init__(self, configs):
|
| 14 |
-
super(Model, self).__init__()
|
| 15 |
-
self.task_name = configs.task_name
|
| 16 |
-
self.pred_len = configs.pred_len
|
| 17 |
-
|
| 18 |
-
self.d_inner = configs.d_model * configs.expand
|
| 19 |
-
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto"
|
| 20 |
-
|
| 21 |
-
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
|
| 22 |
-
|
| 23 |
-
self.mamba = Mamba(
|
| 24 |
-
d_model = configs.d_model,
|
| 25 |
-
d_state = configs.d_ff,
|
| 26 |
-
d_conv = configs.d_conv,
|
| 27 |
-
expand = configs.expand,
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
|
| 31 |
-
|
| 32 |
-
def forecast(self, x_enc, x_mark_enc):
|
| 33 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach()
|
| 34 |
-
x_enc = x_enc - mean_enc
|
| 35 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 36 |
-
x_enc = x_enc / std_enc
|
| 37 |
-
|
| 38 |
-
x = self.embedding(x_enc, x_mark_enc)
|
| 39 |
-
x = self.mamba(x)
|
| 40 |
-
x_out = self.out_layer(x)
|
| 41 |
-
|
| 42 |
-
x_out = x_out * std_enc + mean_enc
|
| 43 |
-
return x_out
|
| 44 |
-
|
| 45 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 46 |
-
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
|
| 47 |
-
x_out = self.forecast(x_enc, x_mark_enc)
|
| 48 |
-
return x_out[:, -self.pred_len:, :]
|
| 49 |
-
|
| 50 |
-
# other tasks not implemented
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/MambaSimple.py
DELETED
|
@@ -1,162 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
from einops import rearrange, repeat, einsum
|
| 7 |
-
|
| 8 |
-
from layers.Embed import DataEmbedding
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Model(nn.Module):
|
| 12 |
-
"""
|
| 13 |
-
Mamba, linear-time sequence modeling with selective state spaces O(L)
|
| 14 |
-
Paper link: https://arxiv.org/abs/2312.00752
|
| 15 |
-
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
def __init__(self, configs):
|
| 19 |
-
super(Model, self).__init__()
|
| 20 |
-
self.task_name = configs.task_name
|
| 21 |
-
self.pred_len = configs.pred_len
|
| 22 |
-
|
| 23 |
-
self.d_inner = configs.d_model * configs.expand
|
| 24 |
-
self.dt_rank = math.ceil(configs.d_model / 16)
|
| 25 |
-
|
| 26 |
-
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
|
| 27 |
-
|
| 28 |
-
self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)])
|
| 29 |
-
self.norm = RMSNorm(configs.d_model)
|
| 30 |
-
|
| 31 |
-
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
|
| 32 |
-
|
| 33 |
-
def forecast(self, x_enc, x_mark_enc):
|
| 34 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach()
|
| 35 |
-
x_enc = x_enc - mean_enc
|
| 36 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
|
| 37 |
-
x_enc = x_enc / std_enc
|
| 38 |
-
|
| 39 |
-
x = self.embedding(x_enc, x_mark_enc)
|
| 40 |
-
for layer in self.layers:
|
| 41 |
-
x = layer(x)
|
| 42 |
-
|
| 43 |
-
x = self.norm(x)
|
| 44 |
-
x_out = self.out_layer(x)
|
| 45 |
-
|
| 46 |
-
x_out = x_out * std_enc + mean_enc
|
| 47 |
-
return x_out
|
| 48 |
-
|
| 49 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 50 |
-
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
|
| 51 |
-
x_out = self.forecast(x_enc, x_mark_enc)
|
| 52 |
-
return x_out[:, -self.pred_len:, :]
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class ResidualBlock(nn.Module):
|
| 56 |
-
def __init__(self, configs, d_inner, dt_rank):
|
| 57 |
-
super(ResidualBlock, self).__init__()
|
| 58 |
-
|
| 59 |
-
self.mixer = MambaBlock(configs, d_inner, dt_rank)
|
| 60 |
-
self.norm = RMSNorm(configs.d_model)
|
| 61 |
-
|
| 62 |
-
def forward(self, x):
|
| 63 |
-
output = self.mixer(self.norm(x)) + x
|
| 64 |
-
return output
|
| 65 |
-
|
| 66 |
-
class MambaBlock(nn.Module):
|
| 67 |
-
def __init__(self, configs, d_inner, dt_rank):
|
| 68 |
-
super(MambaBlock, self).__init__()
|
| 69 |
-
self.d_inner = d_inner
|
| 70 |
-
self.dt_rank = dt_rank
|
| 71 |
-
|
| 72 |
-
self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False)
|
| 73 |
-
|
| 74 |
-
self.conv1d = nn.Conv1d(
|
| 75 |
-
in_channels = self.d_inner,
|
| 76 |
-
out_channels = self.d_inner,
|
| 77 |
-
bias = True,
|
| 78 |
-
kernel_size = configs.d_conv,
|
| 79 |
-
padding = configs.d_conv - 1,
|
| 80 |
-
groups = self.d_inner
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
# takes in x and outputs the input-specific delta, B, C
|
| 84 |
-
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False)
|
| 85 |
-
|
| 86 |
-
# projects delta
|
| 87 |
-
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
| 88 |
-
|
| 89 |
-
A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner)
|
| 90 |
-
self.A_log = nn.Parameter(torch.log(A))
|
| 91 |
-
self.D = nn.Parameter(torch.ones(self.d_inner))
|
| 92 |
-
|
| 93 |
-
self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False)
|
| 94 |
-
|
| 95 |
-
def forward(self, x):
|
| 96 |
-
"""
|
| 97 |
-
Figure 3 in Section 3.4 in the paper
|
| 98 |
-
"""
|
| 99 |
-
(b, l, d) = x.shape
|
| 100 |
-
|
| 101 |
-
x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
|
| 102 |
-
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
|
| 103 |
-
|
| 104 |
-
x = rearrange(x, "b l d -> b d l")
|
| 105 |
-
x = self.conv1d(x)[:, :, :l]
|
| 106 |
-
x = rearrange(x, "b d l -> b l d")
|
| 107 |
-
|
| 108 |
-
x = F.silu(x)
|
| 109 |
-
|
| 110 |
-
y = self.ssm(x)
|
| 111 |
-
y = y * F.silu(res)
|
| 112 |
-
|
| 113 |
-
output = self.out_proj(y)
|
| 114 |
-
return output
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def ssm(self, x):
|
| 118 |
-
"""
|
| 119 |
-
Algorithm 2 in Section 3.2 in the paper
|
| 120 |
-
"""
|
| 121 |
-
|
| 122 |
-
(d_in, n) = self.A_log.shape
|
| 123 |
-
|
| 124 |
-
A = -torch.exp(self.A_log.float()) # [d_in, n]
|
| 125 |
-
D = self.D.float() # [d_in]
|
| 126 |
-
|
| 127 |
-
x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
|
| 128 |
-
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
|
| 129 |
-
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
|
| 130 |
-
y = self.selective_scan(x, delta, A, B, C, D)
|
| 131 |
-
|
| 132 |
-
return y
|
| 133 |
-
|
| 134 |
-
def selective_scan(self, u, delta, A, B, C, D):
|
| 135 |
-
(b, l, d_in) = u.shape
|
| 136 |
-
n = A.shape[1]
|
| 137 |
-
|
| 138 |
-
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
|
| 139 |
-
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"
|
| 140 |
-
|
| 141 |
-
# selective scan, sequential instead of parallel
|
| 142 |
-
x = torch.zeros((b, d_in, n), device=deltaA.device)
|
| 143 |
-
ys = []
|
| 144 |
-
for i in range(l):
|
| 145 |
-
x = deltaA[:, i] * x + deltaB_u[:, i]
|
| 146 |
-
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
|
| 147 |
-
ys.append(y)
|
| 148 |
-
|
| 149 |
-
y = torch.stack(ys, dim=1) # [B, L, d_in]
|
| 150 |
-
y = y + u * D
|
| 151 |
-
|
| 152 |
-
return y
|
| 153 |
-
|
| 154 |
-
class RMSNorm(nn.Module):
|
| 155 |
-
def __init__(self, d_model, eps=1e-5):
|
| 156 |
-
super(RMSNorm, self).__init__()
|
| 157 |
-
self.eps = eps
|
| 158 |
-
self.weight = nn.Parameter(torch.ones(d_model))
|
| 159 |
-
|
| 160 |
-
def forward(self, x):
|
| 161 |
-
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
| 162 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Nonstationary_Transformer.py
DELETED
|
@@ -1,218 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer
|
| 4 |
-
from layers.SelfAttention_Family import DSAttention, AttentionLayer
|
| 5 |
-
from layers.Embed import DataEmbedding
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class Projector(nn.Module):
|
| 10 |
-
'''
|
| 11 |
-
MLP to learn the De-stationary factors
|
| 12 |
-
Paper link: https://openreview.net/pdf?id=ucNDIDRNjjv
|
| 13 |
-
'''
|
| 14 |
-
|
| 15 |
-
def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3):
|
| 16 |
-
super(Projector, self).__init__()
|
| 17 |
-
|
| 18 |
-
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
| 19 |
-
self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding,
|
| 20 |
-
padding_mode='circular', bias=False)
|
| 21 |
-
|
| 22 |
-
layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()]
|
| 23 |
-
for i in range(hidden_layers - 1):
|
| 24 |
-
layers += [nn.Linear(hidden_dims[i], hidden_dims[i + 1]), nn.ReLU()]
|
| 25 |
-
|
| 26 |
-
layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)]
|
| 27 |
-
self.backbone = nn.Sequential(*layers)
|
| 28 |
-
|
| 29 |
-
def forward(self, x, stats):
|
| 30 |
-
# x: B x S x E
|
| 31 |
-
# stats: B x 1 x E
|
| 32 |
-
# y: B x O
|
| 33 |
-
batch_size = x.shape[0]
|
| 34 |
-
x = self.series_conv(x) # B x 1 x E
|
| 35 |
-
x = torch.cat([x, stats], dim=1) # B x 2 x E
|
| 36 |
-
x = x.view(batch_size, -1) # B x 2E
|
| 37 |
-
y = self.backbone(x) # B x O
|
| 38 |
-
|
| 39 |
-
return y
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class Model(nn.Module):
|
| 43 |
-
"""
|
| 44 |
-
Paper link: https://openreview.net/pdf?id=ucNDIDRNjjv
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
def __init__(self, configs):
|
| 48 |
-
super(Model, self).__init__()
|
| 49 |
-
self.task_name = configs.task_name
|
| 50 |
-
self.pred_len = configs.pred_len
|
| 51 |
-
self.seq_len = configs.seq_len
|
| 52 |
-
self.label_len = configs.label_len
|
| 53 |
-
|
| 54 |
-
# Embedding
|
| 55 |
-
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 56 |
-
configs.dropout)
|
| 57 |
-
|
| 58 |
-
# Encoder
|
| 59 |
-
self.encoder = Encoder(
|
| 60 |
-
[
|
| 61 |
-
EncoderLayer(
|
| 62 |
-
AttentionLayer(
|
| 63 |
-
DSAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 64 |
-
output_attention=False), configs.d_model, configs.n_heads),
|
| 65 |
-
configs.d_model,
|
| 66 |
-
configs.d_ff,
|
| 67 |
-
dropout=configs.dropout,
|
| 68 |
-
activation=configs.activation
|
| 69 |
-
) for l in range(configs.e_layers)
|
| 70 |
-
],
|
| 71 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
| 72 |
-
)
|
| 73 |
-
# Decoder
|
| 74 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 75 |
-
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
| 76 |
-
configs.dropout)
|
| 77 |
-
self.decoder = Decoder(
|
| 78 |
-
[
|
| 79 |
-
DecoderLayer(
|
| 80 |
-
AttentionLayer(
|
| 81 |
-
DSAttention(True, configs.factor, attention_dropout=configs.dropout,
|
| 82 |
-
output_attention=False),
|
| 83 |
-
configs.d_model, configs.n_heads),
|
| 84 |
-
AttentionLayer(
|
| 85 |
-
DSAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 86 |
-
output_attention=False),
|
| 87 |
-
configs.d_model, configs.n_heads),
|
| 88 |
-
configs.d_model,
|
| 89 |
-
configs.d_ff,
|
| 90 |
-
dropout=configs.dropout,
|
| 91 |
-
activation=configs.activation,
|
| 92 |
-
)
|
| 93 |
-
for l in range(configs.d_layers)
|
| 94 |
-
],
|
| 95 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model),
|
| 96 |
-
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 97 |
-
)
|
| 98 |
-
if self.task_name == 'imputation':
|
| 99 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 100 |
-
if self.task_name == 'anomaly_detection':
|
| 101 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 102 |
-
if self.task_name == 'classification':
|
| 103 |
-
self.act = F.gelu
|
| 104 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 105 |
-
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
| 106 |
-
|
| 107 |
-
self.tau_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims,
|
| 108 |
-
hidden_layers=configs.p_hidden_layers, output_dim=1)
|
| 109 |
-
self.delta_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len,
|
| 110 |
-
hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers,
|
| 111 |
-
output_dim=configs.seq_len)
|
| 112 |
-
|
| 113 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 114 |
-
x_raw = x_enc.clone().detach()
|
| 115 |
-
|
| 116 |
-
# Normalization
|
| 117 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
| 118 |
-
x_enc = x_enc - mean_enc
|
| 119 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
| 120 |
-
x_enc = x_enc / std_enc
|
| 121 |
-
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
| 122 |
-
tau = self.tau_learner(x_raw, std_enc).exp()
|
| 123 |
-
# B x S x E, B x 1 x E -> B x S
|
| 124 |
-
delta = self.delta_learner(x_raw, mean_enc)
|
| 125 |
-
|
| 126 |
-
x_dec_new = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_dec[:, -self.pred_len:, :])],
|
| 127 |
-
dim=1).to(x_enc.device).clone()
|
| 128 |
-
|
| 129 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 130 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
| 131 |
-
|
| 132 |
-
dec_out = self.dec_embedding(x_dec_new, x_mark_dec)
|
| 133 |
-
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, tau=tau, delta=delta)
|
| 134 |
-
dec_out = dec_out * std_enc + mean_enc
|
| 135 |
-
return dec_out
|
| 136 |
-
|
| 137 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 138 |
-
x_raw = x_enc.clone().detach()
|
| 139 |
-
|
| 140 |
-
# Normalization
|
| 141 |
-
mean_enc = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
| 142 |
-
mean_enc = mean_enc.unsqueeze(1).detach()
|
| 143 |
-
x_enc = x_enc - mean_enc
|
| 144 |
-
x_enc = x_enc.masked_fill(mask == 0, 0)
|
| 145 |
-
std_enc = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
|
| 146 |
-
std_enc = std_enc.unsqueeze(1).detach()
|
| 147 |
-
x_enc /= std_enc
|
| 148 |
-
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
| 149 |
-
tau = self.tau_learner(x_raw, std_enc).exp()
|
| 150 |
-
# B x S x E, B x 1 x E -> B x S
|
| 151 |
-
delta = self.delta_learner(x_raw, mean_enc)
|
| 152 |
-
|
| 153 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 154 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
| 155 |
-
|
| 156 |
-
dec_out = self.projection(enc_out)
|
| 157 |
-
dec_out = dec_out * std_enc + mean_enc
|
| 158 |
-
return dec_out
|
| 159 |
-
|
| 160 |
-
def anomaly_detection(self, x_enc):
|
| 161 |
-
x_raw = x_enc.clone().detach()
|
| 162 |
-
|
| 163 |
-
# Normalization
|
| 164 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
| 165 |
-
x_enc = x_enc - mean_enc
|
| 166 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
| 167 |
-
x_enc = x_enc / std_enc
|
| 168 |
-
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
| 169 |
-
tau = self.tau_learner(x_raw, std_enc).exp()
|
| 170 |
-
# B x S x E, B x 1 x E -> B x S
|
| 171 |
-
delta = self.delta_learner(x_raw, mean_enc)
|
| 172 |
-
# embedding
|
| 173 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 174 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
| 175 |
-
|
| 176 |
-
dec_out = self.projection(enc_out)
|
| 177 |
-
dec_out = dec_out * std_enc + mean_enc
|
| 178 |
-
return dec_out
|
| 179 |
-
|
| 180 |
-
def classification(self, x_enc, x_mark_enc):
|
| 181 |
-
x_raw = x_enc.clone().detach()
|
| 182 |
-
|
| 183 |
-
# Normalization
|
| 184 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
| 185 |
-
std_enc = torch.sqrt(
|
| 186 |
-
torch.var(x_enc - mean_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
| 187 |
-
# B x S x E, B x 1 x E -> B x 1, positive scalar
|
| 188 |
-
tau = self.tau_learner(x_raw, std_enc).exp()
|
| 189 |
-
# B x S x E, B x 1 x E -> B x S
|
| 190 |
-
delta = self.delta_learner(x_raw, mean_enc)
|
| 191 |
-
# embedding
|
| 192 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 193 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
|
| 194 |
-
|
| 195 |
-
# Output
|
| 196 |
-
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
| 197 |
-
output = self.dropout(output)
|
| 198 |
-
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
| 199 |
-
# (batch_size, seq_length * d_model)
|
| 200 |
-
output = output.reshape(output.shape[0], -1)
|
| 201 |
-
# (batch_size, num_classes)
|
| 202 |
-
output = self.projection(output)
|
| 203 |
-
return output
|
| 204 |
-
|
| 205 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 206 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 207 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 208 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 209 |
-
if self.task_name == 'imputation':
|
| 210 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 211 |
-
return dec_out # [B, L, D]
|
| 212 |
-
if self.task_name == 'anomaly_detection':
|
| 213 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 214 |
-
return dec_out # [B, L, D]
|
| 215 |
-
if self.task_name == 'classification':
|
| 216 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 217 |
-
return dec_out # [B, L, D]
|
| 218 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/PatchTST.py
DELETED
|
@@ -1,227 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
-
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
| 4 |
-
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
| 5 |
-
from layers.Embed import PatchEmbedding
|
| 6 |
-
|
| 7 |
-
class Transpose(nn.Module):
|
| 8 |
-
def __init__(self, *dims, contiguous=False):
|
| 9 |
-
super().__init__()
|
| 10 |
-
self.dims, self.contiguous = dims, contiguous
|
| 11 |
-
def forward(self, x):
|
| 12 |
-
if self.contiguous: return x.transpose(*self.dims).contiguous()
|
| 13 |
-
else: return x.transpose(*self.dims)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class FlattenHead(nn.Module):
|
| 17 |
-
def __init__(self, n_vars, nf, target_window, head_dropout=0):
|
| 18 |
-
super().__init__()
|
| 19 |
-
self.n_vars = n_vars
|
| 20 |
-
self.flatten = nn.Flatten(start_dim=-2)
|
| 21 |
-
self.linear = nn.Linear(nf, target_window)
|
| 22 |
-
self.dropout = nn.Dropout(head_dropout)
|
| 23 |
-
|
| 24 |
-
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
|
| 25 |
-
x = self.flatten(x)
|
| 26 |
-
x = self.linear(x)
|
| 27 |
-
x = self.dropout(x)
|
| 28 |
-
return x
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class Model(nn.Module):
|
| 32 |
-
"""
|
| 33 |
-
Paper link: https://arxiv.org/pdf/2211.14730.pdf
|
| 34 |
-
"""
|
| 35 |
-
|
| 36 |
-
def __init__(self, configs, patch_len=16, stride=8):
|
| 37 |
-
"""
|
| 38 |
-
patch_len: int, patch len for patch_embedding
|
| 39 |
-
stride: int, stride for patch_embedding
|
| 40 |
-
"""
|
| 41 |
-
super().__init__()
|
| 42 |
-
self.task_name = configs.task_name
|
| 43 |
-
self.seq_len = configs.seq_len
|
| 44 |
-
self.pred_len = configs.pred_len
|
| 45 |
-
padding = stride
|
| 46 |
-
|
| 47 |
-
# patching and embedding
|
| 48 |
-
self.patch_embedding = PatchEmbedding(
|
| 49 |
-
configs.d_model, patch_len, stride, padding, configs.dropout)
|
| 50 |
-
|
| 51 |
-
# Encoder
|
| 52 |
-
self.encoder = Encoder(
|
| 53 |
-
[
|
| 54 |
-
EncoderLayer(
|
| 55 |
-
AttentionLayer(
|
| 56 |
-
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 57 |
-
output_attention=False), configs.d_model, configs.n_heads),
|
| 58 |
-
configs.d_model,
|
| 59 |
-
configs.d_ff,
|
| 60 |
-
dropout=configs.dropout,
|
| 61 |
-
activation=configs.activation
|
| 62 |
-
) for l in range(configs.e_layers)
|
| 63 |
-
],
|
| 64 |
-
norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2))
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
# Prediction Head
|
| 68 |
-
self.head_nf = configs.d_model * \
|
| 69 |
-
int((configs.seq_len - patch_len) / stride + 2)
|
| 70 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 71 |
-
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len,
|
| 72 |
-
head_dropout=configs.dropout)
|
| 73 |
-
elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
| 74 |
-
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len,
|
| 75 |
-
head_dropout=configs.dropout)
|
| 76 |
-
elif self.task_name == 'classification':
|
| 77 |
-
self.flatten = nn.Flatten(start_dim=-2)
|
| 78 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 79 |
-
self.projection = nn.Linear(
|
| 80 |
-
self.head_nf * configs.enc_in, configs.num_class)
|
| 81 |
-
|
| 82 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 83 |
-
# Normalization from Non-stationary Transformer
|
| 84 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 85 |
-
x_enc = x_enc - means
|
| 86 |
-
stdev = torch.sqrt(
|
| 87 |
-
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 88 |
-
x_enc /= stdev
|
| 89 |
-
|
| 90 |
-
# do patching and embedding
|
| 91 |
-
x_enc = x_enc.permute(0, 2, 1)
|
| 92 |
-
# u: [bs * nvars x patch_num x d_model]
|
| 93 |
-
enc_out, n_vars = self.patch_embedding(x_enc)
|
| 94 |
-
|
| 95 |
-
# Encoder
|
| 96 |
-
# z: [bs * nvars x patch_num x d_model]
|
| 97 |
-
enc_out, attns = self.encoder(enc_out)
|
| 98 |
-
# z: [bs x nvars x patch_num x d_model]
|
| 99 |
-
enc_out = torch.reshape(
|
| 100 |
-
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
| 101 |
-
# z: [bs x nvars x d_model x patch_num]
|
| 102 |
-
enc_out = enc_out.permute(0, 1, 3, 2)
|
| 103 |
-
|
| 104 |
-
# Decoder
|
| 105 |
-
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
| 106 |
-
dec_out = dec_out.permute(0, 2, 1)
|
| 107 |
-
|
| 108 |
-
# De-Normalization from Non-stationary Transformer
|
| 109 |
-
dec_out = dec_out * \
|
| 110 |
-
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
| 111 |
-
dec_out = dec_out + \
|
| 112 |
-
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
| 113 |
-
return dec_out
|
| 114 |
-
|
| 115 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 116 |
-
# Normalization from Non-stationary Transformer
|
| 117 |
-
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
| 118 |
-
means = means.unsqueeze(1).detach()
|
| 119 |
-
x_enc = x_enc - means
|
| 120 |
-
x_enc = x_enc.masked_fill(mask == 0, 0)
|
| 121 |
-
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
|
| 122 |
-
torch.sum(mask == 1, dim=1) + 1e-5)
|
| 123 |
-
stdev = stdev.unsqueeze(1).detach()
|
| 124 |
-
x_enc /= stdev
|
| 125 |
-
|
| 126 |
-
# do patching and embedding
|
| 127 |
-
x_enc = x_enc.permute(0, 2, 1)
|
| 128 |
-
# u: [bs * nvars x patch_num x d_model]
|
| 129 |
-
enc_out, n_vars = self.patch_embedding(x_enc)
|
| 130 |
-
|
| 131 |
-
# Encoder
|
| 132 |
-
# z: [bs * nvars x patch_num x d_model]
|
| 133 |
-
enc_out, attns = self.encoder(enc_out)
|
| 134 |
-
# z: [bs x nvars x patch_num x d_model]
|
| 135 |
-
enc_out = torch.reshape(
|
| 136 |
-
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
| 137 |
-
# z: [bs x nvars x d_model x patch_num]
|
| 138 |
-
enc_out = enc_out.permute(0, 1, 3, 2)
|
| 139 |
-
|
| 140 |
-
# Decoder
|
| 141 |
-
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
| 142 |
-
dec_out = dec_out.permute(0, 2, 1)
|
| 143 |
-
|
| 144 |
-
# De-Normalization from Non-stationary Transformer
|
| 145 |
-
dec_out = dec_out * \
|
| 146 |
-
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
| 147 |
-
dec_out = dec_out + \
|
| 148 |
-
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
| 149 |
-
return dec_out
|
| 150 |
-
|
| 151 |
-
def anomaly_detection(self, x_enc):
|
| 152 |
-
# Normalization from Non-stationary Transformer
|
| 153 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 154 |
-
x_enc = x_enc - means
|
| 155 |
-
stdev = torch.sqrt(
|
| 156 |
-
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 157 |
-
x_enc /= stdev
|
| 158 |
-
|
| 159 |
-
# do patching and embedding
|
| 160 |
-
x_enc = x_enc.permute(0, 2, 1)
|
| 161 |
-
# u: [bs * nvars x patch_num x d_model]
|
| 162 |
-
enc_out, n_vars = self.patch_embedding(x_enc)
|
| 163 |
-
|
| 164 |
-
# Encoder
|
| 165 |
-
# z: [bs * nvars x patch_num x d_model]
|
| 166 |
-
enc_out, attns = self.encoder(enc_out)
|
| 167 |
-
# z: [bs x nvars x patch_num x d_model]
|
| 168 |
-
enc_out = torch.reshape(
|
| 169 |
-
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
| 170 |
-
# z: [bs x nvars x d_model x patch_num]
|
| 171 |
-
enc_out = enc_out.permute(0, 1, 3, 2)
|
| 172 |
-
|
| 173 |
-
# Decoder
|
| 174 |
-
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
|
| 175 |
-
dec_out = dec_out.permute(0, 2, 1)
|
| 176 |
-
|
| 177 |
-
# De-Normalization from Non-stationary Transformer
|
| 178 |
-
dec_out = dec_out * \
|
| 179 |
-
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
| 180 |
-
dec_out = dec_out + \
|
| 181 |
-
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
|
| 182 |
-
return dec_out
|
| 183 |
-
|
| 184 |
-
def classification(self, x_enc, x_mark_enc):
|
| 185 |
-
# Normalization from Non-stationary Transformer
|
| 186 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 187 |
-
x_enc = x_enc - means
|
| 188 |
-
stdev = torch.sqrt(
|
| 189 |
-
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 190 |
-
x_enc /= stdev
|
| 191 |
-
|
| 192 |
-
# do patching and embedding
|
| 193 |
-
x_enc = x_enc.permute(0, 2, 1)
|
| 194 |
-
# u: [bs * nvars x patch_num x d_model]
|
| 195 |
-
enc_out, n_vars = self.patch_embedding(x_enc)
|
| 196 |
-
|
| 197 |
-
# Encoder
|
| 198 |
-
# z: [bs * nvars x patch_num x d_model]
|
| 199 |
-
enc_out, attns = self.encoder(enc_out)
|
| 200 |
-
# z: [bs x nvars x patch_num x d_model]
|
| 201 |
-
enc_out = torch.reshape(
|
| 202 |
-
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
|
| 203 |
-
# z: [bs x nvars x d_model x patch_num]
|
| 204 |
-
enc_out = enc_out.permute(0, 1, 3, 2)
|
| 205 |
-
|
| 206 |
-
# Decoder
|
| 207 |
-
output = self.flatten(enc_out)
|
| 208 |
-
output = self.dropout(output)
|
| 209 |
-
output = output.reshape(output.shape[0], -1)
|
| 210 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 211 |
-
return output
|
| 212 |
-
|
| 213 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 214 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 215 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 216 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 217 |
-
if self.task_name == 'imputation':
|
| 218 |
-
dec_out = self.imputation(
|
| 219 |
-
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 220 |
-
return dec_out # [B, L, D]
|
| 221 |
-
if self.task_name == 'anomaly_detection':
|
| 222 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 223 |
-
return dec_out # [B, L, D]
|
| 224 |
-
if self.task_name == 'classification':
|
| 225 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 226 |
-
return dec_out # [B, N]
|
| 227 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Pyraformer.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from layers.Pyraformer_EncDec import Encoder
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class Model(nn.Module):
|
| 7 |
-
"""
|
| 8 |
-
Pyraformer: Pyramidal attention to reduce complexity
|
| 9 |
-
Paper link: https://openreview.net/pdf?id=0EXmFzUn5I
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def __init__(self, configs, window_size=[4,4], inner_size=5):
|
| 13 |
-
"""
|
| 14 |
-
window_size: list, the downsample window size in pyramidal attention.
|
| 15 |
-
inner_size: int, the size of neighbour attention
|
| 16 |
-
"""
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.task_name = configs.task_name
|
| 19 |
-
self.pred_len = configs.pred_len
|
| 20 |
-
self.d_model = configs.d_model
|
| 21 |
-
|
| 22 |
-
if self.task_name == 'short_term_forecast':
|
| 23 |
-
window_size = [2,2]
|
| 24 |
-
self.encoder = Encoder(configs, window_size, inner_size)
|
| 25 |
-
|
| 26 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 27 |
-
self.projection = nn.Linear(
|
| 28 |
-
(len(window_size)+1)*self.d_model, self.pred_len * configs.enc_in)
|
| 29 |
-
elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
| 30 |
-
self.projection = nn.Linear(
|
| 31 |
-
(len(window_size)+1)*self.d_model, configs.enc_in, bias=True)
|
| 32 |
-
elif self.task_name == 'classification':
|
| 33 |
-
self.act = torch.nn.functional.gelu
|
| 34 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 35 |
-
self.projection = nn.Linear(
|
| 36 |
-
(len(window_size)+1)*self.d_model * configs.seq_len, configs.num_class)
|
| 37 |
-
|
| 38 |
-
def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 39 |
-
enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
|
| 40 |
-
dec_out = self.projection(enc_out).view(
|
| 41 |
-
enc_out.size(0), self.pred_len, -1)
|
| 42 |
-
return dec_out
|
| 43 |
-
|
| 44 |
-
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 45 |
-
# Normalization
|
| 46 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
| 47 |
-
x_enc = x_enc - mean_enc
|
| 48 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
| 49 |
-
x_enc = x_enc / std_enc
|
| 50 |
-
|
| 51 |
-
enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
|
| 52 |
-
dec_out = self.projection(enc_out).view(
|
| 53 |
-
enc_out.size(0), self.pred_len, -1)
|
| 54 |
-
|
| 55 |
-
dec_out = dec_out * std_enc + mean_enc
|
| 56 |
-
return dec_out
|
| 57 |
-
|
| 58 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 59 |
-
enc_out = self.encoder(x_enc, x_mark_enc)
|
| 60 |
-
dec_out = self.projection(enc_out)
|
| 61 |
-
return dec_out
|
| 62 |
-
|
| 63 |
-
def anomaly_detection(self, x_enc, x_mark_enc):
|
| 64 |
-
enc_out = self.encoder(x_enc, x_mark_enc)
|
| 65 |
-
dec_out = self.projection(enc_out)
|
| 66 |
-
return dec_out
|
| 67 |
-
|
| 68 |
-
def classification(self, x_enc, x_mark_enc):
|
| 69 |
-
# enc
|
| 70 |
-
enc_out = self.encoder(x_enc, x_mark_enc=None)
|
| 71 |
-
|
| 72 |
-
# Output
|
| 73 |
-
# the output transformer encoder/decoder embeddings don't include non-linearity
|
| 74 |
-
output = self.act(enc_out)
|
| 75 |
-
output = self.dropout(output)
|
| 76 |
-
# zero-out padding embeddings
|
| 77 |
-
output = output * x_mark_enc.unsqueeze(-1)
|
| 78 |
-
# (batch_size, seq_length * d_model)
|
| 79 |
-
output = output.reshape(output.shape[0], -1)
|
| 80 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 81 |
-
|
| 82 |
-
return output
|
| 83 |
-
|
| 84 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 85 |
-
if self.task_name == 'long_term_forecast':
|
| 86 |
-
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 87 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 88 |
-
if self.task_name == 'short_term_forecast':
|
| 89 |
-
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 90 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 91 |
-
if self.task_name == 'imputation':
|
| 92 |
-
dec_out = self.imputation(
|
| 93 |
-
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 94 |
-
return dec_out # [B, L, D]
|
| 95 |
-
if self.task_name == 'anomaly_detection':
|
| 96 |
-
dec_out = self.anomaly_detection(x_enc, x_mark_enc)
|
| 97 |
-
return dec_out # [B, L, D]
|
| 98 |
-
if self.task_name == 'classification':
|
| 99 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 100 |
-
return dec_out # [B, N]
|
| 101 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Reformer.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
| 5 |
-
from layers.SelfAttention_Family import ReformerLayer
|
| 6 |
-
from layers.Embed import DataEmbedding
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class Model(nn.Module):
|
| 10 |
-
"""
|
| 11 |
-
Reformer with O(LlogL) complexity
|
| 12 |
-
Paper link: https://openreview.net/forum?id=rkgNKkHtvB
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
def __init__(self, configs, bucket_size=4, n_hashes=4):
|
| 16 |
-
"""
|
| 17 |
-
bucket_size: int,
|
| 18 |
-
n_hashes: int,
|
| 19 |
-
"""
|
| 20 |
-
super(Model, self).__init__()
|
| 21 |
-
self.task_name = configs.task_name
|
| 22 |
-
self.pred_len = configs.pred_len
|
| 23 |
-
self.seq_len = configs.seq_len
|
| 24 |
-
|
| 25 |
-
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 26 |
-
configs.dropout)
|
| 27 |
-
# Encoder
|
| 28 |
-
self.encoder = Encoder(
|
| 29 |
-
[
|
| 30 |
-
EncoderLayer(
|
| 31 |
-
ReformerLayer(None, configs.d_model, configs.n_heads,
|
| 32 |
-
bucket_size=bucket_size, n_hashes=n_hashes),
|
| 33 |
-
configs.d_model,
|
| 34 |
-
configs.d_ff,
|
| 35 |
-
dropout=configs.dropout,
|
| 36 |
-
activation=configs.activation
|
| 37 |
-
) for l in range(configs.e_layers)
|
| 38 |
-
],
|
| 39 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
if self.task_name == 'classification':
|
| 43 |
-
self.act = F.gelu
|
| 44 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 45 |
-
self.projection = nn.Linear(
|
| 46 |
-
configs.d_model * configs.seq_len, configs.num_class)
|
| 47 |
-
else:
|
| 48 |
-
self.projection = nn.Linear(
|
| 49 |
-
configs.d_model, configs.c_out, bias=True)
|
| 50 |
-
|
| 51 |
-
def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 52 |
-
# add placeholder
|
| 53 |
-
x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1)
|
| 54 |
-
if x_mark_enc is not None:
|
| 55 |
-
x_mark_enc = torch.cat(
|
| 56 |
-
[x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1)
|
| 57 |
-
|
| 58 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
| 59 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 60 |
-
dec_out = self.projection(enc_out)
|
| 61 |
-
|
| 62 |
-
return dec_out # [B, L, D]
|
| 63 |
-
|
| 64 |
-
def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 65 |
-
# Normalization
|
| 66 |
-
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
|
| 67 |
-
x_enc = x_enc - mean_enc
|
| 68 |
-
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
|
| 69 |
-
x_enc = x_enc / std_enc
|
| 70 |
-
|
| 71 |
-
# add placeholder
|
| 72 |
-
x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1)
|
| 73 |
-
if x_mark_enc is not None:
|
| 74 |
-
x_mark_enc = torch.cat(
|
| 75 |
-
[x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1)
|
| 76 |
-
|
| 77 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
| 78 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 79 |
-
dec_out = self.projection(enc_out)
|
| 80 |
-
|
| 81 |
-
dec_out = dec_out * std_enc + mean_enc
|
| 82 |
-
return dec_out # [B, L, D]
|
| 83 |
-
|
| 84 |
-
def imputation(self, x_enc, x_mark_enc):
|
| 85 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
| 86 |
-
|
| 87 |
-
enc_out, attns = self.encoder(enc_out)
|
| 88 |
-
enc_out = self.projection(enc_out)
|
| 89 |
-
|
| 90 |
-
return enc_out # [B, L, D]
|
| 91 |
-
|
| 92 |
-
def anomaly_detection(self, x_enc):
|
| 93 |
-
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
| 94 |
-
|
| 95 |
-
enc_out, attns = self.encoder(enc_out)
|
| 96 |
-
enc_out = self.projection(enc_out)
|
| 97 |
-
|
| 98 |
-
return enc_out # [B, L, D]
|
| 99 |
-
|
| 100 |
-
def classification(self, x_enc, x_mark_enc):
|
| 101 |
-
# enc
|
| 102 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 103 |
-
enc_out, attns = self.encoder(enc_out)
|
| 104 |
-
|
| 105 |
-
# Output
|
| 106 |
-
# the output transformer encoder/decoder embeddings don't include non-linearity
|
| 107 |
-
output = self.act(enc_out)
|
| 108 |
-
output = self.dropout(output)
|
| 109 |
-
# zero-out padding embeddings
|
| 110 |
-
output = output * x_mark_enc.unsqueeze(-1)
|
| 111 |
-
# (batch_size, seq_length * d_model)
|
| 112 |
-
output = output.reshape(output.shape[0], -1)
|
| 113 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 114 |
-
return output
|
| 115 |
-
|
| 116 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 117 |
-
if self.task_name == 'long_term_forecast':
|
| 118 |
-
dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 119 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 120 |
-
if self.task_name == 'short_term_forecast':
|
| 121 |
-
dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 122 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 123 |
-
if self.task_name == 'imputation':
|
| 124 |
-
dec_out = self.imputation(x_enc, x_mark_enc)
|
| 125 |
-
return dec_out # [B, L, D]
|
| 126 |
-
if self.task_name == 'anomaly_detection':
|
| 127 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 128 |
-
return dec_out # [B, L, D]
|
| 129 |
-
if self.task_name == 'classification':
|
| 130 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 131 |
-
return dec_out # [B, N]
|
| 132 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SCINet.py
DELETED
|
@@ -1,188 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import math
|
| 5 |
-
|
| 6 |
-
class Splitting(nn.Module):
|
| 7 |
-
def __init__(self):
|
| 8 |
-
super(Splitting, self).__init__()
|
| 9 |
-
|
| 10 |
-
def even(self, x):
|
| 11 |
-
return x[:, ::2, :]
|
| 12 |
-
|
| 13 |
-
def odd(self, x):
|
| 14 |
-
return x[:, 1::2, :]
|
| 15 |
-
|
| 16 |
-
def forward(self, x):
|
| 17 |
-
# return the odd and even part
|
| 18 |
-
return self.even(x), self.odd(x)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class CausalConvBlock(nn.Module):
|
| 22 |
-
def __init__(self, d_model, kernel_size=5, dropout=0.0):
|
| 23 |
-
super(CausalConvBlock, self).__init__()
|
| 24 |
-
module_list = [
|
| 25 |
-
nn.ReplicationPad1d((kernel_size - 1, kernel_size - 1)),
|
| 26 |
-
|
| 27 |
-
nn.Conv1d(d_model, d_model,
|
| 28 |
-
kernel_size=kernel_size),
|
| 29 |
-
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
| 30 |
-
|
| 31 |
-
nn.Dropout(dropout),
|
| 32 |
-
nn.Conv1d(d_model, d_model,
|
| 33 |
-
kernel_size=kernel_size),
|
| 34 |
-
nn.Tanh()
|
| 35 |
-
]
|
| 36 |
-
self.causal_conv = nn.Sequential(*module_list)
|
| 37 |
-
|
| 38 |
-
def forward(self, x):
|
| 39 |
-
return self.causal_conv(x) # return value is the same as input dimension
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class SCIBlock(nn.Module):
|
| 43 |
-
def __init__(self, d_model, kernel_size=5, dropout=0.0):
|
| 44 |
-
super(SCIBlock, self).__init__()
|
| 45 |
-
self.splitting = Splitting()
|
| 46 |
-
self.modules_even, self.modules_odd, self.interactor_even, self.interactor_odd = [CausalConvBlock(d_model) for _ in range(4)]
|
| 47 |
-
|
| 48 |
-
def forward(self, x):
|
| 49 |
-
x_even, x_odd = self.splitting(x)
|
| 50 |
-
x_even = x_even.permute(0, 2, 1)
|
| 51 |
-
x_odd = x_odd.permute(0, 2, 1)
|
| 52 |
-
|
| 53 |
-
x_even_temp = x_even.mul(torch.exp(self.modules_even(x_odd)))
|
| 54 |
-
x_odd_temp = x_odd.mul(torch.exp(self.modules_odd(x_even)))
|
| 55 |
-
|
| 56 |
-
x_even_update = x_even_temp + self.interactor_even(x_odd_temp)
|
| 57 |
-
x_odd_update = x_odd_temp - self.interactor_odd(x_even_temp)
|
| 58 |
-
|
| 59 |
-
return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class SCINet(nn.Module):
|
| 63 |
-
def __init__(self, d_model, current_level=3, kernel_size=5, dropout=0.0):
|
| 64 |
-
super(SCINet, self).__init__()
|
| 65 |
-
self.current_level = current_level
|
| 66 |
-
self.working_block = SCIBlock(d_model, kernel_size, dropout)
|
| 67 |
-
|
| 68 |
-
if current_level != 0:
|
| 69 |
-
self.SCINet_Tree_odd = SCINet(d_model, current_level-1, kernel_size, dropout)
|
| 70 |
-
self.SCINet_Tree_even = SCINet(d_model, current_level-1, kernel_size, dropout)
|
| 71 |
-
|
| 72 |
-
def forward(self, x):
|
| 73 |
-
odd_flag = False
|
| 74 |
-
if x.shape[1] % 2 == 1:
|
| 75 |
-
odd_flag = True
|
| 76 |
-
x = torch.cat((x, x[:, -1:, :]), dim=1)
|
| 77 |
-
x_even_update, x_odd_update = self.working_block(x)
|
| 78 |
-
if odd_flag:
|
| 79 |
-
x_odd_update = x_odd_update[:, :-1]
|
| 80 |
-
|
| 81 |
-
if self.current_level == 0:
|
| 82 |
-
return self.zip_up_the_pants(x_even_update, x_odd_update)
|
| 83 |
-
else:
|
| 84 |
-
return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update))
|
| 85 |
-
|
| 86 |
-
def zip_up_the_pants(self, even, odd):
|
| 87 |
-
even = even.permute(1, 0, 2)
|
| 88 |
-
odd = odd.permute(1, 0, 2)
|
| 89 |
-
even_len = even.shape[0]
|
| 90 |
-
odd_len = odd.shape[0]
|
| 91 |
-
min_len = min(even_len, odd_len)
|
| 92 |
-
|
| 93 |
-
zipped_data = []
|
| 94 |
-
for i in range(min_len):
|
| 95 |
-
zipped_data.append(even[i].unsqueeze(0))
|
| 96 |
-
zipped_data.append(odd[i].unsqueeze(0))
|
| 97 |
-
if even_len > odd_len:
|
| 98 |
-
zipped_data.append(even[-1].unsqueeze(0))
|
| 99 |
-
return torch.cat(zipped_data,0).permute(1, 0, 2)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
class Model(nn.Module):
|
| 103 |
-
def __init__(self, configs):
|
| 104 |
-
super(Model, self).__init__()
|
| 105 |
-
self.task_name = configs.task_name
|
| 106 |
-
self.seq_len = configs.seq_len
|
| 107 |
-
self.label_len = configs.label_len
|
| 108 |
-
self.pred_len = configs.pred_len
|
| 109 |
-
|
| 110 |
-
# You can set the number of SCINet stacks by argument "d_layers", but should choose 1 or 2.
|
| 111 |
-
self.num_stacks = configs.d_layers
|
| 112 |
-
if self.num_stacks == 1:
|
| 113 |
-
self.sci_net_1 = SCINet(configs.enc_in, dropout=configs.dropout)
|
| 114 |
-
self.projection_1 = nn.Conv1d(self.seq_len, self.seq_len + self.pred_len, kernel_size=1, stride=1, bias=False)
|
| 115 |
-
else:
|
| 116 |
-
self.sci_net_1, self.sci_net_2 = [SCINet(configs.enc_in, dropout=configs.dropout) for _ in range(2)]
|
| 117 |
-
self.projection_1 = nn.Conv1d(self.seq_len, self.pred_len, kernel_size=1, stride=1, bias=False)
|
| 118 |
-
self.projection_2 = nn.Conv1d(self.seq_len+self.pred_len, self.seq_len+self.pred_len,
|
| 119 |
-
kernel_size = 1, bias = False)
|
| 120 |
-
|
| 121 |
-
# For positional encoding
|
| 122 |
-
self.pe_hidden_size = configs.enc_in
|
| 123 |
-
if self.pe_hidden_size % 2 == 1:
|
| 124 |
-
self.pe_hidden_size += 1
|
| 125 |
-
|
| 126 |
-
num_timescales = self.pe_hidden_size // 2
|
| 127 |
-
max_timescale = 10000.0
|
| 128 |
-
min_timescale = 1.0
|
| 129 |
-
|
| 130 |
-
log_timescale_increment = (
|
| 131 |
-
math.log(float(max_timescale) / float(min_timescale)) /
|
| 132 |
-
max(num_timescales - 1, 1))
|
| 133 |
-
inv_timescales = min_timescale * torch.exp(
|
| 134 |
-
torch.arange(num_timescales, dtype=torch.float32) *
|
| 135 |
-
-log_timescale_increment)
|
| 136 |
-
self.register_buffer('inv_timescales', inv_timescales)
|
| 137 |
-
|
| 138 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 139 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 140 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) # [B,pred_len,C]
|
| 141 |
-
dec_out = torch.cat([torch.zeros_like(x_enc), dec_out], dim=1)
|
| 142 |
-
return dec_out # [B, T, D]
|
| 143 |
-
return None
|
| 144 |
-
|
| 145 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 146 |
-
# Normalization from Non-stationary Transformer
|
| 147 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 148 |
-
x_enc = x_enc - means
|
| 149 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 150 |
-
x_enc /= stdev
|
| 151 |
-
|
| 152 |
-
# position-encoding
|
| 153 |
-
pe = self.get_position_encoding(x_enc)
|
| 154 |
-
if pe.shape[2] > x_enc.shape[2]:
|
| 155 |
-
x_enc += pe[:, :, :-1]
|
| 156 |
-
else:
|
| 157 |
-
x_enc += self.get_position_encoding(x_enc)
|
| 158 |
-
|
| 159 |
-
# SCINet
|
| 160 |
-
dec_out = self.sci_net_1(x_enc)
|
| 161 |
-
dec_out += x_enc
|
| 162 |
-
dec_out = self.projection_1(dec_out)
|
| 163 |
-
if self.num_stacks != 1:
|
| 164 |
-
dec_out = torch.cat((x_enc, dec_out), dim=1)
|
| 165 |
-
temp = dec_out
|
| 166 |
-
dec_out = self.sci_net_2(dec_out)
|
| 167 |
-
dec_out += temp
|
| 168 |
-
dec_out = self.projection_2(dec_out)
|
| 169 |
-
|
| 170 |
-
# De-Normalization from Non-stationary Transformer
|
| 171 |
-
dec_out = dec_out * \
|
| 172 |
-
(stdev[:, 0, :].unsqueeze(1).repeat(
|
| 173 |
-
1, self.pred_len + self.seq_len, 1))
|
| 174 |
-
dec_out = dec_out + \
|
| 175 |
-
(means[:, 0, :].unsqueeze(1).repeat(
|
| 176 |
-
1, self.pred_len + self.seq_len, 1))
|
| 177 |
-
return dec_out
|
| 178 |
-
|
| 179 |
-
def get_position_encoding(self, x):
|
| 180 |
-
max_length = x.size()[1]
|
| 181 |
-
position = torch.arange(max_length, dtype=torch.float32,
|
| 182 |
-
device=x.device) # tensor([0., 1., 2., 3., 4.], device='cuda:0')
|
| 183 |
-
scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) # 5 256
|
| 184 |
-
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) # [T, C]
|
| 185 |
-
signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2))
|
| 186 |
-
signal = signal.view(1, max_length, self.pe_hidden_size)
|
| 187 |
-
|
| 188 |
-
return signal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SegRNN.py
DELETED
|
@@ -1,119 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Autoformer_EncDec import series_decomp
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class Model(nn.Module):
|
| 8 |
-
"""
|
| 9 |
-
Paper link: https://arxiv.org/abs/2308.11200.pdf
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def __init__(self, configs):
|
| 13 |
-
super(Model, self).__init__()
|
| 14 |
-
|
| 15 |
-
# get parameters
|
| 16 |
-
self.seq_len = configs.seq_len
|
| 17 |
-
self.enc_in = configs.enc_in
|
| 18 |
-
self.d_model = configs.d_model
|
| 19 |
-
self.dropout = configs.dropout
|
| 20 |
-
|
| 21 |
-
self.task_name = configs.task_name
|
| 22 |
-
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
|
| 23 |
-
self.pred_len = configs.seq_len
|
| 24 |
-
else:
|
| 25 |
-
self.pred_len = configs.pred_len
|
| 26 |
-
|
| 27 |
-
self.seg_len = configs.seg_len
|
| 28 |
-
self.seg_num_x = self.seq_len // self.seg_len
|
| 29 |
-
self.seg_num_y = self.pred_len // self.seg_len
|
| 30 |
-
|
| 31 |
-
# building model
|
| 32 |
-
self.valueEmbedding = nn.Sequential(
|
| 33 |
-
nn.Linear(self.seg_len, self.d_model),
|
| 34 |
-
nn.ReLU()
|
| 35 |
-
)
|
| 36 |
-
self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,
|
| 37 |
-
batch_first=True, bidirectional=False)
|
| 38 |
-
self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2))
|
| 39 |
-
self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2))
|
| 40 |
-
|
| 41 |
-
self.predict = nn.Sequential(
|
| 42 |
-
nn.Dropout(self.dropout),
|
| 43 |
-
nn.Linear(self.d_model, self.seg_len)
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
if self.task_name == 'classification':
|
| 47 |
-
self.act = F.gelu
|
| 48 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 49 |
-
self.projection = nn.Linear(
|
| 50 |
-
configs.enc_in * configs.seq_len, configs.num_class)
|
| 51 |
-
|
| 52 |
-
def encoder(self, x):
|
| 53 |
-
# b:batch_size c:channel_size s:seq_len s:seq_len
|
| 54 |
-
# d:d_model w:seg_len n:seg_num_x m:seg_num_y
|
| 55 |
-
batch_size = x.size(0)
|
| 56 |
-
|
| 57 |
-
# normalization and permute b,s,c -> b,c,s
|
| 58 |
-
seq_last = x[:, -1:, :].detach()
|
| 59 |
-
x = (x - seq_last).permute(0, 2, 1) # b,c,s
|
| 60 |
-
|
| 61 |
-
# segment and embedding b,c,s -> bc,n,w -> bc,n,d
|
| 62 |
-
x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len))
|
| 63 |
-
|
| 64 |
-
# encoding
|
| 65 |
-
_, hn = self.rnn(x) # bc,n,d 1,bc,d
|
| 66 |
-
|
| 67 |
-
# m,d//2 -> 1,m,d//2 -> c,m,d//2
|
| 68 |
-
# c,d//2 -> c,1,d//2 -> c,m,d//2
|
| 69 |
-
# c,m,d -> cm,1,d -> bcm, 1, d
|
| 70 |
-
pos_emb = torch.cat([
|
| 71 |
-
self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),
|
| 72 |
-
self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1)
|
| 73 |
-
], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1)
|
| 74 |
-
|
| 75 |
-
_, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d
|
| 76 |
-
|
| 77 |
-
# 1,bcm,d -> 1,bcm,w -> b,c,s
|
| 78 |
-
y = self.predict(hy).view(-1, self.enc_in, self.pred_len)
|
| 79 |
-
|
| 80 |
-
# permute and denorm
|
| 81 |
-
y = y.permute(0, 2, 1) + seq_last
|
| 82 |
-
return y
|
| 83 |
-
|
| 84 |
-
def forecast(self, x_enc):
|
| 85 |
-
# Encoder
|
| 86 |
-
return self.encoder(x_enc)
|
| 87 |
-
|
| 88 |
-
def imputation(self, x_enc):
|
| 89 |
-
# Encoder
|
| 90 |
-
return self.encoder(x_enc)
|
| 91 |
-
|
| 92 |
-
def anomaly_detection(self, x_enc):
|
| 93 |
-
# Encoder
|
| 94 |
-
return self.encoder(x_enc)
|
| 95 |
-
|
| 96 |
-
def classification(self, x_enc):
|
| 97 |
-
# Encoder
|
| 98 |
-
enc_out = self.encoder(x_enc)
|
| 99 |
-
# Output
|
| 100 |
-
# (batch_size, seq_length * d_model)
|
| 101 |
-
output = enc_out.reshape(enc_out.shape[0], -1)
|
| 102 |
-
# (batch_size, num_classes)
|
| 103 |
-
output = self.projection(output)
|
| 104 |
-
return output
|
| 105 |
-
|
| 106 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 107 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 108 |
-
dec_out = self.forecast(x_enc)
|
| 109 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 110 |
-
if self.task_name == 'imputation':
|
| 111 |
-
dec_out = self.imputation(x_enc)
|
| 112 |
-
return dec_out # [B, L, D]
|
| 113 |
-
if self.task_name == 'anomaly_detection':
|
| 114 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 115 |
-
return dec_out # [B, L, D]
|
| 116 |
-
if self.task_name == 'classification':
|
| 117 |
-
dec_out = self.classification(x_enc)
|
| 118 |
-
return dec_out # [B, N]
|
| 119 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/TSMixer.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
class ResBlock(nn.Module):
|
| 5 |
-
def __init__(self, configs):
|
| 6 |
-
super(ResBlock, self).__init__()
|
| 7 |
-
|
| 8 |
-
self.temporal = nn.Sequential(
|
| 9 |
-
nn.Linear(configs.seq_len, configs.d_model),
|
| 10 |
-
nn.ReLU(),
|
| 11 |
-
nn.Linear(configs.d_model, configs.seq_len),
|
| 12 |
-
nn.Dropout(configs.dropout)
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
self.channel = nn.Sequential(
|
| 16 |
-
nn.Linear(configs.enc_in, configs.d_model),
|
| 17 |
-
nn.ReLU(),
|
| 18 |
-
nn.Linear(configs.d_model, configs.enc_in),
|
| 19 |
-
nn.Dropout(configs.dropout)
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
def forward(self, x):
|
| 23 |
-
# x: [B, L, D]
|
| 24 |
-
x = x + self.temporal(x.transpose(1, 2)).transpose(1, 2)
|
| 25 |
-
x = x + self.channel(x)
|
| 26 |
-
|
| 27 |
-
return x
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class Model(nn.Module):
|
| 31 |
-
def __init__(self, configs):
|
| 32 |
-
super(Model, self).__init__()
|
| 33 |
-
self.task_name = configs.task_name
|
| 34 |
-
self.layer = configs.e_layers
|
| 35 |
-
self.model = nn.ModuleList([ResBlock(configs)
|
| 36 |
-
for _ in range(configs.e_layers)])
|
| 37 |
-
self.pred_len = configs.pred_len
|
| 38 |
-
self.projection = nn.Linear(configs.seq_len, configs.pred_len)
|
| 39 |
-
|
| 40 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 41 |
-
|
| 42 |
-
# x: [B, L, D]
|
| 43 |
-
for i in range(self.layer):
|
| 44 |
-
x_enc = self.model[i](x_enc)
|
| 45 |
-
enc_out = self.projection(x_enc.transpose(1, 2)).transpose(1, 2)
|
| 46 |
-
|
| 47 |
-
return enc_out
|
| 48 |
-
|
| 49 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 50 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 51 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 52 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 53 |
-
else:
|
| 54 |
-
raise ValueError('Only forecast tasks implemented yet')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/TemporalFusionTransformer.py
DELETED
|
@@ -1,309 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Embed import DataEmbedding, TemporalEmbedding
|
| 5 |
-
from torch import Tensor
|
| 6 |
-
from typing import Optional
|
| 7 |
-
from collections import namedtuple
|
| 8 |
-
|
| 9 |
-
# static: time-independent features
|
| 10 |
-
# observed: time features of the past(e.g. predicted targets)
|
| 11 |
-
# known: known information about the past and future(i.e. time stamp)
|
| 12 |
-
TypePos = namedtuple('TypePos', ['static', 'observed'])
|
| 13 |
-
|
| 14 |
-
# When you want to use new dataset, please add the index of 'static, observed' columns here.
|
| 15 |
-
# 'known' columns needn't be added, because 'known' inputs are automatically judged and provided by the program.
|
| 16 |
-
datatype_dict = {'ETTh1': TypePos([], [x for x in range(7)]),
|
| 17 |
-
'ETTm1': TypePos([], [x for x in range(7)])}
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def get_known_len(embed_type, freq):
|
| 21 |
-
if embed_type != 'timeF':
|
| 22 |
-
if freq == 't':
|
| 23 |
-
return 5
|
| 24 |
-
else:
|
| 25 |
-
return 4
|
| 26 |
-
else:
|
| 27 |
-
freq_map = {'h': 4, 't': 5, 's': 6,
|
| 28 |
-
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
| 29 |
-
return freq_map[freq]
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class TFTTemporalEmbedding(TemporalEmbedding):
|
| 33 |
-
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
| 34 |
-
super(TFTTemporalEmbedding, self).__init__(d_model, embed_type, freq)
|
| 35 |
-
|
| 36 |
-
def forward(self, x):
|
| 37 |
-
x = x.long()
|
| 38 |
-
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
|
| 39 |
-
self, 'minute_embed') else 0.
|
| 40 |
-
hour_x = self.hour_embed(x[:, :, 3])
|
| 41 |
-
weekday_x = self.weekday_embed(x[:, :, 2])
|
| 42 |
-
day_x = self.day_embed(x[:, :, 1])
|
| 43 |
-
month_x = self.month_embed(x[:, :, 0])
|
| 44 |
-
|
| 45 |
-
embedding_x = torch.stack([month_x, day_x, weekday_x, hour_x, minute_x], dim=-2) if hasattr(
|
| 46 |
-
self, 'minute_embed') else torch.stack([month_x, day_x, weekday_x, hour_x], dim=-2)
|
| 47 |
-
return embedding_x
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class TFTTimeFeatureEmbedding(nn.Module):
|
| 51 |
-
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
| 52 |
-
super(TFTTimeFeatureEmbedding, self).__init__()
|
| 53 |
-
d_inp = get_known_len(embed_type, freq)
|
| 54 |
-
self.embed = nn.ModuleList([nn.Linear(1, d_model, bias=False) for _ in range(d_inp)])
|
| 55 |
-
|
| 56 |
-
def forward(self, x):
|
| 57 |
-
return torch.stack([embed(x[:,:,i].unsqueeze(-1)) for i, embed in enumerate(self.embed)], dim=-2)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
class TFTEmbedding(nn.Module):
|
| 61 |
-
def __init__(self, configs):
|
| 62 |
-
super(TFTEmbedding, self).__init__()
|
| 63 |
-
self.pred_len = configs.pred_len
|
| 64 |
-
self.static_pos = datatype_dict[configs.data].static
|
| 65 |
-
self.observed_pos = datatype_dict[configs.data].observed
|
| 66 |
-
self.static_len = len(self.static_pos)
|
| 67 |
-
self.observed_len = len(self.observed_pos)
|
| 68 |
-
|
| 69 |
-
self.static_embedding = nn.ModuleList([DataEmbedding(1,configs.d_model,dropout=configs.dropout) for _ in range(self.static_len)]) \
|
| 70 |
-
if self.static_len else None
|
| 71 |
-
self.observed_embedding = nn.ModuleList([DataEmbedding(1,configs.d_model,dropout=configs.dropout) for _ in range(self.observed_len)])
|
| 72 |
-
self.known_embedding = TFTTemporalEmbedding(configs.d_model, configs.embed, configs.freq) \
|
| 73 |
-
if configs.embed != 'timeF' else TFTTimeFeatureEmbedding(configs.d_model, configs.embed, configs.freq)
|
| 74 |
-
|
| 75 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 76 |
-
if self.static_len:
|
| 77 |
-
# static_input: [B,C,d_model]
|
| 78 |
-
static_input = torch.stack([embed(x_enc[:,:1,self.static_pos[i]].unsqueeze(-1), None).squeeze(1) for i, embed in enumerate(self.static_embedding)], dim=-2)
|
| 79 |
-
else:
|
| 80 |
-
static_input = None
|
| 81 |
-
|
| 82 |
-
# observed_input: [B,T,C,d_model]
|
| 83 |
-
observed_input = torch.stack([embed(x_enc[:,:,self.observed_pos[i]].unsqueeze(-1), None) for i, embed in enumerate(self.observed_embedding)], dim=-2)
|
| 84 |
-
|
| 85 |
-
x_mark = torch.cat([x_mark_enc, x_mark_dec[:,-self.pred_len:,:]], dim=-2)
|
| 86 |
-
# known_input: [B,T,C,d_model]
|
| 87 |
-
known_input = self.known_embedding(x_mark)
|
| 88 |
-
|
| 89 |
-
return static_input, observed_input, known_input
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
class GLU(nn.Module):
|
| 93 |
-
def __init__(self, input_size, output_size):
|
| 94 |
-
super().__init__()
|
| 95 |
-
self.fc1 = nn.Linear(input_size, output_size)
|
| 96 |
-
self.fc2 = nn.Linear(input_size, output_size)
|
| 97 |
-
self.glu = nn.GLU()
|
| 98 |
-
|
| 99 |
-
def forward(self, x):
|
| 100 |
-
a = self.fc1(x)
|
| 101 |
-
b = self.fc2(x)
|
| 102 |
-
return self.glu(torch.cat([a, b], dim=-1))
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
class GateAddNorm(nn.Module):
|
| 106 |
-
def __init__(self, input_size, output_size):
|
| 107 |
-
super(GateAddNorm, self).__init__()
|
| 108 |
-
self.glu = GLU(input_size, input_size)
|
| 109 |
-
self.projection = nn.Linear(input_size, output_size) if input_size != output_size else nn.Identity()
|
| 110 |
-
self.layer_norm = nn.LayerNorm(output_size)
|
| 111 |
-
|
| 112 |
-
def forward(self, x, skip_a):
|
| 113 |
-
x = self.glu(x)
|
| 114 |
-
x = x + skip_a
|
| 115 |
-
return self.layer_norm(self.projection(x))
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
class GRN(nn.Module):
|
| 119 |
-
def __init__(self, input_size, output_size, hidden_size=None, context_size=None, dropout=0.0):
|
| 120 |
-
super(GRN, self).__init__()
|
| 121 |
-
hidden_size = input_size if hidden_size is None else hidden_size
|
| 122 |
-
self.lin_a = nn.Linear(input_size, hidden_size)
|
| 123 |
-
self.lin_c = nn.Linear(context_size, hidden_size) if context_size is not None else None
|
| 124 |
-
self.lin_i = nn.Linear(hidden_size, hidden_size)
|
| 125 |
-
self.dropout = nn.Dropout(dropout)
|
| 126 |
-
self.project_a = nn.Linear(input_size, hidden_size) if hidden_size != input_size else nn.Identity()
|
| 127 |
-
self.gate = GateAddNorm(hidden_size, output_size)
|
| 128 |
-
|
| 129 |
-
def forward(self, a: Tensor, c: Optional[Tensor] = None):
|
| 130 |
-
# a: [B,T,d], c: [B,d]
|
| 131 |
-
x = self.lin_a(a)
|
| 132 |
-
if c is not None:
|
| 133 |
-
x = x + self.lin_c(c).unsqueeze(1)
|
| 134 |
-
x = F.elu(x)
|
| 135 |
-
x = self.lin_i(x)
|
| 136 |
-
x = self.dropout(x)
|
| 137 |
-
return self.gate(x, self.project_a(a))
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
class VariableSelectionNetwork(nn.Module):
|
| 141 |
-
def __init__(self, d_model, variable_num, dropout=0.0):
|
| 142 |
-
super(VariableSelectionNetwork, self).__init__()
|
| 143 |
-
self.joint_grn = GRN(d_model * variable_num, variable_num, hidden_size=d_model, context_size=d_model, dropout=dropout)
|
| 144 |
-
self.variable_grns = nn.ModuleList([GRN(d_model, d_model, dropout=dropout) for _ in range(variable_num)])
|
| 145 |
-
|
| 146 |
-
def forward(self, x: Tensor, context: Optional[Tensor] = None):
|
| 147 |
-
# x: [B,T,C,d] or [B,C,d]
|
| 148 |
-
# selection_weights: [B,T,C] or [B,C]
|
| 149 |
-
# x_processed: [B,T,d,C] or [B,d,C]
|
| 150 |
-
# selection_result: [B,T,d] or [B,d]
|
| 151 |
-
x_flattened = torch.flatten(x, start_dim=-2)
|
| 152 |
-
selection_weights = self.joint_grn(x_flattened, context)
|
| 153 |
-
selection_weights = F.softmax(selection_weights, dim=-1)
|
| 154 |
-
|
| 155 |
-
x_processed = torch.stack([grn(x[...,i,:]) for i, grn in enumerate(self.variable_grns)], dim=-1)
|
| 156 |
-
|
| 157 |
-
selection_result = torch.matmul(x_processed, selection_weights.unsqueeze(-1)).squeeze(-1)
|
| 158 |
-
return selection_result
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
class StaticCovariateEncoder(nn.Module):
|
| 162 |
-
def __init__(self, d_model, static_len, dropout=0.0):
|
| 163 |
-
super(StaticCovariateEncoder, self).__init__()
|
| 164 |
-
self.static_vsn = VariableSelectionNetwork(d_model, static_len) if static_len else None
|
| 165 |
-
self.grns = nn.ModuleList([GRN(d_model, d_model, dropout=dropout) for _ in range(4)])
|
| 166 |
-
|
| 167 |
-
def forward(self, static_input):
|
| 168 |
-
# static_input: [B,C,d]
|
| 169 |
-
if static_input is not None:
|
| 170 |
-
static_features = self.static_vsn(static_input)
|
| 171 |
-
return [grn(static_features) for grn in self.grns]
|
| 172 |
-
else:
|
| 173 |
-
return [None] * 4
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
class InterpretableMultiHeadAttention(nn.Module):
|
| 177 |
-
def __init__(self, configs):
|
| 178 |
-
super(InterpretableMultiHeadAttention, self).__init__()
|
| 179 |
-
self.n_heads = configs.n_heads
|
| 180 |
-
assert configs.d_model % configs.n_heads == 0
|
| 181 |
-
self.d_head = configs.d_model // configs.n_heads
|
| 182 |
-
self.qkv_linears = nn.Linear(configs.d_model, (2 * self.n_heads + 1) * self.d_head, bias=False)
|
| 183 |
-
self.out_projection = nn.Linear(self.d_head, configs.d_model, bias=False)
|
| 184 |
-
self.out_dropout = nn.Dropout(configs.dropout)
|
| 185 |
-
self.scale = self.d_head ** -0.5
|
| 186 |
-
example_len = configs.seq_len + configs.pred_len
|
| 187 |
-
self.register_buffer("mask", torch.triu(torch.full((example_len, example_len), float('-inf')), 1))
|
| 188 |
-
|
| 189 |
-
def forward(self, x):
|
| 190 |
-
# Q,K,V are all from x
|
| 191 |
-
B, T, d_model = x.shape
|
| 192 |
-
qkv = self.qkv_linears(x)
|
| 193 |
-
q, k, v = qkv.split((self.n_heads * self.d_head, self.n_heads * self.d_head, self.d_head), dim=-1)
|
| 194 |
-
q = q.view(B, T, self.n_heads, self.d_head)
|
| 195 |
-
k = k.view(B, T, self.n_heads, self.d_head)
|
| 196 |
-
v = v.view(B, T, self.d_head)
|
| 197 |
-
|
| 198 |
-
attention_score = torch.matmul(q.permute((0, 2, 1, 3)), k.permute((0, 2, 3, 1))) # [B,n,T,T]
|
| 199 |
-
attention_score.mul_(self.scale)
|
| 200 |
-
attention_score = attention_score + self.mask
|
| 201 |
-
attention_prob = F.softmax(attention_score, dim=3) # [B,n,T,T]
|
| 202 |
-
|
| 203 |
-
attention_out = torch.matmul(attention_prob, v.unsqueeze(1)) # [B,n,T,d]
|
| 204 |
-
attention_out = torch.mean(attention_out, dim=1) # [B,T,d]
|
| 205 |
-
out = self.out_projection(attention_out)
|
| 206 |
-
out = self.out_dropout(out) # [B,T,d]
|
| 207 |
-
return out
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
class TemporalFusionDecoder(nn.Module):
|
| 211 |
-
def __init__(self, configs):
|
| 212 |
-
super(TemporalFusionDecoder, self).__init__()
|
| 213 |
-
self.pred_len = configs.pred_len
|
| 214 |
-
|
| 215 |
-
self.history_encoder = nn.LSTM(configs.d_model, configs.d_model, batch_first=True)
|
| 216 |
-
self.future_encoder = nn.LSTM(configs.d_model, configs.d_model, batch_first=True)
|
| 217 |
-
self.gate_after_lstm = GateAddNorm(configs.d_model, configs.d_model)
|
| 218 |
-
self.enrichment_grn = GRN(configs.d_model, configs.d_model, context_size=configs.d_model, dropout=configs.dropout)
|
| 219 |
-
self.attention = InterpretableMultiHeadAttention(configs)
|
| 220 |
-
self.gate_after_attention = GateAddNorm(configs.d_model, configs.d_model)
|
| 221 |
-
self.position_wise_grn = GRN(configs.d_model, configs.d_model, dropout=configs.dropout)
|
| 222 |
-
self.gate_final = GateAddNorm(configs.d_model, configs.d_model)
|
| 223 |
-
self.out_projection = nn.Linear(configs.d_model, configs.c_out)
|
| 224 |
-
|
| 225 |
-
def forward(self, history_input, future_input, c_c, c_h, c_e):
|
| 226 |
-
# history_input, future_input: [B,T,d]
|
| 227 |
-
# c_c, c_h, c_e: [B,d]
|
| 228 |
-
# LSTM
|
| 229 |
-
c = (c_c.unsqueeze(0), c_h.unsqueeze(0)) if c_c is not None and c_h is not None else None
|
| 230 |
-
historical_features, state = self.history_encoder(history_input, c)
|
| 231 |
-
future_features, _ = self.future_encoder(future_input, state)
|
| 232 |
-
|
| 233 |
-
# Skip connection
|
| 234 |
-
temporal_input = torch.cat([history_input, future_input], dim=1)
|
| 235 |
-
temporal_features = torch.cat([historical_features, future_features], dim=1)
|
| 236 |
-
temporal_features = self.gate_after_lstm(temporal_features, temporal_input) # [B,T,d]
|
| 237 |
-
|
| 238 |
-
# Static enrichment
|
| 239 |
-
enriched_features = self.enrichment_grn(temporal_features, c_e) # [B,T,d]
|
| 240 |
-
|
| 241 |
-
# Temporal self-attention
|
| 242 |
-
attention_out = self.attention(enriched_features) # [B,T,d]
|
| 243 |
-
# Don't compute historical loss
|
| 244 |
-
attention_out = self.gate_after_attention(attention_out[:,-self.pred_len:], enriched_features[:,-self.pred_len:])
|
| 245 |
-
|
| 246 |
-
# Position-wise feed-forward
|
| 247 |
-
out = self.position_wise_grn(attention_out) # [B,T,d]
|
| 248 |
-
|
| 249 |
-
# Final skip connection
|
| 250 |
-
out = self.gate_final(out, temporal_features[:,-self.pred_len:])
|
| 251 |
-
return self.out_projection(out)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
class Model(nn.Module):
|
| 255 |
-
def __init__(self, configs):
|
| 256 |
-
super(Model, self).__init__()
|
| 257 |
-
self.configs = configs
|
| 258 |
-
self.task_name = configs.task_name
|
| 259 |
-
self.seq_len = configs.seq_len
|
| 260 |
-
self.label_len = configs.label_len
|
| 261 |
-
self.pred_len = configs.pred_len
|
| 262 |
-
|
| 263 |
-
# Number of variables
|
| 264 |
-
self.static_len = len(datatype_dict[configs.data].static)
|
| 265 |
-
self.observed_len = len(datatype_dict[configs.data].observed)
|
| 266 |
-
self.known_len = get_known_len(configs.embed, configs.freq)
|
| 267 |
-
|
| 268 |
-
self.embedding = TFTEmbedding(configs)
|
| 269 |
-
self.static_encoder = StaticCovariateEncoder(configs.d_model, self.static_len)
|
| 270 |
-
self.history_vsn = VariableSelectionNetwork(configs.d_model, self.observed_len + self.known_len)
|
| 271 |
-
self.future_vsn = VariableSelectionNetwork(configs.d_model, self.known_len)
|
| 272 |
-
self.temporal_fusion_decoder = TemporalFusionDecoder(configs)
|
| 273 |
-
|
| 274 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 275 |
-
# Normalization from Non-stationary Transformer
|
| 276 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 277 |
-
x_enc = x_enc - means
|
| 278 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 279 |
-
x_enc /= stdev
|
| 280 |
-
|
| 281 |
-
# Data embedding
|
| 282 |
-
# static_input: [B,C,d], observed_input:[B,T,C,d], known_input: [B,T,C,d]
|
| 283 |
-
static_input, observed_input, known_input = self.embedding(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 284 |
-
|
| 285 |
-
# Static context
|
| 286 |
-
# c_s,...,c_e: [B,d]
|
| 287 |
-
c_s, c_c, c_h, c_e = self.static_encoder(static_input)
|
| 288 |
-
|
| 289 |
-
# Temporal input Selection
|
| 290 |
-
history_input = torch.cat([observed_input, known_input[:,:self.seq_len]], dim=-2)
|
| 291 |
-
future_input = known_input[:,self.seq_len:]
|
| 292 |
-
history_input = self.history_vsn(history_input, c_s)
|
| 293 |
-
future_input = self.future_vsn(future_input, c_s)
|
| 294 |
-
|
| 295 |
-
# TFT main procedure after variable selection
|
| 296 |
-
# history_input: [B,T,d], future_input: [B,T,d]
|
| 297 |
-
dec_out = self.temporal_fusion_decoder(history_input, future_input, c_c, c_h, c_e)
|
| 298 |
-
|
| 299 |
-
# De-Normalization from Non-stationary Transformer
|
| 300 |
-
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
| 301 |
-
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
| 302 |
-
return dec_out
|
| 303 |
-
|
| 304 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 305 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 306 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) # [B,pred_len,C]
|
| 307 |
-
dec_out = torch.cat([torch.zeros_like(x_enc), dec_out], dim=1)
|
| 308 |
-
return dec_out # [B, T, D]
|
| 309 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/TiDE.py
DELETED
|
@@ -1,145 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class LayerNorm(nn.Module):
|
| 7 |
-
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
| 8 |
-
|
| 9 |
-
def __init__(self, ndim, bias):
|
| 10 |
-
super().__init__()
|
| 11 |
-
self.weight = nn.Parameter(torch.ones(ndim))
|
| 12 |
-
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
| 13 |
-
|
| 14 |
-
def forward(self, input):
|
| 15 |
-
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ResBlock(nn.Module):
|
| 20 |
-
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
|
| 21 |
-
super().__init__()
|
| 22 |
-
|
| 23 |
-
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
|
| 24 |
-
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
|
| 25 |
-
self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
|
| 26 |
-
self.dropout = nn.Dropout(dropout)
|
| 27 |
-
self.relu = nn.ReLU()
|
| 28 |
-
self.ln = LayerNorm(output_dim, bias=bias)
|
| 29 |
-
|
| 30 |
-
def forward(self, x):
|
| 31 |
-
|
| 32 |
-
out = self.fc1(x)
|
| 33 |
-
out = self.relu(out)
|
| 34 |
-
out = self.fc2(out)
|
| 35 |
-
out = self.dropout(out)
|
| 36 |
-
out = out + self.fc3(x)
|
| 37 |
-
out = self.ln(out)
|
| 38 |
-
return out
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
#TiDE
|
| 42 |
-
class Model(nn.Module):
|
| 43 |
-
"""
|
| 44 |
-
paper: https://arxiv.org/pdf/2304.08424.pdf
|
| 45 |
-
"""
|
| 46 |
-
def __init__(self, configs, bias=True, feature_encode_dim=2):
|
| 47 |
-
super(Model, self).__init__()
|
| 48 |
-
self.configs = configs
|
| 49 |
-
self.task_name = configs.task_name
|
| 50 |
-
self.seq_len = configs.seq_len #L
|
| 51 |
-
self.label_len = configs.label_len
|
| 52 |
-
self.pred_len = configs.pred_len #H
|
| 53 |
-
self.hidden_dim=configs.d_model
|
| 54 |
-
self.res_hidden=configs.d_model
|
| 55 |
-
self.encoder_num=configs.e_layers
|
| 56 |
-
self.decoder_num=configs.d_layers
|
| 57 |
-
self.freq=configs.freq
|
| 58 |
-
self.feature_encode_dim=feature_encode_dim
|
| 59 |
-
self.decode_dim = configs.c_out
|
| 60 |
-
self.temporalDecoderHidden=configs.d_ff
|
| 61 |
-
dropout=configs.dropout
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
freq_map = {'h': 4, 't': 5, 's': 6,
|
| 65 |
-
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
|
| 66 |
-
|
| 67 |
-
self.feature_dim=freq_map[self.freq]
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
flatten_dim = self.seq_len + (self.seq_len + self.pred_len) * self.feature_encode_dim
|
| 71 |
-
|
| 72 |
-
self.feature_encoder = ResBlock(self.feature_dim, self.res_hidden, self.feature_encode_dim, dropout, bias)
|
| 73 |
-
self.encoders = nn.Sequential(ResBlock(flatten_dim, self.res_hidden, self.hidden_dim, dropout, bias),*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.encoder_num-1)))
|
| 74 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 75 |
-
self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.pred_len, dropout, bias))
|
| 76 |
-
self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias)
|
| 77 |
-
self.residual_proj = nn.Linear(self.seq_len, self.pred_len, bias=bias)
|
| 78 |
-
if self.task_name == 'imputation':
|
| 79 |
-
self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.seq_len, dropout, bias))
|
| 80 |
-
self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias)
|
| 81 |
-
self.residual_proj = nn.Linear(self.seq_len, self.seq_len, bias=bias)
|
| 82 |
-
if self.task_name == 'anomaly_detection':
|
| 83 |
-
self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.seq_len, dropout, bias))
|
| 84 |
-
self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias)
|
| 85 |
-
self.residual_proj = nn.Linear(self.seq_len, self.seq_len, bias=bias)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, batch_y_mark):
|
| 89 |
-
# Normalization
|
| 90 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 91 |
-
x_enc = x_enc - means
|
| 92 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 93 |
-
x_enc /= stdev
|
| 94 |
-
|
| 95 |
-
feature = self.feature_encoder(batch_y_mark)
|
| 96 |
-
hidden = self.encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1))
|
| 97 |
-
decoded = self.decoders(hidden).reshape(hidden.shape[0], self.pred_len, self.decode_dim)
|
| 98 |
-
dec_out = self.temporalDecoder(torch.cat([feature[:,self.seq_len:], decoded], dim=-1)).squeeze(-1) + self.residual_proj(x_enc)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
# De-Normalization
|
| 102 |
-
dec_out = dec_out * (stdev[:, 0].unsqueeze(1).repeat(1, self.pred_len))
|
| 103 |
-
dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.pred_len))
|
| 104 |
-
return dec_out
|
| 105 |
-
|
| 106 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, batch_y_mark, mask):
|
| 107 |
-
# Normalization
|
| 108 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 109 |
-
x_enc = x_enc - means
|
| 110 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 111 |
-
x_enc /= stdev
|
| 112 |
-
|
| 113 |
-
feature = self.feature_encoder(x_mark_enc)
|
| 114 |
-
hidden = self.encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1))
|
| 115 |
-
decoded = self.decoders(hidden).reshape(hidden.shape[0], self.seq_len, self.decode_dim)
|
| 116 |
-
dec_out = self.temporalDecoder(torch.cat([feature[:,:self.seq_len], decoded], dim=-1)).squeeze(-1) + self.residual_proj(x_enc)
|
| 117 |
-
|
| 118 |
-
# De-Normalization
|
| 119 |
-
dec_out = dec_out * (stdev[:, 0].unsqueeze(1).repeat(1, self.seq_len))
|
| 120 |
-
dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.seq_len))
|
| 121 |
-
return dec_out
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def forward(self, x_enc, x_mark_enc, x_dec, batch_y_mark, mask=None):
|
| 125 |
-
'''x_mark_enc is the exogenous dynamic feature described in the original paper'''
|
| 126 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 127 |
-
if batch_y_mark is None:
|
| 128 |
-
batch_y_mark = torch.zeros((x_enc.shape[0], self.seq_len+self.pred_len, self.feature_dim)).to(x_enc.device).detach()
|
| 129 |
-
else:
|
| 130 |
-
batch_y_mark = torch.concat([x_mark_enc, batch_y_mark[:, -self.pred_len:, :]],dim=1)
|
| 131 |
-
dec_out = torch.stack([self.forecast(x_enc[:, :, feature], x_mark_enc, x_dec, batch_y_mark) for feature in range(x_enc.shape[-1])],dim=-1)
|
| 132 |
-
return dec_out # [B, L, D]
|
| 133 |
-
if self.task_name == 'imputation':
|
| 134 |
-
dec_out = torch.stack([self.imputation(x_enc[:, :, feature], x_mark_enc, x_dec, batch_y_mark, mask) for feature in range(x_enc.shape[-1])],dim=-1)
|
| 135 |
-
return dec_out # [B, L, D]
|
| 136 |
-
if self.task_name == 'anomaly_detection':
|
| 137 |
-
raise NotImplementedError("Task anomaly_detection for Tide is temporarily not supported")
|
| 138 |
-
if self.task_name == 'classification':
|
| 139 |
-
raise NotImplementedError("Task classification for Tide is temporarily not supported")
|
| 140 |
-
return None
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Transformer.py
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
|
| 5 |
-
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
| 6 |
-
from layers.Embed import DataEmbedding
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Model(nn.Module):
|
| 11 |
-
"""
|
| 12 |
-
Vanilla Transformer
|
| 13 |
-
with O(L^2) complexity
|
| 14 |
-
Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, configs):
|
| 18 |
-
super(Model, self).__init__()
|
| 19 |
-
self.task_name = configs.task_name
|
| 20 |
-
self.pred_len = configs.pred_len
|
| 21 |
-
# Embedding
|
| 22 |
-
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 23 |
-
configs.dropout)
|
| 24 |
-
# Encoder
|
| 25 |
-
self.encoder = Encoder(
|
| 26 |
-
[
|
| 27 |
-
EncoderLayer(
|
| 28 |
-
AttentionLayer(
|
| 29 |
-
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 30 |
-
output_attention=False), configs.d_model, configs.n_heads),
|
| 31 |
-
configs.d_model,
|
| 32 |
-
configs.d_ff,
|
| 33 |
-
dropout=configs.dropout,
|
| 34 |
-
activation=configs.activation
|
| 35 |
-
) for l in range(configs.e_layers)
|
| 36 |
-
],
|
| 37 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
| 38 |
-
)
|
| 39 |
-
# Decoder
|
| 40 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 41 |
-
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
|
| 42 |
-
configs.dropout)
|
| 43 |
-
self.decoder = Decoder(
|
| 44 |
-
[
|
| 45 |
-
DecoderLayer(
|
| 46 |
-
AttentionLayer(
|
| 47 |
-
FullAttention(True, configs.factor, attention_dropout=configs.dropout,
|
| 48 |
-
output_attention=False),
|
| 49 |
-
configs.d_model, configs.n_heads),
|
| 50 |
-
AttentionLayer(
|
| 51 |
-
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 52 |
-
output_attention=False),
|
| 53 |
-
configs.d_model, configs.n_heads),
|
| 54 |
-
configs.d_model,
|
| 55 |
-
configs.d_ff,
|
| 56 |
-
dropout=configs.dropout,
|
| 57 |
-
activation=configs.activation,
|
| 58 |
-
)
|
| 59 |
-
for l in range(configs.d_layers)
|
| 60 |
-
],
|
| 61 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model),
|
| 62 |
-
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 63 |
-
)
|
| 64 |
-
if self.task_name == 'imputation':
|
| 65 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 66 |
-
if self.task_name == 'anomaly_detection':
|
| 67 |
-
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
|
| 68 |
-
if self.task_name == 'classification':
|
| 69 |
-
self.act = F.gelu
|
| 70 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 71 |
-
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
|
| 72 |
-
|
| 73 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 74 |
-
# Embedding
|
| 75 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 76 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 77 |
-
|
| 78 |
-
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
| 79 |
-
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
|
| 80 |
-
return dec_out
|
| 81 |
-
|
| 82 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 83 |
-
# Embedding
|
| 84 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 85 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 86 |
-
|
| 87 |
-
dec_out = self.projection(enc_out)
|
| 88 |
-
return dec_out
|
| 89 |
-
|
| 90 |
-
def anomaly_detection(self, x_enc):
|
| 91 |
-
# Embedding
|
| 92 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 93 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 94 |
-
|
| 95 |
-
dec_out = self.projection(enc_out)
|
| 96 |
-
return dec_out
|
| 97 |
-
|
| 98 |
-
def classification(self, x_enc, x_mark_enc):
|
| 99 |
-
# Embedding
|
| 100 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 101 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 102 |
-
|
| 103 |
-
# Output
|
| 104 |
-
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
| 105 |
-
output = self.dropout(output)
|
| 106 |
-
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
|
| 107 |
-
output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model)
|
| 108 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 109 |
-
return output
|
| 110 |
-
|
| 111 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 112 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 113 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 114 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 115 |
-
if self.task_name == 'imputation':
|
| 116 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 117 |
-
return dec_out # [B, L, D]
|
| 118 |
-
if self.task_name == 'anomaly_detection':
|
| 119 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 120 |
-
return dec_out # [B, L, D]
|
| 121 |
-
if self.task_name == 'classification':
|
| 122 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 123 |
-
return dec_out # [B, N]
|
| 124 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/iTransformer.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from layers.Transformer_EncDec import Encoder, EncoderLayer
|
| 5 |
-
from layers.SelfAttention_Family import FullAttention, AttentionLayer
|
| 6 |
-
from layers.Embed import DataEmbedding_inverted
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Model(nn.Module):
|
| 11 |
-
"""
|
| 12 |
-
Paper link: https://arxiv.org/abs/2310.06625
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
def __init__(self, configs):
|
| 16 |
-
super(Model, self).__init__()
|
| 17 |
-
self.task_name = configs.task_name
|
| 18 |
-
self.seq_len = configs.seq_len
|
| 19 |
-
self.pred_len = configs.pred_len
|
| 20 |
-
# Embedding
|
| 21 |
-
self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
|
| 22 |
-
configs.dropout)
|
| 23 |
-
# Encoder
|
| 24 |
-
self.encoder = Encoder(
|
| 25 |
-
[
|
| 26 |
-
EncoderLayer(
|
| 27 |
-
AttentionLayer(
|
| 28 |
-
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
|
| 29 |
-
output_attention=False), configs.d_model, configs.n_heads),
|
| 30 |
-
configs.d_model,
|
| 31 |
-
configs.d_ff,
|
| 32 |
-
dropout=configs.dropout,
|
| 33 |
-
activation=configs.activation
|
| 34 |
-
) for l in range(configs.e_layers)
|
| 35 |
-
],
|
| 36 |
-
norm_layer=torch.nn.LayerNorm(configs.d_model)
|
| 37 |
-
)
|
| 38 |
-
# Decoder
|
| 39 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 40 |
-
self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True)
|
| 41 |
-
if self.task_name == 'imputation':
|
| 42 |
-
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
| 43 |
-
if self.task_name == 'anomaly_detection':
|
| 44 |
-
self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True)
|
| 45 |
-
if self.task_name == 'classification':
|
| 46 |
-
self.act = F.gelu
|
| 47 |
-
self.dropout = nn.Dropout(configs.dropout)
|
| 48 |
-
self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class)
|
| 49 |
-
|
| 50 |
-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 51 |
-
# Normalization from Non-stationary Transformer
|
| 52 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 53 |
-
x_enc = x_enc - means
|
| 54 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 55 |
-
x_enc /= stdev
|
| 56 |
-
|
| 57 |
-
_, _, N = x_enc.shape
|
| 58 |
-
|
| 59 |
-
# Embedding
|
| 60 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 61 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 62 |
-
|
| 63 |
-
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
| 64 |
-
# De-Normalization from Non-stationary Transformer
|
| 65 |
-
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
| 66 |
-
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
|
| 67 |
-
return dec_out
|
| 68 |
-
|
| 69 |
-
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 70 |
-
# Normalization from Non-stationary Transformer
|
| 71 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 72 |
-
x_enc = x_enc - means
|
| 73 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 74 |
-
x_enc /= stdev
|
| 75 |
-
|
| 76 |
-
_, L, N = x_enc.shape
|
| 77 |
-
|
| 78 |
-
# Embedding
|
| 79 |
-
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
| 80 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 81 |
-
|
| 82 |
-
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
| 83 |
-
# De-Normalization from Non-stationary Transformer
|
| 84 |
-
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
| 85 |
-
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
| 86 |
-
return dec_out
|
| 87 |
-
|
| 88 |
-
def anomaly_detection(self, x_enc):
|
| 89 |
-
# Normalization from Non-stationary Transformer
|
| 90 |
-
means = x_enc.mean(1, keepdim=True).detach()
|
| 91 |
-
x_enc = x_enc - means
|
| 92 |
-
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 93 |
-
x_enc /= stdev
|
| 94 |
-
|
| 95 |
-
_, L, N = x_enc.shape
|
| 96 |
-
|
| 97 |
-
# Embedding
|
| 98 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 99 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 100 |
-
|
| 101 |
-
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
|
| 102 |
-
# De-Normalization from Non-stationary Transformer
|
| 103 |
-
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
| 104 |
-
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1))
|
| 105 |
-
return dec_out
|
| 106 |
-
|
| 107 |
-
def classification(self, x_enc, x_mark_enc):
|
| 108 |
-
# Embedding
|
| 109 |
-
enc_out = self.enc_embedding(x_enc, None)
|
| 110 |
-
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
| 111 |
-
|
| 112 |
-
# Output
|
| 113 |
-
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
|
| 114 |
-
output = self.dropout(output)
|
| 115 |
-
output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model)
|
| 116 |
-
output = self.projection(output) # (batch_size, num_classes)
|
| 117 |
-
return output
|
| 118 |
-
|
| 119 |
-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 120 |
-
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 121 |
-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 122 |
-
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 123 |
-
if self.task_name == 'imputation':
|
| 124 |
-
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 125 |
-
return dec_out # [B, L, D]
|
| 126 |
-
if self.task_name == 'anomaly_detection':
|
| 127 |
-
dec_out = self.anomaly_detection(x_enc)
|
| 128 |
-
return dec_out # [B, L, D]
|
| 129 |
-
if self.task_name == 'classification':
|
| 130 |
-
dec_out = self.classification(x_enc, x_mark_enc)
|
| 131 |
-
return dec_out # [B, N]
|
| 132 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -21,4 +21,8 @@ pytz
|
|
| 21 |
google-generativeai
|
| 22 |
|
| 23 |
fastapi
|
| 24 |
-
uvicorn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
google-generativeai
|
| 22 |
|
| 23 |
fastapi
|
| 24 |
+
uvicorn
|
| 25 |
+
|
| 26 |
+
aiohttp>=3.8.0
|
| 27 |
+
scipy>=1.9.0
|
| 28 |
+
apscheduler>=3.10.0 # 백업용
|