Spaces:
Sleeping
Sleeping
| """实现其他 PRF 函数(这些函数的不同之处仅在于如何从上下文中的令牌生成单个哈希值)。 | |
| 可作为修改后的基类 WatermarkBase 挂接到现有的 WatermarkLogitsProcessor 中,请参见 | |
| extended_watermark_processor.py 中的实现。 | |
| """ | |
| import torch | |
| from itertools import combinations | |
| from functools import cache | |
| # 哈希方案的关键属性 | |
| props = { | |
| "prf_type": str, # 基础 PRF 的字符串名称,将多个令牌 ID 映射到随机种子 | |
| "context_width": int, # 这是论文中的 h,每个 PRF 应考虑多少个先前的令牌 | |
| "self_salt": bool, # 根据鲁棒水印技术中的规则,是否使用令牌本身来生成种子,并可能拒绝其自身的列表 | |
| "hash_key": int, # 整数,大质数,用于将种子移动到上述所选 PRF 中的低熵位序列的远离位置 | |
| } | |
| def seeding_scheme_lookup(seeding_scheme: str): | |
| if not isinstance(seeding_scheme, str): | |
| raise ValueError("Seeding scheme should be a string summarizing the procedure.") | |
| if seeding_scheme == "simple_1" or seeding_scheme == "lefthash": | |
| # 默认的简单二元哈希 # 别名为 ff-additive_prf-1-False-15485863 | |
| prf_type = "additive_prf" | |
| context_width = 1 | |
| self_salt = False | |
| hash_key = 15485863 | |
| elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash": | |
| prf_type = "anchored_minhash_prf" | |
| context_width = 4 | |
| self_salt = True | |
| hash_key = 15485863 | |
| elif seeding_scheme == "minhash": | |
| prf_type = "minhash_prf" | |
| context_width = 4 | |
| self_salt = False | |
| hash_key = 15485863 | |
| elif seeding_scheme == "skipgram": | |
| prf_type = "skipgram_prf" | |
| context_width = 5 | |
| self_salt = False | |
| hash_key = 15485863 | |
| elif seeding_scheme.startswith("ff"): # 自由形式的种子方案 API - 仅用于实验目的 | |
| # 期望形式为 ff-additive_prf-4-True-hash 或 ff-additive_prf-5-True (哈希键是可选的) | |
| split_scheme = seeding_scheme.split("-") | |
| prf_type = str(split_scheme[1]) | |
| context_width = int(split_scheme[2]) | |
| self_salt = split_scheme[3] == "True" | |
| if len(split_scheme) == 5: | |
| hash_key = int(split_scheme[4]) | |
| else: | |
| hash_key = 15485863 | |
| else: | |
| raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?") | |
| assert prf_type in prf_lookup.keys() | |
| return prf_type, context_width, self_salt, hash_key | |
| def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int: | |
| return salt_key * input_ids.prod().item() | |
| def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int: | |
| return salt_key * input_ids.sum().item() | |
| def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int: | |
| # 对于非随机输入 id(如文本),这不是一个好主意 | |
| return salt_key * input_ids.min().item() | |
| def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int: | |
| # k是一个跳跃的距离 | |
| return hashint(salt_key * input_ids[::k]).prod().item() | |
| def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int: | |
| # # 上下文内的最大距离跳字 | |
| return hashint(salt_key * input_ids[0]).item() | |
| def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int: | |
| # 上下文内的最大距离跳字 | |
| return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item() | |
| def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int: | |
| return hashint(salt_key * input_ids).min().item() | |
| def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int: | |
| # 另一个关键是生成一个key | |
| return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item() | |
| def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int: | |
| # 上下文中所有跳字组合的最小值,k=2 表示所有对 | |
| skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2))) | |
| return skipgrams.prod(dim=1).min().item() | |
| def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int: | |
| key = torch.as_tensor(salt_key, dtype=torch.long) | |
| for entry in input_ids: | |
| key *= hashint(key * entry) | |
| key %= 2**32 | |
| return key.item() | |
| def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int: | |
| return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item() | |
| prf_lookup = { | |
| "multiplicative_prf": multiplicative_prf, | |
| "additive_prf": additive_prf, | |
| "minfunc_prf": minfunc_prf, | |
| "simple_skip_prf": simple_skip_prf, | |
| "skipgram_prf": skipgram_prf, | |
| "anchored_skipgram_prf": anchored_skipgram_prf, | |
| "minhash_prf": minhash_prf, | |
| "anchored_minhash_prf": anchored_minhash_prf, | |
| "minskipgram_prf": minskipgram_prf, | |
| "noncomm_prf": noncomm_prf, | |
| "position_prf": position_prf, | |
| } | |
| # 在启动时生成全局置换表一次 | |
| rng = torch.Generator(device=torch.device("cpu")) | |
| rng.manual_seed(2971215073) | |
| table_size = 1_000_003 | |
| fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng) # 这个速度很快 | |
| def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor: | |
| return fixed_table[integer_tensor.cpu() % table_size] + 1 # 这里有一个小技巧,这个函数总是返回 CPU 的值 | |
| def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor): | |
| i = integer_tensor.to(torch.int32).clone() # or torch.int16? | |
| i -= i << 6 | |
| i ^= i >> 17 | |
| i -= i << 9 | |
| i ^= i << 4 | |
| i -= i << 3 | |
| i ^= i << 10 | |
| i ^= i >> 15 | |
| return i.to(torch.long) | |
| def _hashint_avalanche_int(integer: int): | |
| i = integer % (2**32) | |
| i -= i << 6 | |
| i ^= i >> 17 | |
| i -= i << 9 | |
| i ^= i << 4 | |
| i -= i << 3 | |
| i ^= i << 10 | |
| i ^= i >> 15 | |
| return i | |