Spaces:
Running
Running
Commit
·
db3eaec
1
Parent(s):
39f4270
Fixes on representation model, visualisations, and embeddings in CPU mode. Package updates and optimisation for compatibility
Browse files- Dockerfile +2 -5
- app.py +6 -6
- funcs/bertopic_vis_documents.py +81 -37
- funcs/embeddings.py +4 -0
- funcs/prompts.py +5 -5
- funcs/representation_model.py +236 -2
- requirements.txt +2 -2
- requirements_aws.txt +2 -2
- requirements_gpu.txt +1 -1
Dockerfile
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# Stage 1: Build dependencies and download models
|
| 2 |
-
FROM public.ecr.aws/docker/library/python:3.11
|
| 3 |
|
| 4 |
# Install Lambda web adapter in case you want to run with with an AWS Lamba function URL (not essential if not using Lambda)
|
| 5 |
#COPY --from=public.ecr.aws/awsguru/aws-lambda-adapter:0.8.4 /lambda-adapter /opt/extensions/lambda-adapter
|
|
@@ -31,7 +31,7 @@ RUN python /src/download_model.py
|
|
| 31 |
RUN rm requirements_aws.txt download_model.py
|
| 32 |
|
| 33 |
# Stage 2: Final runtime image
|
| 34 |
-
FROM public.ecr.aws/docker/library/python:3.
|
| 35 |
|
| 36 |
# Create a non-root user
|
| 37 |
RUN useradd -m -u 1000 user
|
|
@@ -43,9 +43,6 @@ COPY --from=builder /install /usr/local/lib/python3.11/site-packages/
|
|
| 43 |
RUN mkdir -p /home/user/app/output /home/user/.cache/huggingface/hub /home/user/.cache/matplotlib /home/user/app/cache \
|
| 44 |
&& chown -R user:user /home/user
|
| 45 |
|
| 46 |
-
# Download the quantised phi model directly with curl. Changed at it is so big - not loaded
|
| 47 |
-
#RUN curl -L -o /home/user/app/model/rep/Llama-3.2-3B-Instruct-Q5_K_M.gguf https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/tree/main/Llama-3.2-3B-Instruct-Q5_K_M.gguf
|
| 48 |
-
|
| 49 |
# Copy models from the builder stage
|
| 50 |
COPY --from=builder /model/rep /home/user/app/model/rep
|
| 51 |
COPY --from=builder /model/embed /home/user/app/model/embed
|
|
|
|
| 1 |
# Stage 1: Build dependencies and download models
|
| 2 |
+
FROM public.ecr.aws/docker/library/python:3.12.11-slim-trixie AS builder
|
| 3 |
|
| 4 |
# Install Lambda web adapter in case you want to run with with an AWS Lamba function URL (not essential if not using Lambda)
|
| 5 |
#COPY --from=public.ecr.aws/awsguru/aws-lambda-adapter:0.8.4 /lambda-adapter /opt/extensions/lambda-adapter
|
|
|
|
| 31 |
RUN rm requirements_aws.txt download_model.py
|
| 32 |
|
| 33 |
# Stage 2: Final runtime image
|
| 34 |
+
FROM public.ecr.aws/docker/library/python:3.12.1-slim-trixie
|
| 35 |
|
| 36 |
# Create a non-root user
|
| 37 |
RUN useradd -m -u 1000 user
|
|
|
|
| 43 |
RUN mkdir -p /home/user/app/output /home/user/.cache/huggingface/hub /home/user/.cache/matplotlib /home/user/app/cache \
|
| 44 |
&& chown -R user:user /home/user
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
# Copy models from the builder stage
|
| 47 |
COPY --from=builder /model/rep /home/user/app/model/rep
|
| 48 |
COPY --from=builder /model/embed /home/user/app/model/embed
|
app.py
CHANGED
|
@@ -27,7 +27,7 @@ usage_logs_folder = 'usage/' + today_rev + '/' + host_name + '/'
|
|
| 27 |
|
| 28 |
# Gradio app
|
| 29 |
|
| 30 |
-
app = gr.Blocks(theme
|
| 31 |
|
| 32 |
with app:
|
| 33 |
|
|
@@ -77,14 +77,14 @@ with app:
|
|
| 77 |
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
| 78 |
|
| 79 |
with gr.Accordion("Clean data", open = False):
|
| 80 |
-
with gr.Row():
|
| 81 |
clean_text = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Remove html, URLs, non-ASCII, large numbers, emails, postcodes (UK).")
|
| 82 |
drop_duplicate_text = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Remove duplicate text, drop < 50 character strings.")
|
| 83 |
anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Redact personal information - not 100% effective and slow!")
|
| 84 |
#with gr.Row():
|
| 85 |
split_sentence_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Split text into sentences. Useful for small datasets.")
|
| 86 |
#additional_custom_delimiters_drop = gr.Dropdown(choices=["and", ",", "as well as", "also"], multiselect=True, label="Additional custom delimiters to split sentences.")
|
| 87 |
-
min_sentence_length_num = gr.Number(value=5, label="
|
| 88 |
|
| 89 |
with gr.Row():
|
| 90 |
custom_regex = gr.UploadButton(label="Import custom regex removal file", file_count="multiple")
|
|
@@ -115,11 +115,11 @@ with app:
|
|
| 115 |
topics_btn = gr.Button("Extract topics", variant="primary")
|
| 116 |
|
| 117 |
with gr.Row():
|
| 118 |
-
output_single_text = gr.Textbox(label="Output topics")
|
| 119 |
output_file = gr.File(label="Output file")
|
| 120 |
|
| 121 |
with gr.Accordion("Post processing options.", open = True):
|
| 122 |
-
with gr.Row():
|
| 123 |
representation_type = gr.Dropdown(label = "Method for generating new topic labels", value="Default", choices=["Default", "MMR", "KeyBERT", "LLM"])
|
| 124 |
represent_llm_btn = gr.Button("Change topic labels")
|
| 125 |
with gr.Row():
|
|
@@ -135,7 +135,7 @@ with app:
|
|
| 135 |
|
| 136 |
plot_btn = gr.Button("Visualise topic model")
|
| 137 |
with gr.Row():
|
| 138 |
-
vis_output_single_text = gr.Textbox(label="Visualisation output text")
|
| 139 |
out_plot_file = gr.File(label="Output plots to file", file_count="multiple")
|
| 140 |
plot = gr.Plot(label="Visualise your topics here.")
|
| 141 |
plot_2 = gr.Plot(label="Visualise your topics here.")
|
|
|
|
| 27 |
|
| 28 |
# Gradio app
|
| 29 |
|
| 30 |
+
app = gr.Blocks(theme=gr.themes.Default(primary_hue="blue"), fill_width = True)
|
| 31 |
|
| 32 |
with app:
|
| 33 |
|
|
|
|
| 77 |
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
| 78 |
|
| 79 |
with gr.Accordion("Clean data", open = False):
|
| 80 |
+
with gr.Row(equal_height = True):
|
| 81 |
clean_text = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Remove html, URLs, non-ASCII, large numbers, emails, postcodes (UK).")
|
| 82 |
drop_duplicate_text = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Remove duplicate text, drop < 50 character strings.")
|
| 83 |
anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Redact personal information - not 100% effective and slow!")
|
| 84 |
#with gr.Row():
|
| 85 |
split_sentence_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Split text into sentences. Useful for small datasets.")
|
| 86 |
#additional_custom_delimiters_drop = gr.Dropdown(choices=["and", ",", "as well as", "also"], multiselect=True, label="Additional custom delimiters to split sentences.")
|
| 87 |
+
min_sentence_length_num = gr.Number(value=5, label="Minimum character length of split sentences")
|
| 88 |
|
| 89 |
with gr.Row():
|
| 90 |
custom_regex = gr.UploadButton(label="Import custom regex removal file", file_count="multiple")
|
|
|
|
| 115 |
topics_btn = gr.Button("Extract topics", variant="primary")
|
| 116 |
|
| 117 |
with gr.Row():
|
| 118 |
+
output_single_text = gr.Textbox(label="Output topics", lines = 5)
|
| 119 |
output_file = gr.File(label="Output file")
|
| 120 |
|
| 121 |
with gr.Accordion("Post processing options.", open = True):
|
| 122 |
+
with gr.Row(equal_height = True):
|
| 123 |
representation_type = gr.Dropdown(label = "Method for generating new topic labels", value="Default", choices=["Default", "MMR", "KeyBERT", "LLM"])
|
| 124 |
represent_llm_btn = gr.Button("Change topic labels")
|
| 125 |
with gr.Row():
|
|
|
|
| 135 |
|
| 136 |
plot_btn = gr.Button("Visualise topic model")
|
| 137 |
with gr.Row():
|
| 138 |
+
vis_output_single_text = gr.Textbox(label="Visualisation output text (if data points don't appear below, download the html output to see them)")
|
| 139 |
out_plot_file = gr.File(label="Output plots to file", file_count="multiple")
|
| 140 |
plot = gr.Plot(label="Visualise your topics here.")
|
| 141 |
plot_2 = gr.Plot(label="Visualise your topics here.")
|
funcs/bertopic_vis_documents.py
CHANGED
|
@@ -197,50 +197,94 @@ def visualize_documents_custom(topic_model,
|
|
| 197 |
if len(non_selected_topics) == 0:
|
| 198 |
non_selected_topics = [-1]
|
| 199 |
|
| 200 |
-
selection = df.loc[df.topic.isin(non_selected_topics), :]
|
| 201 |
-
selection
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
# Selected topics
|
| 219 |
for name, topic in zip(names, unique_topics):
|
| 220 |
#print(name)
|
| 221 |
#print(topic)
|
| 222 |
if topic in topics and topic != -1:
|
| 223 |
-
selection = df.loc[df.topic == topic, :]
|
| 224 |
-
selection
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
# Add grid in a 'plus' shape
|
| 246 |
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
|
|
|
| 197 |
if len(non_selected_topics) == 0:
|
| 198 |
non_selected_topics = [-1]
|
| 199 |
|
| 200 |
+
selection = df.loc[df.topic.isin(non_selected_topics), :].copy()
|
| 201 |
+
if len(selection) > 0:
|
| 202 |
+
selection["text"] = ""
|
| 203 |
+
# Only add annotation row if selection is not empty
|
| 204 |
+
if not hide_annotations:
|
| 205 |
+
annotation_row = pd.DataFrame({
|
| 206 |
+
"topic": [None],
|
| 207 |
+
"doc": [None],
|
| 208 |
+
"hover_labels": [None],
|
| 209 |
+
"x": [selection.x.mean()],
|
| 210 |
+
"y": [selection.y.mean()],
|
| 211 |
+
"text": ["Other documents"]
|
| 212 |
+
})
|
| 213 |
+
selection = pd.concat([selection, annotation_row], ignore_index=True)
|
| 214 |
+
|
| 215 |
+
# Filter out rows where x or y is NaN to keep arrays aligned
|
| 216 |
+
valid_mask = selection.x.notna() & selection.y.notna()
|
| 217 |
+
selection_valid = selection[valid_mask].copy()
|
| 218 |
+
|
| 219 |
+
# Convert to lists to avoid Series issues
|
| 220 |
+
x_vals = selection_valid.x.tolist()
|
| 221 |
+
y_vals = selection_valid.y.tolist()
|
| 222 |
+
hover_vals = selection_valid.hover_labels.tolist() if not hide_document_hover and len(selection_valid) > 0 else None
|
| 223 |
+
text_vals = selection_valid.text.tolist() if len(selection_valid) > 0 else []
|
| 224 |
+
|
| 225 |
+
if len(x_vals) > 0: # Only add trace if there are valid data points
|
| 226 |
+
fig.add_trace(
|
| 227 |
+
go.Scattergl(
|
| 228 |
+
x=x_vals,
|
| 229 |
+
y=y_vals,
|
| 230 |
+
hovertext=hover_vals,
|
| 231 |
+
hoverinfo="text",
|
| 232 |
+
mode='markers+text',
|
| 233 |
+
name="other",
|
| 234 |
+
showlegend=False,
|
| 235 |
+
marker=dict(color='#CFD8DC', size=5, opacity=0.5),
|
| 236 |
+
hoverlabel=dict(align='left'),
|
| 237 |
+
text=text_vals if len(text_vals) > 0 and any(t for t in text_vals if t) else None
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
|
| 241 |
# Selected topics
|
| 242 |
for name, topic in zip(names, unique_topics):
|
| 243 |
#print(name)
|
| 244 |
#print(topic)
|
| 245 |
if topic in topics and topic != -1:
|
| 246 |
+
selection = df.loc[df.topic == topic, :].copy()
|
| 247 |
+
if len(selection) > 0:
|
| 248 |
+
selection["text"] = ""
|
| 249 |
|
| 250 |
+
if not hide_annotations:
|
| 251 |
+
# Add annotation row properly using DataFrame concat
|
| 252 |
+
annotation_row = pd.DataFrame({
|
| 253 |
+
"topic": [None],
|
| 254 |
+
"doc": [None],
|
| 255 |
+
"hover_labels": [None],
|
| 256 |
+
"x": [selection.x.mean()],
|
| 257 |
+
"y": [selection.y.mean()],
|
| 258 |
+
"text": [name]
|
| 259 |
+
})
|
| 260 |
+
selection = pd.concat([selection, annotation_row], ignore_index=True)
|
| 261 |
+
|
| 262 |
+
# Filter out rows where x or y is NaN to keep arrays aligned
|
| 263 |
+
valid_mask = selection.x.notna() & selection.y.notna()
|
| 264 |
+
selection_valid = selection[valid_mask].copy()
|
| 265 |
+
|
| 266 |
+
# Convert to lists to avoid Series issues
|
| 267 |
+
x_vals = selection_valid.x.tolist()
|
| 268 |
+
y_vals = selection_valid.y.tolist()
|
| 269 |
+
hover_vals = selection_valid.hover_labels.tolist() if not hide_document_hover else None
|
| 270 |
+
text_vals = selection_valid.text.tolist()
|
| 271 |
+
|
| 272 |
+
if len(x_vals) > 0: # Only add trace if there are valid data points
|
| 273 |
+
fig.add_trace(
|
| 274 |
+
go.Scattergl(
|
| 275 |
+
x=x_vals,
|
| 276 |
+
y=y_vals,
|
| 277 |
+
hovertext=hover_vals,
|
| 278 |
+
hoverinfo="text",
|
| 279 |
+
text=text_vals if len(text_vals) > 0 and any(t for t in text_vals if t) else None,
|
| 280 |
+
mode='markers+text',
|
| 281 |
+
name=name,
|
| 282 |
+
textfont=dict(
|
| 283 |
+
size=12,
|
| 284 |
+
),
|
| 285 |
+
marker=dict(size=5, opacity=0.5),
|
| 286 |
+
hoverlabel=dict(align='left')
|
| 287 |
+
))
|
| 288 |
|
| 289 |
# Add grid in a 'plus' shape
|
| 290 |
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
funcs/embeddings.py
CHANGED
|
@@ -71,6 +71,10 @@ def make_or_load_embeddings(docs: list, file_list: list, embeddings_out: np.ndar
|
|
| 71 |
TruncatedSVD(100, random_state=random_seed)
|
| 72 |
)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# If no embeddings found, make or load in
|
| 75 |
if embeddings_out.size == 0:
|
| 76 |
print("Embeddings not found. Loading or generating new ones.")
|
|
|
|
| 71 |
TruncatedSVD(100, random_state=random_seed)
|
| 72 |
)
|
| 73 |
|
| 74 |
+
# Ensure embeddings_out is a numpy array (handle case where it might be a string from Gradio state)
|
| 75 |
+
if not isinstance(embeddings_out, np.ndarray):
|
| 76 |
+
embeddings_out = np.array([])
|
| 77 |
+
|
| 78 |
# If no embeddings found, make or load in
|
| 79 |
if embeddings_out.size == 0:
|
| 80 |
print("Embeddings not found. Loading or generating new ones.")
|
funcs/prompts.py
CHANGED
|
@@ -16,7 +16,7 @@ capybara_example_prompt = """USER:I have a topic that contains the following doc
|
|
| 16 |
|
| 17 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 18 |
|
| 19 |
-
Based on the information about the topic above, please create a short label of this topic.
|
| 20 |
|
| 21 |
Topic label: Environmental impacts of eating meat
|
| 22 |
"""
|
|
@@ -54,7 +54,7 @@ I have a topic that contains the following documents:
|
|
| 54 |
|
| 55 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 56 |
|
| 57 |
-
Based on the information about the topic above, please create a short label of this topic.
|
| 58 |
|
| 59 |
Topic label: Environmental impacts of eating meat
|
| 60 |
"""
|
|
@@ -83,7 +83,7 @@ I have a topic that contains the following documents:
|
|
| 83 |
|
| 84 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 85 |
|
| 86 |
-
Based on the information about the topic above, please create a short label of this topic.
|
| 87 |
|
| 88 |
Topic label: Environmental impacts of eating meat
|
| 89 |
"""
|
|
@@ -115,7 +115,7 @@ I have a topic that contains the following documents:
|
|
| 115 |
|
| 116 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 117 |
|
| 118 |
-
Based on the information about the topic above, please create a short label of this topic.
|
| 119 |
|
| 120 |
Topic label: Environmental impacts of eating meat
|
| 121 |
"""
|
|
@@ -129,7 +129,7 @@ I have a topic that contains the following documents:
|
|
| 129 |
|
| 130 |
The topic is described by the following keywords: '[KEYWORDS]'.
|
| 131 |
|
| 132 |
-
Based on the information about the topic above, please create a short label of this topic.
|
| 133 |
<|assistant|>
|
| 134 |
Topic label:"""
|
| 135 |
|
|
|
|
| 16 |
|
| 17 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 18 |
|
| 19 |
+
Based on the information about the topic above, please create a short label of this topic. Return only the label and no other text or explanation.
|
| 20 |
|
| 21 |
Topic label: Environmental impacts of eating meat
|
| 22 |
"""
|
|
|
|
| 54 |
|
| 55 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 56 |
|
| 57 |
+
Based on the information about the topic above, please create a short label of this topic. Return only the label and no other text or explanation.
|
| 58 |
|
| 59 |
Topic label: Environmental impacts of eating meat
|
| 60 |
"""
|
|
|
|
| 83 |
|
| 84 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 85 |
|
| 86 |
+
Based on the information about the topic above, please create a short label of this topic. Return only the label and no other text or explanation.
|
| 87 |
|
| 88 |
Topic label: Environmental impacts of eating meat
|
| 89 |
"""
|
|
|
|
| 115 |
|
| 116 |
The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
|
| 117 |
|
| 118 |
+
Based on the information about the topic above, please create a short label of this topic. Return only the label and no other text or explanation.
|
| 119 |
|
| 120 |
Topic label: Environmental impacts of eating meat
|
| 121 |
"""
|
|
|
|
| 129 |
|
| 130 |
The topic is described by the following keywords: '[KEYWORDS]'.
|
| 131 |
|
| 132 |
+
Based on the information about the topic above, please create a short label of this topic. Return only the label and no other text or explanation.<|end|>
|
| 133 |
<|assistant|>
|
| 134 |
Topic label:"""
|
| 135 |
|
funcs/representation_model.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import spaces
|
| 3 |
from bertopic.representation import LlamaCPP
|
| 4 |
|
|
@@ -8,10 +9,227 @@ from huggingface_hub import hf_hub_download
|
|
| 8 |
from gradio import Warning
|
| 9 |
|
| 10 |
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, BaseRepresentation
|
| 11 |
-
from funcs.embeddings import torch_device
|
| 12 |
from funcs.prompts import phi3_prompt, phi3_start
|
| 13 |
from funcs.helper_functions import get_or_create_env_var, GPU_SPACE_DURATION
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
chosen_prompt = phi3_prompt #open_hermes_prompt # stablelm_prompt
|
| 16 |
chosen_start_tag = phi3_start #open_hermes_start # stablelm_start
|
| 17 |
|
|
@@ -222,7 +440,23 @@ def create_representation_model(representation_type: str, llm_config: dict, hf_m
|
|
| 222 |
print("Loading representation model with", llm_config.n_gpu_layers, "layers allocated to GPU.")
|
| 223 |
|
| 224 |
#llm_config.n_gpu_layers
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
#print(llm.n_gpu_layers)
|
| 227 |
#print("Chosen prompt:", chosen_prompt)
|
| 228 |
llm_model = LlamaCPP(llm, prompt=chosen_prompt)#, **gen_config.model_dump())
|
|
|
|
| 1 |
import os
|
| 2 |
+
import re
|
| 3 |
import spaces
|
| 4 |
from bertopic.representation import LlamaCPP
|
| 5 |
|
|
|
|
| 9 |
from gradio import Warning
|
| 10 |
|
| 11 |
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, BaseRepresentation
|
|
|
|
| 12 |
from funcs.prompts import phi3_prompt, phi3_start
|
| 13 |
from funcs.helper_functions import get_or_create_env_var, GPU_SPACE_DURATION
|
| 14 |
|
| 15 |
+
|
| 16 |
+
def clean_llm_output_text(text: str) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Clean LLM output text by removing special characters.
|
| 19 |
+
Keeps only: letters, numbers, spaces, dashes, and apostrophes (for contractions).
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
text: The text to clean
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Cleaned text with special characters removed
|
| 26 |
+
"""
|
| 27 |
+
if not text:
|
| 28 |
+
return ""
|
| 29 |
+
|
| 30 |
+
# Keep only alphanumeric characters, spaces, dashes, and apostrophes
|
| 31 |
+
# This regex keeps: a-z, A-Z, 0-9, spaces, hyphens/dashes, and apostrophes
|
| 32 |
+
cleaned = re.sub(r'[^a-zA-Z0-9\s\-\']', '', text)
|
| 33 |
+
|
| 34 |
+
# Clean up multiple spaces and strip
|
| 35 |
+
cleaned = re.sub(r'\s+', ' ', cleaned)
|
| 36 |
+
cleaned = cleaned.strip()
|
| 37 |
+
|
| 38 |
+
return cleaned
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def patch_llama_create_chat_completion(llama_model):
|
| 42 |
+
"""
|
| 43 |
+
Monkey-patch the create_chat_completion method on a Llama model instance
|
| 44 |
+
to use raw completion instead of chat format handler.
|
| 45 |
+
This avoids the "System role not supported" error for models like phi3.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
llama_model: The Llama model instance to patch
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
The same llama_model instance with patched create_chat_completion method
|
| 52 |
+
"""
|
| 53 |
+
def patched_create_chat_completion(messages, **kwargs):
|
| 54 |
+
"""
|
| 55 |
+
Override create_chat_completion to use raw completion.
|
| 56 |
+
This avoids the chat format handler that requires system roles (not supported by phi3).
|
| 57 |
+
BERTopic's LlamaCPP formats messages and uses the prompt template, so we reconstruct
|
| 58 |
+
the full prompt from the messages.
|
| 59 |
+
"""
|
| 60 |
+
# Reconstruct the prompt from messages
|
| 61 |
+
# BERTopic's LlamaCPP passes messages in OpenAI format: [{"role": "user", "content": "..."}]
|
| 62 |
+
prompt_parts = []
|
| 63 |
+
for msg in messages:
|
| 64 |
+
if isinstance(msg, dict):
|
| 65 |
+
role = msg.get('role', 'user')
|
| 66 |
+
content = msg.get('content', '')
|
| 67 |
+
# Skip system messages as phi3 doesn't support them
|
| 68 |
+
if role != 'system' and content:
|
| 69 |
+
prompt_parts.append(content)
|
| 70 |
+
else:
|
| 71 |
+
prompt_parts.append(str(msg))
|
| 72 |
+
|
| 73 |
+
# Join all message contents into a single prompt
|
| 74 |
+
prompt = '\n'.join(prompt_parts) if prompt_parts else ''
|
| 75 |
+
|
| 76 |
+
# Use raw completion instead of chat completion
|
| 77 |
+
# This avoids the chat format handler that requires system roles
|
| 78 |
+
# Remove chat-specific kwargs that might cause issues, but enable streaming
|
| 79 |
+
completion_kwargs = {k: v for k, v in kwargs.items()
|
| 80 |
+
if k not in ['messages', 'chat_format', 'chat_handler']}
|
| 81 |
+
|
| 82 |
+
# Enable streaming to show output in real-time
|
| 83 |
+
completion_kwargs['stream'] = True
|
| 84 |
+
|
| 85 |
+
# Use create_completion for raw text completion (not chat completion)
|
| 86 |
+
# With stream=True, this returns a generator of CompletionChunk objects
|
| 87 |
+
text_parts = []
|
| 88 |
+
try:
|
| 89 |
+
# Create completion with streaming enabled
|
| 90 |
+
completion_stream = llama_model.create_completion(prompt, **completion_kwargs)
|
| 91 |
+
|
| 92 |
+
# Iterate through the stream and collect text
|
| 93 |
+
print("\nLLM Output: ", end="", flush=True) # Print prefix without newline
|
| 94 |
+
for chunk in completion_stream:
|
| 95 |
+
# Extract text from each chunk
|
| 96 |
+
chunk_text = ""
|
| 97 |
+
|
| 98 |
+
# Handle dictionary chunks (the format returned by llama_cpp)
|
| 99 |
+
if isinstance(chunk, dict):
|
| 100 |
+
# Extract from chunk['choices'][0]['text'] - this is the standard format
|
| 101 |
+
if 'choices' in chunk and len(chunk['choices']) > 0:
|
| 102 |
+
choice = chunk['choices'][0]
|
| 103 |
+
if isinstance(choice, dict):
|
| 104 |
+
chunk_text = choice.get('text', '') or choice.get('content', '')
|
| 105 |
+
elif hasattr(choice, 'text'):
|
| 106 |
+
chunk_text = choice.text
|
| 107 |
+
elif hasattr(choice, 'content'):
|
| 108 |
+
chunk_text = choice.content
|
| 109 |
+
elif 'text' in chunk:
|
| 110 |
+
chunk_text = chunk['text']
|
| 111 |
+
elif 'content' in chunk:
|
| 112 |
+
chunk_text = chunk['content']
|
| 113 |
+
|
| 114 |
+
# Try different ways to extract text from the chunk (object format)
|
| 115 |
+
elif hasattr(chunk, 'choices') and len(chunk.choices) > 0:
|
| 116 |
+
choice = chunk.choices[0]
|
| 117 |
+
if hasattr(choice, 'text'):
|
| 118 |
+
chunk_text = choice.text
|
| 119 |
+
elif hasattr(choice, 'delta') and hasattr(choice.delta, 'content'):
|
| 120 |
+
# Some formats use delta.content
|
| 121 |
+
chunk_text = choice.delta.content or ""
|
| 122 |
+
elif hasattr(choice, 'content'):
|
| 123 |
+
chunk_text = choice.content
|
| 124 |
+
elif isinstance(choice, dict):
|
| 125 |
+
chunk_text = choice.get('text', '') or choice.get('delta', {}).get('content', '')
|
| 126 |
+
elif hasattr(chunk, 'text'):
|
| 127 |
+
chunk_text = chunk.text
|
| 128 |
+
elif isinstance(chunk, str):
|
| 129 |
+
chunk_text = chunk
|
| 130 |
+
elif hasattr(chunk, '__dict__'):
|
| 131 |
+
# Check various possible attributes
|
| 132 |
+
chunk_dict = chunk.__dict__
|
| 133 |
+
if 'text' in chunk_dict:
|
| 134 |
+
chunk_text = chunk_dict['text']
|
| 135 |
+
elif 'choices' in chunk_dict:
|
| 136 |
+
choices = chunk_dict['choices']
|
| 137 |
+
if choices and len(choices) > 0:
|
| 138 |
+
if isinstance(choices[0], dict):
|
| 139 |
+
chunk_text = choices[0].get('text', '') or choices[0].get('delta', {}).get('content', '')
|
| 140 |
+
elif hasattr(choices[0], 'text'):
|
| 141 |
+
chunk_text = choices[0].text
|
| 142 |
+
elif hasattr(choices[0], 'delta'):
|
| 143 |
+
delta = choices[0].delta
|
| 144 |
+
if hasattr(delta, 'content'):
|
| 145 |
+
chunk_text = delta.content or ""
|
| 146 |
+
|
| 147 |
+
# Only add non-empty text and filter out debug messages
|
| 148 |
+
if chunk_text and chunk_text.strip():
|
| 149 |
+
# Filter out llama.cpp debug messages
|
| 150 |
+
if not any(debug_keyword in chunk_text for debug_keyword in [
|
| 151 |
+
'llama_perf_context_print', 'Llama.generate', 'load time',
|
| 152 |
+
'prompt eval time', 'eval time', 'total time', 'prefix-match hit'
|
| 153 |
+
]):
|
| 154 |
+
text_parts.append(chunk_text)
|
| 155 |
+
print(chunk_text, end="", flush=True) # Print without newline, flush immediately
|
| 156 |
+
|
| 157 |
+
print() # Newline after streaming is complete
|
| 158 |
+
text = ''.join(text_parts)
|
| 159 |
+
|
| 160 |
+
# Clean the text to remove special characters
|
| 161 |
+
text = clean_llm_output_text(text)
|
| 162 |
+
|
| 163 |
+
# If no text was collected, there might be an issue with chunk extraction
|
| 164 |
+
if not text:
|
| 165 |
+
print("Warning: No text extracted from streaming chunks. Chunk structure may be different.")
|
| 166 |
+
print("Falling back to non-streaming mode.")
|
| 167 |
+
raise Exception("No text in stream")
|
| 168 |
+
|
| 169 |
+
except (AttributeError, TypeError, Exception) as e:
|
| 170 |
+
# Fallback to non-streaming if create_completion doesn't exist or streaming fails
|
| 171 |
+
print(f"\nStreaming failed, falling back to non-streaming mode: {e}")
|
| 172 |
+
completion_kwargs.pop('stream', None) # Remove stream parameter
|
| 173 |
+
try:
|
| 174 |
+
completion = llama_model.create_completion(prompt, **completion_kwargs)
|
| 175 |
+
except AttributeError:
|
| 176 |
+
completion = llama_model(prompt, **completion_kwargs)
|
| 177 |
+
|
| 178 |
+
# Extract text from the completion object
|
| 179 |
+
text = ""
|
| 180 |
+
if hasattr(completion, 'choices') and len(completion.choices) > 0:
|
| 181 |
+
# Standard Completion object format
|
| 182 |
+
if hasattr(completion.choices[0], 'text'):
|
| 183 |
+
text = completion.choices[0].text
|
| 184 |
+
elif hasattr(completion.choices[0], 'content'):
|
| 185 |
+
text = completion.choices[0].content
|
| 186 |
+
elif hasattr(completion, 'text'):
|
| 187 |
+
# Direct text attribute
|
| 188 |
+
text = completion.text
|
| 189 |
+
elif isinstance(completion, str):
|
| 190 |
+
# Already a string
|
| 191 |
+
text = completion
|
| 192 |
+
elif hasattr(completion, '__dict__'):
|
| 193 |
+
# Try to get text from object attributes
|
| 194 |
+
if 'text' in completion.__dict__:
|
| 195 |
+
text = completion.__dict__['text']
|
| 196 |
+
elif 'choices' in completion.__dict__:
|
| 197 |
+
choices = completion.__dict__['choices']
|
| 198 |
+
if choices and len(choices) > 0:
|
| 199 |
+
if isinstance(choices[0], dict):
|
| 200 |
+
text = choices[0].get('text', '')
|
| 201 |
+
elif hasattr(choices[0], 'text'):
|
| 202 |
+
text = choices[0].text
|
| 203 |
+
else:
|
| 204 |
+
# Last resort: convert to string (but this might not work well)
|
| 205 |
+
text = str(completion)
|
| 206 |
+
|
| 207 |
+
# Clean up the text - remove special characters and whitespace
|
| 208 |
+
text = clean_llm_output_text(text) if text else ""
|
| 209 |
+
|
| 210 |
+
# Create a chat completion response as a dictionary
|
| 211 |
+
# BERTopic accesses it as: response["choices"][0]["message"]["content"]
|
| 212 |
+
# Always return a dictionary to ensure it's subscriptable
|
| 213 |
+
return {
|
| 214 |
+
"choices": [{
|
| 215 |
+
"message": {
|
| 216 |
+
"content": text,
|
| 217 |
+
"role": "assistant"
|
| 218 |
+
},
|
| 219 |
+
"finish_reason": "stop",
|
| 220 |
+
"index": 0
|
| 221 |
+
}],
|
| 222 |
+
"id": "custom",
|
| 223 |
+
"created": 0,
|
| 224 |
+
"model": "",
|
| 225 |
+
"object": "chat.completion"
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
# Replace the method on the instance
|
| 229 |
+
llama_model.create_chat_completion = patched_create_chat_completion
|
| 230 |
+
|
| 231 |
+
return llama_model
|
| 232 |
+
|
| 233 |
chosen_prompt = phi3_prompt #open_hermes_prompt # stablelm_prompt
|
| 234 |
chosen_start_tag = phi3_start #open_hermes_start # stablelm_start
|
| 235 |
|
|
|
|
| 440 |
print("Loading representation model with", llm_config.n_gpu_layers, "layers allocated to GPU.")
|
| 441 |
|
| 442 |
#llm_config.n_gpu_layers
|
| 443 |
+
# Initialize Llama model - try to disable chat format handler if supported
|
| 444 |
+
# This helps avoid "System role not supported" error for models like phi3
|
| 445 |
+
try:
|
| 446 |
+
llm = Llama(model_path=found_file, stop=chosen_start_tag, n_gpu_layers=llm_config.n_gpu_layers, n_ctx=llm_config.n_ctx, seed=seed, chat_format=None)
|
| 447 |
+
except TypeError:
|
| 448 |
+
# If chat_format parameter doesn't exist, try without it or with chat_handler
|
| 449 |
+
try:
|
| 450 |
+
llm = Llama(model_path=found_file, stop=chosen_start_tag, n_gpu_layers=llm_config.n_gpu_layers, n_ctx=llm_config.n_ctx, seed=seed, chat_handler=None)
|
| 451 |
+
except TypeError:
|
| 452 |
+
# Fall back to basic initialization if chat format parameters don't exist
|
| 453 |
+
llm = Llama(model_path=found_file, stop=chosen_start_tag, n_gpu_layers=llm_config.n_gpu_layers, n_ctx=llm_config.n_ctx, seed=seed)
|
| 454 |
+
|
| 455 |
+
# Monkey-patch the create_chat_completion method to use raw completion
|
| 456 |
+
# This avoids the chat format handler that requires system roles (not supported by phi3)
|
| 457 |
+
# We patch the instance directly so it still passes isinstance checks in BERTopic
|
| 458 |
+
llm = patch_llama_create_chat_completion(llm)
|
| 459 |
+
|
| 460 |
#print(llm.n_gpu_layers)
|
| 461 |
#print("Chosen prompt:", chosen_prompt)
|
| 462 |
llm_model = LlamaCPP(llm, prompt=chosen_prompt)#, **gen_config.model_dump())
|
requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
pandas==2.3.3
|
| 2 |
-
plotly==6.3.1
|
| 3 |
scikit-learn==1.7.2
|
| 4 |
umap-learn==0.5.9.post2
|
| 5 |
gradio==5.49.1
|
|
@@ -23,4 +22,5 @@ llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-py
|
|
| 23 |
# Specify exact llama_cpp wheel for huggingface compatibility
|
| 24 |
# https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu121/llama_cpp_python-0.3.4-cp310-cp310-linux_x86_64.whl
|
| 25 |
spaces==0.42.1
|
| 26 |
-
numpy==2.2.6
|
|
|
|
|
|
| 1 |
pandas==2.3.3
|
|
|
|
| 2 |
scikit-learn==1.7.2
|
| 3 |
umap-learn==0.5.9.post2
|
| 4 |
gradio==5.49.1
|
|
|
|
| 22 |
# Specify exact llama_cpp wheel for huggingface compatibility
|
| 23 |
# https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.4-cu121/llama_cpp_python-0.3.4-cp310-cp310-linux_x86_64.whl
|
| 24 |
spaces==0.42.1
|
| 25 |
+
numpy==2.2.6
|
| 26 |
+
plotly<=5.24.1 # Downgrade needed to enable correct topic document output display
|
requirements_aws.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
pandas==2.3.3
|
| 2 |
-
plotly==6.3.1
|
| 3 |
scikit-learn==1.7.2
|
| 4 |
umap-learn==0.5.9.post2
|
| 5 |
boto3==1.40.72
|
|
@@ -18,4 +17,5 @@ accelerate==1.11.0
|
|
| 18 |
bertopic==0.17.3
|
| 19 |
sentence-transformers==5.1.2
|
| 20 |
spaces==0.42.1
|
| 21 |
-
numpy==2.2.6
|
|
|
|
|
|
| 1 |
pandas==2.3.3
|
|
|
|
| 2 |
scikit-learn==1.7.2
|
| 3 |
umap-learn==0.5.9.post2
|
| 4 |
boto3==1.40.72
|
|
|
|
| 17 |
bertopic==0.17.3
|
| 18 |
sentence-transformers==5.1.2
|
| 19 |
spaces==0.42.1
|
| 20 |
+
numpy==2.2.6
|
| 21 |
+
plotly<=5.24.1 # Downgrade needed to enable correct topic document output display
|
requirements_gpu.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
pandas==2.3.3
|
| 2 |
-
plotly==6.3.1
|
| 3 |
scikit-learn==1.7.2
|
| 4 |
umap-learn==0.5.9.post2
|
| 5 |
gradio==5.49.1
|
|
@@ -21,4 +20,5 @@ llama-cpp-python==0.3.4 --extra-index-url https://abetlen.github.io/llama-cpp-py
|
|
| 21 |
sentence-transformers==5.1.2
|
| 22 |
spaces==0.42.1
|
| 23 |
numpy==2.2.6
|
|
|
|
| 24 |
|
|
|
|
| 1 |
pandas==2.3.3
|
|
|
|
| 2 |
scikit-learn==1.7.2
|
| 3 |
umap-learn==0.5.9.post2
|
| 4 |
gradio==5.49.1
|
|
|
|
| 20 |
sentence-transformers==5.1.2
|
| 21 |
spaces==0.42.1
|
| 22 |
numpy==2.2.6
|
| 23 |
+
plotly<=5.24.1 # Downgrade needed to enable correct topic document output display
|
| 24 |
|