codeby-hp's picture
adding files
28a0fff verified
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
import mlflow
import pickle
import os
import pandas as pd
import numpy as np
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
import string
import re
import dagshub
import nltk
import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore")
from dotenv import load_dotenv
load_dotenv()
# Download required NLTK data
try:
nltk.download('stopwords', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
except:
pass
def lemmatization(text):
"""Lemmatize the text."""
lemmatizer = WordNetLemmatizer()
text = text.split()
text = [lemmatizer.lemmatize(word) for word in text]
return " ".join(text)
def remove_stop_words(text):
"""Remove stop words from the text."""
stop_words = set(stopwords.words("english"))
text = [word for word in str(text).split() if word not in stop_words]
return " ".join(text)
def removing_numbers(text):
"""Remove numbers from the text."""
text = ''.join([char for char in text if not char.isdigit()])
return text
def lower_case(text):
"""Convert text to lower case."""
text = text.split()
text = [word.lower() for word in text]
return " ".join(text)
def removing_punctuations(text):
"""Remove punctuations from the text."""
text = re.sub('[%s]' % re.escape(string.punctuation), ' ', text)
text = text.replace('؛', "")
text = re.sub('\s+', ' ', text).strip()
return text
def removing_urls(text):
"""Remove URLs from the text."""
url_pattern = re.compile(r'https?://\S+|www\.\S+')
return url_pattern.sub(r'', text)
def remove_small_sentences(df):
"""Remove sentences with less than 3 words."""
for i in range(len(df)):
if len(df.text.iloc[i].split()) < 3:
df.text.iloc[i] = np.nan
def normalize_text(text):
text = lower_case(text)
text = remove_stop_words(text)
text = removing_numbers(text)
text = removing_punctuations(text)
text = removing_urls(text)
text = lemmatization(text)
return text
# Below code block is for local use
# -------------------------------------------------------------------------------------
# mlflow.set_tracking_uri('https://dagshub.com/CodeBy-HP/Sentiment-Classification-Mlflow-DVC.mlflow')
# dagshub.init(repo_owner='CodeBy-HP', repo_name='Sentiment-Classification-Mlflow-DVC', mlflow=True)
# -------------------------------------------------------------------------------------
# Below code block is for production use
# -------------------------------------------------------------------------------------
# Set up DagsHub credentials for MLflow tracking
dagshub_token = os.getenv("CAPSTONE_TEST")
if not dagshub_token:
raise EnvironmentError("CAPSTONE_TEST environment variable is not set")
os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token
os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token
dagshub_url = "https://dagshub.com"
repo_owner = "CodeBy-HP"
repo_name = "Sentiment-Classification-Mlflow-DVC"
# Set up MLflow tracking URI
mlflow.set_tracking_uri(f'{dagshub_url}/{repo_owner}/{repo_name}.mlflow')
# -------------------------------------------------------------------------------------
# Initialize FastAPI app
app = FastAPI(title="Sentiment Analysis API", version="1.0.0")
# Set up Jinja2 templates
current_file_dir = os.path.dirname(os.path.abspath(__file__))
templates_dir = os.path.join(current_file_dir, "templates")
templates = Jinja2Templates(directory=templates_dir)
# ------------------------------------------------------------------------------------------
# Model and vectorizer setup
model_name = "my_model"
# Get the path to the vectorizer file
current_dir = os.path.dirname(os.path.abspath(__file__))
vectorizer_path = os.path.join(current_dir, 'models', 'vectorizer.pkl')
if not os.path.exists(vectorizer_path):
# Try alternative paths
alt_paths = [
os.path.join(os.getcwd(), 'models', 'vectorizer.pkl'),
os.path.join(current_dir, '..', 'models', 'vectorizer.pkl'),
'/app/models/vectorizer.pkl' # Docker path
]
for path in alt_paths:
if os.path.exists(path):
vectorizer_path = path
break
def get_latest_model_version(model_name):
client = mlflow.MlflowClient()
latest_version = client.get_latest_versions(model_name, stages=["Production"])
if not latest_version:
latest_version = client.get_latest_versions(model_name, stages=["None"])
return latest_version[0].version if latest_version else None
model_version = get_latest_model_version(model_name)
model_uri = f'models:/{model_name}/{model_version}'
print(f"Fetching model from: {model_uri}")
model = mlflow.sklearn.load_model(model_uri)
vectorizer = pickle.load(open(vectorizer_path, 'rb'))
# Routes
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Render the home page."""
return templates.TemplateResponse(
request=request,
name="index.html",
context={"result": None}
)
@app.post("/predict", response_class=HTMLResponse)
async def predict(request: Request, text: str = Form(...)):
"""Handle sentiment prediction."""
# Clean text
cleaned_text = normalize_text(text)
# Convert to features
features = vectorizer.transform([cleaned_text])
# Convert to array without column names to avoid sklearn warning
features_array = features.toarray()
# Predict
result = model.predict(features_array)
prediction = int(result[0])
# Get probability scores for confidence
# Note: predict_proba returns [prob_negative, prob_positive]
probabilities = model.predict_proba(features_array)[0]
confidence = float(probabilities[prediction]) * 100 # Convert to percentage
return templates.TemplateResponse(
request=request,
name="index.html",
context={"result": prediction, "confidence": confidence}
)
@app.get("/health")
async def health_check():
"""Health check endpoint for monitoring."""
return {"status": "healthy", "model_version": model_version}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)