Spaces:
Runtime error
Runtime error
Commit
·
9c4915e
1
Parent(s):
f598a68
Update chat.py
Browse files
chat.py
CHANGED
|
@@ -131,7 +131,7 @@ class Chat:
|
|
| 131 |
keywords = [stop_str]
|
| 132 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 133 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
| 134 |
-
|
| 135 |
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens - num_seg_tokens - num_depth_tokens)
|
| 136 |
|
| 137 |
if max_new_tokens < 1:
|
|
@@ -159,30 +159,30 @@ class Chat:
|
|
| 159 |
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
| 160 |
|
| 161 |
def generate_stream_gate(self, params):
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
|
| 187 |
|
| 188 |
if __name__ == "__main__":
|
|
|
|
| 131 |
keywords = [stop_str]
|
| 132 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 133 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
| 134 |
+
|
| 135 |
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens - num_seg_tokens - num_depth_tokens)
|
| 136 |
|
| 137 |
if max_new_tokens < 1:
|
|
|
|
| 159 |
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
| 160 |
|
| 161 |
def generate_stream_gate(self, params):
|
| 162 |
+
try:
|
| 163 |
+
for x in self.generate_stream(params):
|
| 164 |
+
yield x
|
| 165 |
+
except ValueError as e:
|
| 166 |
+
print("Caught ValueError:", e)
|
| 167 |
+
ret = {
|
| 168 |
+
"text": server_error_msg,
|
| 169 |
+
"error_code": 1,
|
| 170 |
+
}
|
| 171 |
+
yield json.dumps(ret).encode()
|
| 172 |
+
except torch.cuda.CudaError as e:
|
| 173 |
+
print("Caught torch.cuda.CudaError:", e)
|
| 174 |
+
ret = {
|
| 175 |
+
"text": server_error_msg,
|
| 176 |
+
"error_code": 1,
|
| 177 |
+
}
|
| 178 |
+
yield json.dumps(ret).encode()
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print("Caught Unknown Error", e)
|
| 181 |
+
ret = {
|
| 182 |
+
"text": server_error_msg,
|
| 183 |
+
"error_code": 1,
|
| 184 |
+
}
|
| 185 |
+
yield json.dumps(ret).encode()
|
| 186 |
|
| 187 |
|
| 188 |
if __name__ == "__main__":
|