pratyushmaini commited on
Commit
c731b5a
·
1 Parent(s): 6da6bfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -47
app.py CHANGED
@@ -29,74 +29,99 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, se
29
  # Create an InferenceClient for the selected model
30
  client = InferenceClient(model_id)
31
 
32
- # Check if the model supports chat completion
33
- if model_tasks.get(model_id) == "chat-completion":
34
- # Handle as chat completion
35
- messages = [{"role": "system", "content": system_message}]
36
- for user_msg, assistant_msg in history:
37
- if user_msg: # Only add non-empty messages
38
- messages.append({"role": "user", "content": user_msg})
39
- if assistant_msg: # Only add non-empty messages
40
- messages.append({"role": "assistant", "content": assistant_msg})
41
- messages.append({"role": "user", "content": message})
42
-
43
- response = ""
44
-
45
- # Stream the response from the client
46
- for token_message in client.chat_completion(
47
- messages,
48
- max_tokens=max_tokens,
49
- stream=True,
50
- temperature=temperature,
51
- top_p=top_p,
52
- ):
53
- # Safe extraction of token with error handling
54
- try:
55
- token = token_message.choices[0].delta.content
56
- if token is not None: # Handle potential None values
57
- response += token
58
- yield response
59
- except (AttributeError, IndexError) as e:
60
- # Handle cases where token structure might be different
61
- print(f"Error extracting token: {e}")
62
- continue
63
- else:
64
- # Handle as text generation for models that don't support chat completion
65
  # Format the prompt manually for text generation
66
- formatted_prompt = f"{system_message}\n\n"
 
67
 
68
- for user_msg, assistant_msg in history:
69
- if user_msg:
70
- formatted_prompt += f"User: {user_msg}\n"
71
- if assistant_msg:
72
- formatted_prompt += f"Assistant: {assistant_msg}\n"
 
 
73
 
74
- formatted_prompt += f"User: {message}\nAssistant:"
 
75
 
76
  response = ""
77
 
78
  # Use text generation instead of chat completion
 
79
  for token in client.text_generation(
80
  formatted_prompt,
81
  max_new_tokens=max_tokens,
82
  stream=True,
83
  temperature=temperature,
84
  top_p=top_p,
 
85
  ):
86
  response += token
87
  yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  except Exception as e:
90
  # Return detailed error message if the model call fails
91
  error_message = str(e)
92
  print(f"Error calling model API: {error_message}")
93
-
94
- # Check for specific error types and give more helpful messages
95
- if "Task not found" in error_message:
96
- yield ("Sorry, the selected model doesn't support chat completion. "
97
- "I'm switching to text generation mode. Please try again.")
98
- else:
99
- yield f"Sorry, there was an error: {error_message}"
100
 
101
 
102
  # Custom CSS for styling
 
29
  # Create an InferenceClient for the selected model
30
  client = InferenceClient(model_id)
31
 
32
+ # Always use text generation for locuslab models
33
+ if "locuslab" in model_id:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Format the prompt manually for text generation
35
+ # Simple formatting that works with most models
36
+ formatted_prompt = ""
37
 
38
+ # Add minimal formatting for better results with research models
39
+ if len(history) > 0:
40
+ # Include minimal context from history
41
+ last_exchanges = history[-1:] # Just use the last exchange
42
+ for user_msg, assistant_msg in last_exchanges:
43
+ if user_msg:
44
+ formatted_prompt += f"{user_msg}\n"
45
 
46
+ # Add current message - keep it simple
47
+ formatted_prompt += f"{message}"
48
 
49
  response = ""
50
 
51
  # Use text generation instead of chat completion
52
+ print(f"Using text generation with prompt: {formatted_prompt}")
53
  for token in client.text_generation(
54
  formatted_prompt,
55
  max_new_tokens=max_tokens,
56
  stream=True,
57
  temperature=temperature,
58
  top_p=top_p,
59
+ do_sample=True # Enable sampling for more creative responses
60
  ):
61
  response += token
62
  yield response
63
+ else:
64
+ # Try chat completion for standard models
65
+ try:
66
+ messages = [{"role": "system", "content": system_message}]
67
+ for user_msg, assistant_msg in history:
68
+ if user_msg: # Only add non-empty messages
69
+ messages.append({"role": "user", "content": user_msg})
70
+ if assistant_msg: # Only add non-empty messages
71
+ messages.append({"role": "assistant", "content": assistant_msg})
72
+ messages.append({"role": "user", "content": message})
73
+
74
+ response = ""
75
+
76
+ # Stream the response from the client
77
+ for token_message in client.chat_completion(
78
+ messages,
79
+ max_tokens=max_tokens,
80
+ stream=True,
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ ):
84
+ # Safe extraction of token with error handling
85
+ try:
86
+ token = token_message.choices[0].delta.content
87
+ if token is not None: # Handle potential None values
88
+ response += token
89
+ yield response
90
+ except (AttributeError, IndexError) as e:
91
+ # Handle cases where token structure might be different
92
+ print(f"Error extracting token: {e}")
93
+ continue
94
+ except Exception as e:
95
+ # If chat completion fails, fall back to text generation
96
+ print(f"Chat completion failed: {e}. Falling back to text generation.")
97
+ formatted_prompt = f"{system_message}\n\n"
98
+
99
+ for user_msg, assistant_msg in history:
100
+ if user_msg:
101
+ formatted_prompt += f"User: {user_msg}\n"
102
+ if assistant_msg:
103
+ formatted_prompt += f"Assistant: {assistant_msg}\n"
104
+
105
+ formatted_prompt += f"User: {message}\nAssistant:"
106
+
107
+ response = ""
108
+
109
+ # Use text generation instead of chat completion
110
+ for token in client.text_generation(
111
+ formatted_prompt,
112
+ max_new_tokens=max_tokens,
113
+ stream=True,
114
+ temperature=temperature,
115
+ top_p=top_p,
116
+ ):
117
+ response += token
118
+ yield response
119
 
120
  except Exception as e:
121
  # Return detailed error message if the model call fails
122
  error_message = str(e)
123
  print(f"Error calling model API: {error_message}")
124
+ yield f"Error: {error_message}. Please try a different model or adjust parameters."
 
 
 
 
 
 
125
 
126
 
127
  # Custom CSS for styling