Commit 
							
							·
						
						cbd1dd1
	
1
								Parent(s):
							
							e464072
								
:bug: fix past_length in santacoder
Browse files- config.json +1 -1
- modeling_gpt2_mq.py +200 -23
    	
        config.json
    CHANGED
    
    | @@ -14,7 +14,7 @@ | |
| 14 | 
             
              "eos_token_id": 50256,
         | 
| 15 | 
             
              "initializer_range": 0.02,
         | 
| 16 | 
             
              "layer_norm_epsilon": 1e-05,
         | 
| 17 | 
            -
              "model_type": " | 
| 18 | 
             
              "n_embd": 2048,
         | 
| 19 | 
             
              "n_head": 16,
         | 
| 20 | 
             
              "n_inner": 8192,
         | 
|  | |
| 14 | 
             
              "eos_token_id": 50256,
         | 
| 15 | 
             
              "initializer_range": 0.02,
         | 
| 16 | 
             
              "layer_norm_epsilon": 1e-05,
         | 
| 17 | 
            +
              "model_type": "santacoder",
         | 
| 18 | 
             
              "n_embd": 2048,
         | 
| 19 | 
             
              "n_head": 16,
         | 
| 20 | 
             
              "n_inner": 8192,
         | 
    	
        modeling_gpt2_mq.py
    CHANGED
    
    | @@ -1,39 +1,21 @@ | |
| 1 | 
             
            """PyTorch OpenAI GPT-2 model modified with MultiQuery attention"""
         | 
| 2 |  | 
| 3 |  | 
| 4 | 
            -
            import math
         | 
| 5 | 
            -
            import os
         | 
| 6 | 
            -
            from dataclasses import dataclass
         | 
| 7 | 
             
            from typing import Optional, Tuple, Union
         | 
| 8 |  | 
| 9 | 
             
            import torch
         | 
| 10 | 
             
            import torch.utils.checkpoint
         | 
| 11 | 
             
            from torch import nn
         | 
| 12 | 
             
            from torch.cuda.amp import autocast
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            from transformers.activations import ACT2FN
         | 
| 16 | 
            -
            from transformers.modeling_outputs import (
         | 
| 17 | 
            -
                BaseModelOutputWithPastAndCrossAttentions,
         | 
| 18 | 
            -
                CausalLMOutputWithCrossAttentions,
         | 
| 19 | 
            -
                SequenceClassifierOutputWithPast,
         | 
| 20 | 
            -
                TokenClassifierOutput,
         | 
| 21 | 
            -
            )
         | 
| 22 | 
            -
            from transformers.modeling_utils import PreTrainedModel, SequenceSummary
         | 
| 23 | 
             
            from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
         | 
| 24 |  | 
| 25 | 
            -
            from transformers.utils import  | 
| 26 | 
            -
                ModelOutput,
         | 
| 27 | 
            -
                add_code_sample_docstrings,
         | 
| 28 | 
            -
                add_start_docstrings,
         | 
| 29 | 
            -
                add_start_docstrings_to_model_forward,
         | 
| 30 | 
            -
                logging,
         | 
| 31 | 
            -
                replace_return_docstrings,
         | 
| 32 | 
            -
            )
         | 
| 33 | 
            -
            from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
         | 
| 34 | 
             
            from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
         | 
| 35 | 
            -
            from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY | 
| 36 |  | 
|  | |
| 37 |  | 
| 38 |  | 
| 39 | 
             
            class GPT2MQAttention(nn.Module):
         | 
| @@ -329,6 +311,201 @@ class GPT2CustomModel(GPT2Model): | |
| 329 | 
             
                    # Initialize weights and apply final processing
         | 
| 330 | 
             
                    self.post_init()
         | 
| 331 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 332 |  | 
| 333 | 
             
            class GPT2LMHeadCustomModel(GPT2LMHeadModel):
         | 
| 334 | 
             
                config_class = GPT2CustomConfig
         | 
|  | |
| 1 | 
             
            """PyTorch OpenAI GPT-2 model modified with MultiQuery attention"""
         | 
| 2 |  | 
| 3 |  | 
|  | |
|  | |
|  | |
| 4 | 
             
            from typing import Optional, Tuple, Union
         | 
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import torch.utils.checkpoint
         | 
| 8 | 
             
            from torch import nn
         | 
| 9 | 
             
            from torch.cuda.amp import autocast
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 12 | 
             
            from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
         | 
| 13 |  | 
| 14 | 
            +
            from transformers.utils import logging
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 15 | 
             
            from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
         | 
| 16 | 
            +
            from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY
         | 
| 17 |  | 
| 18 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            class GPT2MQAttention(nn.Module):
         | 
|  | |
| 311 | 
             
                    # Initialize weights and apply final processing
         | 
| 312 | 
             
                    self.post_init()
         | 
| 313 |  | 
| 314 | 
            +
                def forward(
         | 
| 315 | 
            +
                    self,
         | 
| 316 | 
            +
                    input_ids: Optional[torch.LongTensor] = None,
         | 
| 317 | 
            +
                    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
         | 
| 318 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 319 | 
            +
                    token_type_ids: Optional[torch.LongTensor] = None,
         | 
| 320 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 321 | 
            +
                    head_mask: Optional[torch.FloatTensor] = None,
         | 
| 322 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 323 | 
            +
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 324 | 
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 325 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 326 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 327 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 328 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 329 | 
            +
                ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
         | 
| 330 | 
            +
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         | 
| 331 | 
            +
                    output_hidden_states = (
         | 
| 332 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 333 | 
            +
                    )
         | 
| 334 | 
            +
                    use_cache = use_cache if use_cache is not None else self.config.use_cache
         | 
| 335 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    if input_ids is not None and inputs_embeds is not None:
         | 
| 338 | 
            +
                        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
         | 
| 339 | 
            +
                    elif input_ids is not None:
         | 
| 340 | 
            +
                        input_shape = input_ids.size()
         | 
| 341 | 
            +
                        input_ids = input_ids.view(-1, input_shape[-1])
         | 
| 342 | 
            +
                        batch_size = input_ids.shape[0]
         | 
| 343 | 
            +
                    elif inputs_embeds is not None:
         | 
| 344 | 
            +
                        input_shape = inputs_embeds.size()[:-1]
         | 
| 345 | 
            +
                        batch_size = inputs_embeds.shape[0]
         | 
| 346 | 
            +
                    else:
         | 
| 347 | 
            +
                        raise ValueError("You have to specify either input_ids or inputs_embeds")
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    device = input_ids.device if input_ids is not None else inputs_embeds.device
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    if token_type_ids is not None:
         | 
| 352 | 
            +
                        token_type_ids = token_type_ids.view(-1, input_shape[-1])
         | 
| 353 | 
            +
                    if position_ids is not None:
         | 
| 354 | 
            +
                        position_ids = position_ids.view(-1, input_shape[-1])
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    if past_key_values is None:
         | 
| 357 | 
            +
                        past_length = 0
         | 
| 358 | 
            +
                        past_key_values = tuple([None] * len(self.h))
         | 
| 359 | 
            +
                    else:
         | 
| 360 | 
            +
                        # this is different from GPT2
         | 
| 361 | 
            +
                        past_length = past_key_values[0][0].size(-1)
         | 
| 362 | 
            +
                    if position_ids is None:
         | 
| 363 | 
            +
                        position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
         | 
| 364 | 
            +
                        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    # GPT2Attention mask.
         | 
| 367 | 
            +
                    if attention_mask is not None:
         | 
| 368 | 
            +
                        if batch_size <= 0:
         | 
| 369 | 
            +
                            raise ValueError("batch_size has to be defined and > 0")
         | 
| 370 | 
            +
                        attention_mask = attention_mask.view(batch_size, -1)
         | 
| 371 | 
            +
                        # We create a 3D attention mask from a 2D tensor mask.
         | 
| 372 | 
            +
                        # Sizes are [batch_size, 1, 1, to_seq_length]
         | 
| 373 | 
            +
                        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
         | 
| 374 | 
            +
                        # this attention mask is more simple than the triangular masking of causal attention
         | 
| 375 | 
            +
                        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
         | 
| 376 | 
            +
                        attention_mask = attention_mask[:, None, None, :]
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
         | 
| 379 | 
            +
                        # masked positions, this operation will create a tensor which is 0.0 for
         | 
| 380 | 
            +
                        # positions we want to attend and the dtype's smallest value for masked positions.
         | 
| 381 | 
            +
                        # Since we are adding it to the raw scores before the softmax, this is
         | 
| 382 | 
            +
                        # effectively the same as removing these entirely.
         | 
| 383 | 
            +
                        attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
         | 
| 384 | 
            +
                        attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    # If a 2D or 3D attention mask is provided for the cross-attention
         | 
| 387 | 
            +
                    # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
         | 
| 388 | 
            +
                    if self.config.add_cross_attention and encoder_hidden_states is not None:
         | 
| 389 | 
            +
                        encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
         | 
| 390 | 
            +
                        encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
         | 
| 391 | 
            +
                        if encoder_attention_mask is None:
         | 
| 392 | 
            +
                            encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
         | 
| 393 | 
            +
                        encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
         | 
| 394 | 
            +
                    else:
         | 
| 395 | 
            +
                        encoder_attention_mask = None
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    # Prepare head mask if needed
         | 
| 398 | 
            +
                    # 1.0 in head_mask indicate we keep the head
         | 
| 399 | 
            +
                    # attention_probs has shape bsz x n_heads x N x N
         | 
| 400 | 
            +
                    # head_mask has shape n_layer x batch x n_heads x N x N
         | 
| 401 | 
            +
                    head_mask = self.get_head_mask(head_mask, self.config.n_layer)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    if inputs_embeds is None:
         | 
| 404 | 
            +
                        inputs_embeds = self.wte(input_ids)
         | 
| 405 | 
            +
                    position_embeds = self.wpe(position_ids)
         | 
| 406 | 
            +
                    hidden_states = inputs_embeds + position_embeds
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    if token_type_ids is not None:
         | 
| 409 | 
            +
                        token_type_embeds = self.wte(token_type_ids)
         | 
| 410 | 
            +
                        hidden_states = hidden_states + token_type_embeds
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    hidden_states = self.drop(hidden_states)
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                    output_shape = input_shape + (hidden_states.size(-1),)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    presents = () if use_cache else None
         | 
| 417 | 
            +
                    all_self_attentions = () if output_attentions else None
         | 
| 418 | 
            +
                    all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
         | 
| 419 | 
            +
                    all_hidden_states = () if output_hidden_states else None
         | 
| 420 | 
            +
                    for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                        # Model parallel
         | 
| 423 | 
            +
                        if self.model_parallel:
         | 
| 424 | 
            +
                            torch.cuda.set_device(hidden_states.device)
         | 
| 425 | 
            +
                            # Ensure layer_past is on same device as hidden_states (might not be correct)
         | 
| 426 | 
            +
                            if layer_past is not None:
         | 
| 427 | 
            +
                                layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
         | 
| 428 | 
            +
                            # Ensure that attention_mask is always on the same device as hidden_states
         | 
| 429 | 
            +
                            if attention_mask is not None:
         | 
| 430 | 
            +
                                attention_mask = attention_mask.to(hidden_states.device)
         | 
| 431 | 
            +
                            if isinstance(head_mask, torch.Tensor):
         | 
| 432 | 
            +
                                head_mask = head_mask.to(hidden_states.device)
         | 
| 433 | 
            +
                        if output_hidden_states:
         | 
| 434 | 
            +
                            all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                        if self.gradient_checkpointing and self.training:
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                            if use_cache:
         | 
| 439 | 
            +
                                logger.warning(
         | 
| 440 | 
            +
                                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
         | 
| 441 | 
            +
                                )
         | 
| 442 | 
            +
                                use_cache = False
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                            def create_custom_forward(module):
         | 
| 445 | 
            +
                                def custom_forward(*inputs):
         | 
| 446 | 
            +
                                    # None for past_key_value
         | 
| 447 | 
            +
                                    return module(*inputs, use_cache, output_attentions)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                                return custom_forward
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                            outputs = torch.utils.checkpoint.checkpoint(
         | 
| 452 | 
            +
                                create_custom_forward(block),
         | 
| 453 | 
            +
                                hidden_states,
         | 
| 454 | 
            +
                                None,
         | 
| 455 | 
            +
                                attention_mask,
         | 
| 456 | 
            +
                                head_mask[i],
         | 
| 457 | 
            +
                                encoder_hidden_states,
         | 
| 458 | 
            +
                                encoder_attention_mask,
         | 
| 459 | 
            +
                            )
         | 
| 460 | 
            +
                        else:
         | 
| 461 | 
            +
                            outputs = block(
         | 
| 462 | 
            +
                                hidden_states,
         | 
| 463 | 
            +
                                layer_past=layer_past,
         | 
| 464 | 
            +
                                attention_mask=attention_mask,
         | 
| 465 | 
            +
                                head_mask=head_mask[i],
         | 
| 466 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 467 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 468 | 
            +
                                use_cache=use_cache,
         | 
| 469 | 
            +
                                output_attentions=output_attentions,
         | 
| 470 | 
            +
                            )
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                        hidden_states = outputs[0]
         | 
| 473 | 
            +
                        if use_cache is True:
         | 
| 474 | 
            +
                            presents = presents + (outputs[1],)
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                        if output_attentions:
         | 
| 477 | 
            +
                            all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
         | 
| 478 | 
            +
                            if self.config.add_cross_attention:
         | 
| 479 | 
            +
                                all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                        # Model Parallel: If it's the last layer for that device, put things on the next device
         | 
| 482 | 
            +
                        if self.model_parallel:
         | 
| 483 | 
            +
                            for k, v in self.device_map.items():
         | 
| 484 | 
            +
                                if i == v[-1] and "cuda:" + str(k) != self.last_device:
         | 
| 485 | 
            +
                                    hidden_states = hidden_states.to("cuda:" + str(k + 1))
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    hidden_states = self.ln_f(hidden_states)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    hidden_states = hidden_states.view(output_shape)
         | 
| 490 | 
            +
                    # Add last hidden state
         | 
| 491 | 
            +
                    if output_hidden_states:
         | 
| 492 | 
            +
                        all_hidden_states = all_hidden_states + (hidden_states,)
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    if not return_dict:
         | 
| 495 | 
            +
                        return tuple(
         | 
| 496 | 
            +
                            v
         | 
| 497 | 
            +
                            for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
         | 
| 498 | 
            +
                            if v is not None
         | 
| 499 | 
            +
                        )
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    return BaseModelOutputWithPastAndCrossAttentions(
         | 
| 502 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 503 | 
            +
                        past_key_values=presents,
         | 
| 504 | 
            +
                        hidden_states=all_hidden_states,
         | 
| 505 | 
            +
                        attentions=all_self_attentions,
         | 
| 506 | 
            +
                        cross_attentions=all_cross_attentions,
         | 
| 507 | 
            +
                    )
         | 
| 508 | 
            +
             | 
| 509 |  | 
| 510 | 
             
            class GPT2LMHeadCustomModel(GPT2LMHeadModel):
         | 
| 511 | 
             
                config_class = GPT2CustomConfig
         | 

