Spaces:
Sleeping
Sleeping
File size: 15,080 Bytes
48097f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 |
# Original code from ProtMamba under Apache License 2.0.
#
# Modifications made by Niklas Schmidinger, Lisa Schneckenreiter and Sohvi Luukkonen
# - Add option to pass input state for generation
# - Add functions to generate sequences with xlstm
import numpy as np
import torch
from protxlstm.mamba_utils_generation import (
InferenceParams,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
modify_logits_for_top_p_filtering,
modify_logits_for_min_p_filtering,
modify_logit_for_repetition_penalty,
SampleDecoderOnlyOutput,
update_graph_cache
)
from protxlstm.utils import AA_TO_ID, decode_sequence
def sample_safe(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if top_k == 1: # Short-circuit for greedy decoding
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
if temperature != 1.0:
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1),
]
else:
if min_p > 0.0:
logits_top = logits.clone()
max_prob = logits_top[..., 0].item()
min_prob = max_prob * min_p
modify_logits_for_min_p_filtering(logits_top, min_p)
if temperature != 1.0:
logits_top /= temperature
return torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1)
# Clone so that when we modify for top_p we don't change the original logits
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(
torch.softmax(logits_top, dim=-1), num_samples=1
).squeeze(dim=-1)
@torch.inference_mode()
def decode_safe(
input_ids,
position_ids,
seq_position_ids,
is_fim,
model,
max_length,
state=None,
top_k=1,
top_p=0.0,
min_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
cg=False,
enable_timing=False,
streamer = None,
chunk_chunk_size = 2**15,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
is_fim: dictionary with mask indices and associated position indices
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
if streamer is not None:
streamer.put(input_ids.cpu())
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
)
inference_params = model._decoding_cache.inference_params
inference_params.reset(max_length, batch_size)
else:
inference_params = InferenceParams(
max_seqlen=max_length, max_batch_size=batch_size
)
def get_logits(input_ids, position_ids, seq_position_ids, inference_params):
decoding = inference_params.seqlen_offset > 0
if not cg or not decoding:
logits = model(
input_ids,
position_ids=position_ids,
seq_position_ids=seq_position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
logits = model._decoding_cache.run(
input_ids,
position_ids,
inference_params.seqlen_offset,
seq_position_ids=seq_position_ids,
).squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits
def get_xlstm_logits_step(input_ids, position_ids, seq_position_ids, state):
if not input_ids.shape[1] == 1:
for i in range(input_ids.shape[1]):
if position_ids != None:
token_position_ids = position_ids[:,i:(i+1)]
else:
token_position_ids = None
if seq_position_ids != None:
token_seq_position_ids = seq_position_ids[:,i:(i+1)]
else:
token_seq_position_ids = None
logits, state = model.step(input_ids[:,i:(i+1)], state, position_ids=token_position_ids, seq_position_ids=token_seq_position_ids)
else:
logits, state = model.step(input_ids, state, position_ids=position_ids, seq_position_ids=seq_position_ids)
logits = logits.squeeze(dim=1)
if vocab_size is not None:
logits = logits[..., :vocab_size]
return logits, state
def get_xlstm_logits_chunkwise(input_ids, position_ids, seq_position_ids, chunk_chunk_size=2**15, state=None):
assert model.config.config_dataclass.mlstm_block.mlstm.backend == "chunkwise_variable"
for chunk in range(input_ids.shape[1]//chunk_chunk_size+1):
start_idx = chunk*chunk_chunk_size
end_idx = min((chunk+1)*chunk_chunk_size, input_ids.shape[1])
if start_idx == end_idx:
pass
else:
input_ids_chunk = input_ids[:, start_idx:end_idx]
if not position_ids == None:
position_ids_chunk = position_ids[:, start_idx:end_idx]
else:
position_ids_chunk = None
if not seq_position_ids == None:
seq_position_ids_chunk = seq_position_ids[:, start_idx:end_idx]
else:
seq_position_ids_chunk = None
outputs = model(input_ids_chunk, position_ids=position_ids_chunk, seq_position_ids=seq_position_ids_chunk, state=state)
logits, state = outputs.logits, outputs.state
logits = logits[:,-1,:]
logits = logits.squeeze(dim=1)
if vocab_size is not None:
logits = logits[..., :vocab_size]
return logits, state
def sample_tokens(logits, inference_params):
if (
teacher_outputs is None
or teacher_output_len <= inference_params.seqlen_offset
):
token = sample_safe(
logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature
)
else:
token = teacher_outputs[:, inference_params.seqlen_offset]
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def get_fim_position_id(
last_position_ids, sampled_tokens, is_fim, repeat_next=False
):
if type(is_fim) is dict:
val = int(last_position_ids) + 1
should_repeat_next = False
if is_fim and int(sampled_tokens) in is_fim:
val = is_fim[int(sampled_tokens)]
should_repeat_next = True
elif repeat_next:
val = int(last_position_ids)
return torch.full_like(last_position_ids, fill_value=val), should_repeat_next
else:
t = [get_fim_position_id(last_position_ids_, sampled_tokens_, is_fim_dict, repeat_next) for
(last_position_ids_, sampled_tokens_, is_fim_dict) in
zip(last_position_ids, sampled_tokens, is_fim)]
return torch.stack([t_[0] for t_ in t], dim=0), t[0][1]
def should_stop(current_token, inference_params):
if inference_params.seqlen_offset == 0:
return False
if eos_token_id is not None and (current_token == eos_token_id).any():
if current_token.shape[1] > 1:
raise NotImplementedError("Batched eos_token_id not supported")
return True
if inference_params.seqlen_offset >= max_length - 1:
return True
return False
start = torch.cuda.Event(enable_timing=enable_timing)
end = torch.cuda.Event(enable_timing=enable_timing)
if enable_timing:
start.record()
scores, sequences = [], [input_ids]
new_position_ids, new_seq_position_ids = [position_ids], [seq_position_ids]
sequences_cat = input_ids
repeat_next = False
if position_ids.shape[0] > 1:
raise NotImplementedError("Batched generation with position_ids not supported")
encode_context=True
while not should_stop(sequences[-1], inference_params):
from protxlstm.models.xlstm import xLSTMLMHeadModel
if isinstance(model, xLSTMLMHeadModel):
if encode_context:
with torch.no_grad():
logits, state = get_xlstm_logits_chunkwise(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state, chunk_chunk_size=chunk_chunk_size)
encode_context = False
else:
logits, state = get_xlstm_logits_step(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], state=state)
else:
logits = get_logits(sequences[-1], new_position_ids[-1], new_seq_position_ids[-1], inference_params)
scores.append(logits)
inference_params.seqlen_offset += sequences[-1].shape[1]
if repetition_penalty == 1.0:
sampled_tokens = sample_tokens(scores[-1], inference_params)
else:
logits = modify_logit_for_repetition_penalty(
scores[-1].clone(), sequences_cat, repetition_penalty
)
sampled_tokens = sample_tokens(logits, inference_params)
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
sequences.append(sampled_tokens)
# Update position_ids
if position_ids is not None:
last_position_ids, repeat_next = get_fim_position_id(
new_position_ids[-1][:, -1:], sampled_tokens, is_fim, repeat_next
)
new_position_ids.append(last_position_ids)
# Update seq_position_ids
if seq_position_ids is not None:
new_seq_position_ids.append(new_seq_position_ids[-1][:, -1:])
if streamer is not None:
streamer.put(sampled_tokens.cpu())
if streamer is not None:
streamer.end()
if enable_timing:
end.record()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = (
GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
)
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
class GenerationMixinSafe(GenerationMixin):
def generate(
self,
input_ids,
position_ids,
seq_position_ids,
is_fim=None,
state=None,
max_length=1,
top_k=1,
top_p=0.0,
min_p=0.0,
temperature=1.0,
return_dict_in_generate=False,
output_scores=False,
chunk_chunk_size=2**15,
**kwargs,
):
output = decode_safe(
input_ids,
position_ids,
seq_position_ids,
is_fim,
self,
max_length,
state=state,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temperature=temperature,
chunk_chunk_size=chunk_chunk_size,
**kwargs,
)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
def generate_sequence(model, tokens, position_ids=None, seq_position_ids=None, state=None, is_fim=False, max_length=2000, temperature=1., top_p=0.0, top_k=1,
return_dict_in_generate=False, output_scores=False, eos_token_id=AA_TO_ID["<cls>"], device="cuda", chunk_chunk_size=2**15):
"""Generating, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p. We assume that all sequences in the same batch have the same length.
"""
input_ids = tokens.to(device)
position_ids = position_ids.to(device) if position_ids is not None else None
seq_position_ids = seq_position_ids.to(device) if seq_position_ids is not None else None
# generate sequence
out = model.generate(input_ids=input_ids,
position_ids=position_ids,
seq_position_ids=seq_position_ids,
is_fim=is_fim,
state=state,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
eos_token_id=eos_token_id,
chunk_chunk_size=chunk_chunk_size,
)
sequences = out.sequences
dic = {"input": [decode_sequence(seq) for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
"generated": [decode_sequence(seq) for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()],
"input_tokens": [seq for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()],
"generated_tokens": [seq for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()]}
if output_scores:
dic["scores"] = np.array([el.to(torch.float32).cpu().numpy() for el in out.scores]).transpose(1, 0, 2)
return dic
|