Allow changing LoRA scaling alpha
Browse files- app_inference.py +6 -0
- inference.py +8 -4
    	
        app_inference.py
    CHANGED
    
    | @@ -99,6 +99,11 @@ def create_inference_demo(pipe: InferencePipeline, | |
| 99 | 
             
                                max_lines=1,
         | 
| 100 | 
             
                                placeholder='Example: "A picture of a sks dog in a bucket"'
         | 
| 101 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 102 | 
             
                            seed = gr.Slider(label='Seed',
         | 
| 103 | 
             
                                             minimum=0,
         | 
| 104 | 
             
                                             maximum=100000,
         | 
| @@ -149,6 +154,7 @@ def create_inference_demo(pipe: InferencePipeline, | |
| 149 | 
             
                    inputs = [
         | 
| 150 | 
             
                        lora_model_id,
         | 
| 151 | 
             
                        prompt,
         | 
|  | |
| 152 | 
             
                        seed,
         | 
| 153 | 
             
                        num_steps,
         | 
| 154 | 
             
                        guidance_scale,
         | 
|  | |
| 99 | 
             
                                max_lines=1,
         | 
| 100 | 
             
                                placeholder='Example: "A picture of a sks dog in a bucket"'
         | 
| 101 | 
             
                            )
         | 
| 102 | 
            +
                            alpha = gr.Slider(label='LoRA alpha',
         | 
| 103 | 
            +
                                              minimum=0,
         | 
| 104 | 
            +
                                              maximum=2,
         | 
| 105 | 
            +
                                              step=0.05,
         | 
| 106 | 
            +
                                              value=1)
         | 
| 107 | 
             
                            seed = gr.Slider(label='Seed',
         | 
| 108 | 
             
                                             minimum=0,
         | 
| 109 | 
             
                                             maximum=100000,
         | 
|  | |
| 154 | 
             
                    inputs = [
         | 
| 155 | 
             
                        lora_model_id,
         | 
| 156 | 
             
                        prompt,
         | 
| 157 | 
            +
                        alpha,
         | 
| 158 | 
             
                        seed,
         | 
| 159 | 
             
                        num_steps,
         | 
| 160 | 
             
                        guidance_scale,
         | 
    	
        inference.py
    CHANGED
    
    | @@ -73,6 +73,7 @@ class InferencePipeline: | |
| 73 | 
             
                    self,
         | 
| 74 | 
             
                    lora_model_id: str,
         | 
| 75 | 
             
                    prompt: str,
         | 
|  | |
| 76 | 
             
                    seed: int,
         | 
| 77 | 
             
                    n_steps: int,
         | 
| 78 | 
             
                    guidance_scale: float,
         | 
| @@ -83,8 +84,11 @@ class InferencePipeline: | |
| 83 | 
             
                    self.load_pipe(lora_model_id)
         | 
| 84 |  | 
| 85 | 
             
                    generator = torch.Generator(device=self.device).manual_seed(seed)
         | 
| 86 | 
            -
                    out = self.pipe( | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
|  | |
|  | |
|  | |
| 90 | 
             
                    return out.images[0]
         | 
|  | |
| 73 | 
             
                    self,
         | 
| 74 | 
             
                    lora_model_id: str,
         | 
| 75 | 
             
                    prompt: str,
         | 
| 76 | 
            +
                    lora_scale: float,
         | 
| 77 | 
             
                    seed: int,
         | 
| 78 | 
             
                    n_steps: int,
         | 
| 79 | 
             
                    guidance_scale: float,
         | 
|  | |
| 84 | 
             
                    self.load_pipe(lora_model_id)
         | 
| 85 |  | 
| 86 | 
             
                    generator = torch.Generator(device=self.device).manual_seed(seed)
         | 
| 87 | 
            +
                    out = self.pipe(
         | 
| 88 | 
            +
                        prompt,
         | 
| 89 | 
            +
                        num_inference_steps=n_steps,
         | 
| 90 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 91 | 
            +
                        generator=generator,
         | 
| 92 | 
            +
                        cross_attention_kwargs={'scale': lora_scale},
         | 
| 93 | 
            +
                    )  # type: ignore
         | 
| 94 | 
             
                    return out.images[0]
         | 
 
			
