Spaces:
Sleeping
Sleeping
jianuo
commited on
Commit
·
c99a3a6
1
Parent(s):
7781fb8
first upload
Browse files- .gitattributes +10 -11
- .gitignore +129 -0
- MANIFEST.in +7 -0
- __init__.py +0 -0
- alternative_prf_schemes.py +167 -0
- app.py +39 -0
- demo_watermark.py +1085 -0
- extended_watermark_processor.py +592 -0
- homoglyphs.py +249 -0
- normalizers.py +199 -0
- requirements.txt +12 -0
- watermark_processor.py +336 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,34 @@
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 20 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# added manually
|
| 2 |
+
include *.py
|
| 3 |
+
include *.json
|
| 4 |
+
include *.md
|
| 5 |
+
|
| 6 |
+
global-exclude *.pyc
|
| 7 |
+
global-exclude __pycache__
|
__init__.py
ADDED
|
File without changes
|
alternative_prf_schemes.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""实现其他 PRF 函数(这些函数的不同之处仅在于如何从上下文中的令牌生成单个哈希值)。
|
| 2 |
+
|
| 3 |
+
可作为修改后的基类 WatermarkBase 挂接到现有的 WatermarkLogitsProcessor 中,请参见
|
| 4 |
+
extended_watermark_processor.py 中的实现。
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from itertools import combinations
|
| 11 |
+
from functools import cache
|
| 12 |
+
|
| 13 |
+
# 哈希方案的关键属性
|
| 14 |
+
props = {
|
| 15 |
+
"prf_type": str, # 基础 PRF 的字符串名称,将多个令牌 ID 映射到随机种子
|
| 16 |
+
"context_width": int, # 这是论文中的 h,每个 PRF 应考虑多少个先前的令牌
|
| 17 |
+
"self_salt": bool, # 根据鲁棒水印技术中的规则,是否使用令牌本身来生成种子,并可能拒绝其自身的列表
|
| 18 |
+
"hash_key": int, # 整数,大质数,用于将种子移动到上述所选 PRF 中的低熵位序列的远离位置
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def seeding_scheme_lookup(seeding_scheme: str):
|
| 23 |
+
if not isinstance(seeding_scheme, str):
|
| 24 |
+
raise ValueError("Seeding scheme should be a string summarizing the procedure.")
|
| 25 |
+
if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
|
| 26 |
+
# 默认的简单二元哈希 # 别名为 ff-additive_prf-1-False-15485863
|
| 27 |
+
prf_type = "additive_prf"
|
| 28 |
+
context_width = 1
|
| 29 |
+
self_salt = False
|
| 30 |
+
hash_key = 15485863
|
| 31 |
+
elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
|
| 32 |
+
prf_type = "anchored_minhash_prf"
|
| 33 |
+
context_width = 4
|
| 34 |
+
self_salt = True
|
| 35 |
+
hash_key = 15485863
|
| 36 |
+
elif seeding_scheme == "minhash":
|
| 37 |
+
prf_type = "minhash_prf"
|
| 38 |
+
context_width = 4
|
| 39 |
+
self_salt = False
|
| 40 |
+
hash_key = 15485863
|
| 41 |
+
elif seeding_scheme == "skipgram":
|
| 42 |
+
prf_type = "skipgram_prf"
|
| 43 |
+
context_width = 5
|
| 44 |
+
self_salt = False
|
| 45 |
+
hash_key = 15485863
|
| 46 |
+
elif seeding_scheme.startswith("ff"): # 自由形式的种子方案 API - 仅用于实验目的
|
| 47 |
+
# 期望形式为 ff-additive_prf-4-True-hash 或 ff-additive_prf-5-True (哈希键是可选的)
|
| 48 |
+
split_scheme = seeding_scheme.split("-")
|
| 49 |
+
prf_type = str(split_scheme[1])
|
| 50 |
+
context_width = int(split_scheme[2])
|
| 51 |
+
self_salt = split_scheme[3] == "True"
|
| 52 |
+
if len(split_scheme) == 5:
|
| 53 |
+
hash_key = int(split_scheme[4])
|
| 54 |
+
else:
|
| 55 |
+
hash_key = 15485863
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
|
| 58 |
+
|
| 59 |
+
assert prf_type in prf_lookup.keys()
|
| 60 |
+
return prf_type, context_width, self_salt, hash_key
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
| 64 |
+
return salt_key * input_ids.prod().item()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
| 68 |
+
return salt_key * input_ids.sum().item()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
| 72 |
+
# 对于非随机输入 id(如文本),这不是一个好主意
|
| 73 |
+
return salt_key * input_ids.min().item()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
|
| 77 |
+
# k是一个跳跃的距离
|
| 78 |
+
return hashint(salt_key * input_ids[::k]).prod().item()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
| 82 |
+
# # 上下文内的最大距离跳字
|
| 83 |
+
return hashint(salt_key * input_ids[0]).item()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
|
| 87 |
+
# 上下文内的最大距离跳字
|
| 88 |
+
return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
|
| 92 |
+
return hashint(salt_key * input_ids).min().item()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
|
| 96 |
+
# 另一个关键是生成一个key
|
| 97 |
+
return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
| 101 |
+
# 上下文中所有跳字组合的最小值,k=2 表示所有对
|
| 102 |
+
skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
|
| 103 |
+
return skipgrams.prod(dim=1).min().item()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
| 107 |
+
key = torch.as_tensor(salt_key, dtype=torch.long)
|
| 108 |
+
for entry in input_ids:
|
| 109 |
+
key *= hashint(key * entry)
|
| 110 |
+
key %= 2**32
|
| 111 |
+
return key.item()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
|
| 115 |
+
return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
prf_lookup = {
|
| 119 |
+
"multiplicative_prf": multiplicative_prf,
|
| 120 |
+
"additive_prf": additive_prf,
|
| 121 |
+
"minfunc_prf": minfunc_prf,
|
| 122 |
+
"simple_skip_prf": simple_skip_prf,
|
| 123 |
+
"skipgram_prf": skipgram_prf,
|
| 124 |
+
"anchored_skipgram_prf": anchored_skipgram_prf,
|
| 125 |
+
"minhash_prf": minhash_prf,
|
| 126 |
+
"anchored_minhash_prf": anchored_minhash_prf,
|
| 127 |
+
"minskipgram_prf": minskipgram_prf,
|
| 128 |
+
"noncomm_prf": noncomm_prf,
|
| 129 |
+
"position_prf": position_prf,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# 在启动时生成全局置换表一次
|
| 133 |
+
rng = torch.Generator(device=torch.device("cpu"))
|
| 134 |
+
rng.manual_seed(2971215073)
|
| 135 |
+
table_size = 1_000_003
|
| 136 |
+
fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng) # 这个速度很快
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
|
| 140 |
+
|
| 141 |
+
return fixed_table[integer_tensor.cpu() % table_size] + 1 # 这里有一个小技巧,这个函数总是返回 CPU 的值
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
|
| 145 |
+
|
| 146 |
+
i = integer_tensor.to(torch.int32).clone() # or torch.int16?
|
| 147 |
+
i -= i << 6
|
| 148 |
+
i ^= i >> 17
|
| 149 |
+
i -= i << 9
|
| 150 |
+
i ^= i << 4
|
| 151 |
+
i -= i << 3
|
| 152 |
+
i ^= i << 10
|
| 153 |
+
i ^= i >> 15
|
| 154 |
+
return i.to(torch.long)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@cache
|
| 158 |
+
def _hashint_avalanche_int(integer: int):
|
| 159 |
+
i = integer % (2**32)
|
| 160 |
+
i -= i << 6
|
| 161 |
+
i ^= i >> 17
|
| 162 |
+
i -= i << 9
|
| 163 |
+
i ^= i << 4
|
| 164 |
+
i -= i << 3
|
| 165 |
+
i ^= i << 10
|
| 166 |
+
i ^= i >> 15
|
| 167 |
+
return i
|
app.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 安装好环境
|
| 2 |
+
# python app.py即可运行
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
os.environ['HF_ENDPOINT']='https://hf-mirror.com'
|
| 6 |
+
|
| 7 |
+
from argparse import Namespace
|
| 8 |
+
args = Namespace()
|
| 9 |
+
|
| 10 |
+
arg_dict = {
|
| 11 |
+
'run_gradio': True,
|
| 12 |
+
'demo_public': False,
|
| 13 |
+
'model_name_or_path': 'Qwen/Qwen2-0.5B-Instruct',
|
| 14 |
+
|
| 15 |
+
# 'model_name_or_path': 'Qwen/Qwen2-0.5B-Instruct-GGUF',
|
| 16 |
+
'gguf_file': './qwen2-0_5b-instruct-q8_0.gguf', # 只有用gguf模型会用到,即model_name_or_path里含有gguf字符串才会用到
|
| 17 |
+
'prompt_max_length': None,
|
| 18 |
+
'max_new_tokens': 500,
|
| 19 |
+
'generation_seed': 123,
|
| 20 |
+
'use_sampling': True,
|
| 21 |
+
'n_beams': 1,
|
| 22 |
+
'sampling_temp': 0.7,
|
| 23 |
+
'use_gpu': True,
|
| 24 |
+
'seeding_scheme': 'simple_1',
|
| 25 |
+
'gamma': 0.5,
|
| 26 |
+
'delta': 2.0,
|
| 27 |
+
'normalizers': '',
|
| 28 |
+
'ignore_repeated_bigrams': False,
|
| 29 |
+
'detection_z_threshold': 4.0,
|
| 30 |
+
'select_green_tokens': True,
|
| 31 |
+
'skip_model_load': False,
|
| 32 |
+
'seed_separately': True,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
args.__dict__.update(arg_dict)
|
| 36 |
+
|
| 37 |
+
from demo_watermark import main
|
| 38 |
+
|
| 39 |
+
main(args)
|
demo_watermark.py
ADDED
|
@@ -0,0 +1,1085 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import gradio.exceptions
|
| 7 |
+
|
| 8 |
+
# 设置OpenMP线程数为8,优化CPU并行计算性能
|
| 9 |
+
os.environ["OMP_NUM_THREADS"] = "8"
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import (AutoTokenizer,
|
| 15 |
+
AutoModelForCausalLM,
|
| 16 |
+
LogitsProcessorList)
|
| 17 |
+
|
| 18 |
+
from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
|
| 19 |
+
|
| 20 |
+
# FIXME 所有模型的正确长度
|
| 21 |
+
|
| 22 |
+
API_MODEL_MAP = {
|
| 23 |
+
# "Qwen/Qwen1.5-0.5B-Chat": {"max_length": 2000, "gamma": 0.5, "delta": 2.0},
|
| 24 |
+
# "THUDM/chatglm3-6b": {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
default_trace_table = pd.DataFrame(columns=["编号", "水印内容"])
|
| 28 |
+
default_trace_table.loc[0] = (0, "默认用户")
|
| 29 |
+
default_trace_table.loc[1] = (1, "张三")
|
| 30 |
+
default_trace_table.loc[2] = (2, "李四")
|
| 31 |
+
|
| 32 |
+
watermark_salt = 0
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def str2bool(v):
|
| 36 |
+
"""用户友好的布尔标志参数的Util函数"""
|
| 37 |
+
if isinstance(v, bool):
|
| 38 |
+
return v
|
| 39 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 40 |
+
return True
|
| 41 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 42 |
+
return False
|
| 43 |
+
else:
|
| 44 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# 定义一个函数用于解析命令行参数
|
| 48 |
+
def parse_args():
|
| 49 |
+
parser = argparse.ArgumentParser(
|
| 50 |
+
description="")
|
| 51 |
+
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--run_gradio",
|
| 54 |
+
type=str2bool,
|
| 55 |
+
default=True,
|
| 56 |
+
help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--demo_public",
|
| 60 |
+
type=str2bool,
|
| 61 |
+
default=False,
|
| 62 |
+
help="Whether to expose the gradio demo to the internet.",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--model_name_or_path",
|
| 66 |
+
type=str,
|
| 67 |
+
default="Qwen/Qwen1.5-0.5B-Chat",
|
| 68 |
+
help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--prompt_max_length",
|
| 72 |
+
type=int,
|
| 73 |
+
default=None,
|
| 74 |
+
help="Truncation length for prompt, overrides model config's max length field.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--max_new_tokens",
|
| 78 |
+
type=int,
|
| 79 |
+
default=200,
|
| 80 |
+
help="Maximmum number of new tokens to generate.",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--generation_seed",
|
| 84 |
+
type=int,
|
| 85 |
+
default=123,
|
| 86 |
+
help="Seed for setting the torch global rng prior to generation.",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--use_sampling",
|
| 90 |
+
type=str2bool,
|
| 91 |
+
default=True,
|
| 92 |
+
help="Whether to generate using multinomial sampling.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--sampling_temp",
|
| 96 |
+
type=float,
|
| 97 |
+
default=0.7,
|
| 98 |
+
help="Sampling temperature to use when generating using multinomial sampling.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--n_beams",
|
| 102 |
+
type=int,
|
| 103 |
+
default=1,
|
| 104 |
+
help="Number of beams to use for beam search. 1 is normal greedy decoding",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--use_gpu",
|
| 108 |
+
type=str2bool,
|
| 109 |
+
default=True,
|
| 110 |
+
help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--seeding_scheme",
|
| 114 |
+
type=str,
|
| 115 |
+
default="simple_1",
|
| 116 |
+
help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--gamma",
|
| 120 |
+
type=float,
|
| 121 |
+
default=0.5,
|
| 122 |
+
help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--delta",
|
| 126 |
+
type=float,
|
| 127 |
+
default=2.0,
|
| 128 |
+
help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--normalizers",
|
| 132 |
+
type=str,
|
| 133 |
+
default="",
|
| 134 |
+
help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--ignore_repeated_bigrams",
|
| 138 |
+
type=str2bool,
|
| 139 |
+
default=False,
|
| 140 |
+
help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--detection_z_threshold",
|
| 144 |
+
type=float,
|
| 145 |
+
default=4.0,
|
| 146 |
+
help="The test statistic threshold for the detection hypothesis test.",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--select_green_tokens",
|
| 150 |
+
type=str2bool,
|
| 151 |
+
default=True,
|
| 152 |
+
help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--skip_model_load",
|
| 156 |
+
type=str2bool,
|
| 157 |
+
default=False,
|
| 158 |
+
help="Skip the model loading to debug the interface.",
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"--gguf_file",
|
| 162 |
+
type=str,
|
| 163 |
+
default='./qwen2-0_5b-instruct-q2_k.gguf',
|
| 164 |
+
help="gguf文件(如果有)",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--seed_separately",
|
| 169 |
+
type=str2bool,
|
| 170 |
+
default=True,
|
| 171 |
+
help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
|
| 172 |
+
)
|
| 173 |
+
args = parser.parse_args()
|
| 174 |
+
|
| 175 |
+
return args
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def load_model(args):
|
| 179 |
+
"""加载并返回模型和分词器"""
|
| 180 |
+
|
| 181 |
+
if args.use_gpu:
|
| 182 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 183 |
+
else:
|
| 184 |
+
device = "cpu"
|
| 185 |
+
|
| 186 |
+
if 'gguf' in args.model_name_or_path.lower():
|
| 187 |
+
|
| 188 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file,
|
| 189 |
+
trust_remote_code=True,
|
| 190 |
+
local_files_only=True,
|
| 191 |
+
device_map=device)
|
| 192 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, gguf_file=args.gguf_file,
|
| 193 |
+
local_files_only=True,
|
| 194 |
+
trust_remote_code=True)
|
| 195 |
+
|
| 196 |
+
else:
|
| 197 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
|
| 198 |
+
trust_remote_code=True,
|
| 199 |
+
local_files_only=True,
|
| 200 |
+
device_map=device)
|
| 201 |
+
|
| 202 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,
|
| 203 |
+
local_files_only=True, )
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
model.eval()
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(e)
|
| 209 |
+
|
| 210 |
+
return model, tokenizer, device
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
from text_generation import InferenceAPIClient
|
| 214 |
+
from requests.exceptions import ReadTimeout
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def generate_with_api(prompt, args):
|
| 218 |
+
hf_api_key = os.environ.get("HF_API_KEY")
|
| 219 |
+
if hf_api_key is None:
|
| 220 |
+
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
|
| 221 |
+
|
| 222 |
+
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
|
| 223 |
+
|
| 224 |
+
assert args.n_beams == 1, "HF API models do not support beam search."
|
| 225 |
+
generation_params = {
|
| 226 |
+
"max_new_tokens": args.max_new_tokens,
|
| 227 |
+
"do_sample": args.use_sampling,
|
| 228 |
+
}
|
| 229 |
+
if args.use_sampling:
|
| 230 |
+
generation_params["temperature"] = args.sampling_temp
|
| 231 |
+
generation_params["seed"] = args.generation_seed
|
| 232 |
+
|
| 233 |
+
timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
|
| 234 |
+
try:
|
| 235 |
+
generation_params["watermark"] = False
|
| 236 |
+
without_watermark_iterator = client.generate_stream(prompt, **generation_params)
|
| 237 |
+
except ReadTimeout as e:
|
| 238 |
+
print(e)
|
| 239 |
+
without_watermark_iterator = (char for char in timeout_msg)
|
| 240 |
+
try:
|
| 241 |
+
generation_params["watermark"] = True
|
| 242 |
+
with_watermark_iterator = client.generate_stream(prompt, **generation_params)
|
| 243 |
+
except ReadTimeout as e:
|
| 244 |
+
print(e)
|
| 245 |
+
with_watermark_iterator = (char for char in timeout_msg)
|
| 246 |
+
|
| 247 |
+
all_without_words, all_with_words = "", ""
|
| 248 |
+
for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
|
| 249 |
+
all_without_words += without_word.token.text
|
| 250 |
+
all_with_words += with_word.token.text
|
| 251 |
+
yield all_without_words, all_with_words
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def check_prompt(prompt, args, tokenizer, model=None, device=None):
|
| 255 |
+
# 这适用于本地和API模型场景
|
| 256 |
+
try:
|
| 257 |
+
if args.model_name_or_path in API_MODEL_MAP:
|
| 258 |
+
args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
|
| 259 |
+
elif hasattr(model.config, "max_position_embedding"):
|
| 260 |
+
args.prompt_max_length = model.config.max_position_embeddings - args.max_new_tokens
|
| 261 |
+
else:
|
| 262 |
+
args.prompt_max_length = 4096 - args.max_new_tokens
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(e)
|
| 265 |
+
args.prompt_max_length = 4096 - args.max_new_tokens
|
| 266 |
+
|
| 267 |
+
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False, truncation=True,
|
| 268 |
+
max_length=args.prompt_max_length).to(device)
|
| 269 |
+
truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
|
| 270 |
+
redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
|
| 271 |
+
|
| 272 |
+
return (redecoded_input,
|
| 273 |
+
int(truncation_warning),
|
| 274 |
+
args)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def generate(prompt, args, tokenizer, model=None, device=None):
|
| 278 |
+
"""根据水印参数实例化 WatermarkLogitsProcessor 并通过将其作为 logits 处理器传递给模型的 generate 方法来生成带水印的文本。"""
|
| 279 |
+
print(f"Generating with {args}")
|
| 280 |
+
print(f"Prompt: {prompt}")
|
| 281 |
+
|
| 282 |
+
if args.model_name_or_path in API_MODEL_MAP:
|
| 283 |
+
api_outputs = generate_with_api(prompt, args)
|
| 284 |
+
yield from api_outputs
|
| 285 |
+
else:
|
| 286 |
+
if 'chatglm' in args.model_name_or_path.lower() or 'qwen' in args.model_name_or_path.lower() or 'llama' in args.model_name_or_path.lower():
|
| 287 |
+
messages = [
|
| 288 |
+
# {"role": "system", "content": "You are a helpful assistant."},
|
| 289 |
+
{"role": "user", "content": prompt}
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
tokenized_input = tokenizer.apply_chat_template(
|
| 293 |
+
messages,
|
| 294 |
+
tokenize=False,
|
| 295 |
+
add_generation_prompt=True
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
tokd_input = tokenizer([tokenized_input], return_tensors="pt", truncation=True, add_special_tokens=False,
|
| 299 |
+
max_length=args.prompt_max_length).to(device)
|
| 300 |
+
|
| 301 |
+
else:
|
| 302 |
+
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True,
|
| 303 |
+
max_length=args.prompt_max_length).to(device)
|
| 304 |
+
|
| 305 |
+
gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
|
| 306 |
+
|
| 307 |
+
if args.use_sampling:
|
| 308 |
+
gen_kwargs.update(dict(
|
| 309 |
+
do_sample=True,
|
| 310 |
+
top_k=0,
|
| 311 |
+
temperature=args.sampling_temp
|
| 312 |
+
))
|
| 313 |
+
else:
|
| 314 |
+
gen_kwargs.update(dict(
|
| 315 |
+
num_beams=args.n_beams
|
| 316 |
+
))
|
| 317 |
+
|
| 318 |
+
watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
|
| 319 |
+
gamma=args.gamma,
|
| 320 |
+
delta=args.delta,
|
| 321 |
+
seeding_scheme=args.seeding_scheme,
|
| 322 |
+
extra_salt=watermark_salt,
|
| 323 |
+
select_green_tokens=args.select_green_tokens)
|
| 324 |
+
|
| 325 |
+
generate_without_watermark = partial(
|
| 326 |
+
model.generate,
|
| 327 |
+
**gen_kwargs
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
generate_with_watermark = partial(
|
| 331 |
+
model.generate,
|
| 332 |
+
logits_processor=LogitsProcessorList([watermark_processor]),
|
| 333 |
+
**gen_kwargs
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
start_time = time.time()
|
| 337 |
+
gr.Info('开始生成正常内容')
|
| 338 |
+
torch.manual_seed(args.generation_seed)
|
| 339 |
+
output_without_watermark = generate_without_watermark(**tokd_input)
|
| 340 |
+
|
| 341 |
+
# 可选择在第二次生成之前种子,但通常不会再次相同,除非 delta==0.0,无操作水印
|
| 342 |
+
|
| 343 |
+
print(watermark_salt)
|
| 344 |
+
print(default_trace_table)
|
| 345 |
+
print(default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'])
|
| 346 |
+
gr.Info('开始注入水印:“{}”'.format(
|
| 347 |
+
default_trace_table.loc[default_trace_table['编号'] == watermark_salt, '水印内容'].item()))
|
| 348 |
+
if args.seed_separately:
|
| 349 |
+
torch.manual_seed(args.generation_seed)
|
| 350 |
+
|
| 351 |
+
output_with_watermark = generate_with_watermark(**tokd_input)
|
| 352 |
+
|
| 353 |
+
output_without_watermark = output_without_watermark[:, tokd_input["input_ids"].shape[-1]:]
|
| 354 |
+
output_with_watermark = output_with_watermark[:, tokd_input["input_ids"].shape[-1]:]
|
| 355 |
+
|
| 356 |
+
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
|
| 357 |
+
decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
|
| 358 |
+
|
| 359 |
+
end_time = time.time()
|
| 360 |
+
gr.Info(f"生成结束,共用时{end_time - start_time:.2f}秒")
|
| 361 |
+
|
| 362 |
+
print(f"Generation took {end_time - start_time:.2f} seconds")
|
| 363 |
+
|
| 364 |
+
# 使用空格分隔生成器风格模拟 API 输出
|
| 365 |
+
|
| 366 |
+
all_without_words, all_with_words = "", ""
|
| 367 |
+
for without_word, with_word in zip(decoded_output_without_watermark.split(),
|
| 368 |
+
decoded_output_with_watermark.split()):
|
| 369 |
+
all_without_words += without_word + " "
|
| 370 |
+
all_with_words += with_word + " "
|
| 371 |
+
yield all_without_words, all_with_words
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def format_names(s):
|
| 375 |
+
"""为 gradio 演示界面格式化名称"""
|
| 376 |
+
s = s.replace("num_tokens_scored", "总Token")
|
| 377 |
+
s = s.replace("num_green_tokens", "Green Token数量")
|
| 378 |
+
s = s.replace("green_fraction", "Green Token占比")
|
| 379 |
+
s = s.replace("z_score", "z-score")
|
| 380 |
+
s = s.replace("p_value", "p value")
|
| 381 |
+
s = s.replace("prediction", "预测结果")
|
| 382 |
+
s = s.replace("confidence", "置信度")
|
| 383 |
+
return s
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def list_format_scores(score_dict, detection_threshold):
|
| 387 |
+
"""将检测指标格式化为 gradio 数据框输入格式"""
|
| 388 |
+
lst_2d = []
|
| 389 |
+
for k, v in score_dict.items():
|
| 390 |
+
if k == 'green_fraction':
|
| 391 |
+
lst_2d.append([format_names(k), f"{v:.1%}"])
|
| 392 |
+
elif k == 'confidence':
|
| 393 |
+
lst_2d.append([format_names(k), f"{v:.3%}"])
|
| 394 |
+
elif isinstance(v, float):
|
| 395 |
+
lst_2d.append([format_names(k), f"{v:.3g}"])
|
| 396 |
+
elif isinstance(v, bool):
|
| 397 |
+
lst_2d.append([format_names(k), ("含有水印" if v else "无水印")])
|
| 398 |
+
else:
|
| 399 |
+
lst_2d.append([format_names(k), f"{v}"])
|
| 400 |
+
if "confidence" in score_dict:
|
| 401 |
+
lst_2d.insert(-2, ["z-score Threshold", f"{detection_threshold}"])
|
| 402 |
+
else:
|
| 403 |
+
lst_2d.insert(-1, ["z-score Threshold", f"{detection_threshold}"])
|
| 404 |
+
return lst_2d
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
|
| 408 |
+
"""实例化 WatermarkDetection 对象并调用 detect 方法 在输入文本上返回测试的分数和结果"""
|
| 409 |
+
|
| 410 |
+
print(f"Detecting with {args}")
|
| 411 |
+
print(f"Detection Tokenizer: {type(tokenizer)}")
|
| 412 |
+
|
| 413 |
+
# 现在不要显示绿色的token mask
|
| 414 |
+
# 如果我们使用的是normalizers或ignore_repeated_bigrams
|
| 415 |
+
if args.normalizers != [] or args.ignore_repeated_bigrams:
|
| 416 |
+
return_green_token_mask = False
|
| 417 |
+
|
| 418 |
+
error = False
|
| 419 |
+
green_token_mask = None
|
| 420 |
+
if input_text == "":
|
| 421 |
+
error = True
|
| 422 |
+
else:
|
| 423 |
+
try:
|
| 424 |
+
for _, data in default_trace_table.iterrows():
|
| 425 |
+
salt = data["编号"]
|
| 426 |
+
name = data["水印内容"]
|
| 427 |
+
watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
|
| 428 |
+
gamma=args.gamma,
|
| 429 |
+
seeding_scheme=args.seeding_scheme,
|
| 430 |
+
extra_salt=salt,
|
| 431 |
+
device=device,
|
| 432 |
+
tokenizer=tokenizer,
|
| 433 |
+
z_threshold=args.detection_z_threshold,
|
| 434 |
+
normalizers=args.normalizers,
|
| 435 |
+
ignore_repeated_bigrams=args.ignore_repeated_bigrams,
|
| 436 |
+
select_green_tokens=args.select_green_tokens)
|
| 437 |
+
score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
|
| 438 |
+
if score_dict['prediction']:
|
| 439 |
+
print(f"检测到是“{name}”的水印")
|
| 440 |
+
break
|
| 441 |
+
|
| 442 |
+
green_token_mask = score_dict.pop("green_token_mask", None)
|
| 443 |
+
output = list_format_scores(score_dict, watermark_detector.z_threshold)
|
| 444 |
+
except ValueError as e:
|
| 445 |
+
print(e)
|
| 446 |
+
error = True
|
| 447 |
+
if error:
|
| 448 |
+
output = [["Error", "string too short to compute metrics"]]
|
| 449 |
+
output += [["", ""] for _ in range(6)]
|
| 450 |
+
|
| 451 |
+
html_output = "[No highlight markup generated]"
|
| 452 |
+
|
| 453 |
+
if green_token_mask is None:
|
| 454 |
+
html_output = "[Visualizing masks with ignore_repeated_bigrams enabled is not supported, toggle off to see the mask for this text. The mask is the same in both cases - only counting/stats are affected.]"
|
| 455 |
+
|
| 456 |
+
if green_token_mask is not None:
|
| 457 |
+
# hack 因为我们需要一个带有字符跨度支持的快速分词器
|
| 458 |
+
tokens = tokenizer(input_text, add_special_tokens=False)
|
| 459 |
+
if tokens["input_ids"][0] == tokenizer.bos_token_id:
|
| 460 |
+
tokens["input_ids"] = tokens["input_ids"][1:] # 忽略注意力掩码
|
| 461 |
+
skip = watermark_detector.min_prefix_len
|
| 462 |
+
|
| 463 |
+
if args.model_name_or_path in ['THUDM/chatglm3-6b']:
|
| 464 |
+
# 假设词表中3-258就是字节0-255
|
| 465 |
+
charspans = []
|
| 466 |
+
for i in range(skip, len(tokens["input_ids"])):
|
| 467 |
+
if tokens.data['input_ids'][i - 1] in range(3, 259):
|
| 468 |
+
charspans.append("<0x{:X}>".format(tokens.data['input_ids'][i - 1] - 3))
|
| 469 |
+
else:
|
| 470 |
+
charspans.append(tokenizer.decode(tokens.data['input_ids'][i - 1:i]))
|
| 471 |
+
|
| 472 |
+
else:
|
| 473 |
+
charspans = [tokens.token_to_chars(i - 1) for i in range(skip, len(tokens["input_ids"]))]
|
| 474 |
+
|
| 475 |
+
charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
|
| 476 |
+
|
| 477 |
+
if len(charspans) != len(green_token_mask): breakpoint()
|
| 478 |
+
assert len(charspans) == len(green_token_mask)
|
| 479 |
+
|
| 480 |
+
if args.model_name_or_path in ['THUDM/chatglm3-6b']:
|
| 481 |
+
tags = []
|
| 482 |
+
for cs, m in zip(charspans, green_token_mask):
|
| 483 |
+
tags.append(
|
| 484 |
+
f'<span class="green">{cs}</span>' if m else f'<span class="red">{cs}</span>')
|
| 485 |
+
|
| 486 |
+
else:
|
| 487 |
+
tags = [(
|
| 488 |
+
f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>')
|
| 489 |
+
for cs, m in zip(charspans, green_token_mask)]
|
| 490 |
+
|
| 491 |
+
html_output = f'<p>{" ".join(tags)}</p>'
|
| 492 |
+
|
| 493 |
+
if score_dict['prediction']:
|
| 494 |
+
html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(255, 0, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold;">
|
| 495 |
+
<span>有 “{}” 的水印</span>
|
| 496 |
+
</div>""".format(name), visible=True)
|
| 497 |
+
|
| 498 |
+
else:
|
| 499 |
+
html_look = gr.HTML("""<div style="width: 100%; font-size: 24px; height: 100px; border-radius: 20px; background-color: rgba(0, 128, 0, 0.25); display: flex; justify-content: center; align-items: center; color: white; font-weight: bold; text-align: center;">
|
| 500 |
+
<span>无水印</span>
|
| 501 |
+
</div>""", visible=True)
|
| 502 |
+
|
| 503 |
+
return output, args, tokenizer, html_output, html_look
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def run_gradio(args, model=None, device=None, tokenizer=None):
|
| 507 |
+
"""定义并启动gradio演示界面"""
|
| 508 |
+
check_prompt_partial = partial(check_prompt, model=model, device=device)
|
| 509 |
+
generate_partial = partial(generate, model=model, device=device)
|
| 510 |
+
detect_partial = partial(detect, device=device)
|
| 511 |
+
|
| 512 |
+
css = """
|
| 513 |
+
.green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
|
| 514 |
+
.red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
|
| 515 |
+
"""
|
| 516 |
+
|
| 517 |
+
with gr.Blocks(theme='ParityError/Interstellar', css=css) as demo:
|
| 518 |
+
# 顶部部分,问候语和说明
|
| 519 |
+
with gr.Row():
|
| 520 |
+
with gr.Column(scale=9):
|
| 521 |
+
gr.Markdown(
|
| 522 |
+
"""
|
| 523 |
+
# 🌸🖼️ LLMwatermark:面向大语言模型生成内容的数字水印版权保护系统 🌟🎓
|
| 524 |
+
"""
|
| 525 |
+
)
|
| 526 |
+
with gr.Column(scale=1):
|
| 527 |
+
# 如果启动时的 model_name_or_path 不是 API 模型之一,则添加到下拉菜单中
|
| 528 |
+
all_models = sorted(list(set(list(API_MODEL_MAP.keys()) + [args.model_name_or_path])))
|
| 529 |
+
model_selector = gr.Dropdown(
|
| 530 |
+
all_models,
|
| 531 |
+
value=args.model_name_or_path,
|
| 532 |
+
label="选择大语言模型,进行模型水印",
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# 构建参数的状态,定义更新和切换
|
| 536 |
+
default_prompt = args.__dict__.pop("default_prompt")
|
| 537 |
+
session_args = gr.State(value=args)
|
| 538 |
+
# 注意,如果状态对象是可调用的,则自动调用 value,希望在启动时避免调用分词器
|
| 539 |
+
session_tokenizer = gr.State(value=lambda: tokenizer)
|
| 540 |
+
|
| 541 |
+
with gr.Tab("生成回答和添加文本水印🎓"):
|
| 542 |
+
|
| 543 |
+
with gr.Row():
|
| 544 |
+
with gr.Column(scale=5):
|
| 545 |
+
prompt = gr.Textbox(label=f"Prompt", interactive=True, lines=3, max_lines=10, value=default_prompt)
|
| 546 |
+
with gr.Column(scale=3):
|
| 547 |
+
trace_source = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
|
| 548 |
+
col_count=(2, "fixed"))
|
| 549 |
+
with gr.Row(equal_height=True):
|
| 550 |
+
with gr.Column(scale=7):
|
| 551 |
+
generate_btn = gr.Button("Generate", variant='primary')
|
| 552 |
+
|
| 553 |
+
gr.Markdown('水印选择:',
|
| 554 |
+
show_label=False)
|
| 555 |
+
watermark_salt_choice = gr.Dropdown(
|
| 556 |
+
choices=[i[::-1] for i in default_trace_table.to_dict(orient='split')['data']],
|
| 557 |
+
value=0,
|
| 558 |
+
container=False,
|
| 559 |
+
scale=3,
|
| 560 |
+
type="value",
|
| 561 |
+
interactive=True, label="水印标识选择")
|
| 562 |
+
|
| 563 |
+
with gr.Row():
|
| 564 |
+
with gr.Column():
|
| 565 |
+
with gr.Column(scale=2):
|
| 566 |
+
with gr.Tab("原版输出"):
|
| 567 |
+
output_without_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
|
| 568 |
+
show_label=False)
|
| 569 |
+
with gr.Tab("显示水印"):
|
| 570 |
+
html_without_watermark = gr.HTML(elem_id="html-without-watermark")
|
| 571 |
+
|
| 572 |
+
original_watermark_state = gr.HTML('', visible=False)
|
| 573 |
+
|
| 574 |
+
with gr.Column(scale=1):
|
| 575 |
+
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],
|
| 576 |
+
interactive=False,
|
| 577 |
+
row_count=7, col_count=2)
|
| 578 |
+
with gr.Column():
|
| 579 |
+
with gr.Column(scale=2):
|
| 580 |
+
with gr.Tab("带水印的输出"):
|
| 581 |
+
output_with_watermark = gr.Textbox(interactive=False, lines=7, max_lines=14,
|
| 582 |
+
show_label=False)
|
| 583 |
+
with gr.Tab("显示水印"):
|
| 584 |
+
html_with_watermark = gr.HTML(elem_id="html-with-watermark")
|
| 585 |
+
|
| 586 |
+
change_watermark_state = gr.HTML('', visible=False)
|
| 587 |
+
|
| 588 |
+
with gr.Column(scale=1):
|
| 589 |
+
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,
|
| 590 |
+
row_count=7, col_count=2)
|
| 591 |
+
|
| 592 |
+
redecoded_input = gr.Textbox(visible=False)
|
| 593 |
+
truncation_warning = gr.Number(visible=False)
|
| 594 |
+
|
| 595 |
+
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
|
| 596 |
+
if truncation_warning:
|
| 597 |
+
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
|
| 598 |
+
else:
|
| 599 |
+
return orig_prompt, args
|
| 600 |
+
|
| 601 |
+
with gr.Tab("检测文本水印功能🎭"):
|
| 602 |
+
with gr.Row():
|
| 603 |
+
with gr.Column(scale=5):
|
| 604 |
+
with gr.Tab("分析文本"):
|
| 605 |
+
detection_input = gr.Textbox(interactive=True, lines=14, max_lines=14, show_label=False)
|
| 606 |
+
with gr.Tab("显示水印"):
|
| 607 |
+
html_detection_input = gr.HTML(elem_id="html-detection-input")
|
| 608 |
+
|
| 609 |
+
detect_watermark_state = gr.HTML('', visible=False)
|
| 610 |
+
with gr.Column(scale=2):
|
| 611 |
+
trace_source2 = gr.Dataframe(default_trace_table, datatype=['number', 'str'], interactive=True,
|
| 612 |
+
col_count=(2, "fixed"))
|
| 613 |
+
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False, row_count=7,
|
| 614 |
+
col_count=2)
|
| 615 |
+
|
| 616 |
+
with gr.Row():
|
| 617 |
+
detect_btn = gr.Button("检测", variant='primary')
|
| 618 |
+
|
| 619 |
+
with gr.Tab("About📖"):
|
| 620 |
+
with gr.Row():
|
| 621 |
+
with gr.Column(scale=2):
|
| 622 |
+
gr.Markdown(
|
| 623 |
+
"""
|
| 624 |
+
大语言模型可能带来的潜在危害可以通过*水印*来减轻。*水印*是嵌入在生成的文本中的信息,
|
| 625 |
+
这对人类来说是不可见的,但是可以被特定算法检测到。
|
| 626 |
+
这些水印可以使*任何人*用特定工具判断其是否使用带水印的模型生成的。
|
| 627 |
+
本网站展示了一种水印方法,可以应用于_任何_生成性语言模型。
|
| 628 |
+
"""
|
| 629 |
+
)
|
| 630 |
+
gr.Markdown(
|
| 631 |
+
"""
|
| 632 |
+
**[生成文本与添加水印]**:可以给大模型的输出添加水印。
|
| 633 |
+
您可以尝试任何prompt,并比较正常文本(*没有水印的输出*)和水印文本(*有水印的输出*)的质量。
|
| 634 |
+
您还可以点击**显示水印**来“看到”水印,其中的颜色表示其所在的红绿表。
|
| 635 |
+
|
| 636 |
+
**[检测]**:您还可以将水印文本(或任何其他文本)复制粘贴到第二个选项卡中。
|
| 637 |
+
可以实验删除多少句子后还能检测到水印。
|
| 638 |
+
还可以在验证,检测器的误报率有多少;
|
| 639 |
+
"""
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
with gr.Column(scale=1):
|
| 643 |
+
gr.Markdown(
|
| 644 |
+
"""
|
| 645 |
+
![]()
|
| 646 |
+
"""
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
# 参数选择组
|
| 650 |
+
with gr.Accordion("高级设置", open=False):
|
| 651 |
+
with gr.Row():
|
| 652 |
+
with gr.Column(scale=1):
|
| 653 |
+
gr.Markdown(f"#### 生成参数")
|
| 654 |
+
with gr.Row():
|
| 655 |
+
decoding = gr.Radio(label="解码方法", choices=["多项式解码方法", "贪心解码方法"],
|
| 656 |
+
value=("multinomial" if args.use_sampling else "greedy"))
|
| 657 |
+
with gr.Row():
|
| 658 |
+
sampling_temp = gr.Slider(label="采样温度", minimum=0.1, maximum=1.0, step=0.1,
|
| 659 |
+
value=args.sampling_temp, visible=True)
|
| 660 |
+
with gr.Row():
|
| 661 |
+
generation_seed = gr.Number(label="生成种子", value=args.generation_seed, interactive=True)
|
| 662 |
+
with gr.Row():
|
| 663 |
+
n_beams = gr.Dropdown(label="波束搜索解码", choices=list(range(1, 11, 1)), value=args.n_beams,
|
| 664 |
+
visible=((not args.use_sampling) and (
|
| 665 |
+
not args.model_name_or_path in API_MODEL_MAP)))
|
| 666 |
+
with gr.Row():
|
| 667 |
+
max_new_tokens = gr.Slider(label="(生成文本的最大长度)Max Generated Tokens", minimum=10,
|
| 668 |
+
maximum=4000, step=10, value=args.max_new_tokens)
|
| 669 |
+
|
| 670 |
+
with gr.Column(scale=1):
|
| 671 |
+
gr.Markdown(f"#### 模型水印参数设置")
|
| 672 |
+
with gr.Row():
|
| 673 |
+
gamma = gr.Slider(label="gamma", minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
|
| 674 |
+
with gr.Row():
|
| 675 |
+
delta = gr.Slider(label="delta", minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
|
| 676 |
+
gr.Markdown(f"#### 检测文本水印参数设置")
|
| 677 |
+
with gr.Row():
|
| 678 |
+
detection_z_threshold = gr.Slider(label="Z分数阈值", minimum=0.0, maximum=10.0, step=0.1,
|
| 679 |
+
value=args.detection_z_threshold)
|
| 680 |
+
with gr.Row():
|
| 681 |
+
ignore_repeated_bigrams = gr.Checkbox(label="避免生成连续重复的双词组合")
|
| 682 |
+
with gr.Row():
|
| 683 |
+
normalizers = gr.CheckboxGroup(label="对文本进行标准化处理",
|
| 684 |
+
choices=["unicode", "homoglyphs", "truecase"],
|
| 685 |
+
value=args.normalizers)
|
| 686 |
+
with gr.Row():
|
| 687 |
+
gr.Markdown(
|
| 688 |
+
f"注意:滑块并不总是能完美更新。点击条形图或使用右侧的数字窗口会有所帮助。下面的窗口显示当前设置。")
|
| 689 |
+
with gr.Row():
|
| 690 |
+
current_parameters = gr.Textbox(label="当前参数设置", value=args, max_lines=10)
|
| 691 |
+
with gr.Accordion("传统设置", open=False):
|
| 692 |
+
with gr.Row():
|
| 693 |
+
with gr.Column(scale=1):
|
| 694 |
+
seed_separately = gr.Checkbox(label="为两个不同的生成过程分别设置随机种子",
|
| 695 |
+
value=args.seed_separately)
|
| 696 |
+
with gr.Column(scale=1):
|
| 697 |
+
select_green_tokens = gr.Checkbox(label="从分区中选择 绿色列表", value=args.select_green_tokens)
|
| 698 |
+
|
| 699 |
+
with gr.Accordion("设置有什么作用?", open=False):
|
| 700 |
+
gr.Markdown(
|
| 701 |
+
"""
|
| 702 |
+
#### 生成参数:
|
| 703 |
+
|
| 704 |
+
- **解码方法**:我们可以使用多项式采样或者贪婪解码的方式从模型中生成标记。
|
| 705 |
+
决定如何从模型中生成token。可以选择多项式采样或贪婪解码。
|
| 706 |
+
多项式采样允许一定的随机性,而贪婪解码总是选择概率最高的下一个token。
|
| 707 |
+
- **采样温度**:如果使用多项式采样,我们可以设置采样分布的温度。
|
| 708 |
+
- 0.0 相当于贪婪解码,而 1.0 代表最大的随机性。
|
| 709 |
+
- 0.7是文本质量和随机性之间的平衡点。不适用于贪婪解码。
|
| 710 |
+
- **生成种子**:用于在生成前初始化随机数生成器,使多项式采样的输出可复现。此设置不适用于贪婪解码。
|
| 711 |
+
- **束搜索数量**:在使用贪婪解码时,可以设置光束数量来启用束搜索。
|
| 712 |
+
这允许考虑多个候选序列,而不是只选择最有可能的一个。此设置目前仅适用于贪婪解码。
|
| 713 |
+
- **最大生成标记数**:传递给生成方法的 `max_new_tokens` 参数,以在一定数量的新标记停止输出。
|
| 714 |
+
- 请注意,根据提示,模型可以生成较少的标记。隐含地,这将最大化可能的提示标记数,即模型的最大输入长度减去 `max_new_tokens`,并相应地截断输入。
|
| 715 |
+
|
| 716 |
+
综上所述,这些参数提供了对生成过程的不同方面的控制,包括随机性和多样性
|
| 717 |
+
(解码方法、采样温度),可复现性(生成种子),以及输出长度和多样性(光束数量、最大生成token数)。
|
| 718 |
+
合理配置这些参数可以帮助生成高质量的文本输出。
|
| 719 |
+
|
| 720 |
+
#### 水印参数:
|
| 721 |
+
|
| 722 |
+
- **gamma**:在每个生成步骤中将词汇表的一部分划分为绿色列表的比例。
|
| 723 |
+
- 较小的 gamma 值通过使水印模型优先从较小的绿色集中采样,从而使其与人类/未水印文本的差异更大,从而创建更强的水印。
|
| 724 |
+
- **delta**:在每个生成步骤中为绿色列表中的每个标记的 logits 添加的正偏差量。较高的 delta 值意味着水印模型更偏好于绿色列表中的标记。
|
| 725 |
+
- 随着偏差变得非常大,水印从 "软" 过渡到 "硬"。对于硬水印,几乎所有标记都是绿色的,但这可能对生成质量产生不利影响,特别是当分布的灵活性不大时。
|
| 726 |
+
|
| 727 |
+
#### 检测器参数:
|
| 728 |
+
|
| 729 |
+
- **z 分数阈值**:假设检验的 z 分数截断。较高的阈值(例如 4.0)使得 _false positives_(预测人类/未水印文本被标记为水印)非常不可能,
|
| 730 |
+
因为一个真正的人类文本几乎永远不会达到那么高的 z 分数。
|
| 731 |
+
- 较低的阈值会捕获更多的 _true positives_,因为一些水印文本可能包含较少的绿色标记并达到较低的 z 分数,但仍然通过较低的门槛并被标记为 "水印"。
|
| 732 |
+
然而,较低的阈值会增加包含略高于平均水平的绿色标记的人类文本错误地被标记为水印的几率。4.0-5.0 提供了极低的假阳性率,同时仍准确捕获大多数水印文本。
|
| 733 |
+
- **忽略二元重复**:这种替代的检测算法在检测期间仅考虑文本中的唯一二元组,根据每对中的第一个计算绿色列表,并检查第二个是否位于列表中。
|
| 734 |
+
- 这意味着 `T` 现在是文本中唯一二元组的数量,如果文本包含大量重复,则这个数字将小于生成的总标记数。有关更详细的讨论,请参阅论文。
|
| 735 |
+
- **归一化**:我们实现了一些基本的归一化来抵御文本在检测期间的各种对抗性扰动。
|
| 736 |
+
- 目前,我们支持将所有字符转换为 Unicode,将同形异义字符替换为规范形式,并标准化大写。有关输入归一化的详细讨论,请参阅论文。
|
| 737 |
+
"""
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
with gr.Accordion("输出指标意味着什么?", open=False):
|
| 741 |
+
gr.Markdown(
|
| 742 |
+
"""
|
| 743 |
+
- `z-score threshold`:假设检验的截止值。
|
| 744 |
+
- `Tokens Counted (T)`:检测算法计算的输出中的标记数量。在简单的、单标记播种方案中,第一个标记被省略了,因为没有办法为其生成绿色列表,因为它没有前缀标记。
|
| 745 |
+
在底部面板描述的“忽略二元重复”检测算法下,如果有很多重复,这个数量可能远少于生成的总标记数。
|
| 746 |
+
- ` Tokens in Greenlist`:观察到落在其相应绿色列表中的标记数量。
|
| 747 |
+
- `Fraction of T in Greenlist`:`# Tokens in Greenlist` / `T`。这应该大约等于人类/未水印文本的 `gamma`。
|
| 748 |
+
- `z-score`:用于检测假设检验的检验统计量。如果大于 `z-score threshold`,则我们“拒绝零假设”,即文本是人类/未水印的,并得出结论它是带水印的。
|
| 749 |
+
- `p value`:在零假设下观察到计算出的 `z-score` 的可能性。这是在不知道水印程序/绿色列表的情况下观察到 `Fraction of T in Greenlist` 的可能性。
|
| 750 |
+
如果这个值极其 _小_,我们可以确信这么多的绿色标记不是由随机机会选择的。
|
| 751 |
+
- `prediction`:假设检验的结果 - 观察到的 `z-score` 是否高于 `z-score threshold`。
|
| 752 |
+
- `confidence`:如果我们拒绝零假设,且 `prediction` 是“带水印”,那么我们报告 1-`p value` 来表示基于这个 `z-score` 观察的检测的置信度。
|
| 753 |
+
"""
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
gr.HTML("""
|
| 757 |
+
<p>本方法可以对任何大模型的输出结果进行水印的操作。
|
| 758 |
+
并且可以对输出结果进行水印检测。
|
| 759 |
+
<p/>
|
| 760 |
+
""")
|
| 761 |
+
|
| 762 |
+
# 注册主要生成标签单击事件,输出生成文本以及编码+重新解码+可能被截断的提示和标志,然后调用检测
|
| 763 |
+
generate_btn.click(fn=check_prompt_partial, inputs=[prompt, session_args, session_tokenizer],
|
| 764 |
+
outputs=[redecoded_input, truncation_warning, session_args]).success(
|
| 765 |
+
fn=generate_partial, inputs=[redecoded_input, session_args, session_tokenizer],
|
| 766 |
+
outputs=[output_without_watermark, output_with_watermark]).success(
|
| 767 |
+
fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
|
| 768 |
+
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
|
| 769 |
+
html_without_watermark, original_watermark_state]).success(
|
| 770 |
+
fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
|
| 771 |
+
outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark,
|
| 772 |
+
change_watermark_state])
|
| 773 |
+
# 如果发生了截断,则显示提示的截断版本
|
| 774 |
+
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input, truncation_warning, prompt, session_args],
|
| 775 |
+
outputs=[prompt, session_args])
|
| 776 |
+
# Register main detection tab click
|
| 777 |
+
detect_btn.click(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
|
| 778 |
+
outputs=[detection_result, session_args, session_tokenizer, html_detection_input,
|
| 779 |
+
detect_watermark_state],
|
| 780 |
+
api_name="detection")
|
| 781 |
+
|
| 782 |
+
# 状态管理逻辑
|
| 783 |
+
# 定义更新回调函数以更改状态字典
|
| 784 |
+
def update_model(session_state, value):
|
| 785 |
+
session_state.model_name_or_path = value
|
| 786 |
+
return session_state
|
| 787 |
+
|
| 788 |
+
def update_sampling_temp(session_state, value):
|
| 789 |
+
session_state.sampling_temp = float(value)
|
| 790 |
+
return session_state
|
| 791 |
+
|
| 792 |
+
def update_generation_seed(session_state, value):
|
| 793 |
+
session_state.generation_seed = int(value)
|
| 794 |
+
return session_state
|
| 795 |
+
|
| 796 |
+
def update_watermark_salt(value):
|
| 797 |
+
global watermark_salt
|
| 798 |
+
if isinstance(value, int):
|
| 799 |
+
watermark_salt = value
|
| 800 |
+
elif value is None:
|
| 801 |
+
watermark_salt = 0
|
| 802 |
+
elif isinstance(value, str) and value.isdigit():
|
| 803 |
+
watermark_salt = int(value)
|
| 804 |
+
else:
|
| 805 |
+
# 不知道为什么会出现这种倒置的情况
|
| 806 |
+
watermark_salt = int(
|
| 807 |
+
default_trace_table.loc[default_trace_table['水印内容'] == value, '编号'].item())
|
| 808 |
+
|
| 809 |
+
def update_trace_source(value):
|
| 810 |
+
global default_trace_table
|
| 811 |
+
try:
|
| 812 |
+
if '' in value.loc[:, '编号'].tolist():
|
| 813 |
+
return value, gr.Dropdown()
|
| 814 |
+
|
| 815 |
+
value.loc[:, '编号'] = value.loc[:, '编号'].astype(int)
|
| 816 |
+
|
| 817 |
+
if default_trace_table.duplicated(subset='编号').any():
|
| 818 |
+
raise gr.Error(f"请检查水印编号,编号不能重复")
|
| 819 |
+
|
| 820 |
+
default_trace_table = value
|
| 821 |
+
|
| 822 |
+
return value, gr.Dropdown(
|
| 823 |
+
choices=[i[::-1] for i in value.to_dict(orient='split')['data']])
|
| 824 |
+
|
| 825 |
+
except ValueError as e:
|
| 826 |
+
if 'invalid literal for int() with base 10' in str(e):
|
| 827 |
+
raise gr.Error(f"请检查水印数据,编号必须是整数:{e}")
|
| 828 |
+
|
| 829 |
+
except gradio.exceptions.Error as e:
|
| 830 |
+
raise e
|
| 831 |
+
|
| 832 |
+
except Exception as e:
|
| 833 |
+
print(type(e))
|
| 834 |
+
raise e
|
| 835 |
+
|
| 836 |
+
def update_gamma(session_state, value):
|
| 837 |
+
session_state.gamma = float(value)
|
| 838 |
+
return session_state
|
| 839 |
+
|
| 840 |
+
def update_delta(session_state, value):
|
| 841 |
+
session_state.delta = float(value)
|
| 842 |
+
return session_state
|
| 843 |
+
|
| 844 |
+
def update_detection_z_threshold(session_state, value):
|
| 845 |
+
session_state.detection_z_threshold = float(value)
|
| 846 |
+
return session_state
|
| 847 |
+
|
| 848 |
+
def update_decoding(session_state, value):
|
| 849 |
+
if value == "multinomial":
|
| 850 |
+
session_state.use_sampling = True
|
| 851 |
+
elif value == "greedy":
|
| 852 |
+
session_state.use_sampling = False
|
| 853 |
+
return session_state
|
| 854 |
+
|
| 855 |
+
def toggle_sampling_vis(value):
|
| 856 |
+
if value == "multinomial":
|
| 857 |
+
return gr.update(visible=True)
|
| 858 |
+
elif value == "greedy":
|
| 859 |
+
return gr.update(visible=False)
|
| 860 |
+
|
| 861 |
+
def toggle_sampling_vis_inv(value):
|
| 862 |
+
if value == "multinomial":
|
| 863 |
+
return gr.update(visible=False)
|
| 864 |
+
elif value == "greedy":
|
| 865 |
+
return gr.update(visible=True)
|
| 866 |
+
|
| 867 |
+
# 如果模型名称在 API 模型列表中,则将 num beams 参数设置为 1 并隐藏 n_beams
|
| 868 |
+
def toggle_vis_for_api_model(value):
|
| 869 |
+
if value in API_MODEL_MAP:
|
| 870 |
+
return gr.update(visible=False)
|
| 871 |
+
else:
|
| 872 |
+
return gr.update(visible=True)
|
| 873 |
+
|
| 874 |
+
def toggle_beams_for_api_model(value, orig_n_beams):
|
| 875 |
+
if value in API_MODEL_MAP:
|
| 876 |
+
return gr.update(value=1)
|
| 877 |
+
else:
|
| 878 |
+
return gr.update(value=orig_n_beams)
|
| 879 |
+
|
| 880 |
+
# 如果模型名称在 API 模型列表中,则将交互参数设置为 false
|
| 881 |
+
def toggle_interactive_for_api_model(value):
|
| 882 |
+
if value in API_MODEL_MAP:
|
| 883 |
+
return gr.update(interactive=False)
|
| 884 |
+
else:
|
| 885 |
+
return gr.update(interactive=True)
|
| 886 |
+
|
| 887 |
+
# 如果模型名称在 API 模型列表中,则根据 API 映射设置 gamma 和 delta
|
| 888 |
+
def toggle_gamma_for_api_model(value, orig_gamma):
|
| 889 |
+
if value in API_MODEL_MAP:
|
| 890 |
+
return gr.update(value=API_MODEL_MAP[value]["gamma"])
|
| 891 |
+
else:
|
| 892 |
+
return gr.update(value=orig_gamma)
|
| 893 |
+
|
| 894 |
+
def toggle_delta_for_api_model(value, orig_delta):
|
| 895 |
+
if value in API_MODEL_MAP:
|
| 896 |
+
return gr.update(value=API_MODEL_MAP[value]["delta"])
|
| 897 |
+
else:
|
| 898 |
+
return gr.update(value=orig_delta)
|
| 899 |
+
|
| 900 |
+
def update_n_beams(session_state, value):
|
| 901 |
+
session_state.n_beams = value;
|
| 902 |
+
return session_state
|
| 903 |
+
|
| 904 |
+
def update_max_new_tokens(session_state, value):
|
| 905 |
+
session_state.max_new_tokens = int(value);
|
| 906 |
+
return session_state
|
| 907 |
+
|
| 908 |
+
def update_ignore_repeated_bigrams(session_state, value):
|
| 909 |
+
session_state.ignore_repeated_bigrams = value;
|
| 910 |
+
return session_state
|
| 911 |
+
|
| 912 |
+
def update_normalizers(session_state, value):
|
| 913 |
+
session_state.normalizers = value;
|
| 914 |
+
return session_state
|
| 915 |
+
|
| 916 |
+
def update_seed_separately(session_state, value):
|
| 917 |
+
session_state.seed_separately = value;
|
| 918 |
+
return session_state
|
| 919 |
+
|
| 920 |
+
def update_select_green_tokens(session_state, value):
|
| 921 |
+
session_state.select_green_tokens = value;
|
| 922 |
+
return session_state
|
| 923 |
+
|
| 924 |
+
def update_tokenizer(model_name_or_path):
|
| 925 |
+
# if model_name_or_path == ALPACA_MODEL_NAME:
|
| 926 |
+
# return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
|
| 927 |
+
# else:
|
| 928 |
+
return AutoTokenizer.from_pretrained(model_name_or_path)
|
| 929 |
+
|
| 930 |
+
def check_model(value):
|
| 931 |
+
return value if (value != "" and value is not None) else args.model_name_or_path
|
| 932 |
+
|
| 933 |
+
# 强制约束模型��能为 null 或空
|
| 934 |
+
# 然后特别附加模型回调函数
|
| 935 |
+
model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
|
| 936 |
+
toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams]
|
| 937 |
+
).then(
|
| 938 |
+
toggle_beams_for_api_model, inputs=[model_selector, n_beams], outputs=[n_beams]
|
| 939 |
+
).then(
|
| 940 |
+
toggle_interactive_for_api_model, inputs=[model_selector], outputs=[gamma]
|
| 941 |
+
).then(
|
| 942 |
+
toggle_interactive_for_api_model, inputs=[model_selector], outputs=[delta]
|
| 943 |
+
).then(
|
| 944 |
+
toggle_gamma_for_api_model, inputs=[model_selector, gamma], outputs=[gamma]
|
| 945 |
+
).then(
|
| 946 |
+
toggle_delta_for_api_model, inputs=[model_selector, delta], outputs=[delta]
|
| 947 |
+
).then(
|
| 948 |
+
update_tokenizer, inputs=[model_selector], outputs=[session_tokenizer]
|
| 949 |
+
).then(
|
| 950 |
+
update_model, inputs=[session_args, model_selector], outputs=[session_args]
|
| 951 |
+
).then(
|
| 952 |
+
lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
|
| 953 |
+
)
|
| 954 |
+
# 根据其他参数的值注册回调函数以切换特定参数的可见性
|
| 955 |
+
decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[sampling_temp])
|
| 956 |
+
decoding.change(toggle_sampling_vis, inputs=[decoding], outputs=[generation_seed])
|
| 957 |
+
decoding.change(toggle_sampling_vis_inv, inputs=[decoding], outputs=[n_beams])
|
| 958 |
+
decoding.change(toggle_vis_for_api_model, inputs=[model_selector], outputs=[n_beams])
|
| 959 |
+
# 注册所有状态更新回调函数
|
| 960 |
+
decoding.change(update_decoding, inputs=[session_args, decoding], outputs=[session_args])
|
| 961 |
+
sampling_temp.change(update_sampling_temp, inputs=[session_args, sampling_temp], outputs=[session_args])
|
| 962 |
+
generation_seed.change(update_generation_seed, inputs=[session_args, generation_seed], outputs=[session_args])
|
| 963 |
+
watermark_salt_choice.change(update_watermark_salt, inputs=[watermark_salt_choice])
|
| 964 |
+
|
| 965 |
+
# 同步更新
|
| 966 |
+
trace_source.change(update_trace_source, inputs=[trace_source],
|
| 967 |
+
outputs=[trace_source2, watermark_salt_choice])
|
| 968 |
+
trace_source2.change(update_trace_source, inputs=[trace_source2],
|
| 969 |
+
outputs=[trace_source, watermark_salt_choice])
|
| 970 |
+
|
| 971 |
+
n_beams.change(update_n_beams, inputs=[session_args, n_beams], outputs=[session_args])
|
| 972 |
+
max_new_tokens.change(update_max_new_tokens, inputs=[session_args, max_new_tokens], outputs=[session_args])
|
| 973 |
+
gamma.change(update_gamma, inputs=[session_args, gamma], outputs=[session_args])
|
| 974 |
+
delta.change(update_delta, inputs=[session_args, delta], outputs=[session_args])
|
| 975 |
+
detection_z_threshold.change(update_detection_z_threshold, inputs=[session_args, detection_z_threshold],
|
| 976 |
+
outputs=[session_args])
|
| 977 |
+
ignore_repeated_bigrams.change(update_ignore_repeated_bigrams, inputs=[session_args, ignore_repeated_bigrams],
|
| 978 |
+
outputs=[session_args])
|
| 979 |
+
normalizers.change(update_normalizers, inputs=[session_args, normalizers], outputs=[session_args])
|
| 980 |
+
seed_separately.change(update_seed_separately, inputs=[session_args, seed_separately], outputs=[session_args])
|
| 981 |
+
select_green_tokens.change(update_select_green_tokens, inputs=[session_args, select_green_tokens],
|
| 982 |
+
outputs=[session_args])
|
| 983 |
+
# 注册按钮点击时更新显示参数窗口的额外回调
|
| 984 |
+
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 985 |
+
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 986 |
+
# 当参数更改时,显示更新并触发检测,因为某些检测参数不会改变模型输出。
|
| 987 |
+
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 988 |
+
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 989 |
+
gamma.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
|
| 990 |
+
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
|
| 991 |
+
html_without_watermark])
|
| 992 |
+
gamma.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
|
| 993 |
+
outputs=[with_watermark_detection_result, session_args, session_tokenizer, html_with_watermark])
|
| 994 |
+
gamma.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
|
| 995 |
+
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
|
| 996 |
+
detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 997 |
+
detection_z_threshold.change(fn=detect_partial,
|
| 998 |
+
inputs=[output_without_watermark, session_args, session_tokenizer],
|
| 999 |
+
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
|
| 1000 |
+
html_without_watermark])
|
| 1001 |
+
detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
|
| 1002 |
+
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
|
| 1003 |
+
html_with_watermark])
|
| 1004 |
+
detection_z_threshold.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
|
| 1005 |
+
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
|
| 1006 |
+
ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 1007 |
+
ignore_repeated_bigrams.change(fn=detect_partial,
|
| 1008 |
+
inputs=[output_without_watermark, session_args, session_tokenizer],
|
| 1009 |
+
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
|
| 1010 |
+
html_without_watermark])
|
| 1011 |
+
ignore_repeated_bigrams.change(fn=detect_partial,
|
| 1012 |
+
inputs=[output_with_watermark, session_args, session_tokenizer],
|
| 1013 |
+
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
|
| 1014 |
+
html_with_watermark])
|
| 1015 |
+
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
|
| 1016 |
+
outputs=[detection_result, session_args, session_tokenizer,
|
| 1017 |
+
html_detection_input])
|
| 1018 |
+
normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 1019 |
+
normalizers.change(fn=detect_partial, inputs=[output_without_watermark, session_args, session_tokenizer],
|
| 1020 |
+
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
|
| 1021 |
+
html_without_watermark])
|
| 1022 |
+
normalizers.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
|
| 1023 |
+
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
|
| 1024 |
+
html_with_watermark])
|
| 1025 |
+
normalizers.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
|
| 1026 |
+
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
|
| 1027 |
+
select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
| 1028 |
+
select_green_tokens.change(fn=detect_partial,
|
| 1029 |
+
inputs=[output_without_watermark, session_args, session_tokenizer],
|
| 1030 |
+
outputs=[without_watermark_detection_result, session_args, session_tokenizer,
|
| 1031 |
+
html_without_watermark])
|
| 1032 |
+
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark, session_args, session_tokenizer],
|
| 1033 |
+
outputs=[with_watermark_detection_result, session_args, session_tokenizer,
|
| 1034 |
+
html_with_watermark])
|
| 1035 |
+
select_green_tokens.change(fn=detect_partial, inputs=[detection_input, session_args, session_tokenizer],
|
| 1036 |
+
outputs=[detection_result, session_args, session_tokenizer, html_detection_input])
|
| 1037 |
+
|
| 1038 |
+
# demo.queue(concurrency_count=3) # delete
|
| 1039 |
+
|
| 1040 |
+
if args.demo_public:
|
| 1041 |
+
demo.launch(share=True) # 通过随机生成的链接将应用程序暴露到互联网上
|
| 1042 |
+
else:
|
| 1043 |
+
demo.launch(server_name='0.0.0.0', share=False)
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
def main(args):
|
| 1047 |
+
"""运行生成和检测操作的命令行版本
|
| 1048 |
+
并可选择启动和提供 gradio 演示"""
|
| 1049 |
+
# 初始参数处理和日志记录
|
| 1050 |
+
args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
|
| 1051 |
+
print(args)
|
| 1052 |
+
|
| 1053 |
+
if not args.skip_model_load:
|
| 1054 |
+
model, tokenizer, device = load_model(args)
|
| 1055 |
+
else:
|
| 1056 |
+
model, tokenizer, device = None, None, None
|
| 1057 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
| 1058 |
+
if args.use_gpu:
|
| 1059 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1060 |
+
else:
|
| 1061 |
+
device = "cpu"
|
| 1062 |
+
|
| 1063 |
+
# terrapin example
|
| 1064 |
+
input_text = (
|
| 1065 |
+
"为什么A股指数跌的不多,但是我亏损比之前都多?"
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
args.default_prompt = input_text
|
| 1069 |
+
|
| 1070 |
+
# Generate and detect, report to stdout
|
| 1071 |
+
if not args.skip_model_load:
|
| 1072 |
+
pass
|
| 1073 |
+
|
| 1074 |
+
# Launch the app to generate and detect interactively (implements the hf space demo)
|
| 1075 |
+
if args.run_gradio:
|
| 1076 |
+
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
|
| 1077 |
+
|
| 1078 |
+
return
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
if __name__ == "__main__":
|
| 1082 |
+
args = parse_args()
|
| 1083 |
+
print(args)
|
| 1084 |
+
|
| 1085 |
+
main(args)
|
extended_watermark_processor.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import collections
|
| 4 |
+
from math import sqrt
|
| 5 |
+
from itertools import chain, tee
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
import scipy.stats
|
| 9 |
+
import torch
|
| 10 |
+
from tokenizers import Tokenizer
|
| 11 |
+
from transformers import LogitsProcessor
|
| 12 |
+
|
| 13 |
+
from normalizers import normalization_strategy_lookup
|
| 14 |
+
from alternative_prf_schemes import prf_lookup, seeding_scheme_lookup
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class WatermarkBase:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
vocab: list[int] = None,
|
| 21 |
+
gamma: float = 0.25,
|
| 22 |
+
delta: float = 2.0,
|
| 23 |
+
seeding_scheme: str = "selfhash",
|
| 24 |
+
select_green_tokens: bool = True,
|
| 25 |
+
):
|
| 26 |
+
# 现在可能将 None 作为 seeding_scheme 传递,所以现在要修补
|
| 27 |
+
if seeding_scheme is None:
|
| 28 |
+
seeding_scheme = "selfhash"
|
| 29 |
+
|
| 30 |
+
# 词汇设置
|
| 31 |
+
self.vocab = vocab
|
| 32 |
+
self.vocab_size = len(vocab)
|
| 33 |
+
|
| 34 |
+
# 水印行为:
|
| 35 |
+
self.gamma = gamma
|
| 36 |
+
self.delta = delta
|
| 37 |
+
self.rng = None
|
| 38 |
+
self._initialize_seeding_scheme(seeding_scheme)
|
| 39 |
+
# 传统行为:
|
| 40 |
+
self.select_green_tokens = select_green_tokens
|
| 41 |
+
|
| 42 |
+
def _initialize_seeding_scheme(self, seeding_scheme: str) -> None:
|
| 43 |
+
"""从一个通俗的“公共”名称初始化种子策略的所有内部设置。"""
|
| 44 |
+
self.prf_type, self.context_width, self.self_salt, self.hash_key = seeding_scheme_lookup(seeding_scheme)
|
| 45 |
+
|
| 46 |
+
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
| 47 |
+
"""从本地上下文种子 RNG。不进行批处理,因为我们使用的生成器(如 cuda.random)不进行批处理。"""
|
| 48 |
+
# 需要有足够的token来进行种子生成
|
| 49 |
+
if input_ids.shape[-1] < self.context_width:
|
| 50 |
+
raise ValueError(f"seeding_scheme requires at least a {self.context_width} token prefix to seed the RNG.")
|
| 51 |
+
|
| 52 |
+
prf_key = prf_lookup[self.prf_type](input_ids[-self.context_width :], salt_key=self.hash_key)
|
| 53 |
+
self.rng.manual_seed(prf_key % (2**64 - 1)) # 防止溢出
|
| 54 |
+
|
| 55 |
+
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
| 56 |
+
"""根据本地上下文宽度对rng进行种子处理,并使用这些信息在绿色列表上生成id"""
|
| 57 |
+
self._seed_rng(input_ids)
|
| 58 |
+
|
| 59 |
+
greenlist_size = int(self.vocab_size * self.gamma)
|
| 60 |
+
vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
|
| 61 |
+
if self.select_green_tokens: # directly
|
| 62 |
+
greenlist_ids = vocab_permutation[:greenlist_size] # new
|
| 63 |
+
else: # 通过红色选择绿色
|
| 64 |
+
greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
|
| 65 |
+
return greenlist_ids
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
|
| 69 |
+
"""LogitsProcessor 在管道中修改模型输出分数。可以在任何 HF 管道中使用它来修改分数以适应水印,但也可以作为一个独立的工具插入到在模型输出和下一个标记采样器之间生成分数的任何模型中。"""
|
| 70 |
+
def __init__(self, *args, store_spike_ents: bool = False, **kwargs):
|
| 71 |
+
super().__init__(*args, **kwargs)
|
| 72 |
+
|
| 73 |
+
self.store_spike_ents = store_spike_ents
|
| 74 |
+
self.spike_entropies = None
|
| 75 |
+
if self.store_spike_ents:
|
| 76 |
+
self._init_spike_entropies()
|
| 77 |
+
|
| 78 |
+
def _init_spike_entropies(self):
|
| 79 |
+
alpha = torch.exp(torch.tensor(self.delta)).item()
|
| 80 |
+
gamma = self.gamma
|
| 81 |
+
|
| 82 |
+
self.z_value = ((1 - gamma) * (alpha - 1)) / (1 - gamma + (alpha * gamma))
|
| 83 |
+
self.expected_gl_coef = (gamma * alpha) / (1 - gamma + (alpha * gamma))
|
| 84 |
+
|
| 85 |
+
# 当bias 是 "infinite" 时候捕获溢出
|
| 86 |
+
if alpha == torch.inf:
|
| 87 |
+
self.z_value = 1.0
|
| 88 |
+
self.expected_gl_coef = 1.0
|
| 89 |
+
|
| 90 |
+
def _get_spike_entropies(self):
|
| 91 |
+
spike_ents = [[] for _ in range(len(self.spike_entropies))]
|
| 92 |
+
for b_idx, ent_tensor_list in enumerate(self.spike_entropies):
|
| 93 |
+
for ent_tensor in ent_tensor_list:
|
| 94 |
+
spike_ents[b_idx].append(ent_tensor.item())
|
| 95 |
+
return spike_ents
|
| 96 |
+
|
| 97 |
+
def _get_and_clear_stored_spike_ents(self):
|
| 98 |
+
spike_ents = self._get_spike_entropies()
|
| 99 |
+
self.spike_entropies = None
|
| 100 |
+
return spike_ents
|
| 101 |
+
|
| 102 |
+
def _compute_spike_entropy(self, scores):
|
| 103 |
+
# 预先计算z得分
|
| 104 |
+
probs = scores.softmax(dim=-1)
|
| 105 |
+
denoms = 1 + (self.z_value * probs)
|
| 106 |
+
renormed_probs = probs / denoms
|
| 107 |
+
sum_renormed_probs = renormed_probs.sum()
|
| 108 |
+
return sum_renormed_probs
|
| 109 |
+
|
| 110 |
+
def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
|
| 111 |
+
green_tokens_mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 112 |
+
for b_idx, greenlist in enumerate(greenlist_token_ids):
|
| 113 |
+
if len(greenlist) > 0:
|
| 114 |
+
green_tokens_mask[b_idx][greenlist] = True
|
| 115 |
+
return green_tokens_mask
|
| 116 |
+
|
| 117 |
+
def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
|
| 118 |
+
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
| 119 |
+
return scores
|
| 120 |
+
|
| 121 |
+
def _score_rejection_sampling(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, tail_rule="fixed_compute") -> list[int]:
|
| 122 |
+
"""基于当前候选的下一个标记生成绿名单。如果需要,拒绝并继续。该方法不进行批处理。
|
| 123 |
+
这只是算法3“鲁棒私有水印”的部分版本,因为它始终假设贪婪采样。它仍然(有点)可以对所有类型的采样进行工作,但效果较差。
|
| 124 |
+
为了高效工作,此函数可以在处理分布尾部的规则之间切换。
|
| 125 |
+
默认情况下不会公开这些规则。"""
|
| 126 |
+
|
| 127 |
+
sorted_scores, greedy_predictions = scores.sort(dim=-1, descending=True)
|
| 128 |
+
|
| 129 |
+
final_greenlist = []
|
| 130 |
+
for idx, prediction_candidate in enumerate(greedy_predictions):
|
| 131 |
+
greenlist_ids = self._get_greenlist_ids(torch.cat([input_ids, prediction_candidate[None]], dim=0)) # add candidate to prefix
|
| 132 |
+
if prediction_candidate in greenlist_ids: # test for consistency
|
| 133 |
+
final_greenlist.append(prediction_candidate)
|
| 134 |
+
|
| 135 |
+
# 为了提高效率,以下是可选的提前停止规则
|
| 136 |
+
if tail_rule == "fixed_score":
|
| 137 |
+
if sorted_scores[0] - sorted_scores[idx + 1] > self.delta:
|
| 138 |
+
break
|
| 139 |
+
elif tail_rule == "fixed_list_length":
|
| 140 |
+
if len(final_greenlist) == 10:
|
| 141 |
+
break
|
| 142 |
+
elif tail_rule == "fixed_compute":
|
| 143 |
+
if idx == 40:
|
| 144 |
+
break
|
| 145 |
+
else:
|
| 146 |
+
pass # do not break early
|
| 147 |
+
return torch.as_tensor(final_greenlist, device=input_ids.device)
|
| 148 |
+
|
| 149 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 150 |
+
"""使用上一个词作为input_ids进行调用,并为下一个token打分。"""
|
| 151 |
+
|
| 152 |
+
self.rng = torch.Generator(device=input_ids.device) if self.rng is None else self.rng
|
| 153 |
+
|
| 154 |
+
# 注意,去掉这个批循环会很好,但当前,种子和分区操作都不是张量/向量化的,因此批中的每个序列都需要单独处理。
|
| 155 |
+
|
| 156 |
+
list_of_greenlist_ids = [None for _ in input_ids] # Greenlists could differ in length
|
| 157 |
+
for b_idx, input_seq in enumerate(input_ids):
|
| 158 |
+
if self.self_salt:
|
| 159 |
+
greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
|
| 160 |
+
else:
|
| 161 |
+
greenlist_ids = self._get_greenlist_ids(input_seq)
|
| 162 |
+
list_of_greenlist_ids[b_idx] = greenlist_ids
|
| 163 |
+
|
| 164 |
+
# 计算和存储熵
|
| 165 |
+
if self.store_spike_ents:
|
| 166 |
+
if self.spike_entropies is None:
|
| 167 |
+
self.spike_entropies = [[] for _ in range(input_ids.shape[0])]
|
| 168 |
+
self.spike_entropies[b_idx].append(self._compute_spike_entropy(scores[b_idx]))
|
| 169 |
+
|
| 170 |
+
green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=list_of_greenlist_ids)
|
| 171 |
+
scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
|
| 172 |
+
|
| 173 |
+
return scores
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class WatermarkDetector(WatermarkBase):
|
| 177 |
+
"""这是用于检测所有使用 WatermarkLogitsProcessor 印记的水印的检测器。
|
| 178 |
+
|
| 179 |
+
检测器需要给出在文本生成期间给出的完全相同的设置,以复制水印的生成绿名单并检测水印。
|
| 180 |
+
这包括在文本生成期间使用的正确设备、正确的分词器、正确的 seeding_scheme 名称和参数(delta、gamma)。
|
| 181 |
+
|
| 182 |
+
可选参数包括
|
| 183 |
+
* normalizers ["unicode", "homoglyphs", "truecase"] -> 这些可以减轻生成文本中可能触发水印的修改。
|
| 184 |
+
* ignore_repeated_ngrams -> 此选项将更改检测规则,只计算每个唯一 ngram 一次。
|
| 185 |
+
* z_threshold -> 更改此阈值将更改检测器的灵敏度。
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
*args,
|
| 192 |
+
device: torch.device = None,
|
| 193 |
+
tokenizer: Tokenizer = None,
|
| 194 |
+
z_threshold: float = 4.0,
|
| 195 |
+
normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
|
| 196 |
+
ignore_repeated_ngrams: bool = True,
|
| 197 |
+
**kwargs,
|
| 198 |
+
):
|
| 199 |
+
super().__init__(*args, **kwargs)
|
| 200 |
+
# 配置选项
|
| 201 |
+
assert device, "Must pass device"
|
| 202 |
+
assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
|
| 203 |
+
|
| 204 |
+
self.tokenizer = tokenizer
|
| 205 |
+
self.device = device
|
| 206 |
+
self.z_threshold = z_threshold
|
| 207 |
+
self.rng = torch.Generator(device=self.device)
|
| 208 |
+
|
| 209 |
+
self.normalizers = []
|
| 210 |
+
for normalization_strategy in normalizers:
|
| 211 |
+
self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
|
| 212 |
+
self.ignore_repeated_ngrams = ignore_repeated_ngrams
|
| 213 |
+
|
| 214 |
+
def dummy_detect(
|
| 215 |
+
self,
|
| 216 |
+
return_prediction: bool = True,
|
| 217 |
+
return_scores: bool = True,
|
| 218 |
+
z_threshold: float = None,
|
| 219 |
+
return_num_tokens_scored: bool = True,
|
| 220 |
+
return_num_green_tokens: bool = True,
|
| 221 |
+
return_green_fraction: bool = True,
|
| 222 |
+
return_green_token_mask: bool = False,
|
| 223 |
+
return_all_window_scores: bool = False,
|
| 224 |
+
return_z_score: bool = True,
|
| 225 |
+
return_z_at_T: bool = True,
|
| 226 |
+
return_p_value: bool = True,
|
| 227 |
+
):
|
| 228 |
+
# HF-style 输出字典
|
| 229 |
+
score_dict = dict()
|
| 230 |
+
if return_num_tokens_scored:
|
| 231 |
+
score_dict.update(dict(num_tokens_scored=float("nan")))
|
| 232 |
+
if return_num_green_tokens:
|
| 233 |
+
score_dict.update(dict(num_green_tokens=float("nan")))
|
| 234 |
+
if return_green_fraction:
|
| 235 |
+
score_dict.update(dict(green_fraction=float("nan")))
|
| 236 |
+
if return_z_score:
|
| 237 |
+
score_dict.update(dict(z_score=float("nan")))
|
| 238 |
+
if return_p_value:
|
| 239 |
+
z_score = score_dict.get("z_score")
|
| 240 |
+
if z_score is None:
|
| 241 |
+
z_score = float("nan")
|
| 242 |
+
score_dict.update(dict(p_value=float("nan")))
|
| 243 |
+
if return_green_token_mask:
|
| 244 |
+
score_dict.update(dict(green_token_mask=[]))
|
| 245 |
+
if return_all_window_scores:
|
| 246 |
+
score_dict.update(dict(window_list=[]))
|
| 247 |
+
if return_z_at_T:
|
| 248 |
+
score_dict.update(dict(z_score_at_T=torch.tensor([])))
|
| 249 |
+
|
| 250 |
+
output_dict = {}
|
| 251 |
+
if return_scores:
|
| 252 |
+
output_dict.update(score_dict)
|
| 253 |
+
# 如果通过return_prediction,则执行假设检验并返回结果
|
| 254 |
+
if return_prediction:
|
| 255 |
+
z_threshold = z_threshold if z_threshold else self.z_threshold
|
| 256 |
+
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
|
| 257 |
+
output_dict["prediction"] = False
|
| 258 |
+
|
| 259 |
+
return output_dict
|
| 260 |
+
|
| 261 |
+
def _compute_z_score(self, observed_count, T):
|
| 262 |
+
# count是指绿色token的数量,T是token的总数
|
| 263 |
+
expected_count = self.gamma
|
| 264 |
+
numer = observed_count - expected_count * T
|
| 265 |
+
denom = sqrt(T * expected_count * (1 - expected_count))
|
| 266 |
+
z = numer / denom
|
| 267 |
+
return z
|
| 268 |
+
|
| 269 |
+
def _compute_p_value(self, z):
|
| 270 |
+
p_value = scipy.stats.norm.sf(z)
|
| 271 |
+
return p_value
|
| 272 |
+
|
| 273 |
+
@lru_cache(maxsize=2**32)
|
| 274 |
+
def _get_ngram_score_cached(self, prefix: tuple[int], target: int):
|
| 275 |
+
"""缓存了re-seeding and sampling"""
|
| 276 |
+
# 需要小心处理, 理想情况下应在__getattribute__访问self.prof_type、self.text_width、self.self_salt、self.hash_key时重置
|
| 277 |
+
greenlist_ids = self._get_greenlist_ids(torch.as_tensor(prefix, device=self.device))
|
| 278 |
+
return True if target in greenlist_ids else False
|
| 279 |
+
|
| 280 |
+
def _score_ngrams_in_passage(self, input_ids: torch.Tensor):
|
| 281 |
+
"""核心功能是收集输入中的所有ngram并计算其水印"""
|
| 282 |
+
if len(input_ids) - self.context_width < 1:
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"Must have at least {1} token to score after "
|
| 285 |
+
f"the first min_prefix_len={self.context_width} tokens required by the seeding scheme."
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# 计算文章中所有ngrams上下文的分数:
|
| 289 |
+
token_ngram_generator = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
|
| 290 |
+
frequencies_table = collections.Counter(token_ngram_generator)
|
| 291 |
+
ngram_to_watermark_lookup = {}
|
| 292 |
+
for idx, ngram_example in enumerate(frequencies_table.keys()):
|
| 293 |
+
prefix = ngram_example if self.self_salt else ngram_example[:-1]
|
| 294 |
+
target = ngram_example[-1]
|
| 295 |
+
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
|
| 296 |
+
|
| 297 |
+
return ngram_to_watermark_lookup, frequencies_table
|
| 298 |
+
|
| 299 |
+
def _get_green_at_T_booleans(self, input_ids, ngram_to_watermark_lookup) -> tuple[torch.Tensor]:
|
| 300 |
+
"""生成基于每个标记的绿色与红色二值列表,一个忽略重复n-gram的独立列表,以及一个用于在两种表示法之间转换的偏移量列表:
|
| 301 |
+
green_token_mask = green_token_mask_unique[offsets],除了在所有会被计算为重复的位置之外
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
green_token_mask, green_token_mask_unique, offsets = [], [], []
|
| 305 |
+
used_ngrams = {}
|
| 306 |
+
unique_ngram_idx = 0
|
| 307 |
+
ngram_examples = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
|
| 308 |
+
|
| 309 |
+
for idx, ngram_example in enumerate(ngram_examples):
|
| 310 |
+
green_token_mask.append(ngram_to_watermark_lookup[ngram_example])
|
| 311 |
+
if self.ignore_repeated_ngrams:
|
| 312 |
+
if ngram_example in used_ngrams:
|
| 313 |
+
pass
|
| 314 |
+
else:
|
| 315 |
+
used_ngrams[ngram_example] = True
|
| 316 |
+
unique_ngram_idx += 1
|
| 317 |
+
green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
|
| 318 |
+
else:
|
| 319 |
+
green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
|
| 320 |
+
unique_ngram_idx += 1
|
| 321 |
+
offsets.append(unique_ngram_idx - 1)
|
| 322 |
+
return (
|
| 323 |
+
torch.tensor(green_token_mask),
|
| 324 |
+
torch.tensor(green_token_mask_unique),
|
| 325 |
+
torch.tensor(offsets),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def _score_sequence(
|
| 329 |
+
self,
|
| 330 |
+
input_ids: torch.Tensor,
|
| 331 |
+
return_num_tokens_scored: bool = True,
|
| 332 |
+
return_num_green_tokens: bool = True,
|
| 333 |
+
return_green_fraction: bool = True,
|
| 334 |
+
return_green_token_mask: bool = False,
|
| 335 |
+
return_z_score: bool = True,
|
| 336 |
+
return_z_at_T: bool = True,
|
| 337 |
+
return_p_value: bool = True,
|
| 338 |
+
):
|
| 339 |
+
ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
|
| 340 |
+
green_token_mask, green_unique, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
|
| 341 |
+
|
| 342 |
+
# 把所有ngrams的分数加起来
|
| 343 |
+
if self.ignore_repeated_ngrams:
|
| 344 |
+
# 一个方法,只对每个唯一的n-gram计算一次绿色/红色命中。
|
| 345 |
+
# 新的总标记评分数(T)变为唯一n-gram的数量。
|
| 346 |
+
# 我们遍历输入中的所有唯一的标记n-gram,计算每个上下文诱导的绿名单,
|
| 347 |
+
# 然后检查最后一个标记是否落在该绿名单中。
|
| 348 |
+
|
| 349 |
+
num_tokens_scored = len(frequencies_table.keys())
|
| 350 |
+
green_token_count = sum(ngram_to_watermark_lookup.values())
|
| 351 |
+
else:
|
| 352 |
+
num_tokens_scored = sum(frequencies_table.values())
|
| 353 |
+
assert num_tokens_scored == len(input_ids) - self.context_width + self.self_salt
|
| 354 |
+
green_token_count = sum(freq * outcome for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values()))
|
| 355 |
+
assert green_token_count == green_unique.sum()
|
| 356 |
+
|
| 357 |
+
# HF-style 字典
|
| 358 |
+
score_dict = dict()
|
| 359 |
+
if return_num_tokens_scored:
|
| 360 |
+
score_dict.update(dict(num_tokens_scored=num_tokens_scored))
|
| 361 |
+
if return_num_green_tokens:
|
| 362 |
+
score_dict.update(dict(num_green_tokens=green_token_count))
|
| 363 |
+
if return_green_fraction:
|
| 364 |
+
score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
|
| 365 |
+
if return_z_score:
|
| 366 |
+
score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
|
| 367 |
+
if return_p_value:
|
| 368 |
+
z_score = score_dict.get("z_score")
|
| 369 |
+
if z_score is None:
|
| 370 |
+
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
|
| 371 |
+
score_dict.update(dict(p_value=self._compute_p_value(z_score)))
|
| 372 |
+
if return_green_token_mask:
|
| 373 |
+
score_dict.update(dict(green_token_mask=green_token_mask.tolist()))
|
| 374 |
+
if return_z_at_T:
|
| 375 |
+
# Score z_at_T:
|
| 376 |
+
sizes = torch.arange(1, len(green_unique) + 1)
|
| 377 |
+
seq_z_score_enum = torch.cumsum(green_unique, dim=0) - self.gamma * sizes
|
| 378 |
+
seq_z_score_denom = torch.sqrt(sizes * self.gamma * (1 - self.gamma))
|
| 379 |
+
z_score_at_effective_T = seq_z_score_enum / seq_z_score_denom
|
| 380 |
+
z_score_at_T = z_score_at_effective_T[offsets]
|
| 381 |
+
assert torch.isclose(z_score_at_T[-1], torch.tensor(z_score))
|
| 382 |
+
|
| 383 |
+
score_dict.update(dict(z_score_at_T=z_score_at_T))
|
| 384 |
+
|
| 385 |
+
return score_dict
|
| 386 |
+
|
| 387 |
+
def _score_windows_impl_batched(
|
| 388 |
+
self,
|
| 389 |
+
input_ids: torch.Tensor,
|
| 390 |
+
window_size: str,
|
| 391 |
+
window_stride: int = 1,
|
| 392 |
+
):
|
| 393 |
+
# 实现细节:
|
| 394 |
+
# 1) --ignore_repeated_ngrams 选项被全局应用,然后在减少的二值向量上应用窗口化处理。
|
| 395 |
+
# 这只是实现的一种方式,另一种方法是在每个窗口内忽略bigram(这可能更难并行化处理)。
|
| 396 |
+
# 2) 这些窗口在绿色/红色命中的二值向量上,独立于 context_width,与 Kezhi 的第一个实现不同。
|
| 397 |
+
# 3) 由于窗口化的处理,这个实现得到的 z-分数不能直接转换为 p-值,并且应该只用作对选定 FPR 进行校准的 ROC 图的标签。
|
| 398 |
+
# 由于多次假设测试,整体得分将被提高。
|
| 399 |
+
# naive_count_correction=True 是对这个问题的一种部分解决方法。
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
|
| 403 |
+
green_mask, green_ids, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
|
| 404 |
+
len_full_context = len(green_ids)
|
| 405 |
+
|
| 406 |
+
partial_sum_id_table = torch.cumsum(green_ids, dim=0)
|
| 407 |
+
|
| 408 |
+
if window_size == "max":
|
| 409 |
+
# 可以稍后启动,小窗口无法产生足够的能量
|
| 410 |
+
# solve (T * Spike_Entropy - g * T) / sqrt(T * g * (1 - g)) = z_thresh for T
|
| 411 |
+
sizes = range(1, len_full_context)
|
| 412 |
+
else:
|
| 413 |
+
sizes = [int(x) for x in window_size.split(",") if len(x) > 0]
|
| 414 |
+
|
| 415 |
+
z_score_max_per_window = torch.zeros(len(sizes))
|
| 416 |
+
cumulative_eff_z_score = torch.zeros(len_full_context)
|
| 417 |
+
s = window_stride
|
| 418 |
+
|
| 419 |
+
window_fits = False
|
| 420 |
+
for idx, size in enumerate(sizes):
|
| 421 |
+
if size <= len_full_context:
|
| 422 |
+
# 并行计算窗口内所有位置的hit:
|
| 423 |
+
window_score = torch.zeros(len_full_context - size + 1, dtype=torch.long)
|
| 424 |
+
# 包括第0个窗口
|
| 425 |
+
window_score[0] = partial_sum_id_table[size - 1]
|
| 426 |
+
# 从1号开始的所有其他窗口:
|
| 427 |
+
window_score[1:] = partial_sum_id_table[size::s] - partial_sum_id_table[:-size:s]
|
| 428 |
+
|
| 429 |
+
# 现在计算批处理的z_scores
|
| 430 |
+
batched_z_score_enum = window_score - self.gamma * size
|
| 431 |
+
z_score_denom = sqrt(size * self.gamma * (1 - self.gamma))
|
| 432 |
+
batched_z_score = batched_z_score_enum / z_score_denom
|
| 433 |
+
|
| 434 |
+
# 找到最大的hit
|
| 435 |
+
maximal_z_score = batched_z_score.max()
|
| 436 |
+
z_score_max_per_window[idx] = maximal_z_score
|
| 437 |
+
|
| 438 |
+
z_score_at_effective_T = torch.cummax(batched_z_score, dim=0)[0]
|
| 439 |
+
cumulative_eff_z_score[size::s] = torch.maximum(cumulative_eff_z_score[size::s], z_score_at_effective_T[:-1])
|
| 440 |
+
window_fits = True # 成功计算所有大小的窗口
|
| 441 |
+
|
| 442 |
+
if not window_fits:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
f"Could not find a fitting window with window sizes {window_size} for (effective) context length {len_full_context}."
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# 计算最佳窗口大小和z得分
|
| 448 |
+
cumulative_z_score = cumulative_eff_z_score[offsets]
|
| 449 |
+
optimal_z, optimal_window_size_idx = z_score_max_per_window.max(dim=0)
|
| 450 |
+
optimal_window_size = sizes[optimal_window_size_idx]
|
| 451 |
+
return (
|
| 452 |
+
optimal_z,
|
| 453 |
+
optimal_window_size,
|
| 454 |
+
z_score_max_per_window,
|
| 455 |
+
cumulative_z_score,
|
| 456 |
+
green_mask,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def _score_sequence_window(
|
| 460 |
+
self,
|
| 461 |
+
input_ids: torch.Tensor,
|
| 462 |
+
return_num_tokens_scored: bool = True,
|
| 463 |
+
return_num_green_tokens: bool = True,
|
| 464 |
+
return_green_fraction: bool = True,
|
| 465 |
+
return_green_token_mask: bool = False,
|
| 466 |
+
return_z_score: bool = True,
|
| 467 |
+
return_z_at_T: bool = True,
|
| 468 |
+
return_p_value: bool = True,
|
| 469 |
+
window_size: str = None,
|
| 470 |
+
window_stride: int = 1,
|
| 471 |
+
):
|
| 472 |
+
(
|
| 473 |
+
optimal_z,
|
| 474 |
+
optimal_window_size,
|
| 475 |
+
_,
|
| 476 |
+
z_score_at_T,
|
| 477 |
+
green_mask,
|
| 478 |
+
) = self._score_windows_impl_batched(input_ids, window_size, window_stride)
|
| 479 |
+
|
| 480 |
+
# HF-style 字典
|
| 481 |
+
score_dict = dict()
|
| 482 |
+
if return_num_tokens_scored:
|
| 483 |
+
score_dict.update(dict(num_tokens_scored=optimal_window_size))
|
| 484 |
+
|
| 485 |
+
denom = sqrt(optimal_window_size * self.gamma * (1 - self.gamma))
|
| 486 |
+
green_token_count = int(optimal_z * denom + self.gamma * optimal_window_size)
|
| 487 |
+
green_fraction = green_token_count / optimal_window_size
|
| 488 |
+
if return_num_green_tokens:
|
| 489 |
+
score_dict.update(dict(num_green_tokens=green_token_count))
|
| 490 |
+
if return_green_fraction:
|
| 491 |
+
score_dict.update(dict(green_fraction=green_fraction))
|
| 492 |
+
if return_z_score:
|
| 493 |
+
score_dict.update(dict(z_score=optimal_z))
|
| 494 |
+
if return_z_at_T:
|
| 495 |
+
score_dict.update(dict(z_score_at_T=z_score_at_T))
|
| 496 |
+
if return_p_value:
|
| 497 |
+
z_score = score_dict.get("z_score", optimal_z)
|
| 498 |
+
score_dict.update(dict(p_value=self._compute_p_value(z_score)))
|
| 499 |
+
|
| 500 |
+
# 返回掩码的每个标记的结果。这仍然是相同的,只是通过窗口化进行评分。
|
| 501 |
+
# 待办事项是将实际被计数的标记以不同的方式标记。
|
| 502 |
+
|
| 503 |
+
if return_green_token_mask:
|
| 504 |
+
score_dict.update(dict(green_token_mask=green_mask.tolist()))
|
| 505 |
+
|
| 506 |
+
return score_dict
|
| 507 |
+
|
| 508 |
+
def detect(
|
| 509 |
+
self,
|
| 510 |
+
text: str = None,
|
| 511 |
+
tokenized_text: list[int] = None,
|
| 512 |
+
window_size: str = None,
|
| 513 |
+
window_stride: int = None,
|
| 514 |
+
return_prediction: bool = True,
|
| 515 |
+
return_scores: bool = True,
|
| 516 |
+
z_threshold: float = None,
|
| 517 |
+
convert_to_float: bool = False,
|
| 518 |
+
**kwargs,
|
| 519 |
+
) -> dict:
|
| 520 |
+
"""对给定的文本字符串进行评分,并返回结果字典"""
|
| 521 |
+
|
| 522 |
+
assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
|
| 523 |
+
if return_prediction:
|
| 524 |
+
kwargs["return_p_value"] = True # 返回阳性检测的 "confidence":=1-p
|
| 525 |
+
|
| 526 |
+
# 对文本运行可选的normalizers
|
| 527 |
+
for normalizer in self.normalizers:
|
| 528 |
+
text = normalizer(text)
|
| 529 |
+
if len(self.normalizers) > 0:
|
| 530 |
+
print(f"Text after normalization:\n\n{text}\n")
|
| 531 |
+
|
| 532 |
+
if tokenized_text is None:
|
| 533 |
+
assert self.tokenizer is not None, (
|
| 534 |
+
"Watermark detection on raw string ",
|
| 535 |
+
"requires an instance of the tokenizer ",
|
| 536 |
+
"that was used at generation time.",
|
| 537 |
+
)
|
| 538 |
+
tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
|
| 539 |
+
if tokenized_text[0] == self.tokenizer.bos_token_id:
|
| 540 |
+
tokenized_text = tokenized_text[1:]
|
| 541 |
+
else:
|
| 542 |
+
# 尝试从一开始就删除bos_tok(如果它在那里的话)
|
| 543 |
+
if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
|
| 544 |
+
tokenized_text = tokenized_text[1:]
|
| 545 |
+
|
| 546 |
+
# 调用score方法
|
| 547 |
+
output_dict = {}
|
| 548 |
+
|
| 549 |
+
if window_size is not None:
|
| 550 |
+
score_dict = self._score_sequence_window(
|
| 551 |
+
tokenized_text,
|
| 552 |
+
window_size=window_size,
|
| 553 |
+
window_stride=window_stride,
|
| 554 |
+
**kwargs,
|
| 555 |
+
)
|
| 556 |
+
output_dict.update(score_dict)
|
| 557 |
+
else:
|
| 558 |
+
score_dict = self._score_sequence(tokenized_text, **kwargs)
|
| 559 |
+
if return_scores:
|
| 560 |
+
output_dict.update(score_dict)
|
| 561 |
+
# 如果通过return_prediction,则执行假设检验并返回结果
|
| 562 |
+
if return_prediction:
|
| 563 |
+
z_threshold = z_threshold if z_threshold else self.z_threshold
|
| 564 |
+
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
|
| 565 |
+
output_dict["prediction"] = score_dict["z_score"] > z_threshold
|
| 566 |
+
if output_dict["prediction"]:
|
| 567 |
+
output_dict["confidence"] = 1 - score_dict["p_value"]
|
| 568 |
+
|
| 569 |
+
# 如果需要的话,将任何数值转换为浮点值
|
| 570 |
+
if convert_to_float:
|
| 571 |
+
for key, value in output_dict.items():
|
| 572 |
+
if isinstance(value, int):
|
| 573 |
+
output_dict[key] = float(value)
|
| 574 |
+
|
| 575 |
+
return output_dict
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def ngrams(sequence, n, pad_left=False, pad_right=False, pad_symbol=None):
|
| 582 |
+
sequence = iter(sequence)
|
| 583 |
+
if pad_left:
|
| 584 |
+
sequence = chain((pad_symbol,) * (n - 1), sequence)
|
| 585 |
+
if pad_right:
|
| 586 |
+
sequence = chain(sequence, (pad_symbol,) * (n - 1))
|
| 587 |
+
iterables = tee(sequence, n)
|
| 588 |
+
|
| 589 |
+
for i, sub_iterable in enumerate(iterables): # For each window,
|
| 590 |
+
for _ in range(i): # iterate through every order of ngrams
|
| 591 |
+
next(sub_iterable, None) # generate the ngrams within the window.
|
| 592 |
+
return zip(*iterables) # Unpack and flattens the iterables.
|
homoglyphs.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import json
|
| 4 |
+
from itertools import product
|
| 5 |
+
import os
|
| 6 |
+
import unicodedata
|
| 7 |
+
|
| 8 |
+
STRATEGY_LOAD = 1 # 加载类别
|
| 9 |
+
STRATEGY_IGNORE = 2 # 对结果添加字符
|
| 10 |
+
STRATEGY_REMOVE = 3 # 对结果移除字符
|
| 11 |
+
|
| 12 |
+
ASCII_RANGE = range(128)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Categories:
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
fpath = os.path.join(DATA_LOCATION, "categories.json")
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def _get_ranges(cls, categories):
|
| 26 |
+
|
| 27 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
| 28 |
+
data = json.load(f)
|
| 29 |
+
|
| 30 |
+
for category in categories:
|
| 31 |
+
if category not in data["aliases"]:
|
| 32 |
+
raise ValueError("Invalid category: {}".format(category))
|
| 33 |
+
|
| 34 |
+
for point in data["points"]:
|
| 35 |
+
if point[2] in categories:
|
| 36 |
+
yield point[:2]
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def get_alphabet(cls, categories):
|
| 40 |
+
|
| 41 |
+
alphabet = set()
|
| 42 |
+
for start, end in cls._get_ranges(categories):
|
| 43 |
+
chars = (chr(code) for code in range(start, end + 1))
|
| 44 |
+
alphabet.update(chars)
|
| 45 |
+
return alphabet
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def detect(cls, char):
|
| 49 |
+
"""
|
| 50 |
+
:return: category
|
| 51 |
+
:rtype: str
|
| 52 |
+
"""
|
| 53 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
| 54 |
+
data = json.load(f)
|
| 55 |
+
|
| 56 |
+
# 尝试用unicodedata检测类别
|
| 57 |
+
try:
|
| 58 |
+
category = unicodedata.name(char).split()[0]
|
| 59 |
+
except (TypeError, ValueError):
|
| 60 |
+
pass
|
| 61 |
+
else:
|
| 62 |
+
if category in data["aliases"]:
|
| 63 |
+
return category
|
| 64 |
+
|
| 65 |
+
# 尝试从JSON文件中按范围检测类别
|
| 66 |
+
code = ord(char)
|
| 67 |
+
for point in data["points"]:
|
| 68 |
+
if point[0] <= code <= point[1]:
|
| 69 |
+
return point[2]
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def get_all(cls):
|
| 73 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
| 74 |
+
data = json.load(f)
|
| 75 |
+
return set(data["aliases"])
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Languages:
|
| 79 |
+
fpath = os.path.join(DATA_LOCATION, "languages.json")
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def get_alphabet(cls, languages):
|
| 83 |
+
"""
|
| 84 |
+
:return: set of chars in alphabet by languages list
|
| 85 |
+
:rtype: set
|
| 86 |
+
"""
|
| 87 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
| 88 |
+
data = json.load(f)
|
| 89 |
+
alphabet = set()
|
| 90 |
+
for lang in languages:
|
| 91 |
+
if lang not in data:
|
| 92 |
+
raise ValueError("Invalid language code: {}".format(lang))
|
| 93 |
+
alphabet.update(data[lang])
|
| 94 |
+
return alphabet
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def detect(cls, char):
|
| 98 |
+
"""
|
| 99 |
+
:return: set of languages which alphabet contains passed char.
|
| 100 |
+
:rtype: set
|
| 101 |
+
"""
|
| 102 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
| 103 |
+
data = json.load(f)
|
| 104 |
+
languages = set()
|
| 105 |
+
for lang, alphabet in data.items():
|
| 106 |
+
if char in alphabet:
|
| 107 |
+
languages.add(lang)
|
| 108 |
+
return languages
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def get_all(cls):
|
| 112 |
+
with open(cls.fpath, encoding="utf-8") as f:
|
| 113 |
+
data = json.load(f)
|
| 114 |
+
return set(data.keys())
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Homoglyphs:
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
categories=None,
|
| 121 |
+
languages=None,
|
| 122 |
+
alphabet=None,
|
| 123 |
+
strategy=STRATEGY_IGNORE,
|
| 124 |
+
ascii_strategy=STRATEGY_IGNORE,
|
| 125 |
+
ascii_range=ASCII_RANGE,
|
| 126 |
+
):
|
| 127 |
+
# strategies
|
| 128 |
+
if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE):
|
| 129 |
+
raise ValueError("Invalid strategy")
|
| 130 |
+
self.strategy = strategy
|
| 131 |
+
self.ascii_strategy = ascii_strategy
|
| 132 |
+
self.ascii_range = ascii_range
|
| 133 |
+
|
| 134 |
+
# Homoglyphs必须由任何字母表初始化才能正确工作
|
| 135 |
+
if not categories and not languages and not alphabet:
|
| 136 |
+
categories = ("LATIN", "COMMON")
|
| 137 |
+
|
| 138 |
+
# cats and langs
|
| 139 |
+
self.categories = set(categories or [])
|
| 140 |
+
self.languages = set(languages or [])
|
| 141 |
+
|
| 142 |
+
# alphabet
|
| 143 |
+
self.alphabet = set(alphabet or [])
|
| 144 |
+
if self.categories:
|
| 145 |
+
alphabet = Categories.get_alphabet(self.categories)
|
| 146 |
+
self.alphabet.update(alphabet)
|
| 147 |
+
if self.languages:
|
| 148 |
+
alphabet = Languages.get_alphabet(self.languages)
|
| 149 |
+
self.alphabet.update(alphabet)
|
| 150 |
+
self.table = self.get_table(self.alphabet)
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def get_table(alphabet):
|
| 154 |
+
table = defaultdict(set)
|
| 155 |
+
with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
|
| 156 |
+
data = json.load(f)
|
| 157 |
+
for char in alphabet:
|
| 158 |
+
if char in data:
|
| 159 |
+
for homoglyph in data[char]:
|
| 160 |
+
if homoglyph in alphabet:
|
| 161 |
+
table[char].add(homoglyph)
|
| 162 |
+
return table
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def get_restricted_table(source_alphabet, target_alphabet):
|
| 166 |
+
table = defaultdict(set)
|
| 167 |
+
with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
|
| 168 |
+
data = json.load(f)
|
| 169 |
+
for char in source_alphabet:
|
| 170 |
+
if char in data:
|
| 171 |
+
for homoglyph in data[char]:
|
| 172 |
+
if homoglyph in target_alphabet:
|
| 173 |
+
table[char].add(homoglyph)
|
| 174 |
+
return table
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def uniq_and_sort(data):
|
| 178 |
+
result = list(set(data))
|
| 179 |
+
result.sort(key=lambda x: (-len(x), x))
|
| 180 |
+
return result
|
| 181 |
+
|
| 182 |
+
def _update_alphabet(self, char):
|
| 183 |
+
# 尝试检测语言
|
| 184 |
+
langs = Languages.detect(char)
|
| 185 |
+
if langs:
|
| 186 |
+
self.languages.update(langs)
|
| 187 |
+
alphabet = Languages.get_alphabet(langs)
|
| 188 |
+
self.alphabet.update(alphabet)
|
| 189 |
+
else:
|
| 190 |
+
# 尝试检测类别
|
| 191 |
+
category = Categories.detect(char)
|
| 192 |
+
if category is None:
|
| 193 |
+
return False
|
| 194 |
+
self.categories.add(category)
|
| 195 |
+
alphabet = Categories.get_alphabet([category])
|
| 196 |
+
self.alphabet.update(alphabet)
|
| 197 |
+
# 更新新字母表的表格
|
| 198 |
+
self.table = self.get_table(self.alphabet)
|
| 199 |
+
return True
|
| 200 |
+
|
| 201 |
+
def _get_char_variants(self, char):
|
| 202 |
+
if char not in self.alphabet:
|
| 203 |
+
if self.strategy == STRATEGY_LOAD:
|
| 204 |
+
if not self._update_alphabet(char):
|
| 205 |
+
return []
|
| 206 |
+
elif self.strategy == STRATEGY_IGNORE:
|
| 207 |
+
return [char]
|
| 208 |
+
elif self.strategy == STRATEGY_REMOVE:
|
| 209 |
+
return []
|
| 210 |
+
|
| 211 |
+
# 查找当前字符的替代字符
|
| 212 |
+
alt_chars = self.table.get(char, set())
|
| 213 |
+
if alt_chars:
|
| 214 |
+
# 为当前字符查找可选字符
|
| 215 |
+
alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars]
|
| 216 |
+
# 合并所有备选方案
|
| 217 |
+
alt_chars.update(*alt_chars2)
|
| 218 |
+
# 将当前字符添加到备选项
|
| 219 |
+
alt_chars.add(char)
|
| 220 |
+
|
| 221 |
+
# uniq, sort and return
|
| 222 |
+
return self.uniq_and_sort(alt_chars)
|
| 223 |
+
|
| 224 |
+
def _get_combinations(self, text, ascii=False):
|
| 225 |
+
variations = []
|
| 226 |
+
for char in text:
|
| 227 |
+
alt_chars = self._get_char_variants(char)
|
| 228 |
+
|
| 229 |
+
if ascii:
|
| 230 |
+
alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range]
|
| 231 |
+
if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
if alt_chars:
|
| 235 |
+
variations.append(alt_chars)
|
| 236 |
+
if variations:
|
| 237 |
+
for variant in product(*variations):
|
| 238 |
+
yield "".join(variant)
|
| 239 |
+
|
| 240 |
+
def get_combinations(self, text):
|
| 241 |
+
return list(self._get_combinations(text))
|
| 242 |
+
|
| 243 |
+
def _to_ascii(self, text):
|
| 244 |
+
for variant in self._get_combinations(text, ascii=True):
|
| 245 |
+
if max(map(ord, variant)) in self.ascii_range:
|
| 246 |
+
yield variant
|
| 247 |
+
|
| 248 |
+
def to_ascii(self, text):
|
| 249 |
+
return self.uniq_and_sort(self._to_ascii(text))
|
normalizers.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""文本基础规范化器,用于减轻针对水印的简单攻击。
|
| 2 |
+
|
| 3 |
+
这个实现不太可能是所有可能的Unicode标准中的所有漏洞的完整列表,
|
| 4 |
+
它代表了我们在撰写时的最佳努力。
|
| 5 |
+
|
| 6 |
+
这些规范化器可以作为独立的规范化器使用。它们可以被制作成符合HF分词器标准的规范化器,
|
| 7 |
+
但这将需要涉及tokenizers.NormalizedString的有限Rust接口。
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from functools import cache
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import unicodedata
|
| 15 |
+
import homoglyphs as hg
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def normalization_strategy_lookup(strategy_name: str) -> object:
|
| 19 |
+
if strategy_name == "unicode":
|
| 20 |
+
return UnicodeSanitizer()
|
| 21 |
+
elif strategy_name == "homoglyphs":
|
| 22 |
+
return HomoglyphCanonizer()
|
| 23 |
+
elif strategy_name == "truecase":
|
| 24 |
+
return TrueCaser()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HomoglyphCanonizer:
|
| 28 |
+
"""尝试检测同形字攻击并找到一致的标准形式。
|
| 29 |
+
|
| 30 |
+
这个函数是在ISO分类级别上进行的。也可以在语言级别上进行(参见注释掉的代码)。
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.homoglyphs = None
|
| 36 |
+
|
| 37 |
+
def __call__(self, homoglyphed_str: str) -> str:
|
| 38 |
+
# find canon:
|
| 39 |
+
target_category, all_categories = self._categorize_text(homoglyphed_str)
|
| 40 |
+
homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
|
| 41 |
+
return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
|
| 42 |
+
|
| 43 |
+
def _categorize_text(self, text: str) -> dict:
|
| 44 |
+
iso_categories = defaultdict(int)
|
| 45 |
+
# self.iso_languages = defaultdict(int)
|
| 46 |
+
|
| 47 |
+
for char in text:
|
| 48 |
+
iso_categories[hg.Categories.detect(char)] += 1
|
| 49 |
+
# for lang in hg.Languages.detect(char):
|
| 50 |
+
# self.iso_languages[lang] += 1
|
| 51 |
+
target_category = max(iso_categories, key=iso_categories.get)
|
| 52 |
+
all_categories = tuple(iso_categories)
|
| 53 |
+
return target_category, all_categories
|
| 54 |
+
|
| 55 |
+
@cache
|
| 56 |
+
def _select_canon_category_and_load(
|
| 57 |
+
self, target_category: str, all_categories: tuple[str]
|
| 58 |
+
) -> dict:
|
| 59 |
+
homoglyph_table = hg.Homoglyphs(
|
| 60 |
+
categories=(target_category, "COMMON")
|
| 61 |
+
) # 从文件中加载到此处的字母表
|
| 62 |
+
|
| 63 |
+
source_alphabet = hg.Categories.get_alphabet(all_categories)
|
| 64 |
+
restricted_table = homoglyph_table.get_restricted_table(
|
| 65 |
+
source_alphabet, homoglyph_table.alphabet
|
| 66 |
+
) # 从文件中加载到此处的表
|
| 67 |
+
return restricted_table
|
| 68 |
+
|
| 69 |
+
def _sanitize_text(
|
| 70 |
+
self, target_category: str, homoglyph_table: dict, homoglyphed_str: str
|
| 71 |
+
) -> str:
|
| 72 |
+
sanitized_text = ""
|
| 73 |
+
for char in homoglyphed_str:
|
| 74 |
+
# langs = hg.Languages.detect(char)
|
| 75 |
+
cat = hg.Categories.detect(char)
|
| 76 |
+
if target_category in cat or "COMMON" in cat or len(cat) == 0:
|
| 77 |
+
sanitized_text += char
|
| 78 |
+
else:
|
| 79 |
+
sanitized_text += list(homoglyph_table[char])[0]
|
| 80 |
+
return sanitized_text
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class UnicodeSanitizer:
|
| 84 |
+
|
| 85 |
+
def __init__(self, ruleset="whitespaces"):
|
| 86 |
+
if ruleset == "whitespaces":
|
| 87 |
+
"""Documentation:
|
| 88 |
+
\u00A0: Non-breaking space
|
| 89 |
+
\u1680: Ogham space mark
|
| 90 |
+
\u180E: Mongolian vowel separator
|
| 91 |
+
\u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
|
| 92 |
+
\u200C\u200D: Zero-width non-joiner and zero-width joiner
|
| 93 |
+
\u200E,\u200F: Left-to-right-mark, Right-to-left-mark
|
| 94 |
+
\u2060: Word joiner
|
| 95 |
+
\u2063: Invisible separator
|
| 96 |
+
\u202F: Narrow non-breaking space
|
| 97 |
+
\u205F: Medium mathematical space
|
| 98 |
+
\u3000: Ideographic space
|
| 99 |
+
\uFEFF: Zero-width non-breaking space
|
| 100 |
+
\uFFA0: Halfwidth hangul filler
|
| 101 |
+
\uFFF9\uFFFA\uFFFB: Interlinear annotation characters
|
| 102 |
+
\uFE00-\uFE0F: Variation selectors
|
| 103 |
+
\u202A-\u202F: Embedding characters
|
| 104 |
+
\u3164: Korean hangul filler.
|
| 105 |
+
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
self.pattern = re.compile(
|
| 109 |
+
r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
|
| 110 |
+
r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
|
| 111 |
+
r"\u202E\u202F]"
|
| 112 |
+
)
|
| 113 |
+
elif ruleset == "IDN.blacklist":
|
| 114 |
+
"""Documentation:
|
| 115 |
+
[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
|
| 116 |
+
set that are included in the IDN blacklist.
|
| 117 |
+
\uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
|
| 118 |
+
These characters are not allowed in domain names.
|
| 119 |
+
\uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
|
| 120 |
+
set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
|
| 121 |
+
and the second part is in the range U+DC00 to U+DFFF.
|
| 122 |
+
\uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
|
| 123 |
+
to U+DFFF, and is optional.
|
| 124 |
+
[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
self.pattern = re.compile(
|
| 128 |
+
r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
|
| 129 |
+
r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
"""Documentation:
|
| 133 |
+
This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
|
| 134 |
+
"""
|
| 135 |
+
self.pattern = re.compile(r"[^\x00-\x7F]+")
|
| 136 |
+
|
| 137 |
+
def __call__(self, text: str) -> str:
|
| 138 |
+
text = unicodedata.normalize("NFC", text) # canon forms
|
| 139 |
+
text = self.pattern.sub(" ", text) # pattern match
|
| 140 |
+
text = re.sub(" +", " ", text) # collapse whitespaces
|
| 141 |
+
text = "".join(
|
| 142 |
+
c for c in text if unicodedata.category(c) != "Cc"
|
| 143 |
+
) # 删除所有剩余的不可打印字符
|
| 144 |
+
return text
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TrueCaser:
|
| 148 |
+
"""真大小写还原,是一种将文本还原为其原始大小写形式的大小写规范化处理。
|
| 149 |
+
|
| 150 |
+
这可以防御那些像 spOngBoB 那样随机大小写的攻击。
|
| 151 |
+
|
| 152 |
+
这里使用了简单的词性标注器。
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
uppercase_pos = ["PROPN"] # 应使用大写字母命名POS
|
| 157 |
+
|
| 158 |
+
def __init__(self, backend="spacy"):
|
| 159 |
+
if backend == "spacy":
|
| 160 |
+
import spacy
|
| 161 |
+
|
| 162 |
+
self.nlp = spacy.load("en_core_web_sm")
|
| 163 |
+
self.normalize_fn = self._spacy_truecasing
|
| 164 |
+
else:
|
| 165 |
+
from nltk import pos_tag, word_tokenize # noqa
|
| 166 |
+
import nltk
|
| 167 |
+
|
| 168 |
+
nltk.download("punkt")
|
| 169 |
+
nltk.download("averaged_perceptron_tagger")
|
| 170 |
+
nltk.download("universal_tagset")
|
| 171 |
+
self.normalize_fn = self._nltk_truecasing
|
| 172 |
+
|
| 173 |
+
def __call__(self, random_capitalized_string: str) -> str:
|
| 174 |
+
truecased_str = self.normalize_fn(random_capitalized_string)
|
| 175 |
+
return truecased_str
|
| 176 |
+
|
| 177 |
+
def _spacy_truecasing(self, random_capitalized_string: str):
|
| 178 |
+
doc = self.nlp(random_capitalized_string.lower())
|
| 179 |
+
POS = self.uppercase_pos
|
| 180 |
+
truecased_str = "".join(
|
| 181 |
+
[
|
| 182 |
+
w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws
|
| 183 |
+
for w in doc
|
| 184 |
+
]
|
| 185 |
+
)
|
| 186 |
+
return truecased_str
|
| 187 |
+
|
| 188 |
+
def _nltk_truecasing(self, random_capitalized_string: str):
|
| 189 |
+
from nltk import pos_tag, word_tokenize
|
| 190 |
+
import nltk
|
| 191 |
+
|
| 192 |
+
nltk.download("punkt")
|
| 193 |
+
nltk.download("averaged_perceptron_tagger")
|
| 194 |
+
nltk.download("universal_tagset")
|
| 195 |
+
POS = ["NNP", "NNPS"]
|
| 196 |
+
|
| 197 |
+
tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
|
| 198 |
+
truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
|
| 199 |
+
return truecased_str
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nltk
|
| 2 |
+
scipy
|
| 3 |
+
torch
|
| 4 |
+
transformers
|
| 5 |
+
tokenizers
|
| 6 |
+
accelerate
|
| 7 |
+
text-generation==0.3.1
|
| 8 |
+
optimum
|
| 9 |
+
auto-gptq
|
| 10 |
+
cpm_kernels
|
| 11 |
+
bitsandbytes
|
| 12 |
+
gradio==3.50.2
|
watermark_processor.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
from math import sqrt
|
| 5 |
+
|
| 6 |
+
import scipy.stats
|
| 7 |
+
import torch
|
| 8 |
+
from nltk.util import ngrams
|
| 9 |
+
from tokenizers import Tokenizer
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from transformers import LogitsProcessor
|
| 12 |
+
|
| 13 |
+
from normalizers import normalization_strategy_lookup
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class WatermarkBase:
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
vocab: list[int] = None,
|
| 20 |
+
gamma: float = 0.5,
|
| 21 |
+
delta: float = 2.0,
|
| 22 |
+
seeding_scheme: str = "simple_1",
|
| 23 |
+
hash_key: int = 15485863, # 只需要一个大素数就可以创建一个具有足够位宽的rng种子
|
| 24 |
+
extra_salt: int = 0,
|
| 25 |
+
select_green_tokens: bool = True,
|
| 26 |
+
):
|
| 27 |
+
|
| 28 |
+
# 水印参数
|
| 29 |
+
self.vocab = vocab
|
| 30 |
+
self.vocab_size = len(vocab)
|
| 31 |
+
self.gamma = gamma
|
| 32 |
+
self.delta = delta
|
| 33 |
+
self.seeding_scheme = seeding_scheme
|
| 34 |
+
self.rng = None
|
| 35 |
+
self.hash_key = hash_key
|
| 36 |
+
self.extra_salt = extra_salt
|
| 37 |
+
self.select_green_tokens = select_green_tokens
|
| 38 |
+
|
| 39 |
+
def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
|
| 40 |
+
# 可以选择覆盖种子设定方案,但默认情况下使用实例属性
|
| 41 |
+
if seeding_scheme is None:
|
| 42 |
+
seeding_scheme = self.seeding_scheme
|
| 43 |
+
|
| 44 |
+
if seeding_scheme == "simple_1":
|
| 45 |
+
assert input_ids.shape[
|
| 46 |
+
-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
|
| 47 |
+
prev_token = input_ids[-1].item()
|
| 48 |
+
self.rng.manual_seed(self.hash_key * prev_token + self.extra_salt)
|
| 49 |
+
else:
|
| 50 |
+
raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
|
| 54 |
+
|
| 55 |
+
self._seed_rng(input_ids)
|
| 56 |
+
|
| 57 |
+
greenlist_size = int(self.vocab_size * self.gamma)
|
| 58 |
+
|
| 59 |
+
if input_ids.device != 'cpu':
|
| 60 |
+
# 为了确保能在不同设备上复现,这里的随机数生成都用cpu
|
| 61 |
+
vocab_permutation = torch.randperm(self.vocab_size, device='cpu', generator=self.rng)
|
| 62 |
+
vocab_permutation = vocab_permutation.to(input_ids.device)
|
| 63 |
+
|
| 64 |
+
else:
|
| 65 |
+
vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if self.select_green_tokens: # directly
|
| 69 |
+
greenlist_ids = vocab_permutation[:greenlist_size] # new
|
| 70 |
+
else: # 从红色中挑选绿色
|
| 71 |
+
greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size):]
|
| 72 |
+
return greenlist_ids
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
|
| 76 |
+
|
| 77 |
+
def __init__(self, *args, **kwargs):
|
| 78 |
+
super().__init__(*args, **kwargs)
|
| 79 |
+
|
| 80 |
+
def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
|
| 81 |
+
# TODO lets see if we can lose this loop
|
| 82 |
+
green_tokens_mask = torch.zeros_like(scores)
|
| 83 |
+
for b_idx in range(len(greenlist_token_ids)):
|
| 84 |
+
green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
|
| 85 |
+
final_mask = green_tokens_mask.bool()
|
| 86 |
+
return final_mask
|
| 87 |
+
|
| 88 |
+
def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor,
|
| 89 |
+
greenlist_bias: float) -> torch.Tensor:
|
| 90 |
+
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
| 91 |
+
return scores
|
| 92 |
+
|
| 93 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 94 |
+
if self.rng is None:
|
| 95 |
+
# self.rng = torch.Generator(device=input_ids.device)
|
| 96 |
+
self.rng = torch.Generator(device='cpu')
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# 注意:理想情况下应该去掉这个批处理循环,但目前,
|
| 100 |
+
# 种子和分区操作还没有实现向量化,因此
|
| 101 |
+
# 批处理中的每个序列需要单独处理。
|
| 102 |
+
|
| 103 |
+
batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
|
| 104 |
+
|
| 105 |
+
for b_idx in range(input_ids.shape[0]):
|
| 106 |
+
|
| 107 |
+
greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
|
| 108 |
+
batched_greenlist_ids[b_idx] = greenlist_ids
|
| 109 |
+
|
| 110 |
+
green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
|
| 111 |
+
|
| 112 |
+
scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
|
| 113 |
+
return scores
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class WatermarkDetector(WatermarkBase):
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
*args,
|
| 120 |
+
device: torch.device = None,
|
| 121 |
+
tokenizer: Tokenizer = None,
|
| 122 |
+
z_threshold: float = 4.0,
|
| 123 |
+
normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
|
| 124 |
+
ignore_repeated_bigrams: bool = False,
|
| 125 |
+
**kwargs,
|
| 126 |
+
):
|
| 127 |
+
super().__init__(*args, **kwargs)
|
| 128 |
+
|
| 129 |
+
assert device, "Must pass device"
|
| 130 |
+
assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
|
| 131 |
+
|
| 132 |
+
self.tokenizer = tokenizer
|
| 133 |
+
self.device = device
|
| 134 |
+
self.z_threshold = z_threshold
|
| 135 |
+
# self.rng = torch.Generator(device=self.device)
|
| 136 |
+
self.rng = torch.Generator(device='cpu')
|
| 137 |
+
|
| 138 |
+
if self.seeding_scheme == "simple_1":
|
| 139 |
+
self.min_prefix_len = 1
|
| 140 |
+
else:
|
| 141 |
+
raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
|
| 142 |
+
|
| 143 |
+
self.normalizers = []
|
| 144 |
+
for normalization_strategy in normalizers:
|
| 145 |
+
self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
|
| 146 |
+
|
| 147 |
+
self.ignore_repeated_bigrams = ignore_repeated_bigrams
|
| 148 |
+
if self.ignore_repeated_bigrams:
|
| 149 |
+
assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
|
| 150 |
+
|
| 151 |
+
def _compute_z_score(self, observed_count, T):
|
| 152 |
+
# count是指绿色token的数量,T是token的总数
|
| 153 |
+
expected_count = self.gamma
|
| 154 |
+
numer = observed_count - expected_count * T
|
| 155 |
+
denom = sqrt(T * expected_count * (1 - expected_count))
|
| 156 |
+
z = numer / denom
|
| 157 |
+
return z
|
| 158 |
+
|
| 159 |
+
def _compute_p_value(self, z):
|
| 160 |
+
p_value = scipy.stats.norm.sf(z)
|
| 161 |
+
return p_value
|
| 162 |
+
|
| 163 |
+
def _score_sequence(
|
| 164 |
+
self,
|
| 165 |
+
input_ids: Tensor,
|
| 166 |
+
return_num_tokens_scored: bool = True,
|
| 167 |
+
return_num_green_tokens: bool = True,
|
| 168 |
+
return_green_fraction: bool = True,
|
| 169 |
+
return_green_token_mask: bool = False,
|
| 170 |
+
return_z_score: bool = True,
|
| 171 |
+
return_p_value: bool = True,
|
| 172 |
+
):
|
| 173 |
+
if self.ignore_repeated_bigrams:
|
| 174 |
+
# 一个方法,只对每个唯一的bigram计算一次绿色/红色命中。
|
| 175 |
+
# 新的总标记评分数(T)变为唯一bigram的数量。
|
| 176 |
+
# 我们遍历输入中的所有唯一的标记bigram,计算每个bigram的第一个标记诱导的绿名单,
|
| 177 |
+
# 然后检查第二个标记是否落在该绿名单中。
|
| 178 |
+
|
| 179 |
+
assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats."
|
| 180 |
+
bigram_table = {}
|
| 181 |
+
token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
|
| 182 |
+
freq = collections.Counter(token_bigram_generator)
|
| 183 |
+
num_tokens_scored = len(freq.keys())
|
| 184 |
+
for idx, bigram in enumerate(freq.keys()):
|
| 185 |
+
prefix = torch.tensor([bigram[0]],
|
| 186 |
+
device=self.device) # expects a 1-d prefix tensor on the randperm device
|
| 187 |
+
greenlist_ids = self._get_greenlist_ids(prefix)
|
| 188 |
+
bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
|
| 189 |
+
green_token_count = sum(bigram_table.values())
|
| 190 |
+
else:
|
| 191 |
+
num_tokens_scored = len(input_ids) - self.min_prefix_len
|
| 192 |
+
if num_tokens_scored < 1:
|
| 193 |
+
raise ValueError((f"Must have at least {1} token to score after "
|
| 194 |
+
f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."))
|
| 195 |
+
# 标准方法
|
| 196 |
+
# 由于我们通常至少需要1个token(对于最简单的方案)
|
| 197 |
+
# 我们从最小数量的token开始迭代token序列,作为种子方案的第一个前缀,
|
| 198 |
+
# 在每一步中,计算当前前缀诱导的绿名单,
|
| 199 |
+
# 并检查当前token是否落在绿名单中。
|
| 200 |
+
|
| 201 |
+
green_token_count, green_token_mask = 0, []
|
| 202 |
+
for idx in range(self.min_prefix_len, len(input_ids)):
|
| 203 |
+
curr_token = input_ids[idx]
|
| 204 |
+
greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
|
| 205 |
+
if curr_token in greenlist_ids:
|
| 206 |
+
green_token_count += 1
|
| 207 |
+
green_token_mask.append(True)
|
| 208 |
+
else:
|
| 209 |
+
green_token_mask.append(False)
|
| 210 |
+
|
| 211 |
+
score_dict = dict()
|
| 212 |
+
if return_num_tokens_scored:
|
| 213 |
+
score_dict.update(dict(num_tokens_scored=num_tokens_scored))
|
| 214 |
+
if return_num_green_tokens:
|
| 215 |
+
score_dict.update(dict(num_green_tokens=green_token_count))
|
| 216 |
+
if return_green_fraction:
|
| 217 |
+
score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
|
| 218 |
+
if return_z_score:
|
| 219 |
+
score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
|
| 220 |
+
if return_p_value:
|
| 221 |
+
z_score = score_dict.get("z_score")
|
| 222 |
+
if z_score is None:
|
| 223 |
+
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
|
| 224 |
+
score_dict.update(dict(p_value=self._compute_p_value(z_score)))
|
| 225 |
+
if return_green_token_mask:
|
| 226 |
+
score_dict.update(dict(green_token_mask=green_token_mask))
|
| 227 |
+
|
| 228 |
+
return score_dict
|
| 229 |
+
|
| 230 |
+
def detect(
|
| 231 |
+
self,
|
| 232 |
+
text: str = None,
|
| 233 |
+
tokenized_text: list[int] = None,
|
| 234 |
+
return_prediction: bool = True,
|
| 235 |
+
return_scores: bool = True,
|
| 236 |
+
z_threshold: float = None,
|
| 237 |
+
**kwargs,
|
| 238 |
+
) -> dict:
|
| 239 |
+
|
| 240 |
+
assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
|
| 241 |
+
if return_prediction:
|
| 242 |
+
kwargs["return_p_value"] = True # 返回阳性检测的"confidence":=1-p
|
| 243 |
+
|
| 244 |
+
# 运行可选的normalizers
|
| 245 |
+
for normalizer in self.normalizers:
|
| 246 |
+
text = normalizer(text)
|
| 247 |
+
if len(self.normalizers) > 0:
|
| 248 |
+
print(f"Text after normalization:\n\n{text}\n")
|
| 249 |
+
|
| 250 |
+
if tokenized_text is None:
|
| 251 |
+
assert self.tokenizer is not None, (
|
| 252 |
+
"Watermark detection on raw string ",
|
| 253 |
+
"requires an instance of the tokenizer ",
|
| 254 |
+
"that was used at generation time.",
|
| 255 |
+
)
|
| 256 |
+
tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(
|
| 257 |
+
self.device)
|
| 258 |
+
if tokenized_text[0] == self.tokenizer.bos_token_id:
|
| 259 |
+
tokenized_text = tokenized_text[1:]
|
| 260 |
+
else:
|
| 261 |
+
# 尝试一开始就删除bos_tok(如果它在那里的话)
|
| 262 |
+
if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
|
| 263 |
+
tokenized_text = tokenized_text[1:]
|
| 264 |
+
|
| 265 |
+
# 调用score方法
|
| 266 |
+
output_dict = {}
|
| 267 |
+
score_dict = self._score_sequence(tokenized_text, **kwargs)
|
| 268 |
+
if return_scores:
|
| 269 |
+
output_dict.update(score_dict)
|
| 270 |
+
# 如果通过return_prediction,则执行假设检验并返回结果
|
| 271 |
+
if return_prediction:
|
| 272 |
+
z_threshold = z_threshold if z_threshold else self.z_threshold
|
| 273 |
+
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
|
| 274 |
+
output_dict["prediction"] = score_dict["z_score"] > z_threshold
|
| 275 |
+
if output_dict["prediction"]:
|
| 276 |
+
output_dict["confidence"] = 1 - score_dict["p_value"]
|
| 277 |
+
|
| 278 |
+
return output_dict
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList
|
| 283 |
+
|
| 284 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
|
| 285 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
|
| 286 |
+
|
| 287 |
+
watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
|
| 288 |
+
gamma=0.5,
|
| 289 |
+
delta=2,
|
| 290 |
+
seeding_scheme="simple_1",
|
| 291 |
+
extra_salt=0,
|
| 292 |
+
select_green_tokens=True)
|
| 293 |
+
|
| 294 |
+
messages = [
|
| 295 |
+
# {"role": "system", "content": "You are a helpful assistant."},
|
| 296 |
+
{"role": "user", "content": "讲一段明代的历史"}
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
tokenized_input = tokenizer.apply_chat_template(
|
| 300 |
+
messages,
|
| 301 |
+
tokenize=False,
|
| 302 |
+
add_generation_prompt=True
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
tokd_input = tokenizer([tokenized_input], return_tensors="pt", truncation=True, add_special_tokens=False,
|
| 306 |
+
max_length=2000).to(model.device)
|
| 307 |
+
|
| 308 |
+
logits_processor = LogitsProcessorList([watermark_processor])
|
| 309 |
+
|
| 310 |
+
output = model.generate(
|
| 311 |
+
tokd_input.input_ids,
|
| 312 |
+
max_new_tokens=500,
|
| 313 |
+
logits_processor=logits_processor,
|
| 314 |
+
do_sample=True,
|
| 315 |
+
temperature=0.7,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
print(tokenizer.decode(output[0]))
|
| 319 |
+
|
| 320 |
+
watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
|
| 321 |
+
gamma=0.5,
|
| 322 |
+
seeding_scheme="simple_1",
|
| 323 |
+
extra_salt=0,
|
| 324 |
+
device=torch.device("cpu"),
|
| 325 |
+
tokenizer=tokenizer,
|
| 326 |
+
z_threshold=4,
|
| 327 |
+
# normalizers='',
|
| 328 |
+
ignore_repeated_bigrams=False,
|
| 329 |
+
select_green_tokens=True)
|
| 330 |
+
|
| 331 |
+
print(watermark_detector.detect(tokenizer.decode(output[0]), return_prediction=True, return_scores=True,
|
| 332 |
+
z_threshold=4))
|
| 333 |
+
|
| 334 |
+
# print(watermark_detector.detect("抱歉,作为人工智能语言模型,我无法提供有关希腊历史的信息。我的目的是为用户提供有用的和有用的回答,而不仅仅是提供错误的信息。请告诉我您想要了解的是什么内容。 ", return_prediction=True, return_scores=True, z_threshold=4))
|
| 335 |
+
#
|
| 336 |
+
# print(watermark_detector.detect("明朝是中国历史上一个重要的朝代,从1368年至1542年,中国经历了从宋朝的衰败到明朝的兴盛,这个时期的朝代特征鲜明,政治、经济和社会都取得了显著的进步。 政治方面:明朝的君主制度比较完善,以皇权为核心,实行中央集权。此外,明朝还实行了科举考试制度,选拔了许多优秀人才,使得社会���气更加开放。 经济上:明朝的经济实力非常强大,尤其是手工业和农业发展迅速。明朝的海上贸易也非常发达,被誉为“海路不穷”。另外,明朝还通过派遣郑和等船队出使西洋,传播中国文化,扩大对外交流。 社会上:明朝的社会结构相对稳定,人们生活水平不断提高。然而,这个时期也存在一些问题,如农民起义、土地兼并等。 文化方面:明朝的文化非常丰富多样,有诗词歌赋、绘画、建筑等众多艺术形式。此外,明人的文学创作也非常出色,如《西游记》、《红楼梦》等经典作品。 总之,明朝是中国历史上的一个重要阶段,它的繁荣与衰败、创新与发展深深地影响了后世的人民和社会的发展。 ", return_prediction=True, return_scores=True, z_threshold=4))
|