Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,7 +24,7 @@ if torch.cuda.is_available():
|
|
| 24 |
torch.set_num_threads(min(4, os.cpu_count() or 1))
|
| 25 |
|
| 26 |
class FastAIStockAnalyzer:
|
| 27 |
-
"""Optimized AI Stock Analyzer for Gradio"""
|
| 28 |
|
| 29 |
def __init__(self):
|
| 30 |
self.context_length = 32
|
|
@@ -58,7 +58,7 @@ class FastAIStockAnalyzer:
|
|
| 58 |
return None, None
|
| 59 |
|
| 60 |
def load_chronos_tiny(self) -> Tuple[Optional[Any], str]:
|
| 61 |
-
"""Load Chronos model with caching"""
|
| 62 |
model_key = "chronos_tiny"
|
| 63 |
|
| 64 |
if model_key in self.model_cache:
|
|
@@ -67,6 +67,7 @@ class FastAIStockAnalyzer:
|
|
| 67 |
try:
|
| 68 |
from chronos import ChronosPipeline
|
| 69 |
|
|
|
|
| 70 |
pipeline = ChronosPipeline.from_pretrained(
|
| 71 |
"amazon/chronos-t5-tiny",
|
| 72 |
device_map="cpu",
|
|
@@ -78,11 +79,24 @@ class FastAIStockAnalyzer:
|
|
| 78 |
self.model_cache[model_key] = pipeline
|
| 79 |
return pipeline, "chronos"
|
| 80 |
|
| 81 |
-
except
|
|
|
|
| 82 |
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
def load_moirai_small(self) -> Tuple[Optional[Any], str]:
|
| 85 |
-
"""Load Moirai model with
|
| 86 |
model_key = "moirai_small"
|
| 87 |
|
| 88 |
if model_key in self.model_cache:
|
|
@@ -91,33 +105,81 @@ class FastAIStockAnalyzer:
|
|
| 91 |
try:
|
| 92 |
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
except Exception as e:
|
| 117 |
return None, None
|
| 118 |
|
| 119 |
def predict_chronos_fast(self, pipeline: Any, data: np.ndarray) -> Optional[Dict]:
|
| 120 |
-
"""Fast Chronos prediction"""
|
| 121 |
try:
|
| 122 |
context_data = data[-self.context_length:]
|
| 123 |
context = torch.tensor(context_data, dtype=torch.float32).unsqueeze(0)
|
|
@@ -146,265 +208,381 @@ class FastAIStockAnalyzer:
|
|
| 146 |
return None
|
| 147 |
|
| 148 |
def predict_moirai_fast(self, model: Any, data: np.ndarray) -> Optional[Dict]:
|
| 149 |
-
"""Fast Moirai prediction"""
|
| 150 |
try:
|
| 151 |
from gluonts.dataset.common import ListDataset
|
| 152 |
|
|
|
|
| 153 |
dataset = ListDataset([{
|
| 154 |
"item_id": "stock",
|
| 155 |
"start": "2023-01-01",
|
| 156 |
"target": data[-self.context_length:].tolist()
|
| 157 |
}], freq='D')
|
| 158 |
|
|
|
|
| 159 |
predictor = model.create_predictor(
|
| 160 |
batch_size=1,
|
| 161 |
-
num_parallel_samples=
|
| 162 |
)
|
| 163 |
|
|
|
|
| 164 |
forecasts = list(predictor.predict(dataset))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
forecast = forecasts[0]
|
| 166 |
|
| 167 |
predictions = {
|
| 168 |
'mean': forecast.mean,
|
| 169 |
'q10': forecast.quantile(0.1),
|
| 170 |
'q90': forecast.quantile(0.9),
|
| 171 |
-
'std': np.std(forecast.samples, axis=0)
|
| 172 |
}
|
| 173 |
|
| 174 |
return predictions
|
| 175 |
|
| 176 |
except Exception as e:
|
| 177 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# Initialize analyzer globally for caching
|
| 180 |
analyzer = FastAIStockAnalyzer()
|
| 181 |
|
| 182 |
def analyze_stock(stock_symbol, model_choice, investment_amount, progress=gr.Progress()):
|
| 183 |
-
"""Main analysis function
|
| 184 |
-
|
| 185 |
-
progress(0.1, desc="Fetching stock data...")
|
| 186 |
-
|
| 187 |
-
# Fetch data
|
| 188 |
-
stock_data, stock_info = analyzer.fetch_stock_data(stock_symbol)
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
## π€
|
| 269 |
**{explanation}**
|
| 270 |
*Powered by {model_name}*
|
| 271 |
|
| 272 |
## π Key Metrics
|
| 273 |
- **Current Price**: ${current_price:.2f}
|
| 274 |
- **7-Day Prediction**: ${final_pred:.2f} ({week_change:+.2f}%)
|
| 275 |
-
- **
|
| 276 |
-
- **
|
| 277 |
|
| 278 |
## π° Investment Scenario (${investment_amount:,.0f})
|
| 279 |
- **Shares**: {investment_amount/current_price:.2f}
|
|
|
|
| 280 |
- **Predicted Value**: ${investment_amount + ((final_pred - current_price) * (investment_amount/current_price)):,.2f}
|
| 281 |
- **Profit/Loss**: ${((final_pred - current_price) * (investment_amount/current_price)):+,.2f} ({week_change:+.2f}%)
|
| 282 |
|
| 283 |
-
β οΈ
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
"""
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
fig.add_trace(go.Scatter(
|
| 294 |
-
x=recent.index,
|
| 295 |
-
y=recent['Close'],
|
| 296 |
-
mode='lines',
|
| 297 |
-
name='Historical Price',
|
| 298 |
-
line=dict(color='blue', width=2)
|
| 299 |
-
))
|
| 300 |
-
|
| 301 |
-
# Predictions
|
| 302 |
-
future_dates = pd.date_range(
|
| 303 |
-
start=stock_data.index[-1] + pd.Timedelta(days=1),
|
| 304 |
-
periods=7,
|
| 305 |
-
freq='D'
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
fig.add_trace(go.Scatter(
|
| 309 |
-
x=future_dates,
|
| 310 |
-
y=mean_pred,
|
| 311 |
-
mode='lines+markers',
|
| 312 |
-
name='AI Prediction',
|
| 313 |
-
line=dict(color='red', width=3),
|
| 314 |
-
marker=dict(size=8)
|
| 315 |
-
))
|
| 316 |
-
|
| 317 |
-
# Confidence bands
|
| 318 |
-
if 'q10' in predictions and 'q90' in predictions:
|
| 319 |
fig.add_trace(go.Scatter(
|
| 320 |
-
x=
|
| 321 |
-
y=
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
line=dict(color='
|
| 325 |
-
name='Confidence Range',
|
| 326 |
-
showlegend=True
|
| 327 |
))
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
with gr.Blocks(
|
| 357 |
theme=gr.themes.Soft(),
|
| 358 |
title="β‘ Fast AI Stock Predictor",
|
| 359 |
-
css="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
) as demo:
|
| 361 |
|
| 362 |
gr.HTML("""
|
| 363 |
-
<div
|
| 364 |
-
<h1>β‘
|
| 365 |
-
<p><strong>π€ Powered by Amazon Chronos & Salesforce Moirai</strong></p>
|
| 366 |
-
<p style="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
</div>
|
| 368 |
""")
|
| 369 |
|
| 370 |
with gr.Row():
|
| 371 |
-
with gr.Column(scale=1):
|
| 372 |
-
gr.HTML("<h3>π― Configuration</h3>")
|
| 373 |
|
| 374 |
stock_input = gr.Dropdown(
|
| 375 |
-
choices=["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN", "META", "NFLX", "NVDA"],
|
| 376 |
value="AAPL",
|
| 377 |
-
label="Select Stock",
|
| 378 |
allow_custom_value=True,
|
| 379 |
-
info="Choose
|
| 380 |
)
|
| 381 |
|
| 382 |
model_input = gr.Radio(
|
| 383 |
-
choices=["π Chronos (Fast)", "π― Moirai (
|
| 384 |
-
value="π Chronos (Fast)",
|
| 385 |
-
label="AI Model",
|
| 386 |
-
info="Chronos: Faster | Moirai: More
|
| 387 |
)
|
| 388 |
|
| 389 |
investment_input = gr.Slider(
|
| 390 |
minimum=500,
|
| 391 |
-
maximum=
|
| 392 |
value=5000,
|
| 393 |
step=500,
|
| 394 |
-
label="Investment Amount ($)",
|
| 395 |
-
info="Amount
|
| 396 |
)
|
| 397 |
|
| 398 |
analyze_btn = gr.Button(
|
| 399 |
-
"π Analyze Stock",
|
| 400 |
variant="primary",
|
| 401 |
-
size="lg"
|
|
|
|
| 402 |
)
|
| 403 |
|
| 404 |
-
|
| 405 |
-
gr.HTML("<h3>π Results</h3>")
|
| 406 |
|
| 407 |
-
|
|
|
|
|
|
|
| 408 |
current_price_display = gr.Textbox(
|
| 409 |
label="Current Price",
|
| 410 |
interactive=False,
|
|
@@ -416,21 +594,41 @@ with gr.Blocks(
|
|
| 416 |
container=True
|
| 417 |
)
|
| 418 |
decision_display = gr.Textbox(
|
| 419 |
-
label="AI
|
| 420 |
interactive=False,
|
| 421 |
container=True
|
| 422 |
)
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
with gr.Row():
|
| 431 |
chart_output = gr.Plot(
|
| 432 |
-
label="Price Chart & Predictions",
|
| 433 |
-
container=True
|
|
|
|
| 434 |
)
|
| 435 |
|
| 436 |
# Event handlers
|
|
@@ -443,19 +641,44 @@ with gr.Blocks(
|
|
| 443 |
decision_display,
|
| 444 |
current_price_display,
|
| 445 |
prediction_display
|
| 446 |
-
]
|
|
|
|
| 447 |
)
|
| 448 |
|
| 449 |
-
# Examples
|
|
|
|
| 450 |
gr.Examples(
|
| 451 |
examples=[
|
| 452 |
-
["AAPL", "π Chronos (Fast)", 5000],
|
| 453 |
-
["TSLA", "π― Moirai (
|
| 454 |
-
["GOOGL", "π Chronos (Fast)", 2500],
|
|
|
|
|
|
|
| 455 |
],
|
| 456 |
inputs=[stock_input, model_input, investment_input],
|
| 457 |
-
label="
|
|
|
|
| 458 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
|
|
|
| 460 |
if __name__ == "__main__":
|
| 461 |
-
demo.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
torch.set_num_threads(min(4, os.cpu_count() or 1))
|
| 25 |
|
| 26 |
class FastAIStockAnalyzer:
|
| 27 |
+
"""Optimized AI Stock Analyzer for Gradio with robust error handling"""
|
| 28 |
|
| 29 |
def __init__(self):
|
| 30 |
self.context_length = 32
|
|
|
|
| 58 |
return None, None
|
| 59 |
|
| 60 |
def load_chronos_tiny(self) -> Tuple[Optional[Any], str]:
|
| 61 |
+
"""Load Chronos model with caching and fallback"""
|
| 62 |
model_key = "chronos_tiny"
|
| 63 |
|
| 64 |
if model_key in self.model_cache:
|
|
|
|
| 67 |
try:
|
| 68 |
from chronos import ChronosPipeline
|
| 69 |
|
| 70 |
+
# Try primary loading method
|
| 71 |
pipeline = ChronosPipeline.from_pretrained(
|
| 72 |
"amazon/chronos-t5-tiny",
|
| 73 |
device_map="cpu",
|
|
|
|
| 79 |
self.model_cache[model_key] = pipeline
|
| 80 |
return pipeline, "chronos"
|
| 81 |
|
| 82 |
+
except ImportError:
|
| 83 |
+
# Chronos not available
|
| 84 |
return None, None
|
| 85 |
+
except Exception as e:
|
| 86 |
+
# Try fallback loading method
|
| 87 |
+
try:
|
| 88 |
+
pipeline = ChronosPipeline.from_pretrained(
|
| 89 |
+
"amazon/chronos-t5-tiny",
|
| 90 |
+
device_map="auto",
|
| 91 |
+
torch_dtype=torch.float32
|
| 92 |
+
)
|
| 93 |
+
self.model_cache[model_key] = pipeline
|
| 94 |
+
return pipeline, "chronos"
|
| 95 |
+
except:
|
| 96 |
+
return None, None
|
| 97 |
|
| 98 |
def load_moirai_small(self) -> Tuple[Optional[Any], str]:
|
| 99 |
+
"""Load Moirai model with updated method and fallbacks"""
|
| 100 |
model_key = "moirai_small"
|
| 101 |
|
| 102 |
if model_key in self.model_cache:
|
|
|
|
| 105 |
try:
|
| 106 |
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
|
| 107 |
|
| 108 |
+
# Method 1: Try the standard approach
|
| 109 |
+
try:
|
| 110 |
+
module = MoiraiModule.from_pretrained(
|
| 111 |
+
"Salesforce/moirai-1.0-R-small",
|
| 112 |
+
device_map="cpu",
|
| 113 |
+
torch_dtype=torch.float32,
|
| 114 |
+
trust_remote_code=True,
|
| 115 |
+
low_cpu_mem_usage=True
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
model = MoiraiForecast(
|
| 119 |
+
module=module,
|
| 120 |
+
prediction_length=self.prediction_length,
|
| 121 |
+
context_length=self.context_length,
|
| 122 |
+
patch_size="auto",
|
| 123 |
+
num_samples=10,
|
| 124 |
+
target_dim=1,
|
| 125 |
+
feat_dynamic_real_dim=0,
|
| 126 |
+
past_feat_dynamic_real_dim=0
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.model_cache[model_key] = model
|
| 130 |
+
return model, "moirai"
|
| 131 |
+
|
| 132 |
+
except Exception as e1:
|
| 133 |
+
# Method 2: Try newer version
|
| 134 |
+
try:
|
| 135 |
+
module = MoiraiModule.from_pretrained(
|
| 136 |
+
"Salesforce/moirai-1.1-R-small",
|
| 137 |
+
device_map="cpu",
|
| 138 |
+
torch_dtype=torch.float32,
|
| 139 |
+
trust_remote_code=True
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
model = MoiraiForecast(
|
| 143 |
+
module=module,
|
| 144 |
+
prediction_length=self.prediction_length,
|
| 145 |
+
context_length=self.context_length,
|
| 146 |
+
patch_size="auto",
|
| 147 |
+
num_samples=5, # Reduced for stability
|
| 148 |
+
target_dim=1,
|
| 149 |
+
feat_dynamic_real_dim=0,
|
| 150 |
+
past_feat_dynamic_real_dim=0
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.model_cache[model_key] = model
|
| 154 |
+
return model, "moirai"
|
| 155 |
+
|
| 156 |
+
except Exception as e2:
|
| 157 |
+
# Method 3: Minimal configuration
|
| 158 |
+
try:
|
| 159 |
+
module = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-small")
|
| 160 |
+
model = MoiraiForecast(
|
| 161 |
+
module=module,
|
| 162 |
+
prediction_length=7,
|
| 163 |
+
context_length=32,
|
| 164 |
+
patch_size="auto",
|
| 165 |
+
num_samples=5,
|
| 166 |
+
target_dim=1
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.model_cache[model_key] = model
|
| 170 |
+
return model, "moirai"
|
| 171 |
+
|
| 172 |
+
except Exception as e3:
|
| 173 |
+
return None, None
|
| 174 |
+
|
| 175 |
+
except ImportError:
|
| 176 |
+
# uni2ts not available
|
| 177 |
+
return None, None
|
| 178 |
except Exception as e:
|
| 179 |
return None, None
|
| 180 |
|
| 181 |
def predict_chronos_fast(self, pipeline: Any, data: np.ndarray) -> Optional[Dict]:
|
| 182 |
+
"""Fast Chronos prediction with error handling"""
|
| 183 |
try:
|
| 184 |
context_data = data[-self.context_length:]
|
| 185 |
context = torch.tensor(context_data, dtype=torch.float32).unsqueeze(0)
|
|
|
|
| 208 |
return None
|
| 209 |
|
| 210 |
def predict_moirai_fast(self, model: Any, data: np.ndarray) -> Optional[Dict]:
|
| 211 |
+
"""Fast Moirai prediction with enhanced error handling"""
|
| 212 |
try:
|
| 213 |
from gluonts.dataset.common import ListDataset
|
| 214 |
|
| 215 |
+
# Prepare dataset with minimal configuration
|
| 216 |
dataset = ListDataset([{
|
| 217 |
"item_id": "stock",
|
| 218 |
"start": "2023-01-01",
|
| 219 |
"target": data[-self.context_length:].tolist()
|
| 220 |
}], freq='D')
|
| 221 |
|
| 222 |
+
# Create predictor with safe parameters
|
| 223 |
predictor = model.create_predictor(
|
| 224 |
batch_size=1,
|
| 225 |
+
num_parallel_samples=5 # Further reduced for stability
|
| 226 |
)
|
| 227 |
|
| 228 |
+
# Generate forecast
|
| 229 |
forecasts = list(predictor.predict(dataset))
|
| 230 |
+
|
| 231 |
+
if not forecasts:
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
forecast = forecasts[0]
|
| 235 |
|
| 236 |
predictions = {
|
| 237 |
'mean': forecast.mean,
|
| 238 |
'q10': forecast.quantile(0.1),
|
| 239 |
'q90': forecast.quantile(0.9),
|
| 240 |
+
'std': np.std(forecast.samples, axis=0) if hasattr(forecast, 'samples') else np.zeros(7)
|
| 241 |
}
|
| 242 |
|
| 243 |
return predictions
|
| 244 |
|
| 245 |
except Exception as e:
|
| 246 |
return None
|
| 247 |
+
|
| 248 |
+
def generate_simple_prediction(self, data: np.ndarray) -> Dict:
|
| 249 |
+
"""Fallback prediction method using simple statistical models"""
|
| 250 |
+
try:
|
| 251 |
+
# Simple moving average with trend
|
| 252 |
+
recent_data = data[-30:] # Last 30 days
|
| 253 |
+
short_ma = np.mean(recent_data[-7:]) # 7-day average
|
| 254 |
+
long_ma = np.mean(recent_data[-21:]) # 21-day average
|
| 255 |
+
|
| 256 |
+
# Calculate trend
|
| 257 |
+
trend = (short_ma - long_ma) / long_ma if long_ma != 0 else 0
|
| 258 |
+
|
| 259 |
+
# Generate predictions
|
| 260 |
+
current_price = data[-1]
|
| 261 |
+
predictions = []
|
| 262 |
+
|
| 263 |
+
for i in range(7):
|
| 264 |
+
# Simple trend projection with some noise
|
| 265 |
+
predicted_price = current_price * (1 + trend * (i + 1) * 0.1)
|
| 266 |
+
predictions.append(predicted_price)
|
| 267 |
+
|
| 268 |
+
predictions = np.array(predictions)
|
| 269 |
+
|
| 270 |
+
return {
|
| 271 |
+
'mean': predictions,
|
| 272 |
+
'q10': predictions * 0.95, # 5% lower
|
| 273 |
+
'q90': predictions * 1.05, # 5% higher
|
| 274 |
+
'std': np.full(7, np.std(recent_data) * 0.5)
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
except Exception:
|
| 278 |
+
# Ultimate fallback - flat prediction
|
| 279 |
+
current_price = data[-1]
|
| 280 |
+
return {
|
| 281 |
+
'mean': np.full(7, current_price),
|
| 282 |
+
'q10': np.full(7, current_price * 0.98),
|
| 283 |
+
'q90': np.full(7, current_price * 1.02),
|
| 284 |
+
'std': np.full(7, 0.01)
|
| 285 |
+
}
|
| 286 |
|
| 287 |
# Initialize analyzer globally for caching
|
| 288 |
analyzer = FastAIStockAnalyzer()
|
| 289 |
|
| 290 |
def analyze_stock(stock_symbol, model_choice, investment_amount, progress=gr.Progress()):
|
| 291 |
+
"""Main analysis function with comprehensive error handling and fallbacks"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
try:
|
| 294 |
+
progress(0.1, desc="Fetching stock data...")
|
| 295 |
+
|
| 296 |
+
# Validate input
|
| 297 |
+
if not stock_symbol or stock_symbol.strip() == "":
|
| 298 |
+
return (
|
| 299 |
+
"β Error: Please enter a valid stock symbol.",
|
| 300 |
+
None,
|
| 301 |
+
"β Invalid Input",
|
| 302 |
+
"N/A",
|
| 303 |
+
"N/A"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Fetch data
|
| 307 |
+
stock_data, stock_info = analyzer.fetch_stock_data(stock_symbol.upper())
|
| 308 |
+
|
| 309 |
+
if stock_data is None or len(stock_data) < 50:
|
| 310 |
+
return (
|
| 311 |
+
f"β Error: Insufficient data for {stock_symbol.upper()}. Please check the stock symbol or try a different one.",
|
| 312 |
+
None,
|
| 313 |
+
"β Data Error",
|
| 314 |
+
"N/A",
|
| 315 |
+
"N/A"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
current_price = stock_data['Close'].iloc[-1]
|
| 319 |
+
company_name = stock_info.get('longName', stock_symbol) if stock_info else stock_symbol
|
| 320 |
+
|
| 321 |
+
progress(0.3, desc="Loading AI model...")
|
| 322 |
+
|
| 323 |
+
# Determine model type and load
|
| 324 |
+
model_type = "chronos" if "Chronos" in model_choice else "moirai"
|
| 325 |
+
model = None
|
| 326 |
+
model_name = ""
|
| 327 |
+
prediction_method = None
|
| 328 |
+
|
| 329 |
+
if model_type == "chronos":
|
| 330 |
+
model, loaded_type = analyzer.load_chronos_tiny()
|
| 331 |
+
model_name = "Amazon Chronos Tiny"
|
| 332 |
+
prediction_method = "chronos"
|
| 333 |
+
else:
|
| 334 |
+
model, loaded_type = analyzer.load_moirai_small()
|
| 335 |
+
model_name = "Salesforce Moirai Small"
|
| 336 |
+
prediction_method = "moirai"
|
| 337 |
+
|
| 338 |
+
# Fallback to Chronos if Moirai fails
|
| 339 |
+
if model is None:
|
| 340 |
+
progress(0.4, desc="Moirai unavailable, switching to Chronos...")
|
| 341 |
+
model, loaded_type = analyzer.load_chronos_tiny()
|
| 342 |
+
model_name = "Amazon Chronos Tiny (Fallback)"
|
| 343 |
+
prediction_method = "chronos"
|
| 344 |
+
|
| 345 |
+
# If both models fail, use simple prediction
|
| 346 |
+
if model is None:
|
| 347 |
+
progress(0.5, desc="Using statistical fallback method...")
|
| 348 |
+
model_name = "Statistical Trend Model (Fallback)"
|
| 349 |
+
prediction_method = "simple"
|
| 350 |
+
|
| 351 |
+
progress(0.6, desc="Generating predictions...")
|
| 352 |
+
|
| 353 |
+
# Generate predictions based on available method
|
| 354 |
+
predictions = None
|
| 355 |
+
|
| 356 |
+
if prediction_method == "chronos" and model is not None:
|
| 357 |
+
predictions = analyzer.predict_chronos_fast(model, stock_data['Close'].values)
|
| 358 |
+
elif prediction_method == "moirai" and model is not None:
|
| 359 |
+
predictions = analyzer.predict_moirai_fast(model, stock_data['Close'].values)
|
| 360 |
+
|
| 361 |
+
# Use simple prediction if AI models fail
|
| 362 |
+
if predictions is None:
|
| 363 |
+
predictions = analyzer.generate_simple_prediction(stock_data['Close'].values)
|
| 364 |
+
model_name = "Statistical Trend Model (AI Models Unavailable)"
|
| 365 |
+
|
| 366 |
+
progress(0.8, desc="Calculating investment scenarios...")
|
| 367 |
+
|
| 368 |
+
# Analysis results
|
| 369 |
+
mean_pred = predictions['mean']
|
| 370 |
+
final_pred = mean_pred[-1]
|
| 371 |
+
week_change = ((final_pred - current_price) / current_price) * 100
|
| 372 |
+
|
| 373 |
+
# Decision logic
|
| 374 |
+
if week_change > 5:
|
| 375 |
+
decision = "π’ STRONG BUY"
|
| 376 |
+
explanation = "Model expects significant gains!"
|
| 377 |
+
elif week_change > 2:
|
| 378 |
+
decision = "π’ BUY"
|
| 379 |
+
explanation = "Model expects moderate gains"
|
| 380 |
+
elif week_change < -5:
|
| 381 |
+
decision = "π΄ STRONG SELL"
|
| 382 |
+
explanation = "Model expects significant losses"
|
| 383 |
+
elif week_change < -2:
|
| 384 |
+
decision = "π΄ SELL"
|
| 385 |
+
explanation = "Model expects losses"
|
| 386 |
+
else:
|
| 387 |
+
decision = "βͺ HOLD"
|
| 388 |
+
explanation = "Model expects stable prices"
|
| 389 |
+
|
| 390 |
+
# Create analysis text
|
| 391 |
+
analysis_text = f"""
|
| 392 |
+
# π― {company_name} ({stock_symbol.upper()}) Analysis
|
| 393 |
|
| 394 |
+
## π€ RECOMMENDATION: {decision}
|
| 395 |
**{explanation}**
|
| 396 |
*Powered by {model_name}*
|
| 397 |
|
| 398 |
## π Key Metrics
|
| 399 |
- **Current Price**: ${current_price:.2f}
|
| 400 |
- **7-Day Prediction**: ${final_pred:.2f} ({week_change:+.2f}%)
|
| 401 |
+
- **Confidence Level**: {min(100, max(50, 70 + abs(week_change) * 1.5)):.0f}%
|
| 402 |
+
- **Analysis Method**: {model_name}
|
| 403 |
|
| 404 |
## π° Investment Scenario (${investment_amount:,.0f})
|
| 405 |
- **Shares**: {investment_amount/current_price:.2f}
|
| 406 |
+
- **Current Value**: ${investment_amount:,.2f}
|
| 407 |
- **Predicted Value**: ${investment_amount + ((final_pred - current_price) * (investment_amount/current_price)):,.2f}
|
| 408 |
- **Profit/Loss**: ${((final_pred - current_price) * (investment_amount/current_price)):+,.2f} ({week_change:+.2f}%)
|
| 409 |
|
| 410 |
+
## β οΈ Important Disclaimers
|
| 411 |
+
- **This is for educational purposes only**
|
| 412 |
+
- **Not financial advice - consult professionals**
|
| 413 |
+
- **AI predictions can be wrong - invest responsibly**
|
| 414 |
+
- **Past performance β future results**
|
| 415 |
"""
|
| 416 |
+
|
| 417 |
+
progress(0.9, desc="Creating visualizations...")
|
| 418 |
+
|
| 419 |
+
# Create chart
|
| 420 |
+
fig = go.Figure()
|
| 421 |
+
|
| 422 |
+
# Historical data (last 30 days)
|
| 423 |
+
recent = stock_data.tail(30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
fig.add_trace(go.Scatter(
|
| 425 |
+
x=recent.index,
|
| 426 |
+
y=recent['Close'],
|
| 427 |
+
mode='lines',
|
| 428 |
+
name='Historical Price',
|
| 429 |
+
line=dict(color='blue', width=2)
|
|
|
|
|
|
|
| 430 |
))
|
| 431 |
+
|
| 432 |
+
# Predictions
|
| 433 |
+
future_dates = pd.date_range(
|
| 434 |
+
start=stock_data.index[-1] + pd.Timedelta(days=1),
|
| 435 |
+
periods=7,
|
| 436 |
+
freq='D'
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
fig.add_trace(go.Scatter(
|
| 440 |
+
x=future_dates,
|
| 441 |
+
y=mean_pred,
|
| 442 |
+
mode='lines+markers',
|
| 443 |
+
name='Prediction',
|
| 444 |
+
line=dict(color='red', width=3),
|
| 445 |
+
marker=dict(size=8)
|
| 446 |
+
))
|
| 447 |
+
|
| 448 |
+
# Confidence bands
|
| 449 |
+
if 'q10' in predictions and 'q90' in predictions:
|
| 450 |
+
fig.add_trace(go.Scatter(
|
| 451 |
+
x=future_dates.tolist() + future_dates[::-1].tolist(),
|
| 452 |
+
y=predictions['q90'].tolist() + predictions['q10'][::-1].tolist(),
|
| 453 |
+
fill='toself',
|
| 454 |
+
fillcolor='rgba(255,0,0,0.1)',
|
| 455 |
+
line=dict(color='rgba(255,255,255,0)'),
|
| 456 |
+
name='Confidence Range',
|
| 457 |
+
showlegend=True
|
| 458 |
+
))
|
| 459 |
+
|
| 460 |
+
fig.update_layout(
|
| 461 |
+
title=f"{stock_symbol.upper()} - Stock Forecast ({model_name})",
|
| 462 |
+
xaxis_title="Date",
|
| 463 |
+
yaxis_title="Price ($)",
|
| 464 |
+
height=500,
|
| 465 |
+
showlegend=True,
|
| 466 |
+
template="plotly_white"
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
progress(1.0, desc="Analysis complete!")
|
| 470 |
+
|
| 471 |
+
# Create summary metrics
|
| 472 |
+
try:
|
| 473 |
+
day_change = stock_data['Close'].iloc[-1] - stock_data['Close'].iloc[-2]
|
| 474 |
+
day_change_pct = (day_change / stock_data['Close'].iloc[-2]) * 100
|
| 475 |
+
except:
|
| 476 |
+
day_change_pct = 0
|
| 477 |
+
|
| 478 |
+
current_metrics = f"${current_price:.2f} ({day_change_pct:+.2f}%)"
|
| 479 |
+
prediction_metrics = f"${final_pred:.2f} ({week_change:+.2f}%)"
|
| 480 |
+
|
| 481 |
+
return (
|
| 482 |
+
analysis_text,
|
| 483 |
+
fig,
|
| 484 |
+
decision,
|
| 485 |
+
current_metrics,
|
| 486 |
+
prediction_metrics
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
except Exception as e:
|
| 490 |
+
# Ultimate error handler
|
| 491 |
+
error_msg = f"""
|
| 492 |
+
# β Analysis Error
|
| 493 |
+
|
| 494 |
+
**Something went wrong during the analysis:**
|
| 495 |
+
|
| 496 |
+
- **Error**: {str(e)[:200]}...
|
| 497 |
+
- **Stock**: {stock_symbol}
|
| 498 |
+
- **Model**: {model_choice}
|
| 499 |
+
|
| 500 |
+
## π§ Try these solutions:
|
| 501 |
+
1. **Check stock symbol** - Make sure it's valid (e.g., AAPL, GOOGL)
|
| 502 |
+
2. **Try different model** - Switch between Chronos and Moirai
|
| 503 |
+
3. **Refresh and try again** - Temporary server issues
|
| 504 |
+
4. **Use popular stocks** - AAPL, MSFT, GOOGL work best
|
| 505 |
|
| 506 |
+
## π Still having issues?
|
| 507 |
+
This may be due to Hugging Face Spaces resource limitations or model availability.
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
return (
|
| 511 |
+
error_msg,
|
| 512 |
+
None,
|
| 513 |
+
"β Error",
|
| 514 |
+
"N/A",
|
| 515 |
+
"N/A"
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Create Gradio interface with enhanced UI
|
| 519 |
with gr.Blocks(
|
| 520 |
theme=gr.themes.Soft(),
|
| 521 |
title="β‘ Fast AI Stock Predictor",
|
| 522 |
+
css="""
|
| 523 |
+
footer {visibility: hidden}
|
| 524 |
+
.gradio-container {max-width: 1200px; margin: auto;}
|
| 525 |
+
.main-header {text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;}
|
| 526 |
+
.disclaimer {background-color: #fff3cd; border: 1px solid #ffeaa7; padding: 15px; border-radius: 8px; margin: 10px 0;}
|
| 527 |
+
"""
|
| 528 |
) as demo:
|
| 529 |
|
| 530 |
gr.HTML("""
|
| 531 |
+
<div class="main-header">
|
| 532 |
+
<h1 style="margin: 0; font-size: 2.5em;">β‘ AI Stock Predictor</h1>
|
| 533 |
+
<p style="margin: 10px 0 0 0; font-size: 1.2em;"><strong>π€ Powered by Amazon Chronos & Salesforce Moirai</strong></p>
|
| 534 |
+
<p style="margin: 5px 0 0 0; opacity: 0.9;">Advanced AI models for stock price forecasting</p>
|
| 535 |
+
</div>
|
| 536 |
+
""")
|
| 537 |
+
|
| 538 |
+
gr.HTML("""
|
| 539 |
+
<div class="disclaimer">
|
| 540 |
+
<strong>β οΈ IMPORTANT DISCLAIMER:</strong> This tool is for educational purposes only.
|
| 541 |
+
Not financial advice. AI predictions can be wrong. Always consult financial professionals
|
| 542 |
+
before making investment decisions. Only invest what you can afford to lose.
|
| 543 |
</div>
|
| 544 |
""")
|
| 545 |
|
| 546 |
with gr.Row():
|
| 547 |
+
with gr.Column(scale=1, min_width=300):
|
| 548 |
+
gr.HTML("<h3>π― Analysis Configuration</h3>")
|
| 549 |
|
| 550 |
stock_input = gr.Dropdown(
|
| 551 |
+
choices=["AAPL", "GOOGL", "MSFT", "TSLA", "AMZN", "META", "NFLX", "NVDA", "ORCL", "CRM"],
|
| 552 |
value="AAPL",
|
| 553 |
+
label="π Select Stock Symbol",
|
| 554 |
allow_custom_value=True,
|
| 555 |
+
info="Choose popular stocks or enter any valid symbol"
|
| 556 |
)
|
| 557 |
|
| 558 |
model_input = gr.Radio(
|
| 559 |
+
choices=["π Chronos (Fast & Reliable)", "π― Moirai (Advanced)"],
|
| 560 |
+
value="π Chronos (Fast & Reliable)",
|
| 561 |
+
label="π€ AI Model Selection",
|
| 562 |
+
info="Chronos: Faster, more stable | Moirai: More sophisticated (may fallback to Chronos)"
|
| 563 |
)
|
| 564 |
|
| 565 |
investment_input = gr.Slider(
|
| 566 |
minimum=500,
|
| 567 |
+
maximum=100000,
|
| 568 |
value=5000,
|
| 569 |
step=500,
|
| 570 |
+
label="π° Investment Amount ($)",
|
| 571 |
+
info="Amount for profit/loss calculation"
|
| 572 |
)
|
| 573 |
|
| 574 |
analyze_btn = gr.Button(
|
| 575 |
+
"π Analyze Stock Now",
|
| 576 |
variant="primary",
|
| 577 |
+
size="lg",
|
| 578 |
+
scale=1
|
| 579 |
)
|
| 580 |
|
| 581 |
+
gr.HTML("<br>")
|
|
|
|
| 582 |
|
| 583 |
+
# Quick stats
|
| 584 |
+
with gr.Group():
|
| 585 |
+
gr.HTML("<h4>π Quick Metrics</h4>")
|
| 586 |
current_price_display = gr.Textbox(
|
| 587 |
label="Current Price",
|
| 588 |
interactive=False,
|
|
|
|
| 594 |
container=True
|
| 595 |
)
|
| 596 |
decision_display = gr.Textbox(
|
| 597 |
+
label="AI Recommendation",
|
| 598 |
interactive=False,
|
| 599 |
container=True
|
| 600 |
)
|
| 601 |
+
|
| 602 |
+
with gr.Column(scale=2, min_width=600):
|
| 603 |
+
gr.HTML("<h3>π Analysis Results</h3>")
|
| 604 |
+
|
| 605 |
+
analysis_output = gr.Markdown(
|
| 606 |
+
value="""
|
| 607 |
+
# π Welcome to AI Stock Predictor!
|
| 608 |
+
|
| 609 |
+
**Ready to analyze stocks with cutting-edge AI?**
|
| 610 |
+
|
| 611 |
+
π― **How to use:**
|
| 612 |
+
1. Select a stock symbol (or enter your own)
|
| 613 |
+
2. Choose AI model (Chronos recommended for beginners)
|
| 614 |
+
3. Set investment amount for scenario analysis
|
| 615 |
+
4. Click "Analyze Stock Now"
|
| 616 |
+
|
| 617 |
+
π‘ **Tips:**
|
| 618 |
+
- Try popular stocks like AAPL, GOOGL, MSFT first
|
| 619 |
+
- Chronos model is faster and more reliable
|
| 620 |
+
- Analysis takes 30-60 seconds for first-time model loading
|
| 621 |
+
|
| 622 |
+
β‘ **Click the button to get started!**
|
| 623 |
+
""",
|
| 624 |
+
container=True
|
| 625 |
+
)
|
| 626 |
|
| 627 |
with gr.Row():
|
| 628 |
chart_output = gr.Plot(
|
| 629 |
+
label="π Stock Price Chart & AI Predictions",
|
| 630 |
+
container=True,
|
| 631 |
+
show_label=True
|
| 632 |
)
|
| 633 |
|
| 634 |
# Event handlers
|
|
|
|
| 641 |
decision_display,
|
| 642 |
current_price_display,
|
| 643 |
prediction_display
|
| 644 |
+
],
|
| 645 |
+
show_progress=True
|
| 646 |
)
|
| 647 |
|
| 648 |
+
# Examples section
|
| 649 |
+
gr.HTML("<h3>π Try These Examples</h3>")
|
| 650 |
gr.Examples(
|
| 651 |
examples=[
|
| 652 |
+
["AAPL", "π Chronos (Fast & Reliable)", 5000],
|
| 653 |
+
["TSLA", "π― Moirai (Advanced)", 10000],
|
| 654 |
+
["GOOGL", "π Chronos (Fast & Reliable)", 2500],
|
| 655 |
+
["MSFT", "π― Moirai (Advanced)", 7500],
|
| 656 |
+
["NVDA", "π Chronos (Fast & Reliable)", 3000],
|
| 657 |
],
|
| 658 |
inputs=[stock_input, model_input, investment_input],
|
| 659 |
+
label="Click any example to load it:",
|
| 660 |
+
examples_per_page=5
|
| 661 |
)
|
| 662 |
+
|
| 663 |
+
# Footer
|
| 664 |
+
gr.HTML("""
|
| 665 |
+
<div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #eee;">
|
| 666 |
+
<p><strong>π€ AI Stock Predictor</strong> | Built with β€οΈ using Gradio & Hugging Face</p>
|
| 667 |
+
<p style="font-size: 12px; color: #666;">
|
| 668 |
+
Powered by Amazon Chronos & Salesforce Moirai |
|
| 669 |
+
Educational Tool - Not Financial Advice
|
| 670 |
+
</p>
|
| 671 |
+
</div>
|
| 672 |
+
""")
|
| 673 |
|
| 674 |
+
# Launch configuration
|
| 675 |
if __name__ == "__main__":
|
| 676 |
+
demo.launch(
|
| 677 |
+
share=True, # Set to True for public sharing
|
| 678 |
+
server_name="0.0.0.0",
|
| 679 |
+
server_port=7860,
|
| 680 |
+
show_error=True,
|
| 681 |
+
show_tips=True,
|
| 682 |
+
enable_queue=True,
|
| 683 |
+
max_threads=10
|
| 684 |
+
)
|