BaoKhuong commited on
Commit
cba485a
·
verified ·
1 Parent(s): 81039c9

Delete app1 - claude-4.1-opus.py

Browse files
Files changed (1) hide show
  1. app1 - claude-4.1-opus.py +0 -803
app1 - claude-4.1-opus.py DELETED
@@ -1,803 +0,0 @@
1
- import os
2
- import json
3
- import time
4
- import random
5
- from collections import defaultdict
6
- from datetime import date, datetime, timedelta
7
- import gradio as gr
8
- import pandas as pd
9
- import finnhub
10
- import google.generativeai as genai
11
- from datasets import load_dataset
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
- from peft import PeftModel
14
- from io import StringIO
15
- import requests
16
- from requests.adapters import HTTPAdapter
17
- from urllib3.util.retry import Retry
18
- from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizerFast
19
- from peft import PeftModel # 0.5.0
20
- import torch
21
-
22
- # Suppress Google Cloud warnings
23
- os.environ['GRPC_VERBOSITY'] = 'ERROR'
24
- os.environ['GRPC_TRACE'] = ''
25
-
26
- # Suppress other warnings
27
- import warnings
28
- warnings.filterwarnings('ignore', category=UserWarning)
29
- warnings.filterwarnings('ignore', category=FutureWarning)
30
-
31
- # ---------- CẤU HÌNH ---------------------------------------------------------
32
-
33
- GEMINI_MODEL = "gemini-2.5-pro" # legacy, no longer used for generation
34
- FIN_MODEL_ID = "TheFinAI/Fin-o1-14B"
35
- USE_LOCAL_FIN_MODEL = os.getenv("USE_LOCAL_FIN_MODEL", "0").strip() in {"1", "true", "True", "YES", "yes"}
36
-
37
- # RapidAPI Configuration
38
- RAPIDAPI_HOST = "alpha-vantage.p.rapidapi.com"
39
-
40
- # Load Finnhub API keys from single secret (multiple keys separated by newlines)
41
- FINNHUB_KEYS_RAW = os.getenv("FINNHUB_KEYS", "")
42
- if FINNHUB_KEYS_RAW:
43
- FINNHUB_KEYS = [key.strip() for key in FINNHUB_KEYS_RAW.split('\n') if key.strip()]
44
- else:
45
- FINNHUB_KEYS = []
46
-
47
- # Load RapidAPI keys from single secret (multiple keys separated by newlines)
48
- RAPIDAPI_KEYS_RAW = os.getenv("RAPIDAPI_KEYS", "")
49
- if RAPIDAPI_KEYS_RAW:
50
- RAPIDAPI_KEYS = [key.strip() for key in RAPIDAPI_KEYS_RAW.split('\n') if key.strip()]
51
- else:
52
- RAPIDAPI_KEYS = []
53
-
54
- # Load Google API keys from single secret (multiple keys separated by newlines)
55
- GOOGLE_API_KEYS_RAW = os.getenv("GOOGLE_API_KEYS", "")
56
- if GOOGLE_API_KEYS_RAW:
57
- GOOGLE_API_KEYS = [key.strip() for key in GOOGLE_API_KEYS_RAW.split('\n') if key.strip()]
58
- else:
59
- GOOGLE_API_KEYS = []
60
-
61
- # Hugging Face Inference token
62
- HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "").strip()
63
-
64
- # Filter out empty keys
65
- FINNHUB_KEYS = [key for key in FINNHUB_KEYS if key.strip()]
66
- GOOGLE_API_KEYS = [key for key in GOOGLE_API_KEYS if key.strip()]
67
-
68
- # Validate that we have at least one key for each service
69
- if not FINNHUB_KEYS:
70
- print("⚠️ Warning: No Finnhub API keys found in secrets")
71
- if not RAPIDAPI_KEYS:
72
- print("⚠️ Warning: No RapidAPI keys found in secrets")
73
- if not GOOGLE_API_KEYS:
74
- print("⚠️ Warning: No Google API keys found in secrets")
75
-
76
- # Chọn ngẫu nhiên một khóa API để bắt đầu (if available)
77
- GOOGLE_API_KEY = random.choice(GOOGLE_API_KEYS) if GOOGLE_API_KEYS else None
78
-
79
- print("=" * 50)
80
- print("🚀 FinRobot Forecaster Starting Up...")
81
- print("=" * 50)
82
- if FINNHUB_KEYS:
83
- print(f"📊 Finnhub API: {len(FINNHUB_KEYS)} keys loaded")
84
- else:
85
- print("📊 Finnhub API: Not configured")
86
- if RAPIDAPI_KEYS:
87
- print(f"📈 RapidAPI Alpha Vantage: {RAPIDAPI_HOST} ({len(RAPIDAPI_KEYS)} keys loaded)")
88
- else:
89
- print("📈 RapidAPI Alpha Vantage: Not configured")
90
- if HF_TOKEN:
91
- print("🤖 HF Inference API: Token detected for Fin-o1-14B")
92
- else:
93
- print("🤖 HF Inference API: No token found (set HUGGINGFACEHUB_API_TOKEN)")
94
- print(f"🦙 LLM Model: {FIN_MODEL_ID} via HF Inference API")
95
- if USE_LOCAL_FIN_MODEL and torch.cuda.is_available():
96
- print("🧩 Local GPU mode requested and CUDA detected; will try local load of Fin-o1-14B")
97
- else:
98
- if USE_LOCAL_FIN_MODEL:
99
- print("🧩 Local mode requested but CUDA not available; falling back to HF Inference API")
100
- print("✅ Application started successfully!")
101
- print("=" * 50)
102
-
103
- # (Legacy) Google Generative AI configuration retained for backward compatibility
104
- if GOOGLE_API_KEYS:
105
- try:
106
- genai.configure(api_key=GOOGLE_API_KEYS[0])
107
- except Exception:
108
- pass
109
-
110
- # Cấu hình Finnhub client (if keys available)
111
- if FINNHUB_KEYS:
112
- # Configure with first key for initial setup
113
- finnhub_client = finnhub.Client(api_key=FINNHUB_KEYS[0])
114
- print(f"✅ Finnhub configured with {len(FINNHUB_KEYS)} keys")
115
- else:
116
- finnhub_client = None
117
- print("⚠️ Finnhub not configured - will use mock news data")
118
-
119
- # Tạo session với retry strategy cho requests
120
- def create_session():
121
- session = requests.Session()
122
- retry_strategy = Retry(
123
- total=3,
124
- backoff_factor=1,
125
- status_forcelist=[429, 500, 502, 503, 504],
126
- )
127
- adapter = HTTPAdapter(max_retries=retry_strategy)
128
- session.mount("http://", adapter)
129
- session.mount("https://", adapter)
130
- return session
131
-
132
- # Tạo session global
133
- requests_session = create_session()
134
-
135
- SYSTEM_PROMPT = (
136
- "You are a seasoned stock-market analyst. "
137
- "Given recent company news and optional basic financials, "
138
- "return:\n"
139
- "[Positive Developments] – 2-4 bullets\n"
140
- "[Potential Concerns] – 2-4 bullets\n"
141
- "[Prediction & Analysis] – a one-week price outlook with rationale."
142
- )
143
-
144
- # ---------- UTILITY HELPERS ----------------------------------------
145
-
146
- def today() -> str:
147
- return date.today().strftime("%Y-%m-%d")
148
-
149
- def n_weeks_before(date_string: str, n: int) -> str:
150
- return (datetime.strptime(date_string, "%Y-%m-%d") -
151
- timedelta(days=7 * n)).strftime("%Y-%m-%d")
152
-
153
- # ---------- DATA FETCHING --------------------------------------------------
154
-
155
- def get_stock_data(symbol: str, steps: list[str]) -> pd.DataFrame:
156
- # Thử tất cả RapidAPI Alpha Vantage keys
157
- for rapidapi_key in RAPIDAPI_KEYS:
158
- try:
159
- print(f"📈 Fetching stock data for {symbol} via RapidAPI (key: {rapidapi_key[:8]}...)")
160
-
161
- # RapidAPI Alpha Vantage endpoint
162
- url = f"https://{RAPIDAPI_HOST}/query"
163
-
164
- headers = {
165
- "X-RapidAPI-Host": RAPIDAPI_HOST,
166
- "X-RapidAPI-Key": rapidapi_key
167
- }
168
-
169
- params = {
170
- "function": "TIME_SERIES_DAILY",
171
- "symbol": symbol,
172
- "outputsize": "full",
173
- "datatype": "csv"
174
- }
175
-
176
- # Thử lại 3 lần với RapidAPI key hiện tại
177
- for attempt in range(3):
178
- try:
179
- resp = requests_session.get(url, headers=headers, params=params, timeout=30)
180
- if not resp.ok:
181
- print(f"RapidAPI HTTP error {resp.status_code} with key {rapidapi_key[:8]}..., attempt {attempt + 1}")
182
- time.sleep(2 ** attempt)
183
- continue
184
-
185
- text = resp.text.strip()
186
- if text.startswith("{"):
187
- info = resp.json()
188
- msg = info.get("Note") or info.get("Error Message") or info.get("Information") or str(info)
189
- if "rate limit" in msg.lower() or "quota" in msg.lower():
190
- print(f"RapidAPI rate limit hit with key {rapidapi_key[:8]}..., trying next key")
191
- break # Thử key tiếp theo
192
- raise RuntimeError(f"RapidAPI Alpha Vantage Error: {msg}")
193
-
194
- # Parse CSV data
195
- df = pd.read_csv(StringIO(text))
196
- date_col = "timestamp" if "timestamp" in df.columns else df.columns[0]
197
- df[date_col] = pd.to_datetime(df[date_col])
198
- df = df.sort_values(date_col).set_index(date_col)
199
-
200
- data = {"Start Date": [], "End Date": [], "Start Price": [], "End Price": []}
201
- for i in range(len(steps) - 1):
202
- s_date = pd.to_datetime(steps[i])
203
- e_date = pd.to_datetime(steps[i+1])
204
- seg = df.loc[s_date:e_date]
205
- if seg.empty:
206
- raise RuntimeError(
207
- f"RapidAPI Alpha Vantage cannot get {symbol} data for {steps[i]} – {steps[i+1]}"
208
- )
209
- data["Start Date"].append(seg.index[0])
210
- data["Start Price"].append(seg["close"].iloc[0])
211
- data["End Date"].append(seg.index[-1])
212
- data["End Price"].append(seg["close"].iloc[-1])
213
- time.sleep(1) # RapidAPI has higher limits
214
-
215
- print(f"✅ Successfully retrieved {symbol} data via RapidAPI (key: {rapidapi_key[:8]}...)")
216
- return pd.DataFrame(data)
217
-
218
- except requests.exceptions.Timeout:
219
- print(f"RapidAPI timeout with key {rapidapi_key[:8]}..., attempt {attempt + 1}")
220
- if attempt < 2:
221
- time.sleep(5 * (attempt + 1))
222
- continue
223
- else:
224
- break
225
- except requests.exceptions.RequestException as e:
226
- print(f"RapidAPI request error with key {rapidapi_key[:8]}..., attempt {attempt + 1}: {e}")
227
- if attempt < 2:
228
- time.sleep(3)
229
- continue
230
- else:
231
- break
232
-
233
- except Exception as e:
234
- print(f"RapidAPI Alpha Vantage failed with key {rapidapi_key[:8]}...: {e}")
235
- continue # Thử key tiếp theo
236
-
237
- # Fallback: Tạo mock data nếu tất cả RapidAPI keys đều fail
238
- print("⚠️ All RapidAPI keys failed, using mock data for demonstration...")
239
- return create_mock_stock_data(symbol, steps)
240
-
241
- def create_mock_stock_data(symbol: str, steps: list[str]) -> pd.DataFrame:
242
- """Tạo mock data để demo khi API không hoạt động"""
243
- import numpy as np
244
-
245
- data = {"Start Date": [], "End Date": [], "Start Price": [], "End Price": []}
246
-
247
- # Giá cơ bản khác nhau cho các symbol khác nhau
248
- base_prices = {
249
- "AAPL": 180.0, "MSFT": 350.0, "GOOGL": 140.0,
250
- "TSLA": 200.0, "NVDA": 450.0, "AMZN": 150.0
251
- }
252
- base_price = base_prices.get(symbol.upper(), 150.0)
253
-
254
- for i in range(len(steps) - 1):
255
- s_date = pd.to_datetime(steps[i])
256
- e_date = pd.to_datetime(steps[i+1])
257
-
258
- # Tạo giá ngẫu nhiên với xu hướng tăng nhẹ
259
- start_price = base_price + np.random.normal(0, 5)
260
- end_price = start_price + np.random.normal(2, 8) # Xu hướng tăng nhẹ
261
-
262
- data["Start Date"].append(s_date)
263
- data["Start Price"].append(round(start_price, 2))
264
- data["End Date"].append(e_date)
265
- data["End Price"].append(round(end_price, 2))
266
-
267
- base_price = end_price # Cập nhật giá cơ bản cho tuần tiếp theo
268
-
269
- return pd.DataFrame(data)
270
-
271
- def current_basics(symbol: str, curday: str) -> dict:
272
- # Check if Finnhub is configured
273
- if not FINNHUB_KEYS:
274
- print(f"⚠️ Finnhub not configured, skipping financial basics for {symbol}")
275
- return {}
276
-
277
- # Thử với tất cả các Finnhub API keys
278
- for api_key in FINNHUB_KEYS:
279
- try:
280
- client = finnhub.Client(api_key=api_key)
281
- # Thêm timeout cho Finnhub client
282
- raw = client.company_basic_financials(symbol, "all")
283
- if not raw["series"]:
284
- continue
285
- merged = defaultdict(dict)
286
- for metric, vals in raw["series"]["quarterly"].items():
287
- for v in vals:
288
- merged[v["period"]][metric] = v["v"]
289
-
290
- latest = max((p for p in merged if p <= curday), default=None)
291
- if latest is None:
292
- continue
293
- d = dict(merged[latest])
294
- d["period"] = latest
295
- return d
296
- except Exception as e:
297
- print(f"Error getting basics for {symbol} with key {api_key[:8]}...: {e}")
298
- time.sleep(2) # Thêm delay trước khi thử key tiếp theo
299
- continue
300
- return {}
301
-
302
- def attach_news(symbol: str, df: pd.DataFrame) -> pd.DataFrame:
303
- news_col = []
304
- for _, row in df.iterrows():
305
- start = row["Start Date"].strftime("%Y-%m-%d")
306
- end = row["End Date"].strftime("%Y-%m-%d")
307
- time.sleep(2) # Tăng delay để tránh rate limit
308
-
309
- # Check if Finnhub is configured
310
- if not FINNHUB_KEYS:
311
- print(f"⚠️ Finnhub not configured, using mock news for {symbol}")
312
- news_data = create_mock_news(symbol, start, end)
313
- news_col.append(json.dumps(news_data))
314
- continue
315
-
316
- # Thử với tất cả các Finnhub API keys
317
- news_data = []
318
- for api_key in FINNHUB_KEYS:
319
- try:
320
- client = finnhub.Client(api_key=api_key)
321
- weekly = client.company_news(symbol, _from=start, to=end)
322
- weekly_fmt = [
323
- {
324
- "date" : datetime.fromtimestamp(n["datetime"]).strftime("%Y%m%d%H%M%S"),
325
- "headline": n["headline"],
326
- "summary" : n["summary"],
327
- }
328
- for n in weekly
329
- ]
330
- weekly_fmt.sort(key=lambda x: x["date"])
331
- news_data = weekly_fmt
332
- break # Thành công, thoát khỏi loop
333
- except Exception as e:
334
- print(f"Error with Finnhub key {api_key[:8]}... for {symbol} from {start} to {end}: {e}")
335
- time.sleep(3) # Thêm delay trước khi thử key tiếp theo
336
- continue
337
-
338
- # Nếu không có news data, tạo mock news
339
- if not news_data:
340
- news_data = create_mock_news(symbol, start, end)
341
-
342
- news_col.append(json.dumps(news_data))
343
- df["News"] = news_col
344
- return df
345
-
346
- def create_mock_news(symbol: str, start: str, end: str) -> list:
347
- """Tạo mock news data khi API không hoạt động"""
348
- mock_news = [
349
- {
350
- "date": f"{start}120000",
351
- "headline": f"{symbol} Shows Strong Performance in Recent Trading",
352
- "summary": f"Company {symbol} has demonstrated resilience in the current market conditions with positive investor sentiment."
353
- },
354
- {
355
- "date": f"{end}090000",
356
- "headline": f"Analysts Maintain Positive Outlook for {symbol}",
357
- "summary": f"Financial analysts continue to recommend {symbol} based on strong fundamentals and growth prospects."
358
- }
359
- ]
360
- return mock_news
361
-
362
- # ---------- PROMPT CONSTRUCTION -------------------------------------------
363
-
364
- def sample_news(news: list[str], k: int = 5) -> list[str]:
365
- if len(news) <= k:
366
- return news
367
- return [news[i] for i in sorted(random.sample(range(len(news)), k))]
368
-
369
- def make_prompt(symbol: str, df: pd.DataFrame, curday: str, use_basics=False) -> str:
370
- # Thử với tất cả các Finnhub API keys để lấy company profile
371
- company_blurb = f"[Company Introduction]:\n{symbol} is a publicly traded company.\n"
372
-
373
- if FINNHUB_KEYS:
374
- for api_key in FINNHUB_KEYS:
375
- try:
376
- client = finnhub.Client(api_key=api_key)
377
- prof = client.company_profile2(symbol=symbol)
378
- company_blurb = (
379
- f"[Company Introduction]:\n{prof['name']} operates in the "
380
- f"{prof['finnhubIndustry']} sector ({prof['country']}). "
381
- f"Founded {prof['ipo']}, market cap {prof['marketCapitalization']:.1f} "
382
- f"{prof['currency']}; ticker {symbol} on {prof['exchange']}.\n"
383
- )
384
- break # Thành công, thoát khỏi loop
385
- except Exception as e:
386
- print(f"Error getting company profile for {symbol} with key {api_key[:8]}...: {e}")
387
- time.sleep(2) # Thêm delay trước khi thử key tiếp theo
388
- continue
389
- else:
390
- print(f"⚠️ Finnhub not configured, using basic company info for {symbol}")
391
-
392
- # Past weeks block
393
- past_block = ""
394
- for _, row in df.iterrows():
395
- term = "increased" if row["End Price"] > row["Start Price"] else "decreased"
396
- head = (f"From {row['Start Date']:%Y-%m-%d} to {row['End Date']:%Y-%m-%d}, "
397
- f"{symbol}'s stock price {term} from "
398
- f"{row['Start Price']:.2f} to {row['End Price']:.2f}.")
399
- news_items = json.loads(row["News"])
400
- summaries = [
401
- f"[Headline] {n['headline']}\n[Summary] {n['summary']}\n"
402
- for n in news_items
403
- if not n["summary"].startswith("Looking for stock market analysis")
404
- ]
405
- past_block += "\n" + head + "\n" + "".join(sample_news(summaries, 5))
406
-
407
- # Optional basic financials
408
- if use_basics:
409
- basics = current_basics(symbol, curday)
410
- if basics:
411
- basics_txt = "\n".join(f"{k}: {v}" for k, v in basics.items() if k != "period")
412
- basics_block = (f"\n[Basic Financials] (reported {basics['period']}):\n{basics_txt}\n")
413
- else:
414
- basics_block = "\n[Basic Financials]: not available\n"
415
- else:
416
- basics_block = "\n[Basic Financials]: not requested\n"
417
-
418
- horizon = f"{curday} to {n_weeks_before(curday, -1)}"
419
- final_user_msg = (
420
- company_blurb
421
- + past_block
422
- + basics_block
423
- + f"\nBased on all information before {curday}, analyse positive "
424
- "developments and potential concerns for {symbol}, then predict its "
425
- f"price movement for next week ({horizon})."
426
- )
427
- return final_user_msg
428
-
429
- # ---------- LLM CALL -------------------------------------------------------
430
-
431
- def chat_completion(prompt: str,
432
- model: str = FIN_MODEL_ID,
433
- temperature: float = 0.2,
434
- stream: bool = False,
435
- symbol: str = "STOCK") -> str:
436
- # Prefer local GPU inference if requested and available
437
- if USE_LOCAL_FIN_MODEL and torch.cuda.is_available():
438
- try:
439
- text = _local_generate_with_fin_model(prompt, model, temperature)
440
- if isinstance(text, str) and text.strip():
441
- return text.strip()
442
- except Exception as e:
443
- print(f"Local GPU inference failed: {e}. Falling back to HF Inference API...")
444
-
445
- # Use Hugging Face Inference API for Fin-o1-14B
446
- if not HF_TOKEN:
447
- print(f"⚠️ HF token missing, using mock response for {symbol}")
448
- return create_mock_ai_response(symbol)
449
-
450
- full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}"
451
- url = f"https://api-inference.huggingface.co/models/{model}"
452
- headers = {
453
- "Authorization": f"Bearer {HF_TOKEN}",
454
- "Accept": "application/json",
455
- "Content-Type": "application/json",
456
- }
457
- payload = {
458
- "inputs": full_prompt,
459
- "parameters": {
460
- "max_new_tokens": 1024,
461
- "temperature": max(0.0, min(1.0, float(temperature))),
462
- "top_p": 0.9,
463
- "repetition_penalty": 1.05,
464
- "return_full_text": False
465
- },
466
- "options": {"use_cache": True, "wait_for_model": True}
467
- }
468
-
469
- # Retry logic including model loading 503
470
- for attempt in range(4):
471
- try:
472
- resp = requests_session.post(url, headers=headers, data=json.dumps(payload), timeout=120)
473
- if resp.status_code == 503:
474
- try:
475
- info = resp.json()
476
- wait_s = float(info.get("estimated_time", 5.0))
477
- except Exception:
478
- wait_s = 5.0
479
- print(f"Model loading (503). Waiting {wait_s:.1f}s before retry...")
480
- time.sleep(min(wait_s + attempt, 15))
481
- continue
482
- if not resp.ok:
483
- print(f"HF API error {resp.status_code}: {resp.text[:200]}")
484
- time.sleep(1 + attempt)
485
- continue
486
-
487
- data = resp.json()
488
- # Possible shapes: [{"generated_text": "..."}], {"generated_text": "..."}, or text
489
- if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
490
- return data[0]["generated_text"].strip()
491
- if isinstance(data, dict) and "generated_text" in data:
492
- return str(data["generated_text"]).strip()
493
- # Some pipelines return token sequence under 'outputs'
494
- text = None
495
- if isinstance(data, list) and data and isinstance(data[0], dict):
496
- text = data[0].get("text") or data[0].get("generated_text")
497
- if isinstance(data, dict):
498
- text = text or data.get("text") or data.get("data")
499
- if isinstance(text, str):
500
- return text.strip()
501
- # Fallback stringify
502
- return str(data)
503
- except requests.exceptions.RequestException as e:
504
- print(f"HF request error (attempt {attempt+1}): {e}")
505
- time.sleep(1 + attempt)
506
- continue
507
- except Exception as e:
508
- print(f"HF unknown error: {e}")
509
- break
510
-
511
- print("⚠️ All HF attempts failed, using mock AI response for demonstration...")
512
- return create_mock_ai_response(symbol)
513
-
514
-
515
- # ---------- LOCAL GPU INFERENCE (optional) ---------------------------------
516
-
517
- _LOCAL_FIN_TOKENIZER = None
518
- _LOCAL_FIN_MODEL = None
519
-
520
- def _load_local_fin_model(model_id: str):
521
- global _LOCAL_FIN_TOKENIZER, _LOCAL_FIN_MODEL
522
- if _LOCAL_FIN_MODEL is not None and _LOCAL_FIN_TOKENIZER is not None:
523
- return _LOCAL_FIN_TOKENIZER, _LOCAL_FIN_MODEL
524
-
525
- print(f"Loading local model {model_id} ...")
526
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
527
-
528
- quant_config = None
529
- try:
530
- import bitsandbytes as bnb # noqa: F401
531
- quant_config = BitsAndBytesConfig(
532
- load_in_4bit=True,
533
- bnb_4bit_use_double_quant=True,
534
- bnb_4bit_quant_type="nf4",
535
- bnb_4bit_compute_dtype=torch.bfloat16,
536
- )
537
- print("Using 4-bit quantization via bitsandbytes")
538
- except Exception:
539
- print("bitsandbytes not available; trying bf16 with accelerate device_map=auto")
540
-
541
- _LOCAL_FIN_TOKENIZER = AutoTokenizer.from_pretrained(model_id, use_fast=True)
542
- if quant_config is not None:
543
- _LOCAL_FIN_MODEL = AutoModelForCausalLM.from_pretrained(
544
- model_id,
545
- quantization_config=quant_config,
546
- device_map="auto",
547
- trust_remote_code=True,
548
- )
549
- else:
550
- _LOCAL_FIN_MODEL = AutoModelForCausalLM.from_pretrained(
551
- model_id,
552
- torch_dtype=torch.bfloat16,
553
- device_map="auto",
554
- trust_remote_code=True,
555
- )
556
-
557
- try:
558
- _LOCAL_FIN_MODEL.eval()
559
- except Exception:
560
- pass
561
- return _LOCAL_FIN_TOKENIZER, _LOCAL_FIN_MODEL
562
-
563
-
564
- def _local_generate_with_fin_model(user_prompt: str, model_id: str, temperature: float) -> str:
565
- tokenizer, model = _load_local_fin_model(model_id)
566
- full_prompt = f"{SYSTEM_PROMPT}\n\n{user_prompt}"
567
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
568
- with torch.no_grad():
569
- output_ids = model.generate(
570
- **inputs,
571
- max_new_tokens=1024,
572
- do_sample=True,
573
- temperature=float(max(0.0, min(1.0, temperature))),
574
- top_p=0.9,
575
- repetition_penalty=1.05,
576
- eos_token_id=tokenizer.eos_token_id,
577
- pad_token_id=tokenizer.eos_token_id,
578
- )
579
- generated = output_ids[0][inputs["input_ids"].shape[-1]:]
580
- text = tokenizer.decode(generated, skip_special_tokens=True)
581
- return text
582
-
583
- def create_mock_ai_response(symbol: str) -> str:
584
- """Tạo mock AI response khi Google API không hoạt động"""
585
- return f"""
586
- [Positive Developments]
587
- • Strong market position and brand recognition for {symbol}
588
- • Recent quarterly earnings showing growth potential
589
- • Positive analyst sentiment and institutional investor interest
590
- • Technological innovation and market expansion opportunities
591
-
592
- [Potential Concerns]
593
- • Market volatility and economic uncertainty
594
- • Competitive pressures in the industry
595
- • Regulatory changes that may impact operations
596
- • Global economic factors affecting stock performance
597
-
598
- [Prediction & Analysis]
599
- Based on the current market conditions and company fundamentals, {symbol} is expected to show moderate growth over the next week. The stock may experience some volatility but should maintain an upward trend with a potential price increase of 2-5%. This prediction is based on current market sentiment and technical analysis patterns.
600
-
601
- Note: This is a demonstration response using mock data. For real investment decisions, please consult with qualified financial professionals.
602
- """
603
-
604
- # ---------- MAIN PREDICTION FUNCTION -----------------------------------------
605
-
606
- def predict(symbol: str = "AAPL",
607
- curday: str = today(),
608
- n_weeks: int = 3,
609
- use_basics: bool = False,
610
- stream: bool = False) -> tuple[str, str]:
611
- try:
612
- steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
613
- df = get_stock_data(symbol, steps)
614
- df = attach_news(symbol, df)
615
-
616
- prompt_info = make_prompt(symbol, df, curday, use_basics)
617
- answer = chat_completion(prompt_info, stream=stream, symbol=symbol)
618
-
619
- return prompt_info, answer
620
- except Exception as e:
621
- error_msg = f"Error in prediction: {str(e)}"
622
- print(f"Prediction error: {e}") # Log the error for debugging
623
- return error_msg, error_msg
624
-
625
- # ---------- HUGGINGFACE SPACES INTERFACE -----------------------------------------
626
-
627
- def hf_predict(symbol, n_weeks, use_basics):
628
- # 1. get curday
629
- curday = date.today().strftime("%Y-%m-%d")
630
- # 2. call predict
631
- prompt, answer = predict(
632
- symbol=symbol.upper(),
633
- curday=curday,
634
- n_weeks=int(n_weeks),
635
- use_basics=bool(use_basics),
636
- stream=False
637
- )
638
- return prompt, answer
639
-
640
- # ---------- GRADIO INTERFACE -----------------------------------------
641
-
642
- def create_interface():
643
- with gr.Blocks(
644
- title="FinRobot Forecaster",
645
- theme=gr.themes.Soft(),
646
- css="""
647
- .gradio-container {
648
- max-width: 1200px !important;
649
- margin: auto !important;
650
- }
651
- #model_prompt_textbox textarea {
652
- overflow-y: auto !important;
653
- max-height: none !important;
654
- min-height: 400px !important;
655
- resize: vertical !important;
656
- white-space: pre-wrap !important;
657
- word-wrap: break-word !important;
658
- height: auto !important;
659
- }
660
- #model_prompt_textbox {
661
- height: auto !important;
662
- }
663
- #analysis_results_textbox textarea {
664
- overflow-y: auto !important;
665
- max-height: none !important;
666
- min-height: 400px !important;
667
- resize: vertical !important;
668
- white-space: pre-wrap !important;
669
- word-wrap: break-word !important;
670
- height: auto !important;
671
- }
672
- #analysis_results_textbox {
673
- height: auto !important;
674
- }
675
- .textarea textarea {
676
- overflow-y: auto !important;
677
- max-height: 500px !important;
678
- resize: vertical !important;
679
- }
680
- .textarea {
681
- height: auto !important;
682
- min-height: 300px !important;
683
- }
684
- .gradio-textbox {
685
- height: auto !important;
686
- max-height: none !important;
687
- }
688
- .gradio-textbox textarea {
689
- height: auto !important;
690
- max-height: none !important;
691
- overflow-y: auto !important;
692
- }
693
- """
694
- ) as demo:
695
- gr.Markdown("""
696
- # 🤖 FinRobot Forecaster
697
-
698
- **AI-powered stock market analysis and prediction using Fin-o1-14B**
699
-
700
- This application analyzes stock market data, company news, and financial metrics to provide comprehensive market insights and predictions.
701
-
702
- • Model: **TheFinAI/Fin-o1-14B** (Qwen3-14B finetune) via Hugging Face Inference API
703
- • Set secret **HUGGINGFACEHUB_API_TOKEN** in your Space for real responses
704
-
705
- ⚠️ **Note**: Free data APIs have daily rate limits. If you encounter errors, the app may use mock data for demonstration purposes.
706
- """)
707
-
708
- with gr.Row():
709
- with gr.Column(scale=1):
710
- symbol = gr.Textbox(
711
- label="Stock Symbol",
712
- value="AAPL",
713
- placeholder="Enter stock symbol (e.g., AAPL, MSFT, GOOGL)",
714
- info="Enter the ticker symbol of the stock you want to analyze"
715
- )
716
- n_weeks = gr.Slider(
717
- 1, 6,
718
- value=3,
719
- step=1,
720
- label="Historical Weeks to Analyze",
721
- info="Number of weeks of historical data to include in analysis"
722
- )
723
- use_basics = gr.Checkbox(
724
- label="Include Basic Financials",
725
- value=True,
726
- info="Include basic financial metrics in the analysis"
727
- )
728
- btn = gr.Button(
729
- "🚀 Run Analysis",
730
- variant="primary"
731
- )
732
-
733
- with gr.Column(scale=2):
734
- with gr.Tabs():
735
- with gr.Tab("📊 Analysis Results"):
736
- gr.Markdown("**AI Analysis & Prediction**")
737
- output_answer = gr.Textbox(
738
- label="",
739
- lines=40,
740
- show_copy_button=True,
741
- interactive=False,
742
- placeholder="AI analysis and predictions will appear here...",
743
- container=True,
744
- scale=1,
745
- elem_id="analysis_results_textbox"
746
- )
747
- with gr.Tab("🔍 Model Prompt"):
748
- gr.Markdown("**Generated Prompt**")
749
- output_prompt = gr.Textbox(
750
- label="",
751
- lines=40,
752
- show_copy_button=True,
753
- interactive=False,
754
- placeholder="Generated prompt will appear here...",
755
- container=True,
756
- scale=1,
757
- elem_id="model_prompt_textbox"
758
- )
759
-
760
- # Examples
761
- gr.Examples(
762
- examples=[
763
- ["AAPL", 3, False],
764
- ["MSFT", 4, True],
765
- ["GOOGL", 2, False],
766
- ["TSLA", 5, True],
767
- ["NVDA", 3, True]
768
- ],
769
- inputs=[symbol, n_weeks, use_basics],
770
- label="💡 Try these examples"
771
- )
772
-
773
- # Event handlers
774
- btn.click(
775
- fn=hf_predict,
776
- inputs=[symbol, n_weeks, use_basics],
777
- outputs=[output_prompt, output_answer],
778
- show_progress=True
779
- )
780
-
781
-
782
- # Footer
783
- gr.Markdown("""
784
- ---
785
- **Disclaimer**: This application is for educational and research purposes only.
786
- The predictions and analysis provided should not be considered as financial advice.
787
- Always consult with qualified financial professionals before making investment decisions.
788
- """)
789
-
790
- return demo
791
-
792
- # ---------- MAIN EXECUTION -----------------------------------------
793
-
794
- if __name__ == "__main__":
795
- demo = create_interface()
796
- demo.launch(
797
- server_name="0.0.0.0",
798
- server_port=7860,
799
- share=False,
800
- show_error=True,
801
- debug=False,
802
- quiet=True
803
- )