Spaces:
Sleeping
Sleeping
File size: 6,401 Bytes
28a0fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
import mlflow
import pickle
import os
import pandas as pd
import numpy as np
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
import string
import re
import dagshub
import nltk
import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore")
from dotenv import load_dotenv
load_dotenv()
# Download required NLTK data
try:
nltk.download('stopwords', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
except:
pass
def lemmatization(text):
"""Lemmatize the text."""
lemmatizer = WordNetLemmatizer()
text = text.split()
text = [lemmatizer.lemmatize(word) for word in text]
return " ".join(text)
def remove_stop_words(text):
"""Remove stop words from the text."""
stop_words = set(stopwords.words("english"))
text = [word for word in str(text).split() if word not in stop_words]
return " ".join(text)
def removing_numbers(text):
"""Remove numbers from the text."""
text = ''.join([char for char in text if not char.isdigit()])
return text
def lower_case(text):
"""Convert text to lower case."""
text = text.split()
text = [word.lower() for word in text]
return " ".join(text)
def removing_punctuations(text):
"""Remove punctuations from the text."""
text = re.sub('[%s]' % re.escape(string.punctuation), ' ', text)
text = text.replace('؛', "")
text = re.sub('\s+', ' ', text).strip()
return text
def removing_urls(text):
"""Remove URLs from the text."""
url_pattern = re.compile(r'https?://\S+|www\.\S+')
return url_pattern.sub(r'', text)
def remove_small_sentences(df):
"""Remove sentences with less than 3 words."""
for i in range(len(df)):
if len(df.text.iloc[i].split()) < 3:
df.text.iloc[i] = np.nan
def normalize_text(text):
text = lower_case(text)
text = remove_stop_words(text)
text = removing_numbers(text)
text = removing_punctuations(text)
text = removing_urls(text)
text = lemmatization(text)
return text
# Below code block is for local use
# -------------------------------------------------------------------------------------
# mlflow.set_tracking_uri('https://dagshub.com/CodeBy-HP/Sentiment-Classification-Mlflow-DVC.mlflow')
# dagshub.init(repo_owner='CodeBy-HP', repo_name='Sentiment-Classification-Mlflow-DVC', mlflow=True)
# -------------------------------------------------------------------------------------
# Below code block is for production use
# -------------------------------------------------------------------------------------
# Set up DagsHub credentials for MLflow tracking
dagshub_token = os.getenv("CAPSTONE_TEST")
if not dagshub_token:
raise EnvironmentError("CAPSTONE_TEST environment variable is not set")
os.environ["MLFLOW_TRACKING_USERNAME"] = dagshub_token
os.environ["MLFLOW_TRACKING_PASSWORD"] = dagshub_token
dagshub_url = "https://dagshub.com"
repo_owner = "CodeBy-HP"
repo_name = "Sentiment-Classification-Mlflow-DVC"
# Set up MLflow tracking URI
mlflow.set_tracking_uri(f'{dagshub_url}/{repo_owner}/{repo_name}.mlflow')
# -------------------------------------------------------------------------------------
# Initialize FastAPI app
app = FastAPI(title="Sentiment Analysis API", version="1.0.0")
# Set up Jinja2 templates
current_file_dir = os.path.dirname(os.path.abspath(__file__))
templates_dir = os.path.join(current_file_dir, "templates")
templates = Jinja2Templates(directory=templates_dir)
# ------------------------------------------------------------------------------------------
# Model and vectorizer setup
model_name = "my_model"
# Get the path to the vectorizer file
current_dir = os.path.dirname(os.path.abspath(__file__))
vectorizer_path = os.path.join(current_dir, 'models', 'vectorizer.pkl')
if not os.path.exists(vectorizer_path):
# Try alternative paths
alt_paths = [
os.path.join(os.getcwd(), 'models', 'vectorizer.pkl'),
os.path.join(current_dir, '..', 'models', 'vectorizer.pkl'),
'/app/models/vectorizer.pkl' # Docker path
]
for path in alt_paths:
if os.path.exists(path):
vectorizer_path = path
break
def get_latest_model_version(model_name):
client = mlflow.MlflowClient()
latest_version = client.get_latest_versions(model_name, stages=["Production"])
if not latest_version:
latest_version = client.get_latest_versions(model_name, stages=["None"])
return latest_version[0].version if latest_version else None
model_version = get_latest_model_version(model_name)
model_uri = f'models:/{model_name}/{model_version}'
print(f"Fetching model from: {model_uri}")
model = mlflow.sklearn.load_model(model_uri)
vectorizer = pickle.load(open(vectorizer_path, 'rb'))
# Routes
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Render the home page."""
return templates.TemplateResponse(
request=request,
name="index.html",
context={"result": None}
)
@app.post("/predict", response_class=HTMLResponse)
async def predict(request: Request, text: str = Form(...)):
"""Handle sentiment prediction."""
# Clean text
cleaned_text = normalize_text(text)
# Convert to features
features = vectorizer.transform([cleaned_text])
# Convert to array without column names to avoid sklearn warning
features_array = features.toarray()
# Predict
result = model.predict(features_array)
prediction = int(result[0])
# Get probability scores for confidence
# Note: predict_proba returns [prob_negative, prob_positive]
probabilities = model.predict_proba(features_array)[0]
confidence = float(probabilities[prediction]) * 100 # Convert to percentage
return templates.TemplateResponse(
request=request,
name="index.html",
context={"result": prediction, "confidence": confidence}
)
@app.get("/health")
async def health_check():
"""Health check endpoint for monitoring."""
return {"status": "healthy", "model_version": model_version}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|