pratyushmaini commited on
Commit
6da6bfa
·
1 Parent(s): e9867ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -33
app.py CHANGED
@@ -13,45 +13,91 @@ model_list = {
13
  "Mix IFT V2 - Score0 Only MBS16 GBS1024": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_only-300B-mbs16-gbs1024-16feb-lr2e-05-gbs16"
14
  }
15
 
 
 
 
 
 
 
 
16
 
17
  def respond(message, history, system_message, max_tokens, temperature, top_p, selected_model):
18
  try:
19
- # Create an InferenceClient for the selected model
20
- client = InferenceClient(model_list.get(selected_model, "HuggingFaceH4/zephyr-7b-beta"))
21
-
22
- # Build conversation messages for the client
23
- messages = [{"role": "system", "content": system_message}]
24
- for user_msg, assistant_msg in history:
25
- if user_msg: # Only add non-empty messages
26
- messages.append({"role": "user", "content": user_msg})
27
- if assistant_msg: # Only add non-empty messages
28
- messages.append({"role": "assistant", "content": assistant_msg})
29
- messages.append({"role": "user", "content": message})
30
 
31
- response = ""
 
32
 
33
- # Stream the response from the client
34
- for token_message in client.chat_completion(
35
- messages,
36
- max_tokens=max_tokens,
37
- stream=True,
38
- temperature=temperature,
39
- top_p=top_p,
40
- ):
41
- # Safe extraction of token with error handling
42
- try:
43
- token = token_message.choices[0].delta.content
44
- if token is not None: # Handle potential None values
45
- response += token
46
- yield response
47
- except (AttributeError, IndexError) as e:
48
- # Handle cases where token structure might be different
49
- print(f"Error extracting token: {e}")
50
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except Exception as e:
52
- # Return error message if the model call fails
53
- print(f"Error calling model API: {e}")
54
- yield f"Sorry, there was an error: {str(e)}"
 
 
 
 
 
 
 
 
55
 
56
  # Custom CSS for styling
57
  css = """
@@ -118,6 +164,9 @@ with gr.Blocks(css=css) as demo:
118
  </h1>
119
  </div>
120
  """)
 
 
 
121
 
122
  with gr.Row():
123
  # Left sidebar: Model selector
 
13
  "Mix IFT V2 - Score0 Only MBS16 GBS1024": "locuslab/mix_ift_v2-smollm2-360m-smollm2-360m-score0_only-300B-mbs16-gbs1024-16feb-lr2e-05-gbs16"
14
  }
15
 
16
+ # Dictionary to track which models support chat completion vs. text generation
17
+ model_tasks = {
18
+ "HuggingFaceH4/zephyr-7b-beta": "chat-completion", # This model supports chat completion
19
+ # Add other models that support chat completion
20
+ }
21
+ # Default to text-generation for models not specified above
22
+
23
 
24
  def respond(message, history, system_message, max_tokens, temperature, top_p, selected_model):
25
  try:
26
+ # Get the model ID for the selected model
27
+ model_id = model_list.get(selected_model, "HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
28
 
29
+ # Create an InferenceClient for the selected model
30
+ client = InferenceClient(model_id)
31
 
32
+ # Check if the model supports chat completion
33
+ if model_tasks.get(model_id) == "chat-completion":
34
+ # Handle as chat completion
35
+ messages = [{"role": "system", "content": system_message}]
36
+ for user_msg, assistant_msg in history:
37
+ if user_msg: # Only add non-empty messages
38
+ messages.append({"role": "user", "content": user_msg})
39
+ if assistant_msg: # Only add non-empty messages
40
+ messages.append({"role": "assistant", "content": assistant_msg})
41
+ messages.append({"role": "user", "content": message})
42
+
43
+ response = ""
44
+
45
+ # Stream the response from the client
46
+ for token_message in client.chat_completion(
47
+ messages,
48
+ max_tokens=max_tokens,
49
+ stream=True,
50
+ temperature=temperature,
51
+ top_p=top_p,
52
+ ):
53
+ # Safe extraction of token with error handling
54
+ try:
55
+ token = token_message.choices[0].delta.content
56
+ if token is not None: # Handle potential None values
57
+ response += token
58
+ yield response
59
+ except (AttributeError, IndexError) as e:
60
+ # Handle cases where token structure might be different
61
+ print(f"Error extracting token: {e}")
62
+ continue
63
+ else:
64
+ # Handle as text generation for models that don't support chat completion
65
+ # Format the prompt manually for text generation
66
+ formatted_prompt = f"{system_message}\n\n"
67
+
68
+ for user_msg, assistant_msg in history:
69
+ if user_msg:
70
+ formatted_prompt += f"User: {user_msg}\n"
71
+ if assistant_msg:
72
+ formatted_prompt += f"Assistant: {assistant_msg}\n"
73
+
74
+ formatted_prompt += f"User: {message}\nAssistant:"
75
+
76
+ response = ""
77
+
78
+ # Use text generation instead of chat completion
79
+ for token in client.text_generation(
80
+ formatted_prompt,
81
+ max_new_tokens=max_tokens,
82
+ stream=True,
83
+ temperature=temperature,
84
+ top_p=top_p,
85
+ ):
86
+ response += token
87
+ yield response
88
+
89
  except Exception as e:
90
+ # Return detailed error message if the model call fails
91
+ error_message = str(e)
92
+ print(f"Error calling model API: {error_message}")
93
+
94
+ # Check for specific error types and give more helpful messages
95
+ if "Task not found" in error_message:
96
+ yield ("Sorry, the selected model doesn't support chat completion. "
97
+ "I'm switching to text generation mode. Please try again.")
98
+ else:
99
+ yield f"Sorry, there was an error: {error_message}"
100
+
101
 
102
  # Custom CSS for styling
103
  css = """
 
164
  </h1>
165
  </div>
166
  """)
167
+
168
+ # Status message for API errors
169
+ status_message = gr.Markdown("", elem_id="status-message")
170
 
171
  with gr.Row():
172
  # Left sidebar: Model selector