Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,7 @@ st.title("π DigiTs the Twin")
|
|
| 21 |
with st.sidebar:
|
| 22 |
st.header("π Upload Knowledge Files")
|
| 23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
| 24 |
-
model_choice = st.selectbox("π§ Choose Model", ["Qwen", "Mistral"])
|
| 25 |
if uploaded_files:
|
| 26 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
| 27 |
|
|
@@ -30,10 +30,8 @@ with st.sidebar:
|
|
| 30 |
def load_model(selected_model):
|
| 31 |
if selected_model == "Qwen":
|
| 32 |
model_id = "amiguel/GM_Qwen1.8B_Finetune"
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
model_id = "amiguel/Llama3_8B_Instruct_FP16"
|
| 36 |
-
|
| 37 |
else:
|
| 38 |
model_id = "amiguel/GM_Mistral7B_Finetune"
|
| 39 |
|
|
@@ -61,7 +59,6 @@ SYSTEM_PROMPT = (
|
|
| 61 |
# --- Prompt Builder ---
|
| 62 |
def build_prompt(messages, context="", model_name="Qwen"):
|
| 63 |
if "Mistral" in model_name:
|
| 64 |
-
# Alpaca-style prompt
|
| 65 |
prompt = f"You are DigiTwin, an expert in offshore inspection, maintenance, and asset integrity.\n"
|
| 66 |
if context:
|
| 67 |
prompt += f"Here is relevant context:\n{context}\n\n"
|
|
@@ -71,8 +68,18 @@ def build_prompt(messages, context="", model_name="Qwen"):
|
|
| 71 |
elif msg["role"] == "assistant":
|
| 72 |
prompt += f"### Response:\n{msg['content'].strip()}\n"
|
| 73 |
prompt += "### Response:\n"
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
|
| 77 |
for msg in messages:
|
| 78 |
role = msg["role"]
|
|
@@ -80,7 +87,6 @@ def build_prompt(messages, context="", model_name="Qwen"):
|
|
| 80 |
prompt += "<|im_start|>assistant\n"
|
| 81 |
return prompt
|
| 82 |
|
| 83 |
-
|
| 84 |
# --- Embed Uploaded Documents ---
|
| 85 |
@st.cache_resource
|
| 86 |
def embed_uploaded_files(files):
|
|
@@ -125,7 +131,7 @@ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/99
|
|
| 125 |
if "messages" not in st.session_state:
|
| 126 |
st.session_state.messages = []
|
| 127 |
|
| 128 |
-
# --- Display
|
| 129 |
for msg in st.session_state.messages:
|
| 130 |
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
|
| 131 |
st.markdown(msg["content"])
|
|
@@ -141,7 +147,6 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
|
|
| 141 |
docs = retriever.similarity_search(prompt, k=3)
|
| 142 |
context = "\n\n".join([doc.page_content for doc in docs])
|
| 143 |
|
| 144 |
-
# Limit to last 6 messages for memory
|
| 145 |
recent_messages = st.session_state.messages[-6:]
|
| 146 |
full_prompt = build_prompt(recent_messages, context, model_name=model_id)
|
| 147 |
|
|
@@ -154,9 +159,10 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
|
|
| 154 |
answer += chunk
|
| 155 |
cleaned = answer
|
| 156 |
|
| 157 |
-
|
| 158 |
-
if "Mistral" in model_id:
|
| 159 |
cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
|
|
|
|
|
|
|
| 160 |
|
| 161 |
container.markdown(cleaned + "β", unsafe_allow_html=True)
|
| 162 |
|
|
|
|
| 21 |
with st.sidebar:
|
| 22 |
st.header("π Upload Knowledge Files")
|
| 23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
| 24 |
+
model_choice = st.selectbox("π§ Choose Model", ["Qwen", "Mistral", "Llama3"])
|
| 25 |
if uploaded_files:
|
| 26 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
| 27 |
|
|
|
|
| 30 |
def load_model(selected_model):
|
| 31 |
if selected_model == "Qwen":
|
| 32 |
model_id = "amiguel/GM_Qwen1.8B_Finetune"
|
| 33 |
+
elif selected_model == "Llama3":
|
| 34 |
+
model_id = "amiguel/Llama3_8B_Instruct_FP16"
|
|
|
|
|
|
|
| 35 |
else:
|
| 36 |
model_id = "amiguel/GM_Mistral7B_Finetune"
|
| 37 |
|
|
|
|
| 59 |
# --- Prompt Builder ---
|
| 60 |
def build_prompt(messages, context="", model_name="Qwen"):
|
| 61 |
if "Mistral" in model_name:
|
|
|
|
| 62 |
prompt = f"You are DigiTwin, an expert in offshore inspection, maintenance, and asset integrity.\n"
|
| 63 |
if context:
|
| 64 |
prompt += f"Here is relevant context:\n{context}\n\n"
|
|
|
|
| 68 |
elif msg["role"] == "assistant":
|
| 69 |
prompt += f"### Response:\n{msg['content'].strip()}\n"
|
| 70 |
prompt += "### Response:\n"
|
| 71 |
+
|
| 72 |
+
elif "Llama" in model_name:
|
| 73 |
+
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
|
| 74 |
+
prompt += f"{SYSTEM_PROMPT}\n\nContext:\n{context}\n"
|
| 75 |
+
for msg in messages:
|
| 76 |
+
if msg["role"] == "user":
|
| 77 |
+
prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"].strip() + "\n"
|
| 78 |
+
elif msg["role"] == "assistant":
|
| 79 |
+
prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"].strip() + "\n"
|
| 80 |
+
prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
|
| 81 |
+
|
| 82 |
+
else: # Qwen
|
| 83 |
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
|
| 84 |
for msg in messages:
|
| 85 |
role = msg["role"]
|
|
|
|
| 87 |
prompt += "<|im_start|>assistant\n"
|
| 88 |
return prompt
|
| 89 |
|
|
|
|
| 90 |
# --- Embed Uploaded Documents ---
|
| 91 |
@st.cache_resource
|
| 92 |
def embed_uploaded_files(files):
|
|
|
|
| 131 |
if "messages" not in st.session_state:
|
| 132 |
st.session_state.messages = []
|
| 133 |
|
| 134 |
+
# --- Display Chat History ---
|
| 135 |
for msg in st.session_state.messages:
|
| 136 |
with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
|
| 137 |
st.markdown(msg["content"])
|
|
|
|
| 147 |
docs = retriever.similarity_search(prompt, k=3)
|
| 148 |
context = "\n\n".join([doc.page_content for doc in docs])
|
| 149 |
|
|
|
|
| 150 |
recent_messages = st.session_state.messages[-6:]
|
| 151 |
full_prompt = build_prompt(recent_messages, context, model_name=model_id)
|
| 152 |
|
|
|
|
| 159 |
answer += chunk
|
| 160 |
cleaned = answer
|
| 161 |
|
| 162 |
+
if "Mistral" in model_id or "Llama" in model_id:
|
|
|
|
| 163 |
cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
|
| 164 |
+
cleaned = cleaned.replace("<|start_header_id|>", "").replace("<|end_header_id|>", "")
|
| 165 |
+
cleaned = cleaned.replace("<|begin_of_text|>", "").strip()
|
| 166 |
|
| 167 |
container.markdown(cleaned + "β", unsafe_allow_html=True)
|
| 168 |
|