import datetime import json import os import sys import warnings import pandas as pd import plotly.graph_objects as go import plotly.utils import pytz from binance.client import Client from flask import Flask, render_template, request, jsonify from flask_cors import CORS from sympy import false try: from technical_indicators import add_technical_indicators, get_available_indicators TECHNICAL_INDICATORS_AVAILABLE = False except ImportError as e: print(f"⚠️ 技术指标模块导入失败: {e}") TECHNICAL_INDICATORS_AVAILABLE = False # 定义空的替代函数 def add_technical_indicators(df, indicators_config=None): return df def get_available_indicators(): return {'trend': [], 'momentum': [], 'volatility': [], 'volume': []} warnings.filterwarnings('ignore') # 设置东八区时区 BEIJING_TZ = pytz.timezone('Asia/Shanghai') # Add project root directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) try: from model import Kronos, KronosTokenizer, KronosPredictor MODEL_AVAILABLE = True except ImportError: MODEL_AVAILABLE = False print("Warning: Kronos model cannot be imported, will use simulated data for demonstration") app = Flask(__name__) CORS(app) # Global variables to store models tokenizer = None model = None predictor = None # Available model configurations AVAILABLE_MODELS = { 'kronos-mini': { 'name': 'Kronos-mini', 'model_id': 'NeoQuasar/Kronos-mini', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-2k', 'context_length': 2048, 'params': '4.1M', 'description': 'Lightweight model, suitable for fast prediction' }, 'kronos-small': { 'name': 'Kronos-small', 'model_id': 'NeoQuasar/Kronos-small', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', 'context_length': 512, 'params': '24.7M', 'description': 'Small model, balanced performance and speed' }, 'kronos-base': { 'name': 'Kronos-base', 'model_id': 'NeoQuasar/Kronos-base', 'tokenizer_id': 'NeoQuasar/Kronos-Tokenizer-base', 'context_length': 512, 'params': '102.3M', 'description': 'Base model, provides better prediction quality' } } def get_available_symbols(): """获取固定的交易对列表""" # 返回固定的主要交易对,不再从币安API获取 return [ {'symbol': 'BTCUSDT', 'baseAsset': 'BTC', 'quoteAsset': 'USDT', 'name': 'BTC/USDT'}, {'symbol': 'ETHUSDT', 'baseAsset': 'ETH', 'quoteAsset': 'USDT', 'name': 'ETH/USDT'}, {'symbol': 'SOLUSDT', 'baseAsset': 'SOL', 'quoteAsset': 'USDT', 'name': 'SOL/USDT'}, {'symbol': 'BNBUSDT', 'baseAsset': 'BNB', 'quoteAsset': 'USDT', 'name': 'BNB/USDT'} ] # 币安客户端初始化(使用公开API,无需API密钥) binance_client = Client("", "") def get_binance_klines(symbol, interval='1h', limit=1000): """从币安获取K线数据,如果失败则生成模拟数据""" try: # 尝试获取真实的币安数据 klines = binance_client.get_klines( symbol=symbol, interval=interval, limit=limit ) # 转换为DataFrame df = pd.DataFrame(klines, columns=[ 'timestamp', 'open', 'high', 'low', 'close', 'volume', 'close_time', 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore' ]) # 数据类型转换,转换为东八区时间 df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) df['timestamp'] = df['timestamp'].dt.tz_convert(BEIJING_TZ) df['timestamps'] = df['timestamp'] # 保持兼容性 # 转换数值列 numeric_cols = ['open', 'high', 'low', 'close', 'volume', 'quote_asset_volume'] for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce') # 添加amount列(成交额) df['amount'] = df['quote_asset_volume'] # 只保留需要的列 df = df[['timestamp','timestamps', 'open', 'high', 'low', 'close', 'volume', 'amount']] # 按时间排序 df = df.sort_values('timestamp').reset_index(drop=True) # 添加技术指标(如果可用) if TECHNICAL_INDICATORS_AVAILABLE: try: df = add_technical_indicators(df) print(f"✅ 成功获取币安真实数据并计算技术指标: {symbol} {interval} {len(df)}条,{len(df.columns)}个特征") except Exception as e: print(f"⚠️ 技术指标计算失败,使用原始数据: {e}") else: print(f"✅ 成功获取币安真实数据: {symbol} {interval} {len(df)}条") return df, None except Exception as e: print(f"⚠️ 币安API连接失败,使用模拟数据: {str(e)}") def get_timeframe_options(): """获取可用的时间周期选项""" return [ {'value': '1m', 'label': '1分钟', 'description': '1分钟K线'}, {'value': '5m', 'label': '5分钟', 'description': '5分钟K线'}, {'value': '15m', 'label': '15分钟', 'description': '15分钟K线'}, {'value': '30m', 'label': '30分钟', 'description': '30分钟K线'}, {'value': '1h', 'label': '1小时', 'description': '1小时K线'}, {'value': '4h', 'label': '4小时', 'description': '4小时K线'}, {'value': '1d', 'label': '1天', 'description': '日K线'}, {'value': '1w', 'label': '1周', 'description': '周K线'}, ] def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params): """Save prediction results to file""" try: # Create prediction results directory results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results') os.makedirs(results_dir, exist_ok=True) # Generate filename timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') filename = f'prediction_{timestamp}.json' filepath = os.path.join(results_dir, filename) # Prepare data for saving save_data = { 'timestamp': datetime.datetime.now().isoformat(), 'file_path': file_path, 'prediction_type': prediction_type, 'prediction_params': prediction_params, 'input_data_summary': { 'rows': len(input_data), 'columns': list(input_data.columns), 'price_range': { 'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())}, 'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())}, 'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())}, 'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())} }, 'last_values': { 'open': float(input_data['open'].iloc[-1]), 'high': float(input_data['high'].iloc[-1]), 'low': float(input_data['low'].iloc[-1]), 'close': float(input_data['close'].iloc[-1]) } }, 'prediction_results': prediction_results, 'actual_data': actual_data, 'analysis': {} } # If actual data exists, perform comparison analysis if actual_data and len(actual_data) > 0: # Calculate continuity analysis if len(prediction_results) > 0 and len(actual_data) > 0: last_pred = prediction_results[0] # First prediction point first_actual = actual_data[0] # First actual point save_data['analysis']['continuity'] = { 'last_prediction': { 'open': last_pred['open'], 'high': last_pred['high'], 'low': last_pred['low'], 'close': last_pred['close'] }, 'first_actual': { 'open': first_actual['open'], 'high': first_actual['high'], 'low': first_actual['low'], 'close': first_actual['close'] }, 'gaps': { 'open_gap': abs(last_pred['open'] - first_actual['open']), 'high_gap': abs(last_pred['high'] - first_actual['high']), 'low_gap': abs(last_pred['low'] - first_actual['low']), 'close_gap': abs(last_pred['close'] - first_actual['close']) }, 'gap_percentages': { 'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100, 'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100, 'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100, 'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100 } } # Save to file with open(filepath, 'w', encoding='utf-8') as f: json.dump(save_data, f, indent=2, ensure_ascii=False) print(f"Prediction results saved to: {filepath}") return filepath except Exception as e: print(f"Failed to save prediction results: {e}") return None def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0): """Create prediction chart""" # Use specified historical data start position, not always from the beginning of df if historical_start_idx + lookback + pred_len <= len(df): # Display lookback historical points + pred_len prediction points starting from specified position historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback] prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len) else: # If data is insufficient, adjust to maximum available range available_lookback = min(lookback, len(df) - historical_start_idx) available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback)) historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback] prediction_range = range(historical_start_idx + available_lookback, historical_start_idx + available_lookback + available_pred_len) # Create chart fig = go.Figure() # Add historical data (candlestick chart) fig.add_trace(go.Candlestick( x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index, open=historical_df['open'], high=historical_df['high'], low=historical_df['low'], close=historical_df['close'], name='Historical Data (400 data points)', increasing_line_color='#26A69A', decreasing_line_color='#EF5350' )) # Add prediction data (candlestick chart) if pred_df is not None and len(pred_df) > 0: # Calculate prediction data timestamps - ensure continuity with historical data if 'timestamps' in df.columns and len(historical_df) > 0: # Start from the last timestamp of historical data, create prediction timestamps with the same time interval last_timestamp = historical_df['timestamps'].iloc[-1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1) pred_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=len(pred_df), freq=time_diff ) else: # If no timestamps, use index pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df)) fig.add_trace(go.Candlestick( x=pred_timestamps, open=pred_df['open'], high=pred_df['high'], low=pred_df['low'], close=pred_df['close'], name='Prediction Data (120 data points)', increasing_line_color='#66BB6A', decreasing_line_color='#FF7043' )) # Add actual data for comparison (if exists) if actual_df is not None and len(actual_df) > 0: # Actual data should be in the same time period as prediction data if 'timestamps' in df.columns: # Actual data should use the same timestamps as prediction data to ensure time alignment if 'pred_timestamps' in locals(): actual_timestamps = pred_timestamps else: # If no prediction timestamps, calculate from the last timestamp of historical data if len(historical_df) > 0: last_timestamp = historical_df['timestamps'].iloc[-1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta( hours=1) actual_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=len(actual_df), freq=time_diff ) else: actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df)) else: actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df)) fig.add_trace(go.Candlestick( x=actual_timestamps, open=actual_df['open'], high=actual_df['high'], low=actual_df['low'], close=actual_df['close'], name='Actual Data (120 data points)', increasing_line_color='#FF9800', decreasing_line_color='#F44336' )) # Update layout fig.update_layout( title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points', xaxis_title='Time', yaxis_title='Price', template='plotly_white', height=600, showlegend=True ) # Ensure x-axis time continuity if 'timestamps' in historical_df.columns: # Get all timestamps and sort them all_timestamps = [] if len(historical_df) > 0: all_timestamps.extend(historical_df['timestamps']) if 'pred_timestamps' in locals(): all_timestamps.extend(pred_timestamps) if 'actual_timestamps' in locals(): all_timestamps.extend(actual_timestamps) if all_timestamps: all_timestamps = sorted(all_timestamps) fig.update_xaxes( range=[all_timestamps[0], all_timestamps[-1]], rangeslider_visible=False, type='date' ) return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder) @app.route('/') def index(): """Home page""" return render_template('index.html') @app.route('/api/symbols') def get_symbols(): """获取可用的交易对列表""" symbols = get_available_symbols() return jsonify(symbols) @app.route('/api/timeframes') def get_timeframes(): """获取可用的时间周期列表""" timeframes = get_timeframe_options() return jsonify(timeframes) @app.route('/api/technical-indicators') def get_technical_indicators(): """获取可用的技术指标列表""" indicators = get_available_indicators() return jsonify(indicators) @app.route('/api/load-data', methods=['POST']) def load_data(): """加载币安数据""" try: data = request.get_json() symbol = data.get('symbol') interval = data.get('interval', '1h') limit = int(data.get('limit', 1000)) if not symbol: return jsonify({'error': '交易对不能为空'}), 400 df, error = get_binance_klines(symbol, interval, limit) if error: return jsonify({'error': error}), 400 # Detect data time frequency def detect_timeframe(df): if len(df) < 2: return "Unknown" time_diffs = [] for i in range(1, min(10, len(df))): # Check first 10 time differences diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i - 1] time_diffs.append(diff) if not time_diffs: return "Unknown" # Calculate average time difference avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs) # Convert to readable format if avg_diff < pd.Timedelta(minutes=1): return f"{avg_diff.total_seconds():.0f} seconds" elif avg_diff < pd.Timedelta(hours=1): return f"{avg_diff.total_seconds() / 60:.0f} minutes" elif avg_diff < pd.Timedelta(days=1): return f"{avg_diff.total_seconds() / 3600:.0f} hours" else: return f"{avg_diff.days} days" # Return data information with formatted time def format_beijing_time(timestamp): """格式化东八区时间为 yyyy-MM-dd HH:mm:ss""" if pd.isna(timestamp): return 'N/A' # 确保时间戳有时区信息 if timestamp.tz is None: timestamp = timestamp.tz_localize(BEIJING_TZ) elif timestamp.tz != BEIJING_TZ: timestamp = timestamp.tz_convert(BEIJING_TZ) return timestamp.strftime('%Y-%m-%d %H:%M:%S') data_info = { 'rows': len(df), 'columns': list(df.columns), 'start_date': format_beijing_time(df['timestamps'].min()) if 'timestamps' in df.columns else 'N/A', 'end_date': format_beijing_time(df['timestamps'].max()) if 'timestamps' in df.columns else 'N/A', 'price_range': { 'min': float(df[['open', 'high', 'low', 'close']].min().min()), 'max': float(df[['open', 'high', 'low', 'close']].max().max()) }, 'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []), 'timeframe': detect_timeframe(df) } return jsonify({ 'success': True, 'data_info': data_info, 'message': f'Successfully loaded data, total {len(df)} rows' }) except Exception as e: return jsonify({'error': f'Failed to load data: {str(e)}'}), 500 @app.route('/api/predict', methods=['POST']) def predict(): """Perform prediction""" try: data = request.get_json() symbol = data.get('symbol') interval = data.get('interval', '1h') limit = int(data.get('limit', 1000)) lookback = int(data.get('lookback', 400)) pred_len = int(data.get('pred_len', 120)) # Get prediction quality parameters temperature = float(data.get('temperature', 1.0)) top_p = float(data.get('top_p', 0.9)) sample_count = int(data.get('sample_count', 1)) if not symbol: return jsonify({'error': '交易对不能为空'}), 400 # Load data from Binance df, error = get_binance_klines(symbol, interval, limit) if error: return jsonify({'error': error}), 400 if len(df) < lookback: return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400 # Perform prediction if MODEL_AVAILABLE: try: # Use real Kronos model # Only use necessary columns: OHLCVA (6 features required by Kronos model) required_cols = ['open', 'high', 'low', 'close'] if 'volume' in df.columns: required_cols.append('volume') if 'amount' in df.columns: required_cols.append('amount') print(f"🔍 Using features for prediction: {required_cols}") print(f" Available columns in data: {list(df.columns)}") print(f" Data shape: {df.shape}") # Check if required columns exist missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: return jsonify({'error': f'Missing required columns: {missing_cols}'}), 400 # Process time period selection start_date = data.get('start_date') if start_date: # Custom time period - fix logic: use data within selected window start_dt = pd.to_datetime(start_date) # Find data after start time mask = df['timestamps'] >= start_dt time_range_df = df[mask] # Ensure sufficient data: lookback + pred_len if len(time_range_df) < lookback + pred_len: return jsonify({ 'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400 # Use first lookback data points within selected window for prediction x_df = time_range_df.iloc[:lookback][required_cols] x_timestamp = time_range_df.iloc[:lookback]['timestamps'] print(f"🔍 Custom time period - x_df shape: {x_df.shape}") print(f" x_timestamp length: {len(x_timestamp)}") print(f" x_df columns: {list(x_df.columns)}") print(f" x_df sample:\n{x_df.head()}") # Generate future timestamps for prediction instead of using existing data # Calculate time difference from the data if len(time_range_df) >= 2: time_diff = time_range_df['timestamps'].iloc[1] - time_range_df['timestamps'].iloc[0] else: time_diff = pd.Timedelta(hours=1) # Default to 1 hour # Generate future timestamps starting from the last timestamp of input data last_timestamp = time_range_df['timestamps'].iloc[lookback - 1] y_timestamp = pd.date_range( start=last_timestamp + time_diff, periods=pred_len, freq=time_diff ) # Calculate actual time period length start_timestamp = time_range_df['timestamps'].iloc[0] end_timestamp = y_timestamp[-1] # Use the last generated timestamp time_span = end_timestamp - start_timestamp prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, {pred_len} future predictions, time span: {time_span})" else: # Use latest data x_df = df.iloc[:lookback][required_cols] x_timestamp = df.iloc[:lookback]['timestamps'] # Generate future timestamps for prediction instead of using existing data # Calculate time difference from the data if len(df) >= 2: time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] else: time_diff = pd.Timedelta(hours=1) # Default to 1 hour # Generate future timestamps starting from the last timestamp of input data last_timestamp = df['timestamps'].iloc[lookback - 1] y_timestamp = pd.date_range( start=last_timestamp + time_diff, periods=pred_len, freq=time_diff ) prediction_type = "Kronos model prediction (latest data)" print(f"🔍 Latest data - x_df shape: {x_df.shape}") print(f" x_timestamp length: {len(x_timestamp)}") print(f" y_timestamp length: {len(y_timestamp)}") print(f" x_df columns: {list(x_df.columns)}") print(f" x_df sample:\n{x_df.head()}") # Check if data is empty if x_df.empty or len(x_df) == 0: return jsonify({'error': 'Input data is empty after processing'}), 400 if len(x_timestamp) == 0: return jsonify({'error': 'Input timestamps are empty'}), 400 if len(y_timestamp) == 0: return jsonify({'error': 'Target timestamps are empty'}), 400 # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model if isinstance(x_timestamp, pd.DatetimeIndex): x_timestamp = pd.Series(x_timestamp, name='timestamps') if isinstance(y_timestamp, pd.DatetimeIndex): y_timestamp = pd.Series(y_timestamp, name='timestamps') pred_df = predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=pred_len, T=temperature, top_p=top_p, sample_count=sample_count ) except Exception as e: return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500 else: return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400 # Prepare actual data for comparison (if exists) actual_data = [] actual_df = None if start_date: # Custom time period # Fix logic: use data within selected window # Prediction uses first 400 data points within selected window # Actual data should be last 120 data points within selected window start_dt = pd.to_datetime(start_date) # 确保时区一致性 if start_dt.tz is None: start_dt = start_dt.tz_localize(BEIJING_TZ) # Find data starting from start_date mask = df['timestamps'] >= start_dt time_range_df = df[mask] if len(time_range_df) >= lookback + pred_len: # Get last 120 data points within selected window as actual values actual_df = time_range_df.iloc[lookback:lookback + pred_len] for i, (_, row) in enumerate(actual_df.iterrows()): actual_data.append({ 'timestamp': row['timestamps'].isoformat(), 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) if 'volume' in row else 0, 'amount': float(row['amount']) if 'amount' in row else 0 }) else: # Latest data # Prediction uses first 400 data points # Actual data should be 120 data points after first 400 data points if len(df) >= lookback + pred_len: actual_df = df.iloc[lookback:lookback + pred_len] for i, (_, row) in enumerate(actual_df.iterrows()): actual_data.append({ 'timestamp': row['timestamps'].isoformat(), 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) if 'volume' in row else 0, 'amount': float(row['amount']) if 'amount' in row else 0 }) # Create chart - pass historical data start position if start_date: # Custom time period: find starting position of historical data in original df start_dt = pd.to_datetime(start_date) # 确保时区一致性 if start_dt.tz is None: start_dt = start_dt.tz_localize(BEIJING_TZ) mask = df['timestamps'] >= start_dt historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0 else: # Latest data: start from beginning historical_start_idx = 0 chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx) # Prepare prediction result data - fix timestamp calculation logic if 'timestamps' in df.columns: if start_date: # Custom time period: use selected window data to calculate timestamps start_dt = pd.to_datetime(start_date) # 确保时区一致性 if start_dt.tz is None: start_dt = start_dt.tz_localize(BEIJING_TZ) mask = df['timestamps'] >= start_dt time_range_df = df[mask] if len(time_range_df) >= lookback: # Calculate prediction timestamps starting from last time point of selected window last_timestamp = time_range_df['timestamps'].iloc[lookback - 1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] future_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=pred_len, freq=time_diff ) else: future_timestamps = [] else: # Latest data: calculate from last time point of entire data file last_timestamp = df['timestamps'].iloc[-1] time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] future_timestamps = pd.date_range( start=last_timestamp + time_diff, periods=pred_len, freq=time_diff ) else: future_timestamps = range(len(df), len(df) + pred_len) prediction_results = [] for i, (_, row) in enumerate(pred_df.iterrows()): prediction_results.append({ 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}", 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) if 'volume' in row else 0, 'amount': float(row['amount']) if 'amount' in row else 0 }) # Save prediction results to file try: data_source = f"{symbol}_{interval}" save_prediction_results( file_path=data_source, prediction_type=prediction_type, prediction_results=prediction_results, actual_data=actual_data, input_data=x_df, prediction_params={ 'symbol': symbol, 'interval': interval, 'limit': limit, 'lookback': lookback, 'pred_len': pred_len, 'temperature': temperature, 'top_p': top_p, 'sample_count': sample_count, 'start_date': start_date if start_date else 'latest' } ) except Exception as e: print(f"Failed to save prediction results: {e}") return jsonify({ 'success': True, 'prediction_type': prediction_type, 'chart': chart_json, 'prediction_results': prediction_results, 'actual_data': actual_data, 'has_comparison': len(actual_data) > 0, 'message': f'Prediction completed, generated {pred_len} prediction points' + ( f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '') }) except Exception as e: return jsonify({'error': f'Prediction failed: {str(e)}'}), 500 @app.route('/api/load-model', methods=['POST']) def load_model(): """Load Kronos model""" global tokenizer, model, predictor try: if not MODEL_AVAILABLE: return jsonify({'error': 'Kronos model library not available'}), 400 data = request.get_json() model_key = data.get('model_key', 'kronos-small') device = data.get('device', 'cpu') if model_key not in AVAILABLE_MODELS: return jsonify({'error': f'Unsupported model: {model_key}'}), 400 model_config = AVAILABLE_MODELS[model_key] # Load tokenizer and model tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id']) model = Kronos.from_pretrained(model_config['model_id']) # Create predictor predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length']) return jsonify({ 'success': True, 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}', 'model_info': { 'name': model_config['name'], 'params': model_config['params'], 'context_length': model_config['context_length'], 'description': model_config['description'] } }) except Exception as e: return jsonify({'error': f'Model loading failed: {str(e)}'}), 500 @app.route('/api/available-models') def get_available_models(): """Get available model list""" return jsonify({ 'models': AVAILABLE_MODELS, 'model_available': MODEL_AVAILABLE }) @app.route('/api/model-status') def get_model_status(): """Get model status""" if MODEL_AVAILABLE: if predictor is not None: return jsonify({ 'available': True, 'loaded': True, 'message': 'Kronos model loaded and available', 'current_model': { 'name': predictor.model.__class__.__name__, 'device': str(next(predictor.model.parameters()).device) } }) else: return jsonify({ 'available': True, 'loaded': False, 'message': 'Kronos model available but not loaded' }) else: return jsonify({ 'available': False, 'loaded': False, 'message': 'Kronos model library not available, please install related dependencies' }) if __name__ == '__main__': print("Starting Kronos Web UI...") print(f"Model availability: {MODEL_AVAILABLE}") if MODEL_AVAILABLE: print("Tip: You can load Kronos model through /api/load-model endpoint") else: print("Tip: Will use simulated data for demonstration") app.run(debug=True, host='0.0.0.0', port=7860)