fix vision-only inference
Browse files- modeling_minicpmo.py +4 -3
    	
        modeling_minicpmo.py
    CHANGED
    
    | @@ -484,9 +484,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): | |
| 484 | 
             
                    Returns:
         | 
| 485 | 
             
                        List[List[torch.Tensor]]: audio embeddings
         | 
| 486 | 
             
                    """
         | 
| 487 | 
            -
                     | 
| 488 | 
            -
                    device = self.apm.embed_positions.weight.device
         | 
| 489 | 
            -
             | 
| 490 | 
             
                    wavforms = data.get("audio_features", [])  # (bs, 80, frames) or [], multi audios need filled in advance
         | 
| 491 | 
             
                    audio_feature_lens_raw = data.get("audio_feature_lens", [])  # list, [[x1, x2], [y1], [z1]]
         | 
| 492 |  | 
| @@ -547,6 +545,9 @@ class MiniCPMO(MiniCPMOPreTrainedModel): | |
| 547 | 
             
                            final_audio_embeds.append(target_audio_embeds)
         | 
| 548 | 
             
                        return final_audio_embeds
         | 
| 549 | 
             
                    elif self.training and dummy:
         | 
|  | |
|  | |
|  | |
| 550 | 
             
                        dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
         | 
| 551 | 
             
                        audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
         | 
| 552 |  | 
|  | |
| 484 | 
             
                    Returns:
         | 
| 485 | 
             
                        List[List[torch.Tensor]]: audio embeddings
         | 
| 486 | 
             
                    """
         | 
| 487 | 
            +
                    
         | 
|  | |
|  | |
| 488 | 
             
                    wavforms = data.get("audio_features", [])  # (bs, 80, frames) or [], multi audios need filled in advance
         | 
| 489 | 
             
                    audio_feature_lens_raw = data.get("audio_feature_lens", [])  # list, [[x1, x2], [y1], [z1]]
         | 
| 490 |  | 
|  | |
| 545 | 
             
                            final_audio_embeds.append(target_audio_embeds)
         | 
| 546 | 
             
                        return final_audio_embeds
         | 
| 547 | 
             
                    elif self.training and dummy:
         | 
| 548 | 
            +
                        dtype = self.apm.embed_positions.weight.dtype
         | 
| 549 | 
            +
                        device = self.apm.embed_positions.weight.device
         | 
| 550 | 
            +
             | 
| 551 | 
             
                        dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
         | 
| 552 | 
             
                        audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
         | 
| 553 |  | 

