Update src.py
Browse files
src.py
CHANGED
|
@@ -6,12 +6,14 @@ from PIL import Image
|
|
| 6 |
from pandasai.llm import HuggingFaceTextGen
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from langchain_groq.chat_models import ChatGroq
|
|
|
|
| 9 |
|
| 10 |
load_dotenv()
|
| 11 |
Groq_Token = os.environ["GROQ_API_KEY"]
|
| 12 |
-
models = {"mixtral": "mixtral-8x7b-32768", "llama": "llama2-70b-4096", "gemma": "gemma-7b-it"}
|
| 13 |
|
| 14 |
hf_token = os.getenv("HF_READ")
|
|
|
|
| 15 |
|
| 16 |
def preprocess_and_load_df(path: str) -> pd.DataFrame:
|
| 17 |
df = pd.read_csv(path)
|
|
@@ -27,7 +29,10 @@ def load_agent(df: pd.DataFrame, context: str, inference_server: str, name="mixt
|
|
| 27 |
# top_k=5,
|
| 28 |
# )
|
| 29 |
# llm.client.headers = {"Authorization": f"Bearer {hf_token}"}
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
agent = Agent(df, config={"llm": llm, "enable_cache": False, "options": {"wait_for_model": True}})
|
| 33 |
agent.add_message(context)
|
|
@@ -86,7 +91,10 @@ def show_response(st, response):
|
|
| 86 |
return {"is_image": False}
|
| 87 |
|
| 88 |
def ask_question(model_name, question):
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
df_check = pd.read_csv("Data.csv")
|
| 92 |
df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
|
|
@@ -121,11 +129,13 @@ df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
|
| 121 |
{template}
|
| 122 |
|
| 123 |
"""
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
code = f"""
|
| 127 |
{template.split("```python")[1].split("```")[0]}
|
| 128 |
-
{answer.
|
| 129 |
"""
|
| 130 |
# update variable `answer` when code is executed
|
| 131 |
exec(code)
|
|
|
|
| 6 |
from pandasai.llm import HuggingFaceTextGen
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from langchain_groq.chat_models import ChatGroq
|
| 9 |
+
from langchain_google_genai import GoogleGenerativeAI
|
| 10 |
|
| 11 |
load_dotenv()
|
| 12 |
Groq_Token = os.environ["GROQ_API_KEY"]
|
| 13 |
+
models = {"mixtral": "mixtral-8x7b-32768", "llama": "llama2-70b-4096", "gemma": "gemma-7b-it", "gemini-pro": "gemini-pro"}
|
| 14 |
|
| 15 |
hf_token = os.getenv("HF_READ")
|
| 16 |
+
gemini_token = os.getenv("GEMINI_TOKEN")
|
| 17 |
|
| 18 |
def preprocess_and_load_df(path: str) -> pd.DataFrame:
|
| 19 |
df = pd.read_csv(path)
|
|
|
|
| 29 |
# top_k=5,
|
| 30 |
# )
|
| 31 |
# llm.client.headers = {"Authorization": f"Bearer {hf_token}"}
|
| 32 |
+
if name == "gemini-pro":
|
| 33 |
+
llm = GoogleGenerativeAI(model=model, google_api_key=gemini_token, temperature=0.1)
|
| 34 |
+
else:
|
| 35 |
+
llm = ChatGroq(model=models[name], api_key=os.getenv("GROQ_API"), temperature=0.1)
|
| 36 |
|
| 37 |
agent = Agent(df, config={"llm": llm, "enable_cache": False, "options": {"wait_for_model": True}})
|
| 38 |
agent.add_message(context)
|
|
|
|
| 91 |
return {"is_image": False}
|
| 92 |
|
| 93 |
def ask_question(model_name, question):
|
| 94 |
+
if model_name == "gemini-pro":
|
| 95 |
+
llm = GoogleGenerativeAI(model=model, google_api_key=os.environ.get("GOOGLE_API_KEY"), temperature=0)
|
| 96 |
+
else:
|
| 97 |
+
llm = ChatGroq(model=models[model_name], api_key=os.getenv("GROQ_API"), temperature=0.1)
|
| 98 |
|
| 99 |
df_check = pd.read_csv("Data.csv")
|
| 100 |
df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
|
|
|
|
| 129 |
{template}
|
| 130 |
|
| 131 |
"""
|
| 132 |
+
if model_name == "gemini-pro":
|
| 133 |
+
answer = llm.invoke(query)
|
| 134 |
+
else:
|
| 135 |
+
answer = llm.invoke(query).content
|
| 136 |
code = f"""
|
| 137 |
{template.split("```python")[1].split("```")[0]}
|
| 138 |
+
{answer.split("```python")[1].split("```")[0]}
|
| 139 |
"""
|
| 140 |
# update variable `answer` when code is executed
|
| 141 |
exec(code)
|