heyunfei commited on
Commit
1ef2135
·
verified ·
1 Parent(s): 613683c

Update webui/run.py

Browse files
Files changed (1) hide show
  1. webui/run.py +862 -89
webui/run.py CHANGED
@@ -1,89 +1,862 @@
1
- #!/usr/bin/env python3
2
- """
3
- Kronos Web UI startup script
4
- """
5
-
6
- import os
7
- import sys
8
- import subprocess
9
- import webbrowser
10
- import time
11
-
12
- def check_dependencies():
13
- """Check if dependencies are installed"""
14
- try:
15
- import flask
16
- import flask_cors
17
- import pandas
18
- import numpy
19
- import plotly
20
- print("✅ All dependencies installed")
21
- return True
22
- except ImportError as e:
23
- print(f"❌ Missing dependency: {e}")
24
- print("Please run: pip install -r requirements.txt")
25
- return False
26
-
27
- def install_dependencies():
28
- """Install dependencies"""
29
- print("Installing dependencies...")
30
- try:
31
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
32
- print("✅ Dependencies installation completed")
33
- return True
34
- except subprocess.CalledProcessError:
35
- print("❌ Dependencies installation failed")
36
- return False
37
-
38
- def main():
39
- """Main function"""
40
- print("🚀 Starting Kronos Web UI...")
41
- print("=" * 50)
42
-
43
- # Check dependencies
44
- if not check_dependencies():
45
- print("\nAuto-install dependencies? (y/n): ", end="")
46
- if input().lower() == 'y':
47
- if not install_dependencies():
48
- return
49
- else:
50
- print("Please manually install dependencies and retry")
51
- return
52
-
53
- # Check model availability
54
- try:
55
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
56
- from model import Kronos, KronosTokenizer, KronosPredictor
57
- print("✅ Kronos model library available")
58
- model_available = True
59
- except ImportError:
60
- print("⚠️ Kronos model library not available, will use simulated prediction")
61
- model_available = False
62
-
63
- # Start Flask application
64
- print("\n🌐 Starting Web server...")
65
-
66
- # Set environment variables
67
- os.environ['FLASK_APP'] = 'app.py'
68
- os.environ['FLASK_ENV'] = 'development'
69
-
70
- # Start server
71
- try:
72
- from app import app
73
- print("✅ Web server started successfully!")
74
- print(f"🌐 Access URL: http://localhost:7070")
75
- print("💡 Tip: Press Ctrl+C to stop server")
76
-
77
- # Auto-open browser
78
- time.sleep(2)
79
- webbrowser.open('http://localhost:7070')
80
-
81
- # Start Flask application
82
- app.run(debug=True, host='0.0.0.0', port=7070)
83
-
84
- except Exception as e:
85
- print(f"❌ Startup failed: {e}")
86
- print("Please check if port 7070 is occupied")
87
-
88
- if __name__ == "__main__":
89
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import sys
5
+ import warnings
6
+
7
+ import pandas as pd
8
+ import plotly.graph_objects as go
9
+ import plotly.utils
10
+ import pytz
11
+ from binance.client import Client
12
+ from flask import Flask, render_template, request, jsonify
13
+ from flask_cors import CORS
14
+ from sympy import false
15
+
16
+ try:
17
+ from technical_indicators import add_technical_indicators, get_available_indicators
18
+
19
+ TECHNICAL_INDICATORS_AVAILABLE = False
20
+ except ImportError as e:
21
+ print(f"⚠️ 技术指标模块导入失败: {e}")
22
+ TECHNICAL_INDICATORS_AVAILABLE = False
23
+
24
+
25
+ # 定义空的替代函数
26
+ def add_technical_indicators(df, indicators_config=None):
27
+ return df
28
+
29
+
30
+ def get_available_indicators():
31
+ return {'trend': [], 'momentum': [], 'volatility': [], 'volume': []}
32
+
33
+ warnings.filterwarnings('ignore')
34
+
35
+ # 设置东八区时区
36
+ BEIJING_TZ = pytz.timezone('Asia/Shanghai')
37
+
38
+ # Add project root directory to path
39
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
40
+
41
+ try:
42
+ from model import Kronos, KronosTokenizer, KronosPredictor
43
+
44
+ MODEL_AVAILABLE = True
45
+ except ImportError:
46
+ MODEL_AVAILABLE = False
47
+ print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
48
+
49
+ app = Flask(__name__)
50
+ CORS(app)
51
+
52
+ # Global variables to store models
53
+ tokenizer = None
54
+ model = None
55
+ predictor = None
56
+
57
+ # Available model configurations
58
+ AVAILABLE_MODELS = {
59
+ 'kronos-mini': {
60
+ 'name': 'Kronos-mini',
61
+ 'model_id': 'NeoQuasar/Kronos-mini',
62
+ 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
63
+ 'context_length': 2048,
64
+ 'params': '4.1M',
65
+ 'description': 'Lightweight model, suitable for fast prediction'
66
+ },
67
+ 'kronos-small': {
68
+ 'name': 'Kronos-small',
69
+ 'model_id': 'NeoQuasar/Kronos-small',
70
+ 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
71
+ 'context_length': 512,
72
+ 'params': '24.7M',
73
+ 'description': 'Small model, balanced performance and speed'
74
+ },
75
+ 'kronos-base': {
76
+ 'name': 'Kronos-base',
77
+ 'model_id': 'NeoQuasar/Kronos-base',
78
+ 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
79
+ 'context_length': 512,
80
+ 'params': '102.3M',
81
+ 'description': 'Base model, provides better prediction quality'
82
+ }
83
+ }
84
+
85
+
86
+
87
+ def get_available_symbols():
88
+ """获取固定的交易对列表"""
89
+ # 返回固定的主要交易对,不再从币安API获取
90
+ return [
91
+ {'symbol': 'BTCUSDT', 'baseAsset': 'BTC', 'quoteAsset': 'USDT', 'name': 'BTC/USDT'},
92
+ {'symbol': 'ETHUSDT', 'baseAsset': 'ETH', 'quoteAsset': 'USDT', 'name': 'ETH/USDT'},
93
+ {'symbol': 'SOLUSDT', 'baseAsset': 'SOL', 'quoteAsset': 'USDT', 'name': 'SOL/USDT'},
94
+ {'symbol': 'BNBUSDT', 'baseAsset': 'BNB', 'quoteAsset': 'USDT', 'name': 'BNB/USDT'}
95
+ ]
96
+
97
+
98
+
99
+ def get_binance_klines(symbol, interval='1h', limit=1000):
100
+ """从币安获取K线数据,如果失败则生成模拟数据"""
101
+ try:
102
+ # 尝试初始化客户端并获取真实的币安数据
103
+ client = Client("", "")
104
+ klines = client.get_klines(
105
+ symbol=symbol,
106
+ interval=interval,
107
+ limit=limit
108
+ )
109
+
110
+ # 转换为DataFrame
111
+ df = pd.DataFrame(klines, columns=[
112
+ 'timestamp', 'open', 'high', 'low', 'close', 'volume',
113
+ 'close_time', 'quote_asset_volume', 'number_of_trades',
114
+ 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore'
115
+ ])
116
+
117
+ # 数据类型转换,转换为东八区时间
118
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
119
+ df['timestamp'] = df['timestamp'].dt.tz_convert(BEIJING_TZ)
120
+ df['timestamps'] = df['timestamp'] # 保持兼容性
121
+
122
+ # 转换数值列
123
+ numeric_cols = ['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume']
124
+ for col in numeric_cols:
125
+ df[col] = pd.to_numeric(df[col], errors='coerce')
126
+
127
+ # 添加amount列(成交额)
128
+ df['amount'] = df['quote_asset_volume']
129
+
130
+ # 只保留需要的列
131
+ df = df[['timestamp','timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']]
132
+
133
+ # 按时间排序
134
+ df = df.sort_values('timestamp').reset_index(drop=True)
135
+
136
+ # 添加技术指标(如果可用)
137
+ if TECHNICAL_INDICATORS_AVAILABLE:
138
+ try:
139
+ df = add_technical_indicators(df)
140
+ print(f"✅ 成功获取币安真实数据并计算技术指标: {symbol} {interval} {len(df)}条,{len(df.columns)}个特征")
141
+ except Exception as e:
142
+ print(f"⚠️ 技术指标计算失败,使用原始数据: {e}")
143
+ else:
144
+ print(f"✅ 成功获取币安真实数据: {symbol} {interval} {len(df)}条")
145
+
146
+ return df, None
147
+
148
+ except Exception as e:
149
+ print(f"⚠️ 币安API连接失败,使用模拟数据: {str(e)}")
150
+
151
+
152
+ def get_timeframe_options():
153
+ """获取可用的时间周期选项"""
154
+ return [
155
+ {'value': '1m', 'label': '1分钟', 'description': '1分钟K线'},
156
+ {'value': '5m', 'label': '5分钟', 'description': '5分钟K线'},
157
+ {'value': '15m', 'label': '15分钟', 'description': '15分钟K线'},
158
+ {'value': '30m', 'label': '30分钟', 'description': '30分钟K线'},
159
+ {'value': '1h', 'label': '1小时', 'description': '1小时K线'},
160
+ {'value': '4h', 'label': '4小时', 'description': '4小时K线'},
161
+ {'value': '1d', 'label': '1天', 'description': '日K线'},
162
+ {'value': '1w', 'label': '1周', 'description': '周K线'},
163
+ ]
164
+
165
+
166
+ def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
167
+ """Save prediction results to file"""
168
+ try:
169
+ # Create prediction results directory
170
+ results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
171
+ os.makedirs(results_dir, exist_ok=True)
172
+
173
+ # Generate filename
174
+ timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
175
+ filename = f'prediction_{timestamp}.json'
176
+ filepath = os.path.join(results_dir, filename)
177
+
178
+ # Prepare data for saving
179
+ save_data = {
180
+ 'timestamp': datetime.datetime.now().isoformat(),
181
+ 'file_path': file_path,
182
+ 'prediction_type': prediction_type,
183
+ 'prediction_params': prediction_params,
184
+ 'input_data_summary': {
185
+ 'rows': len(input_data),
186
+ 'columns': list(input_data.columns),
187
+ 'price_range': {
188
+ 'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
189
+ 'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
190
+ 'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
191
+ 'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
192
+ },
193
+ 'last_values': {
194
+ 'open': float(input_data['open'].iloc[-1]),
195
+ 'high': float(input_data['high'].iloc[-1]),
196
+ 'low': float(input_data['low'].iloc[-1]),
197
+ 'close': float(input_data['close'].iloc[-1])
198
+ }
199
+ },
200
+ 'prediction_results': prediction_results,
201
+ 'actual_data': actual_data,
202
+ 'analysis': {}
203
+ }
204
+
205
+ # If actual data exists, perform comparison analysis
206
+ if actual_data and len(actual_data) > 0:
207
+ # Calculate continuity analysis
208
+ if len(prediction_results) > 0 and len(actual_data) > 0:
209
+ last_pred = prediction_results[0] # First prediction point
210
+ first_actual = actual_data[0] # First actual point
211
+
212
+ save_data['analysis']['continuity'] = {
213
+ 'last_prediction': {
214
+ 'open': last_pred['open'],
215
+ 'high': last_pred['high'],
216
+ 'low': last_pred['low'],
217
+ 'close': last_pred['close']
218
+ },
219
+ 'first_actual': {
220
+ 'open': first_actual['open'],
221
+ 'high': first_actual['high'],
222
+ 'low': first_actual['low'],
223
+ 'close': first_actual['close']
224
+ },
225
+ 'gaps': {
226
+ 'open_gap': abs(last_pred['open'] - first_actual['open']),
227
+ 'high_gap': abs(last_pred['high'] - first_actual['high']),
228
+ 'low_gap': abs(last_pred['low'] - first_actual['low']),
229
+ 'close_gap': abs(last_pred['close'] - first_actual['close'])
230
+ },
231
+ 'gap_percentages': {
232
+ 'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
233
+ 'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
234
+ 'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
235
+ 'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
236
+ }
237
+ }
238
+
239
+ # Save to file
240
+ with open(filepath, 'w', encoding='utf-8') as f:
241
+ json.dump(save_data, f, indent=2, ensure_ascii=False)
242
+
243
+ print(f"Prediction results saved to: {filepath}")
244
+ return filepath
245
+
246
+ except Exception as e:
247
+ print(f"Failed to save prediction results: {e}")
248
+ return None
249
+
250
+
251
+ def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
252
+ """Create prediction chart"""
253
+ # Use specified historical data start position, not always from the beginning of df
254
+ if historical_start_idx + lookback + pred_len <= len(df):
255
+ # Display lookback historical points + pred_len prediction points starting from specified position
256
+ historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
257
+ prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len)
258
+ else:
259
+ # If data is insufficient, adjust to maximum available range
260
+ available_lookback = min(lookback, len(df) - historical_start_idx)
261
+ available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
262
+ historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
263
+ prediction_range = range(historical_start_idx + available_lookback,
264
+ historical_start_idx + available_lookback + available_pred_len)
265
+
266
+ # Create chart
267
+ fig = go.Figure()
268
+
269
+ # Add historical data (candlestick chart)
270
+ fig.add_trace(go.Candlestick(
271
+ x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
272
+ open=historical_df['open'],
273
+ high=historical_df['high'],
274
+ low=historical_df['low'],
275
+ close=historical_df['close'],
276
+ name='Historical Data (400 data points)',
277
+ increasing_line_color='#26A69A',
278
+ decreasing_line_color='#EF5350'
279
+ ))
280
+
281
+ # Add prediction data (candlestick chart)
282
+ if pred_df is not None and len(pred_df) > 0:
283
+ # Calculate prediction data timestamps - ensure continuity with historical data
284
+ if 'timestamps' in df.columns and len(historical_df) > 0:
285
+ # Start from the last timestamp of historical data, create prediction timestamps with the same time interval
286
+ last_timestamp = historical_df['timestamps'].iloc[-1]
287
+ time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
288
+
289
+ pred_timestamps = pd.date_range(
290
+ start=last_timestamp + time_diff,
291
+ periods=len(pred_df),
292
+ freq=time_diff
293
+ )
294
+ else:
295
+ # If no timestamps, use index
296
+ pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
297
+
298
+ fig.add_trace(go.Candlestick(
299
+ x=pred_timestamps,
300
+ open=pred_df['open'],
301
+ high=pred_df['high'],
302
+ low=pred_df['low'],
303
+ close=pred_df['close'],
304
+ name='Prediction Data (120 data points)',
305
+ increasing_line_color='#66BB6A',
306
+ decreasing_line_color='#FF7043'
307
+ ))
308
+
309
+ # Add actual data for comparison (if exists)
310
+ if actual_df is not None and len(actual_df) > 0:
311
+ # Actual data should be in the same time period as prediction data
312
+ if 'timestamps' in df.columns:
313
+ # Actual data should use the same timestamps as prediction data to ensure time alignment
314
+ if 'pred_timestamps' in locals():
315
+ actual_timestamps = pred_timestamps
316
+ else:
317
+ # If no prediction timestamps, calculate from the last timestamp of historical data
318
+ if len(historical_df) > 0:
319
+ last_timestamp = historical_df['timestamps'].iloc[-1]
320
+ time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
321
+ hours=1)
322
+ actual_timestamps = pd.date_range(
323
+ start=last_timestamp + time_diff,
324
+ periods=len(actual_df),
325
+ freq=time_diff
326
+ )
327
+ else:
328
+ actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
329
+ else:
330
+ actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
331
+
332
+ fig.add_trace(go.Candlestick(
333
+ x=actual_timestamps,
334
+ open=actual_df['open'],
335
+ high=actual_df['high'],
336
+ low=actual_df['low'],
337
+ close=actual_df['close'],
338
+ name='Actual Data (120 data points)',
339
+ increasing_line_color='#FF9800',
340
+ decreasing_line_color='#F44336'
341
+ ))
342
+
343
+ # Update layout
344
+ fig.update_layout(
345
+ title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
346
+ xaxis_title='Time',
347
+ yaxis_title='Price',
348
+ template='plotly_white',
349
+ height=600,
350
+ showlegend=True
351
+ )
352
+
353
+ # Ensure x-axis time continuity
354
+ if 'timestamps' in historical_df.columns:
355
+ # Get all timestamps and sort them
356
+ all_timestamps = []
357
+ if len(historical_df) > 0:
358
+ all_timestamps.extend(historical_df['timestamps'])
359
+ if 'pred_timestamps' in locals():
360
+ all_timestamps.extend(pred_timestamps)
361
+ if 'actual_timestamps' in locals():
362
+ all_timestamps.extend(actual_timestamps)
363
+
364
+ if all_timestamps:
365
+ all_timestamps = sorted(all_timestamps)
366
+ fig.update_xaxes(
367
+ range=[all_timestamps[0], all_timestamps[-1]],
368
+ rangeslider_visible=False,
369
+ type='date'
370
+ )
371
+
372
+ return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
373
+
374
+
375
+ @app.route('/')
376
+ def index():
377
+ """Home page"""
378
+ return render_template('index.html')
379
+
380
+
381
+ @app.route('/api/symbols')
382
+ def get_symbols():
383
+ """获取可用的交易对列表"""
384
+ symbols = get_available_symbols()
385
+ return jsonify(symbols)
386
+
387
+
388
+ @app.route('/api/timeframes')
389
+ def get_timeframes():
390
+ """获取可用的时间周期列表"""
391
+ timeframes = get_timeframe_options()
392
+ return jsonify(timeframes)
393
+
394
+
395
+ @app.route('/api/technical-indicators')
396
+ def get_technical_indicators():
397
+ """获取可用的技术指标列表"""
398
+ indicators = get_available_indicators()
399
+ return jsonify(indicators)
400
+
401
+
402
+ @app.route('/api/load-data', methods=['POST'])
403
+ def load_data():
404
+ """加载币安数据"""
405
+ try:
406
+ data = request.get_json()
407
+ symbol = data.get('symbol')
408
+ interval = data.get('interval', '1h')
409
+ limit = int(data.get('limit', 1000))
410
+
411
+ if not symbol:
412
+ return jsonify({'error': '交易对不能为空'}), 400
413
+
414
+ df, error = get_binance_klines(symbol, interval, limit)
415
+ if error:
416
+ return jsonify({'error': error}), 400
417
+
418
+ # Detect data time frequency
419
+ def detect_timeframe(df):
420
+ if len(df) < 2:
421
+ return "Unknown"
422
+
423
+ time_diffs = []
424
+ for i in range(1, min(10, len(df))): # Check first 10 time differences
425
+ diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i - 1]
426
+ time_diffs.append(diff)
427
+
428
+ if not time_diffs:
429
+ return "Unknown"
430
+
431
+ # Calculate average time difference
432
+ avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
433
+
434
+ # Convert to readable format
435
+ if avg_diff < pd.Timedelta(minutes=1):
436
+ return f"{avg_diff.total_seconds():.0f} seconds"
437
+ elif avg_diff < pd.Timedelta(hours=1):
438
+ return f"{avg_diff.total_seconds() / 60:.0f} minutes"
439
+ elif avg_diff < pd.Timedelta(days=1):
440
+ return f"{avg_diff.total_seconds() / 3600:.0f} hours"
441
+ else:
442
+ return f"{avg_diff.days} days"
443
+
444
+ # Return data information with formatted time
445
+ def format_beijing_time(timestamp):
446
+ """格式化东八区时间为 yyyy-MM-dd HH:mm:ss"""
447
+ if pd.isna(timestamp):
448
+ return 'N/A'
449
+ # 确保时间戳有时区信息
450
+ if timestamp.tz is None:
451
+ timestamp = timestamp.tz_localize(BEIJING_TZ)
452
+ elif timestamp.tz != BEIJING_TZ:
453
+ timestamp = timestamp.tz_convert(BEIJING_TZ)
454
+ return timestamp.strftime('%Y-%m-%d %H:%M:%S')
455
+
456
+ data_info = {
457
+ 'rows': len(df),
458
+ 'columns': list(df.columns),
459
+ 'start_date': format_beijing_time(df['timestamps'].min()) if 'timestamps' in df.columns else 'N/A',
460
+ 'end_date': format_beijing_time(df['timestamps'].max()) if 'timestamps' in df.columns else 'N/A',
461
+ 'price_range': {
462
+ 'min': float(df[['open', 'high', 'low', 'close']].min().min()),
463
+ 'max': float(df[['open', 'high', 'low', 'close']].max().max())
464
+ },
465
+ 'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
466
+ 'timeframe': detect_timeframe(df)
467
+ }
468
+
469
+ return jsonify({
470
+ 'success': True,
471
+ 'data_info': data_info,
472
+ 'message': f'Successfully loaded data, total {len(df)} rows'
473
+ })
474
+
475
+ except Exception as e:
476
+ return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
477
+
478
+
479
+ @app.route('/api/predict', methods=['POST'])
480
+ def predict():
481
+ """Perform prediction"""
482
+ try:
483
+ data = request.get_json()
484
+ symbol = data.get('symbol')
485
+ interval = data.get('interval', '1h')
486
+ limit = int(data.get('limit', 1000))
487
+ lookback = int(data.get('lookback', 400))
488
+ pred_len = int(data.get('pred_len', 120))
489
+
490
+ # Get prediction quality parameters
491
+ temperature = float(data.get('temperature', 1.0))
492
+ top_p = float(data.get('top_p', 0.9))
493
+ sample_count = int(data.get('sample_count', 1))
494
+
495
+ if not symbol:
496
+ return jsonify({'error': '交易对不能为空'}), 400
497
+
498
+ # Load data from Binance
499
+ df, error = get_binance_klines(symbol, interval, limit)
500
+ if error:
501
+ return jsonify({'error': error}), 400
502
+
503
+ if len(df) < lookback:
504
+ return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
505
+
506
+ # Perform prediction
507
+ if MODEL_AVAILABLE:
508
+ try:
509
+ # Use real Kronos model
510
+ # Only use necessary columns: OHLCVA (6 features required by Kronos model)
511
+ required_cols = ['open', 'high', 'low', 'close']
512
+ if 'volume' in df.columns:
513
+ required_cols.append('volume')
514
+ if 'amount' in df.columns:
515
+ required_cols.append('amount')
516
+
517
+ print(f"🔍 Using features for prediction: {required_cols}")
518
+ print(f" Available columns in data: {list(df.columns)}")
519
+ print(f" Data shape: {df.shape}")
520
+
521
+ # Check if required columns exist
522
+ missing_cols = [col for col in required_cols if col not in df.columns]
523
+ if missing_cols:
524
+ return jsonify({'error': f'Missing required columns: {missing_cols}'}), 400
525
+
526
+ # Process time period selection
527
+ start_date = data.get('start_date')
528
+
529
+ if start_date:
530
+ # Custom time period - fix logic: use data within selected window
531
+ start_dt = pd.to_datetime(start_date)
532
+
533
+ # Find data after start time
534
+ mask = df['timestamps'] >= start_dt
535
+ time_range_df = df[mask]
536
+
537
+ # Ensure sufficient data: lookback + pred_len
538
+ if len(time_range_df) < lookback + pred_len:
539
+ return jsonify({
540
+ 'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
541
+
542
+ # Use first lookback data points within selected window for prediction
543
+ x_df = time_range_df.iloc[:lookback][required_cols]
544
+ x_timestamp = time_range_df.iloc[:lookback]['timestamps']
545
+
546
+ print(f"🔍 Custom time period - x_df shape: {x_df.shape}")
547
+ print(f" x_timestamp length: {len(x_timestamp)}")
548
+ print(f" x_df columns: {list(x_df.columns)}")
549
+ print(f" x_df sample:\n{x_df.head()}")
550
+
551
+ # Generate future timestamps for prediction instead of using existing data
552
+ # Calculate time difference from the data
553
+ if len(time_range_df) >= 2:
554
+ time_diff = time_range_df['timestamps'].iloc[1] - time_range_df['timestamps'].iloc[0]
555
+ else:
556
+ time_diff = pd.Timedelta(hours=1) # Default to 1 hour
557
+
558
+ # Generate future timestamps starting from the last timestamp of input data
559
+ last_timestamp = time_range_df['timestamps'].iloc[lookback - 1]
560
+ y_timestamp = pd.date_range(
561
+ start=last_timestamp + time_diff,
562
+ periods=pred_len,
563
+ freq=time_diff
564
+ )
565
+
566
+ # Calculate actual time period length
567
+ start_timestamp = time_range_df['timestamps'].iloc[0]
568
+ end_timestamp = y_timestamp[-1] # Use the last generated timestamp
569
+ time_span = end_timestamp - start_timestamp
570
+
571
+ prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, {pred_len} future predictions, time span: {time_span})"
572
+ else:
573
+ # Use latest data
574
+ x_df = df.iloc[:lookback][required_cols]
575
+ x_timestamp = df.iloc[:lookback]['timestamps']
576
+
577
+ # Generate future timestamps for prediction instead of using existing data
578
+ # Calculate time difference from the data
579
+ if len(df) >= 2:
580
+ time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
581
+ else:
582
+ time_diff = pd.Timedelta(hours=1) # Default to 1 hour
583
+
584
+ # Generate future timestamps starting from the last timestamp of input data
585
+ last_timestamp = df['timestamps'].iloc[lookback - 1]
586
+ y_timestamp = pd.date_range(
587
+ start=last_timestamp + time_diff,
588
+ periods=pred_len,
589
+ freq=time_diff
590
+ )
591
+ prediction_type = "Kronos model prediction (latest data)"
592
+
593
+ print(f"🔍 Latest data - x_df shape: {x_df.shape}")
594
+ print(f" x_timestamp length: {len(x_timestamp)}")
595
+ print(f" y_timestamp length: {len(y_timestamp)}")
596
+ print(f" x_df columns: {list(x_df.columns)}")
597
+ print(f" x_df sample:\n{x_df.head()}")
598
+
599
+ # Check if data is empty
600
+ if x_df.empty or len(x_df) == 0:
601
+ return jsonify({'error': 'Input data is empty after processing'}), 400
602
+
603
+ if len(x_timestamp) == 0:
604
+ return jsonify({'error': 'Input timestamps are empty'}), 400
605
+
606
+ if len(y_timestamp) == 0:
607
+ return jsonify({'error': 'Target timestamps are empty'}), 400
608
+
609
+ # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
610
+ if isinstance(x_timestamp, pd.DatetimeIndex):
611
+ x_timestamp = pd.Series(x_timestamp, name='timestamps')
612
+ if isinstance(y_timestamp, pd.DatetimeIndex):
613
+ y_timestamp = pd.Series(y_timestamp, name='timestamps')
614
+
615
+ pred_df = predictor.predict(
616
+ df=x_df,
617
+ x_timestamp=x_timestamp,
618
+ y_timestamp=y_timestamp,
619
+ pred_len=pred_len,
620
+ T=temperature,
621
+ top_p=top_p,
622
+ sample_count=sample_count
623
+ )
624
+
625
+ except Exception as e:
626
+ return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
627
+ else:
628
+ return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
629
+
630
+ # Prepare actual data for comparison (if exists)
631
+ actual_data = []
632
+ actual_df = None
633
+
634
+ if start_date: # Custom time period
635
+ # Fix logic: use data within selected window
636
+ # Prediction uses first 400 data points within selected window
637
+ # Actual data should be last 120 data points within selected window
638
+ start_dt = pd.to_datetime(start_date)
639
+ # 确保时区一致性
640
+ if start_dt.tz is None:
641
+ start_dt = start_dt.tz_localize(BEIJING_TZ)
642
+
643
+ # Find data starting from start_date
644
+ mask = df['timestamps'] >= start_dt
645
+ time_range_df = df[mask]
646
+
647
+ if len(time_range_df) >= lookback + pred_len:
648
+ # Get last 120 data points within selected window as actual values
649
+ actual_df = time_range_df.iloc[lookback:lookback + pred_len]
650
+
651
+ for i, (_, row) in enumerate(actual_df.iterrows()):
652
+ actual_data.append({
653
+ 'timestamp': row['timestamps'].isoformat(),
654
+ 'open': float(row['open']),
655
+ 'high': float(row['high']),
656
+ 'low': float(row['low']),
657
+ 'close': float(row['close']),
658
+ 'volume': float(row['volume']) if 'volume' in row else 0,
659
+ 'amount': float(row['amount']) if 'amount' in row else 0
660
+ })
661
+ else: # Latest data
662
+ # Prediction uses first 400 data points
663
+ # Actual data should be 120 data points after first 400 data points
664
+ if len(df) >= lookback + pred_len:
665
+ actual_df = df.iloc[lookback:lookback + pred_len]
666
+ for i, (_, row) in enumerate(actual_df.iterrows()):
667
+ actual_data.append({
668
+ 'timestamp': row['timestamps'].isoformat(),
669
+ 'open': float(row['open']),
670
+ 'high': float(row['high']),
671
+ 'low': float(row['low']),
672
+ 'close': float(row['close']),
673
+ 'volume': float(row['volume']) if 'volume' in row else 0,
674
+ 'amount': float(row['amount']) if 'amount' in row else 0
675
+ })
676
+
677
+ # Create chart - pass historical data start position
678
+ if start_date:
679
+ # Custom time period: find starting position of historical data in original df
680
+ start_dt = pd.to_datetime(start_date)
681
+ # 确保时区一致性
682
+ if start_dt.tz is None:
683
+ start_dt = start_dt.tz_localize(BEIJING_TZ)
684
+ mask = df['timestamps'] >= start_dt
685
+ historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
686
+ else:
687
+ # Latest data: start from beginning
688
+ historical_start_idx = 0
689
+
690
+ chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
691
+
692
+ # Prepare prediction result data - fix timestamp calculation logic
693
+ if 'timestamps' in df.columns:
694
+ if start_date:
695
+ # Custom time period: use selected window data to calculate timestamps
696
+ start_dt = pd.to_datetime(start_date)
697
+ # 确保时区一致性
698
+ if start_dt.tz is None:
699
+ start_dt = start_dt.tz_localize(BEIJING_TZ)
700
+ mask = df['timestamps'] >= start_dt
701
+ time_range_df = df[mask]
702
+
703
+ if len(time_range_df) >= lookback:
704
+ # Calculate prediction timestamps starting from last time point of selected window
705
+ last_timestamp = time_range_df['timestamps'].iloc[lookback - 1]
706
+ time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
707
+ future_timestamps = pd.date_range(
708
+ start=last_timestamp + time_diff,
709
+ periods=pred_len,
710
+ freq=time_diff
711
+ )
712
+ else:
713
+ future_timestamps = []
714
+ else:
715
+ # Latest data: calculate from last time point of entire data file
716
+ last_timestamp = df['timestamps'].iloc[-1]
717
+ time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
718
+ future_timestamps = pd.date_range(
719
+ start=last_timestamp + time_diff,
720
+ periods=pred_len,
721
+ freq=time_diff
722
+ )
723
+ else:
724
+ future_timestamps = range(len(df), len(df) + pred_len)
725
+
726
+ prediction_results = []
727
+ for i, (_, row) in enumerate(pred_df.iterrows()):
728
+ prediction_results.append({
729
+ 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
730
+ 'open': float(row['open']),
731
+ 'high': float(row['high']),
732
+ 'low': float(row['low']),
733
+ 'close': float(row['close']),
734
+ 'volume': float(row['volume']) if 'volume' in row else 0,
735
+ 'amount': float(row['amount']) if 'amount' in row else 0
736
+ })
737
+
738
+ # Save prediction results to file
739
+ try:
740
+ data_source = f"{symbol}_{interval}"
741
+ save_prediction_results(
742
+ file_path=data_source,
743
+ prediction_type=prediction_type,
744
+ prediction_results=prediction_results,
745
+ actual_data=actual_data,
746
+ input_data=x_df,
747
+ prediction_params={
748
+ 'symbol': symbol,
749
+ 'interval': interval,
750
+ 'limit': limit,
751
+ 'lookback': lookback,
752
+ 'pred_len': pred_len,
753
+ 'temperature': temperature,
754
+ 'top_p': top_p,
755
+ 'sample_count': sample_count,
756
+ 'start_date': start_date if start_date else 'latest'
757
+ }
758
+ )
759
+ except Exception as e:
760
+ print(f"Failed to save prediction results: {e}")
761
+
762
+ return jsonify({
763
+ 'success': True,
764
+ 'prediction_type': prediction_type,
765
+ 'chart': chart_json,
766
+ 'prediction_results': prediction_results,
767
+ 'actual_data': actual_data,
768
+ 'has_comparison': len(actual_data) > 0,
769
+ 'message': f'Prediction completed, generated {pred_len} prediction points' + (
770
+ f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
771
+ })
772
+
773
+ except Exception as e:
774
+ return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
775
+
776
+
777
+ @app.route('/api/load-model', methods=['POST'])
778
+ def load_model():
779
+ """Load Kronos model"""
780
+ global tokenizer, model, predictor
781
+
782
+ try:
783
+ if not MODEL_AVAILABLE:
784
+ return jsonify({'error': 'Kronos model library not available'}), 400
785
+
786
+ data = request.get_json()
787
+ model_key = data.get('model_key', 'kronos-small')
788
+ device = data.get('device', 'cpu')
789
+
790
+ if model_key not in AVAILABLE_MODELS:
791
+ return jsonify({'error': f'Unsupported model: {model_key}'}), 400
792
+
793
+ model_config = AVAILABLE_MODELS[model_key]
794
+
795
+ # Load tokenizer and model
796
+ tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
797
+ model = Kronos.from_pretrained(model_config['model_id'])
798
+
799
+ # Create predictor
800
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
801
+
802
+ return jsonify({
803
+ 'success': True,
804
+ 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
805
+ 'model_info': {
806
+ 'name': model_config['name'],
807
+ 'params': model_config['params'],
808
+ 'context_length': model_config['context_length'],
809
+ 'description': model_config['description']
810
+ }
811
+ })
812
+
813
+ except Exception as e:
814
+ return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
815
+
816
+
817
+ @app.route('/api/available-models')
818
+ def get_available_models():
819
+ """Get available model list"""
820
+ return jsonify({
821
+ 'models': AVAILABLE_MODELS,
822
+ 'model_available': MODEL_AVAILABLE
823
+ })
824
+
825
+
826
+ @app.route('/api/model-status')
827
+ def get_model_status():
828
+ """Get model status"""
829
+ if MODEL_AVAILABLE:
830
+ if predictor is not None:
831
+ return jsonify({
832
+ 'available': True,
833
+ 'loaded': True,
834
+ 'message': 'Kronos model loaded and available',
835
+ 'current_model': {
836
+ 'name': predictor.model.__class__.__name__,
837
+ 'device': str(next(predictor.model.parameters()).device)
838
+ }
839
+ })
840
+ else:
841
+ return jsonify({
842
+ 'available': True,
843
+ 'loaded': False,
844
+ 'message': 'Kronos model available but not loaded'
845
+ })
846
+ else:
847
+ return jsonify({
848
+ 'available': False,
849
+ 'loaded': False,
850
+ 'message': 'Kronos model library not available, please install related dependencies'
851
+ })
852
+
853
+
854
+ if __name__ == '__main__':
855
+ print("Starting Kronos Web UI...")
856
+ print(f"Model availability: {MODEL_AVAILABLE}")
857
+ if MODEL_AVAILABLE:
858
+ print("Tip: You can load Kronos model through /api/load-model endpoint")
859
+ else:
860
+ print("Tip: Will use simulated data for demonstration")
861
+
862
+ app.run(debug=True, host='0.0.0.0', port=7070)