Spaces:
Paused
Paused
Oleh Kuznetsov
commited on
Commit
·
bdaca7e
1
Parent(s):
6e1997a
feat(rec): Finalize recommendations (almost done)
Browse files- .gitignore +2 -1
- Dockerfile +1 -1
- app.py +307 -24
- ingest.py +6 -2
- prompts/api.txt +0 -7
- resources/description.md +33 -0
- resources/prompt_api.md +12 -0
- prompts/local.txt → resources/prompt_vllm.md +1 -1
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
*__pycache__*
|
| 2 |
.venv
|
| 3 |
.env
|
| 4 |
-
data
|
|
|
|
|
|
| 1 |
*__pycache__*
|
| 2 |
.venv
|
| 3 |
.env
|
| 4 |
+
data
|
| 5 |
+
*sandbox*
|
Dockerfile
CHANGED
|
@@ -31,7 +31,7 @@ ENV HOME=/home/user \
|
|
| 31 |
|
| 32 |
# Setup application directory
|
| 33 |
WORKDIR $HOME/app
|
| 34 |
-
ADD --chown=user ./
|
| 35 |
ADD --chown=user ./ingest.py $HOME/app/ingest.py
|
| 36 |
ADD --chown=user ./app.py $HOME/app/app.py
|
| 37 |
|
|
|
|
| 31 |
|
| 32 |
# Setup application directory
|
| 33 |
WORKDIR $HOME/app
|
| 34 |
+
ADD --chown=user ./resources $HOME/app/resources
|
| 35 |
ADD --chown=user ./ingest.py $HOME/app/ingest.py
|
| 36 |
ADD --chown=user ./app.py $HOME/app/app.py
|
| 37 |
|
app.py
CHANGED
|
@@ -1,15 +1,25 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import random
|
|
|
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from google import genai
|
| 8 |
from google.genai import types
|
| 9 |
-
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
| 10 |
from vllm import LLM, SamplingParams
|
| 11 |
from vllm.sampling_params import GuidedDecodingParams
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 15 |
|
|
@@ -20,11 +30,44 @@ VLLM_DTYPE = os.getenv("VLLM_DTYPE")
|
|
| 20 |
|
| 21 |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# -------------------------------- HELPERS -------------------------------------
|
| 24 |
-
def
|
| 25 |
with path.open("r") as file:
|
| 26 |
-
|
| 27 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
# -------------------------------- Data Models -------------------------------
|
|
@@ -41,8 +84,51 @@ class QueryRewrite(BaseModel):
|
|
| 41 |
structured: StructuredQueryRewriteResponse | None = None
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
class APIGenreRecommendationResponse(BaseModel):
|
| 45 |
-
genres: list[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
# -------------------------------- VLLM --------------------------------------
|
|
@@ -68,7 +154,7 @@ vllm_system_prompt = (
|
|
| 68 |
"You are a search query optimization assistant built into"
|
| 69 |
" music genre search engine, helping users discover novel music genres."
|
| 70 |
)
|
| 71 |
-
vllm_prompt =
|
| 72 |
|
| 73 |
# -------------------------------- GEMINI ------------------------------------
|
| 74 |
gemini_config = types.GenerateContentConfig(
|
|
@@ -76,20 +162,35 @@ gemini_config = types.GenerateContentConfig(
|
|
| 76 |
response_schema=APIGenreRecommendationResponse,
|
| 77 |
temperature=0.7,
|
| 78 |
max_output_tokens=1024,
|
| 79 |
-
system_instruction=(
|
|
|
|
|
|
|
|
|
|
| 80 |
)
|
| 81 |
gemini_llm = genai.Client(
|
| 82 |
api_key=GEMINI_API_KEY,
|
| 83 |
http_options={"api_version": "v1alpha"},
|
| 84 |
)
|
| 85 |
-
gemini_prompt =
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
def
|
| 93 |
prompt = vllm_prompt.format(query=query)
|
| 94 |
messages = [
|
| 95 |
{"role": "system", "content": vllm_system_prompt},
|
|
@@ -104,10 +205,181 @@ def recommend_sadaimrec(query: str):
|
|
| 104 |
rewrites=[x for x in list(rewrite_json.values()) if x is not None],
|
| 105 |
structured=rewrite_json,
|
| 106 |
)
|
| 107 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def recommend_gemini(query: str):
|
|
|
|
| 111 |
prompt = gemini_prompt.format(query=query)
|
| 112 |
response = gemini_llm.models.generate_content(
|
| 113 |
model="gemini-2.0-flash",
|
|
@@ -115,17 +387,19 @@ def recommend_gemini(query: str):
|
|
| 115 |
config=gemini_config,
|
| 116 |
)
|
| 117 |
parsed_content: APIGenreRecommendationResponse = response.parsed
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
-
#
|
| 122 |
pipelines = {
|
| 123 |
"sadaimrec": recommend_sadaimrec,
|
| 124 |
"chatgpt": recommend_gemini,
|
| 125 |
}
|
| 126 |
|
| 127 |
|
| 128 |
-
# -------------------------------------- INTERFACE -----------------------------
|
| 129 |
def generate_responses(query):
|
| 130 |
# Randomize model order
|
| 131 |
pipeline_names = list(pipelines.keys())
|
|
@@ -156,30 +430,37 @@ def reset_ui():
|
|
| 156 |
gr.update(value=""), # clear query
|
| 157 |
gr.update(visible=False), # hide radio
|
| 158 |
gr.update(visible=False), # hide vote button
|
| 159 |
-
gr.update(value=""), # clear Option 1 text
|
| 160 |
-
gr.update(value=""), # clear Option 2 text
|
| 161 |
gr.update(value=""), # clear result
|
| 162 |
gr.update(active=False),
|
| 163 |
)
|
| 164 |
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
submit_btn = gr.Button("Submit")
|
| 170 |
# timer that resets ui after feedback is sent
|
| 171 |
reset_timer = gr.Timer(value=2.0, active=False)
|
| 172 |
|
| 173 |
# Hidden components to store model responses and names
|
| 174 |
with gr.Row(visible=False) as response_row:
|
| 175 |
-
response_1 = gr.
|
| 176 |
-
response_2 = gr.
|
| 177 |
model_label_1 = gr.Textbox(visible=False)
|
| 178 |
model_label_2 = gr.Textbox(visible=False)
|
| 179 |
|
| 180 |
# Feedback
|
| 181 |
vote = gr.Radio(
|
| 182 |
-
["Option 1", "Option 2"],
|
|
|
|
|
|
|
| 183 |
)
|
| 184 |
vote_btn = gr.Button("Vote", visible=False)
|
| 185 |
result = gr.Textbox(label="Console", interactive=False)
|
|
@@ -189,6 +470,7 @@ with gr.Blocks() as demo:
|
|
| 189 |
fn=generate_responses,
|
| 190 |
inputs=[query],
|
| 191 |
outputs=[response_1, response_2, model_label_1, model_label_2],
|
|
|
|
| 192 |
)
|
| 193 |
submit_btn.click( # update ui
|
| 194 |
fn=lambda: (
|
|
@@ -222,6 +504,7 @@ with gr.Blocks() as demo:
|
|
| 222 |
trigger_mode="once",
|
| 223 |
)
|
| 224 |
|
|
|
|
| 225 |
if __name__ == "__main__":
|
| 226 |
demo.queue(max_size=10, default_concurrency_limit=1).launch(
|
| 227 |
server_name="0.0.0.0", server_port=7860
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
+
import urllib.parse
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from fastembed import SparseEmbedding, SparseTextEmbedding
|
| 13 |
from google import genai
|
| 14 |
from google.genai import types
|
| 15 |
+
from pydantic import BaseModel, Field
|
| 16 |
+
from qdrant_client import QdrantClient
|
| 17 |
+
from qdrant_client import models as qmodels
|
| 18 |
+
from sentence_transformers import CrossEncoder, SentenceTransformer
|
| 19 |
from vllm import LLM, SamplingParams
|
| 20 |
from vllm.sampling_params import GuidedDecodingParams
|
| 21 |
|
| 22 |
+
load_dotenv()
|
| 23 |
|
| 24 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 25 |
|
|
|
|
| 30 |
|
| 31 |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 32 |
|
| 33 |
+
DATA_PATH = Path(os.getenv("DATA_PATH"))
|
| 34 |
+
DB_PATH = DATA_PATH / "db"
|
| 35 |
+
|
| 36 |
+
client = QdrantClient(path=str(DB_PATH))
|
| 37 |
+
collection_name = "knowledge_cards"
|
| 38 |
+
num_chunks_base = 500
|
| 39 |
+
alpha = 0.5
|
| 40 |
+
top_k = 5 # we only want top 5 genres
|
| 41 |
+
|
| 42 |
+
youtube_url_template = "{genre} music playlist"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
# -------------------------------- HELPERS -------------------------------------
|
| 46 |
+
def load_text_resource(path: Path) -> str:
|
| 47 |
with path.open("r") as file:
|
| 48 |
+
resource = file.read()
|
| 49 |
+
return resource
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def youtube_search_link_for_genre(genre: str) -> str:
|
| 53 |
+
base_url = "https://www.youtube.com/results"
|
| 54 |
+
params = {
|
| 55 |
+
"search_query": youtube_url_template.format(
|
| 56 |
+
genre=genre.replace("_", " ").lower()
|
| 57 |
+
)
|
| 58 |
+
}
|
| 59 |
+
return f"{base_url}?{urllib.parse.urlencode(params)}"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def generate_recommendation_string(ranking: dict[str, float]) -> str:
|
| 63 |
+
recommendation_string = "## Recommendations for You\n\n"
|
| 64 |
+
for idx, (genre, score) in enumerate(ranking.items(), start=1):
|
| 65 |
+
youtube_link = youtube_search_link_for_genre(genre=genre)
|
| 66 |
+
recommendation_string += (
|
| 67 |
+
f"{idx}. **{genre.replace('_', ' ').capitalize()}** ({score:.2f}); "
|
| 68 |
+
f"[YouTube link]({youtube_link})\n"
|
| 69 |
+
)
|
| 70 |
+
return recommendation_string
|
| 71 |
|
| 72 |
|
| 73 |
# -------------------------------- Data Models -------------------------------
|
|
|
|
| 84 |
structured: StructuredQueryRewriteResponse | None = None
|
| 85 |
|
| 86 |
|
| 87 |
+
class APIGenreRecommendation(BaseModel):
|
| 88 |
+
name: str = Field(description="Name of the music genre.")
|
| 89 |
+
score: float = Field(
|
| 90 |
+
description="Score you assign to the genre (from 0 to 1).", ge=0, le=1
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
class APIGenreRecommendationResponse(BaseModel):
|
| 95 |
+
genres: list[APIGenreRecommendation]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RetrievalResult(BaseModel):
|
| 99 |
+
chunk: str
|
| 100 |
+
genre: str
|
| 101 |
+
score: float
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class RerankingResult(BaseModel):
|
| 105 |
+
query: str
|
| 106 |
+
genre: str
|
| 107 |
+
chunk: str
|
| 108 |
+
score: float
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Recommendation(BaseModel):
|
| 112 |
+
name: str
|
| 113 |
+
rank: int
|
| 114 |
+
score: Optional[float] = None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class PipelineResult(BaseModel):
|
| 118 |
+
query: str
|
| 119 |
+
rewrite: Optional[QueryRewrite] = None
|
| 120 |
+
retrieval_result: Optional[list[RetrievalResult]] = None
|
| 121 |
+
reranking_result: Optional[list[RerankingResult]] = None
|
| 122 |
+
recommendations: Optional[dict[str, Recommendation]] = None
|
| 123 |
+
|
| 124 |
+
def to_ranking(self) -> dict[str, float]:
|
| 125 |
+
if not self.recommendations:
|
| 126 |
+
return {}
|
| 127 |
+
|
| 128 |
+
return {
|
| 129 |
+
genre: recommendation.score
|
| 130 |
+
for genre, recommendation in self.recommendations.items()
|
| 131 |
+
}
|
| 132 |
|
| 133 |
|
| 134 |
# -------------------------------- VLLM --------------------------------------
|
|
|
|
| 154 |
"You are a search query optimization assistant built into"
|
| 155 |
" music genre search engine, helping users discover novel music genres."
|
| 156 |
)
|
| 157 |
+
vllm_prompt = load_text_resource(Path("./resources/prompt_vllm.md"))
|
| 158 |
|
| 159 |
# -------------------------------- GEMINI ------------------------------------
|
| 160 |
gemini_config = types.GenerateContentConfig(
|
|
|
|
| 162 |
response_schema=APIGenreRecommendationResponse,
|
| 163 |
temperature=0.7,
|
| 164 |
max_output_tokens=1024,
|
| 165 |
+
system_instruction=(
|
| 166 |
+
"You are a helpful music genre recommendation assistant built into"
|
| 167 |
+
" music genre search engine, helping users discover novel music genres."
|
| 168 |
+
)
|
| 169 |
)
|
| 170 |
gemini_llm = genai.Client(
|
| 171 |
api_key=GEMINI_API_KEY,
|
| 172 |
http_options={"api_version": "v1alpha"},
|
| 173 |
)
|
| 174 |
+
gemini_prompt = load_text_resource(Path("./resources/prompt_api.md"))
|
| 175 |
|
| 176 |
+
# ---------------------------- EMBEDDING MODELS --------------------------------
|
| 177 |
+
dense_encoder = SentenceTransformer(
|
| 178 |
+
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
|
| 179 |
+
device="cuda",
|
| 180 |
+
model_kwargs={"torch_dtype": VLLM_DTYPE},
|
| 181 |
+
)
|
| 182 |
+
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
|
| 183 |
+
reranker = CrossEncoder(
|
| 184 |
+
model_name_or_path="BAAI/bge-reranker-v2-m3",
|
| 185 |
+
max_length=1024,
|
| 186 |
+
device="cuda",
|
| 187 |
+
model_kwargs={"torch_dtype": VLLM_DTYPE},
|
| 188 |
+
)
|
| 189 |
+
reranker_batch_size = 128
|
| 190 |
|
| 191 |
|
| 192 |
+
# ---------------------------- RETRIEVAL ---------------------------------------
|
| 193 |
+
def run_query_rewrite(query: str) -> QueryRewrite:
|
| 194 |
prompt = vllm_prompt.format(query=query)
|
| 195 |
messages = [
|
| 196 |
{"role": "system", "content": vllm_system_prompt},
|
|
|
|
| 205 |
rewrites=[x for x in list(rewrite_json.values()) if x is not None],
|
| 206 |
structured=rewrite_json,
|
| 207 |
)
|
| 208 |
+
return rewrite
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def prepare_queries_for_retrieval(
|
| 212 |
+
query: str, rewrite: QueryRewrite
|
| 213 |
+
) -> list[dict[str, str | None]]:
|
| 214 |
+
queries_to_retrieve = [{"text": query, "topic": None}]
|
| 215 |
+
for cat, rewrite in rewrite.structured.model_dump().items():
|
| 216 |
+
if rewrite is None:
|
| 217 |
+
continue
|
| 218 |
+
topic = cat
|
| 219 |
+
if cat not in ["subjective", "purpose", "technical"]:
|
| 220 |
+
topic = None
|
| 221 |
+
queries_to_retrieve.append({"text": rewrite, "topic": topic})
|
| 222 |
+
return queries_to_retrieve
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def run_retrieval(
|
| 226 |
+
queries: list[dict[str, str]],
|
| 227 |
+
) -> RetrievalResult:
|
| 228 |
+
queries_to_embed = [query["text"] for query in queries]
|
| 229 |
+
dense_queries = list(
|
| 230 |
+
dense_encoder.encode(
|
| 231 |
+
queries_to_embed, convert_to_numpy=True, normalize_embeddings=True
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
sparse_queries = list(sparse_encoder.query_embed(queries_to_embed))
|
| 235 |
+
prefetches: list[qmodels.Prefetch] = []
|
| 236 |
+
|
| 237 |
+
for query, dense_query, sparse_query in zip(queries, dense_queries, sparse_queries):
|
| 238 |
+
assert dense_query is not None and sparse_query is not None
|
| 239 |
+
assert isinstance(dense_query, np.ndarray) and isinstance(
|
| 240 |
+
sparse_query, SparseEmbedding
|
| 241 |
+
)
|
| 242 |
+
topic = query.get("topic", None)
|
| 243 |
+
prefetch = [
|
| 244 |
+
qmodels.Prefetch(
|
| 245 |
+
query=dense_query,
|
| 246 |
+
using="dense",
|
| 247 |
+
filter=qmodels.Filter(
|
| 248 |
+
must=[
|
| 249 |
+
qmodels.FieldCondition(
|
| 250 |
+
key="topic", match=qmodels.MatchValue(value=topic)
|
| 251 |
+
)
|
| 252 |
+
]
|
| 253 |
+
)
|
| 254 |
+
if topic is not None
|
| 255 |
+
else None,
|
| 256 |
+
limit=num_chunks_base,
|
| 257 |
+
),
|
| 258 |
+
qmodels.Prefetch(
|
| 259 |
+
query=qmodels.SparseVector(**sparse_query.as_object()),
|
| 260 |
+
using="sparse",
|
| 261 |
+
filter=qmodels.Filter(
|
| 262 |
+
must=[
|
| 263 |
+
qmodels.FieldCondition(
|
| 264 |
+
key="topic", match=qmodels.MatchValue(value=topic)
|
| 265 |
+
)
|
| 266 |
+
]
|
| 267 |
+
)
|
| 268 |
+
if topic is not None
|
| 269 |
+
else None,
|
| 270 |
+
limit=num_chunks_base,
|
| 271 |
+
),
|
| 272 |
+
]
|
| 273 |
+
prefetches.extend(prefetch)
|
| 274 |
+
|
| 275 |
+
retrieval_results = client.query_points(
|
| 276 |
+
collection_name=collection_name,
|
| 277 |
+
prefetch=prefetches,
|
| 278 |
+
query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF),
|
| 279 |
+
limit=num_chunks_base,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
final_hits: list[RetrievalResult] = [
|
| 283 |
+
RetrievalResult(
|
| 284 |
+
chunk=hit.payload["text"], genre=hit.payload["genre"], score=hit.score
|
| 285 |
+
)
|
| 286 |
+
for hit in retrieval_results.points
|
| 287 |
+
]
|
| 288 |
+
return final_hits
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def run_reranking(
|
| 292 |
+
query: str, retrieval_result: list[RetrievalResult]
|
| 293 |
+
) -> list[RerankingResult]:
|
| 294 |
+
hit_texts: list[str] = [result.chunk for result in retrieval_result]
|
| 295 |
+
hit_genres: list[str] = [result.genre for result in retrieval_result]
|
| 296 |
+
hit_rerank = reranker.rank(
|
| 297 |
+
query=query,
|
| 298 |
+
documents=hit_texts,
|
| 299 |
+
batch_size=reranker_batch_size,
|
| 300 |
+
)
|
| 301 |
+
ranking = [
|
| 302 |
+
RerankingResult(
|
| 303 |
+
query=query,
|
| 304 |
+
genre=hit_genres[hit["corpus_id"]],
|
| 305 |
+
chunk=hit_texts[hit["corpus_id"]],
|
| 306 |
+
score=hit["score"],
|
| 307 |
+
)
|
| 308 |
+
for hit in hit_rerank
|
| 309 |
+
]
|
| 310 |
+
ranking.sort(key=lambda x: x.score, reverse=True)
|
| 311 |
+
return ranking
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_top_genres(
|
| 315 |
+
df: pd.DataFrame,
|
| 316 |
+
column: str,
|
| 317 |
+
alpha: float = 1.0,
|
| 318 |
+
# beta: float = 1.0,
|
| 319 |
+
top_k: int | None = None,
|
| 320 |
+
) -> pd.Series:
|
| 321 |
+
assert 0 <= alpha <= 1.0
|
| 322 |
+
|
| 323 |
+
# Min-max normalization of re-ranker scores before aggregation
|
| 324 |
+
task_scores = df[column]
|
| 325 |
+
min_score = task_scores.min()
|
| 326 |
+
max_score = task_scores.max()
|
| 327 |
+
if max_score > min_score: # Avoid division by zero
|
| 328 |
+
df.loc[:, column] = (task_scores - min_score) / (max_score - min_score)
|
| 329 |
+
|
| 330 |
+
tg_df = df.groupby("genre").agg(size=("chunk", "size"), score=(column, "sum"))
|
| 331 |
+
tg_df["weighted_score"] = alpha * (tg_df["size"] / tg_df["size"].max()) + (
|
| 332 |
+
1 - alpha
|
| 333 |
+
) * (tg_df["score"] / tg_df["score"].max())
|
| 334 |
+
tg = tg_df.sort_values("weighted_score", ascending=False)["weighted_score"]
|
| 335 |
+
|
| 336 |
+
if top_k:
|
| 337 |
+
tg = tg.head(top_k)
|
| 338 |
+
|
| 339 |
+
return tg
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def get_recommendations(
|
| 343 |
+
reranking_result: list[RerankingResult],
|
| 344 |
+
) -> dict[str, Recommendation]:
|
| 345 |
+
ranking_df = pd.DataFrame([x.model_dump(mode="python") for x in reranking_result])
|
| 346 |
+
top_genres_series = get_top_genres(
|
| 347 |
+
df=ranking_df, column="score", alpha=alpha, top_k=top_k
|
| 348 |
+
)
|
| 349 |
+
recommendations = {
|
| 350 |
+
genre: Recommendation(name=genre, rank=rank, score=score)
|
| 351 |
+
for rank, (genre, score) in enumerate(
|
| 352 |
+
top_genres_series.to_dict().items(), start=1
|
| 353 |
+
)
|
| 354 |
+
}
|
| 355 |
+
return recommendations
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# ----------------------- GENERATE RECOMMENDATIONS -----------------------------
|
| 359 |
+
def recommend_sadaimrec(query: str):
|
| 360 |
+
result = PipelineResult(query=query)
|
| 361 |
+
print("Running query processing...", flush=True)
|
| 362 |
+
result.rewrite = run_query_rewrite(query=query)
|
| 363 |
+
queries_to_retrieve = prepare_queries_for_retrieval(
|
| 364 |
+
query=query, rewrite=result.rewrite
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
print("Running retrieval...", flush=True)
|
| 368 |
+
result.retrieval_result = run_retrieval(queries_to_retrieve)
|
| 369 |
+
|
| 370 |
+
print("Running re-ranking...", flush=True)
|
| 371 |
+
result.reranking_result = run_reranking(
|
| 372 |
+
query=query, retrieval_result=result.retrieval_result
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
print("Aggregating recommendations...", flush=True)
|
| 376 |
+
result.recommendations = get_recommendations(result.reranking_result)
|
| 377 |
+
recommendation_string = generate_recommendation_string(result.to_ranking())
|
| 378 |
+
return f"{recommendation_string}"
|
| 379 |
|
| 380 |
|
| 381 |
def recommend_gemini(query: str):
|
| 382 |
+
print("Generating recommendations using Gemini...", flush=True)
|
| 383 |
prompt = gemini_prompt.format(query=query)
|
| 384 |
response = gemini_llm.models.generate_content(
|
| 385 |
model="gemini-2.0-flash",
|
|
|
|
| 387 |
config=gemini_config,
|
| 388 |
)
|
| 389 |
parsed_content: APIGenreRecommendationResponse = response.parsed
|
| 390 |
+
parsed_content.genres.sort(key=lambda x: x.score, reverse=True)
|
| 391 |
+
ranking = {x.name.lower(): x.score for x in parsed_content.genres}
|
| 392 |
+
recommendation_string = generate_recommendation_string(ranking)
|
| 393 |
+
return f"{recommendation_string}"
|
| 394 |
|
| 395 |
|
| 396 |
+
# -------------------------------------- INTERFACE -----------------------------
|
| 397 |
pipelines = {
|
| 398 |
"sadaimrec": recommend_sadaimrec,
|
| 399 |
"chatgpt": recommend_gemini,
|
| 400 |
}
|
| 401 |
|
| 402 |
|
|
|
|
| 403 |
def generate_responses(query):
|
| 404 |
# Randomize model order
|
| 405 |
pipeline_names = list(pipelines.keys())
|
|
|
|
| 430 |
gr.update(value=""), # clear query
|
| 431 |
gr.update(visible=False), # hide radio
|
| 432 |
gr.update(visible=False), # hide vote button
|
| 433 |
+
gr.update(value="**Generating...**"), # clear Option 1 text
|
| 434 |
+
gr.update(value="**Generating...**"), # clear Option 2 text
|
| 435 |
gr.update(value=""), # clear result
|
| 436 |
gr.update(active=False),
|
| 437 |
)
|
| 438 |
|
| 439 |
|
| 440 |
+
app_description = load_text_resource(Path("./resources/description.md"))
|
| 441 |
+
|
| 442 |
+
with gr.Blocks(title="SADAIMREC") as demo:
|
| 443 |
+
gr.Markdown(app_description)
|
| 444 |
+
query = gr.Textbox(
|
| 445 |
+
label="Your Query",
|
| 446 |
+
placeholder="Calming, music for deep relaxation with echoing sounds and deep bass",
|
| 447 |
+
)
|
| 448 |
submit_btn = gr.Button("Submit")
|
| 449 |
# timer that resets ui after feedback is sent
|
| 450 |
reset_timer = gr.Timer(value=2.0, active=False)
|
| 451 |
|
| 452 |
# Hidden components to store model responses and names
|
| 453 |
with gr.Row(visible=False) as response_row:
|
| 454 |
+
response_1 = gr.Markdown(value="**Generating...**", label="Option 1")
|
| 455 |
+
response_2 = gr.Markdown(value="**Generating...**", label="Option 2")
|
| 456 |
model_label_1 = gr.Textbox(visible=False)
|
| 457 |
model_label_2 = gr.Textbox(visible=False)
|
| 458 |
|
| 459 |
# Feedback
|
| 460 |
vote = gr.Radio(
|
| 461 |
+
["Option 1 (left)", "Option 2 (right)"],
|
| 462 |
+
label="Select Best Response",
|
| 463 |
+
visible=False,
|
| 464 |
)
|
| 465 |
vote_btn = gr.Button("Vote", visible=False)
|
| 466 |
result = gr.Textbox(label="Console", interactive=False)
|
|
|
|
| 470 |
fn=generate_responses,
|
| 471 |
inputs=[query],
|
| 472 |
outputs=[response_1, response_2, model_label_1, model_label_2],
|
| 473 |
+
show_progress="full",
|
| 474 |
)
|
| 475 |
submit_btn.click( # update ui
|
| 476 |
fn=lambda: (
|
|
|
|
| 504 |
trigger_mode="once",
|
| 505 |
)
|
| 506 |
|
| 507 |
+
|
| 508 |
if __name__ == "__main__":
|
| 509 |
demo.queue(max_size=10, default_concurrency_limit=1).launch(
|
| 510 |
server_name="0.0.0.0", server_port=7860
|
ingest.py
CHANGED
|
@@ -9,11 +9,13 @@ from huggingface_hub import hf_hub_download
|
|
| 9 |
from qdrant_client import QdrantClient
|
| 10 |
from qdrant_client import models as qmodels
|
| 11 |
|
|
|
|
|
|
|
| 12 |
DATA_PATH = Path(os.getenv("DATA_PATH"))
|
| 13 |
DB_PATH = DATA_PATH / "db"
|
| 14 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 15 |
|
| 16 |
-
RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower == "true")
|
| 17 |
DATA_REPO = os.getenv("DATA_REPO")
|
| 18 |
DATA_FILENAME = os.getenv("DATA_FILENAME")
|
| 19 |
|
|
@@ -24,7 +26,9 @@ dense_batch_size = 128
|
|
| 24 |
sparse_batch_size = 256
|
| 25 |
|
| 26 |
dense_encoder = SentenceTransformer(
|
| 27 |
-
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
|
| 30 |
|
|
|
|
| 9 |
from qdrant_client import QdrantClient
|
| 10 |
from qdrant_client import models as qmodels
|
| 11 |
|
| 12 |
+
VLLM_DTYPE = os.getenv("VLLM_DTYPE")
|
| 13 |
+
|
| 14 |
DATA_PATH = Path(os.getenv("DATA_PATH"))
|
| 15 |
DB_PATH = DATA_PATH / "db"
|
| 16 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 17 |
|
| 18 |
+
RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true")
|
| 19 |
DATA_REPO = os.getenv("DATA_REPO")
|
| 20 |
DATA_FILENAME = os.getenv("DATA_FILENAME")
|
| 21 |
|
|
|
|
| 26 |
sparse_batch_size = 256
|
| 27 |
|
| 28 |
dense_encoder = SentenceTransformer(
|
| 29 |
+
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
|
| 30 |
+
device="cuda",
|
| 31 |
+
model_kwargs={"torch_dtype": VLLM_DTYPE},
|
| 32 |
)
|
| 33 |
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
|
| 34 |
|
prompts/api.txt
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
# Purpose
|
| 2 |
-
|
| 3 |
-
Recommend 5 genres based on the user query
|
| 4 |
-
|
| 5 |
-
# Query
|
| 6 |
-
|
| 7 |
-
{query}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resources/description.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Music Genre Recommendation Side-By-Side Comparison
|
| 2 |
+
|
| 3 |
+
This simple application was developed and deployed as **complementary material for my thesis**.
|
| 4 |
+
|
| 5 |
+
In case of any complications, questions or suggestions, please reach out via [email](mailto:[email protected]).
|
| 6 |
+
|
| 7 |
+
## Instructions
|
| 8 |
+
|
| 9 |
+
1. Formulate a **search query** with description of a music genre you would like to listen to. Expected format is described below.
|
| 10 |
+
2. Explore **two generated recommendation rankings**: one is created by my system, one is generated using `gemini-2.0-flash`. Order is **randomized** each run.
|
| 11 |
+
3. Determine which ranking you prefer more.
|
| 12 |
+
4. Vote for your choice.
|
| 13 |
+
5. Wait for refresh and repeat as many times as you want.
|
| 14 |
+
|
| 15 |
+
## Expected Query Format
|
| 16 |
+
|
| 17 |
+
- The system was designed to support **3 categories** of music genre descriptors:
|
| 18 |
+
- **Subjective**: Emotional & perceptual qualities, desired **inner feeling** (melancholic, energetic)
|
| 19 |
+
- **Purpose-Based**: Listening setting, context, suitable activities, scenario (party, workout)
|
| 20 |
+
- **Technical**: Musical & production attributes, **HOW the sound is made** (instrumentation, timbre, tempo, lo-fi)
|
| 21 |
+
- **Other descriptors are out of scope of the current implementation**:
|
| 22 |
+
- I kindly ask you to only use the above 3 categories for your queries
|
| 23 |
+
- Usage of cultural, historical, etc. descriptors can lead to suboptimal results
|
| 24 |
+
- You can make the descriptors **as complex and poetic as you want**, but I kindly ask you to **limit your query to a couple of sentences**
|
| 25 |
+
|
| 26 |
+
## Query Examples
|
| 27 |
+
|
| 28 |
+
- `Music for deep relaxation with echoing sounds and heavy bass, perfect for unwinding after along day`
|
| 29 |
+
- `Music that feels like the echo of a forgotten world—slow, sorrowful. Guitars and distant vocals create the sensation of a long, drifting sleep on the edge of melancholy and oblivion.A soundtrack to isolation, it slowly pulls you into the depths of existential despair.`
|
| 30 |
+
- `Raw and filled with aggression, high-energy drums, mosh-pit vibes, high bpm, guitars`
|
| 31 |
+
- `Music to study to, relaxing, chill with calm drums, some piano, and suitable for background`
|
| 32 |
+
- `Creamy and cozy, suitable for evenings with loved ones`
|
| 33 |
+
- `Dreamy instrumental music for midnight melancholia`
|
resources/prompt_api.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Purpose and Context
|
| 2 |
+
|
| 3 |
+
Given a user-generated Search Query describing music they wish to explore, create a ranking of the most suitable music genres.
|
| 4 |
+
|
| 5 |
+
# Instructions
|
| 6 |
+
|
| 7 |
+
1. Create a music genre ranking, including 5 the most suitable music genres, ordered from the most to the least suitable.
|
| 8 |
+
2. Respond in JSON.
|
| 9 |
+
|
| 10 |
+
# Search Query
|
| 11 |
+
|
| 12 |
+
{query}
|
prompts/local.txt → resources/prompt_vllm.md
RENAMED
|
@@ -228,4 +228,4 @@ Given a user-generated Search Query describing music they wish to explore, you m
|
|
| 228 |
|
| 229 |
# Search Query
|
| 230 |
|
| 231 |
-
{query}
|
|
|
|
| 228 |
|
| 229 |
# Search Query
|
| 230 |
|
| 231 |
+
{query}
|