Spaces:
Sleeping
Sleeping
Commit
·
a134a9d
1
Parent(s):
b5b3015
fixed args
Browse files- demo_watermark.py +16 -6
demo_watermark.py
CHANGED
|
@@ -157,7 +157,7 @@ def parse_args():
|
|
| 157 |
args = parser.parse_args()
|
| 158 |
return args
|
| 159 |
|
| 160 |
-
def load_model():
|
| 161 |
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
|
| 162 |
args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
|
| 163 |
if args.is_seq2seq_model:
|
|
@@ -178,7 +178,7 @@ def load_model():
|
|
| 178 |
|
| 179 |
return model, tokenizer, device
|
| 180 |
|
| 181 |
-
def generate(prompt, args, model=None, tokenizer=None):
|
| 182 |
|
| 183 |
print(f"Generating with {args}")
|
| 184 |
|
|
@@ -261,7 +261,7 @@ def detect(input_text, args, device=None, tokenizer=None):
|
|
| 261 |
|
| 262 |
def run_gradio(args, model=None, device=None, tokenizer=None):
|
| 263 |
|
| 264 |
-
generate_partial = partial(generate, model=model, tokenizer=tokenizer)
|
| 265 |
detect_partial = partial(detect, device=device, tokenizer=tokenizer)
|
| 266 |
|
| 267 |
with gr.Blocks() as demo:
|
|
@@ -447,9 +447,19 @@ def main(args):
|
|
| 447 |
print("Prompt:")
|
| 448 |
print(input_text)
|
| 449 |
|
| 450 |
-
_, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
|
| 451 |
-
|
| 452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
print("#"*term_width)
|
| 455 |
print("Output without watermark:")
|
|
|
|
| 157 |
args = parser.parse_args()
|
| 158 |
return args
|
| 159 |
|
| 160 |
+
def load_model(args):
|
| 161 |
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
|
| 162 |
args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
|
| 163 |
if args.is_seq2seq_model:
|
|
|
|
| 178 |
|
| 179 |
return model, tokenizer, device
|
| 180 |
|
| 181 |
+
def generate(prompt, args, model=None, device=None, tokenizer=None):
|
| 182 |
|
| 183 |
print(f"Generating with {args}")
|
| 184 |
|
|
|
|
| 261 |
|
| 262 |
def run_gradio(args, model=None, device=None, tokenizer=None):
|
| 263 |
|
| 264 |
+
generate_partial = partial(generate, model=model, device=None, tokenizer=tokenizer)
|
| 265 |
detect_partial = partial(detect, device=device, tokenizer=tokenizer)
|
| 266 |
|
| 267 |
with gr.Blocks() as demo:
|
|
|
|
| 447 |
print("Prompt:")
|
| 448 |
print(input_text)
|
| 449 |
|
| 450 |
+
_, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
|
| 451 |
+
args,
|
| 452 |
+
model=model,
|
| 453 |
+
device=device,
|
| 454 |
+
tokenizer=tokenizer)
|
| 455 |
+
without_watermark_detection_result = detect(decoded_output_without_watermark,
|
| 456 |
+
args,
|
| 457 |
+
device=device,
|
| 458 |
+
tokenizer=tokenizer)
|
| 459 |
+
with_watermark_detection_result = detect(decoded_output_with_watermark,
|
| 460 |
+
args,
|
| 461 |
+
device=device,
|
| 462 |
+
tokenizer=tokenizer)
|
| 463 |
|
| 464 |
print("#"*term_width)
|
| 465 |
print("Output without watermark:")
|