Update custom_st.py (#19)
Browse files- Update custom_st.py (89146e49f8e1cd1cd5231642a803167c3868f443)
- Update custom_st.py (23e2bf96c6f5d8b13f352d4ac29cd153f522ab91)
- Update custom_st.py (7a9a21cc961c2fd89de1c56c616a7c64e9cbcd2a)
Co-authored-by: kosung <[email protected]>
- custom_st.py +3 -1
 
    	
        custom_st.py
    CHANGED
    
    | 
         @@ -1,3 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            from io import BytesIO
         
     | 
| 2 | 
         
             
            from typing import Any, Dict, Optional, List
         
     | 
| 3 | 
         
             
            import torch
         
     | 
| 
         @@ -51,7 +53,7 @@ class MultiModalTransformer(BaseTransformer): 
     | 
|
| 51 | 
         
             
                    self, features: Dict[str, torch.Tensor], **kwargs
         
     | 
| 52 | 
         
             
                ) -> Dict[str, torch.Tensor]:       
         
     | 
| 53 | 
         
             
                    if features.get("inputs_embeds", None) is None:
         
     | 
| 54 | 
         
            -
                        features["inputs_embeds"] = self.auto_model.base_model. 
     | 
| 55 | 
         
             
                        if features.get("pixel_values", None) is not None:
         
     | 
| 56 | 
         
             
                            features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
         
     | 
| 57 | 
         
             
                            image_embeds = self.auto_model.visual(
         
     | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import logging
         
     | 
| 3 | 
         
             
            from io import BytesIO
         
     | 
| 4 | 
         
             
            from typing import Any, Dict, Optional, List
         
     | 
| 5 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 53 | 
         
             
                    self, features: Dict[str, torch.Tensor], **kwargs
         
     | 
| 54 | 
         
             
                ) -> Dict[str, torch.Tensor]:       
         
     | 
| 55 | 
         
             
                    if features.get("inputs_embeds", None) is None:
         
     | 
| 56 | 
         
            +
                        features["inputs_embeds"] = self.auto_model.base_model.get_input_embeddings()(features["input_ids"])
         
     | 
| 57 | 
         
             
                        if features.get("pixel_values", None) is not None:
         
     | 
| 58 | 
         
             
                            features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
         
     | 
| 59 | 
         
             
                            image_embeds = self.auto_model.visual(
         
     |