sync with the latest official code
Browse files- modeling_qwen.py +5 -7
    	
        modeling_qwen.py
    CHANGED
    
    | @@ -520,11 +520,9 @@ class QWenAttention(nn.Module): | |
| 520 |  | 
| 521 | 
             
                        if not self.use_cache_quantization and SUPPORT_TORCH2:
         | 
| 522 | 
             
                            if attention_mask is not None:
         | 
| 523 | 
            -
                                attention_mask = attention_mask.expand(
         | 
| 524 | 
            -
                                    -1, -1, causal_mask.size(2), -1
         | 
| 525 | 
            -
                                )
         | 
| 526 | 
             
                                if causal_mask is not None:
         | 
| 527 | 
            -
                                    attention_mask. | 
| 528 | 
             
                            else:
         | 
| 529 | 
             
                                attention_mask = causal_mask
         | 
| 530 | 
             
                            attn_output = F.scaled_dot_product_attention(
         | 
| @@ -1330,14 +1328,14 @@ def apply_rotary_pos_emb(t, freqs): | |
| 1330 | 
             
                  t (tensor(batch_size, seq_len, n_head, head_dim)):
         | 
| 1331 | 
             
                    the input embedding/hidden states
         | 
| 1332 | 
             
                  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
         | 
| 1333 | 
            -
                    the cached cos/sin position embeddings | 
| 1334 | 
             
                """
         | 
| 1335 | 
             
                rot_dim = freqs[0].shape[-1]
         | 
| 1336 | 
             
                cos, sin = freqs
         | 
| 1337 | 
             
                t_float = t.float()
         | 
| 1338 | 
             
                if apply_rotary_emb_func is not None and t.is_cuda:
         | 
| 1339 | 
            -
                    # apply_rotary_emb in flash_attn requires cos/sin to be of | 
| 1340 | 
            -
                    # shape (seqlen, rotary_dim / 2) and apply rotary embedding | 
| 1341 | 
             
                    # to the first rotary_dim of the input
         | 
| 1342 | 
             
                    cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
         | 
| 1343 | 
             
                    sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
         | 
|  | |
| 520 |  | 
| 521 | 
             
                        if not self.use_cache_quantization and SUPPORT_TORCH2:
         | 
| 522 | 
             
                            if attention_mask is not None:
         | 
| 523 | 
            +
                                attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
         | 
|  | |
|  | |
| 524 | 
             
                                if causal_mask is not None:
         | 
| 525 | 
            +
                                    attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
         | 
| 526 | 
             
                            else:
         | 
| 527 | 
             
                                attention_mask = causal_mask
         | 
| 528 | 
             
                            attn_output = F.scaled_dot_product_attention(
         | 
|  | |
| 1328 | 
             
                  t (tensor(batch_size, seq_len, n_head, head_dim)):
         | 
| 1329 | 
             
                    the input embedding/hidden states
         | 
| 1330 | 
             
                  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
         | 
| 1331 | 
            +
                    the cached cos/sin position embeddings
         | 
| 1332 | 
             
                """
         | 
| 1333 | 
             
                rot_dim = freqs[0].shape[-1]
         | 
| 1334 | 
             
                cos, sin = freqs
         | 
| 1335 | 
             
                t_float = t.float()
         | 
| 1336 | 
             
                if apply_rotary_emb_func is not None and t.is_cuda:
         | 
| 1337 | 
            +
                    # apply_rotary_emb in flash_attn requires cos/sin to be of
         | 
| 1338 | 
            +
                    # shape (seqlen, rotary_dim / 2) and apply rotary embedding
         | 
| 1339 | 
             
                    # to the first rotary_dim of the input
         | 
| 1340 | 
             
                    cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
         | 
| 1341 | 
             
                    sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
         | 

