jianuo commited on
Commit
c99a3a6
·
1 Parent(s): 7781fb8

first upload

Browse files
.gitattributes CHANGED
@@ -1,35 +1,34 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
MANIFEST.in ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # added manually
2
+ include *.py
3
+ include *.json
4
+ include *.md
5
+
6
+ global-exclude *.pyc
7
+ global-exclude __pycache__
__init__.py ADDED
File without changes
alternative_prf_schemes.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """实现其他 PRF 函数(这些函数的不同之处仅在于如何从上下文中的令牌生成单个哈希值)。
2
+
3
+ 可作为修改后的基类 WatermarkBase 挂接到现有的 WatermarkLogitsProcessor 中,请参见
4
+ extended_watermark_processor.py 中的实现。
5
+ """
6
+
7
+
8
+
9
+ import torch
10
+ from itertools import combinations
11
+ from functools import cache
12
+
13
+ # 哈希方案的关键属性
14
+ props = {
15
+ "prf_type": str, # 基础 PRF 的字符串名称,将多个令牌 ID 映射到随机种子
16
+ "context_width": int, # 这是论文中的 h,每个 PRF 应考虑多少个先前的令牌
17
+ "self_salt": bool, # 根据鲁棒水印技术中的规则,是否使用令牌本身来生成种子,并可能拒绝其自身的列表
18
+ "hash_key": int, # 整数,大质数,用于将种子移动到上述所选 PRF 中的低熵位序列的远离位置
19
+ }
20
+
21
+
22
+ def seeding_scheme_lookup(seeding_scheme: str):
23
+ if not isinstance(seeding_scheme, str):
24
+ raise ValueError("Seeding scheme should be a string summarizing the procedure.")
25
+ if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
26
+ # 默认的简单二元哈希 # 别名为 ff-additive_prf-1-False-15485863
27
+ prf_type = "additive_prf"
28
+ context_width = 1
29
+ self_salt = False
30
+ hash_key = 15485863
31
+ elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
32
+ prf_type = "anchored_minhash_prf"
33
+ context_width = 4
34
+ self_salt = True
35
+ hash_key = 15485863
36
+ elif seeding_scheme == "minhash":
37
+ prf_type = "minhash_prf"
38
+ context_width = 4
39
+ self_salt = False
40
+ hash_key = 15485863
41
+ elif seeding_scheme == "skipgram":
42
+ prf_type = "skipgram_prf"
43
+ context_width = 5
44
+ self_salt = False
45
+ hash_key = 15485863
46
+ elif seeding_scheme.startswith("ff"): # 自由形式的种子方案 API - 仅用于实验目的
47
+ # 期望形式为 ff-additive_prf-4-True-hash 或 ff-additive_prf-5-True (哈希键是可选的)
48
+ split_scheme = seeding_scheme.split("-")
49
+ prf_type = str(split_scheme[1])
50
+ context_width = int(split_scheme[2])
51
+ self_salt = split_scheme[3] == "True"
52
+ if len(split_scheme) == 5:
53
+ hash_key = int(split_scheme[4])
54
+ else:
55
+ hash_key = 15485863
56
+ else:
57
+ raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
58
+
59
+ assert prf_type in prf_lookup.keys()
60
+ return prf_type, context_width, self_salt, hash_key
61
+
62
+
63
+ def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
64
+ return salt_key * input_ids.prod().item()
65
+
66
+
67
+ def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
68
+ return salt_key * input_ids.sum().item()
69
+
70
+
71
+ def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
72
+ # 对于非随机输入 id(如文本),这不是一个好主意
73
+ return salt_key * input_ids.min().item()
74
+
75
+
76
+ def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
77
+ # k是一个跳跃的距离
78
+ return hashint(salt_key * input_ids[::k]).prod().item()
79
+
80
+
81
+ def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
82
+ # # 上下文内的最大距离跳字
83
+ return hashint(salt_key * input_ids[0]).item()
84
+
85
+
86
+ def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
87
+ # 上下文内的最大距离跳字
88
+ return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
89
+
90
+
91
+ def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
92
+ return hashint(salt_key * input_ids).min().item()
93
+
94
+
95
+ def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
96
+ # 另一个关键是生成一个key
97
+ return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
98
+
99
+
100
+ def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
101
+ # 上下文中所有跳字组合的最小值,k=2 表示所有对
102
+ skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
103
+ return skipgrams.prod(dim=1).min().item()
104
+
105
+
106
+ def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
107
+ key = torch.as_tensor(salt_key, dtype=torch.long)
108
+ for entry in input_ids:
109
+ key *= hashint(key * entry)
110
+ key %= 2**32
111
+ return key.item()
112
+
113
+
114
+ def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
115
+ return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()
116
+
117
+
118
+ prf_lookup = {
119
+ "multiplicative_prf": multiplicative_prf,
120
+ "additive_prf": additive_prf,
121
+ "minfunc_prf": minfunc_prf,
122
+ "simple_skip_prf": simple_skip_prf,
123
+ "skipgram_prf": skipgram_prf,
124
+ "anchored_skipgram_prf": anchored_skipgram_prf,
125
+ "minhash_prf": minhash_prf,
126
+ "anchored_minhash_prf": anchored_minhash_prf,
127
+ "minskipgram_prf": minskipgram_prf,
128
+ "noncomm_prf": noncomm_prf,
129
+ "position_prf": position_prf,
130
+ }
131
+
132
+ # 在启动时生成全局置换表一次
133
+ rng = torch.Generator(device=torch.device("cpu"))
134
+ rng.manual_seed(2971215073)
135
+ table_size = 1_000_003
136
+ fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng) # 这个速度很快
137
+
138
+
139
+ def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
140
+
141
+ return fixed_table[integer_tensor.cpu() % table_size] + 1 # 这里有一个小技巧,这个函数总是返回 CPU 的值
142
+
143
+
144
+ def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
145
+
146
+ i = integer_tensor.to(torch.int32).clone() # or torch.int16?
147
+ i -= i << 6
148
+ i ^= i >> 17
149
+ i -= i << 9
150
+ i ^= i << 4
151
+ i -= i << 3
152
+ i ^= i << 10
153
+ i ^= i >> 15
154
+ return i.to(torch.long)
155
+
156
+
157
+ @cache
158
+ def _hashint_avalanche_int(integer: int):
159
+ i = integer % (2**32)
160
+ i -= i << 6
161
+ i ^= i >> 17
162
+ i -= i << 9
163
+ i ^= i << 4
164
+ i -= i << 3
165
+ i ^= i << 10
166
+ i ^= i >> 15
167
+ return i
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 安装好环境
2
+ # python app.py即可运行
3
+ import os
4
+
5
+ os.environ['HF_ENDPOINT']='https://hf-mirror.com'
6
+
7
+ from argparse import Namespace
8
+ args = Namespace()
9
+
10
+ arg_dict = {
11
+ 'run_gradio': True,
12
+ 'demo_public': False,
13
+ 'model_name_or_path': 'Qwen/Qwen2-0.5B-Instruct',
14
+
15
+ # 'model_name_or_path': 'Qwen/Qwen2-0.5B-Instruct-GGUF',
16
+ 'gguf_file': './qwen2-0_5b-instruct-q8_0.gguf', # 只有用gguf模型会用到,即model_name_or_path里含有gguf字符串才会用到
17
+ 'prompt_max_length': None,
18
+ 'max_new_tokens': 500,
19
+ 'generation_seed': 123,
20
+ 'use_sampling': True,
21
+ 'n_beams': 1,
22
+ 'sampling_temp': 0.7,
23
+ 'use_gpu': True,
24
+ 'seeding_scheme': 'simple_1',
25
+ 'gamma': 0.5,
26
+ 'delta': 2.0,
27
+ 'normalizers': '',
28
+ 'ignore_repeated_bigrams': False,
29
+ 'detection_z_threshold': 4.0,
30
+ 'select_green_tokens': True,
31
+ 'skip_model_load': False,
32
+ 'seed_separately': True,
33
+ }
34
+
35
+ args.__dict__.update(arg_dict)
36
+
37
+ from demo_watermark import main
38
+
39
+ main(args)
demo_watermark.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from functools import partial
5
+
6
+ import gradio.exceptions
7
+
8
+ # 设置OpenMP线程数为8,优化CPU并行计算性能
9
+ os.environ["OMP_NUM_THREADS"] = "8"
10
+
11
+ import gradio as gr
12
+ import pandas as pd
13
+ import torch
14
+ from transformers import (AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ LogitsProcessorList)
17
+
18
+ from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
19
+
20
+ # FIXME 所有模型的正确长度
21
+
22
+ API_MODEL_MAP = {
23
+ # "Qwen/Qwen1.5-0.5B-Chat": {"max_length": 2000, "gamma": 0.5, "delta": 2.0},
24
+ # "THUDM/chatglm3-6b": {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
25
+ }
26
+
27
+ default_trace_table = pd.DataFrame(columns=["编号", "水印内容"])
28
+ default_trace_table.loc[0] = (0, "默认用户")
29
+ default_trace_table.loc[1] = (1, "张三")
30
+ default_trace_table.loc[2] = (2, "李四")
31
+
32
+ watermark_salt = 0
33
+
34
+
35
+ def str2bool(v):
36
+ """用户友好的布尔标志参数的Util函数"""
37
+ if isinstance(v, bool):
38
+ return v
39
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
40
+ return True
41
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
42
+ return False
43
+ else:
44
+ raise argparse.ArgumentTypeError('Boolean value expected.')
45
+
46
+
47
+ # 定义一个函数用于解析命令行参数
48
+ def parse_args():
49
+ parser = argparse.ArgumentParser(
50
+ description="")
51
+
52
+ parser.add_argument(
53
+ "--run_gradio",
54
+ type=str2bool,
55
+ default=True,
56
+ help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
57
+ )
58
+ parser.add_argument(
59
+ "--demo_public",
60
+ type=str2bool,
61
+ default=False,
62
+ help="Whether to expose the gradio demo to the internet.",
63
+ )
64
+ parser.add_argument(
65
+ "--model_name_or_path",
66
+ type=str,
67
+ default="Qwen/Qwen1.5-0.5B-Chat",
68
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
69
+ )
70
+ parser.add_argument(
71
+ "--prompt_max_length",
72
+ type=int,
73
+ default=None,
74
+ help="Truncation length for prompt, overrides model config's max length field.",
75
+ )
76
+ parser.add_argument(
77
+ "--max_new_tokens",
78
+ type=int,
79
+ default=200,
80
+ help="Maximmum number of new tokens to generate.",
81
+ )
82
+ parser.add_argument(
83
+ "--generation_seed",
84
+ type=int,
85
+ default=123,
86
+ help="Seed for setting the torch global rng prior to generation.",
87
+ )
88
+ parser.add_argument(
89
+ "--use_sampling",
90
+ type=str2bool,
91
+ default=True,
92
+ help="Whether to generate using multinomial sampling.",
93
+ )
94
+ parser.add_argument(
95
+ "--sampling_temp",
96
+ type=float,
97
+ default=0.7,
98
+ help="Sampling temperature to use when generating using multinomial sampling.",
99
+ )
100
+ parser.add_argument(
101
+ "--n_beams",
102
+ type=int,
103
+ default=1,
104
+ help="Number of beams to use for beam search. 1 is normal greedy decoding",
105
+ )
106
+ parser.add_argument(
107
+ "--use_gpu",
108
+ type=str2bool,
109
+ default=True,
110
+ help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
111
+ )
112
+ parser.add_argument(
113
+ "--seeding_scheme",
114
+ type=str,
115
+ default="simple_1",
116
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
117
+ )
118
+ parser.add_argument(
119
+ "--gamma",
120
+ type=float,
121
+ default=0.5,
122
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
123
+ )
124
+ parser.add_argument(
125
+ "--delta",
126
+ type=float,
127
+ default=2.0,
128
+ help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
129
+ )
130
+ parser.add_argument(
131
+ "--normalizers",
132
+ type=str,
133
+ default="",
134
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
135
+ )
136
+ parser.add_argument(
137
+ "--ignore_repeated_bigrams",
138
+ type=str2bool,
139
+ default=False,
140
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
141
+ )
142
+ parser.add_argument(
143
+ "--detection_z_threshold",
144
+ type=float,
145
+ default=4.0,
146
+ help="The test statistic threshold for the detection hypothesis test.",
147
+ )
148
+ parser.add_argument(
149
+ "--select_green_tokens",
150
+ type=str2bool,
151
+ default=True,
152
+ help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
153
+ )
154
+ parser.add_argument(
155
+ "--skip_model_load",
156
+ type=str2bool,
157
+ default=False,
158
+ help="Skip the model loading to debug the interface.",
159
+ )
160
+ parser.add_argument(
161
+ "--gguf_file",
162
+ type=str,
163
+ default='./qwen2-0_5b-instruct-q2_k.gguf',
164
+ help="gguf文件(如果有)",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--seed_separately",
169
+ type=str2bool,
170
+ default=True,
171
+ help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
172
+ )
173
+ args = parser.parse_args()
174
+
175
+ return args
176
+
177
+
178
+ def load_model(args):
179
+ """加载并返回模型和分词器"""
180
+
181
+ if args.use_gpu:
182
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
183
+ else:
184
+ device = "cpu"
185
+
186
+ if 'gguf' in args.model_name_or_path.lower():
187
+
188
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file,
189
+ trust_remote_code=True,
190
+ local_files_only=True,
191
+ device_map=device)
192
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file,
193
+ local_files_only=True,
194
+ trust_remote_code=True)
195
+
196
+ else:
197
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
198
+ trust_remote_code=True,
199
+ local_files_only=True,
200
+ device_map=device)
201
+
202
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,
203
+ local_files_only=True, )
204
+
205
+ try:
206
+ model.eval()
207
+ except Exception as e:
208
+ print(e)
209
+
210
+ return model, tokenizer, device
211
+
212
+
213
+ from text_generation import InferenceAPIClient
214
+ from requests.exceptions import ReadTimeout
215
+
216
+
217
+ def generate_with_api(prompt, args):
218
+ hf_api_key = os.environ.get("HF_API_KEY")
219
+ if hf_api_key is None:
220
+ raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
221
+
222
+ client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
223
+
224
+ assert args.n_beams == 1, "HF API models do not support beam search."
225
+ generation_params = {
226
+ "max_new_tokens": args.max_new_tokens,
227
+ "do_sample": args.use_sampling,
228
+ }
229
+ if args.use_sampling:
230
+ generation_params["temperature"] = args.sampling_temp
231
+ generation_params["seed"] = args.generation_seed
232
+
233
+ timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
234
+ try:
235
+ generation_params["watermark"] = False
236
+ without_watermark_iterator = client.generate_stream(prompt, **generation_params)
237
+ except ReadTimeout as e:
238
+ print(e)
239
+ without_watermark_iterator = (char for char in timeout_msg)
240
+ try:
241
+ generation_params["watermark"] = True
242
+ with_watermark_iterator = client.generate_stream(prompt, **generation_params)
243
+ except ReadTimeout as e:
244
+ print(e)
245
+ with_watermark_iterator = (char for char in timeout_msg)
246
+
247
+ all_without_words, all_with_words = "", ""
248
+ for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
249
+ all_without_words += without_word.token.text
250
+ all_with_words += with_word.token.text
251
+ yield all_without_words, all_with_words
252
+
253
+
254
+ def check_prompt(prompt, args, tokenizer, model=None, device=None):
255
+ # 这适用于本地和API模型场景
256
+ try:
257
+ if args.model_name_or_path in API_MODEL_MAP:
258
+ args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
259
+ elif hasattr(model.config, "max_position_embedding"):
260
+ args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
261
+ else:
262
+ args.prompt_max_length = 4096 - args.max_new_tokens
263
+ except Exception as e:
264
+ print(e)
265
+ args.prompt_max_length = 4096 - args.max_new_tokens
266
+
267
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, truncation=True,
268
+ max_length=args.prompt_max_length).to(device)
269
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
270
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
271
+
272
+ return (redecoded_input,
273
+ int(truncation_warning),
274
+ args)
275
+
276
+
277
+ def generate(prompt, args, tokenizer, model=None, device=None):
278
+ """根据水印参数实例化 WatermarkLogitsProcessor 并通过将其作为 logits 处理器传递给模型的 generate 方法来生成带水印的文本。"""
279
+ print(f"Generating with {args}")
280
+ print(f"Prompt: {prompt}")
281
+
282
+ if args.model_name_or_path in API_MODEL_MAP:
283
+ api_outputs = generate_with_api(prompt, args)
284
+ yield from api_outputs
285
+ else:
286
+ if 'chatglm' in args.model_name_or_path.lower() or 'qwen' in args.model_name_or_path.lower() or 'llama' in args.model_name_or_path.lower():
287
+ messages = [
288
+ # {"role": "system", "content": "You are a helpful assistant."},
289
+ {"role": "user", "content": prompt}
290
+ ]
291
+
292
+ tokenized_input = tokenizer.apply_chat_template(
293
+ messages,
294
+ tokenize=False,
295
+ add_generation_prompt=True
296
+ )
297
+
298
+ tokd_input = tokenizer([tokenized_input], return_tensors="pt", truncation=True, add_special_tokens=False,
299
+ max_length=args.prompt_max_length).to(device)
300
+
301
+ else:
302
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
303
+ max_length=args.prompt_max_length).to(device)
304
+
305
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
306
+
307
+ if args.use_sampling:
308
+ gen_kwargs.update(dict(
309
+ do_sample=True,
310
+ top_k=0,
311
+ temperature=args.sampling_temp
312
+ ))
313
+ else:
314
+ gen_kwargs.update(dict(
315
+ num_beams=args.n_beams
316
+ ))
317
+
318
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
319
+ gamma=args.gamma,
320
+ delta=args.delta,
321
+ seeding_scheme=args.seeding_scheme,
322
+ extra_salt=watermark_salt,
323
+ select_green_tokens=args.select_green_tokens)
324
+
325
+ generate_without_watermark = partial(
326
+ model.generate,
327
+ **gen_kwargs
328
+ )
329
+
330
+ generate_with_watermark = partial(
331
+ model.generate,
332
+ logits_processor=LogitsProcessorList([watermark_processor]),
333
+ **gen_kwargs
334
+ )
335
+
336
+ start_time = time.time()
337
+ gr.Info('开始生成正常内容')
338
+ torch.manual_seed(args.generation_seed)
339
+ output_without_watermark = generate_without_watermark(**tokd_input)
340
+
341
+ # 可选择在第二次生成之前种子,但通常不会再次相同,除非 delta==0.0,无操作水印
342
+
343
+ print(watermark_salt)
344
+ print(default_trace_table)
345
+ print(default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'])
346
+ gr.Info('开始注入水印:“{}”'.format(
347
+ default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'].item()))
348
+ if args.seed_separately:
349
+ torch.manual_seed(args.generation_seed)
350
+
351
+ output_with_watermark = generate_with_watermark(**tokd_input)
352
+
353
+ output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
354
+ output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]
355
+
356
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
357
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
358
+
359
+ end_time = time.time()
360
+ gr.Info(f"生成结束,共用时{end_time - start_time:.2f}秒")
361
+
362
+ print(f"Generation took {end_time - start_time:.2f} seconds")
363
+
364
+ # 使用空格分隔生成器风格模拟 API 输出
365
+
366
+ all_without_words, all_with_words = "", ""
367
+ for without_word, with_word in zip(decoded_output_without_watermark.split(),
368
+ decoded_output_with_watermark.split()):
369
+ all_without_words += without_word + " "
370
+ all_with_words += with_word + " "
371
+ yield all_without_words, all_with_words
372
+
373
+
374
+ def format_names(s):
375
+ """为 gradio 演示界面格式化名称"""
376
+ s = s.replace("num_tokens_scored", "总Token")
377
+ s = s.replace("num_green_tokens", "Green Token数量")
378
+ s = s.replace("green_fraction", "Green Token占比")
379
+ s = s.replace("z_score", "z-score")
380
+ s = s.replace("p_value", "p value")
381
+ s = s.replace("prediction", "预测结果")
382
+ s = s.replace("confidence", "置信度")
383
+ return s
384
+
385
+
386
+ def list_format_scores(score_dict, detection_threshold):
387
+ """将检测指标格式化为 gradio 数据框输入格式"""
388
+ lst_2d = []
389
+ for k, v in score_dict.items():
390
+ if k == 'green_fraction':
391
+ lst_2d.append([format_names(k), f"{v:.1%}"])
392
+ elif k == 'confidence':
393
+ lst_2d.append([format_names(k), f"{v:.3%}"])
394
+ elif isinstance(v, float):
395
+ lst_2d.append([format_names(k), f"{v:.3g}"])
396
+ elif isinstance(v, bool):
397
+ lst_2d.append([format_names(k), ("含有水印" if v else "无水印")])
398
+ else:
399
+ lst_2d.append([format_names(k), f"{v}"])
400
+ if "confidence" in score_dict:
401
+ lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
402
+ else:
403
+ lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
404
+ return lst_2d
405
+
406
+
407
+ def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
408
+ """实例化 WatermarkDetection 对象并调用 detect 方法 在输入文本上返回测试的分数和结果"""
409
+
410
+ print(f"Detecting with {args}")
411
+ print(f"Detection Tokenizer: {type(tokenizer)}")
412
+
413
+ # 现在不要显示绿色的token mask
414
+ # 如果我们使用的是normalizers或ignore_repeated_bigrams
415
+ if args.normalizers != [] or args.ignore_repeated_bigrams:
416
+ return_green_token_mask = False
417
+
418
+ error = False
419
+ green_token_mask = None
420
+ if input_text == "":
421
+ error = True
422
+ else:
423
+ try:
424
+ for _, data in default_trace_table.iterrows():
425
+ salt = data["编号"]
426
+ name = data["水印内容"]
427
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
428
+ gamma=args.gamma,
429
+ seeding_scheme=args.seeding_scheme,
430
+ extra_salt=salt,
431
+ device=device,
432
+ tokenizer=tokenizer,
433
+ z_threshold=args.detection_z_threshold,
434
+ normalizers=args.normalizers,
435
+ ignore_repeated_bigrams=args.ignore_repeated_bigrams,
436
+ select_green_tokens=args.select_green_tokens)
437
+ score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
438
+ if score_dict['prediction']:
439
+ print(f"检测到是“{name}”的水印")
440
+ break
441
+
442
+ green_token_mask = score_dict.pop("green_token_mask", None)
443
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
444
+ except ValueError as e:
445
+ print(e)
446
+ error = True
447
+ if error:
448
+ output = [["Error", "string too short to compute metrics"]]
449
+ output += [["", ""] for _ in range(6)]
450
+
451
+ html_output = "[No highlight markup generated]"
452
+
453
+ if green_token_mask is None:
454
+ html_output = "[Visualizing masks with ignore_repeated_bigrams enabled is not supported, toggle off to see the mask for this text. The mask is the same in both cases - only counting/stats are affected.]"
455
+
456
+ if green_token_mask is not None:
457
+ # hack 因为我们需要一个带有字符跨度支持的快速分词器
458
+ tokens = tokenizer(input_text, add_special_tokens=False)
459
+ if tokens["input_ids"][0] == tokenizer.bos_token_id:
460
+ tokens["input_ids"] = tokens["input_ids"][1:] # 忽略注意力掩码
461
+ skip = watermark_detector.min_prefix_len
462
+
463
+ if args.model_name_or_path in ['THUDM/chatglm3-6b']:
464
+ # 假设词表中3-258就是字节0-255
465
+ charspans = []
466
+ for i in range(skip, len(tokens["input_ids"])):
467
+ if tokens.data['input_ids'][i - 1] in range(3, 259):
468
+ charspans.append("<0x{:X}>".format(tokens.data['input_ids'][i - 1] - 3))
469
+ else:
470
+ charspans.append(tokenizer.decode(tokens.data['input_ids'][i - 1:i]))
471
+
472
+ else:
473
+ charspans = [tokens.token_to_chars(i - 1) for i in range(skip, len(tokens["input_ids"]))]
474
+
475
+ charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
476
+
477
+ if len(charspans) != len(green_token_mask): breakpoint()
478
+ assert len(charspans) == len(green_token_mask)
479
+
480
+ if args.model_name_or_path in ['THUDM/chatglm3-6b']:
481
+ tags = []
482
+ for cs, m in zip(charspans, green_token_mask):
483
+ tags.append(
484
+ f'<span class="green">{cs}</span>' if m else f'<span class="red">{cs}</span>')
485
+
486
+ else:
487
+ tags = [(
488
+ f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>')
489
+ for cs, m in zip(charspans, green_token_mask)]
490
+
491
+ html_output = f'<p>{" ".join(tags)}</p>'
492
+
493
+ if score_dict['prediction']:
494
+ html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(255, 0, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold;">
495
+ <span>有 “{}” 的水印</span>
496
+ </div>""".format(name), visible=True)
497
+
498
+ else:
499
+ html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(0, 128, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold; text-align: center;">
500
+ <span>无水印</span>
501
+ </div>""", visible=True)
502
+
503
+ return output, args, tokenizer, html_output, html_look
504
+
505
+
506
+ def run_gradio(args, model=None, device=None, tokenizer=None):
507
+ """定义并启动gradio演示界面"""
508
+ check_prompt_partial = partial(check_prompt, model=model, device=device)
509
+ generate_partial = partial(generate, model=model, device=device)
510
+ detect_partial = partial(detect, device=device)
511
+
512
+ css = """
513
+ .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
514
+ .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
515
+ """
516
+
517
+ with gr.Blocks(theme='ParityError/Interstellar', css=css) as demo:
518
+ # 顶部部分,问候语和说明
519
+ with gr.Row():
520
+ with gr.Column(scale=9):
521
+ gr.Markdown(
522
+ """
523
+ # 🌸🖼️ LLMwatermark:面向大语言模型生成内容的数字水印版权保护系统 🌟🎓
524
+ """
525
+ )
526
+ with gr.Column(scale=1):
527
+ # 如果启动时的 model_name_or_path 不是 API 模型之一,则添加到下拉菜单中
528
+ all_models = sorted(list(set(list(API_MODEL_MAP.keys()) + [args.model_name_or_path])))
529
+ model_selector = gr.Dropdown(
530
+ all_models,
531
+ value=args.model_name_or_path,
532
+ label="选择大语言模型,进行模型水印",
533
+ )
534
+
535
+ # 构建参数的状态,定义更新和切换
536
+ default_prompt = args.__dict__.pop("default_prompt")
537
+ session_args = gr.State(value=args)
538
+ # 注意,如果状态对象是可调用的,则自动调用 value,希望在启动时避免调用分词器
539
+ session_tokenizer = gr.State(value=lambda: tokenizer)
540
+
541
+ with gr.Tab("生成回答和添加文本水印🎓"):
542
+
543
+ with gr.Row():
544
+ with gr.Column(scale=5):
545
+ prompt = gr.Textbox(label=f"Prompt", interactive=True, lines=3, max_lines=10, value=default_prompt)
546
+ with gr.Column(scale=3):
547
+ trace_source = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
548
+ col_count=(2, "fixed"))
549
+ with gr.Row(equal_height=True):
550
+ with gr.Column(scale=7):
551
+ generate_btn = gr.Button("Generate", variant='primary')
552
+
553
+ gr.Markdown('水印选择:',
554
+ show_label=False)
555
+ watermark_salt_choice = gr.Dropdown(
556
+ choices=[i[::-1] for i in default_trace_table.to_dict(orient='split')['data']],
557
+ value=0,
558
+ container=False,
559
+ scale=3,
560
+ type="value",
561
+ interactive=True, label="水印标识选择")
562
+
563
+ with gr.Row():
564
+ with gr.Column():
565
+ with gr.Column(scale=2):
566
+ with gr.Tab("原版输出"):
567
+ output_without_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
568
+ show_label=False)
569
+ with gr.Tab("显示水印"):
570
+ html_without_watermark = gr.HTML(elem_id="html-without-watermark")
571
+
572
+ original_watermark_state = gr.HTML('', visible=False)
573
+
574
+ with gr.Column(scale=1):
575
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],
576
+ interactive=False,
577
+ row_count=7, col_count=2)
578
+ with gr.Column():
579
+ with gr.Column(scale=2):
580
+ with gr.Tab("带水印的输出"):
581
+ output_with_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
582
+ show_label=False)
583
+ with gr.Tab("显示水印"):
584
+ html_with_watermark = gr.HTML(elem_id="html-with-watermark")
585
+
586
+ change_watermark_state = gr.HTML('', visible=False)
587
+
588
+ with gr.Column(scale=1):
589
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
590
+ row_count=7, col_count=2)
591
+
592
+ redecoded_input = gr.Textbox(visible=False)
593
+ truncation_warning = gr.Number(visible=False)
594
+
595
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
596
+ if truncation_warning:
597
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
598
+ else:
599
+ return orig_prompt, args
600
+
601
+ with gr.Tab("检测文本水印功能🎭"):
602
+ with gr.Row():
603
+ with gr.Column(scale=5):
604
+ with gr.Tab("分析文本"):
605
+ detection_input = gr.Textbox(interactive=True, lines=14, max_lines=14, show_label=False)
606
+ with gr.Tab("显示水印"):
607
+ html_detection_input = gr.HTML(elem_id="html-detection-input")
608
+
609
+ detect_watermark_state = gr.HTML('', visible=False)
610
+ with gr.Column(scale=2):
611
+ trace_source2 = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
612
+ col_count=(2, "fixed"))
613
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False, row_count=7,
614
+ col_count=2)
615
+
616
+ with gr.Row():
617
+ detect_btn = gr.Button("检测", variant='primary')
618
+
619
+ with gr.Tab("About📖"):
620
+ with gr.Row():
621
+ with gr.Column(scale=2):
622
+ gr.Markdown(
623
+ """
624
+ 大语言模型可能带来的潜在危害可以通过*水印*来减轻。*水印*是嵌入在生成的文本中的信息,
625
+ 这对人类来说是不可见的,但是可以被特定算法检测到。
626
+ 这些水印可以使*任何人*用特定工具判断其是否使用带水印的模型生成的。
627
+ 本网站展示了一种水印方法,可以应用于_任何_生成性语言模型。
628
+ """
629
+ )
630
+ gr.Markdown(
631
+ """
632
+ **[生成文本与添加水印]**:可以给大模型的输出添加水印。
633
+ 您可以尝试任何prompt,并比较正常文本(*没有水印的输出*)和水印文本(*有水印的输出*)的质量。
634
+ 您还可以点击**显示水印**来“看到”水印,其中的颜色表示其所在的红绿表。
635
+
636
+ **[检测]**:您还可以将水印文本(或任何其他文本)复制粘贴到第二个选项卡中。
637
+ 可以实验删除多少句子后还能检测到水印。
638
+ 还可以在验证,检测器的误报率有多少;
639
+ """
640
+ )
641
+
642
+ with gr.Column(scale=1):
643
+ gr.Markdown(
644
+ """
645
+ ![]()
646
+ """
647
+ )
648
+
649
+ # 参数选择组
650
+ with gr.Accordion("高级设置", open=False):
651
+ with gr.Row():
652
+ with gr.Column(scale=1):
653
+ gr.Markdown(f"#### 生成参数")
654
+ with gr.Row():
655
+ decoding = gr.Radio(label="解码方法", choices=["多项式解码方法", "贪心解码方法"],
656
+ value=("multinomial" if args.use_sampling else "greedy"))
657
+ with gr.Row():
658
+ sampling_temp = gr.Slider(label="采样温度", minimum=0.1, maximum=1.0, step=0.1,
659
+ value=args.sampling_temp, visible=True)
660
+ with gr.Row():
661
+ generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
662
+ with gr.Row():
663
+ n_beams = gr.Dropdown(label="波束搜索解码", choices=list(range(1, 11, 1)), value=args.n_beams,
664
+ visible=((not args.use_sampling) and (
665
+ not args.model_name_or_path in API_MODEL_MAP)))
666
+ with gr.Row():
667
+ max_new_tokens = gr.Slider(label="(生成文本的最大长度)Max Generated Tokens", minimum=10,
668
+ maximum=4000, step=10, value=args.max_new_tokens)
669
+
670
+ with gr.Column(scale=1):
671
+ gr.Markdown(f"#### 模型水印参数设置")
672
+ with gr.Row():
673
+ gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
674
+ with gr.Row():
675
+ delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
676
+ gr.Markdown(f"#### 检测文本水印参数设置")
677
+ with gr.Row():
678
+ detection_z_threshold = gr.Slider(label="Z分数阈值", minimum=0.0, maximum=10.0, step=0.1,
679
+ value=args.detection_z_threshold)
680
+ with gr.Row():
681
+ ignore_repeated_bigrams = gr.Checkbox(label="避免生成连续重复的双词组合")
682
+ with gr.Row():
683
+ normalizers = gr.CheckboxGroup(label="对文本进行标准化处理",
684
+ choices=["unicode", "homoglyphs", "truecase"],
685
+ value=args.normalizers)
686
+ with gr.Row():
687
+ gr.Markdown(
688
+ f"注意:滑块并不总是能完美更新。点击条形图或使用右侧的数字窗口会有所帮助。下面的窗口显示当前设置。")
689
+ with gr.Row():
690
+ current_parameters = gr.Textbox(label="当前参数设置", value=args, max_lines=10)
691
+ with gr.Accordion("传统设置", open=False):
692
+ with gr.Row():
693
+ with gr.Column(scale=1):
694
+ seed_separately = gr.Checkbox(label="为两个不同的生成过程分别设置随机种子",
695
+ value=args.seed_separately)
696
+ with gr.Column(scale=1):
697
+ select_green_tokens = gr.Checkbox(label="从分区中选择 绿色列表", value=args.select_green_tokens)
698
+
699
+ with gr.Accordion("设置有什么作用?", open=False):
700
+ gr.Markdown(
701
+ """
702
+ #### 生成参数:
703
+
704
+ - **解码方法**:我们可以使用多项式采样或者贪婪解码的方式从模型中生成标记。
705
+ 决定如何从模型中生成token。可以选择多项式采样或贪婪解码。
706
+ 多项式采样允许一定的随机性,而贪婪解码总是选择概率最高的下一个token。
707
+ - **采样温度**:如果使用多项式采样,我们可以设置采样分布的温度。
708
+ - 0.0 相当于贪婪解码,而 1.0 代表最大的随机性。
709
+ - 0.7是文本质量和随机性之间的平衡点。不适用于贪婪解码。
710
+ - **生成种子**:用于在生成前初始化随机数生成器,使多项式采样的输出可复现。此设置不适用于贪婪解码。
711
+ - **束搜索数量**:在使用贪婪解码时,可以设置光束数量来启用束搜索。
712
+ 这允许考虑多个候选序列,而不是只选择最有可能的一个。此设置目前仅适用于贪婪解码。
713
+ - **最大生成标记数**:传递给生成方法的 `max_new_tokens` 参数,以在一定数量的新标记停止输出。
714
+ - 请注意,根据提示,模型可以生成较少的标记。隐含地,这将最大化可能的提示标记数,即模型的最大输入长度减去 `max_new_tokens`,并相应地截断输入。
715
+
716
+ 综上所述,这些参数提供了对生成过程的不同方面的控制,包括随机性和多样性
717
+ (解码方法、采样温度),可复现性(生成种子),以及输出长度和多样性(光束数量、最大生成token数)。
718
+ 合理配置这些参数可以帮助生成高质量的文本输出。
719
+
720
+ #### 水印参数:
721
+
722
+ - **gamma**:在每个生成步骤中将词汇表的一部分划分为绿色列表的比例。
723
+ - 较小的 gamma 值通过使水印模型优先从较小的绿色集中采样,从而使其与人类/未水印文本的差异更大,从而创建更强的水印。
724
+ - **delta**:在每个生成步骤中为绿色列表中的每个标记的 logits 添加的正偏差量。较高的 delta 值意味着水印模型更偏好于绿色列表中的标记。
725
+ - 随着偏差变得非常大,水印从 "软" 过渡到 "硬"。对于硬水印,几乎所有标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性不大时。
726
+
727
+ #### 检测器参数:
728
+
729
+ - **z 分数阈值**:假设检验的 z 分数截断。较高的阈值(例如 4.0)使得 _false positives_(预测人类/未水印文本被标记为水印)非常不可能,
730
+ 因为一个真正的人类文本几乎永远不会达到那么高的 z 分数。
731
+ - 较低的阈值会捕获更多的 _true positives_,因为一些水印文本可能包含较少的绿色标记并达到较低的 z 分数,但仍然通过较低的门槛并被标记为 "水印"。
732
+ 然而,较低的阈值会增加包含略高于平均水平的绿色标记的人类文本错误地被标记为水印的几率。4.0-5.0 提供了极低的假阳性率,同时仍准确捕获大多数水印文本。
733
+ - **忽略二元重复**:这种替代的检测算法在检测期间仅考虑文本中的唯一二元组,根据每对中的第一个计算绿色列表,并检查第二个是否位于列表中。
734
+ - 这意味着 `T` 现在是文本中唯一二元组的数量,如果文本包含大量重复,则这个数字将小于生成的总标记数。有关更详细的讨论,请参阅论文。
735
+ - **归一化**:我们实现了一些基本的归一化来抵御文本在检测期间的各种对抗性扰动。
736
+ - 目前,我们支持将所有字符转换为 Unicode,将同形异义字符替换为规范形式,并标准化大写。有关输入归一化的详细讨论,请参阅论文。
737
+ """
738
+ )
739
+
740
+ with gr.Accordion("输出指标意味着什么?", open=False):
741
+ gr.Markdown(
742
+ """
743
+ - `z-score threshold`:假设检验的截止值。
744
+ - `Tokens Counted (T)`:检测算法计算的输出中的标记数量。在简单的、单标记播种方案中,第一个标记被省略了,因为没有办法为其生成绿色列表,因为它没有前缀标记。
745
+ 在底部面板描述的“忽略二元重复”检测算法下,如果有很多重复,这个数量可能远少于生成的总标记数。
746
+ - ` Tokens in Greenlist`:观察到落在其相应绿色列表中的标记数量。
747
+ - `Fraction of T in Greenlist`:`# Tokens in Greenlist` / `T`。这应该大约等于人类/未水印文本的 `gamma`。
748
+ - `z-score`:用于检测假设检验的检验统计量。如果大于 `z-score threshold`,则我们“拒绝零假设”,即文本是人类/未水印的,并得出结论它是带水印的。
749
+ - `p value`:在零假设下观察到计算出的 `z-score` 的可能性。这是在不知道水印程序/绿色列表的情况下观察到 `Fraction of T in Greenlist` 的可能性。
750
+ 如果这个值极其 _小_,我们可以确信这么多的绿色标记不是由随机机会选择的。
751
+ - `prediction`:假设检验的结果 - 观察到的 `z-score` 是否高于 `z-score threshold`。
752
+ - `confidence`:如果我们拒绝零假设,且 `prediction` 是“带水印”,那么我们报告 1-`p value` 来表示基于这个 `z-score` 观察的检测的置信度。
753
+ """
754
+ )
755
+
756
+ gr.HTML("""
757
+ <p>本方法可以对任何大模型的输出结果进行水印的操作。
758
+ 并且可以对输出结果进行水印检测。
759
+ <p/>
760
+ """)
761
+
762
+ # 注册主要生成标签单击事件,输出生成文本以及编码+重新解码+可能被截断的提示和标志,然后调用检测
763
+ generate_btn.click(fn=check_prompt_partial, inputs=[prompt, session_args, session_tokenizer],
764
+ outputs=[redecoded_input, truncation_warning, session_args]).success(
765
+ fn=generate_partial, inputs=[redecoded_input, session_args, session_tokenizer],
766
+ outputs=[output_without_watermark, output_with_watermark]).success(
767
+ fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
768
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
769
+ html_without_watermark, original_watermark_state]).success(
770
+ fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
771
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark,
772
+ change_watermark_state])
773
+ # 如果发生了截断,则显示提示的截断版本
774
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
775
+ outputs=[prompt, session_args])
776
+ # Register main detection tab click
777
+ detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
778
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input,
779
+ detect_watermark_state],
780
+ api_name="detection")
781
+
782
+ # 状态管理逻辑
783
+ # 定义更新回调函数以更改状态字典
784
+ def update_model(session_state, value):
785
+ session_state.model_name_or_path = value
786
+ return session_state
787
+
788
+ def update_sampling_temp(session_state, value):
789
+ session_state.sampling_temp = float(value)
790
+ return session_state
791
+
792
+ def update_generation_seed(session_state, value):
793
+ session_state.generation_seed = int(value)
794
+ return session_state
795
+
796
+ def update_watermark_salt(value):
797
+ global watermark_salt
798
+ if isinstance(value, int):
799
+ watermark_salt = value
800
+ elif value is None:
801
+ watermark_salt = 0
802
+ elif isinstance(value, str) and value.isdigit():
803
+ watermark_salt = int(value)
804
+ else:
805
+ # 不知道为什么会出现这种倒置的情况
806
+ watermark_salt = int(
807
+ default_trace_table.loc[default_trace_table['水印内容'] == value, '编号'].item())
808
+
809
+ def update_trace_source(value):
810
+ global default_trace_table
811
+ try:
812
+ if '' in value.loc[:, '编号'].tolist():
813
+ return value, gr.Dropdown()
814
+
815
+ value.loc[:, '编号'] = value.loc[:, '编号'].astype(int)
816
+
817
+ if default_trace_table.duplicated(subset='编号').any():
818
+ raise gr.Error(f"请检查水印编号,编号不能重复")
819
+
820
+ default_trace_table = value
821
+
822
+ return value, gr.Dropdown(
823
+ choices=[i[::-1] for i in value.to_dict(orient='split')['data']])
824
+
825
+ except ValueError as e:
826
+ if 'invalid literal for int() with base 10' in str(e):
827
+ raise gr.Error(f"请检查水印数据,编号必须是整数:{e}")
828
+
829
+ except gradio.exceptions.Error as e:
830
+ raise e
831
+
832
+ except Exception as e:
833
+ print(type(e))
834
+ raise e
835
+
836
+ def update_gamma(session_state, value):
837
+ session_state.gamma = float(value)
838
+ return session_state
839
+
840
+ def update_delta(session_state, value):
841
+ session_state.delta = float(value)
842
+ return session_state
843
+
844
+ def update_detection_z_threshold(session_state, value):
845
+ session_state.detection_z_threshold = float(value)
846
+ return session_state
847
+
848
+ def update_decoding(session_state, value):
849
+ if value == "multinomial":
850
+ session_state.use_sampling = True
851
+ elif value == "greedy":
852
+ session_state.use_sampling = False
853
+ return session_state
854
+
855
+ def toggle_sampling_vis(value):
856
+ if value == "multinomial":
857
+ return gr.update(visible=True)
858
+ elif value == "greedy":
859
+ return gr.update(visible=False)
860
+
861
+ def toggle_sampling_vis_inv(value):
862
+ if value == "multinomial":
863
+ return gr.update(visible=False)
864
+ elif value == "greedy":
865
+ return gr.update(visible=True)
866
+
867
+ # 如果模型名称在 API 模型列表中,则将 num beams 参数设置为 1 并隐藏 n_beams
868
+ def toggle_vis_for_api_model(value):
869
+ if value in API_MODEL_MAP:
870
+ return gr.update(visible=False)
871
+ else:
872
+ return gr.update(visible=True)
873
+
874
+ def toggle_beams_for_api_model(value, orig_n_beams):
875
+ if value in API_MODEL_MAP:
876
+ return gr.update(value=1)
877
+ else:
878
+ return gr.update(value=orig_n_beams)
879
+
880
+ # 如果模型名称在 API 模型列表中,则将交互参数设置为 false
881
+ def toggle_interactive_for_api_model(value):
882
+ if value in API_MODEL_MAP:
883
+ return gr.update(interactive=False)
884
+ else:
885
+ return gr.update(interactive=True)
886
+
887
+ # 如果模型名称在 API 模型列表中,则根据 API 映射设置 gamma 和 delta
888
+ def toggle_gamma_for_api_model(value, orig_gamma):
889
+ if value in API_MODEL_MAP:
890
+ return gr.update(value=API_MODEL_MAP[value]["gamma"])
891
+ else:
892
+ return gr.update(value=orig_gamma)
893
+
894
+ def toggle_delta_for_api_model(value, orig_delta):
895
+ if value in API_MODEL_MAP:
896
+ return gr.update(value=API_MODEL_MAP[value]["delta"])
897
+ else:
898
+ return gr.update(value=orig_delta)
899
+
900
+ def update_n_beams(session_state, value):
901
+ session_state.n_beams = value;
902
+ return session_state
903
+
904
+ def update_max_new_tokens(session_state, value):
905
+ session_state.max_new_tokens = int(value);
906
+ return session_state
907
+
908
+ def update_ignore_repeated_bigrams(session_state, value):
909
+ session_state.ignore_repeated_bigrams = value;
910
+ return session_state
911
+
912
+ def update_normalizers(session_state, value):
913
+ session_state.normalizers = value;
914
+ return session_state
915
+
916
+ def update_seed_separately(session_state, value):
917
+ session_state.seed_separately = value;
918
+ return session_state
919
+
920
+ def update_select_green_tokens(session_state, value):
921
+ session_state.select_green_tokens = value;
922
+ return session_state
923
+
924
+ def update_tokenizer(model_name_or_path):
925
+ # if model_name_or_path == ALPACA_MODEL_NAME:
926
+ # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
927
+ # else:
928
+ return AutoTokenizer.from_pretrained(model_name_or_path)
929
+
930
+ def check_model(value):
931
+ return value if (value != "" and value is not None) else args.model_name_or_path
932
+
933
+ # 强制约束模型��能为 null 或空
934
+ # 然后特别附加模型回调函数
935
+ model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
936
+ toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams]
937
+ ).then(
938
+ toggle_beams_for_api_model, inputs=[model_selector, n_beams], outputs=[n_beams]
939
+ ).then(
940
+ toggle_interactive_for_api_model, inputs=[model_selector], outputs=[gamma]
941
+ ).then(
942
+ toggle_interactive_for_api_model, inputs=[model_selector], outputs=[delta]
943
+ ).then(
944
+ toggle_gamma_for_api_model, inputs=[model_selector, gamma], outputs=[gamma]
945
+ ).then(
946
+ toggle_delta_for_api_model, inputs=[model_selector, delta], outputs=[delta]
947
+ ).then(
948
+ update_tokenizer, inputs=[model_selector], outputs=[session_tokenizer]
949
+ ).then(
950
+ update_model, inputs=[session_args, model_selector], outputs=[session_args]
951
+ ).then(
952
+ lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
953
+ )
954
+ # 根据其他参数的值注册回调函数以切换特定参数的可见性
955
+ decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
956
+ decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
957
+ decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
958
+ decoding.change(toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams])
959
+ # 注册所有状态更新回调函数
960
+ decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
961
+ sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
962
+ generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
963
+ watermark_salt_choice.change(update_watermark_salt, inputs=[watermark_salt_choice])
964
+
965
+ # 同步更新
966
+ trace_source.change(update_trace_source, inputs=[trace_source],
967
+ outputs=[trace_source2, watermark_salt_choice])
968
+ trace_source2.change(update_trace_source, inputs=[trace_source2],
969
+ outputs=[trace_source, watermark_salt_choice])
970
+
971
+ n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
972
+ max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
973
+ gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
974
+ delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
975
+ detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
976
+ outputs=[session_args])
977
+ ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
978
+ outputs=[session_args])
979
+ normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
980
+ seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
981
+ select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
982
+ outputs=[session_args])
983
+ # 注册按钮点击时更新显示参数窗口的额外回调
984
+ generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
985
+ detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
986
+ # 当参数更改时,显示更新并触发检测,因为某些检测参数不会改变模型输出。
987
+ delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
988
+ gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
989
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
990
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
991
+ html_without_watermark])
992
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
993
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
994
+ gamma.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
995
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
996
+ detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
997
+ detection_z_threshold.change(fn=detect_partial,
998
+ inputs=[output_without_watermark, session_args, session_tokenizer],
999
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1000
+ html_without_watermark])
1001
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1002
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1003
+ html_with_watermark])
1004
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1005
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1006
+ ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1007
+ ignore_repeated_bigrams.change(fn=detect_partial,
1008
+ inputs=[output_without_watermark, session_args, session_tokenizer],
1009
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1010
+ html_without_watermark])
1011
+ ignore_repeated_bigrams.change(fn=detect_partial,
1012
+ inputs=[output_with_watermark, session_args, session_tokenizer],
1013
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1014
+ html_with_watermark])
1015
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1016
+ outputs=[detection_result, session_args, session_tokenizer,
1017
+ html_detection_input])
1018
+ normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1019
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
1020
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1021
+ html_without_watermark])
1022
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1023
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1024
+ html_with_watermark])
1025
+ normalizers.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1026
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1027
+ select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
1028
+ select_green_tokens.change(fn=detect_partial,
1029
+ inputs=[output_without_watermark, session_args, session_tokenizer],
1030
+ outputs=[without_watermark_detection_result, session_args, session_tokenizer,
1031
+ html_without_watermark])
1032
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
1033
+ outputs=[with_watermark_detection_result, session_args, session_tokenizer,
1034
+ html_with_watermark])
1035
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
1036
+ outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
1037
+
1038
+ # demo.queue(concurrency_count=3) # delete
1039
+
1040
+ if args.demo_public:
1041
+ demo.launch(share=True) # 通过随机生成的链接将应用程序暴露到互联网上
1042
+ else:
1043
+ demo.launch(server_name='0.0.0.0', share=False)
1044
+
1045
+
1046
+ def main(args):
1047
+ """运行生成和检测操作的命令行版本
1048
+ 并可选择启动和提供 gradio 演示"""
1049
+ # 初始参数处理和日志记录
1050
+ args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
1051
+ print(args)
1052
+
1053
+ if not args.skip_model_load:
1054
+ model, tokenizer, device = load_model(args)
1055
+ else:
1056
+ model, tokenizer, device = None, None, None
1057
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
1058
+ if args.use_gpu:
1059
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
1060
+ else:
1061
+ device = "cpu"
1062
+
1063
+ # terrapin example
1064
+ input_text = (
1065
+ "为什么A股指数跌的不多,但是我亏损比之前都多?"
1066
+ )
1067
+
1068
+ args.default_prompt = input_text
1069
+
1070
+ # Generate and detect, report to stdout
1071
+ if not args.skip_model_load:
1072
+ pass
1073
+
1074
+ # Launch the app to generate and detect interactively (implements the hf space demo)
1075
+ if args.run_gradio:
1076
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
1077
+
1078
+ return
1079
+
1080
+
1081
+ if __name__ == "__main__":
1082
+ args = parse_args()
1083
+ print(args)
1084
+
1085
+ main(args)
extended_watermark_processor.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+ import collections
4
+ from math import sqrt
5
+ from itertools import chain, tee
6
+ from functools import lru_cache
7
+
8
+ import scipy.stats
9
+ import torch
10
+ from tokenizers import Tokenizer
11
+ from transformers import LogitsProcessor
12
+
13
+ from normalizers import normalization_strategy_lookup
14
+ from alternative_prf_schemes import prf_lookup, seeding_scheme_lookup
15
+
16
+
17
+ class WatermarkBase:
18
+ def __init__(
19
+ self,
20
+ vocab: list[int] = None,
21
+ gamma: float = 0.25,
22
+ delta: float = 2.0,
23
+ seeding_scheme: str = "selfhash",
24
+ select_green_tokens: bool = True,
25
+ ):
26
+ # 现在可能将 None 作为 seeding_scheme 传递,所以现在要修补
27
+ if seeding_scheme is None:
28
+ seeding_scheme = "selfhash"
29
+
30
+ # 词汇设置
31
+ self.vocab = vocab
32
+ self.vocab_size = len(vocab)
33
+
34
+ # 水印行为:
35
+ self.gamma = gamma
36
+ self.delta = delta
37
+ self.rng = None
38
+ self._initialize_seeding_scheme(seeding_scheme)
39
+ # 传统行为:
40
+ self.select_green_tokens = select_green_tokens
41
+
42
+ def _initialize_seeding_scheme(self, seeding_scheme: str) -> None:
43
+ """从一个通俗的“公共”名称初始化种子策略的所有内部设置。"""
44
+ self.prf_type, self.context_width, self.self_salt, self.hash_key = seeding_scheme_lookup(seeding_scheme)
45
+
46
+ def _seed_rng(self, input_ids: torch.LongTensor) -> None:
47
+ """从本地上下文种子 RNG。不进行批处理,因为我们使用的生成器(如 cuda.random)不进行批处理。"""
48
+ # 需要有足够的token来进行种子生成
49
+ if input_ids.shape[-1] < self.context_width:
50
+ raise ValueError(f"seeding_scheme requires at least a {self.context_width} token prefix to seed the RNG.")
51
+
52
+ prf_key = prf_lookup[self.prf_type](input_ids[-self.context_width :], salt_key=self.hash_key)
53
+ self.rng.manual_seed(prf_key % (2**64 - 1)) # 防止溢出
54
+
55
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor:
56
+ """根据本地上下文宽度对rng进行种子处理,并使用这些信息在绿色列表上生成id"""
57
+ self._seed_rng(input_ids)
58
+
59
+ greenlist_size = int(self.vocab_size * self.gamma)
60
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
61
+ if self.select_green_tokens: # directly
62
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
63
+ else: # 通过红色选择绿色
64
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
65
+ return greenlist_ids
66
+
67
+
68
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
69
+ """LogitsProcessor 在管道中修改模型输出分数。可以在任何 HF 管道中使用它来修改分数以适应水印,但也可以作为一个独立的工具插入到在模型输出和下一个标记采样器之间生成分数的任何模型中。"""
70
+ def __init__(self, *args, store_spike_ents: bool = False, **kwargs):
71
+ super().__init__(*args, **kwargs)
72
+
73
+ self.store_spike_ents = store_spike_ents
74
+ self.spike_entropies = None
75
+ if self.store_spike_ents:
76
+ self._init_spike_entropies()
77
+
78
+ def _init_spike_entropies(self):
79
+ alpha = torch.exp(torch.tensor(self.delta)).item()
80
+ gamma = self.gamma
81
+
82
+ self.z_value = ((1 - gamma) * (alpha - 1)) / (1 - gamma + (alpha * gamma))
83
+ self.expected_gl_coef = (gamma * alpha) / (1 - gamma + (alpha * gamma))
84
+
85
+ # 当bias 是 "infinite" 时候捕获溢出
86
+ if alpha == torch.inf:
87
+ self.z_value = 1.0
88
+ self.expected_gl_coef = 1.0
89
+
90
+ def _get_spike_entropies(self):
91
+ spike_ents = [[] for _ in range(len(self.spike_entropies))]
92
+ for b_idx, ent_tensor_list in enumerate(self.spike_entropies):
93
+ for ent_tensor in ent_tensor_list:
94
+ spike_ents[b_idx].append(ent_tensor.item())
95
+ return spike_ents
96
+
97
+ def _get_and_clear_stored_spike_ents(self):
98
+ spike_ents = self._get_spike_entropies()
99
+ self.spike_entropies = None
100
+ return spike_ents
101
+
102
+ def _compute_spike_entropy(self, scores):
103
+ # 预先计算z得分
104
+ probs = scores.softmax(dim=-1)
105
+ denoms = 1 + (self.z_value * probs)
106
+ renormed_probs = probs / denoms
107
+ sum_renormed_probs = renormed_probs.sum()
108
+ return sum_renormed_probs
109
+
110
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
111
+ green_tokens_mask = torch.zeros_like(scores, dtype=torch.bool)
112
+ for b_idx, greenlist in enumerate(greenlist_token_ids):
113
+ if len(greenlist) > 0:
114
+ green_tokens_mask[b_idx][greenlist] = True
115
+ return green_tokens_mask
116
+
117
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
118
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
119
+ return scores
120
+
121
+ def _score_rejection_sampling(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, tail_rule="fixed_compute") -> list[int]:
122
+ """基于当前候选的下一个标记生成绿名单。如果需要,拒绝并继续。该方法不进行批处理。
123
+ 这只是算法3“鲁棒私有水印”的部分版本,因为它始终假设贪婪采样。它仍然(有点)可以对所有类型的采样进行工作,但效果较差。
124
+ 为了高效工作,此函数可以在处理分布尾部的规则之间切换。
125
+ 默认情况下不会公开这些规则。"""
126
+
127
+ sorted_scores, greedy_predictions = scores.sort(dim=-1, descending=True)
128
+
129
+ final_greenlist = []
130
+ for idx, prediction_candidate in enumerate(greedy_predictions):
131
+ greenlist_ids = self._get_greenlist_ids(torch.cat([input_ids, prediction_candidate[None]], dim=0)) # add candidate to prefix
132
+ if prediction_candidate in greenlist_ids: # test for consistency
133
+ final_greenlist.append(prediction_candidate)
134
+
135
+ # 为了提高效率,以下是可选的提前停止规则
136
+ if tail_rule == "fixed_score":
137
+ if sorted_scores[0] - sorted_scores[idx + 1] > self.delta:
138
+ break
139
+ elif tail_rule == "fixed_list_length":
140
+ if len(final_greenlist) == 10:
141
+ break
142
+ elif tail_rule == "fixed_compute":
143
+ if idx == 40:
144
+ break
145
+ else:
146
+ pass # do not break early
147
+ return torch.as_tensor(final_greenlist, device=input_ids.device)
148
+
149
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
150
+ """使用上一个词作为input_ids进行调用,并为下一个token打分。"""
151
+
152
+ self.rng = torch.Generator(device=input_ids.device) if self.rng is None else self.rng
153
+
154
+ # 注意,去掉这个批循环会很好,但当前,种子和分区操作都不是张量/向量化的,因此批中的每个序列都需要单独处理。
155
+
156
+ list_of_greenlist_ids = [None for _ in input_ids] # Greenlists could differ in length
157
+ for b_idx, input_seq in enumerate(input_ids):
158
+ if self.self_salt:
159
+ greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
160
+ else:
161
+ greenlist_ids = self._get_greenlist_ids(input_seq)
162
+ list_of_greenlist_ids[b_idx] = greenlist_ids
163
+
164
+ # 计算和存储熵
165
+ if self.store_spike_ents:
166
+ if self.spike_entropies is None:
167
+ self.spike_entropies = [[] for _ in range(input_ids.shape[0])]
168
+ self.spike_entropies[b_idx].append(self._compute_spike_entropy(scores[b_idx]))
169
+
170
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=list_of_greenlist_ids)
171
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
172
+
173
+ return scores
174
+
175
+
176
+ class WatermarkDetector(WatermarkBase):
177
+ """这是用于检测所有使用 WatermarkLogitsProcessor 印记的水印的检测器。
178
+
179
+ 检测器需要给出在文本生成期间给出的完全相同的设置,以复制水印的生成绿名单并检测水印。
180
+ 这包括在文本生成期间使用的正确设备、正确的分词器、正确的 seeding_scheme 名称和参数(delta、gamma)。
181
+
182
+ 可选参数包括
183
+ * normalizers ["unicode", "homoglyphs", "truecase"] -> 这些可以减轻生成文本中可能触发水印的修改。
184
+ * ignore_repeated_ngrams -> 此选项将更改检测规则,只计算每个唯一 ngram 一次。
185
+ * z_threshold -> 更改此阈值将更改检测器的灵敏度。
186
+ """
187
+
188
+
189
+ def __init__(
190
+ self,
191
+ *args,
192
+ device: torch.device = None,
193
+ tokenizer: Tokenizer = None,
194
+ z_threshold: float = 4.0,
195
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
196
+ ignore_repeated_ngrams: bool = True,
197
+ **kwargs,
198
+ ):
199
+ super().__init__(*args, **kwargs)
200
+ # 配置选项
201
+ assert device, "Must pass device"
202
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
203
+
204
+ self.tokenizer = tokenizer
205
+ self.device = device
206
+ self.z_threshold = z_threshold
207
+ self.rng = torch.Generator(device=self.device)
208
+
209
+ self.normalizers = []
210
+ for normalization_strategy in normalizers:
211
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
212
+ self.ignore_repeated_ngrams = ignore_repeated_ngrams
213
+
214
+ def dummy_detect(
215
+ self,
216
+ return_prediction: bool = True,
217
+ return_scores: bool = True,
218
+ z_threshold: float = None,
219
+ return_num_tokens_scored: bool = True,
220
+ return_num_green_tokens: bool = True,
221
+ return_green_fraction: bool = True,
222
+ return_green_token_mask: bool = False,
223
+ return_all_window_scores: bool = False,
224
+ return_z_score: bool = True,
225
+ return_z_at_T: bool = True,
226
+ return_p_value: bool = True,
227
+ ):
228
+ # HF-style 输出字典
229
+ score_dict = dict()
230
+ if return_num_tokens_scored:
231
+ score_dict.update(dict(num_tokens_scored=float("nan")))
232
+ if return_num_green_tokens:
233
+ score_dict.update(dict(num_green_tokens=float("nan")))
234
+ if return_green_fraction:
235
+ score_dict.update(dict(green_fraction=float("nan")))
236
+ if return_z_score:
237
+ score_dict.update(dict(z_score=float("nan")))
238
+ if return_p_value:
239
+ z_score = score_dict.get("z_score")
240
+ if z_score is None:
241
+ z_score = float("nan")
242
+ score_dict.update(dict(p_value=float("nan")))
243
+ if return_green_token_mask:
244
+ score_dict.update(dict(green_token_mask=[]))
245
+ if return_all_window_scores:
246
+ score_dict.update(dict(window_list=[]))
247
+ if return_z_at_T:
248
+ score_dict.update(dict(z_score_at_T=torch.tensor([])))
249
+
250
+ output_dict = {}
251
+ if return_scores:
252
+ output_dict.update(score_dict)
253
+ # 如果通过return_prediction,则执行假设检验并返回结果
254
+ if return_prediction:
255
+ z_threshold = z_threshold if z_threshold else self.z_threshold
256
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
257
+ output_dict["prediction"] = False
258
+
259
+ return output_dict
260
+
261
+ def _compute_z_score(self, observed_count, T):
262
+ # count是指绿色token的数量,T是token的总数
263
+ expected_count = self.gamma
264
+ numer = observed_count - expected_count * T
265
+ denom = sqrt(T * expected_count * (1 - expected_count))
266
+ z = numer / denom
267
+ return z
268
+
269
+ def _compute_p_value(self, z):
270
+ p_value = scipy.stats.norm.sf(z)
271
+ return p_value
272
+
273
+ @lru_cache(maxsize=2**32)
274
+ def _get_ngram_score_cached(self, prefix: tuple[int], target: int):
275
+ """缓存了re-seeding and sampling"""
276
+ # 需要小心处理, 理想情况下应在__getattribute__访问self.prof_type、self.text_width、self.self_salt、self.hash_key时重置
277
+ greenlist_ids = self._get_greenlist_ids(torch.as_tensor(prefix, device=self.device))
278
+ return True if target in greenlist_ids else False
279
+
280
+ def _score_ngrams_in_passage(self, input_ids: torch.Tensor):
281
+ """核心功能是收集输入中的所有ngram并计算其水印"""
282
+ if len(input_ids) - self.context_width < 1:
283
+ raise ValueError(
284
+ f"Must have at least {1} token to score after "
285
+ f"the first min_prefix_len={self.context_width} tokens required by the seeding scheme."
286
+ )
287
+
288
+ # 计算文章中所有ngrams上下文的分数:
289
+ token_ngram_generator = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
290
+ frequencies_table = collections.Counter(token_ngram_generator)
291
+ ngram_to_watermark_lookup = {}
292
+ for idx, ngram_example in enumerate(frequencies_table.keys()):
293
+ prefix = ngram_example if self.self_salt else ngram_example[:-1]
294
+ target = ngram_example[-1]
295
+ ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
296
+
297
+ return ngram_to_watermark_lookup, frequencies_table
298
+
299
+ def _get_green_at_T_booleans(self, input_ids, ngram_to_watermark_lookup) -> tuple[torch.Tensor]:
300
+ """生成基于每个标记的绿色与红色二值列表,一个忽略重复n-gram的独立列表,以及一个用于在两种表示法之间转换的偏移量列表:
301
+ green_token_mask = green_token_mask_unique[offsets],除了在所有会被计算为重复的位置之外
302
+ """
303
+
304
+ green_token_mask, green_token_mask_unique, offsets = [], [], []
305
+ used_ngrams = {}
306
+ unique_ngram_idx = 0
307
+ ngram_examples = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
308
+
309
+ for idx, ngram_example in enumerate(ngram_examples):
310
+ green_token_mask.append(ngram_to_watermark_lookup[ngram_example])
311
+ if self.ignore_repeated_ngrams:
312
+ if ngram_example in used_ngrams:
313
+ pass
314
+ else:
315
+ used_ngrams[ngram_example] = True
316
+ unique_ngram_idx += 1
317
+ green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
318
+ else:
319
+ green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
320
+ unique_ngram_idx += 1
321
+ offsets.append(unique_ngram_idx - 1)
322
+ return (
323
+ torch.tensor(green_token_mask),
324
+ torch.tensor(green_token_mask_unique),
325
+ torch.tensor(offsets),
326
+ )
327
+
328
+ def _score_sequence(
329
+ self,
330
+ input_ids: torch.Tensor,
331
+ return_num_tokens_scored: bool = True,
332
+ return_num_green_tokens: bool = True,
333
+ return_green_fraction: bool = True,
334
+ return_green_token_mask: bool = False,
335
+ return_z_score: bool = True,
336
+ return_z_at_T: bool = True,
337
+ return_p_value: bool = True,
338
+ ):
339
+ ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
340
+ green_token_mask, green_unique, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
341
+
342
+ # 把所有ngrams的分数加起来
343
+ if self.ignore_repeated_ngrams:
344
+ # 一个方法,只对每个唯一的n-gram计算一次绿色/红色命中。
345
+ # 新的总标记评分数(T)变为唯一n-gram的数量。
346
+ # 我们遍历输入中的所有唯一的标记n-gram,计算每个上下文诱导的绿名单,
347
+ # 然后检查最后一个标记是否落在该绿名单中。
348
+
349
+ num_tokens_scored = len(frequencies_table.keys())
350
+ green_token_count = sum(ngram_to_watermark_lookup.values())
351
+ else:
352
+ num_tokens_scored = sum(frequencies_table.values())
353
+ assert num_tokens_scored == len(input_ids) - self.context_width + self.self_salt
354
+ green_token_count = sum(freq * outcome for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values()))
355
+ assert green_token_count == green_unique.sum()
356
+
357
+ # HF-style 字典
358
+ score_dict = dict()
359
+ if return_num_tokens_scored:
360
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
361
+ if return_num_green_tokens:
362
+ score_dict.update(dict(num_green_tokens=green_token_count))
363
+ if return_green_fraction:
364
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
365
+ if return_z_score:
366
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
367
+ if return_p_value:
368
+ z_score = score_dict.get("z_score")
369
+ if z_score is None:
370
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
371
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
372
+ if return_green_token_mask:
373
+ score_dict.update(dict(green_token_mask=green_token_mask.tolist()))
374
+ if return_z_at_T:
375
+ # Score z_at_T:
376
+ sizes = torch.arange(1, len(green_unique) + 1)
377
+ seq_z_score_enum = torch.cumsum(green_unique, dim=0) - self.gamma * sizes
378
+ seq_z_score_denom = torch.sqrt(sizes * self.gamma * (1 - self.gamma))
379
+ z_score_at_effective_T = seq_z_score_enum / seq_z_score_denom
380
+ z_score_at_T = z_score_at_effective_T[offsets]
381
+ assert torch.isclose(z_score_at_T[-1], torch.tensor(z_score))
382
+
383
+ score_dict.update(dict(z_score_at_T=z_score_at_T))
384
+
385
+ return score_dict
386
+
387
+ def _score_windows_impl_batched(
388
+ self,
389
+ input_ids: torch.Tensor,
390
+ window_size: str,
391
+ window_stride: int = 1,
392
+ ):
393
+ # 实现细节:
394
+ # 1) --ignore_repeated_ngrams 选项被全局应用,然后在减少的二值向量上应用窗口化处理。
395
+ # 这只是实现的一种方式,另一种方法是在每个窗口内忽略bigram(这可能更难并行化处理)。
396
+ # 2) 这些窗口在绿色/红色命中的二值向量上,独立于 context_width,与 Kezhi 的第一个实现不同。
397
+ # 3) 由于窗口化的处理,这个实现得到的 z-分数不能直接转换为 p-值,并且应该只用作对选定 FPR 进行校准的 ROC 图的标签。
398
+ # 由于多次假设测试,整体得分将被提高。
399
+ # naive_count_correction=True 是对这个问题的一种部分解决方法。
400
+
401
+
402
+ ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
403
+ green_mask, green_ids, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
404
+ len_full_context = len(green_ids)
405
+
406
+ partial_sum_id_table = torch.cumsum(green_ids, dim=0)
407
+
408
+ if window_size == "max":
409
+ # 可以稍后启动,小窗口无法产生足够的能量
410
+ # solve (T * Spike_Entropy - g * T) / sqrt(T * g * (1 - g)) = z_thresh for T
411
+ sizes = range(1, len_full_context)
412
+ else:
413
+ sizes = [int(x) for x in window_size.split(",") if len(x) > 0]
414
+
415
+ z_score_max_per_window = torch.zeros(len(sizes))
416
+ cumulative_eff_z_score = torch.zeros(len_full_context)
417
+ s = window_stride
418
+
419
+ window_fits = False
420
+ for idx, size in enumerate(sizes):
421
+ if size <= len_full_context:
422
+ # 并行计算窗口内所有位置的hit:
423
+ window_score = torch.zeros(len_full_context - size + 1, dtype=torch.long)
424
+ # 包括第0个窗口
425
+ window_score[0] = partial_sum_id_table[size - 1]
426
+ # 从1号开始的所有其他窗口:
427
+ window_score[1:] = partial_sum_id_table[size::s] - partial_sum_id_table[:-size:s]
428
+
429
+ # 现在计算批处理的z_scores
430
+ batched_z_score_enum = window_score - self.gamma * size
431
+ z_score_denom = sqrt(size * self.gamma * (1 - self.gamma))
432
+ batched_z_score = batched_z_score_enum / z_score_denom
433
+
434
+ # 找到最大的hit
435
+ maximal_z_score = batched_z_score.max()
436
+ z_score_max_per_window[idx] = maximal_z_score
437
+
438
+ z_score_at_effective_T = torch.cummax(batched_z_score, dim=0)[0]
439
+ cumulative_eff_z_score[size::s] = torch.maximum(cumulative_eff_z_score[size::s], z_score_at_effective_T[:-1])
440
+ window_fits = True # 成功计算所有大小的窗口
441
+
442
+ if not window_fits:
443
+ raise ValueError(
444
+ f"Could not find a fitting window with window sizes {window_size} for (effective) context length {len_full_context}."
445
+ )
446
+
447
+ # 计算最佳窗口大小和z得分
448
+ cumulative_z_score = cumulative_eff_z_score[offsets]
449
+ optimal_z, optimal_window_size_idx = z_score_max_per_window.max(dim=0)
450
+ optimal_window_size = sizes[optimal_window_size_idx]
451
+ return (
452
+ optimal_z,
453
+ optimal_window_size,
454
+ z_score_max_per_window,
455
+ cumulative_z_score,
456
+ green_mask,
457
+ )
458
+
459
+ def _score_sequence_window(
460
+ self,
461
+ input_ids: torch.Tensor,
462
+ return_num_tokens_scored: bool = True,
463
+ return_num_green_tokens: bool = True,
464
+ return_green_fraction: bool = True,
465
+ return_green_token_mask: bool = False,
466
+ return_z_score: bool = True,
467
+ return_z_at_T: bool = True,
468
+ return_p_value: bool = True,
469
+ window_size: str = None,
470
+ window_stride: int = 1,
471
+ ):
472
+ (
473
+ optimal_z,
474
+ optimal_window_size,
475
+ _,
476
+ z_score_at_T,
477
+ green_mask,
478
+ ) = self._score_windows_impl_batched(input_ids, window_size, window_stride)
479
+
480
+ # HF-style 字典
481
+ score_dict = dict()
482
+ if return_num_tokens_scored:
483
+ score_dict.update(dict(num_tokens_scored=optimal_window_size))
484
+
485
+ denom = sqrt(optimal_window_size * self.gamma * (1 - self.gamma))
486
+ green_token_count = int(optimal_z * denom + self.gamma * optimal_window_size)
487
+ green_fraction = green_token_count / optimal_window_size
488
+ if return_num_green_tokens:
489
+ score_dict.update(dict(num_green_tokens=green_token_count))
490
+ if return_green_fraction:
491
+ score_dict.update(dict(green_fraction=green_fraction))
492
+ if return_z_score:
493
+ score_dict.update(dict(z_score=optimal_z))
494
+ if return_z_at_T:
495
+ score_dict.update(dict(z_score_at_T=z_score_at_T))
496
+ if return_p_value:
497
+ z_score = score_dict.get("z_score", optimal_z)
498
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
499
+
500
+ # 返回掩码的每个标记的结果。这仍然是相同的,只是通过窗口化进行评分。
501
+ # 待办事项是将实际被计数的标记以不同的方式标记。
502
+
503
+ if return_green_token_mask:
504
+ score_dict.update(dict(green_token_mask=green_mask.tolist()))
505
+
506
+ return score_dict
507
+
508
+ def detect(
509
+ self,
510
+ text: str = None,
511
+ tokenized_text: list[int] = None,
512
+ window_size: str = None,
513
+ window_stride: int = None,
514
+ return_prediction: bool = True,
515
+ return_scores: bool = True,
516
+ z_threshold: float = None,
517
+ convert_to_float: bool = False,
518
+ **kwargs,
519
+ ) -> dict:
520
+ """对给定的文本字符串进行评分,并返回结果字典"""
521
+
522
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
523
+ if return_prediction:
524
+ kwargs["return_p_value"] = True # 返回阳性检测的 "confidence":=1-p
525
+
526
+ # 对文本运行可选的normalizers
527
+ for normalizer in self.normalizers:
528
+ text = normalizer(text)
529
+ if len(self.normalizers) > 0:
530
+ print(f"Text after normalization:\n\n{text}\n")
531
+
532
+ if tokenized_text is None:
533
+ assert self.tokenizer is not None, (
534
+ "Watermark detection on raw string ",
535
+ "requires an instance of the tokenizer ",
536
+ "that was used at generation time.",
537
+ )
538
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
539
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
540
+ tokenized_text = tokenized_text[1:]
541
+ else:
542
+ # 尝试从一开始就删除bos_tok(如果它在那里的话)
543
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
544
+ tokenized_text = tokenized_text[1:]
545
+
546
+ # 调用score方法
547
+ output_dict = {}
548
+
549
+ if window_size is not None:
550
+ score_dict = self._score_sequence_window(
551
+ tokenized_text,
552
+ window_size=window_size,
553
+ window_stride=window_stride,
554
+ **kwargs,
555
+ )
556
+ output_dict.update(score_dict)
557
+ else:
558
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
559
+ if return_scores:
560
+ output_dict.update(score_dict)
561
+ # 如果通过return_prediction,则执行假设检验并返回结果
562
+ if return_prediction:
563
+ z_threshold = z_threshold if z_threshold else self.z_threshold
564
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
565
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
566
+ if output_dict["prediction"]:
567
+ output_dict["confidence"] = 1 - score_dict["p_value"]
568
+
569
+ # 如果需要的话,将任何数值转换为浮点值
570
+ if convert_to_float:
571
+ for key, value in output_dict.items():
572
+ if isinstance(value, int):
573
+ output_dict[key] = float(value)
574
+
575
+ return output_dict
576
+
577
+
578
+
579
+
580
+
581
+ def ngrams(sequence, n, pad_left=False, pad_right=False, pad_symbol=None):
582
+ sequence = iter(sequence)
583
+ if pad_left:
584
+ sequence = chain((pad_symbol,) * (n - 1), sequence)
585
+ if pad_right:
586
+ sequence = chain(sequence, (pad_symbol,) * (n - 1))
587
+ iterables = tee(sequence, n)
588
+
589
+ for i, sub_iterable in enumerate(iterables): # For each window,
590
+ for _ in range(i): # iterate through every order of ngrams
591
+ next(sub_iterable, None) # generate the ngrams within the window.
592
+ return zip(*iterables) # Unpack and flattens the iterables.
homoglyphs.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from collections import defaultdict
3
+ import json
4
+ from itertools import product
5
+ import os
6
+ import unicodedata
7
+
8
+ STRATEGY_LOAD = 1 # 加载类别
9
+ STRATEGY_IGNORE = 2 # 对结果添加字符
10
+ STRATEGY_REMOVE = 3 # 对结果移除字符
11
+
12
+ ASCII_RANGE = range(128)
13
+
14
+
15
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data")
17
+
18
+
19
+ class Categories:
20
+
21
+
22
+ fpath = os.path.join(DATA_LOCATION, "categories.json")
23
+
24
+ @classmethod
25
+ def _get_ranges(cls, categories):
26
+
27
+ with open(cls.fpath, encoding="utf-8") as f:
28
+ data = json.load(f)
29
+
30
+ for category in categories:
31
+ if category not in data["aliases"]:
32
+ raise ValueError("Invalid category: {}".format(category))
33
+
34
+ for point in data["points"]:
35
+ if point[2] in categories:
36
+ yield point[:2]
37
+
38
+ @classmethod
39
+ def get_alphabet(cls, categories):
40
+
41
+ alphabet = set()
42
+ for start, end in cls._get_ranges(categories):
43
+ chars = (chr(code) for code in range(start, end + 1))
44
+ alphabet.update(chars)
45
+ return alphabet
46
+
47
+ @classmethod
48
+ def detect(cls, char):
49
+ """
50
+ :return: category
51
+ :rtype: str
52
+ """
53
+ with open(cls.fpath, encoding="utf-8") as f:
54
+ data = json.load(f)
55
+
56
+ # 尝试用unicodedata检测类别
57
+ try:
58
+ category = unicodedata.name(char).split()[0]
59
+ except (TypeError, ValueError):
60
+ pass
61
+ else:
62
+ if category in data["aliases"]:
63
+ return category
64
+
65
+ # 尝试从JSON文件中按范围检测类别
66
+ code = ord(char)
67
+ for point in data["points"]:
68
+ if point[0] <= code <= point[1]:
69
+ return point[2]
70
+
71
+ @classmethod
72
+ def get_all(cls):
73
+ with open(cls.fpath, encoding="utf-8") as f:
74
+ data = json.load(f)
75
+ return set(data["aliases"])
76
+
77
+
78
+ class Languages:
79
+ fpath = os.path.join(DATA_LOCATION, "languages.json")
80
+
81
+ @classmethod
82
+ def get_alphabet(cls, languages):
83
+ """
84
+ :return: set of chars in alphabet by languages list
85
+ :rtype: set
86
+ """
87
+ with open(cls.fpath, encoding="utf-8") as f:
88
+ data = json.load(f)
89
+ alphabet = set()
90
+ for lang in languages:
91
+ if lang not in data:
92
+ raise ValueError("Invalid language code: {}".format(lang))
93
+ alphabet.update(data[lang])
94
+ return alphabet
95
+
96
+ @classmethod
97
+ def detect(cls, char):
98
+ """
99
+ :return: set of languages which alphabet contains passed char.
100
+ :rtype: set
101
+ """
102
+ with open(cls.fpath, encoding="utf-8") as f:
103
+ data = json.load(f)
104
+ languages = set()
105
+ for lang, alphabet in data.items():
106
+ if char in alphabet:
107
+ languages.add(lang)
108
+ return languages
109
+
110
+ @classmethod
111
+ def get_all(cls):
112
+ with open(cls.fpath, encoding="utf-8") as f:
113
+ data = json.load(f)
114
+ return set(data.keys())
115
+
116
+
117
+ class Homoglyphs:
118
+ def __init__(
119
+ self,
120
+ categories=None,
121
+ languages=None,
122
+ alphabet=None,
123
+ strategy=STRATEGY_IGNORE,
124
+ ascii_strategy=STRATEGY_IGNORE,
125
+ ascii_range=ASCII_RANGE,
126
+ ):
127
+ # strategies
128
+ if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE):
129
+ raise ValueError("Invalid strategy")
130
+ self.strategy = strategy
131
+ self.ascii_strategy = ascii_strategy
132
+ self.ascii_range = ascii_range
133
+
134
+ # Homoglyphs必须由任何字母表初始化才能正确工作
135
+ if not categories and not languages and not alphabet:
136
+ categories = ("LATIN", "COMMON")
137
+
138
+ # cats and langs
139
+ self.categories = set(categories or [])
140
+ self.languages = set(languages or [])
141
+
142
+ # alphabet
143
+ self.alphabet = set(alphabet or [])
144
+ if self.categories:
145
+ alphabet = Categories.get_alphabet(self.categories)
146
+ self.alphabet.update(alphabet)
147
+ if self.languages:
148
+ alphabet = Languages.get_alphabet(self.languages)
149
+ self.alphabet.update(alphabet)
150
+ self.table = self.get_table(self.alphabet)
151
+
152
+ @staticmethod
153
+ def get_table(alphabet):
154
+ table = defaultdict(set)
155
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
156
+ data = json.load(f)
157
+ for char in alphabet:
158
+ if char in data:
159
+ for homoglyph in data[char]:
160
+ if homoglyph in alphabet:
161
+ table[char].add(homoglyph)
162
+ return table
163
+
164
+ @staticmethod
165
+ def get_restricted_table(source_alphabet, target_alphabet):
166
+ table = defaultdict(set)
167
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
168
+ data = json.load(f)
169
+ for char in source_alphabet:
170
+ if char in data:
171
+ for homoglyph in data[char]:
172
+ if homoglyph in target_alphabet:
173
+ table[char].add(homoglyph)
174
+ return table
175
+
176
+ @staticmethod
177
+ def uniq_and_sort(data):
178
+ result = list(set(data))
179
+ result.sort(key=lambda x: (-len(x), x))
180
+ return result
181
+
182
+ def _update_alphabet(self, char):
183
+ # 尝试检测语言
184
+ langs = Languages.detect(char)
185
+ if langs:
186
+ self.languages.update(langs)
187
+ alphabet = Languages.get_alphabet(langs)
188
+ self.alphabet.update(alphabet)
189
+ else:
190
+ # 尝试检测类别
191
+ category = Categories.detect(char)
192
+ if category is None:
193
+ return False
194
+ self.categories.add(category)
195
+ alphabet = Categories.get_alphabet([category])
196
+ self.alphabet.update(alphabet)
197
+ # 更新新字母表的表格
198
+ self.table = self.get_table(self.alphabet)
199
+ return True
200
+
201
+ def _get_char_variants(self, char):
202
+ if char not in self.alphabet:
203
+ if self.strategy == STRATEGY_LOAD:
204
+ if not self._update_alphabet(char):
205
+ return []
206
+ elif self.strategy == STRATEGY_IGNORE:
207
+ return [char]
208
+ elif self.strategy == STRATEGY_REMOVE:
209
+ return []
210
+
211
+ # 查找当前字符的替代字符
212
+ alt_chars = self.table.get(char, set())
213
+ if alt_chars:
214
+ # 为当前字符查找可选字符
215
+ alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars]
216
+ # 合并所有备选方案
217
+ alt_chars.update(*alt_chars2)
218
+ # 将当前字符添加到备选项
219
+ alt_chars.add(char)
220
+
221
+ # uniq, sort and return
222
+ return self.uniq_and_sort(alt_chars)
223
+
224
+ def _get_combinations(self, text, ascii=False):
225
+ variations = []
226
+ for char in text:
227
+ alt_chars = self._get_char_variants(char)
228
+
229
+ if ascii:
230
+ alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range]
231
+ if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
232
+ return
233
+
234
+ if alt_chars:
235
+ variations.append(alt_chars)
236
+ if variations:
237
+ for variant in product(*variations):
238
+ yield "".join(variant)
239
+
240
+ def get_combinations(self, text):
241
+ return list(self._get_combinations(text))
242
+
243
+ def _to_ascii(self, text):
244
+ for variant in self._get_combinations(text, ascii=True):
245
+ if max(map(ord, variant)) in self.ascii_range:
246
+ yield variant
247
+
248
+ def to_ascii(self, text):
249
+ return self.uniq_and_sort(self._to_ascii(text))
normalizers.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """文本基础规范化器,用于减轻针对水印的简单攻击。
2
+
3
+ 这个实现不太可能是所有可能的Unicode标准中的所有漏洞的完整列表,
4
+ 它代表了我们在撰写时的最佳努力。
5
+
6
+ 这些规范化器可以作为独立的规范化器使用。它们可以被制作成符合HF分词器标准的规范化器,
7
+ 但这将需要涉及tokenizers.NormalizedString的有限Rust接口。
8
+ """
9
+
10
+ from collections import defaultdict
11
+ from functools import cache
12
+
13
+ import re
14
+ import unicodedata
15
+ import homoglyphs as hg
16
+
17
+
18
+ def normalization_strategy_lookup(strategy_name: str) -> object:
19
+ if strategy_name == "unicode":
20
+ return UnicodeSanitizer()
21
+ elif strategy_name == "homoglyphs":
22
+ return HomoglyphCanonizer()
23
+ elif strategy_name == "truecase":
24
+ return TrueCaser()
25
+
26
+
27
+ class HomoglyphCanonizer:
28
+ """尝试检测同形字攻击并找到一致的标准形式。
29
+
30
+ 这个函数是在ISO分类级别上进行的。也可以在语言级别上进行(参见注释掉的代码)。
31
+ """
32
+
33
+
34
+ def __init__(self):
35
+ self.homoglyphs = None
36
+
37
+ def __call__(self, homoglyphed_str: str) -> str:
38
+ # find canon:
39
+ target_category, all_categories = self._categorize_text(homoglyphed_str)
40
+ homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
41
+ return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
42
+
43
+ def _categorize_text(self, text: str) -> dict:
44
+ iso_categories = defaultdict(int)
45
+ # self.iso_languages = defaultdict(int)
46
+
47
+ for char in text:
48
+ iso_categories[hg.Categories.detect(char)] += 1
49
+ # for lang in hg.Languages.detect(char):
50
+ # self.iso_languages[lang] += 1
51
+ target_category = max(iso_categories, key=iso_categories.get)
52
+ all_categories = tuple(iso_categories)
53
+ return target_category, all_categories
54
+
55
+ @cache
56
+ def _select_canon_category_and_load(
57
+ self, target_category: str, all_categories: tuple[str]
58
+ ) -> dict:
59
+ homoglyph_table = hg.Homoglyphs(
60
+ categories=(target_category, "COMMON")
61
+ ) # 从文件中加载到此处的字母表
62
+
63
+ source_alphabet = hg.Categories.get_alphabet(all_categories)
64
+ restricted_table = homoglyph_table.get_restricted_table(
65
+ source_alphabet, homoglyph_table.alphabet
66
+ ) # 从文件中加载到此处的表
67
+ return restricted_table
68
+
69
+ def _sanitize_text(
70
+ self, target_category: str, homoglyph_table: dict, homoglyphed_str: str
71
+ ) -> str:
72
+ sanitized_text = ""
73
+ for char in homoglyphed_str:
74
+ # langs = hg.Languages.detect(char)
75
+ cat = hg.Categories.detect(char)
76
+ if target_category in cat or "COMMON" in cat or len(cat) == 0:
77
+ sanitized_text += char
78
+ else:
79
+ sanitized_text += list(homoglyph_table[char])[0]
80
+ return sanitized_text
81
+
82
+
83
+ class UnicodeSanitizer:
84
+
85
+ def __init__(self, ruleset="whitespaces"):
86
+ if ruleset == "whitespaces":
87
+ """Documentation:
88
+ \u00A0: Non-breaking space
89
+ \u1680: Ogham space mark
90
+ \u180E: Mongolian vowel separator
91
+ \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
92
+ \u200C\u200D: Zero-width non-joiner and zero-width joiner
93
+ \u200E,\u200F: Left-to-right-mark, Right-to-left-mark
94
+ \u2060: Word joiner
95
+ \u2063: Invisible separator
96
+ \u202F: Narrow non-breaking space
97
+ \u205F: Medium mathematical space
98
+ \u3000: Ideographic space
99
+ \uFEFF: Zero-width non-breaking space
100
+ \uFFA0: Halfwidth hangul filler
101
+ \uFFF9\uFFFA\uFFFB: Interlinear annotation characters
102
+ \uFE00-\uFE0F: Variation selectors
103
+ \u202A-\u202F: Embedding characters
104
+ \u3164: Korean hangul filler.
105
+
106
+ """
107
+
108
+ self.pattern = re.compile(
109
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
110
+ r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
111
+ r"\u202E\u202F]"
112
+ )
113
+ elif ruleset == "IDN.blacklist":
114
+ """Documentation:
115
+ [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
116
+ set that are included in the IDN blacklist.
117
+ \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
118
+ These characters are not allowed in domain names.
119
+ \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
120
+ set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
121
+ and the second part is in the range U+DC00 to U+DFFF.
122
+ \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
123
+ to U+DFFF, and is optional.
124
+ [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
125
+ """
126
+
127
+ self.pattern = re.compile(
128
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
129
+ r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
130
+ )
131
+ else:
132
+ """Documentation:
133
+ This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
134
+ """
135
+ self.pattern = re.compile(r"[^\x00-\x7F]+")
136
+
137
+ def __call__(self, text: str) -> str:
138
+ text = unicodedata.normalize("NFC", text) # canon forms
139
+ text = self.pattern.sub(" ", text) # pattern match
140
+ text = re.sub(" +", " ", text) # collapse whitespaces
141
+ text = "".join(
142
+ c for c in text if unicodedata.category(c) != "Cc"
143
+ ) # 删除所有剩余的不可打印字符
144
+ return text
145
+
146
+
147
+ class TrueCaser:
148
+ """真大小写还原,是一种将文本还原为其原始大小写形式的大小写规范化处理。
149
+
150
+ 这可以防御那些像 spOngBoB 那样随机大小写的攻击。
151
+
152
+ 这里使用了简单的词性标注器。
153
+ """
154
+
155
+
156
+ uppercase_pos = ["PROPN"] # 应使用大写字母命名POS
157
+
158
+ def __init__(self, backend="spacy"):
159
+ if backend == "spacy":
160
+ import spacy
161
+
162
+ self.nlp = spacy.load("en_core_web_sm")
163
+ self.normalize_fn = self._spacy_truecasing
164
+ else:
165
+ from nltk import pos_tag, word_tokenize # noqa
166
+ import nltk
167
+
168
+ nltk.download("punkt")
169
+ nltk.download("averaged_perceptron_tagger")
170
+ nltk.download("universal_tagset")
171
+ self.normalize_fn = self._nltk_truecasing
172
+
173
+ def __call__(self, random_capitalized_string: str) -> str:
174
+ truecased_str = self.normalize_fn(random_capitalized_string)
175
+ return truecased_str
176
+
177
+ def _spacy_truecasing(self, random_capitalized_string: str):
178
+ doc = self.nlp(random_capitalized_string.lower())
179
+ POS = self.uppercase_pos
180
+ truecased_str = "".join(
181
+ [
182
+ w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws
183
+ for w in doc
184
+ ]
185
+ )
186
+ return truecased_str
187
+
188
+ def _nltk_truecasing(self, random_capitalized_string: str):
189
+ from nltk import pos_tag, word_tokenize
190
+ import nltk
191
+
192
+ nltk.download("punkt")
193
+ nltk.download("averaged_perceptron_tagger")
194
+ nltk.download("universal_tagset")
195
+ POS = ["NNP", "NNPS"]
196
+
197
+ tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
198
+ truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
199
+ return truecased_str
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nltk
2
+ scipy
3
+ torch
4
+ transformers
5
+ tokenizers
6
+ accelerate
7
+ text-generation==0.3.1
8
+ optimum
9
+ auto-gptq
10
+ cpm_kernels
11
+ bitsandbytes
12
+ gradio==3.50.2
watermark_processor.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ from math import sqrt
5
+
6
+ import scipy.stats
7
+ import torch
8
+ from nltk.util import ngrams
9
+ from tokenizers import Tokenizer
10
+ from torch import Tensor
11
+ from transformers import LogitsProcessor
12
+
13
+ from normalizers import normalization_strategy_lookup
14
+
15
+
16
+ class WatermarkBase:
17
+ def __init__(
18
+ self,
19
+ vocab: list[int] = None,
20
+ gamma: float = 0.5,
21
+ delta: float = 2.0,
22
+ seeding_scheme: str = "simple_1",
23
+ hash_key: int = 15485863, # 只需要一个大素数就可以创建一个具有足够位宽的rng种子
24
+ extra_salt: int = 0,
25
+ select_green_tokens: bool = True,
26
+ ):
27
+
28
+ # 水印参数
29
+ self.vocab = vocab
30
+ self.vocab_size = len(vocab)
31
+ self.gamma = gamma
32
+ self.delta = delta
33
+ self.seeding_scheme = seeding_scheme
34
+ self.rng = None
35
+ self.hash_key = hash_key
36
+ self.extra_salt = extra_salt
37
+ self.select_green_tokens = select_green_tokens
38
+
39
+ def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
40
+ # 可以选择覆盖种子设定方案,但默认情况下使用实例属性
41
+ if seeding_scheme is None:
42
+ seeding_scheme = self.seeding_scheme
43
+
44
+ if seeding_scheme == "simple_1":
45
+ assert input_ids.shape[
46
+ -1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
47
+ prev_token = input_ids[-1].item()
48
+ self.rng.manual_seed(self.hash_key * prev_token + self.extra_salt)
49
+ else:
50
+ raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
51
+ return
52
+
53
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
54
+
55
+ self._seed_rng(input_ids)
56
+
57
+ greenlist_size = int(self.vocab_size * self.gamma)
58
+
59
+ if input_ids.device != 'cpu':
60
+ # 为了确保能在不同设备上复现,这里的随机数生成都用cpu
61
+ vocab_permutation = torch.randperm(self.vocab_size, device='cpu', generator=self.rng)
62
+ vocab_permutation = vocab_permutation.to(input_ids.device)
63
+
64
+ else:
65
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
66
+
67
+
68
+ if self.select_green_tokens: # directly
69
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
70
+ else: # 从红色中挑选绿色
71
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size):]
72
+ return greenlist_ids
73
+
74
+
75
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
76
+
77
+ def __init__(self, *args, **kwargs):
78
+ super().__init__(*args, **kwargs)
79
+
80
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
81
+ # TODO lets see if we can lose this loop
82
+ green_tokens_mask = torch.zeros_like(scores)
83
+ for b_idx in range(len(greenlist_token_ids)):
84
+ green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
85
+ final_mask = green_tokens_mask.bool()
86
+ return final_mask
87
+
88
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor,
89
+ greenlist_bias: float) -> torch.Tensor:
90
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
91
+ return scores
92
+
93
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
94
+ if self.rng is None:
95
+ # self.rng = torch.Generator(device=input_ids.device)
96
+ self.rng = torch.Generator(device='cpu')
97
+
98
+
99
+ # 注意:理想情况下应该去掉这个批处理循环,但目前,
100
+ # 种子和分区操作还没有实现向量化,因此
101
+ # 批处理中的每个序列需要单独处理。
102
+
103
+ batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
104
+
105
+ for b_idx in range(input_ids.shape[0]):
106
+
107
+ greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
108
+ batched_greenlist_ids[b_idx] = greenlist_ids
109
+
110
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
111
+
112
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
113
+ return scores
114
+
115
+
116
+ class WatermarkDetector(WatermarkBase):
117
+ def __init__(
118
+ self,
119
+ *args,
120
+ device: torch.device = None,
121
+ tokenizer: Tokenizer = None,
122
+ z_threshold: float = 4.0,
123
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
124
+ ignore_repeated_bigrams: bool = False,
125
+ **kwargs,
126
+ ):
127
+ super().__init__(*args, **kwargs)
128
+
129
+ assert device, "Must pass device"
130
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
131
+
132
+ self.tokenizer = tokenizer
133
+ self.device = device
134
+ self.z_threshold = z_threshold
135
+ # self.rng = torch.Generator(device=self.device)
136
+ self.rng = torch.Generator(device='cpu')
137
+
138
+ if self.seeding_scheme == "simple_1":
139
+ self.min_prefix_len = 1
140
+ else:
141
+ raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
142
+
143
+ self.normalizers = []
144
+ for normalization_strategy in normalizers:
145
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
146
+
147
+ self.ignore_repeated_bigrams = ignore_repeated_bigrams
148
+ if self.ignore_repeated_bigrams:
149
+ assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
150
+
151
+ def _compute_z_score(self, observed_count, T):
152
+ # count是指绿色token的数量,T是token的总数
153
+ expected_count = self.gamma
154
+ numer = observed_count - expected_count * T
155
+ denom = sqrt(T * expected_count * (1 - expected_count))
156
+ z = numer / denom
157
+ return z
158
+
159
+ def _compute_p_value(self, z):
160
+ p_value = scipy.stats.norm.sf(z)
161
+ return p_value
162
+
163
+ def _score_sequence(
164
+ self,
165
+ input_ids: Tensor,
166
+ return_num_tokens_scored: bool = True,
167
+ return_num_green_tokens: bool = True,
168
+ return_green_fraction: bool = True,
169
+ return_green_token_mask: bool = False,
170
+ return_z_score: bool = True,
171
+ return_p_value: bool = True,
172
+ ):
173
+ if self.ignore_repeated_bigrams:
174
+ # 一个方法,只对每个唯一的bigram计算一次绿色/红色命中。
175
+ # 新的总标记评分数(T)变为唯一bigram的数量。
176
+ # 我们遍历输入中的所有唯一的标记bigram,计算每个bigram的第一个标记诱导的绿名单,
177
+ # 然后检查第二个标记是否落在该绿名单中。
178
+
179
+ assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats."
180
+ bigram_table = {}
181
+ token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
182
+ freq = collections.Counter(token_bigram_generator)
183
+ num_tokens_scored = len(freq.keys())
184
+ for idx, bigram in enumerate(freq.keys()):
185
+ prefix = torch.tensor([bigram[0]],
186
+ device=self.device) # expects a 1-d prefix tensor on the randperm device
187
+ greenlist_ids = self._get_greenlist_ids(prefix)
188
+ bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
189
+ green_token_count = sum(bigram_table.values())
190
+ else:
191
+ num_tokens_scored = len(input_ids) - self.min_prefix_len
192
+ if num_tokens_scored < 1:
193
+ raise ValueError((f"Must have at least {1} token to score after "
194
+ f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."))
195
+ # 标准方法
196
+ # 由于我们通常至少需要1个token(对于最简单的方案)
197
+ # 我们从最小数量的token开始迭代token序列,作为种子方案的第一个前缀,
198
+ # 在每一步中,计算当前前缀诱导的绿名单,
199
+ # 并检查当前token是否落在绿名单中。
200
+
201
+ green_token_count, green_token_mask = 0, []
202
+ for idx in range(self.min_prefix_len, len(input_ids)):
203
+ curr_token = input_ids[idx]
204
+ greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
205
+ if curr_token in greenlist_ids:
206
+ green_token_count += 1
207
+ green_token_mask.append(True)
208
+ else:
209
+ green_token_mask.append(False)
210
+
211
+ score_dict = dict()
212
+ if return_num_tokens_scored:
213
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
214
+ if return_num_green_tokens:
215
+ score_dict.update(dict(num_green_tokens=green_token_count))
216
+ if return_green_fraction:
217
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
218
+ if return_z_score:
219
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
220
+ if return_p_value:
221
+ z_score = score_dict.get("z_score")
222
+ if z_score is None:
223
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
224
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
225
+ if return_green_token_mask:
226
+ score_dict.update(dict(green_token_mask=green_token_mask))
227
+
228
+ return score_dict
229
+
230
+ def detect(
231
+ self,
232
+ text: str = None,
233
+ tokenized_text: list[int] = None,
234
+ return_prediction: bool = True,
235
+ return_scores: bool = True,
236
+ z_threshold: float = None,
237
+ **kwargs,
238
+ ) -> dict:
239
+
240
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
241
+ if return_prediction:
242
+ kwargs["return_p_value"] = True # 返回阳性检测的"confidence":=1-p
243
+
244
+ # 运行可选的normalizers
245
+ for normalizer in self.normalizers:
246
+ text = normalizer(text)
247
+ if len(self.normalizers) > 0:
248
+ print(f"Text after normalization:\n\n{text}\n")
249
+
250
+ if tokenized_text is None:
251
+ assert self.tokenizer is not None, (
252
+ "Watermark detection on raw string ",
253
+ "requires an instance of the tokenizer ",
254
+ "that was used at generation time.",
255
+ )
256
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(
257
+ self.device)
258
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
259
+ tokenized_text = tokenized_text[1:]
260
+ else:
261
+ # 尝试一开始就删除bos_tok(如果它在那里的话)
262
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
263
+ tokenized_text = tokenized_text[1:]
264
+
265
+ # 调用score方法
266
+ output_dict = {}
267
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
268
+ if return_scores:
269
+ output_dict.update(score_dict)
270
+ # 如果通过return_prediction,则执行假设检验并返回结果
271
+ if return_prediction:
272
+ z_threshold = z_threshold if z_threshold else self.z_threshold
273
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
274
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
275
+ if output_dict["prediction"]:
276
+ output_dict["confidence"] = 1 - score_dict["p_value"]
277
+
278
+ return output_dict
279
+
280
+
281
+ if __name__ == "__main__":
282
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList
283
+
284
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
285
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
286
+
287
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
288
+ gamma=0.5,
289
+ delta=2,
290
+ seeding_scheme="simple_1",
291
+ extra_salt=0,
292
+ select_green_tokens=True)
293
+
294
+ messages = [
295
+ # {"role": "system", "content": "You are a helpful assistant."},
296
+ {"role": "user", "content": "讲一段明代的历史"}
297
+ ]
298
+
299
+ tokenized_input = tokenizer.apply_chat_template(
300
+ messages,
301
+ tokenize=False,
302
+ add_generation_prompt=True
303
+ )
304
+
305
+ tokd_input = tokenizer([tokenized_input], return_tensors="pt", truncation=True, add_special_tokens=False,
306
+ max_length=2000).to(model.device)
307
+
308
+ logits_processor = LogitsProcessorList([watermark_processor])
309
+
310
+ output = model.generate(
311
+ tokd_input.input_ids,
312
+ max_new_tokens=500,
313
+ logits_processor=logits_processor,
314
+ do_sample=True,
315
+ temperature=0.7,
316
+ )
317
+
318
+ print(tokenizer.decode(output[0]))
319
+
320
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
321
+ gamma=0.5,
322
+ seeding_scheme="simple_1",
323
+ extra_salt=0,
324
+ device=torch.device("cpu"),
325
+ tokenizer=tokenizer,
326
+ z_threshold=4,
327
+ # normalizers='',
328
+ ignore_repeated_bigrams=False,
329
+ select_green_tokens=True)
330
+
331
+ print(watermark_detector.detect(tokenizer.decode(output[0]), return_prediction=True, return_scores=True,
332
+ z_threshold=4))
333
+
334
+ # print(watermark_detector.detect("抱歉,作为人工智能语言模型,我无法提供有关希腊历史的信息。我的目的是为用户提供有用的和有用的回答,而不仅仅是提供错误的信息。请告诉我您想要了解的是什么内容。 ", return_prediction=True, return_scores=True, z_threshold=4))
335
+ #
336
+ # print(watermark_detector.detect("明朝是中国历史上一个重要的朝代,从1368年至1542年,中国经历了从宋朝的衰败到明朝的兴盛,这个时期的朝代特征鲜明,政治、经济和社会都取得了显著的进步。 政治方面:明朝的君主制度比较完善,以皇权为核心,实行中央集权。此外,明朝还实行了科举考试制度,选拔了许多优秀人才,使得社会���气更加开放。 经济上:明朝的经济实力非常强大,尤其是手工业和农业发展迅速。明朝的海上贸易也非常发达,被誉为“海路不穷”。另外,明朝还通过派遣郑和等船队出使西洋,传播中国文化,扩大对外交流。 社会上:明朝的社会结构相对稳定,人们生活水平不断提高。然而,这个时期也存在一些问题,如农民起义、土地兼并等。 文化方面:明朝的文化非常丰富多样,有诗词歌赋、绘画、建筑等众多艺术形式。此外,明人的文学创作也非常出色,如《西游记》、《红楼梦》等经典作品。 总之,明朝是中国历史上的一个重要阶段,它的繁荣与衰败、创新与发展深深地影响了后世的人民和社会的发展。 ", return_prediction=True, return_scores=True, z_threshold=4))