Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Haobo Yuan
		
	commited on
		
		
					Commit 
							
							·
						
						976cbed
	
1
								Parent(s):
							
							1f8df14
								
bugfix back_toke
Browse files
    	
        seg/models/heads/mask2former_vid.py
    CHANGED
    
    | @@ -191,14 +191,15 @@ class Mask2FormerVideoHead(AnchorFreeHead): | |
| 191 | 
             
                            _dim = cls_embed.size(2)
         | 
| 192 | 
             
                            _prototypes = cls_embed.size(1)
         | 
| 193 |  | 
| 194 | 
            -
                            if rank == 0:
         | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
                            else:
         | 
| 198 | 
            -
             | 
| 199 | 
            -
                            if world_size > 1:
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                            back_token = back_token.to(device='cpu')
         | 
|  | |
| 202 | 
             
                            cls_embed = torch.cat([
         | 
| 203 | 
             
                                cls_embed, back_token.repeat(_prototypes, 1)[None]
         | 
| 204 | 
             
                            ], dim=0)
         | 
|  | |
| 191 | 
             
                            _dim = cls_embed.size(2)
         | 
| 192 | 
             
                            _prototypes = cls_embed.size(1)
         | 
| 193 |  | 
| 194 | 
            +
                            # if rank == 0:
         | 
| 195 | 
            +
                            #     back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
         | 
| 196 | 
            +
                            #     # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
         | 
| 197 | 
            +
                            # else:
         | 
| 198 | 
            +
                            #     back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
         | 
| 199 | 
            +
                            # if world_size > 1:
         | 
| 200 | 
            +
                            #     dist.broadcast(back_token, src=0)
         | 
| 201 | 
            +
                            # back_token = back_token.to(device='cpu')
         | 
| 202 | 
            +
                            back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu')
         | 
| 203 | 
             
                            cls_embed = torch.cat([
         | 
| 204 | 
             
                                cls_embed, back_token.repeat(_prototypes, 1)[None]
         | 
| 205 | 
             
                            ], dim=0)
         | 
