omer15699 commited on
Commit
c5cf83b
·
verified ·
1 Parent(s): 3814bb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -9
app.py CHANGED
@@ -1,13 +1,190 @@
 
 
1
  import gradio as gr
 
 
2
 
3
- def echo(txt):
4
- return f"📢 קיבלתי: {txt}"
 
 
5
 
6
- with gr.Blocks(title="Tweet UI - sanity") as demo:
7
- gr.Markdown("### Hello! תכתוב משהו ולחץ שלח")
8
- inp = gr.Textbox(label="טקסט")
9
- btn = gr.Button("שלח")
10
- out = gr.Textbox(label="תשובה", interactive=False)
11
- btn.click(fn=echo, inputs=inp, outputs=out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo.queue().launch()
 
1
+ # app.py
2
+ import os, re, functools, numpy as np, pandas as pd
3
  import gradio as gr
4
+ from datasets import load_dataset
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
 
7
+ # -------- Config --------
8
+ SAMPLE_SIZE = int(os.getenv("SAMPLE_SIZE", "3000")) # small by default for CPU Spaces
9
+ RANDOM_STATE = 42
10
+ DEFAULT_INPUT = "I am so happy with this product"
11
 
12
+ # -------- Helpers --------
13
+ def clean_text(text: str) -> str:
14
+ text = (text or "").lower()
15
+ text = re.sub(r"http\S+", "", text)
16
+ text = re.sub(r"@\w+", "", text)
17
+ text = re.sub(r"#\w+", "", text)
18
+ text = re.sub(r"[^\w\s]", "", text)
19
+ text = re.sub(r"\s+", " ", text).strip()
20
+ return text
21
+
22
+ def _to_numpy(x):
23
+ try:
24
+ import torch
25
+ if hasattr(torch, "Tensor") and isinstance(x, torch.Tensor):
26
+ return x.detach().cpu().numpy()
27
+ except Exception:
28
+ pass
29
+ return np.asarray(x)
30
+
31
+ def _l2norm(x: np.ndarray) -> np.ndarray:
32
+ x = x.astype(np.float32, copy=False)
33
+ if x.ndim == 1:
34
+ x = x.reshape(1, -1)
35
+ return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
36
+
37
+ # -------- Load sample data once (FAST: only a slice) --------
38
+ @functools.lru_cache(maxsize=1)
39
+ def load_sample_df():
40
+ # Load only a slice (e.g., first 3000 rows) instead of the full 1.6M
41
+ ds = load_dataset("sentiment140", split=f"train[:{SAMPLE_SIZE}]")
42
+ df = ds.to_pandas()
43
+
44
+ df = df.dropna(subset=["text", "sentiment"]).copy()
45
+ df["text_length"] = df["text"].str.len()
46
+ df = df[(df["text_length"] >= 5) & (df["text_length"] <= 280)].copy()
47
+ df["clean_text"] = df["text"].apply(clean_text)
48
+ df = df.sample(frac=1.0, random_state=RANDOM_STATE).reset_index(drop=True)
49
+ return df[["text", "clean_text"]]
50
+
51
+ # -------- Lazy model loaders --------
52
+ @functools.lru_cache(maxsize=None)
53
+ def load_sentence_model(model_id: str):
54
+ from sentence_transformers import SentenceTransformer
55
+ return SentenceTransformer(model_id)
56
+
57
+ @functools.lru_cache(maxsize=None)
58
+ def load_generator():
59
+ from transformers import pipeline, set_seed
60
+ set_seed(RANDOM_STATE)
61
+ return pipeline("text-generation", model="distilgpt2")
62
+
63
+ # HF model ids
64
+ EMBEDDERS = {
65
+ "MiniLM (fast)": "sentence-transformers/all-MiniLM-L6-v2",
66
+ "MPNet (heavier)": "sentence-transformers/all-mpnet-base-v2",
67
+ "DistilRoBERTa (paraphrase)": "sentence-transformers/paraphrase-distilroberta-base-v1",
68
+ }
69
+
70
+ # Cache for corpus embeddings per model
71
+ _CORPUS_CACHE = {}
72
+
73
+ def _encode_norm(model, texts):
74
+ """Encode compatibly across sentence-transformers versions; return L2-normalized numpy (n,d)."""
75
+ out = model.encode(texts, show_progress_bar=False)
76
+ out = _to_numpy(out)
77
+ return _l2norm(out)
78
+
79
+ def ensure_corpus_embeddings(model_name: str, texts: list):
80
+ if model_name in _CORPUS_CACHE:
81
+ return _CORPUS_CACHE[model_name]
82
+ model = load_sentence_model(EMBEDDERS[model_name])
83
+ emb = _encode_norm(model, texts)
84
+ _CORPUS_CACHE[model_name] = emb
85
+ return emb
86
+
87
+ # -------- Retrieval --------
88
+ def top3_for_each_model(user_input: str, selected_models: list):
89
+ df = load_sample_df()
90
+ texts = df["clean_text"].tolist()
91
+ rows = []
92
+ for name in selected_models:
93
+ try:
94
+ model = load_sentence_model(EMBEDDERS[name])
95
+ corpus_emb = ensure_corpus_embeddings(name, texts)
96
+ q = _encode_norm(model, [clean_text(user_input)])
97
+ sims = cosine_similarity(q, corpus_emb)[0]
98
+ top_idx = sims.argsort()[-3:][::-1]
99
+ for rank, i in enumerate(top_idx, start=1):
100
+ rows.append({
101
+ "Model": name,
102
+ "Rank": rank,
103
+ "Similarity": float(sims[i]),
104
+ "Tweet (clean)": texts[i],
105
+ "Tweet (orig)": df.loc[i, "text"],
106
+ })
107
+ except Exception as e:
108
+ rows.append({
109
+ "Model": name, "Rank": "-", "Similarity": "-",
110
+ "Tweet (clean)": f"[Error: {e}]", "Tweet (orig)": ""
111
+ })
112
+ return pd.DataFrame(rows, columns=["Model","Rank","Similarity","Tweet (clean)","Tweet (orig)"])
113
+
114
+ # -------- Generation + scoring (with progress) --------
115
+ def generate_and_pick_best(prompt: str, n_sequences: int, max_length: int,
116
+ temperature: float, scorer_model_name: str,
117
+ progress=gr.Progress()):
118
+ progress(0.0, desc="Loading models…")
119
+ gen = load_generator()
120
+ scorer = load_sentence_model(EMBEDDERS[scorer_model_name])
121
+
122
+ progress(0.3, desc="Generating candidates…")
123
+ outputs = gen(
124
+ prompt,
125
+ max_new_tokens=int(max_length), # number of NEW tokens to generate
126
+ num_return_sequences=int(n_sequences),
127
+ do_sample=True,
128
+ temperature=float(temperature),
129
+ pad_token_id=50256,
130
+ )
131
+ candidates = [o["generated_text"].strip() for o in outputs]
132
+
133
+ progress(0.7, desc="Scoring candidates…")
134
+ q = _encode_norm(scorer, [prompt])
135
+ cand_vecs = _encode_norm(scorer, candidates)
136
+ sims = cosine_similarity(q, cand_vecs)[0]
137
+ best_idx = int(sims.argmax())
138
+
139
+ table = pd.DataFrame({
140
+ "Rank": np.argsort(-sims) + 1,
141
+ "Similarity": np.sort(sims)[::-1],
142
+ "Generated Tweet": [c for _, c in sorted(zip(-sims, candidates))]
143
+ })
144
+ progress(1.0)
145
+ return candidates[best_idx], float(sims[best_idx]), table
146
+
147
+ # ---------------- UI ----------------
148
+ with gr.Blocks(title="Sentiment140 Embeddings + Generation") as demo:
149
+ gr.Markdown(
150
+ """
151
+ # 🧪 Sentiment140 — Embeddings & Tweet Generator
152
+ Type a tweet, get similar tweets from Sentiment140, and generate a new one.
153
+ """
154
+ )
155
+
156
+ with gr.Row():
157
+ test_input = gr.Textbox(label="Your input", value=DEFAULT_INPUT, lines=2)
158
+ models = gr.CheckboxGroup(
159
+ choices=list(EMBEDDERS.keys()),
160
+ value=["MiniLM (fast)"],
161
+ label="Embedding models to compare",
162
+ )
163
+
164
+ run_btn = gr.Button("🔎 Find Top‑3 Similar Tweets")
165
+ table_out = gr.Dataframe(interactive=False)
166
+
167
+ run_btn.click(top3_for_each_model, inputs=[test_input, models], outputs=table_out)
168
+
169
+ gr.Markdown("---")
170
+ gr.Markdown("## 📝 Generate Tweets and Pick the Best")
171
+
172
+ with gr.Row():
173
+ n_seq = gr.Slider(1, 8, value=4, step=1, label="Number of candidates")
174
+ max_len = gr.Slider(20, 80, value=40, step=1, label="Max length (new tokens)")
175
+ temp = gr.Slider(0.7, 1.3, value=0.9, step=0.05, label="Temperature")
176
+ scorer_model = gr.Dropdown(list(EMBEDDERS.keys()), value="MiniLM (fast)", label="Scorer embedding")
177
+
178
+ gen_btn = gr.Button("✨ Generate & Score")
179
+ best_txt = gr.Textbox(label="Best generated tweet")
180
+ best_score = gr.Number(label="Similarity (best)")
181
+ gen_table = gr.Dataframe(interactive=False)
182
+
183
+ gen_btn.click(
184
+ generate_and_pick_best,
185
+ inputs=[test_input, n_seq, max_len, temp, scorer_model],
186
+ outputs=[best_txt, best_score, gen_table],
187
+ )
188
+
189
+ demo.queue(max_size=32).launch()
190