Spaces:
Running
Running
| from torch_grammar import GrammarSampler | |
| from transformers.generation.logits_process import LogitsProcessor | |
| from modules import shared | |
| sampler = None | |
| grammar = None | |
| grammar_string = '' | |
| class GrammarLogitsProcessor(LogitsProcessor): | |
| def __init__(self, string): | |
| global sampler, grammar, grammar_string | |
| if string != grammar_string: | |
| grammar_string = string | |
| if string.strip() != '': | |
| string = string.strip() + '\n' | |
| sampler = GrammarSampler(string, 'root', shared.tokenizer) | |
| else: | |
| sampler = None | |
| if sampler is not None: | |
| grammar = sampler.logits_processor() | |
| else: | |
| grammar = None | |
| def __call__(self, input_ids, scores): | |
| if grammar is not None: | |
| scores = grammar(input_ids, scores) | |
| return scores | |