Spaces:
Paused
Paused
| import json | |
| import os | |
| import random | |
| import signal | |
| import sys | |
| import urllib.parse | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional | |
| from uuid import uuid4 | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| # from dotenv import load_dotenv | |
| from fastembed import SparseEmbedding, SparseTextEmbedding | |
| from google import genai | |
| from google.genai import types | |
| from huggingface_hub import CommitScheduler | |
| from pydantic import BaseModel, Field | |
| from qdrant_client import QdrantClient | |
| from qdrant_client import models as qmodels | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| from vllm import LLM, SamplingParams | |
| from vllm.sampling_params import GuidedDecodingParams | |
| # load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| VLLM_MODEL_NAME = os.getenv("VLLM_MODEL_NAME") | |
| VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION")) | |
| VLLM_MAX_SEQ_LEN = int(os.getenv("VLLM_MAX_SEQ_LEN")) | |
| VLLM_DTYPE = os.getenv("VLLM_DTYPE") | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| DATA_PATH = Path(os.getenv("DATA_PATH")) | |
| DB_PATH = DATA_PATH / "db" | |
| FEEDBACK_REPO = os.getenv("FEEDBACK_REPO") | |
| FEEDBACK_DIR = DATA_PATH / "feedback" | |
| FEEDBACK_DIR.mkdir(parents=True, exist_ok=True) | |
| FEEDBACK_FILE = FEEDBACK_DIR / f"votes_{uuid4()}.jsonl" | |
| scheduler = CommitScheduler( | |
| repo_id=FEEDBACK_REPO, | |
| repo_type="dataset", | |
| folder_path=FEEDBACK_DIR, | |
| path_in_repo="data", | |
| every=5, | |
| token=HF_TOKEN, | |
| private=True, | |
| ) | |
| client = QdrantClient(path=str(DB_PATH)) | |
| collection_name = "knowledge_cards" | |
| num_chunks_base = 500 | |
| alpha = 0.5 | |
| top_k = 5 # we only want top 5 genres | |
| youtube_url_template = "{genre} music playlist" | |
| # -------------------------------- HELPERS ------------------------------------- | |
| def load_text_resource(path: Path) -> str: | |
| with path.open("r") as file: | |
| resource = file.read() | |
| return resource | |
| def youtube_search_link_for_genre(genre: str) -> str: | |
| base_url = "https://www.youtube.com/results" | |
| params = { | |
| "search_query": youtube_url_template.format( | |
| genre=genre.replace("_", " ").lower() | |
| ) | |
| } | |
| return f"{base_url}?{urllib.parse.urlencode(params)}" | |
| def generate_recommendation_string(ranking: dict[str, float]) -> str: | |
| recommendation_string = "## Recommendations for You\n\n" | |
| for idx, (genre, score) in enumerate(ranking.items(), start=1): | |
| youtube_link = youtube_search_link_for_genre(genre=genre) | |
| recommendation_string += ( | |
| f"{idx}. **{genre.replace('_', ' ').capitalize()}**; " | |
| f"[YouTube link]({youtube_link})\n" | |
| ) | |
| return recommendation_string | |
| def graceful_shutdown(signum, frame): | |
| print(f"{signum} received - flushing feedback …", flush=True) | |
| scheduler.trigger().result() | |
| sys.exit(0) | |
| signal.signal(signal.SIGTERM, graceful_shutdown) | |
| signal.signal(signal.SIGINT, graceful_shutdown) | |
| # -------------------------------- Data Models ------------------------------- | |
| class StructuredQueryRewriteResponse(BaseModel): | |
| general: str | None | |
| subjective: str | None | |
| purpose: str | None | |
| technical: str | None | |
| curiosity: str | None | |
| class QueryRewrite(BaseModel): | |
| rewrites: list[str] | None = None | |
| structured: StructuredQueryRewriteResponse | None = None | |
| class APIGenreRecommendation(BaseModel): | |
| name: str = Field(description="Name of the music genre.") | |
| score: float = Field( | |
| description="Score you assign to the genre (from 0 to 1).", ge=0, le=1 | |
| ) | |
| class APIGenreRecommendationResponse(BaseModel): | |
| genres: list[APIGenreRecommendation] | |
| class RetrievalResult(BaseModel): | |
| chunk: str | |
| genre: str | |
| score: float | |
| class RerankingResult(BaseModel): | |
| query: str | |
| genre: str | |
| chunk: str | |
| score: float | |
| class Recommendation(BaseModel): | |
| name: str | |
| rank: int | |
| score: Optional[float] = None | |
| class PipelineResult(BaseModel): | |
| query: str | |
| rewrite: Optional[QueryRewrite] = None | |
| retrieval_result: Optional[list[RetrievalResult]] = None | |
| reranking_result: Optional[list[RerankingResult]] = None | |
| recommendations: Optional[dict[str, Recommendation]] = None | |
| def to_ranking(self) -> dict[str, float]: | |
| if not self.recommendations: | |
| return {} | |
| return { | |
| genre: recommendation.score | |
| for genre, recommendation in self.recommendations.items() | |
| } | |
| # -------------------------------- VLLM -------------------------------------- | |
| local_llm = LLM( | |
| model=VLLM_MODEL_NAME, | |
| max_model_len=VLLM_MAX_SEQ_LEN, | |
| gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, | |
| hf_token=HF_TOKEN, | |
| enforce_eager=True, | |
| dtype=VLLM_DTYPE, | |
| ) | |
| json_schema = StructuredQueryRewriteResponse.model_json_schema() | |
| guided_decoding_params_json = GuidedDecodingParams(json=json_schema) | |
| sampling_params_json = SamplingParams( | |
| guided_decoding=guided_decoding_params_json, | |
| temperature=0.7, | |
| top_p=0.8, | |
| repetition_penalty=1.05, | |
| max_tokens=1024, | |
| ) | |
| vllm_system_prompt = ( | |
| "You are a search query optimization assistant built into" | |
| " music genre search engine, helping users discover novel music genres." | |
| ) | |
| vllm_prompt = load_text_resource(Path("./resources/prompt_vllm.md")) | |
| # -------------------------------- GEMINI ------------------------------------ | |
| gemini_config = types.GenerateContentConfig( | |
| response_mime_type="application/json", | |
| response_schema=APIGenreRecommendationResponse, | |
| temperature=0.7, | |
| max_output_tokens=1024, | |
| system_instruction=( | |
| "You are a helpful music genre recommendation assistant built into" | |
| " music genre search engine, helping users discover novel music genres." | |
| ), | |
| ) | |
| gemini_llm = genai.Client( | |
| api_key=GEMINI_API_KEY, | |
| http_options={"api_version": "v1alpha"}, | |
| ) | |
| gemini_prompt = load_text_resource(Path("./resources/prompt_api.md")) | |
| # ---------------------------- EMBEDDING MODELS -------------------------------- | |
| dense_encoder = SentenceTransformer( | |
| model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", | |
| device="cuda", | |
| model_kwargs={"torch_dtype": VLLM_DTYPE}, | |
| ) | |
| sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True) | |
| reranker = CrossEncoder( | |
| model_name_or_path="BAAI/bge-reranker-v2-m3", | |
| max_length=1024, | |
| device="cuda", | |
| model_kwargs={"torch_dtype": VLLM_DTYPE}, | |
| ) | |
| reranker_batch_size = 128 | |
| # ---------------------------- RETRIEVAL --------------------------------------- | |
| def run_query_rewrite(query: str) -> QueryRewrite: | |
| prompt = vllm_prompt.format(query=query) | |
| messages = [ | |
| {"role": "system", "content": vllm_system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| outputs = local_llm.chat( | |
| messages=messages, | |
| sampling_params=sampling_params_json, | |
| ) | |
| rewrite_json = json.loads(outputs[0].outputs[0].text) | |
| rewrite = QueryRewrite( | |
| rewrites=[x for x in list(rewrite_json.values()) if x is not None], | |
| structured=rewrite_json, | |
| ) | |
| return rewrite | |
| def prepare_queries_for_retrieval( | |
| query: str, rewrite: QueryRewrite | |
| ) -> list[dict[str, str | None]]: | |
| queries_to_retrieve = [{"text": query, "topic": None}] | |
| for cat, rewrite in rewrite.structured.model_dump().items(): | |
| if rewrite is None: | |
| continue | |
| topic = cat | |
| if cat not in ["subjective", "purpose", "technical"]: | |
| topic = None | |
| queries_to_retrieve.append({"text": rewrite, "topic": topic}) | |
| return queries_to_retrieve | |
| def run_retrieval( | |
| queries: list[dict[str, str]], | |
| ) -> RetrievalResult: | |
| queries_to_embed = [query["text"] for query in queries] | |
| dense_queries = list( | |
| dense_encoder.encode( | |
| queries_to_embed, convert_to_numpy=True, normalize_embeddings=True | |
| ) | |
| ) | |
| sparse_queries = list(sparse_encoder.query_embed(queries_to_embed)) | |
| prefetches: list[qmodels.Prefetch] = [] | |
| for query, dense_query, sparse_query in zip(queries, dense_queries, sparse_queries): | |
| assert dense_query is not None and sparse_query is not None | |
| assert isinstance(dense_query, np.ndarray) and isinstance( | |
| sparse_query, SparseEmbedding | |
| ) | |
| topic = query.get("topic", None) | |
| prefetch = [ | |
| qmodels.Prefetch( | |
| query=dense_query, | |
| using="dense", | |
| filter=qmodels.Filter( | |
| must=[ | |
| qmodels.FieldCondition( | |
| key="topic", match=qmodels.MatchValue(value=topic) | |
| ) | |
| ] | |
| ) | |
| if topic is not None | |
| else None, | |
| limit=num_chunks_base, | |
| ), | |
| qmodels.Prefetch( | |
| query=qmodels.SparseVector(**sparse_query.as_object()), | |
| using="sparse", | |
| filter=qmodels.Filter( | |
| must=[ | |
| qmodels.FieldCondition( | |
| key="topic", match=qmodels.MatchValue(value=topic) | |
| ) | |
| ] | |
| ) | |
| if topic is not None | |
| else None, | |
| limit=num_chunks_base, | |
| ), | |
| ] | |
| prefetches.extend(prefetch) | |
| retrieval_results = client.query_points( | |
| collection_name=collection_name, | |
| prefetch=prefetches, | |
| query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF), | |
| limit=num_chunks_base, | |
| ) | |
| final_hits: list[RetrievalResult] = [ | |
| RetrievalResult( | |
| chunk=hit.payload["text"], genre=hit.payload["genre"], score=hit.score | |
| ) | |
| for hit in retrieval_results.points | |
| ] | |
| return final_hits | |
| def run_reranking( | |
| query: str, retrieval_result: list[RetrievalResult] | |
| ) -> list[RerankingResult]: | |
| hit_texts: list[str] = [result.chunk for result in retrieval_result] | |
| hit_genres: list[str] = [result.genre for result in retrieval_result] | |
| hit_rerank = reranker.rank( | |
| query=query, | |
| documents=hit_texts, | |
| batch_size=reranker_batch_size, | |
| ) | |
| ranking = [ | |
| RerankingResult( | |
| query=query, | |
| genre=hit_genres[hit["corpus_id"]], | |
| chunk=hit_texts[hit["corpus_id"]], | |
| score=hit["score"], | |
| ) | |
| for hit in hit_rerank | |
| ] | |
| ranking.sort(key=lambda x: x.score, reverse=True) | |
| return ranking | |
| def get_top_genres( | |
| df: pd.DataFrame, | |
| column: str, | |
| alpha: float = 1.0, | |
| # beta: float = 1.0, | |
| top_k: int | None = None, | |
| ) -> pd.Series: | |
| assert 0 <= alpha <= 1.0 | |
| # Min-max normalization of re-ranker scores before aggregation | |
| task_scores = df[column] | |
| min_score = task_scores.min() | |
| max_score = task_scores.max() | |
| if max_score > min_score: # Avoid division by zero | |
| df.loc[:, column] = (task_scores - min_score) / (max_score - min_score) | |
| tg_df = df.groupby("genre").agg(size=("chunk", "size"), score=(column, "sum")) | |
| tg_df["weighted_score"] = alpha * (tg_df["size"] / tg_df["size"].max()) + ( | |
| 1 - alpha | |
| ) * (tg_df["score"] / tg_df["score"].max()) | |
| tg = tg_df.sort_values("weighted_score", ascending=False)["weighted_score"] | |
| if top_k: | |
| tg = tg.head(top_k) | |
| return tg | |
| def get_recommendations( | |
| reranking_result: list[RerankingResult], | |
| ) -> dict[str, Recommendation]: | |
| ranking_df = pd.DataFrame([x.model_dump(mode="python") for x in reranking_result]) | |
| top_genres_series = get_top_genres( | |
| df=ranking_df, column="score", alpha=alpha, top_k=top_k | |
| ) | |
| recommendations = { | |
| genre: Recommendation(name=genre, rank=rank, score=score) | |
| for rank, (genre, score) in enumerate( | |
| top_genres_series.to_dict().items(), start=1 | |
| ) | |
| } | |
| return recommendations | |
| # ----------------------- GENERATE RECOMMENDATIONS ----------------------------- | |
| def recommend_sadaimrec(query: str): | |
| result = PipelineResult(query=query) | |
| print("Running query processing...", flush=True) | |
| result.rewrite = run_query_rewrite(query=query) | |
| print(f"Rewrites:\n{result.rewrite.model_dump_json(indent=4)}") | |
| queries_to_retrieve = prepare_queries_for_retrieval( | |
| query=query, rewrite=result.rewrite | |
| ) | |
| print("Running retrieval...", flush=True) | |
| result.retrieval_result = run_retrieval(queries_to_retrieve) | |
| print("Running re-ranking...", flush=True) | |
| result.reranking_result = run_reranking( | |
| query=query, retrieval_result=result.retrieval_result | |
| ) | |
| print("Aggregating recommendations...", flush=True) | |
| result.recommendations = get_recommendations(result.reranking_result) | |
| recommendation_string = generate_recommendation_string(result.to_ranking()) | |
| return f"{recommendation_string}" | |
| def recommend_gemini(query: str): | |
| print("Generating recommendations using Gemini...", flush=True) | |
| prompt = gemini_prompt.format(query=query) | |
| response = gemini_llm.models.generate_content( | |
| model="gemini-2.0-flash", | |
| contents=prompt, | |
| config=gemini_config, | |
| ) | |
| parsed_content: APIGenreRecommendationResponse = response.parsed | |
| parsed_content.genres.sort(key=lambda x: x.score, reverse=True) | |
| ranking = {x.name.lower(): x.score for x in parsed_content.genres} | |
| recommendation_string = generate_recommendation_string(ranking) | |
| return f"{recommendation_string}" | |
| # -------------------------------------- INTERFACE ----------------------------- | |
| pipelines = { | |
| "sadaimrec": recommend_sadaimrec, | |
| "gemini": recommend_gemini, | |
| } | |
| def generate_responses(query): | |
| if not query.strip(): | |
| raise gr.Error("Please enter a query before submitting.") | |
| # Randomize model order | |
| pipeline_names = list(pipelines.keys()) | |
| random.shuffle(pipeline_names) | |
| # Generate responses | |
| resp1 = pipelines[pipeline_names[0]](query) | |
| resp2 = pipelines[pipeline_names[1]](query) | |
| # Return texts and hidden labels | |
| return resp1, resp2, pipeline_names[0], pipeline_names[1] | |
| # Callback to capture vote | |
| def handle_vote(nickname, query, selected, label1, label2, resp1, resp2): | |
| nick = nickname.strip() or uuid4().hex[:8] | |
| winner_name, loser_name = ( | |
| (label1, label2) if selected == "Option 1 (left)" else (label2, label1) | |
| ) | |
| winner_resp, loser_resp = ( | |
| (resp1, resp2) if selected == "Option 1 (left)" else (resp2, resp1) | |
| ) | |
| print( | |
| ( | |
| f"User voted:\nwinner = {winner_name}: {winner_resp};" | |
| f" loser = {loser_name}: {loser_resp}" | |
| ), | |
| flush=True, | |
| ) | |
| # ---------- persist feedback locally ---------- | |
| entry = { | |
| "ts": datetime.now().isoformat(timespec="seconds") + "Z", | |
| "nickname": nick, | |
| "query": query, | |
| "winner": winner_name, | |
| "loser": loser_name, | |
| "winner_response": winner_resp, | |
| "loser_response": loser_resp, | |
| } | |
| with FEEDBACK_FILE.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(entry) + "\n") | |
| return ( | |
| f"Thank you for your vote! Winner: {winner_name}. Restarting in 3 seconds...", | |
| gr.update(active=True), | |
| gr.update(value=nick), | |
| ) | |
| def reset_ui(): | |
| return ( | |
| gr.update(value="", visible=False), # hide row | |
| gr.update(value=""), # clear query | |
| gr.update(visible=False), # hide radio | |
| gr.update(visible=False), # hide vote button | |
| gr.update(value="**Generating...**"), # clear Option 1 text | |
| gr.update(value="**Generating...**"), # clear Option 2 text | |
| gr.update(value=""), # clear Model Label 1 text | |
| gr.update(value=""), # clear Model Label 2 text | |
| gr.update(value=""), # clear result | |
| gr.update(active=False), | |
| ) | |
| app_description = load_text_resource(Path("./resources/description.md")) | |
| app_instructions = load_text_resource(Path("./resources/instructions.md")) | |
| with gr.Blocks( | |
| title="sadai-mrec", theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) | |
| ) as demo: | |
| gr.Markdown(app_description) | |
| with gr.Accordion("Detailed usage instructions", open=False): | |
| gr.Markdown(app_instructions) | |
| nickname = gr.Textbox( | |
| label="Your nickname", | |
| placeholder="Leave empty to generate a random nickname on first vote within session", | |
| ) | |
| query = gr.Textbox( | |
| label="Your Query", | |
| placeholder="Calming, music for deep relaxation with echoing sounds and deep bass", | |
| ) | |
| submit_btn = gr.Button("Submit") | |
| # timer that resets ui after feedback is sent | |
| reset_timer = gr.Timer(value=3.0, active=False) | |
| # Hidden components to store model responses and names | |
| with gr.Row(visible=False) as response_row: | |
| response_1 = gr.Markdown(value="**Generating...**", label="Option 1") | |
| response_2 = gr.Markdown(value="**Generating...**", label="Option 2") | |
| model_label_1 = gr.Textbox(visible=False) | |
| model_label_2 = gr.Textbox(visible=False) | |
| # Feedback | |
| vote = gr.Radio( | |
| ["Option 1 (left)", "Option 2 (right)"], | |
| label="Select Best Response", | |
| visible=False, | |
| ) | |
| vote_btn = gr.Button("Vote", visible=False) | |
| result = gr.Textbox(label="Console", interactive=False) | |
| # On submit | |
| submit_btn.click( # generate | |
| fn=generate_responses, | |
| inputs=[query], | |
| outputs=[response_1, response_2, model_label_1, model_label_2], | |
| show_progress="full", | |
| ) | |
| submit_btn.click( # update ui | |
| fn=lambda: ( | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ), | |
| inputs=None, | |
| outputs=[response_row, vote, vote_btn], | |
| ) | |
| # Feedback handling | |
| vote_btn.click( | |
| fn=handle_vote, | |
| inputs=[ | |
| nickname, | |
| query, | |
| vote, | |
| model_label_1, | |
| model_label_2, | |
| response_1, | |
| response_2, | |
| ], | |
| outputs=[result, reset_timer, nickname], | |
| ) | |
| reset_timer.tick( | |
| fn=reset_ui, | |
| inputs=None, | |
| outputs=[ | |
| response_row, | |
| query, | |
| vote, | |
| vote_btn, | |
| response_1, | |
| response_2, | |
| model_label_1, | |
| model_label_2, | |
| result, | |
| reset_timer, | |
| ], | |
| trigger_mode="once", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10, default_concurrency_limit=1).launch( | |
| server_name="0.0.0.0", server_port=7860 | |
| ) | |