Spaces:
Sleeping
Sleeping
Commit
·
5b3e92c
1
Parent(s):
c3f4b90
API timeout incr to 60 sec
Browse files- demo_watermark.py +19 -10
demo_watermark.py
CHANGED
|
@@ -210,12 +210,13 @@ def load_model(args):
|
|
| 210 |
|
| 211 |
|
| 212 |
from text_generation import InferenceAPIClient
|
|
|
|
| 213 |
def generate_with_api(prompt, args):
|
| 214 |
hf_api_key = os.environ.get("HF_API_KEY")
|
| 215 |
if hf_api_key is None:
|
| 216 |
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
|
| 217 |
|
| 218 |
-
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key)
|
| 219 |
|
| 220 |
assert args.n_beams == 1, "HF API models do not support beam search."
|
| 221 |
generation_params = {
|
|
@@ -226,14 +227,22 @@ def generate_with_api(prompt, args):
|
|
| 226 |
generation_params["temperature"] = args.sampling_temp
|
| 227 |
generation_params["seed"] = args.generation_seed
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
return (output_text_without_watermark,
|
| 238 |
output_text_with_watermark)
|
| 239 |
|
|
@@ -737,7 +746,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
| 737 |
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 738 |
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 739 |
model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 740 |
-
# When the parameters change, display the update and fire detection, since some detection params dont change the model output.
|
| 741 |
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 742 |
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 743 |
gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
|
|
|
|
| 210 |
|
| 211 |
|
| 212 |
from text_generation import InferenceAPIClient
|
| 213 |
+
from requests.exceptions import ReadTimeout
|
| 214 |
def generate_with_api(prompt, args):
|
| 215 |
hf_api_key = os.environ.get("HF_API_KEY")
|
| 216 |
if hf_api_key is None:
|
| 217 |
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
|
| 218 |
|
| 219 |
+
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
|
| 220 |
|
| 221 |
assert args.n_beams == 1, "HF API models do not support beam search."
|
| 222 |
generation_params = {
|
|
|
|
| 227 |
generation_params["temperature"] = args.sampling_temp
|
| 228 |
generation_params["seed"] = args.generation_seed
|
| 229 |
|
| 230 |
+
timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
|
| 231 |
+
try:
|
| 232 |
+
generation_params["watermark"] = False
|
| 233 |
+
output = client.generate(prompt, **generation_params)
|
| 234 |
+
output_text_without_watermark = output.generated_text
|
| 235 |
+
except ReadTimeout as e:
|
| 236 |
+
print(e)
|
| 237 |
+
output_text_without_watermark = timeout_msg
|
| 238 |
+
try:
|
| 239 |
+
generation_params["watermark"] = True
|
| 240 |
+
output = client.generate(prompt, **generation_params)
|
| 241 |
+
output_text_with_watermark = output.generated_text
|
| 242 |
+
except ReadTimeout as e:
|
| 243 |
+
print(e)
|
| 244 |
+
output_text_with_watermark = timeout_msg
|
| 245 |
+
|
| 246 |
return (output_text_without_watermark,
|
| 247 |
output_text_with_watermark)
|
| 248 |
|
|
|
|
| 746 |
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 747 |
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 748 |
model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 749 |
+
# When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
|
| 750 |
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 751 |
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 752 |
gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
|