crypt / webui /app.py
heyunfei's picture
Update webui/app.py
7690b19 verified
import datetime
import json
import os
import sys
import warnings
import pandas as pd
import plotly.graph_objects as go
import plotly.utils
import pytz
from binance.client import Client
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
from sympy import false
try:
from technical_indicators import add_technical_indicators, get_available_indicators
TECHNICAL_INDICATORS_AVAILABLE = False
except ImportError as e:
print(f"⚠️ 技术指标模块导入失败: {e}")
TECHNICAL_INDICATORS_AVAILABLE = False
# 定义空的替代函数
def add_technical_indicators(df, indicators_config=None):
return df
def get_available_indicators():
return {'trend': [], 'momentum': [], 'volatility': [], 'volume': []}
warnings.filterwarnings('ignore')
# 设置东八区时区
BEIJING_TZ = pytz.timezone('Asia/Shanghai')
# Add project root directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from model import Kronos, KronosTokenizer, KronosPredictor
MODEL_AVAILABLE = True
except ImportError:
MODEL_AVAILABLE = False
print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
app = Flask(__name__)
CORS(app)
# Global variables to store models
tokenizer = None
model = None
predictor = None
# Available model configurations
AVAILABLE_MODELS = {
'kronos-mini': {
'name': 'Kronos-mini',
'model_id': 'NeoQuasar/Kronos-mini',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k',
'context_length': 2048,
'params': '4.1M',
'description': 'Lightweight model, suitable for fast prediction'
},
'kronos-small': {
'name': 'Kronos-small',
'model_id': 'NeoQuasar/Kronos-small',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512,
'params': '24.7M',
'description': 'Small model, balanced performance and speed'
},
'kronos-base': {
'name': 'Kronos-base',
'model_id': 'NeoQuasar/Kronos-base',
'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base',
'context_length': 512,
'params': '102.3M',
'description': 'Base model, provides better prediction quality'
}
}
def get_available_symbols():
"""获取固定的交易对列表"""
# 返回固定的主要交易对,不再从币安API获取
return [
{'symbol': 'BTCUSDT', 'baseAsset': 'BTC', 'quoteAsset': 'USDT', 'name': 'BTC/USDT'},
{'symbol': 'ETHUSDT', 'baseAsset': 'ETH', 'quoteAsset': 'USDT', 'name': 'ETH/USDT'},
{'symbol': 'SOLUSDT', 'baseAsset': 'SOL', 'quoteAsset': 'USDT', 'name': 'SOL/USDT'},
{'symbol': 'BNBUSDT', 'baseAsset': 'BNB', 'quoteAsset': 'USDT', 'name': 'BNB/USDT'}
]
# 币安客户端初始化(使用公开API,无需API密钥)
binance_client = Client("", "")
def get_binance_klines(symbol, interval='1h', limit=1000):
"""从币安获取K线数据,如果失败则生成模拟数据"""
try:
# 尝试获取真实的币安数据
klines = binance_client.get_klines(
symbol=symbol,
interval=interval,
limit=limit
)
# 转换为DataFrame
df = pd.DataFrame(klines, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_asset_volume', 'number_of_trades',
'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore'
])
# 数据类型转换,转换为东八区时间
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df['timestamp'] = df['timestamp'].dt.tz_convert(BEIJING_TZ)
df['timestamps'] = df['timestamp'] # 保持兼容性
# 转换数值列
numeric_cols = ['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume']
for col in numeric_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 添加amount列(成交额)
df['amount'] = df['quote_asset_volume']
# 只保留需要的列
df = df[['timestamp','timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']]
# 按时间排序
df = df.sort_values('timestamp').reset_index(drop=True)
# 添加技术指标(如果可用)
if TECHNICAL_INDICATORS_AVAILABLE:
try:
df = add_technical_indicators(df)
print(f"✅ 成功获取币安真实数据并计算技术指标: {symbol} {interval} {len(df)}条,{len(df.columns)}个特征")
except Exception as e:
print(f"⚠️ 技术指标计算失败,使用原始数据: {e}")
else:
print(f"✅ 成功获取币安真实数据: {symbol} {interval} {len(df)}条")
return df, None
except Exception as e:
print(f"⚠️ 币安API连接失败,使用模拟数据: {str(e)}")
def get_timeframe_options():
"""获取可用的时间周期选项"""
return [
{'value': '1m', 'label': '1分钟', 'description': '1分钟K线'},
{'value': '5m', 'label': '5分钟', 'description': '5分钟K线'},
{'value': '15m', 'label': '15分钟', 'description': '15分钟K线'},
{'value': '30m', 'label': '30分钟', 'description': '30分钟K线'},
{'value': '1h', 'label': '1小时', 'description': '1小时K线'},
{'value': '4h', 'label': '4小时', 'description': '4小时K线'},
{'value': '1d', 'label': '1天', 'description': '日K线'},
{'value': '1w', 'label': '1周', 'description': '周K线'},
]
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
"""Save prediction results to file"""
try:
# Create prediction results directory
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
os.makedirs(results_dir, exist_ok=True)
# Generate filename
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'prediction_{timestamp}.json'
filepath = os.path.join(results_dir, filename)
# Prepare data for saving
save_data = {
'timestamp': datetime.datetime.now().isoformat(),
'file_path': file_path,
'prediction_type': prediction_type,
'prediction_params': prediction_params,
'input_data_summary': {
'rows': len(input_data),
'columns': list(input_data.columns),
'price_range': {
'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
},
'last_values': {
'open': float(input_data['open'].iloc[-1]),
'high': float(input_data['high'].iloc[-1]),
'low': float(input_data['low'].iloc[-1]),
'close': float(input_data['close'].iloc[-1])
}
},
'prediction_results': prediction_results,
'actual_data': actual_data,
'analysis': {}
}
# If actual data exists, perform comparison analysis
if actual_data and len(actual_data) > 0:
# Calculate continuity analysis
if len(prediction_results) > 0 and len(actual_data) > 0:
last_pred = prediction_results[0] # First prediction point
first_actual = actual_data[0] # First actual point
save_data['analysis']['continuity'] = {
'last_prediction': {
'open': last_pred['open'],
'high': last_pred['high'],
'low': last_pred['low'],
'close': last_pred['close']
},
'first_actual': {
'open': first_actual['open'],
'high': first_actual['high'],
'low': first_actual['low'],
'close': first_actual['close']
},
'gaps': {
'open_gap': abs(last_pred['open'] - first_actual['open']),
'high_gap': abs(last_pred['high'] - first_actual['high']),
'low_gap': abs(last_pred['low'] - first_actual['low']),
'close_gap': abs(last_pred['close'] - first_actual['close'])
},
'gap_percentages': {
'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
}
}
# Save to file
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
print(f"Prediction results saved to: {filepath}")
return filepath
except Exception as e:
print(f"Failed to save prediction results: {e}")
return None
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
"""Create prediction chart"""
# Use specified historical data start position, not always from the beginning of df
if historical_start_idx + lookback + pred_len <= len(df):
# Display lookback historical points + pred_len prediction points starting from specified position
historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len)
else:
# If data is insufficient, adjust to maximum available range
available_lookback = min(lookback, len(df) - historical_start_idx)
available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
prediction_range = range(historical_start_idx + available_lookback,
historical_start_idx + available_lookback + available_pred_len)
# Create chart
fig = go.Figure()
# Add historical data (candlestick chart)
fig.add_trace(go.Candlestick(
x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
open=historical_df['open'],
high=historical_df['high'],
low=historical_df['low'],
close=historical_df['close'],
name='Historical Data (400 data points)',
increasing_line_color='#26A69A',
decreasing_line_color='#EF5350'
))
# Add prediction data (candlestick chart)
if pred_df is not None and len(pred_df) > 0:
# Calculate prediction data timestamps - ensure continuity with historical data
if 'timestamps' in df.columns and len(historical_df) > 0:
# Start from the last timestamp of historical data, create prediction timestamps with the same time interval
last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
pred_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=len(pred_df),
freq=time_diff
)
else:
# If no timestamps, use index
pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
fig.add_trace(go.Candlestick(
x=pred_timestamps,
open=pred_df['open'],
high=pred_df['high'],
low=pred_df['low'],
close=pred_df['close'],
name='Prediction Data (120 data points)',
increasing_line_color='#66BB6A',
decreasing_line_color='#FF7043'
))
# Add actual data for comparison (if exists)
if actual_df is not None and len(actual_df) > 0:
# Actual data should be in the same time period as prediction data
if 'timestamps' in df.columns:
# Actual data should use the same timestamps as prediction data to ensure time alignment
if 'pred_timestamps' in locals():
actual_timestamps = pred_timestamps
else:
# If no prediction timestamps, calculate from the last timestamp of historical data
if len(historical_df) > 0:
last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
hours=1)
actual_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=len(actual_df),
freq=time_diff
)
else:
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
else:
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
fig.add_trace(go.Candlestick(
x=actual_timestamps,
open=actual_df['open'],
high=actual_df['high'],
low=actual_df['low'],
close=actual_df['close'],
name='Actual Data (120 data points)',
increasing_line_color='#FF9800',
decreasing_line_color='#F44336'
))
# Update layout
fig.update_layout(
title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
xaxis_title='Time',
yaxis_title='Price',
template='plotly_white',
height=600,
showlegend=True
)
# Ensure x-axis time continuity
if 'timestamps' in historical_df.columns:
# Get all timestamps and sort them
all_timestamps = []
if len(historical_df) > 0:
all_timestamps.extend(historical_df['timestamps'])
if 'pred_timestamps' in locals():
all_timestamps.extend(pred_timestamps)
if 'actual_timestamps' in locals():
all_timestamps.extend(actual_timestamps)
if all_timestamps:
all_timestamps = sorted(all_timestamps)
fig.update_xaxes(
range=[all_timestamps[0], all_timestamps[-1]],
rangeslider_visible=False,
type='date'
)
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
@app.route('/')
def index():
"""Home page"""
return render_template('index.html')
@app.route('/api/symbols')
def get_symbols():
"""获取可用的交易对列表"""
symbols = get_available_symbols()
return jsonify(symbols)
@app.route('/api/timeframes')
def get_timeframes():
"""获取可用的时间周期列表"""
timeframes = get_timeframe_options()
return jsonify(timeframes)
@app.route('/api/technical-indicators')
def get_technical_indicators():
"""获取可用的技术指标列表"""
indicators = get_available_indicators()
return jsonify(indicators)
@app.route('/api/load-data', methods=['POST'])
def load_data():
"""加载币安数据"""
try:
data = request.get_json()
symbol = data.get('symbol')
interval = data.get('interval', '1h')
limit = int(data.get('limit', 1000))
if not symbol:
return jsonify({'error': '交易对不能为空'}), 400
df, error = get_binance_klines(symbol, interval, limit)
if error:
return jsonify({'error': error}), 400
# Detect data time frequency
def detect_timeframe(df):
if len(df) < 2:
return "Unknown"
time_diffs = []
for i in range(1, min(10, len(df))): # Check first 10 time differences
diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i - 1]
time_diffs.append(diff)
if not time_diffs:
return "Unknown"
# Calculate average time difference
avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
# Convert to readable format
if avg_diff < pd.Timedelta(minutes=1):
return f"{avg_diff.total_seconds():.0f} seconds"
elif avg_diff < pd.Timedelta(hours=1):
return f"{avg_diff.total_seconds() / 60:.0f} minutes"
elif avg_diff < pd.Timedelta(days=1):
return f"{avg_diff.total_seconds() / 3600:.0f} hours"
else:
return f"{avg_diff.days} days"
# Return data information with formatted time
def format_beijing_time(timestamp):
"""格式化东八区时间为 yyyy-MM-dd HH:mm:ss"""
if pd.isna(timestamp):
return 'N/A'
# 确保时间戳有时区信息
if timestamp.tz is None:
timestamp = timestamp.tz_localize(BEIJING_TZ)
elif timestamp.tz != BEIJING_TZ:
timestamp = timestamp.tz_convert(BEIJING_TZ)
return timestamp.strftime('%Y-%m-%d %H:%M:%S')
data_info = {
'rows': len(df),
'columns': list(df.columns),
'start_date': format_beijing_time(df['timestamps'].min()) if 'timestamps' in df.columns else 'N/A',
'end_date': format_beijing_time(df['timestamps'].max()) if 'timestamps' in df.columns else 'N/A',
'price_range': {
'min': float(df[['open', 'high', 'low', 'close']].min().min()),
'max': float(df[['open', 'high', 'low', 'close']].max().max())
},
'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
'timeframe': detect_timeframe(df)
}
return jsonify({
'success': True,
'data_info': data_info,
'message': f'Successfully loaded data, total {len(df)} rows'
})
except Exception as e:
return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
@app.route('/api/predict', methods=['POST'])
def predict():
"""Perform prediction"""
try:
data = request.get_json()
symbol = data.get('symbol')
interval = data.get('interval', '1h')
limit = int(data.get('limit', 1000))
lookback = int(data.get('lookback', 400))
pred_len = int(data.get('pred_len', 120))
# Get prediction quality parameters
temperature = float(data.get('temperature', 1.0))
top_p = float(data.get('top_p', 0.9))
sample_count = int(data.get('sample_count', 1))
if not symbol:
return jsonify({'error': '交易对不能为空'}), 400
# Load data from Binance
df, error = get_binance_klines(symbol, interval, limit)
if error:
return jsonify({'error': error}), 400
if len(df) < lookback:
return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
# Perform prediction
if MODEL_AVAILABLE:
try:
# Use real Kronos model
# Only use necessary columns: OHLCVA (6 features required by Kronos model)
required_cols = ['open', 'high', 'low', 'close']
if 'volume' in df.columns:
required_cols.append('volume')
if 'amount' in df.columns:
required_cols.append('amount')
print(f"🔍 Using features for prediction: {required_cols}")
print(f" Available columns in data: {list(df.columns)}")
print(f" Data shape: {df.shape}")
# Check if required columns exist
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
return jsonify({'error': f'Missing required columns: {missing_cols}'}), 400
# Process time period selection
start_date = data.get('start_date')
if start_date:
# Custom time period - fix logic: use data within selected window
start_dt = pd.to_datetime(start_date)
# Find data after start time
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
# Ensure sufficient data: lookback + pred_len
if len(time_range_df) < lookback + pred_len:
return jsonify({
'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
# Use first lookback data points within selected window for prediction
x_df = time_range_df.iloc[:lookback][required_cols]
x_timestamp = time_range_df.iloc[:lookback]['timestamps']
print(f"🔍 Custom time period - x_df shape: {x_df.shape}")
print(f" x_timestamp length: {len(x_timestamp)}")
print(f" x_df columns: {list(x_df.columns)}")
print(f" x_df sample:\n{x_df.head()}")
# Generate future timestamps for prediction instead of using existing data
# Calculate time difference from the data
if len(time_range_df) >= 2:
time_diff = time_range_df['timestamps'].iloc[1] - time_range_df['timestamps'].iloc[0]
else:
time_diff = pd.Timedelta(hours=1) # Default to 1 hour
# Generate future timestamps starting from the last timestamp of input data
last_timestamp = time_range_df['timestamps'].iloc[lookback - 1]
y_timestamp = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
# Calculate actual time period length
start_timestamp = time_range_df['timestamps'].iloc[0]
end_timestamp = y_timestamp[-1] # Use the last generated timestamp
time_span = end_timestamp - start_timestamp
prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, {pred_len} future predictions, time span: {time_span})"
else:
# Use latest data
x_df = df.iloc[:lookback][required_cols]
x_timestamp = df.iloc[:lookback]['timestamps']
# Generate future timestamps for prediction instead of using existing data
# Calculate time difference from the data
if len(df) >= 2:
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
else:
time_diff = pd.Timedelta(hours=1) # Default to 1 hour
# Generate future timestamps starting from the last timestamp of input data
last_timestamp = df['timestamps'].iloc[lookback - 1]
y_timestamp = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
prediction_type = "Kronos model prediction (latest data)"
print(f"🔍 Latest data - x_df shape: {x_df.shape}")
print(f" x_timestamp length: {len(x_timestamp)}")
print(f" y_timestamp length: {len(y_timestamp)}")
print(f" x_df columns: {list(x_df.columns)}")
print(f" x_df sample:\n{x_df.head()}")
# Check if data is empty
if x_df.empty or len(x_df) == 0:
return jsonify({'error': 'Input data is empty after processing'}), 400
if len(x_timestamp) == 0:
return jsonify({'error': 'Input timestamps are empty'}), 400
if len(y_timestamp) == 0:
return jsonify({'error': 'Target timestamps are empty'}), 400
# Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
if isinstance(x_timestamp, pd.DatetimeIndex):
x_timestamp = pd.Series(x_timestamp, name='timestamps')
if isinstance(y_timestamp, pd.DatetimeIndex):
y_timestamp = pd.Series(y_timestamp, name='timestamps')
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=temperature,
top_p=top_p,
sample_count=sample_count
)
except Exception as e:
return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
else:
return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
# Prepare actual data for comparison (if exists)
actual_data = []
actual_df = None
if start_date: # Custom time period
# Fix logic: use data within selected window
# Prediction uses first 400 data points within selected window
# Actual data should be last 120 data points within selected window
start_dt = pd.to_datetime(start_date)
# 确保时区一致性
if start_dt.tz is None:
start_dt = start_dt.tz_localize(BEIJING_TZ)
# Find data starting from start_date
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
if len(time_range_df) >= lookback + pred_len:
# Get last 120 data points within selected window as actual values
actual_df = time_range_df.iloc[lookback:lookback + pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()):
actual_data.append({
'timestamp': row['timestamps'].isoformat(),
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
else: # Latest data
# Prediction uses first 400 data points
# Actual data should be 120 data points after first 400 data points
if len(df) >= lookback + pred_len:
actual_df = df.iloc[lookback:lookback + pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()):
actual_data.append({
'timestamp': row['timestamps'].isoformat(),
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
# Create chart - pass historical data start position
if start_date:
# Custom time period: find starting position of historical data in original df
start_dt = pd.to_datetime(start_date)
# 确保时区一致性
if start_dt.tz is None:
start_dt = start_dt.tz_localize(BEIJING_TZ)
mask = df['timestamps'] >= start_dt
historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
else:
# Latest data: start from beginning
historical_start_idx = 0
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
# Prepare prediction result data - fix timestamp calculation logic
if 'timestamps' in df.columns:
if start_date:
# Custom time period: use selected window data to calculate timestamps
start_dt = pd.to_datetime(start_date)
# 确保时区一致性
if start_dt.tz is None:
start_dt = start_dt.tz_localize(BEIJING_TZ)
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
if len(time_range_df) >= lookback:
# Calculate prediction timestamps starting from last time point of selected window
last_timestamp = time_range_df['timestamps'].iloc[lookback - 1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
else:
future_timestamps = []
else:
# Latest data: calculate from last time point of entire data file
last_timestamp = df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
else:
future_timestamps = range(len(df), len(df) + pred_len)
prediction_results = []
for i, (_, row) in enumerate(pred_df.iterrows()):
prediction_results.append({
'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
# Save prediction results to file
try:
data_source = f"{symbol}_{interval}"
save_prediction_results(
file_path=data_source,
prediction_type=prediction_type,
prediction_results=prediction_results,
actual_data=actual_data,
input_data=x_df,
prediction_params={
'symbol': symbol,
'interval': interval,
'limit': limit,
'lookback': lookback,
'pred_len': pred_len,
'temperature': temperature,
'top_p': top_p,
'sample_count': sample_count,
'start_date': start_date if start_date else 'latest'
}
)
except Exception as e:
print(f"Failed to save prediction results: {e}")
return jsonify({
'success': True,
'prediction_type': prediction_type,
'chart': chart_json,
'prediction_results': prediction_results,
'actual_data': actual_data,
'has_comparison': len(actual_data) > 0,
'message': f'Prediction completed, generated {pred_len} prediction points' + (
f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
})
except Exception as e:
return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
@app.route('/api/load-model', methods=['POST'])
def load_model():
"""Load Kronos model"""
global tokenizer, model, predictor
try:
if not MODEL_AVAILABLE:
return jsonify({'error': 'Kronos model library not available'}), 400
data = request.get_json()
model_key = data.get('model_key', 'kronos-small')
device = data.get('device', 'cpu')
if model_key not in AVAILABLE_MODELS:
return jsonify({'error': f'Unsupported model: {model_key}'}), 400
model_config = AVAILABLE_MODELS[model_key]
# Load tokenizer and model
tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
model = Kronos.from_pretrained(model_config['model_id'])
# Create predictor
predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
return jsonify({
'success': True,
'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
'model_info': {
'name': model_config['name'],
'params': model_config['params'],
'context_length': model_config['context_length'],
'description': model_config['description']
}
})
except Exception as e:
return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route('/api/available-models')
def get_available_models():
"""Get available model list"""
return jsonify({
'models': AVAILABLE_MODELS,
'model_available': MODEL_AVAILABLE
})
@app.route('/api/model-status')
def get_model_status():
"""Get model status"""
if MODEL_AVAILABLE:
if predictor is not None:
return jsonify({
'available': True,
'loaded': True,
'message': 'Kronos model loaded and available',
'current_model': {
'name': predictor.model.__class__.__name__,
'device': str(next(predictor.model.parameters()).device)
}
})
else:
return jsonify({
'available': True,
'loaded': False,
'message': 'Kronos model available but not loaded'
})
else:
return jsonify({
'available': False,
'loaded': False,
'message': 'Kronos model library not available, please install related dependencies'
})
if __name__ == '__main__':
print("Starting Kronos Web UI...")
print(f"Model availability: {MODEL_AVAILABLE}")
if MODEL_AVAILABLE:
print("Tip: You can load Kronos model through /api/load-model endpoint")
else:
print("Tip: Will use simulated data for demonstration")
app.run(debug=True, host='0.0.0.0', port=7860)