Tschoui's picture
Migrate application to hugginface
48097f5
# Original code from ProtMamba under Apache License 2.0.
#
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
# - Add option to pass input state for generation
# - Add functions to generate sequences with xlstm
import numpy as np
import torch
from protxlstm.mamba_utils_generation import (
InferenceParams,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
modify_logits_for_top_p_filtering,
modify_logits_for_min_p_filtering,
modify_logit_for_repetition_penalty,
SampleDecoderOnlyOutput,
update_graph_cache
)
from protxlstm.utils import AA_TO_ID, decode_sequence
def sample_safe(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if top_k == 1: # Short-circuit for greedy decoding
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
if temperature != 1.0:
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1),
]
else:
if min_p > 0.0:
logits_top = logits.clone()
max_prob = logits_top[..., 0].item()
min_prob = max_prob * min_p
modify_logits_for_min_p_filtering(logits_top, min_p)
if temperature != 1.0:
logits_top /= temperature
return torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1)
# Clone so that when we modify for top_p we don't change the original logits
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1)
@torch.inference_mode()
def decode_safe(
input_ids,
position_ids,
seq_position_ids,
is_fim,
model,
max_length,
state=None,
top_k=1,
top_p=0.0,
min_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
cg=False,
enable_timing=False,
streamer = None,
chunk_chunk_size = 2**15,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
is_fim: dictionary with mask indices and associated position indices
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
if streamer is not None:
streamer.put(input_ids.cpu())
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
)
inference_params = model._decoding_cache.inference_params
inference_params.reset(max_length, batch_size)
else:
inference_params = InferenceParams(
max_seqlen=max_length, max_batch_size=batch_size
)
def get_logits(input_ids, position_ids, seq_position_ids, inference_params):
decoding = inference_params.seqlen_offset > 0
if not cg or not decoding:
logits = model(
input_ids,
position_ids=position_ids,
seq_position_ids=seq_position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
logits = model._decoding_cache.run(
input_ids,
position_ids,
inference_params.seqlen_offset,
seq_position_ids=seq_position_ids,
).squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits
def get_xlstm_logits_step(input_ids, position_ids, seq_position_ids, state):
if not input_ids.shape[1] == 1:
for i in range(input_ids.shape[1]):
if position_ids != None:
token_position_ids = position_ids[:,i:(i+1)]
else:
token_position_ids = None
if seq_position_ids != None:
token_seq_position_ids = seq_position_ids[:,i:(i+1)]
else:
token_seq_position_ids = None
logits, state = model.step(input_ids[:,i:(i+1)], state, position_ids=token_position_ids, seq_position_ids=token_seq_position_ids)
else:
logits, state = model.step(input_ids, state, position_ids=position_ids, seq_position_ids=seq_position_ids)
logits = logits.squeeze(dim=1)
if vocab_size is not None:
logits = logits[..., :vocab_size]
return logits, state
def get_xlstm_logits_chunkwise(input_ids, position_ids, seq_position_ids, chunk_chunk_size=2**15, state=None):
assert model.config.config_dataclass.mlstm_block.mlstm.backend == "chunkwise_variable"
for chunk in range(input_ids.shape[1]//chunk_chunk_size+1):
start_idx = chunk*chunk_chunk_size
end_idx = min((chunk+1)*chunk_chunk_size, input_ids.shape[1])
if start_idx == end_idx:
pass
else:
input_ids_chunk = input_ids[:, start_idx:end_idx]
if not position_ids == None:
position_ids_chunk = position_ids[:, start_idx:end_idx]
else:
position_ids_chunk = None
if not seq_position_ids == None:
seq_position_ids_chunk = seq_position_ids[:, start_idx:end_idx]
else:
seq_position_ids_chunk = None
outputs = model(input_ids_chunk, position_ids=position_ids_chunk, seq_position_ids=seq_position_ids_chunk, state=state)
logits, state = outputs.logits, outputs.state
logits = logits[:,-1,:]
logits = logits.squeeze(dim=1)
if vocab_size is not None:
logits = logits[..., :vocab_size]
return logits, state
def sample_tokens(logits, inference_params):
if (
teacher_outputs is None
or teacher_output_len <= inference_params.seqlen_offset
):
token = sample_safe(
logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
)
else:
token = teacher_outputs[:, inference_params.seqlen_offset]
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def get_fim_position_id(
last_position_ids, sampled_tokens, is_fim, repeat_next=False
):
if type(is_fim) is dict:
val = int(last_position_ids) + 1
should_repeat_next = False
if is_fim and int(sampled_tokens) in is_fim:
val = is_fim[int(sampled_tokens)]
should_repeat_next = True
elif repeat_next:
val = int(last_position_ids)
return torch.full_like(last_position_ids, fill_value=val), should_repeat_next
else:
t = [get_fim_position_id(last_position_ids_, sampled_tokens_, is_fim_dict, repeat_next) for
(last_position_ids_, sampled_tokens_, is_fim_dict) in
zip(last_position_ids, sampled_tokens, is_fim)]
return torch.stack([t_[0] for t_ in t], dim=0), t[0][1]
def should_stop(current_token, inference_params):
if inference_params.seqlen_offset == 0:
return False
if eos_token_id is not None and (current_token == eos_token_id).any():
if current_token.shape[1] > 1:
raise NotImplementedError("Batched eos_token_id not supported")
return True
if inference_params.seqlen_offset >= max_length - 1:
return True
return False
start = torch.cuda.Event(enable_timing=enable_timing)
end = torch.cuda.Event(enable_timing=enable_timing)
if enable_timing:
start.record()
scores, sequences = [], [input_ids]
new_position_ids, new_seq_position_ids = [position_ids], [seq_position_ids]
sequences_cat = input_ids
repeat_next = False
if position_ids.shape[0] > 1:
raise NotImplementedError("Batched generation with position_ids not supported")
encode_context=True
while not should_stop(sequences[-1], inference_params):
from protxlstm.models.xlstm import xLSTMLMHeadModel
if isinstance(model, xLSTMLMHeadModel):
if encode_context:
with torch.no_grad():
logits, state = get_xlstm_logits_chunkwise(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state, chunk_chunk_size=chunk_chunk_size)
encode_context = False
else:
logits, state = get_xlstm_logits_step(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state)
else:
logits = get_logits(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], inference_params)
scores.append(logits)
inference_params.seqlen_offset += sequences[-1].shape[1]
if repetition_penalty == 1.0:
sampled_tokens = sample_tokens(scores[-1], inference_params)
else:
logits = modify_logit_for_repetition_penalty(
scores[-1].clone(), sequences_cat, repetition_penalty
)
sampled_tokens = sample_tokens(logits, inference_params)
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
sequences.append(sampled_tokens)
# Update position_ids
if position_ids is not None:
last_position_ids, repeat_next = get_fim_position_id(
new_position_ids[-1][:, -1:], sampled_tokens, is_fim, repeat_next
)
new_position_ids.append(last_position_ids)
# Update seq_position_ids
if seq_position_ids is not None:
new_seq_position_ids.append(new_seq_position_ids[-1][:, -1:])
if streamer is not None:
streamer.put(sampled_tokens.cpu())
if streamer is not None:
streamer.end()
if enable_timing:
end.record()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = (
GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
)
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
class GenerationMixinSafe(GenerationMixin):
def generate(
self,
input_ids,
position_ids,
seq_position_ids,
is_fim=None,
state=None,
max_length=1,
top_k=1,
top_p=0.0,
min_p=0.0,
temperature=1.0,
return_dict_in_generate=False,
output_scores=False,
chunk_chunk_size=2**15,
**kwargs,
):
output = decode_safe(
input_ids,
position_ids,
seq_position_ids,
is_fim,
self,
max_length,
state=state,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temperature=temperature,
chunk_chunk_size=chunk_chunk_size,
**kwargs,
)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
def generate_sequence(model, tokens, position_ids=None, seq_position_ids=None, state=None, is_fim=False, max_length=2000, temperature=1., top_p=0.0, top_k=1,
return_dict_in_generate=False, output_scores=False, eos_token_id=AA_TO_ID["<cls>"], device="cuda", chunk_chunk_size=2**15):
"""Generating, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p. We assume that all sequences in the same batch have the same length.
"""
input_ids = tokens.to(device)
position_ids = position_ids.to(device) if position_ids is not None else None
seq_position_ids = seq_position_ids.to(device) if seq_position_ids is not None else None
# generate sequence
out = model.generate(input_ids=input_ids,
position_ids=position_ids,
seq_position_ids=seq_position_ids,
is_fim=is_fim,
state=state,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
eos_token_id=eos_token_id,
chunk_chunk_size=chunk_chunk_size,
)
sequences = out.sequences
dic = {"input": [decode_sequence(seq) for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
"generated": [decode_sequence(seq) for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()],
"input_tokens": [seq for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
"generated_tokens": [seq for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()]}
if output_scores:
dic["scores"] = np.array([el.to(torch.float32).cpu().numpy() for el in out.scores]).transpose(1, 0, 2)
return dic