smirki commited on
Commit
d8d468c
·
verified ·
1 Parent(s): c2cc818

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -146
app.py CHANGED
@@ -1,17 +1,4 @@
1
  import subprocess
2
-
3
- # Minimal essential installs (FlashAttention pinned version, skipping cuda build)
4
- subprocess.run(
5
- "pip install flash-attn==2.7.0.post2 --no-build-isolation",
6
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
7
- shell=True
8
- )
9
- subprocess.run("pip install transformers 'accelerate>=0.26.0' gradio==3.30.0", shell=True)
10
-
11
- # Optional: This can boost performance on some systems.
12
- import torch
13
- torch.backends.cudnn.benchmark = True
14
-
15
  import os
16
  import re
17
  import logging
@@ -19,103 +6,75 @@ import base64
19
  from threading import Thread
20
  from typing import List
21
 
 
22
  import gradio as gr
23
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
24
 
25
- # ----------------------------------------------------------------------
26
- # 1. Setup Model & Tokenizer
27
- # ----------------------------------------------------------------------
28
- model_name = "smirki/UIGEN-T1.1-Qwen-7B" # change as needed
29
-
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
32
 
 
 
 
 
 
 
33
  logger.info("Loading model and tokenizer...")
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_name,
36
  torch_dtype=torch.bfloat16,
37
- device_map="auto", # attempts to automatically place the model on GPU
38
  trust_remote_code=True,
39
  )
40
- model.eval() # disable dropout for faster inference
41
 
42
  tokenizer = AutoTokenizer.from_pretrained(
43
  model_name,
44
  trust_remote_code=True
45
  )
46
- logger.info("Model and tokenizer loaded successfully.")
47
 
48
- # ----------------------------------------------------------------------
49
- # 2. Two-Phase Prompt Templates
50
- # ----------------------------------------------------------------------
51
  s1_inference_prompt_think_only = """<|im_start|>user
52
  {question}<|im_end|>
53
  <|im_start|>assistant
54
  <|im_start|>think
55
  """
56
 
57
- # ----------------------------------------------------------------------
58
- # 3. Generation Parameter Setup
59
- # ----------------------------------------------------------------------
60
  THINK_MAX_NEW_TOKENS = 2048
61
  ANSWER_MAX_NEW_TOKENS = 2048
62
 
63
  def initialize_gen_kwargs():
64
  return {
65
- "max_new_tokens": 512, # default; updated dynamically for think/answer
66
  "do_sample": True,
67
  "temperature": 0.7,
68
  "top_p": 0.9,
69
  "repetition_penalty": 1.05,
70
- # "eos_token_id": model.generation_config.eos_token_id,
71
  "pad_token_id": tokenizer.pad_token_id,
72
  "use_cache": True,
73
- "streamer": None, # will attach actual streamer at runtime
74
  }
75
 
76
- # ----------------------------------------------------------------------
77
- # 4. Helper to submit chat
78
- # ----------------------------------------------------------------------
79
- def submit_chat(chatbot, text_input):
80
- if not text_input.strip():
81
- return chatbot, ""
82
- chatbot.append((text_input, ""))
83
- logger.info(f"New chat prompt: {text_input}")
84
- return chatbot, ""
85
-
86
- # ----------------------------------------------------------------------
87
- # 5. Artifacts Handling
88
- # ----------------------------------------------------------------------
89
  def extract_html_code_block(text: str) -> str:
90
- """
91
- Extract the first ```html ... ``` code block (if any).
92
- """
93
  pattern = r"```html\s*(.*?)\s*```"
94
  match = re.search(pattern, text, re.DOTALL)
95
- if match:
96
- return match.group(1).strip()
97
- return text.strip()
98
 
99
  def send_to_sandbox(html_code: str) -> str:
100
- """
101
- Renders the extracted HTML in an iframe.
102
- """
103
  encoded_html = base64.b64encode(html_code.encode("utf-8")).decode("utf-8")
104
  data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
105
  return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'
106
 
107
- # ----------------------------------------------------------------------
108
- # 6. The Two-Phase Streaming Inference
109
- # ----------------------------------------------------------------------
110
- def ovis_chat(chatbot: List[List[str]]):
111
- """
112
- 1) Think Phase: produce chain-of-thought (hidden to user).
113
- 2) Answer Phase: produce final user-facing answer + HTML artifact if present.
114
- """
115
- # Phase 1: "think" phase
116
- last_query = chatbot[-1][0]
117
- formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query)
118
 
 
 
119
  input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device)
120
  attention_mask_think = (input_ids_think != tokenizer.pad_token_id).to(model.device)
121
 
@@ -127,24 +86,27 @@ def ovis_chat(chatbot: List[List[str]]):
127
  full_think = ""
128
  try:
129
  with torch.inference_mode():
130
- thread_think = Thread(
131
- target=lambda: model.generate(input_ids=input_ids_think, attention_mask=attention_mask_think, **gen_kwargs_think)
 
 
 
 
132
  )
133
- thread_think.start()
134
- # Streaming tokens from 'think' phase
135
  for new_text in think_streamer:
136
  full_think += new_text
137
- # We won't log each token to reduce overhead.
138
- # Update partial chain-of-thought display:
139
- chatbot[-1][1] = f"<|im_start|>think\n{full_think.strip()}"
140
- yield chatbot, ""
141
- thread_think.join()
142
  except Exception as e:
143
  logger.error(f"Error during think phase: {e}")
144
- yield chatbot, f"Error in think phase: {str(e)}"
 
145
  return
146
 
147
- # Phase 2: "answer" phase
148
  new_prompt = (
149
  formatted_think_prompt
150
  + full_think.strip()
@@ -161,101 +123,92 @@ def ovis_chat(chatbot: List[List[str]]):
161
  full_answer = ""
162
  try:
163
  with torch.inference_mode():
164
- thread_answer = Thread(
165
- target=lambda: model.generate(input_ids=input_ids_answer, attention_mask=attention_mask_answer, **gen_kwargs_answer)
 
 
 
 
166
  )
167
- thread_answer.start()
168
- # Streaming tokens from 'answer' phase
169
  for new_text in answer_streamer:
170
  full_answer += new_text
171
- # For the UI, display both think + answer
172
  display_text = (
173
  f"<|im_start|>think\n{full_think.strip()}\n\n"
174
  f"<|im_start|>answer\n{full_answer.strip()}"
175
  )
176
- chatbot[-1][1] = display_text
177
- yield chatbot, ""
178
- thread_answer.join()
179
  except Exception as e:
180
  logger.error(f"Error during answer phase: {e}")
181
- yield chatbot, f"Error in answer phase: {str(e)}"
 
182
  return
183
 
184
- # Finally, parse out any HTML artifact from the final answer
185
- html_code = extract_html_code_block(full_answer)
186
- sandbox_iframe = send_to_sandbox(html_code)
187
- yield chatbot, sandbox_iframe
 
188
 
189
- # ----------------------------------------------------------------------
190
- # 7. Clearing
191
- # ----------------------------------------------------------------------
192
  def clear_chat():
193
  return [], "", ""
194
 
195
- # ----------------------------------------------------------------------
196
- # 8. Gradio UI Setup
197
- # ----------------------------------------------------------------------
198
- css_code = """
199
  .left_header {
200
- display: flex;
201
- flex-direction: column;
202
- justify-content: center;
203
- align-items: center;
204
  }
205
  .right_panel {
206
- margin-top: 16px;
207
- border: 1px solid #BFBFC4;
208
- border-radius: 8px;
209
- overflow: hidden;
210
  }
211
  .render_header {
212
- height: 30px;
213
- width: 100%;
214
- padding: 5px 16px;
215
- background-color: #f5f5f5;
216
  }
217
  .header_btn {
218
- display: inline-block;
219
- height: 10px;
220
- width: 10px;
221
- border-radius: 50%;
222
- margin-right: 4px;
223
- }
224
- .render_header > .header_btn:nth-child(1) {
225
- background-color: #f5222d;
226
- }
227
- .render_header > .header_btn:nth-child(2) {
228
- background-color: #faad14;
229
- }
230
- .render_header > .header_btn:nth-child(3) {
231
- background-color: #52c41a;
232
  }
 
 
 
233
  .right_content {
234
- height: 920px;
235
- display: flex;
236
- flex-direction: column;
237
- justify-content: center;
238
- align-items: center;
239
- }
240
- .html_content {
241
- width: 100%;
242
- height: 920px;
243
  }
 
244
  """
245
 
246
- svg_content = """
247
  <svg width="40" height="40" viewBox="0 0 45 45" fill="none" xmlns="http://www.w3.org/2000/svg">
248
- <circle cx="22.5" cy="22.5" r="22.5" fill="#5572F9"/>
249
- <path d="M22.5 11.25L26.25 16.875H18.75L22.5 11.25Z" fill="white"/>
250
- <path d="M22.5 33.75L26.25 28.125H18.75L22.5 33.75Z" fill="white"/>
251
- <path d="M28.125 22.5L22.5 28.125L16.875 22.5L22.5 16.875L28.125 22.5Z" fill="white"/>
252
  </svg>
253
  """
254
 
255
- with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
256
  gr.HTML(f"""
257
  <div class="left_header" style="margin-bottom: 20px;">
258
- {svg_content}
259
  <h1>{model_name.split('/')[-1]} - Chat + Artifacts</h1>
260
  <p>(Two-phase chain-of-thought with artifact extraction)</p>
261
  </div>
@@ -277,6 +230,7 @@ with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
277
  with gr.Row():
278
  submit_btn = gr.Button("Send", variant="primary")
279
  clear_btn = gr.Button("Clear", variant="secondary")
 
280
  with gr.Column(scale=6):
281
  gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>')
282
  artifact_html = gr.HTML(
@@ -284,23 +238,38 @@ with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
284
  elem_classes="html_content"
285
  )
286
 
287
- # Button logic
288
- submit_btn.click(
289
- submit_chat, [chatbot, text_input], [chatbot, text_input]
 
 
 
 
 
290
  ).then(
291
- ovis_chat, [chatbot], [chatbot, artifact_html]
 
 
292
  )
293
 
294
- text_input.submit(
295
- submit_chat, [chatbot, text_input], [chatbot, text_input]
 
 
 
 
 
296
  ).then(
297
- ovis_chat, [chatbot], [chatbot, artifact_html]
 
 
298
  )
299
 
300
  clear_btn.click(
301
- clear_chat,
302
  outputs=[chatbot, text_input, artifact_html]
303
  )
304
 
305
- logger.info("Launching Gradio demo...")
306
- demo.queue(default_concurrency_count=1).launch(server_name="0.0.0.0", share=True)
 
 
1
  import subprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import re
4
  import logging
 
6
  from threading import Thread
7
  from typing import List
8
 
9
+ import torch
10
  import gradio as gr
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
13
+ # Setup logging
 
 
 
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Optional: Performance boost
18
+ torch.backends.cudnn.benchmark = True
19
+
20
+ # Model setup
21
+ model_name = "smirki/UIGEN-T1.1-Qwen-7B"
22
+
23
  logger.info("Loading model and tokenizer...")
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_name,
26
  torch_dtype=torch.bfloat16,
27
+ device_map="auto",
28
  trust_remote_code=True,
29
  )
30
+ model.eval()
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  model_name,
34
  trust_remote_code=True
35
  )
 
36
 
37
+ # Prompt templates
 
 
38
  s1_inference_prompt_think_only = """<|im_start|>user
39
  {question}<|im_end|>
40
  <|im_start|>assistant
41
  <|im_start|>think
42
  """
43
 
44
+ # Constants
 
 
45
  THINK_MAX_NEW_TOKENS = 2048
46
  ANSWER_MAX_NEW_TOKENS = 2048
47
 
48
  def initialize_gen_kwargs():
49
  return {
50
+ "max_new_tokens": 512,
51
  "do_sample": True,
52
  "temperature": 0.7,
53
  "top_p": 0.9,
54
  "repetition_penalty": 1.05,
 
55
  "pad_token_id": tokenizer.pad_token_id,
56
  "use_cache": True,
 
57
  }
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def extract_html_code_block(text: str) -> str:
 
 
 
60
  pattern = r"```html\s*(.*?)\s*```"
61
  match = re.search(pattern, text, re.DOTALL)
62
+ return match.group(1).strip() if match else text.strip()
 
 
63
 
64
  def send_to_sandbox(html_code: str) -> str:
 
 
 
65
  encoded_html = base64.b64encode(html_code.encode("utf-8")).decode("utf-8")
66
  data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
67
  return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'
68
 
69
+ def chat_stream(history: List[List[str]], text: str):
70
+ if not text.strip():
71
+ return history
72
+
73
+ history.append([text, ""])
74
+ logger.info(f"New chat prompt: {text}")
 
 
 
 
 
75
 
76
+ # Think Phase
77
+ formatted_think_prompt = s1_inference_prompt_think_only.format(question=text)
78
  input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device)
79
  attention_mask_think = (input_ids_think != tokenizer.pad_token_id).to(model.device)
80
 
 
86
  full_think = ""
87
  try:
88
  with torch.inference_mode():
89
+ thread = Thread(
90
+ target=lambda: model.generate(
91
+ input_ids=input_ids_think,
92
+ attention_mask=attention_mask_think,
93
+ **gen_kwargs_think
94
+ )
95
  )
96
+ thread.start()
97
+
98
  for new_text in think_streamer:
99
  full_think += new_text
100
+ history[-1][1] = f"<|im_start|>think\n{full_think.strip()}"
101
+ yield history
102
+ thread.join()
 
 
103
  except Exception as e:
104
  logger.error(f"Error during think phase: {e}")
105
+ history[-1][1] = f"Error in think phase: {str(e)}"
106
+ yield history
107
  return
108
 
109
+ # Answer Phase
110
  new_prompt = (
111
  formatted_think_prompt
112
  + full_think.strip()
 
123
  full_answer = ""
124
  try:
125
  with torch.inference_mode():
126
+ thread = Thread(
127
+ target=lambda: model.generate(
128
+ input_ids=input_ids_answer,
129
+ attention_mask=attention_mask_answer,
130
+ **gen_kwargs_answer
131
+ )
132
  )
133
+ thread.start()
134
+
135
  for new_text in answer_streamer:
136
  full_answer += new_text
 
137
  display_text = (
138
  f"<|im_start|>think\n{full_think.strip()}\n\n"
139
  f"<|im_start|>answer\n{full_answer.strip()}"
140
  )
141
+ history[-1][1] = display_text
142
+ yield history
143
+ thread.join()
144
  except Exception as e:
145
  logger.error(f"Error during answer phase: {e}")
146
+ history[-1][1] = f"Error in answer phase: {str(e)}"
147
+ yield history
148
  return
149
 
150
+ def process_artifact(history: List[List[str]]):
151
+ if not history or not history[-1][1]:
152
+ return ""
153
+ html_code = extract_html_code_block(history[-1][1])
154
+ return send_to_sandbox(html_code)
155
 
 
 
 
156
  def clear_chat():
157
  return [], "", ""
158
 
159
+ # Gradio UI
160
+ css = """
 
 
161
  .left_header {
162
+ display: flex;
163
+ flex-direction: column;
164
+ justify-content: center;
165
+ align-items: center;
166
  }
167
  .right_panel {
168
+ margin-top: 16px;
169
+ border: 1px solid #BFBFC4;
170
+ border-radius: 8px;
171
+ overflow: hidden;
172
  }
173
  .render_header {
174
+ height: 30px;
175
+ width: 100%;
176
+ padding: 5px 16px;
177
+ background-color: #f5f5f5;
178
  }
179
  .header_btn {
180
+ display: inline-block;
181
+ height: 10px;
182
+ width: 10px;
183
+ border-radius: 50%;
184
+ margin-right: 4px;
 
 
 
 
 
 
 
 
 
185
  }
186
+ .render_header > .header_btn:nth-child(1) { background-color: #f5222d; }
187
+ .render_header > .header_btn:nth-child(2) { background-color: #faad14; }
188
+ .render_header > .header_btn:nth-child(3) { background-color: #52c41a; }
189
  .right_content {
190
+ height: 920px;
191
+ display: flex;
192
+ flex-direction: column;
193
+ justify-content: center;
194
+ align-items: center;
 
 
 
 
195
  }
196
+ .html_content { width: 100%; height: 920px; }
197
  """
198
 
199
+ svg_logo = """
200
  <svg width="40" height="40" viewBox="0 0 45 45" fill="none" xmlns="http://www.w3.org/2000/svg">
201
+ <circle cx="22.5" cy="22.5" r="22.5" fill="#5572F9"/>
202
+ <path d="M22.5 11.25L26.25 16.875H18.75L22.5 11.25Z" fill="white"/>
203
+ <path d="M22.5 33.75L26.25 28.125H18.75L22.5 33.75Z" fill="white"/>
204
+ <path d="M28.125 22.5L22.5 28.125L16.875 22.5L22.5 16.875L28.125 22.5Z" fill="white"/>
205
  </svg>
206
  """
207
 
208
+ with gr.Blocks(title=model_name.split('/')[-1], css=css) as demo:
209
  gr.HTML(f"""
210
  <div class="left_header" style="margin-bottom: 20px;">
211
+ {svg_logo}
212
  <h1>{model_name.split('/')[-1]} - Chat + Artifacts</h1>
213
  <p>(Two-phase chain-of-thought with artifact extraction)</p>
214
  </div>
 
230
  with gr.Row():
231
  submit_btn = gr.Button("Send", variant="primary")
232
  clear_btn = gr.Button("Clear", variant="secondary")
233
+
234
  with gr.Column(scale=6):
235
  gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>')
236
  artifact_html = gr.HTML(
 
238
  elem_classes="html_content"
239
  )
240
 
241
+ # Event handlers
242
+ text_input.submit(
243
+ fn=chat_stream,
244
+ inputs=[chatbot, text_input],
245
+ outputs=chatbot
246
+ ).then(
247
+ fn=lambda: "",
248
+ outputs=text_input
249
  ).then(
250
+ fn=process_artifact,
251
+ inputs=[chatbot],
252
+ outputs=artifact_html
253
  )
254
 
255
+ submit_btn.click(
256
+ fn=chat_stream,
257
+ inputs=[chatbot, text_input],
258
+ outputs=chatbot
259
+ ).then(
260
+ fn=lambda: "",
261
+ outputs=text_input
262
  ).then(
263
+ fn=process_artifact,
264
+ inputs=[chatbot],
265
+ outputs=artifact_html
266
  )
267
 
268
  clear_btn.click(
269
+ fn=clear_chat,
270
  outputs=[chatbot, text_input, artifact_html]
271
  )
272
 
273
+ if __name__ == "__main__":
274
+ logger.info("Launching Gradio demo...")
275
+ demo.queue(concurrency_limit=1).launch(server_name="0.0.0.0", share=True)