Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update model_helper.py
Browse files- model_helper.py +3 -3
 
    	
        model_helper.py
    CHANGED
    
    | 
         @@ -22,7 +22,7 @@ from model.ymt3 import YourMT3 
     | 
|
| 22 | 
         | 
| 23 | 
         | 
| 24 | 
         | 
| 25 | 
         
            -
            def load_model_checkpoint(args=None):
         
     | 
| 26 | 
         
             
                parser = argparse.ArgumentParser(description="YourMT3")
         
     | 
| 27 | 
         
             
                # General
         
     | 
| 28 | 
         
             
                parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
         
     | 
| 
         @@ -104,7 +104,7 @@ def load_model_checkpoint(args=None): 
     | 
|
| 104 | 
         
             
                print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
         
     | 
| 105 | 
         | 
| 106 | 
         
             
                # Use GPU if available
         
     | 
| 107 | 
         
            -
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 108 | 
         | 
| 109 | 
         
             
                # Model
         
     | 
| 110 | 
         
             
                model = YourMT3(
         
     | 
| 
         @@ -120,7 +120,7 @@ def load_model_checkpoint(args=None): 
     | 
|
| 120 | 
         
             
                state_dict = checkpoint['state_dict']
         
     | 
| 121 | 
         
             
                new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
         
     | 
| 122 | 
         
             
                model.load_state_dict(new_state_dict, strict=False)
         
     | 
| 123 | 
         
            -
                return model.eval()
         
     | 
| 124 | 
         | 
| 125 | 
         | 
| 126 | 
         
             
            def transcribe(model, audio_info):
         
     | 
| 
         | 
|
| 22 | 
         | 
| 23 | 
         | 
| 24 | 
         | 
| 25 | 
         
            +
            def load_model_checkpoint(args=None, device='cpu'):
         
     | 
| 26 | 
         
             
                parser = argparse.ArgumentParser(description="YourMT3")
         
     | 
| 27 | 
         
             
                # General
         
     | 
| 28 | 
         
             
                parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
         
     | 
| 
         | 
|
| 104 | 
         
             
                print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
         
     | 
| 105 | 
         | 
| 106 | 
         
             
                # Use GPU if available
         
     | 
| 107 | 
         
            +
                # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 108 | 
         | 
| 109 | 
         
             
                # Model
         
     | 
| 110 | 
         
             
                model = YourMT3(
         
     | 
| 
         | 
|
| 120 | 
         
             
                state_dict = checkpoint['state_dict']
         
     | 
| 121 | 
         
             
                new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
         
     | 
| 122 | 
         
             
                model.load_state_dict(new_state_dict, strict=False)
         
     | 
| 123 | 
         
            +
                return model.eval() # load checkpoint on cpu first
         
     | 
| 124 | 
         | 
| 125 | 
         | 
| 126 | 
         
             
            def transcribe(model, audio_info):
         
     |