Spaces:
Sleeping
Sleeping
Commit
·
34e65b0
1
Parent(s):
5ac7b67
limit to one small model that fits in 24gb vram
Browse files- app.py +9 -9
- demo_watermark.py +25 -17
app.py
CHANGED
|
@@ -24,17 +24,17 @@ arg_dict = {
|
|
| 24 |
# 'model_name_or_path': 'facebook/opt-2.7b', # historical
|
| 25 |
# 'model_name_or_path': 'facebook/opt-6.7b', # historical
|
| 26 |
# 'model_name_or_path': 'meta-llama/Llama-2-7b-hf', # historical
|
| 27 |
-
'model_name_or_path': 'meta-llama/Llama-3.
|
| 28 |
'all_models':[
|
| 29 |
-
"meta-llama/Llama-3.1-8B",
|
| 30 |
"meta-llama/Llama-3.2-3B",
|
| 31 |
-
"meta-llama/Llama-3.2-1B",
|
| 32 |
-
"Qwen/Qwen3-8B",
|
| 33 |
-
"Qwen/Qwen3-4B",
|
| 34 |
-
"Qwen/Qwen3-1.7B",
|
| 35 |
-
"Qwen/Qwen3-0.6B",
|
| 36 |
-
"Qwen/Qwen3-4B-Instruct-2507",
|
| 37 |
-
"Qwen/Qwen3-4B-Thinking-2507",
|
| 38 |
],
|
| 39 |
# 'load_fp16' : True,
|
| 40 |
'load_fp16' : False,
|
|
|
|
| 24 |
# 'model_name_or_path': 'facebook/opt-2.7b', # historical
|
| 25 |
# 'model_name_or_path': 'facebook/opt-6.7b', # historical
|
| 26 |
# 'model_name_or_path': 'meta-llama/Llama-2-7b-hf', # historical
|
| 27 |
+
'model_name_or_path': 'meta-llama/Llama-3.2-3B',
|
| 28 |
'all_models':[
|
| 29 |
+
# "meta-llama/Llama-3.1-8B", # too big for the A10G 24GB
|
| 30 |
"meta-llama/Llama-3.2-3B",
|
| 31 |
+
# "meta-llama/Llama-3.2-1B",
|
| 32 |
+
# "Qwen/Qwen3-8B", # too big for the A10G 24GB
|
| 33 |
+
# "Qwen/Qwen3-4B",
|
| 34 |
+
# "Qwen/Qwen3-1.7B",
|
| 35 |
+
# "Qwen/Qwen3-0.6B",
|
| 36 |
+
# "Qwen/Qwen3-4B-Instruct-2507",
|
| 37 |
+
# "Qwen/Qwen3-4B-Thinking-2507",
|
| 38 |
],
|
| 39 |
# 'load_fp16' : True,
|
| 40 |
'load_fp16' : False,
|
demo_watermark.py
CHANGED
|
@@ -19,6 +19,8 @@ import argparse
|
|
| 19 |
from pprint import pprint
|
| 20 |
from functools import partial
|
| 21 |
|
|
|
|
|
|
|
| 22 |
import numpy # for gradio hot reload
|
| 23 |
import gradio as gr
|
| 24 |
|
|
@@ -206,9 +208,11 @@ def load_model(args):
|
|
| 206 |
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
|
| 207 |
elif args.is_decoder_only_model:
|
| 208 |
if args.load_fp16:
|
| 209 |
-
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
|
|
|
|
| 210 |
elif args.load_bf16:
|
| 211 |
-
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16, device_map='auto')
|
|
|
|
| 212 |
else:
|
| 213 |
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
| 214 |
else:
|
|
@@ -216,12 +220,18 @@ def load_model(args):
|
|
| 216 |
|
| 217 |
if args.use_gpu:
|
| 218 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 219 |
-
if args.load_fp16 or args.load_bf16:
|
| 220 |
-
|
| 221 |
-
else:
|
| 222 |
-
|
| 223 |
else:
|
| 224 |
device = "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
model.eval()
|
| 226 |
|
| 227 |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
|
@@ -268,7 +278,7 @@ def generate_with_api(prompt, args):
|
|
| 268 |
yield all_without_words, all_with_words
|
| 269 |
|
| 270 |
|
| 271 |
-
def check_prompt(prompt, args, tokenizer, model, device=None):
|
| 272 |
|
| 273 |
# This applies to both the local and API model scenarios
|
| 274 |
if args.model_name_or_path in API_MODEL_MAP:
|
|
@@ -288,7 +298,7 @@ def check_prompt(prompt, args, tokenizer, model, device=None):
|
|
| 288 |
|
| 289 |
|
| 290 |
|
| 291 |
-
def generate(prompt, args, tokenizer, model, device=None):
|
| 292 |
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters
|
| 293 |
and generate watermarked text by passing it to the generate method of the model
|
| 294 |
as a logits processor. """
|
|
@@ -486,11 +496,10 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
| 486 |
default_prompt = args.__dict__.pop("default_prompt")
|
| 487 |
session_args = gr.State(value=args)
|
| 488 |
# note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
|
| 489 |
-
session_tokenizer = gr.State(value=lambda : tokenizer)
|
| 490 |
-
session_model = gr.State(value=lambda : model)
|
| 491 |
|
| 492 |
-
check_prompt_partial = partial(check_prompt, device=device)
|
| 493 |
-
generate_partial = partial(generate, device=device)
|
| 494 |
detect_partial = partial(detect, device=device)
|
| 495 |
|
| 496 |
with gr.Tab("Welcome"):
|
|
@@ -704,8 +713,8 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
| 704 |
""")
|
| 705 |
|
| 706 |
# Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
|
| 707 |
-
generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer
|
| 708 |
-
fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer
|
| 709 |
fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
|
| 710 |
fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
| 711 |
# Show truncated version of prompt if truncation occurred
|
|
@@ -781,6 +790,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
| 781 |
def update_model(state, old_model):
|
| 782 |
del old_model
|
| 783 |
torch.cuda.empty_cache()
|
|
|
|
| 784 |
model, _, _ = load_model(state)
|
| 785 |
return model
|
| 786 |
|
|
@@ -803,8 +813,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
| 803 |
update_model_state,inputs=[session_args, model_selector], outputs=[session_args]
|
| 804 |
).then(
|
| 805 |
update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
|
| 806 |
-
).then(
|
| 807 |
-
update_model,inputs=[session_args, session_model], outputs=[session_model]
|
| 808 |
).then(
|
| 809 |
lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
|
| 810 |
)
|
|
@@ -852,7 +860,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
| 852 |
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
| 853 |
select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
|
| 854 |
|
| 855 |
-
|
| 856 |
demo.queue()
|
| 857 |
|
| 858 |
if args.demo_public:
|
|
|
|
| 19 |
from pprint import pprint
|
| 20 |
from functools import partial
|
| 21 |
|
| 22 |
+
import gc
|
| 23 |
+
|
| 24 |
import numpy # for gradio hot reload
|
| 25 |
import gradio as gr
|
| 26 |
|
|
|
|
| 208 |
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
|
| 209 |
elif args.is_decoder_only_model:
|
| 210 |
if args.load_fp16:
|
| 211 |
+
# model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
|
| 212 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16)
|
| 213 |
elif args.load_bf16:
|
| 214 |
+
# model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16, device_map='auto')
|
| 215 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16)
|
| 216 |
else:
|
| 217 |
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
| 218 |
else:
|
|
|
|
| 220 |
|
| 221 |
if args.use_gpu:
|
| 222 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 223 |
+
# if args.load_fp16 or args.load_bf16:
|
| 224 |
+
# pass
|
| 225 |
+
# else:
|
| 226 |
+
model = model.to(device)
|
| 227 |
else:
|
| 228 |
device = "cpu"
|
| 229 |
+
|
| 230 |
+
if args.load_bf16:
|
| 231 |
+
model = model.to(torch.bfloat16)
|
| 232 |
+
if args.load_fp16:
|
| 233 |
+
model = model.to(torch.float16)
|
| 234 |
+
|
| 235 |
model.eval()
|
| 236 |
|
| 237 |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
|
|
|
| 278 |
yield all_without_words, all_with_words
|
| 279 |
|
| 280 |
|
| 281 |
+
def check_prompt(prompt, args, tokenizer, model=None, device=None):
|
| 282 |
|
| 283 |
# This applies to both the local and API model scenarios
|
| 284 |
if args.model_name_or_path in API_MODEL_MAP:
|
|
|
|
| 298 |
|
| 299 |
|
| 300 |
|
| 301 |
+
def generate(prompt, args, tokenizer, model=None, device=None):
|
| 302 |
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters
|
| 303 |
and generate watermarked text by passing it to the generate method of the model
|
| 304 |
as a logits processor. """
|
|
|
|
| 496 |
default_prompt = args.__dict__.pop("default_prompt")
|
| 497 |
session_args = gr.State(value=args)
|
| 498 |
# note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
|
| 499 |
+
session_tokenizer = gr.State(value=lambda : tokenizer)
|
|
|
|
| 500 |
|
| 501 |
+
check_prompt_partial = partial(check_prompt, model=model, device=device)
|
| 502 |
+
generate_partial = partial(generate, model=model, device=device)
|
| 503 |
detect_partial = partial(detect, device=device)
|
| 504 |
|
| 505 |
with gr.Tab("Welcome"):
|
|
|
|
| 713 |
""")
|
| 714 |
|
| 715 |
# Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
|
| 716 |
+
generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success(
|
| 717 |
+
fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success(
|
| 718 |
fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
|
| 719 |
fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
| 720 |
# Show truncated version of prompt if truncation occurred
|
|
|
|
| 790 |
def update_model(state, old_model):
|
| 791 |
del old_model
|
| 792 |
torch.cuda.empty_cache()
|
| 793 |
+
gc.collect()
|
| 794 |
model, _, _ = load_model(state)
|
| 795 |
return model
|
| 796 |
|
|
|
|
| 813 |
update_model_state,inputs=[session_args, model_selector], outputs=[session_args]
|
| 814 |
).then(
|
| 815 |
update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
|
|
|
|
|
|
|
| 816 |
).then(
|
| 817 |
lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
|
| 818 |
)
|
|
|
|
| 860 |
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
|
| 861 |
select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
|
| 862 |
|
| 863 |
+
|
| 864 |
demo.queue()
|
| 865 |
|
| 866 |
if args.demo_public:
|