Update app.py
Browse files
app.py
CHANGED
|
@@ -17,27 +17,33 @@ HF_TOKEN = st.secrets["HF_TOKEN"]
|
|
| 17 |
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered")
|
| 18 |
st.title("π DigiTs the Twin")
|
| 19 |
|
| 20 |
-
# ---
|
| 21 |
with st.sidebar:
|
| 22 |
st.header("π Upload Knowledge Files")
|
| 23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
|
|
|
| 24 |
if uploaded_files:
|
| 25 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
| 26 |
|
| 27 |
# --- Load Model & Tokenizer ---
|
| 28 |
@st.cache_resource
|
| 29 |
-
def load_model():
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
model = AutoModelForCausalLM.from_pretrained(
|
| 32 |
-
|
| 33 |
device_map="auto",
|
| 34 |
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
|
| 35 |
trust_remote_code=True,
|
| 36 |
token=HF_TOKEN
|
| 37 |
)
|
| 38 |
-
return model, tokenizer
|
| 39 |
|
| 40 |
-
model, tokenizer = load_model()
|
| 41 |
|
| 42 |
# --- System Prompt ---
|
| 43 |
SYSTEM_PROMPT = (
|
|
@@ -128,14 +134,19 @@ if prompt := st.chat_input("Ask something based on uploaded documents..."):
|
|
| 128 |
|
| 129 |
for chunk in generate_response(full_prompt):
|
| 130 |
answer += chunk
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
end = time.time()
|
| 135 |
-
st.session_state.messages.append({"role": "assistant", "content":
|
| 136 |
|
| 137 |
input_tokens = len(tokenizer(full_prompt)["input_ids"])
|
| 138 |
-
output_tokens = len(tokenizer(
|
| 139 |
speed = output_tokens / (end - start)
|
| 140 |
|
| 141 |
with st.expander("π Debug Info"):
|
|
|
|
| 17 |
st.set_page_config(page_title="DigiTwin RAG", page_icon="π", layout="centered")
|
| 18 |
st.title("π DigiTs the Twin")
|
| 19 |
|
| 20 |
+
# --- Sidebar ---
|
| 21 |
with st.sidebar:
|
| 22 |
st.header("π Upload Knowledge Files")
|
| 23 |
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
|
| 24 |
+
model_choice = st.selectbox("π§ Choose Model", ["Qwen", "Mistral"])
|
| 25 |
if uploaded_files:
|
| 26 |
st.success(f"{len(uploaded_files)} file(s) uploaded")
|
| 27 |
|
| 28 |
# --- Load Model & Tokenizer ---
|
| 29 |
@st.cache_resource
|
| 30 |
+
def load_model(selected_model):
|
| 31 |
+
if selected_model == "Qwen":
|
| 32 |
+
model_id = "amiguel/GM_Qwen1.8B_Finetune"
|
| 33 |
+
else:
|
| 34 |
+
model_id = "amiguel/GM_Mistral7B_Finetune"
|
| 35 |
+
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=HF_TOKEN)
|
| 37 |
model = AutoModelForCausalLM.from_pretrained(
|
| 38 |
+
model_id,
|
| 39 |
device_map="auto",
|
| 40 |
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
|
| 41 |
trust_remote_code=True,
|
| 42 |
token=HF_TOKEN
|
| 43 |
)
|
| 44 |
+
return model, tokenizer, model_id
|
| 45 |
|
| 46 |
+
model, tokenizer, model_id = load_model(model_choice)
|
| 47 |
|
| 48 |
# --- System Prompt ---
|
| 49 |
SYSTEM_PROMPT = (
|
|
|
|
| 134 |
|
| 135 |
for chunk in generate_response(full_prompt):
|
| 136 |
answer += chunk
|
| 137 |
+
cleaned = answer
|
| 138 |
+
|
| 139 |
+
# π§ Strip <|im_start|>, <|im_end|> if using Mistral (Qwen needs them)
|
| 140 |
+
if "Mistral" in model_id:
|
| 141 |
+
cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()
|
| 142 |
+
|
| 143 |
+
container.markdown(cleaned + "β", unsafe_allow_html=True)
|
| 144 |
|
| 145 |
end = time.time()
|
| 146 |
+
st.session_state.messages.append({"role": "assistant", "content": cleaned})
|
| 147 |
|
| 148 |
input_tokens = len(tokenizer(full_prompt)["input_ids"])
|
| 149 |
+
output_tokens = len(tokenizer(cleaned)["input_ids"])
|
| 150 |
speed = output_tokens / (end - start)
|
| 151 |
|
| 152 |
with st.expander("π Debug Info"):
|