Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Form | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| import mlflow | |
| import pickle | |
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from nltk.stem import WordNetLemmatizer | |
| from nltk.corpus import stopwords | |
| import string | |
| import re | |
| import dagshub | |
| import nltk | |
| import warnings | |
| warnings.simplefilter("ignore", UserWarning) | |
| warnings.filterwarnings("ignore") | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Download required NLTK data | |
| try: | |
| nltk.download('stopwords', quiet=True) | |
| nltk.download('wordnet', quiet=True) | |
| nltk.download('omw-1.4', quiet=True) | |
| except: | |
| pass | |
| def lemmatization(text): | |
| """Lemmatize the text.""" | |
| lemmatizer = WordNetLemmatizer() | |
| text = text.split() | |
| text = [lemmatizer.lemmatize(word) for word in text] | |
| return " ".join(text) | |
| def remove_stop_words(text): | |
| """Remove stop words from the text.""" | |
| stop_words = set(stopwords.words("english")) | |
| text = [word for word in str(text).split() if word not in stop_words] | |
| return " ".join(text) | |
| def removing_numbers(text): | |
| """Remove numbers from the text.""" | |
| text = ''.join([char for char in text if not char.isdigit()]) | |
| return text | |
| def lower_case(text): | |
| """Convert text to lower case.""" | |
| text = text.split() | |
| text = [word.lower() for word in text] | |
| return " ".join(text) | |
| def removing_punctuations(text): | |
| """Remove punctuations from the text.""" | |
| text = re.sub('[%s]' % re.escape(string.punctuation), ' ', text) | |
| text = text.replace('؛', "") | |
| text = re.sub('\s+', ' ', text).strip() | |
| return text | |
| def removing_urls(text): | |
| """Remove URLs from the text.""" | |
| url_pattern = re.compile(r'https?://\S+|www\.\S+') | |
| return url_pattern.sub(r'', text) | |
| def remove_small_sentences(df): | |
| """Remove sentences with less than 3 words.""" | |
| for i in range(len(df)): | |
| if len(df.text.iloc[i].split()) < 3: | |
| df.text.iloc[i] = np.nan | |
| def normalize_text(text): | |
| text = lower_case(text) | |
| text = remove_stop_words(text) | |
| text = removing_numbers(text) | |
| text = removing_punctuations(text) | |
| text = removing_urls(text) | |
| text = lemmatization(text) | |
| return text | |
| # Below code block is for local use | |
| # ------------------------------------------------------------------------------------- | |
| # mlflow.set_tracking_uri('https://dagshub.com/CodeBy-HP/Sentiment-Classification-Mlflow-DVC.mlflow') | |
| # dagshub.init(repo_owner='CodeBy-HP', repo_name='Sentiment-Classification-Mlflow-DVC', mlflow=True) | |
| # ------------------------------------------------------------------------------------- | |
| # Below code block is for production use | |
| # ------------------------------------------------------------------------------------- | |
| # Set up DagsHub credentials for MLflow tracking | |
| dagshub_token = os.getenv("CAPSTONE_TEST") | |
| if not dagshub_token: | |
| raise EnvironmentError("CAPSTONE_TEST environment variable is not set") | |
| os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token | |
| os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token | |
| dagshub_url = "https://dagshub.com" | |
| repo_owner = "CodeBy-HP" | |
| repo_name = "Sentiment-Classification-Mlflow-DVC" | |
| # Set up MLflow tracking URI | |
| mlflow.set_tracking_uri(f'{dagshub_url}/{repo_owner}/{repo_name}.mlflow') | |
| # ------------------------------------------------------------------------------------- | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Sentiment Analysis API", version="1.0.0") | |
| # Set up Jinja2 templates | |
| current_file_dir = os.path.dirname(os.path.abspath(__file__)) | |
| templates_dir = os.path.join(current_file_dir, "templates") | |
| templates = Jinja2Templates(directory=templates_dir) | |
| # ------------------------------------------------------------------------------------------ | |
| # Model and vectorizer setup | |
| model_name = "my_model" | |
| # Get the path to the vectorizer file | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| vectorizer_path = os.path.join(current_dir, 'models', 'vectorizer.pkl') | |
| if not os.path.exists(vectorizer_path): | |
| # Try alternative paths | |
| alt_paths = [ | |
| os.path.join(os.getcwd(), 'models', 'vectorizer.pkl'), | |
| os.path.join(current_dir, '..', 'models', 'vectorizer.pkl'), | |
| '/app/models/vectorizer.pkl' # Docker path | |
| ] | |
| for path in alt_paths: | |
| if os.path.exists(path): | |
| vectorizer_path = path | |
| break | |
| def get_latest_model_version(model_name): | |
| client = mlflow.MlflowClient() | |
| latest_version = client.get_latest_versions(model_name, stages=["Production"]) | |
| if not latest_version: | |
| latest_version = client.get_latest_versions(model_name, stages=["None"]) | |
| return latest_version[0].version if latest_version else None | |
| model_version = get_latest_model_version(model_name) | |
| model_uri = f'models:/{model_name}/{model_version}' | |
| print(f"Fetching model from: {model_uri}") | |
| model = mlflow.sklearn.load_model(model_uri) | |
| vectorizer = pickle.load(open(vectorizer_path, 'rb')) | |
| # Routes | |
| async def home(request: Request): | |
| """Render the home page.""" | |
| return templates.TemplateResponse( | |
| request=request, | |
| name="index.html", | |
| context={"result": None} | |
| ) | |
| async def predict(request: Request, text: str = Form(...)): | |
| """Handle sentiment prediction.""" | |
| # Clean text | |
| cleaned_text = normalize_text(text) | |
| # Convert to features | |
| features = vectorizer.transform([cleaned_text]) | |
| # Convert to array without column names to avoid sklearn warning | |
| features_array = features.toarray() | |
| # Predict | |
| result = model.predict(features_array) | |
| prediction = int(result[0]) | |
| # Get probability scores for confidence | |
| # Note: predict_proba returns [prob_negative, prob_positive] | |
| probabilities = model.predict_proba(features_array)[0] | |
| confidence = float(probabilities[prediction]) * 100 # Convert to percentage | |
| return templates.TemplateResponse( | |
| request=request, | |
| name="index.html", | |
| context={"result": prediction, "confidence": confidence} | |
| ) | |
| async def health_check(): | |
| """Health check endpoint for monitoring.""" | |
| return {"status": "healthy", "model_version": model_version} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |