Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update ootd/inference_ootd_hd.py
Browse files
    	
        ootd/inference_ootd_hd.py
    CHANGED
    
    | @@ -32,7 +32,7 @@ MODEL_PATH = "./checkpoints/ootd" | |
| 32 | 
             
            class OOTDiffusionHD:
         | 
| 33 |  | 
| 34 | 
             
                def __init__(self, gpu_id):
         | 
| 35 | 
            -
                    self.gpu_id = 'cuda:' + str(gpu_id)
         | 
| 36 |  | 
| 37 | 
             
                    vae = AutoencoderKL.from_pretrained(
         | 
| 38 | 
             
                        VAE_PATH,
         | 
| @@ -63,12 +63,12 @@ class OOTDiffusionHD: | |
| 63 | 
             
                        use_safetensors=True,
         | 
| 64 | 
             
                        safety_checker=None,
         | 
| 65 | 
             
                        requires_safety_checker=False,
         | 
| 66 | 
            -
                    ) | 
| 67 |  | 
| 68 | 
             
                    self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
         | 
| 69 |  | 
| 70 | 
             
                    self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
         | 
| 71 | 
            -
                    self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH) | 
| 72 |  | 
| 73 | 
             
                    self.tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 74 | 
             
                        MODEL_PATH,
         | 
| @@ -77,7 +77,7 @@ class OOTDiffusionHD: | |
| 77 | 
             
                    self.text_encoder = CLIPTextModel.from_pretrained(
         | 
| 78 | 
             
                        MODEL_PATH,
         | 
| 79 | 
             
                        subfolder="text_encoder",
         | 
| 80 | 
            -
                    ) | 
| 81 |  | 
| 82 |  | 
| 83 | 
             
                def tokenize_captions(self, captions, max_length):
         | 
| @@ -106,14 +106,14 @@ class OOTDiffusionHD: | |
| 106 | 
             
                    generator = torch.manual_seed(seed)
         | 
| 107 |  | 
| 108 | 
             
                    with torch.no_grad():
         | 
| 109 | 
            -
                        prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to( | 
| 110 | 
             
                        prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
         | 
| 111 | 
             
                        prompt_image = prompt_image.unsqueeze(1)
         | 
| 112 | 
             
                        if model_type == 'hd':
         | 
| 113 | 
            -
                            prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to( | 
| 114 | 
             
                            prompt_embeds[:, 1:] = prompt_image[:]
         | 
| 115 | 
             
                        elif model_type == 'dc':
         | 
| 116 | 
            -
                            prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to( | 
| 117 | 
             
                            prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
         | 
| 118 | 
             
                        else:
         | 
| 119 | 
             
                            raise ValueError("model_type must be \'hd\' or \'dc\'!")
         | 
|  | |
| 32 | 
             
            class OOTDiffusionHD:
         | 
| 33 |  | 
| 34 | 
             
                def __init__(self, gpu_id):
         | 
| 35 | 
            +
                    # self.gpu_id = 'cuda:' + str(gpu_id)
         | 
| 36 |  | 
| 37 | 
             
                    vae = AutoencoderKL.from_pretrained(
         | 
| 38 | 
             
                        VAE_PATH,
         | 
|  | |
| 63 | 
             
                        use_safetensors=True,
         | 
| 64 | 
             
                        safety_checker=None,
         | 
| 65 | 
             
                        requires_safety_checker=False,
         | 
| 66 | 
            +
                    )#.to(self.gpu_id)
         | 
| 67 |  | 
| 68 | 
             
                    self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
         | 
| 69 |  | 
| 70 | 
             
                    self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
         | 
| 71 | 
            +
                    self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)#.to(self.gpu_id)
         | 
| 72 |  | 
| 73 | 
             
                    self.tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 74 | 
             
                        MODEL_PATH,
         | 
|  | |
| 77 | 
             
                    self.text_encoder = CLIPTextModel.from_pretrained(
         | 
| 78 | 
             
                        MODEL_PATH,
         | 
| 79 | 
             
                        subfolder="text_encoder",
         | 
| 80 | 
            +
                    )#.to(self.gpu_id)
         | 
| 81 |  | 
| 82 |  | 
| 83 | 
             
                def tokenize_captions(self, captions, max_length):
         | 
|  | |
| 106 | 
             
                    generator = torch.manual_seed(seed)
         | 
| 107 |  | 
| 108 | 
             
                    with torch.no_grad():
         | 
| 109 | 
            +
                        prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to('cuda')
         | 
| 110 | 
             
                        prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
         | 
| 111 | 
             
                        prompt_image = prompt_image.unsqueeze(1)
         | 
| 112 | 
             
                        if model_type == 'hd':
         | 
| 113 | 
            +
                            prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to('cuda'))[0]
         | 
| 114 | 
             
                            prompt_embeds[:, 1:] = prompt_image[:]
         | 
| 115 | 
             
                        elif model_type == 'dc':
         | 
| 116 | 
            +
                            prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to('cuda'))[0]
         | 
| 117 | 
             
                            prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
         | 
| 118 | 
             
                        else:
         | 
| 119 | 
             
                            raise ValueError("model_type must be \'hd\' or \'dc\'!")
         | 
