Mel Seto commited on
Commit
c78b0c8
·
1 Parent(s): a40ed0a

add RAG option to app

Browse files
Files changed (2) hide show
  1. requirements.txt +0 -133
  2. src/app.py +34 -6
requirements.txt CHANGED
@@ -1,9 +1,5 @@
1
  # This file was autogenerated by uv via the following command:
2
- <<<<<<< HEAD
3
- # uv pip compile pyproject.toml -o requirements.txt
4
- =======
5
  # uv export --no-hashes --format requirements-txt
6
- >>>>>>> 660f6fb (organizing files into src folder etc.)
7
  aiofiles==24.1.0
8
  # via gradio
9
  aiohappyeyeballs==2.6.1
@@ -20,14 +16,6 @@ anyio==4.10.0
20
  # gradio
21
  # httpx
22
  # starlette
23
- <<<<<<< HEAD
24
- attrs==25.3.0
25
- # via aiohttp
26
- brotli==1.1.0
27
- # via gradio
28
- cerebras-cloud-sdk==1.50.1
29
- # via chinese-idioms (pyproject.toml)
30
- =======
31
  # watchfiles
32
  astroid==3.3.11
33
  # via pylint
@@ -40,7 +28,6 @@ brotli==1.1.0
40
  # via gradio
41
  cerebras-cloud-sdk==1.50.1
42
  # via chinese-idioms
43
- >>>>>>> 660f6fb (organizing files into src folder etc.)
44
  certifi==2025.8.3
45
  # via
46
  # httpcore
@@ -50,12 +37,6 @@ charset-normalizer==3.4.3
50
  # via requests
51
  click==8.2.1
52
  # via
53
- <<<<<<< HEAD
54
- # typer
55
- # uvicorn
56
- datasets==4.1.0
57
- # via chinese-idioms (pyproject.toml)
58
- =======
59
  # black
60
  # typer
61
  # uvicorn
@@ -67,21 +48,14 @@ colorama==0.4.6 ; sys_platform == 'win32'
67
  # tqdm
68
  datasets==4.1.0
69
  # via chinese-idioms
70
- >>>>>>> 660f6fb (organizing files into src folder etc.)
71
  dill==0.4.0
72
  # via
73
  # datasets
74
  # multiprocess
75
- <<<<<<< HEAD
76
- distro==1.9.0
77
- # via cerebras-cloud-sdk
78
- fastapi==0.116.2
79
- =======
80
  # pylint
81
  distro==1.9.0
82
  # via cerebras-cloud-sdk
83
  fastapi==0.116.1
84
- >>>>>>> 660f6fb (organizing files into src folder etc.)
85
  # via gradio
86
  ffmpy==0.6.1
87
  # via gradio
@@ -89,34 +63,21 @@ filelock==3.19.1
89
  # via
90
  # datasets
91
  # huggingface-hub
92
- <<<<<<< HEAD
93
- =======
94
  # torch
95
  # transformers
96
- >>>>>>> 660f6fb (organizing files into src folder etc.)
97
  frozenlist==1.7.0
98
  # via
99
  # aiohttp
100
  # aiosignal
101
- <<<<<<< HEAD
102
- fsspec==2025.9.0
103
- =======
104
  fsspec==2025.7.0
105
- >>>>>>> 660f6fb (organizing files into src folder etc.)
106
  # via
107
  # datasets
108
  # gradio-client
109
  # huggingface-hub
110
- <<<<<<< HEAD
111
- gradio==5.46.0
112
- # via chinese-idioms (pyproject.toml)
113
- gradio-client==1.13.0
114
- =======
115
  # torch
116
  gradio==5.44.0
117
  # via chinese-idioms
118
  gradio-client==1.12.1
119
- >>>>>>> 660f6fb (organizing files into src folder etc.)
120
  # via gradio
121
  groovy==0.1.2
122
  # via gradio
@@ -124,11 +85,7 @@ h11==0.16.0
124
  # via
125
  # httpcore
126
  # uvicorn
127
- <<<<<<< HEAD
128
- hf-xet==1.1.10
129
- =======
130
  hf-xet==1.1.8 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
131
- >>>>>>> 660f6fb (organizing files into src folder etc.)
132
  # via huggingface-hub
133
  httpcore==1.0.9
134
  # via httpx
@@ -139,32 +96,20 @@ httpx==0.28.1
139
  # gradio-client
140
  # ollama
141
  # safehttpx
142
- <<<<<<< HEAD
143
- huggingface-hub==0.35.0
144
- =======
145
  huggingface-hub==0.34.4
146
- >>>>>>> 660f6fb (organizing files into src folder etc.)
147
  # via
148
  # datasets
149
  # gradio
150
  # gradio-client
151
- <<<<<<< HEAD
152
- =======
153
  # sentence-transformers
154
  # tokenizers
155
  # transformers
156
- >>>>>>> 660f6fb (organizing files into src folder etc.)
157
  idna==3.10
158
  # via
159
  # anyio
160
  # httpx
161
  # requests
162
  # yarl
163
- <<<<<<< HEAD
164
- jinja2==3.1.6
165
- # via gradio
166
- markdown-it-py==4.0.0
167
- =======
168
  iniconfig==2.1.0
169
  # via pytest
170
  isort==6.0.1
@@ -176,38 +121,23 @@ jinja2==3.1.6
176
  joblib==1.5.2
177
  # via scikit-learn
178
  markdown-it-py==4.0.0 ; sys_platform != 'emscripten'
179
- >>>>>>> 660f6fb (organizing files into src folder etc.)
180
  # via rich
181
  markupsafe==3.0.2
182
  # via
183
  # gradio
184
  # jinja2
185
- <<<<<<< HEAD
186
- mdurl==0.1.2
187
- # via markdown-it-py
188
- =======
189
  mccabe==0.7.0
190
  # via pylint
191
  mdurl==0.1.2 ; sys_platform != 'emscripten'
192
  # via markdown-it-py
193
  mpmath==1.3.0
194
  # via sympy
195
- >>>>>>> 660f6fb (organizing files into src folder etc.)
196
  multidict==6.6.4
197
  # via
198
  # aiohttp
199
  # yarl
200
  multiprocess==0.70.16
201
  # via datasets
202
- <<<<<<< HEAD
203
- numpy >= 2.0, < 3.0
204
- # via
205
- # datasets
206
- # gradio
207
- # pandas
208
- ollama==0.5.4
209
- # via chinese-idioms (pyproject.toml)
210
- =======
211
  mypy-extensions==1.1.0
212
  # via black
213
  networkx==3.5
@@ -260,32 +190,21 @@ nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'li
260
  # via torch
261
  ollama==0.5.3
262
  # via chinese-idioms
263
- >>>>>>> 660f6fb (organizing files into src folder etc.)
264
  orjson==3.11.3
265
  # via gradio
266
  packaging==25.0
267
  # via
268
- <<<<<<< HEAD
269
- =======
270
  # black
271
- >>>>>>> 660f6fb (organizing files into src folder etc.)
272
  # datasets
273
  # gradio
274
  # gradio-client
275
  # huggingface-hub
276
- <<<<<<< HEAD
277
- =======
278
  # pytest
279
  # transformers
280
- >>>>>>> 660f6fb (organizing files into src folder etc.)
281
  pandas==2.3.2
282
  # via
283
  # datasets
284
  # gradio
285
- <<<<<<< HEAD
286
- pillow==11.3.0
287
- # via gradio
288
- =======
289
  pathspec==0.12.1
290
  # via black
291
  pillow==11.3.0
@@ -298,7 +217,6 @@ platformdirs==4.4.0
298
  # pylint
299
  pluggy==1.6.0
300
  # via pytest
301
- >>>>>>> 660f6fb (organizing files into src folder etc.)
302
  propcache==0.3.2
303
  # via
304
  # aiohttp
@@ -306,13 +224,8 @@ propcache==0.3.2
306
  pyarrow==21.0.0
307
  # via datasets
308
  pycccedict==1.2.0
309
- <<<<<<< HEAD
310
- # via chinese-idioms (pyproject.toml)
311
- pydantic==2.11.9
312
- =======
313
  # via chinese-idioms
314
  pydantic==2.11.7
315
- >>>>>>> 660f6fb (organizing files into src folder etc.)
316
  # via
317
  # cerebras-cloud-sdk
318
  # fastapi
@@ -323,13 +236,6 @@ pydantic-core==2.33.2
323
  pydub==0.25.1
324
  # via gradio
325
  pygments==2.19.2
326
- <<<<<<< HEAD
327
- # via rich
328
- pypinyin==0.55.0
329
- # via chinese-idioms (pyproject.toml)
330
- python-dateutil==2.9.0.post0
331
- # via pandas
332
- =======
333
  # via
334
  # pytest
335
  # rich
@@ -340,7 +246,6 @@ pytest==8.4.2
340
  python-dateutil==2.9.0.post0
341
  # via pandas
342
  python-dotenv==1.1.1
343
- >>>>>>> 660f6fb (organizing files into src folder etc.)
344
  python-multipart==0.0.20
345
  # via gradio
346
  pytz==2025.2
@@ -350,27 +255,13 @@ pyyaml==6.0.2
350
  # datasets
351
  # gradio
352
  # huggingface-hub
353
- <<<<<<< HEAD
354
- =======
355
  # transformers
356
  regex==2025.9.18
357
  # via transformers
358
- >>>>>>> 660f6fb (organizing files into src folder etc.)
359
  requests==2.32.5
360
  # via
361
  # datasets
362
  # huggingface-hub
363
- <<<<<<< HEAD
364
- rich==14.1.0
365
- # via typer
366
- ruff==0.13.0
367
- # via gradio
368
- safehttpx==0.1.6
369
- # via gradio
370
- semantic-version==2.10.0
371
- # via gradio
372
- shellingham==1.5.4
373
- =======
374
  # transformers
375
  rich==14.1.0 ; sys_platform != 'emscripten'
376
  # via typer
@@ -395,7 +286,6 @@ setuptools==80.9.0
395
  # torch
396
  # triton
397
  shellingham==1.5.4 ; sys_platform != 'emscripten'
398
- >>>>>>> 660f6fb (organizing files into src folder etc.)
399
  # via typer
400
  six==1.17.0
401
  # via python-dateutil
@@ -403,14 +293,6 @@ sniffio==1.3.1
403
  # via
404
  # anyio
405
  # cerebras-cloud-sdk
406
- <<<<<<< HEAD
407
- starlette==0.48.0
408
- # via
409
- # fastapi
410
- # gradio
411
- tomlkit==0.13.3
412
- # via gradio
413
- =======
414
  starlette==0.47.3
415
  # via
416
  # fastapi
@@ -427,14 +309,10 @@ tomlkit==0.13.3
427
  # pylint
428
  torch==2.8.0
429
  # via sentence-transformers
430
- >>>>>>> 660f6fb (organizing files into src folder etc.)
431
  tqdm==4.67.1
432
  # via
433
  # datasets
434
  # huggingface-hub
435
- <<<<<<< HEAD
436
- typer==0.17.4
437
- =======
438
  # sentence-transformers
439
  # transformers
440
  transformers==4.56.2
@@ -442,7 +320,6 @@ transformers==4.56.2
442
  triton==3.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
443
  # via torch
444
  typer==0.16.1 ; sys_platform != 'emscripten'
445
- >>>>>>> 660f6fb (organizing files into src folder etc.)
446
  # via gradio
447
  typing-extensions==4.15.0
448
  # via
@@ -455,13 +332,9 @@ typing-extensions==4.15.0
455
  # huggingface-hub
456
  # pydantic
457
  # pydantic-core
458
- <<<<<<< HEAD
459
- # starlette
460
- =======
461
  # sentence-transformers
462
  # starlette
463
  # torch
464
- >>>>>>> 660f6fb (organizing files into src folder etc.)
465
  # typer
466
  # typing-inspection
467
  typing-inspection==0.4.1
@@ -469,18 +342,12 @@ typing-inspection==0.4.1
469
  tzdata==2025.2
470
  # via pandas
471
  urllib3==2.5.0
472
- <<<<<<< HEAD
473
- # via requests
474
- uvicorn==0.35.0
475
- # via gradio
476
- =======
477
  # via
478
  # gradio
479
  # requests
480
  uvicorn==0.35.0 ; sys_platform != 'emscripten'
481
  # via gradio
482
  watchfiles==1.1.0
483
- >>>>>>> 660f6fb (organizing files into src folder etc.)
484
  websockets==15.0.1
485
  # via gradio-client
486
  xxhash==3.5.0
 
1
  # This file was autogenerated by uv via the following command:
 
 
 
2
  # uv export --no-hashes --format requirements-txt
 
3
  aiofiles==24.1.0
4
  # via gradio
5
  aiohappyeyeballs==2.6.1
 
16
  # gradio
17
  # httpx
18
  # starlette
 
 
 
 
 
 
 
 
19
  # watchfiles
20
  astroid==3.3.11
21
  # via pylint
 
28
  # via gradio
29
  cerebras-cloud-sdk==1.50.1
30
  # via chinese-idioms
 
31
  certifi==2025.8.3
32
  # via
33
  # httpcore
 
37
  # via requests
38
  click==8.2.1
39
  # via
 
 
 
 
 
 
40
  # black
41
  # typer
42
  # uvicorn
 
48
  # tqdm
49
  datasets==4.1.0
50
  # via chinese-idioms
 
51
  dill==0.4.0
52
  # via
53
  # datasets
54
  # multiprocess
 
 
 
 
 
55
  # pylint
56
  distro==1.9.0
57
  # via cerebras-cloud-sdk
58
  fastapi==0.116.1
 
59
  # via gradio
60
  ffmpy==0.6.1
61
  # via gradio
 
63
  # via
64
  # datasets
65
  # huggingface-hub
 
 
66
  # torch
67
  # transformers
 
68
  frozenlist==1.7.0
69
  # via
70
  # aiohttp
71
  # aiosignal
 
 
 
72
  fsspec==2025.7.0
 
73
  # via
74
  # datasets
75
  # gradio-client
76
  # huggingface-hub
 
 
 
 
 
77
  # torch
78
  gradio==5.44.0
79
  # via chinese-idioms
80
  gradio-client==1.12.1
 
81
  # via gradio
82
  groovy==0.1.2
83
  # via gradio
 
85
  # via
86
  # httpcore
87
  # uvicorn
 
 
 
88
  hf-xet==1.1.8 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
 
89
  # via huggingface-hub
90
  httpcore==1.0.9
91
  # via httpx
 
96
  # gradio-client
97
  # ollama
98
  # safehttpx
 
 
 
99
  huggingface-hub==0.34.4
 
100
  # via
101
  # datasets
102
  # gradio
103
  # gradio-client
 
 
104
  # sentence-transformers
105
  # tokenizers
106
  # transformers
 
107
  idna==3.10
108
  # via
109
  # anyio
110
  # httpx
111
  # requests
112
  # yarl
 
 
 
 
 
113
  iniconfig==2.1.0
114
  # via pytest
115
  isort==6.0.1
 
121
  joblib==1.5.2
122
  # via scikit-learn
123
  markdown-it-py==4.0.0 ; sys_platform != 'emscripten'
 
124
  # via rich
125
  markupsafe==3.0.2
126
  # via
127
  # gradio
128
  # jinja2
 
 
 
 
129
  mccabe==0.7.0
130
  # via pylint
131
  mdurl==0.1.2 ; sys_platform != 'emscripten'
132
  # via markdown-it-py
133
  mpmath==1.3.0
134
  # via sympy
 
135
  multidict==6.6.4
136
  # via
137
  # aiohttp
138
  # yarl
139
  multiprocess==0.70.16
140
  # via datasets
 
 
 
 
 
 
 
 
 
141
  mypy-extensions==1.1.0
142
  # via black
143
  networkx==3.5
 
190
  # via torch
191
  ollama==0.5.3
192
  # via chinese-idioms
 
193
  orjson==3.11.3
194
  # via gradio
195
  packaging==25.0
196
  # via
 
 
197
  # black
 
198
  # datasets
199
  # gradio
200
  # gradio-client
201
  # huggingface-hub
 
 
202
  # pytest
203
  # transformers
 
204
  pandas==2.3.2
205
  # via
206
  # datasets
207
  # gradio
 
 
 
 
208
  pathspec==0.12.1
209
  # via black
210
  pillow==11.3.0
 
217
  # pylint
218
  pluggy==1.6.0
219
  # via pytest
 
220
  propcache==0.3.2
221
  # via
222
  # aiohttp
 
224
  pyarrow==21.0.0
225
  # via datasets
226
  pycccedict==1.2.0
 
 
 
 
227
  # via chinese-idioms
228
  pydantic==2.11.7
 
229
  # via
230
  # cerebras-cloud-sdk
231
  # fastapi
 
236
  pydub==0.25.1
237
  # via gradio
238
  pygments==2.19.2
 
 
 
 
 
 
 
239
  # via
240
  # pytest
241
  # rich
 
246
  python-dateutil==2.9.0.post0
247
  # via pandas
248
  python-dotenv==1.1.1
 
249
  python-multipart==0.0.20
250
  # via gradio
251
  pytz==2025.2
 
255
  # datasets
256
  # gradio
257
  # huggingface-hub
 
 
258
  # transformers
259
  regex==2025.9.18
260
  # via transformers
 
261
  requests==2.32.5
262
  # via
263
  # datasets
264
  # huggingface-hub
 
 
 
 
 
 
 
 
 
 
 
265
  # transformers
266
  rich==14.1.0 ; sys_platform != 'emscripten'
267
  # via typer
 
286
  # torch
287
  # triton
288
  shellingham==1.5.4 ; sys_platform != 'emscripten'
 
289
  # via typer
290
  six==1.17.0
291
  # via python-dateutil
 
293
  # via
294
  # anyio
295
  # cerebras-cloud-sdk
 
 
 
 
 
 
 
 
296
  starlette==0.47.3
297
  # via
298
  # fastapi
 
309
  # pylint
310
  torch==2.8.0
311
  # via sentence-transformers
 
312
  tqdm==4.67.1
313
  # via
314
  # datasets
315
  # huggingface-hub
 
 
 
316
  # sentence-transformers
317
  # transformers
318
  transformers==4.56.2
 
320
  triton==3.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
321
  # via torch
322
  typer==0.16.1 ; sys_platform != 'emscripten'
 
323
  # via gradio
324
  typing-extensions==4.15.0
325
  # via
 
332
  # huggingface-hub
333
  # pydantic
334
  # pydantic-core
 
 
 
335
  # sentence-transformers
336
  # starlette
337
  # torch
 
338
  # typer
339
  # typing-inspection
340
  typing-inspection==0.4.1
 
342
  tzdata==2025.2
343
  # via pandas
344
  urllib3==2.5.0
 
 
 
 
 
345
  # via
346
  # gradio
347
  # requests
348
  uvicorn==0.35.0 ; sys_platform != 'emscripten'
349
  # via gradio
350
  watchfiles==1.1.0
 
351
  websockets==15.0.1
352
  # via gradio-client
353
  xxhash==3.5.0
src/app.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  from cerebras.cloud.sdk import Cerebras
6
  from dotenv import load_dotenv
7
 
 
8
  from utils.utils import get_pinyin
9
 
10
  # ======================
@@ -92,11 +93,30 @@ Answer:"""
92
  # ======================
93
  # UI Wrapper
94
  # ======================
95
- def update_ui(situation):
96
- if USE_MOCK:
97
- idiom, explanation = generate_idiom_mock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  else:
99
- idiom, explanation = generate_idiom(situation)
 
100
 
101
  return (
102
  f"<div class='idiom-output'>{idiom}</div>",
@@ -108,9 +128,9 @@ def update_ui(situation):
108
  # Launch app
109
  # ======================
110
  def launch_app():
 
111
  with gr.Blocks(css="style.css") as demo:
112
  gr.Markdown("# 🎋 Chinese Idiom Finder")
113
-
114
  with gr.Row():
115
  with gr.Column():
116
  situation = gr.Textbox(
@@ -118,6 +138,11 @@ def launch_app():
118
  lines=2,
119
  placeholder="e.g., When facing a big challenge",
120
  )
 
 
 
 
 
121
  generate_btn = gr.Button("✨ Find Idiom")
122
 
123
  # ✅ Example situations
@@ -138,9 +163,12 @@ def launch_app():
138
 
139
  # pylint: disable=no-member
140
  generate_btn.click(
141
- fn=update_ui, inputs=situation, outputs=[idiom_output, explanation_output]
 
 
142
  )
143
 
 
144
  demo.launch()
145
 
146
 
 
5
  from cerebras.cloud.sdk import Cerebras
6
  from dotenv import load_dotenv
7
 
8
+ from retrieval.retriever import retrieve_idiom
9
  from utils.utils import get_pinyin
10
 
11
  # ======================
 
93
  # ======================
94
  # UI Wrapper
95
  # ======================
96
+ def update_ui(situation, mode):
97
+ if mode == "LLM":
98
+ if USE_MOCK:
99
+ idiom, explanation = generate_idiom_mock()
100
+ else:
101
+ idiom, explanation = generate_idiom(situation)
102
+ elif mode == "RAG":
103
+ top_idioms = retrieve_idiom(situation, top_k=3)
104
+ formatted_idioms = []
105
+ for idiom_entry in top_idioms:
106
+ # Split "<Chinese>: <English>" format
107
+ if ": " in idiom_entry:
108
+ chinese, english = idiom_entry.split(": ", 1)
109
+ else:
110
+ chinese, english = idiom_entry, ""
111
+ pinyin_text = get_pinyin(chinese)
112
+ formatted_idioms.append(f"<div class='idiom-entry'><b>{chinese}</b><br>{pinyin_text}<br>{english}</div>")
113
+
114
+ # Combine all entries with horizontal separators
115
+ idiom = "<hr>".join(formatted_idioms)
116
+ explanation = "Retrieved using embeddings (RAG)."
117
  else:
118
+ idiom = "Unknown mode"
119
+ explanation = ""
120
 
121
  return (
122
  f"<div class='idiom-output'>{idiom}</div>",
 
128
  # Launch app
129
  # ======================
130
  def launch_app():
131
+
132
  with gr.Blocks(css="style.css") as demo:
133
  gr.Markdown("# 🎋 Chinese Idiom Finder")
 
134
  with gr.Row():
135
  with gr.Column():
136
  situation = gr.Textbox(
 
138
  lines=2,
139
  placeholder="e.g., When facing a big challenge",
140
  )
141
+ mode_dropdown = gr.Dropdown(
142
+ ["LLM", "RAG"],
143
+ label="Mode",
144
+ value="LLM",
145
+ )
146
  generate_btn = gr.Button("✨ Find Idiom")
147
 
148
  # ✅ Example situations
 
163
 
164
  # pylint: disable=no-member
165
  generate_btn.click(
166
+ fn=update_ui,
167
+ inputs=[situation, mode_dropdown],
168
+ outputs=[idiom_output, explanation_output],
169
  )
170
 
171
+
172
  demo.launch()
173
 
174