support audio finetuning (#22)
Browse files- support audio finetuning (29a824ea848aa2a72b26706a1c641bc339a51869)
Co-authored-by: Zhangchi Feng <[email protected]>
- modeling_minicpmo.py +14 -1
 
    	
        modeling_minicpmo.py
    CHANGED
    
    | 
         @@ -466,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel): 
     | 
|
| 466 | 
         
             
                    else:
         
     | 
| 467 | 
         
             
                        return []
         
     | 
| 468 | 
         | 
| 469 | 
         
            -
                def get_audio_embedding(self, data, chunk_length=-1):
         
     | 
| 470 | 
         
             
                    r"""
         
     | 
| 471 | 
         
             
                    Extract full audio embeddings with optional chunk-based attention.
         
     | 
| 472 | 
         | 
| 
         @@ -484,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel): 
     | 
|
| 484 | 
         
             
                    Returns:
         
     | 
| 485 | 
         
             
                        List[List[torch.Tensor]]: audio embeddings
         
     | 
| 486 | 
         
             
                    """
         
     | 
| 
         | 
|
| 
         | 
|
| 487 | 
         | 
| 488 | 
         
             
                    wavforms = data.get("audio_features", [])  # (bs, 80, frames) or [], multi audios need filled in advance
         
     | 
| 489 | 
         
             
                    audio_feature_lens_raw = data.get("audio_feature_lens", [])  # list, [[x1, x2], [y1], [z1]]
         
     | 
| 
         @@ -544,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel): 
     | 
|
| 544 | 
         
             
                                idx += 1
         
     | 
| 545 | 
         
             
                            final_audio_embeds.append(target_audio_embeds)
         
     | 
| 546 | 
         
             
                        return final_audio_embeds
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 547 | 
         
             
                    else:
         
     | 
| 548 | 
         
             
                        return []
         
     | 
| 549 | 
         | 
| 
         | 
|
| 466 | 
         
             
                    else:
         
     | 
| 467 | 
         
             
                        return []
         
     | 
| 468 | 
         | 
| 469 | 
         
            +
                def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
         
     | 
| 470 | 
         
             
                    r"""
         
     | 
| 471 | 
         
             
                    Extract full audio embeddings with optional chunk-based attention.
         
     | 
| 472 | 
         | 
| 
         | 
|
| 484 | 
         
             
                    Returns:
         
     | 
| 485 | 
         
             
                        List[List[torch.Tensor]]: audio embeddings
         
     | 
| 486 | 
         
             
                    """
         
     | 
| 487 | 
         
            +
                    dtype = self.apm.embed_positions.weight.dtype
         
     | 
| 488 | 
         
            +
                    device = self.apm.embed_positions.weight.device
         
     | 
| 489 | 
         | 
| 490 | 
         
             
                    wavforms = data.get("audio_features", [])  # (bs, 80, frames) or [], multi audios need filled in advance
         
     | 
| 491 | 
         
             
                    audio_feature_lens_raw = data.get("audio_feature_lens", [])  # list, [[x1, x2], [y1], [z1]]
         
     | 
| 
         | 
|
| 546 | 
         
             
                                idx += 1
         
     | 
| 547 | 
         
             
                            final_audio_embeds.append(target_audio_embeds)
         
     | 
| 548 | 
         
             
                        return final_audio_embeds
         
     | 
| 549 | 
         
            +
                    elif self.training and dummy:
         
     | 
| 550 | 
         
            +
                        dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
         
     | 
| 551 | 
         
            +
                        audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
         
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
                        audio_embeds = self.audio_projection_layer(audio_states)
         
     | 
| 554 | 
         
            +
             
     | 
| 555 | 
         
            +
                        audio_embeds = audio_embeds.transpose(1, 2)
         
     | 
| 556 | 
         
            +
                        audio_embeds = self.audio_avg_pooler(audio_embeds)
         
     | 
| 557 | 
         
            +
                        audio_embeds = audio_embeds.transpose(1, 2)
         
     | 
| 558 | 
         
            +
                        return [audio_embeds]
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
             
                    else:
         
     | 
| 561 | 
         
             
                        return []
         
     | 
| 562 | 
         |