|
|
import datetime |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import warnings |
|
|
|
|
|
import pandas as pd |
|
|
import plotly.graph_objects as go |
|
|
import plotly.utils |
|
|
import pytz |
|
|
from binance.client import Client |
|
|
from flask import Flask, render_template, request, jsonify |
|
|
from flask_cors import CORS |
|
|
from sympy import false |
|
|
|
|
|
try: |
|
|
from technical_indicators import add_technical_indicators, get_available_indicators |
|
|
|
|
|
TECHNICAL_INDICATORS_AVAILABLE = False |
|
|
except ImportError as e: |
|
|
print(f"⚠️ 技术指标模块导入失败: {e}") |
|
|
TECHNICAL_INDICATORS_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
|
|
def add_technical_indicators(df, indicators_config=None): |
|
|
return df |
|
|
|
|
|
|
|
|
def get_available_indicators(): |
|
|
return {'trend': [], 'momentum': [], 'volatility': [], 'volume': []} |
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
BEIJING_TZ = pytz.timezone('Asia/Shanghai') |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
try: |
|
|
from model import Kronos, KronosTokenizer, KronosPredictor |
|
|
|
|
|
MODEL_AVAILABLE = True |
|
|
except ImportError: |
|
|
MODEL_AVAILABLE = False |
|
|
print("Warning: Kronos model cannot be imported, will use simulated data for demonstration") |
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
predictor = None |
|
|
|
|
|
|
|
|
AVAILABLE_MODELS = { |
|
|
'kronos-mini': { |
|
|
'name': 'Kronos-mini', |
|
|
'model_id': 'NeoQuasar/Kronos-mini', |
|
|
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k', |
|
|
'context_length': 2048, |
|
|
'params': '4.1M', |
|
|
'description': 'Lightweight model, suitable for fast prediction' |
|
|
}, |
|
|
'kronos-small': { |
|
|
'name': 'Kronos-small', |
|
|
'model_id': 'NeoQuasar/Kronos-small', |
|
|
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', |
|
|
'context_length': 512, |
|
|
'params': '24.7M', |
|
|
'description': 'Small model, balanced performance and speed' |
|
|
}, |
|
|
'kronos-base': { |
|
|
'name': 'Kronos-base', |
|
|
'model_id': 'NeoQuasar/Kronos-base', |
|
|
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', |
|
|
'context_length': 512, |
|
|
'params': '102.3M', |
|
|
'description': 'Base model, provides better prediction quality' |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def get_available_symbols(): |
|
|
"""获取固定的交易对列表""" |
|
|
|
|
|
return [ |
|
|
{'symbol': 'BTCUSDT', 'baseAsset': 'BTC', 'quoteAsset': 'USDT', 'name': 'BTC/USDT'}, |
|
|
{'symbol': 'ETHUSDT', 'baseAsset': 'ETH', 'quoteAsset': 'USDT', 'name': 'ETH/USDT'}, |
|
|
{'symbol': 'SOLUSDT', 'baseAsset': 'SOL', 'quoteAsset': 'USDT', 'name': 'SOL/USDT'}, |
|
|
{'symbol': 'BNBUSDT', 'baseAsset': 'BNB', 'quoteAsset': 'USDT', 'name': 'BNB/USDT'} |
|
|
] |
|
|
|
|
|
|
|
|
binance_client = Client("", "") |
|
|
|
|
|
def get_binance_klines(symbol, interval='1h', limit=1000): |
|
|
"""从币安获取K线数据,如果失败则生成模拟数据""" |
|
|
try: |
|
|
|
|
|
klines = binance_client.get_klines( |
|
|
symbol=symbol, |
|
|
interval=interval, |
|
|
limit=limit |
|
|
) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(klines, columns=[ |
|
|
'timestamp', 'open', 'high', 'low', 'close', 'volume', |
|
|
'close_time', 'quote_asset_volume', 'number_of_trades', |
|
|
'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore' |
|
|
]) |
|
|
|
|
|
|
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) |
|
|
df['timestamp'] = df['timestamp'].dt.tz_convert(BEIJING_TZ) |
|
|
df['timestamps'] = df['timestamp'] |
|
|
|
|
|
|
|
|
numeric_cols = ['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume'] |
|
|
for col in numeric_cols: |
|
|
df[col] = pd.to_numeric(df[col], errors='coerce') |
|
|
|
|
|
|
|
|
df['amount'] = df['quote_asset_volume'] |
|
|
|
|
|
|
|
|
df = df[['timestamp','timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] |
|
|
|
|
|
|
|
|
df = df.sort_values('timestamp').reset_index(drop=True) |
|
|
|
|
|
|
|
|
if TECHNICAL_INDICATORS_AVAILABLE: |
|
|
try: |
|
|
df = add_technical_indicators(df) |
|
|
print(f"✅ 成功获取币安真实数据并计算技术指标: {symbol} {interval} {len(df)}条,{len(df.columns)}个特征") |
|
|
except Exception as e: |
|
|
print(f"⚠️ 技术指标计算失败,使用原始数据: {e}") |
|
|
else: |
|
|
print(f"✅ 成功获取币安真实数据: {symbol} {interval} {len(df)}条") |
|
|
|
|
|
return df, None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ 币安API连接失败,使用模拟数据: {str(e)}") |
|
|
|
|
|
|
|
|
def get_timeframe_options(): |
|
|
"""获取可用的时间周期选项""" |
|
|
return [ |
|
|
{'value': '1m', 'label': '1分钟', 'description': '1分钟K线'}, |
|
|
{'value': '5m', 'label': '5分钟', 'description': '5分钟K线'}, |
|
|
{'value': '15m', 'label': '15分钟', 'description': '15分钟K线'}, |
|
|
{'value': '30m', 'label': '30分钟', 'description': '30分钟K线'}, |
|
|
{'value': '1h', 'label': '1小时', 'description': '1小时K线'}, |
|
|
{'value': '4h', 'label': '4小时', 'description': '4小时K线'}, |
|
|
{'value': '1d', 'label': '1天', 'description': '日K线'}, |
|
|
{'value': '1w', 'label': '1周', 'description': '周K线'}, |
|
|
] |
|
|
|
|
|
|
|
|
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params): |
|
|
"""Save prediction results to file""" |
|
|
try: |
|
|
|
|
|
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results') |
|
|
os.makedirs(results_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
filename = f'prediction_{timestamp}.json' |
|
|
filepath = os.path.join(results_dir, filename) |
|
|
|
|
|
|
|
|
save_data = { |
|
|
'timestamp': datetime.datetime.now().isoformat(), |
|
|
'file_path': file_path, |
|
|
'prediction_type': prediction_type, |
|
|
'prediction_params': prediction_params, |
|
|
'input_data_summary': { |
|
|
'rows': len(input_data), |
|
|
'columns': list(input_data.columns), |
|
|
'price_range': { |
|
|
'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())}, |
|
|
'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())}, |
|
|
'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())}, |
|
|
'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())} |
|
|
}, |
|
|
'last_values': { |
|
|
'open': float(input_data['open'].iloc[-1]), |
|
|
'high': float(input_data['high'].iloc[-1]), |
|
|
'low': float(input_data['low'].iloc[-1]), |
|
|
'close': float(input_data['close'].iloc[-1]) |
|
|
} |
|
|
}, |
|
|
'prediction_results': prediction_results, |
|
|
'actual_data': actual_data, |
|
|
'analysis': {} |
|
|
} |
|
|
|
|
|
|
|
|
if actual_data and len(actual_data) > 0: |
|
|
|
|
|
if len(prediction_results) > 0 and len(actual_data) > 0: |
|
|
last_pred = prediction_results[0] |
|
|
first_actual = actual_data[0] |
|
|
|
|
|
save_data['analysis']['continuity'] = { |
|
|
'last_prediction': { |
|
|
'open': last_pred['open'], |
|
|
'high': last_pred['high'], |
|
|
'low': last_pred['low'], |
|
|
'close': last_pred['close'] |
|
|
}, |
|
|
'first_actual': { |
|
|
'open': first_actual['open'], |
|
|
'high': first_actual['high'], |
|
|
'low': first_actual['low'], |
|
|
'close': first_actual['close'] |
|
|
}, |
|
|
'gaps': { |
|
|
'open_gap': abs(last_pred['open'] - first_actual['open']), |
|
|
'high_gap': abs(last_pred['high'] - first_actual['high']), |
|
|
'low_gap': abs(last_pred['low'] - first_actual['low']), |
|
|
'close_gap': abs(last_pred['close'] - first_actual['close']) |
|
|
}, |
|
|
'gap_percentages': { |
|
|
'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100, |
|
|
'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100, |
|
|
'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100, |
|
|
'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
|
json.dump(save_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print(f"Prediction results saved to: {filepath}") |
|
|
return filepath |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to save prediction results: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0): |
|
|
"""Create prediction chart""" |
|
|
|
|
|
if historical_start_idx + lookback + pred_len <= len(df): |
|
|
|
|
|
historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback] |
|
|
prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len) |
|
|
else: |
|
|
|
|
|
available_lookback = min(lookback, len(df) - historical_start_idx) |
|
|
available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback)) |
|
|
historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback] |
|
|
prediction_range = range(historical_start_idx + available_lookback, |
|
|
historical_start_idx + available_lookback + available_pred_len) |
|
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace(go.Candlestick( |
|
|
x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index, |
|
|
open=historical_df['open'], |
|
|
high=historical_df['high'], |
|
|
low=historical_df['low'], |
|
|
close=historical_df['close'], |
|
|
name='Historical Data (400 data points)', |
|
|
increasing_line_color='#26A69A', |
|
|
decreasing_line_color='#EF5350' |
|
|
)) |
|
|
|
|
|
|
|
|
if pred_df is not None and len(pred_df) > 0: |
|
|
|
|
|
if 'timestamps' in df.columns and len(historical_df) > 0: |
|
|
|
|
|
last_timestamp = historical_df['timestamps'].iloc[-1] |
|
|
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1) |
|
|
|
|
|
pred_timestamps = pd.date_range( |
|
|
start=last_timestamp + time_diff, |
|
|
periods=len(pred_df), |
|
|
freq=time_diff |
|
|
) |
|
|
else: |
|
|
|
|
|
pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df)) |
|
|
|
|
|
fig.add_trace(go.Candlestick( |
|
|
x=pred_timestamps, |
|
|
open=pred_df['open'], |
|
|
high=pred_df['high'], |
|
|
low=pred_df['low'], |
|
|
close=pred_df['close'], |
|
|
name='Prediction Data (120 data points)', |
|
|
increasing_line_color='#66BB6A', |
|
|
decreasing_line_color='#FF7043' |
|
|
)) |
|
|
|
|
|
|
|
|
if actual_df is not None and len(actual_df) > 0: |
|
|
|
|
|
if 'timestamps' in df.columns: |
|
|
|
|
|
if 'pred_timestamps' in locals(): |
|
|
actual_timestamps = pred_timestamps |
|
|
else: |
|
|
|
|
|
if len(historical_df) > 0: |
|
|
last_timestamp = historical_df['timestamps'].iloc[-1] |
|
|
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta( |
|
|
hours=1) |
|
|
actual_timestamps = pd.date_range( |
|
|
start=last_timestamp + time_diff, |
|
|
periods=len(actual_df), |
|
|
freq=time_diff |
|
|
) |
|
|
else: |
|
|
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df)) |
|
|
else: |
|
|
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df)) |
|
|
|
|
|
fig.add_trace(go.Candlestick( |
|
|
x=actual_timestamps, |
|
|
open=actual_df['open'], |
|
|
high=actual_df['high'], |
|
|
low=actual_df['low'], |
|
|
close=actual_df['close'], |
|
|
name='Actual Data (120 data points)', |
|
|
increasing_line_color='#FF9800', |
|
|
decreasing_line_color='#F44336' |
|
|
)) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points', |
|
|
xaxis_title='Time', |
|
|
yaxis_title='Price', |
|
|
template='plotly_white', |
|
|
height=600, |
|
|
showlegend=True |
|
|
) |
|
|
|
|
|
|
|
|
if 'timestamps' in historical_df.columns: |
|
|
|
|
|
all_timestamps = [] |
|
|
if len(historical_df) > 0: |
|
|
all_timestamps.extend(historical_df['timestamps']) |
|
|
if 'pred_timestamps' in locals(): |
|
|
all_timestamps.extend(pred_timestamps) |
|
|
if 'actual_timestamps' in locals(): |
|
|
all_timestamps.extend(actual_timestamps) |
|
|
|
|
|
if all_timestamps: |
|
|
all_timestamps = sorted(all_timestamps) |
|
|
fig.update_xaxes( |
|
|
range=[all_timestamps[0], all_timestamps[-1]], |
|
|
rangeslider_visible=False, |
|
|
type='date' |
|
|
) |
|
|
|
|
|
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) |
|
|
|
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
"""Home page""" |
|
|
return render_template('index.html') |
|
|
|
|
|
|
|
|
@app.route('/api/symbols') |
|
|
def get_symbols(): |
|
|
"""获取可用的交易对列表""" |
|
|
symbols = get_available_symbols() |
|
|
return jsonify(symbols) |
|
|
|
|
|
|
|
|
@app.route('/api/timeframes') |
|
|
def get_timeframes(): |
|
|
"""获取可用的时间周期列表""" |
|
|
timeframes = get_timeframe_options() |
|
|
return jsonify(timeframes) |
|
|
|
|
|
|
|
|
@app.route('/api/technical-indicators') |
|
|
def get_technical_indicators(): |
|
|
"""获取可用的技术指标列表""" |
|
|
indicators = get_available_indicators() |
|
|
return jsonify(indicators) |
|
|
|
|
|
|
|
|
@app.route('/api/load-data', methods=['POST']) |
|
|
def load_data(): |
|
|
"""加载币安数据""" |
|
|
try: |
|
|
data = request.get_json() |
|
|
symbol = data.get('symbol') |
|
|
interval = data.get('interval', '1h') |
|
|
limit = int(data.get('limit', 1000)) |
|
|
|
|
|
if not symbol: |
|
|
return jsonify({'error': '交易对不能为空'}), 400 |
|
|
|
|
|
df, error = get_binance_klines(symbol, interval, limit) |
|
|
if error: |
|
|
return jsonify({'error': error}), 400 |
|
|
|
|
|
|
|
|
def detect_timeframe(df): |
|
|
if len(df) < 2: |
|
|
return "Unknown" |
|
|
|
|
|
time_diffs = [] |
|
|
for i in range(1, min(10, len(df))): |
|
|
diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i - 1] |
|
|
time_diffs.append(diff) |
|
|
|
|
|
if not time_diffs: |
|
|
return "Unknown" |
|
|
|
|
|
|
|
|
avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs) |
|
|
|
|
|
|
|
|
if avg_diff < pd.Timedelta(minutes=1): |
|
|
return f"{avg_diff.total_seconds():.0f} seconds" |
|
|
elif avg_diff < pd.Timedelta(hours=1): |
|
|
return f"{avg_diff.total_seconds() / 60:.0f} minutes" |
|
|
elif avg_diff < pd.Timedelta(days=1): |
|
|
return f"{avg_diff.total_seconds() / 3600:.0f} hours" |
|
|
else: |
|
|
return f"{avg_diff.days} days" |
|
|
|
|
|
|
|
|
def format_beijing_time(timestamp): |
|
|
"""格式化东八区时间为 yyyy-MM-dd HH:mm:ss""" |
|
|
if pd.isna(timestamp): |
|
|
return 'N/A' |
|
|
|
|
|
if timestamp.tz is None: |
|
|
timestamp = timestamp.tz_localize(BEIJING_TZ) |
|
|
elif timestamp.tz != BEIJING_TZ: |
|
|
timestamp = timestamp.tz_convert(BEIJING_TZ) |
|
|
return timestamp.strftime('%Y-%m-%d %H:%M:%S') |
|
|
|
|
|
data_info = { |
|
|
'rows': len(df), |
|
|
'columns': list(df.columns), |
|
|
'start_date': format_beijing_time(df['timestamps'].min()) if 'timestamps' in df.columns else 'N/A', |
|
|
'end_date': format_beijing_time(df['timestamps'].max()) if 'timestamps' in df.columns else 'N/A', |
|
|
'price_range': { |
|
|
'min': float(df[['open', 'high', 'low', 'close']].min().min()), |
|
|
'max': float(df[['open', 'high', 'low', 'close']].max().max()) |
|
|
}, |
|
|
'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []), |
|
|
'timeframe': detect_timeframe(df) |
|
|
} |
|
|
|
|
|
return jsonify({ |
|
|
'success': True, |
|
|
'data_info': data_info, |
|
|
'message': f'Successfully loaded data, total {len(df)} rows' |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({'error': f'Failed to load data: {str(e)}'}), 500 |
|
|
|
|
|
|
|
|
@app.route('/api/predict', methods=['POST']) |
|
|
def predict(): |
|
|
"""Perform prediction""" |
|
|
try: |
|
|
data = request.get_json() |
|
|
symbol = data.get('symbol') |
|
|
interval = data.get('interval', '1h') |
|
|
limit = int(data.get('limit', 1000)) |
|
|
lookback = int(data.get('lookback', 400)) |
|
|
pred_len = int(data.get('pred_len', 120)) |
|
|
|
|
|
|
|
|
temperature = float(data.get('temperature', 1.0)) |
|
|
top_p = float(data.get('top_p', 0.9)) |
|
|
sample_count = int(data.get('sample_count', 1)) |
|
|
|
|
|
if not symbol: |
|
|
return jsonify({'error': '交易对不能为空'}), 400 |
|
|
|
|
|
|
|
|
df, error = get_binance_klines(symbol, interval, limit) |
|
|
if error: |
|
|
return jsonify({'error': error}), 400 |
|
|
|
|
|
if len(df) < lookback: |
|
|
return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400 |
|
|
|
|
|
|
|
|
if MODEL_AVAILABLE: |
|
|
try: |
|
|
|
|
|
|
|
|
required_cols = ['open', 'high', 'low', 'close'] |
|
|
if 'volume' in df.columns: |
|
|
required_cols.append('volume') |
|
|
if 'amount' in df.columns: |
|
|
required_cols.append('amount') |
|
|
|
|
|
print(f"🔍 Using features for prediction: {required_cols}") |
|
|
print(f" Available columns in data: {list(df.columns)}") |
|
|
print(f" Data shape: {df.shape}") |
|
|
|
|
|
|
|
|
missing_cols = [col for col in required_cols if col not in df.columns] |
|
|
if missing_cols: |
|
|
return jsonify({'error': f'Missing required columns: {missing_cols}'}), 400 |
|
|
|
|
|
|
|
|
start_date = data.get('start_date') |
|
|
|
|
|
if start_date: |
|
|
|
|
|
start_dt = pd.to_datetime(start_date) |
|
|
|
|
|
|
|
|
mask = df['timestamps'] >= start_dt |
|
|
time_range_df = df[mask] |
|
|
|
|
|
|
|
|
if len(time_range_df) < lookback + pred_len: |
|
|
return jsonify({ |
|
|
'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400 |
|
|
|
|
|
|
|
|
x_df = time_range_df.iloc[:lookback][required_cols] |
|
|
x_timestamp = time_range_df.iloc[:lookback]['timestamps'] |
|
|
|
|
|
print(f"🔍 Custom time period - x_df shape: {x_df.shape}") |
|
|
print(f" x_timestamp length: {len(x_timestamp)}") |
|
|
print(f" x_df columns: {list(x_df.columns)}") |
|
|
print(f" x_df sample:\n{x_df.head()}") |
|
|
|
|
|
|
|
|
|
|
|
if len(time_range_df) >= 2: |
|
|
time_diff = time_range_df['timestamps'].iloc[1] - time_range_df['timestamps'].iloc[0] |
|
|
else: |
|
|
time_diff = pd.Timedelta(hours=1) |
|
|
|
|
|
|
|
|
last_timestamp = time_range_df['timestamps'].iloc[lookback - 1] |
|
|
y_timestamp = pd.date_range( |
|
|
start=last_timestamp + time_diff, |
|
|
periods=pred_len, |
|
|
freq=time_diff |
|
|
) |
|
|
|
|
|
|
|
|
start_timestamp = time_range_df['timestamps'].iloc[0] |
|
|
end_timestamp = y_timestamp[-1] |
|
|
time_span = end_timestamp - start_timestamp |
|
|
|
|
|
prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, {pred_len} future predictions, time span: {time_span})" |
|
|
else: |
|
|
|
|
|
x_df = df.iloc[:lookback][required_cols] |
|
|
x_timestamp = df.iloc[:lookback]['timestamps'] |
|
|
|
|
|
|
|
|
|
|
|
if len(df) >= 2: |
|
|
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] |
|
|
else: |
|
|
time_diff = pd.Timedelta(hours=1) |
|
|
|
|
|
|
|
|
last_timestamp = df['timestamps'].iloc[lookback - 1] |
|
|
y_timestamp = pd.date_range( |
|
|
start=last_timestamp + time_diff, |
|
|
periods=pred_len, |
|
|
freq=time_diff |
|
|
) |
|
|
prediction_type = "Kronos model prediction (latest data)" |
|
|
|
|
|
print(f"🔍 Latest data - x_df shape: {x_df.shape}") |
|
|
print(f" x_timestamp length: {len(x_timestamp)}") |
|
|
print(f" y_timestamp length: {len(y_timestamp)}") |
|
|
print(f" x_df columns: {list(x_df.columns)}") |
|
|
print(f" x_df sample:\n{x_df.head()}") |
|
|
|
|
|
|
|
|
if x_df.empty or len(x_df) == 0: |
|
|
return jsonify({'error': 'Input data is empty after processing'}), 400 |
|
|
|
|
|
if len(x_timestamp) == 0: |
|
|
return jsonify({'error': 'Input timestamps are empty'}), 400 |
|
|
|
|
|
if len(y_timestamp) == 0: |
|
|
return jsonify({'error': 'Target timestamps are empty'}), 400 |
|
|
|
|
|
|
|
|
if isinstance(x_timestamp, pd.DatetimeIndex): |
|
|
x_timestamp = pd.Series(x_timestamp, name='timestamps') |
|
|
if isinstance(y_timestamp, pd.DatetimeIndex): |
|
|
y_timestamp = pd.Series(y_timestamp, name='timestamps') |
|
|
|
|
|
pred_df = predictor.predict( |
|
|
df=x_df, |
|
|
x_timestamp=x_timestamp, |
|
|
y_timestamp=y_timestamp, |
|
|
pred_len=pred_len, |
|
|
T=temperature, |
|
|
top_p=top_p, |
|
|
sample_count=sample_count |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500 |
|
|
else: |
|
|
return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400 |
|
|
|
|
|
|
|
|
actual_data = [] |
|
|
actual_df = None |
|
|
|
|
|
if start_date: |
|
|
|
|
|
|
|
|
|
|
|
start_dt = pd.to_datetime(start_date) |
|
|
|
|
|
if start_dt.tz is None: |
|
|
start_dt = start_dt.tz_localize(BEIJING_TZ) |
|
|
|
|
|
|
|
|
mask = df['timestamps'] >= start_dt |
|
|
time_range_df = df[mask] |
|
|
|
|
|
if len(time_range_df) >= lookback + pred_len: |
|
|
|
|
|
actual_df = time_range_df.iloc[lookback:lookback + pred_len] |
|
|
|
|
|
for i, (_, row) in enumerate(actual_df.iterrows()): |
|
|
actual_data.append({ |
|
|
'timestamp': row['timestamps'].isoformat(), |
|
|
'open': float(row['open']), |
|
|
'high': float(row['high']), |
|
|
'low': float(row['low']), |
|
|
'close': float(row['close']), |
|
|
'volume': float(row['volume']) if 'volume' in row else 0, |
|
|
'amount': float(row['amount']) if 'amount' in row else 0 |
|
|
}) |
|
|
else: |
|
|
|
|
|
|
|
|
if len(df) >= lookback + pred_len: |
|
|
actual_df = df.iloc[lookback:lookback + pred_len] |
|
|
for i, (_, row) in enumerate(actual_df.iterrows()): |
|
|
actual_data.append({ |
|
|
'timestamp': row['timestamps'].isoformat(), |
|
|
'open': float(row['open']), |
|
|
'high': float(row['high']), |
|
|
'low': float(row['low']), |
|
|
'close': float(row['close']), |
|
|
'volume': float(row['volume']) if 'volume' in row else 0, |
|
|
'amount': float(row['amount']) if 'amount' in row else 0 |
|
|
}) |
|
|
|
|
|
|
|
|
if start_date: |
|
|
|
|
|
start_dt = pd.to_datetime(start_date) |
|
|
|
|
|
if start_dt.tz is None: |
|
|
start_dt = start_dt.tz_localize(BEIJING_TZ) |
|
|
mask = df['timestamps'] >= start_dt |
|
|
historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0 |
|
|
else: |
|
|
|
|
|
historical_start_idx = 0 |
|
|
|
|
|
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx) |
|
|
|
|
|
|
|
|
if 'timestamps' in df.columns: |
|
|
if start_date: |
|
|
|
|
|
start_dt = pd.to_datetime(start_date) |
|
|
|
|
|
if start_dt.tz is None: |
|
|
start_dt = start_dt.tz_localize(BEIJING_TZ) |
|
|
mask = df['timestamps'] >= start_dt |
|
|
time_range_df = df[mask] |
|
|
|
|
|
if len(time_range_df) >= lookback: |
|
|
|
|
|
last_timestamp = time_range_df['timestamps'].iloc[lookback - 1] |
|
|
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] |
|
|
future_timestamps = pd.date_range( |
|
|
start=last_timestamp + time_diff, |
|
|
periods=pred_len, |
|
|
freq=time_diff |
|
|
) |
|
|
else: |
|
|
future_timestamps = [] |
|
|
else: |
|
|
|
|
|
last_timestamp = df['timestamps'].iloc[-1] |
|
|
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] |
|
|
future_timestamps = pd.date_range( |
|
|
start=last_timestamp + time_diff, |
|
|
periods=pred_len, |
|
|
freq=time_diff |
|
|
) |
|
|
else: |
|
|
future_timestamps = range(len(df), len(df) + pred_len) |
|
|
|
|
|
prediction_results = [] |
|
|
for i, (_, row) in enumerate(pred_df.iterrows()): |
|
|
prediction_results.append({ |
|
|
'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}", |
|
|
'open': float(row['open']), |
|
|
'high': float(row['high']), |
|
|
'low': float(row['low']), |
|
|
'close': float(row['close']), |
|
|
'volume': float(row['volume']) if 'volume' in row else 0, |
|
|
'amount': float(row['amount']) if 'amount' in row else 0 |
|
|
}) |
|
|
|
|
|
|
|
|
try: |
|
|
data_source = f"{symbol}_{interval}" |
|
|
save_prediction_results( |
|
|
file_path=data_source, |
|
|
prediction_type=prediction_type, |
|
|
prediction_results=prediction_results, |
|
|
actual_data=actual_data, |
|
|
input_data=x_df, |
|
|
prediction_params={ |
|
|
'symbol': symbol, |
|
|
'interval': interval, |
|
|
'limit': limit, |
|
|
'lookback': lookback, |
|
|
'pred_len': pred_len, |
|
|
'temperature': temperature, |
|
|
'top_p': top_p, |
|
|
'sample_count': sample_count, |
|
|
'start_date': start_date if start_date else 'latest' |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to save prediction results: {e}") |
|
|
|
|
|
return jsonify({ |
|
|
'success': True, |
|
|
'prediction_type': prediction_type, |
|
|
'chart': chart_json, |
|
|
'prediction_results': prediction_results, |
|
|
'actual_data': actual_data, |
|
|
'has_comparison': len(actual_data) > 0, |
|
|
'message': f'Prediction completed, generated {pred_len} prediction points' + ( |
|
|
f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '') |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({'error': f'Prediction failed: {str(e)}'}), 500 |
|
|
|
|
|
|
|
|
@app.route('/api/load-model', methods=['POST']) |
|
|
def load_model(): |
|
|
"""Load Kronos model""" |
|
|
global tokenizer, model, predictor |
|
|
|
|
|
try: |
|
|
if not MODEL_AVAILABLE: |
|
|
return jsonify({'error': 'Kronos model library not available'}), 400 |
|
|
|
|
|
data = request.get_json() |
|
|
model_key = data.get('model_key', 'kronos-small') |
|
|
device = data.get('device', 'cpu') |
|
|
|
|
|
if model_key not in AVAILABLE_MODELS: |
|
|
return jsonify({'error': f'Unsupported model: {model_key}'}), 400 |
|
|
|
|
|
model_config = AVAILABLE_MODELS[model_key] |
|
|
|
|
|
|
|
|
tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id']) |
|
|
model = Kronos.from_pretrained(model_config['model_id']) |
|
|
|
|
|
|
|
|
predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length']) |
|
|
|
|
|
return jsonify({ |
|
|
'success': True, |
|
|
'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}', |
|
|
'model_info': { |
|
|
'name': model_config['name'], |
|
|
'params': model_config['params'], |
|
|
'context_length': model_config['context_length'], |
|
|
'description': model_config['description'] |
|
|
} |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({'error': f'Model loading failed: {str(e)}'}), 500 |
|
|
|
|
|
|
|
|
@app.route('/api/available-models') |
|
|
def get_available_models(): |
|
|
"""Get available model list""" |
|
|
return jsonify({ |
|
|
'models': AVAILABLE_MODELS, |
|
|
'model_available': MODEL_AVAILABLE |
|
|
}) |
|
|
|
|
|
|
|
|
@app.route('/api/model-status') |
|
|
def get_model_status(): |
|
|
"""Get model status""" |
|
|
if MODEL_AVAILABLE: |
|
|
if predictor is not None: |
|
|
return jsonify({ |
|
|
'available': True, |
|
|
'loaded': True, |
|
|
'message': 'Kronos model loaded and available', |
|
|
'current_model': { |
|
|
'name': predictor.model.__class__.__name__, |
|
|
'device': str(next(predictor.model.parameters()).device) |
|
|
} |
|
|
}) |
|
|
else: |
|
|
return jsonify({ |
|
|
'available': True, |
|
|
'loaded': False, |
|
|
'message': 'Kronos model available but not loaded' |
|
|
}) |
|
|
else: |
|
|
return jsonify({ |
|
|
'available': False, |
|
|
'loaded': False, |
|
|
'message': 'Kronos model library not available, please install related dependencies' |
|
|
}) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
print("Starting Kronos Web UI...") |
|
|
print(f"Model availability: {MODEL_AVAILABLE}") |
|
|
if MODEL_AVAILABLE: |
|
|
print("Tip: You can load Kronos model through /api/load-model endpoint") |
|
|
else: |
|
|
print("Tip: Will use simulated data for demonstration") |
|
|
|
|
|
app.run(debug=True, host='0.0.0.0', port=7860) |
|
|
|