Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Luis Oala
		
	commited on
		
		
					Commit 
							
							·
						
						d62813f
	
1
								Parent(s):
							
							255a136
								
first test
Browse files- LICENSE +21 -0
- README.md +1 -1
- README.md~ +37 -0
- app.py +199 -0
- app.py~ +196 -0
- glide_text2im/__init__.py +3 -0
- glide_text2im/clip/__init__.py +0 -0
- glide_text2im/clip/attention.py +179 -0
- glide_text2im/clip/config.yaml +18 -0
- glide_text2im/clip/encoders.py +497 -0
- glide_text2im/clip/model_creation.py +117 -0
- glide_text2im/clip/utils.py +97 -0
- glide_text2im/download.py +71 -0
- glide_text2im/fp16_util.py +25 -0
- glide_text2im/gaussian_diffusion.py +639 -0
- glide_text2im/model_creation.py +195 -0
- glide_text2im/nn.py +105 -0
- glide_text2im/respace.py +117 -0
- glide_text2im/text2im_model.py +233 -0
- glide_text2im/tokenizer/__init__.py +0 -0
- glide_text2im/tokenizer/bpe.py +151 -0
- glide_text2im/tokenizer/bpe_simple_vocab_16e6.txt.gz +3 -0
- glide_text2im/tokenizer/encoder.json.gz +3 -0
- glide_text2im/tokenizer/simple_tokenizer.py +163 -0
- glide_text2im/tokenizer/vocab.bpe.gz +3 -0
- glide_text2im/unet.py +635 -0
- glide_text2im/xf.py +130 -0
- model-card.md +50 -0
- notebooks/clip_guided.ipynb +246 -0
- notebooks/grass.png +0 -0
- notebooks/inpaint.ipynb +302 -0
- notebooks/text2im.ipynb +251 -0
- requirements.txt +4 -0
- setup.py +29 -0
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2021 OpenAI
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
             
            emoji: 🔥
         | 
| 4 | 
             
            colorFrom: red
         | 
| 5 | 
             
            colorTo: purple
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: glide
         | 
| 3 | 
             
            emoji: 🔥
         | 
| 4 | 
             
            colorFrom: red
         | 
| 5 | 
             
            colorTo: purple
         | 
    	
        README.md~
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: Glide
         | 
| 3 | 
            +
            emoji: 🔥
         | 
| 4 | 
            +
            colorFrom: red
         | 
| 5 | 
            +
            colorTo: purple
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            app_file: app.py
         | 
| 8 | 
            +
            pinned: false
         | 
| 9 | 
            +
            ---
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Configuration
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            `title`: _string_  
         | 
| 14 | 
            +
            Display title for the Space
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            `emoji`: _string_  
         | 
| 17 | 
            +
            Space emoji (emoji-only character allowed)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            `colorFrom`: _string_  
         | 
| 20 | 
            +
            Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            `colorTo`: _string_  
         | 
| 23 | 
            +
            Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            `sdk`: _string_  
         | 
| 26 | 
            +
            Can be either `gradio` or `streamlit`
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            `sdk_version` : _string_  
         | 
| 29 | 
            +
            Only applicable for `streamlit` SDK.  
         | 
| 30 | 
            +
            See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            `app_file`: _string_  
         | 
| 33 | 
            +
            Path to your main application file (which contains either `gradio` or `streamlit` Python code).  
         | 
| 34 | 
            +
            Path is relative to the root of the repository.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            `pinned`: _boolean_  
         | 
| 37 | 
            +
            Whether the Space stays on top of your list.
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,199 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            os.system('pip install -e .')
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import base64
         | 
| 6 | 
            +
            from io import BytesIO
         | 
| 7 | 
            +
            # from fastapi import FastAPI
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from PIL import Image
         | 
| 10 | 
            +
            import torch as th
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from glide_text2im.download import load_checkpoint
         | 
| 13 | 
            +
            from glide_text2im.model_creation import (
         | 
| 14 | 
            +
                create_model_and_diffusion,
         | 
| 15 | 
            +
                model_and_diffusion_defaults,
         | 
| 16 | 
            +
                model_and_diffusion_defaults_upsampler
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # print("Loading models...")
         | 
| 20 | 
            +
            # app = FastAPI()
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # This notebook supports both CPU and GPU.
         | 
| 23 | 
            +
            # On CPU, generating one sample may take on the order of 20 minutes.
         | 
| 24 | 
            +
            # On a GPU, it should be under a minute.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            """
         | 
| 27 | 
            +
            credit: follows the gradio glide example by valhalla https://huggingface.co/spaces/valhalla/glide-text2im
         | 
| 28 | 
            +
            """
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            has_cuda = th.cuda.is_available()
         | 
| 31 | 
            +
            device = th.device('cpu' if not has_cuda else 'cuda')
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # Create base model.
         | 
| 34 | 
            +
            options = model_and_diffusion_defaults()
         | 
| 35 | 
            +
            options['use_fp16'] = has_cuda
         | 
| 36 | 
            +
            options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
         | 
| 37 | 
            +
            model, diffusion = create_model_and_diffusion(**options)
         | 
| 38 | 
            +
            model.eval()
         | 
| 39 | 
            +
            if has_cuda:
         | 
| 40 | 
            +
                model.convert_to_fp16()
         | 
| 41 | 
            +
            model.to(device)
         | 
| 42 | 
            +
            model.load_state_dict(load_checkpoint('base', device))
         | 
| 43 | 
            +
            print('total base parameters', sum(x.numel() for x in model.parameters()))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            # Create upsampler model.
         | 
| 46 | 
            +
            options_up = model_and_diffusion_defaults_upsampler()
         | 
| 47 | 
            +
            options_up['use_fp16'] = has_cuda
         | 
| 48 | 
            +
            options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
         | 
| 49 | 
            +
            model_up, diffusion_up = create_model_and_diffusion(**options_up)
         | 
| 50 | 
            +
            model_up.eval()
         | 
| 51 | 
            +
            if has_cuda:
         | 
| 52 | 
            +
                model_up.convert_to_fp16()
         | 
| 53 | 
            +
            model_up.to(device)
         | 
| 54 | 
            +
            model_up.load_state_dict(load_checkpoint('upsample', device))
         | 
| 55 | 
            +
            print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def get_images(batch: th.Tensor):
         | 
| 59 | 
            +
                """ Display a batch of images inline. """
         | 
| 60 | 
            +
                scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
         | 
| 61 | 
            +
                reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
         | 
| 62 | 
            +
                return Image.fromarray(reshaped.numpy())
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            # Create a classifier-free guidance sampling function
         | 
| 66 | 
            +
            guidance_scale = 3.0
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            def model_fn(x_t, ts, **kwargs):
         | 
| 69 | 
            +
                half = x_t[: len(x_t) // 2]
         | 
| 70 | 
            +
                combined = th.cat([half, half], dim=0)
         | 
| 71 | 
            +
                model_out = model(combined, ts, **kwargs)
         | 
| 72 | 
            +
                eps, rest = model_out[:, :3], model_out[:, 3:]
         | 
| 73 | 
            +
                cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
         | 
| 74 | 
            +
                half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
         | 
| 75 | 
            +
                eps = th.cat([half_eps, half_eps], dim=0)
         | 
| 76 | 
            +
                return th.cat([eps, rest], dim=1)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            # @app.get("/")
         | 
| 80 | 
            +
            def read_root():
         | 
| 81 | 
            +
                return {"glide!"}
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            # @app.get("/{generate}")
         | 
| 84 | 
            +
            def sample(prompt):
         | 
| 85 | 
            +
                # Sampling parameters
         | 
| 86 | 
            +
                batch_size = 1
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                # Tune this parameter to control the sharpness of 256x256 images.
         | 
| 89 | 
            +
                # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
         | 
| 90 | 
            +
                upsample_temp = 0.997
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                ##############################
         | 
| 93 | 
            +
                # Sample from the base model #
         | 
| 94 | 
            +
                ##############################
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                # Create the text tokens to feed to the model.
         | 
| 97 | 
            +
                tokens = model.tokenizer.encode(prompt)
         | 
| 98 | 
            +
                tokens, mask = model.tokenizer.padded_tokens_and_mask(
         | 
| 99 | 
            +
                    tokens, options['text_ctx']
         | 
| 100 | 
            +
                )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                # Create the classifier-free guidance tokens (empty)
         | 
| 103 | 
            +
                full_batch_size = batch_size * 2
         | 
| 104 | 
            +
                uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
         | 
| 105 | 
            +
                    [], options['text_ctx']
         | 
| 106 | 
            +
                )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                # Pack the tokens together into model kwargs.
         | 
| 109 | 
            +
                model_kwargs = dict(
         | 
| 110 | 
            +
                    tokens=th.tensor(
         | 
| 111 | 
            +
                        [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
         | 
| 112 | 
            +
                    ),
         | 
| 113 | 
            +
                    mask=th.tensor(
         | 
| 114 | 
            +
                        [mask] * batch_size + [uncond_mask] * batch_size,
         | 
| 115 | 
            +
                        dtype=th.bool,
         | 
| 116 | 
            +
                        device=device,
         | 
| 117 | 
            +
                    ),
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                # Sample from the base model.
         | 
| 121 | 
            +
                model.del_cache()
         | 
| 122 | 
            +
                samples = diffusion.p_sample_loop(
         | 
| 123 | 
            +
                    model_fn,
         | 
| 124 | 
            +
                    (full_batch_size, 3, options["image_size"], options["image_size"]),
         | 
| 125 | 
            +
                    device=device,
         | 
| 126 | 
            +
                    clip_denoised=True,
         | 
| 127 | 
            +
                    progress=True,
         | 
| 128 | 
            +
                    model_kwargs=model_kwargs,
         | 
| 129 | 
            +
                    cond_fn=None,
         | 
| 130 | 
            +
                )[:batch_size]
         | 
| 131 | 
            +
                model.del_cache()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
                ##############################
         | 
| 135 | 
            +
                # Upsample the 64x64 samples #
         | 
| 136 | 
            +
                ##############################
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                tokens = model_up.tokenizer.encode(prompt)
         | 
| 139 | 
            +
                tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
         | 
| 140 | 
            +
                    tokens, options_up['text_ctx']
         | 
| 141 | 
            +
                )
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # Create the model conditioning dict.
         | 
| 144 | 
            +
                model_kwargs = dict(
         | 
| 145 | 
            +
                    # Low-res image to upsample.
         | 
| 146 | 
            +
                    low_res=((samples+1)*127.5).round()/127.5 - 1,
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    # Text tokens
         | 
| 149 | 
            +
                    tokens=th.tensor(
         | 
| 150 | 
            +
                        [tokens] * batch_size, device=device
         | 
| 151 | 
            +
                    ),
         | 
| 152 | 
            +
                    mask=th.tensor(
         | 
| 153 | 
            +
                        [mask] * batch_size,
         | 
| 154 | 
            +
                        dtype=th.bool,
         | 
| 155 | 
            +
                        device=device,
         | 
| 156 | 
            +
                    ),
         | 
| 157 | 
            +
                )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # Sample from the base model.
         | 
| 160 | 
            +
                model_up.del_cache()
         | 
| 161 | 
            +
                up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
         | 
| 162 | 
            +
                up_samples = diffusion_up.ddim_sample_loop(
         | 
| 163 | 
            +
                    model_up,
         | 
| 164 | 
            +
                    up_shape,
         | 
| 165 | 
            +
                    noise=th.randn(up_shape, device=device) * upsample_temp,
         | 
| 166 | 
            +
                    device=device,
         | 
| 167 | 
            +
                    clip_denoised=True,
         | 
| 168 | 
            +
                    progress=True,
         | 
| 169 | 
            +
                    model_kwargs=model_kwargs,
         | 
| 170 | 
            +
                    cond_fn=None,
         | 
| 171 | 
            +
                )[:batch_size]
         | 
| 172 | 
            +
                model_up.del_cache()
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                # Show the output
         | 
| 175 | 
            +
                image = get_images(up_samples)
         | 
| 176 | 
            +
                # image = to_base64(image)
         | 
| 177 | 
            +
                # return {"image": image}
         | 
| 178 | 
            +
                return image
         | 
| 179 | 
            +
             | 
| 180 | 
            +
             | 
| 181 | 
            +
            def to_base64(pil_image):
         | 
| 182 | 
            +
                buffered = BytesIO()
         | 
| 183 | 
            +
                pil_image.save(buffered, format="JPEG")
         | 
| 184 | 
            +
                return base64.b64encode(buffered.getvalue())
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            title = "Interactive demo: glide-text2im"
         | 
| 187 | 
            +
            description = "Demo for OpenAI's GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models."
         | 
| 188 | 
            +
            article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10741'>GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models</a> | <a href='https://github.com/openai/glide-text2im/'>Official Repo</a></p>"
         | 
| 189 | 
            +
            examples =["an oil painting of a corgi"]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            iface = gr.Interface(fn=sample, 
         | 
| 192 | 
            +
                                 inputs=gr.inputs.Textbox(label='What would you like to see?'), 
         | 
| 193 | 
            +
                                 outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
         | 
| 194 | 
            +
                                 title=title,
         | 
| 195 | 
            +
                                 description=description,
         | 
| 196 | 
            +
                                 article=article,
         | 
| 197 | 
            +
                                 examples=examples,
         | 
| 198 | 
            +
                                 enable_queue=True)
         | 
| 199 | 
            +
            iface.launch(debug=True)
         | 
    	
        app.py~
    ADDED
    
    | @@ -0,0 +1,196 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            os.system('pip install -e .')
         | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import base64
         | 
| 7 | 
            +
            from io import BytesIO
         | 
| 8 | 
            +
            # from fastapi import FastAPI
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from PIL import Image
         | 
| 11 | 
            +
            import torch as th
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from glide_text2im.download import load_checkpoint
         | 
| 14 | 
            +
            from glide_text2im.model_creation import (
         | 
| 15 | 
            +
                create_model_and_diffusion,
         | 
| 16 | 
            +
                model_and_diffusion_defaults,
         | 
| 17 | 
            +
                model_and_diffusion_defaults_upsampler
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # print("Loading models...")
         | 
| 21 | 
            +
            # app = FastAPI()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # This notebook supports both CPU and GPU.
         | 
| 24 | 
            +
            # On CPU, generating one sample may take on the order of 20 minutes.
         | 
| 25 | 
            +
            # On a GPU, it should be under a minute.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            has_cuda = th.cuda.is_available()
         | 
| 28 | 
            +
            device = th.device('cpu' if not has_cuda else 'cuda')
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # Create base model.
         | 
| 31 | 
            +
            options = model_and_diffusion_defaults()
         | 
| 32 | 
            +
            options['use_fp16'] = has_cuda
         | 
| 33 | 
            +
            options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
         | 
| 34 | 
            +
            model, diffusion = create_model_and_diffusion(**options)
         | 
| 35 | 
            +
            model.eval()
         | 
| 36 | 
            +
            if has_cuda:
         | 
| 37 | 
            +
                model.convert_to_fp16()
         | 
| 38 | 
            +
            model.to(device)
         | 
| 39 | 
            +
            model.load_state_dict(load_checkpoint('base', device))
         | 
| 40 | 
            +
            print('total base parameters', sum(x.numel() for x in model.parameters()))
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            # Create upsampler model.
         | 
| 43 | 
            +
            options_up = model_and_diffusion_defaults_upsampler()
         | 
| 44 | 
            +
            options_up['use_fp16'] = has_cuda
         | 
| 45 | 
            +
            options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
         | 
| 46 | 
            +
            model_up, diffusion_up = create_model_and_diffusion(**options_up)
         | 
| 47 | 
            +
            model_up.eval()
         | 
| 48 | 
            +
            if has_cuda:
         | 
| 49 | 
            +
                model_up.convert_to_fp16()
         | 
| 50 | 
            +
            model_up.to(device)
         | 
| 51 | 
            +
            model_up.load_state_dict(load_checkpoint('upsample', device))
         | 
| 52 | 
            +
            print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def get_images(batch: th.Tensor):
         | 
| 56 | 
            +
                """ Display a batch of images inline. """
         | 
| 57 | 
            +
                scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
         | 
| 58 | 
            +
                reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
         | 
| 59 | 
            +
                return Image.fromarray(reshaped.numpy())
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            # Create a classifier-free guidance sampling function
         | 
| 63 | 
            +
            guidance_scale = 3.0
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            def model_fn(x_t, ts, **kwargs):
         | 
| 66 | 
            +
                half = x_t[: len(x_t) // 2]
         | 
| 67 | 
            +
                combined = th.cat([half, half], dim=0)
         | 
| 68 | 
            +
                model_out = model(combined, ts, **kwargs)
         | 
| 69 | 
            +
                eps, rest = model_out[:, :3], model_out[:, 3:]
         | 
| 70 | 
            +
                cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
         | 
| 71 | 
            +
                half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
         | 
| 72 | 
            +
                eps = th.cat([half_eps, half_eps], dim=0)
         | 
| 73 | 
            +
                return th.cat([eps, rest], dim=1)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            # @app.get("/")
         | 
| 77 | 
            +
            def read_root():
         | 
| 78 | 
            +
                return {"glide!"}
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            # @app.get("/{generate}")
         | 
| 81 | 
            +
            def sample(prompt):
         | 
| 82 | 
            +
                # Sampling parameters
         | 
| 83 | 
            +
                batch_size = 1
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # Tune this parameter to control the sharpness of 256x256 images.
         | 
| 86 | 
            +
                # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
         | 
| 87 | 
            +
                upsample_temp = 0.997
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                ##############################
         | 
| 90 | 
            +
                # Sample from the base model #
         | 
| 91 | 
            +
                ##############################
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # Create the text tokens to feed to the model.
         | 
| 94 | 
            +
                tokens = model.tokenizer.encode(prompt)
         | 
| 95 | 
            +
                tokens, mask = model.tokenizer.padded_tokens_and_mask(
         | 
| 96 | 
            +
                    tokens, options['text_ctx']
         | 
| 97 | 
            +
                )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # Create the classifier-free guidance tokens (empty)
         | 
| 100 | 
            +
                full_batch_size = batch_size * 2
         | 
| 101 | 
            +
                uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
         | 
| 102 | 
            +
                    [], options['text_ctx']
         | 
| 103 | 
            +
                )
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                # Pack the tokens together into model kwargs.
         | 
| 106 | 
            +
                model_kwargs = dict(
         | 
| 107 | 
            +
                    tokens=th.tensor(
         | 
| 108 | 
            +
                        [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
         | 
| 109 | 
            +
                    ),
         | 
| 110 | 
            +
                    mask=th.tensor(
         | 
| 111 | 
            +
                        [mask] * batch_size + [uncond_mask] * batch_size,
         | 
| 112 | 
            +
                        dtype=th.bool,
         | 
| 113 | 
            +
                        device=device,
         | 
| 114 | 
            +
                    ),
         | 
| 115 | 
            +
                )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                # Sample from the base model.
         | 
| 118 | 
            +
                model.del_cache()
         | 
| 119 | 
            +
                samples = diffusion.p_sample_loop(
         | 
| 120 | 
            +
                    model_fn,
         | 
| 121 | 
            +
                    (full_batch_size, 3, options["image_size"], options["image_size"]),
         | 
| 122 | 
            +
                    device=device,
         | 
| 123 | 
            +
                    clip_denoised=True,
         | 
| 124 | 
            +
                    progress=True,
         | 
| 125 | 
            +
                    model_kwargs=model_kwargs,
         | 
| 126 | 
            +
                    cond_fn=None,
         | 
| 127 | 
            +
                )[:batch_size]
         | 
| 128 | 
            +
                model.del_cache()
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
                ##############################
         | 
| 132 | 
            +
                # Upsample the 64x64 samples #
         | 
| 133 | 
            +
                ##############################
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                tokens = model_up.tokenizer.encode(prompt)
         | 
| 136 | 
            +
                tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
         | 
| 137 | 
            +
                    tokens, options_up['text_ctx']
         | 
| 138 | 
            +
                )
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # Create the model conditioning dict.
         | 
| 141 | 
            +
                model_kwargs = dict(
         | 
| 142 | 
            +
                    # Low-res image to upsample.
         | 
| 143 | 
            +
                    low_res=((samples+1)*127.5).round()/127.5 - 1,
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # Text tokens
         | 
| 146 | 
            +
                    tokens=th.tensor(
         | 
| 147 | 
            +
                        [tokens] * batch_size, device=device
         | 
| 148 | 
            +
                    ),
         | 
| 149 | 
            +
                    mask=th.tensor(
         | 
| 150 | 
            +
                        [mask] * batch_size,
         | 
| 151 | 
            +
                        dtype=th.bool,
         | 
| 152 | 
            +
                        device=device,
         | 
| 153 | 
            +
                    ),
         | 
| 154 | 
            +
                )
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                # Sample from the base model.
         | 
| 157 | 
            +
                model_up.del_cache()
         | 
| 158 | 
            +
                up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
         | 
| 159 | 
            +
                up_samples = diffusion_up.ddim_sample_loop(
         | 
| 160 | 
            +
                    model_up,
         | 
| 161 | 
            +
                    up_shape,
         | 
| 162 | 
            +
                    noise=th.randn(up_shape, device=device) * upsample_temp,
         | 
| 163 | 
            +
                    device=device,
         | 
| 164 | 
            +
                    clip_denoised=True,
         | 
| 165 | 
            +
                    progress=True,
         | 
| 166 | 
            +
                    model_kwargs=model_kwargs,
         | 
| 167 | 
            +
                    cond_fn=None,
         | 
| 168 | 
            +
                )[:batch_size]
         | 
| 169 | 
            +
                model_up.del_cache()
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                # Show the output
         | 
| 172 | 
            +
                image = get_images(up_samples)
         | 
| 173 | 
            +
                # image = to_base64(image)
         | 
| 174 | 
            +
                # return {"image": image}
         | 
| 175 | 
            +
                return image
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            def to_base64(pil_image):
         | 
| 179 | 
            +
                buffered = BytesIO()
         | 
| 180 | 
            +
                pil_image.save(buffered, format="JPEG")
         | 
| 181 | 
            +
                return base64.b64encode(buffered.getvalue())
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            title = "Interactive demo: glide-text2im"
         | 
| 184 | 
            +
            description = "Demo for OpenAI's GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models."
         | 
| 185 | 
            +
            article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10741'>GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models</a> | <a href='https://github.com/openai/glide-text2im/'>Official Repo</a></p>"
         | 
| 186 | 
            +
            examples =["an oil painting of a corgi"]
         | 
| 187 | 
            +
             | 
| 188 | 
            +
            iface = gr.Interface(fn=sample, 
         | 
| 189 | 
            +
                                 inputs=gr.inputs.Textbox(label='What would you like to see?'), 
         | 
| 190 | 
            +
                                 outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
         | 
| 191 | 
            +
                                 title=title,
         | 
| 192 | 
            +
                                 description=description,
         | 
| 193 | 
            +
                                 article=article,
         | 
| 194 | 
            +
                                 examples=examples,
         | 
| 195 | 
            +
                                 enable_queue=True)
         | 
| 196 | 
            +
            iface.launch(debug=True)
         | 
    	
        glide_text2im/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            A codebase for performing model inference with a text-conditional diffusion model.
         | 
| 3 | 
            +
            """
         | 
    	
        glide_text2im/clip/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        glide_text2im/clip/attention.py
    ADDED
    
    | @@ -0,0 +1,179 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from abc import ABC, abstractmethod
         | 
| 3 | 
            +
            from itertools import product
         | 
| 4 | 
            +
            from typing import Any, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import attr
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            @attr.s
         | 
| 12 | 
            +
            class AttentionMask(ABC):
         | 
| 13 | 
            +
                query_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1)  # type: ignore
         | 
| 14 | 
            +
                key_context_size: int = attr.ib(validator=lambda i, a, x: x >= 1)  # type: ignore
         | 
| 15 | 
            +
                block_size: int = attr.ib(validator=lambda i, a, x: x >= 1)  # type: ignore
         | 
| 16 | 
            +
                n_head: int = attr.ib(validator=lambda i, a, x: x >= 1)  # type: ignore
         | 
| 17 | 
            +
                is_head_specific: bool = attr.ib(default=False)
         | 
| 18 | 
            +
                n_query_pad: int = attr.ib(default=0)
         | 
| 19 | 
            +
                n_key_pad: int = attr.ib(default=0)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 22 | 
            +
                    if self.query_context_size % self.block_size != 0:
         | 
| 23 | 
            +
                        raise ValueError()
         | 
| 24 | 
            +
                    if self.key_context_size % self.block_size != 0:
         | 
| 25 | 
            +
                        raise ValueError()
         | 
| 26 | 
            +
                    if self.n_query_pad >= self.query_context_size:
         | 
| 27 | 
            +
                        raise ValueError()
         | 
| 28 | 
            +
                    if self.n_key_pad >= self.key_context_size:
         | 
| 29 | 
            +
                        raise ValueError()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self.n_query_block = self.query_context_size // self.block_size
         | 
| 32 | 
            +
                    self.n_key_block = self.key_context_size // self.block_size
         | 
| 33 | 
            +
                    self.first_pad_query_block_idx = self.n_query_block - int(
         | 
| 34 | 
            +
                        math.ceil(self.n_query_pad / self.block_size)
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
                    self.first_pad_key_block_idx = self.n_key_block - int(
         | 
| 37 | 
            +
                        math.ceil(self.n_key_pad / self.block_size)
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def _make_global_layout(self) -> None:
         | 
| 41 | 
            +
                    if not self.is_head_specific:
         | 
| 42 | 
            +
                        m = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
         | 
| 43 | 
            +
                        r = product(*[range(n) for n in m.shape])
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        for qb, kb in r:
         | 
| 46 | 
            +
                            m[qb, kb] = np.any(self.block_layout(None, 0, qb, kb, 0))
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        m = np.ones([self.n_head, self.n_query_block, self.n_key_block], dtype=np.bool)
         | 
| 49 | 
            +
                        r = product(*[range(n) for n in m.shape])
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        for h, qb, kb in r:
         | 
| 52 | 
            +
                            m[h, qb, kb] = np.any(self.block_layout(None, h, qb, kb, 0))
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.global_layout = m
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @abstractmethod
         | 
| 57 | 
            +
                def _block_layout(
         | 
| 58 | 
            +
                    self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
         | 
| 59 | 
            +
                ) -> np.ndarray:
         | 
| 60 | 
            +
                    raise NotImplementedError()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def block_layout(
         | 
| 63 | 
            +
                    self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
         | 
| 64 | 
            +
                ) -> np.ndarray:
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    `query_idx`, `key_idx` are block-level, zero-based indices.
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    m = np.ones([self.block_size, self.block_size], dtype=np.bool)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    if query_idx >= self.first_pad_query_block_idx:
         | 
| 72 | 
            +
                        n_pad = min(
         | 
| 73 | 
            +
                            self.block_size,
         | 
| 74 | 
            +
                            (query_idx + 1) * self.block_size - (self.query_context_size - self.n_query_pad),
         | 
| 75 | 
            +
                        )
         | 
| 76 | 
            +
                        assert n_pad > 0
         | 
| 77 | 
            +
                        m[self.block_size - n_pad :] = False
         | 
| 78 | 
            +
                    if key_idx >= self.first_pad_key_block_idx:
         | 
| 79 | 
            +
                        n_pad = min(
         | 
| 80 | 
            +
                            self.block_size,
         | 
| 81 | 
            +
                            (key_idx + 1) * self.block_size - (self.key_context_size - self.n_key_pad),
         | 
| 82 | 
            +
                        )
         | 
| 83 | 
            +
                        assert n_pad > 0
         | 
| 84 | 
            +
                        m[:, self.block_size - n_pad :] = False
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    return m & self._block_layout(blk_shape, head_idx, query_idx, key_idx, blk_idx)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            @attr.s
         | 
| 90 | 
            +
            class DenseAttentionMask(AttentionMask):
         | 
| 91 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 92 | 
            +
                    super().__attrs_post_init__()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    self.global_layout = np.ones([self.n_query_block, self.n_key_block], dtype=np.bool)
         | 
| 95 | 
            +
                    n_zero_query_blocks = self.n_query_pad // self.block_size
         | 
| 96 | 
            +
                    n_zero_key_blocks = self.n_key_pad // self.block_size
         | 
| 97 | 
            +
                    self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
         | 
| 98 | 
            +
                    self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def _block_layout(
         | 
| 101 | 
            +
                    self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
         | 
| 102 | 
            +
                ) -> np.ndarray:
         | 
| 103 | 
            +
                    return np.ones([self.block_size, self.block_size], dtype=np.bool)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            @attr.s
         | 
| 107 | 
            +
            class DenseCausalAttentionMask(AttentionMask):
         | 
| 108 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 109 | 
            +
                    super().__attrs_post_init__()
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.global_layout = np.tril(np.ones([self.n_query_block, self.n_key_block], dtype=np.bool))
         | 
| 112 | 
            +
                    n_zero_query_blocks = self.n_query_pad // self.block_size
         | 
| 113 | 
            +
                    n_zero_key_blocks = self.n_key_pad // self.block_size
         | 
| 114 | 
            +
                    self.global_layout[self.n_query_block - n_zero_query_blocks :] = False
         | 
| 115 | 
            +
                    self.global_layout[:, self.n_key_block - n_zero_key_blocks :] = False
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def _block_layout(
         | 
| 118 | 
            +
                    self, blk_shape: Any, head_idx: int, query_idx: int, key_idx: int, blk_idx: int
         | 
| 119 | 
            +
                ) -> np.ndarray:
         | 
| 120 | 
            +
                    if query_idx > key_idx:
         | 
| 121 | 
            +
                        return np.ones(2 * [self.block_size], dtype=np.bool)
         | 
| 122 | 
            +
                    elif query_idx < key_idx:
         | 
| 123 | 
            +
                        return np.zeros(2 * [self.block_size], dtype=np.bool)
         | 
| 124 | 
            +
                    else:
         | 
| 125 | 
            +
                        return np.tril(np.ones(2 * [self.block_size], dtype=np.bool))
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 129 | 
            +
            class AttentionInfo:
         | 
| 130 | 
            +
                n_heads: int = attr.ib()
         | 
| 131 | 
            +
                ctx_blks_q: int = attr.ib()
         | 
| 132 | 
            +
                ctx_blks_k: int = attr.ib()
         | 
| 133 | 
            +
                block_size: int = attr.ib()
         | 
| 134 | 
            +
                pytorch_attn_bias: Optional[torch.Tensor] = attr.ib()
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def to_attention_info(d: AttentionMask) -> AttentionInfo:
         | 
| 138 | 
            +
                return AttentionInfo(
         | 
| 139 | 
            +
                    n_heads=d.n_head,
         | 
| 140 | 
            +
                    ctx_blks_q=d.n_query_block,
         | 
| 141 | 
            +
                    ctx_blks_k=d.n_key_block,
         | 
| 142 | 
            +
                    block_size=d.block_size,
         | 
| 143 | 
            +
                    pytorch_attn_bias=None,
         | 
| 144 | 
            +
                )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def make_full_layout(d: AttentionMask) -> np.ndarray:
         | 
| 148 | 
            +
                """
         | 
| 149 | 
            +
                Returns the `context_size x context_size` layout matrix described by `d`. If the layout is dependent on the index of
         | 
| 150 | 
            +
                the attention head, a `attention_head x context_size x context_size` layout matrix is returned instead.
         | 
| 151 | 
            +
                """
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                if not d.is_head_specific:
         | 
| 154 | 
            +
                    u = np.reshape(d.global_layout, [d.n_query_block, d.n_key_block, 1, 1])
         | 
| 155 | 
            +
                    r = product(range(d.n_query_block), range(d.n_key_block))
         | 
| 156 | 
            +
                    v = np.array([d.block_layout(None, 0, i, j, 0) for i, j in r])
         | 
| 157 | 
            +
                    v = np.reshape(v, [d.n_query_block, d.n_key_block, d.block_size, d.block_size])
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    w = u * v
         | 
| 160 | 
            +
                    w = np.transpose(w, [0, 2, 1, 3])
         | 
| 161 | 
            +
                    w = np.reshape(w, [d.query_context_size, d.key_context_size])
         | 
| 162 | 
            +
                    return w
         | 
| 163 | 
            +
                else:
         | 
| 164 | 
            +
                    if len(d.global_layout.shape) == 2:
         | 
| 165 | 
            +
                        u = np.reshape(d.global_layout, [1, d.n_query_block, d.n_key_block, 1, 1])
         | 
| 166 | 
            +
                        u = np.tile(u, [d.n_head, 1, 1, 1, 1])
         | 
| 167 | 
            +
                    elif len(d.global_layout.shape) == 3:
         | 
| 168 | 
            +
                        u = np.reshape(d.global_layout, [d.n_head, d.n_query_block, d.n_key_block, 1, 1])
         | 
| 169 | 
            +
                    else:
         | 
| 170 | 
            +
                        raise RuntimeError()
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    s = product(range(d.n_head), range(d.n_query_block), range(d.n_key_block))
         | 
| 173 | 
            +
                    v = np.array([d.block_layout(None, i, j, k, 0) for i, j, k in s])
         | 
| 174 | 
            +
                    v = np.reshape(v, [d.n_head, d.n_query_block, d.n_key_block, d.block_size, d.block_size])
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    w = u * v
         | 
| 177 | 
            +
                    w = np.transpose(w, [0, 1, 3, 2, 4])
         | 
| 178 | 
            +
                    w = np.reshape(w, [d.n_head, d.query_context_size, d.key_context_size])
         | 
| 179 | 
            +
                    return w
         | 
    	
        glide_text2im/clip/config.yaml
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            logit_scale: 100.0
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Diffusion settings
         | 
| 4 | 
            +
            beta_schedule: "squaredcos_cap_v2"
         | 
| 5 | 
            +
            n_timesteps: 1000
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Architecture settings
         | 
| 8 | 
            +
            image_size: 64
         | 
| 9 | 
            +
            patch_size: 4
         | 
| 10 | 
            +
            n_vocab: 65536
         | 
| 11 | 
            +
            max_text_len: 77
         | 
| 12 | 
            +
            n_embd: 512
         | 
| 13 | 
            +
            n_head_state_text: 64
         | 
| 14 | 
            +
            n_head_text: 8
         | 
| 15 | 
            +
            n_xf_blocks_text: 12
         | 
| 16 | 
            +
            n_head_state_image: 64
         | 
| 17 | 
            +
            n_head_image: 12
         | 
| 18 | 
            +
            n_xf_blocks_image: 12
         | 
    	
        glide_text2im/clip/encoders.py
    ADDED
    
    | @@ -0,0 +1,497 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from collections import OrderedDict
         | 
| 3 | 
            +
            from typing import List, Optional, Tuple, cast
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import attr
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn as nn
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .attention import (
         | 
| 12 | 
            +
                AttentionInfo,
         | 
| 13 | 
            +
                DenseAttentionMask,
         | 
| 14 | 
            +
                DenseCausalAttentionMask,
         | 
| 15 | 
            +
                make_full_layout,
         | 
| 16 | 
            +
                to_attention_info,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
            from .utils import Affine, LayerNorm, zero_key_bias_grad
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # Constants used in the original CLIP implementation.
         | 
| 21 | 
            +
            image_channel_means = [122.77093945, 116.74601272, 104.09373519]
         | 
| 22 | 
            +
            image_channel_stds = [68.50053285, 66.63215831, 70.32316309]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 26 | 
            +
            class TextEmbedding(nn.Module):
         | 
| 27 | 
            +
                n_vocab: int = attr.ib()
         | 
| 28 | 
            +
                n_context: int = attr.ib()
         | 
| 29 | 
            +
                n_state: int = attr.ib()
         | 
| 30 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 33 | 
            +
                    super().__init__()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    w_voc = torch.empty((self.n_vocab, self.n_state), dtype=torch.float32, device=self.device)
         | 
| 36 | 
            +
                    w_pos = torch.empty((self.n_context, self.n_state), dtype=torch.float32, device=self.device)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    with torch.no_grad():
         | 
| 39 | 
            +
                        w_voc.normal_(std=0.02)
         | 
| 40 | 
            +
                        w_pos.normal_(std=0.01)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.w_voc = nn.Parameter(w_voc)
         | 
| 43 | 
            +
                    self.w_pos = nn.Parameter(w_pos)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 46 | 
            +
                    if len(x.shape) != 2:
         | 
| 47 | 
            +
                        raise ValueError()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    return F.embedding(x, self.w_voc) + self.w_pos[None, :, :]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 53 | 
            +
            class ImageEmbedding(nn.Module):
         | 
| 54 | 
            +
                image_size: int = attr.ib()
         | 
| 55 | 
            +
                patch_size: int = attr.ib()
         | 
| 56 | 
            +
                n_state: int = attr.ib()
         | 
| 57 | 
            +
                n_timestep: int = attr.ib(default=0)
         | 
| 58 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 61 | 
            +
                    super().__init__()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    if self.image_size % self.patch_size != 0:
         | 
| 64 | 
            +
                        raise ValueError()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    n_patch = self.image_size // self.patch_size
         | 
| 67 | 
            +
                    patch_proj = torch.empty(
         | 
| 68 | 
            +
                        (self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
                    w_pos = torch.empty(
         | 
| 71 | 
            +
                        (1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device
         | 
| 72 | 
            +
                    )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    with torch.no_grad():
         | 
| 75 | 
            +
                        if self.n_timestep == 0:
         | 
| 76 | 
            +
                            pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device)
         | 
| 77 | 
            +
                            pred_state.normal_(std=1 / np.sqrt(self.n_state))
         | 
| 78 | 
            +
                            self.pred_state = nn.Parameter(pred_state)
         | 
| 79 | 
            +
                        else:
         | 
| 80 | 
            +
                            w_t = torch.empty(
         | 
| 81 | 
            +
                                (self.n_timestep, self.n_state), dtype=torch.float32, device=self.device
         | 
| 82 | 
            +
                            )
         | 
| 83 | 
            +
                            w_t.normal_(std=1 / np.sqrt(self.n_state))
         | 
| 84 | 
            +
                            self.w_t = nn.Parameter(w_t)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                        patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2)))
         | 
| 87 | 
            +
                        w_pos.normal_(std=1 / np.sqrt(self.n_state))
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    self.patch_proj = nn.Parameter(patch_proj)
         | 
| 90 | 
            +
                    self.w_pos = nn.Parameter(w_pos)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    self.channel_means = torch.tensor(
         | 
| 93 | 
            +
                        image_channel_means, dtype=torch.float32, device=self.device
         | 
| 94 | 
            +
                    )[None, :, None, None]
         | 
| 95 | 
            +
                    self.channel_stds = torch.tensor(
         | 
| 96 | 
            +
                        image_channel_stds, dtype=torch.float32, device=self.device
         | 
| 97 | 
            +
                    )[None, :, None, None]
         | 
| 98 | 
            +
                    self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor:
         | 
| 101 | 
            +
                    if len(x.shape) != 4:
         | 
| 102 | 
            +
                        raise ValueError("input should be 4d")
         | 
| 103 | 
            +
                    if x.shape[1] != 3:
         | 
| 104 | 
            +
                        raise ValueError("input should have 3 channels")
         | 
| 105 | 
            +
                    if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size):
         | 
| 106 | 
            +
                        raise ValueError(f"input is not {self.image_size} x {self.image_size}")
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None):
         | 
| 109 | 
            +
                        raise ValueError()
         | 
| 110 | 
            +
                    if self.n_timestep != 0:
         | 
| 111 | 
            +
                        assert t is not None
         | 
| 112 | 
            +
                        if len(t.shape) != 1:
         | 
| 113 | 
            +
                            raise ValueError()
         | 
| 114 | 
            +
                        if t.shape[0] != x.shape[0]:
         | 
| 115 | 
            +
                            raise ValueError()
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    x = (x - self.channel_means) / self.channel_stds
         | 
| 118 | 
            +
                    x = F.conv2d(x, self.patch_proj, stride=self.patch_size)
         | 
| 119 | 
            +
                    x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute(
         | 
| 120 | 
            +
                        0, 2, 1
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    sot = (
         | 
| 124 | 
            +
                        self.pred_state[None, None].expand(x.shape[0], -1, -1)
         | 
| 125 | 
            +
                        if self.n_timestep == 0
         | 
| 126 | 
            +
                        else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None]
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
                    x = torch.cat((sot, x), dim=1) + self.w_pos[None]
         | 
| 129 | 
            +
                    return self.ln(x)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 133 | 
            +
            class AttentionResblock(nn.Module):
         | 
| 134 | 
            +
                n_state: int = attr.ib()
         | 
| 135 | 
            +
                n_resblocks: int = attr.ib()
         | 
| 136 | 
            +
                attn_fn: AttentionInfo = attr.ib()
         | 
| 137 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 140 | 
            +
                    super().__init__()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    self.n_head_state = self.n_state // self.attn_fn.n_heads
         | 
| 143 | 
            +
                    self.qk_scale = 1 / np.sqrt(self.n_head_state)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
         | 
| 146 | 
            +
                    self.f_q = Affine(
         | 
| 147 | 
            +
                        self.n_state,
         | 
| 148 | 
            +
                        self.n_state,
         | 
| 149 | 
            +
                        std=1 / math.sqrt(self.n_state),
         | 
| 150 | 
            +
                        use_bias=True,
         | 
| 151 | 
            +
                        bias_filter_fn=zero_key_bias_grad,
         | 
| 152 | 
            +
                        device=self.device,
         | 
| 153 | 
            +
                    )
         | 
| 154 | 
            +
                    self.f_k = Affine(
         | 
| 155 | 
            +
                        self.n_state,
         | 
| 156 | 
            +
                        self.n_state,
         | 
| 157 | 
            +
                        std=1 / math.sqrt(self.n_state),
         | 
| 158 | 
            +
                        use_bias=False,
         | 
| 159 | 
            +
                        bias_filter_fn=zero_key_bias_grad,
         | 
| 160 | 
            +
                        device=self.device,
         | 
| 161 | 
            +
                    )
         | 
| 162 | 
            +
                    self.f_v = Affine(
         | 
| 163 | 
            +
                        self.n_state,
         | 
| 164 | 
            +
                        self.n_state,
         | 
| 165 | 
            +
                        std=1 / math.sqrt(self.n_state),
         | 
| 166 | 
            +
                        use_bias=True,
         | 
| 167 | 
            +
                        bias_filter_fn=zero_key_bias_grad,
         | 
| 168 | 
            +
                        device=self.device,
         | 
| 169 | 
            +
                    )
         | 
| 170 | 
            +
                    self.f_c = Affine(
         | 
| 171 | 
            +
                        self.n_state,
         | 
| 172 | 
            +
                        self.n_state,
         | 
| 173 | 
            +
                        use_bias=True,
         | 
| 174 | 
            +
                        std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2),
         | 
| 175 | 
            +
                        device=self.device,
         | 
| 176 | 
            +
                    )  # XXX
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def forward(self, m: torch.Tensor) -> torch.Tensor:
         | 
| 179 | 
            +
                    n_context = m.shape[1]
         | 
| 180 | 
            +
                    n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context
         | 
| 181 | 
            +
                    n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context
         | 
| 182 | 
            +
                    assert n_query_pad >= 0
         | 
| 183 | 
            +
                    assert n_key_pad >= 0
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    r = m
         | 
| 186 | 
            +
                    r = self.ln(r)
         | 
| 187 | 
            +
                    q, k, v = self.f_q(r), self.f_k(r), self.f_v(r)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    if n_query_pad != 0:
         | 
| 190 | 
            +
                        q = F.pad(q, (0, 0, 0, n_query_pad))
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    if n_key_pad != 0:
         | 
| 193 | 
            +
                        k = F.pad(k, (0, 0, 0, n_key_pad))
         | 
| 194 | 
            +
                        v = F.pad(v, (0, 0, 0, n_key_pad))
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
         | 
| 197 | 
            +
                    k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
         | 
| 198 | 
            +
                    v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3))
         | 
| 199 | 
            +
                    w = torch.einsum(
         | 
| 200 | 
            +
                        "bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale)
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    if hasattr(self.attn_fn, "pytorch_attn_bias"):
         | 
| 204 | 
            +
                        bias = self.attn_fn.pytorch_attn_bias
         | 
| 205 | 
            +
                        assert len(bias.shape) in {2, 3}
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                        if len(bias.shape) == 2:
         | 
| 208 | 
            +
                            w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1)
         | 
| 209 | 
            +
                        elif len(bias.shape) == 3:
         | 
| 210 | 
            +
                            w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1)
         | 
| 211 | 
            +
                    else:
         | 
| 212 | 
            +
                        w = torch.softmax(w, dim=-1)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    r = torch.einsum("bhck,bhkd->bhcd", w, v)
         | 
| 215 | 
            +
                    r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state))
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    if n_query_pad != 0:
         | 
| 218 | 
            +
                        r = r[:, :-n_query_pad]
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    assert r.shape[1] == n_context
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    r = self.f_c(r)
         | 
| 223 | 
            +
                    return m + r
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 227 | 
            +
            class FullyConnectedResblock(nn.Module):
         | 
| 228 | 
            +
                """
         | 
| 229 | 
            +
                Not imported from other files because we retain Alec's original inits.
         | 
| 230 | 
            +
                """
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                n_state: int = attr.ib()
         | 
| 233 | 
            +
                n_resblocks: int = attr.ib()
         | 
| 234 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 237 | 
            +
                    super().__init__()
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
         | 
| 240 | 
            +
                    self.f_1 = Affine(
         | 
| 241 | 
            +
                        self.n_state,
         | 
| 242 | 
            +
                        4 * self.n_state,
         | 
| 243 | 
            +
                        use_bias=True,
         | 
| 244 | 
            +
                        std=np.sqrt(2 / (4 * self.n_state)),
         | 
| 245 | 
            +
                        device=self.device,
         | 
| 246 | 
            +
                    )
         | 
| 247 | 
            +
                    self.f_2 = Affine(
         | 
| 248 | 
            +
                        4 * self.n_state,
         | 
| 249 | 
            +
                        self.n_state,
         | 
| 250 | 
            +
                        use_bias=True,
         | 
| 251 | 
            +
                        std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2),
         | 
| 252 | 
            +
                        device=self.device,
         | 
| 253 | 
            +
                    )  # XXX
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def forward(self, m: torch.Tensor) -> torch.Tensor:
         | 
| 256 | 
            +
                    r = m
         | 
| 257 | 
            +
                    r = self.ln(r)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    r = self.f_2(F.gelu(self.f_1(r)))
         | 
| 260 | 
            +
                    return m + r
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 264 | 
            +
            class TransformerBlock(nn.Module):
         | 
| 265 | 
            +
                n_state: int = attr.ib()
         | 
| 266 | 
            +
                n_resblocks: int = attr.ib()
         | 
| 267 | 
            +
                attn_fn: AttentionInfo = attr.ib()
         | 
| 268 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 271 | 
            +
                    super().__init__()
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    self.f_attn = AttentionResblock(
         | 
| 274 | 
            +
                        self.n_state,
         | 
| 275 | 
            +
                        self.n_resblocks,
         | 
| 276 | 
            +
                        self.attn_fn,
         | 
| 277 | 
            +
                        self.device,
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                    self.f_mlp = FullyConnectedResblock(self.n_state, self.n_resblocks, self.device)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 282 | 
            +
                    return self.f_mlp(self.f_attn(x))
         | 
| 283 | 
            +
             | 
| 284 | 
            +
             | 
| 285 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 286 | 
            +
            class TextFeatureExtractor(nn.Module):
         | 
| 287 | 
            +
                n_state: int = attr.ib()
         | 
| 288 | 
            +
                n_embd: int = attr.ib()
         | 
| 289 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 292 | 
            +
                    super().__init__()
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
         | 
| 295 | 
            +
                    self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                def forward(
         | 
| 298 | 
            +
                    self, text: torch.Tensor, text_len: torch.Tensor, return_probe_features: bool = False
         | 
| 299 | 
            +
                ) -> torch.Tensor:
         | 
| 300 | 
            +
                    if len(text.shape) != 3:
         | 
| 301 | 
            +
                        raise ValueError("expected text to be 3d")
         | 
| 302 | 
            +
                    if len(text_len.shape) != 1:
         | 
| 303 | 
            +
                        raise ValueError("expected text length to be 1d")
         | 
| 304 | 
            +
                    if text.shape[0] != text_len.shape[0]:
         | 
| 305 | 
            +
                        raise ValueError("text and text_len have inconsistent batch dimensions")
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    index = (text_len - 1)[:, None, None].expand(-1, 1, text.shape[2])
         | 
| 308 | 
            +
                    x = torch.gather(text, dim=1, index=index)
         | 
| 309 | 
            +
                    assert list(x.shape) == [text.shape[0], 1, text.shape[2]]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if return_probe_features:
         | 
| 312 | 
            +
                        return x[:, 0]
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    x = self.ln(x)
         | 
| 315 | 
            +
                    return self.f(x[:, 0])
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 319 | 
            +
            class ImageFeatureExtractor(nn.Module):
         | 
| 320 | 
            +
                n_state: int = attr.ib()
         | 
| 321 | 
            +
                n_embd: int = attr.ib()
         | 
| 322 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 325 | 
            +
                    super().__init__()
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device)
         | 
| 328 | 
            +
                    self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def forward(self, x: torch.Tensor, return_probe_features: bool = False) -> torch.Tensor:
         | 
| 331 | 
            +
                    if return_probe_features:
         | 
| 332 | 
            +
                        return x[:, 0]
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    x = self.ln(x[:, :1])
         | 
| 335 | 
            +
                    return self.f(x[:, 0])
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 339 | 
            +
            class TextEncoder(nn.Module):
         | 
| 340 | 
            +
                n_bpe_vocab: int = attr.ib()
         | 
| 341 | 
            +
                max_text_len: int = attr.ib()
         | 
| 342 | 
            +
                n_embd: int = attr.ib()
         | 
| 343 | 
            +
                n_head: int = attr.ib()
         | 
| 344 | 
            +
                n_xf_blocks: int = attr.ib()
         | 
| 345 | 
            +
                n_head_state: int = attr.ib(default=64)
         | 
| 346 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 347 | 
            +
                block_size: int = attr.ib(init=False, default=32)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 350 | 
            +
                    super().__init__()
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    self.n_state = self.n_head * self.n_head_state
         | 
| 353 | 
            +
                    n_rounded_context = self.block_size * int(math.ceil(self.max_text_len / self.block_size))
         | 
| 354 | 
            +
                    n_pad = n_rounded_context - self.max_text_len
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    args = (
         | 
| 357 | 
            +
                        n_rounded_context,
         | 
| 358 | 
            +
                        n_rounded_context,
         | 
| 359 | 
            +
                        self.block_size,
         | 
| 360 | 
            +
                        self.n_head,
         | 
| 361 | 
            +
                        False,
         | 
| 362 | 
            +
                        n_pad,
         | 
| 363 | 
            +
                        n_pad,
         | 
| 364 | 
            +
                    )
         | 
| 365 | 
            +
                    mask = DenseCausalAttentionMask(*args)
         | 
| 366 | 
            +
                    attn_fn = to_attention_info(mask)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    m = 1 - make_full_layout(mask).astype(np.float32)
         | 
| 369 | 
            +
                    m[m == 1] = -1e10
         | 
| 370 | 
            +
                    attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    blocks: List[Tuple[str, nn.Module]] = [
         | 
| 373 | 
            +
                        (
         | 
| 374 | 
            +
                            "input",
         | 
| 375 | 
            +
                            TextEmbedding(
         | 
| 376 | 
            +
                                self.n_bpe_vocab, self.max_text_len, self.n_state, device=self.device
         | 
| 377 | 
            +
                            ),
         | 
| 378 | 
            +
                        )
         | 
| 379 | 
            +
                    ]
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    for i in range(self.n_xf_blocks):
         | 
| 382 | 
            +
                        blocks.append(
         | 
| 383 | 
            +
                            (
         | 
| 384 | 
            +
                                f"block_{i}",
         | 
| 385 | 
            +
                                TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device),
         | 
| 386 | 
            +
                            )
         | 
| 387 | 
            +
                        )
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    blocks.append(
         | 
| 390 | 
            +
                        ("output", TextFeatureExtractor(self.n_state, self.n_embd, device=self.device))
         | 
| 391 | 
            +
                    )
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    self.blocks = nn.ModuleDict(OrderedDict(blocks))
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                def forward(
         | 
| 396 | 
            +
                    self,
         | 
| 397 | 
            +
                    text: torch.Tensor,
         | 
| 398 | 
            +
                    text_len: torch.Tensor,
         | 
| 399 | 
            +
                    return_probe_features: bool = False,
         | 
| 400 | 
            +
                ) -> torch.Tensor:
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    n_batch = text.shape[0]
         | 
| 403 | 
            +
                    h = self.blocks["input"](text)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    for i in range(self.n_xf_blocks):
         | 
| 406 | 
            +
                        h = self.blocks[f"block_{i}"](h)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    h = self.blocks["output"](h, text_len, return_probe_features=return_probe_features)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    assert list(h.shape) == [
         | 
| 411 | 
            +
                        n_batch,
         | 
| 412 | 
            +
                        self.n_embd if not return_probe_features else self.n_state,
         | 
| 413 | 
            +
                    ]
         | 
| 414 | 
            +
                    return h
         | 
| 415 | 
            +
             | 
| 416 | 
            +
             | 
| 417 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 418 | 
            +
            class ImageEncoder(nn.Module):
         | 
| 419 | 
            +
                image_size: int = attr.ib()
         | 
| 420 | 
            +
                patch_size: int = attr.ib()
         | 
| 421 | 
            +
                n_embd: int = attr.ib()
         | 
| 422 | 
            +
                n_head: int = attr.ib()
         | 
| 423 | 
            +
                n_xf_blocks: int = attr.ib()
         | 
| 424 | 
            +
                n_head_state: int = attr.ib(default=64)
         | 
| 425 | 
            +
                n_timestep: int = attr.ib(default=0)
         | 
| 426 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 427 | 
            +
                block_size: int = attr.ib(init=False, default=32)
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 430 | 
            +
                    super().__init__()
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    self.n_state = self.n_head * self.n_head_state
         | 
| 433 | 
            +
                    self.n_context = 1 + (self.image_size // self.patch_size) ** 2
         | 
| 434 | 
            +
                    n_rounded_context = self.block_size * int(math.ceil(self.n_context / self.block_size))
         | 
| 435 | 
            +
                    n_pad = n_rounded_context - self.n_context
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    args = (
         | 
| 438 | 
            +
                        n_rounded_context,
         | 
| 439 | 
            +
                        n_rounded_context,
         | 
| 440 | 
            +
                        self.block_size,
         | 
| 441 | 
            +
                        self.n_head,
         | 
| 442 | 
            +
                        False,
         | 
| 443 | 
            +
                        n_pad,
         | 
| 444 | 
            +
                        n_pad,
         | 
| 445 | 
            +
                    )
         | 
| 446 | 
            +
                    mask = DenseAttentionMask(*args)
         | 
| 447 | 
            +
                    attn_fn = to_attention_info(mask)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    m = 1 - make_full_layout(mask).astype(np.float32)
         | 
| 450 | 
            +
                    m[m == 1] = -1e10
         | 
| 451 | 
            +
                    attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    blocks: List[Tuple[str, nn.Module]] = [
         | 
| 454 | 
            +
                        (
         | 
| 455 | 
            +
                            "input",
         | 
| 456 | 
            +
                            ImageEmbedding(
         | 
| 457 | 
            +
                                self.image_size,
         | 
| 458 | 
            +
                                self.patch_size,
         | 
| 459 | 
            +
                                self.n_state,
         | 
| 460 | 
            +
                                n_timestep=self.n_timestep,
         | 
| 461 | 
            +
                                device=self.device,
         | 
| 462 | 
            +
                            ),
         | 
| 463 | 
            +
                        )
         | 
| 464 | 
            +
                    ]
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    for i in range(self.n_xf_blocks):
         | 
| 467 | 
            +
                        blocks.append(
         | 
| 468 | 
            +
                            (
         | 
| 469 | 
            +
                                f"block_{i}",
         | 
| 470 | 
            +
                                TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device),
         | 
| 471 | 
            +
                            )
         | 
| 472 | 
            +
                        )
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    blocks.append(("output", ImageFeatureExtractor(self.n_state, self.n_embd, self.device)))
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    self.blocks = nn.ModuleDict(OrderedDict(blocks))
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                def forward(
         | 
| 479 | 
            +
                    self,
         | 
| 480 | 
            +
                    image: torch.Tensor,
         | 
| 481 | 
            +
                    timesteps: Optional[torch.Tensor] = None,
         | 
| 482 | 
            +
                    return_probe_features: bool = False,
         | 
| 483 | 
            +
                ) -> torch.Tensor:
         | 
| 484 | 
            +
                    n_batch = image.shape[0]
         | 
| 485 | 
            +
                    h = self.blocks["input"](image, t=timesteps)
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    for i in range(self.n_xf_blocks):
         | 
| 488 | 
            +
                        h = self.blocks[f"block_{i}"](h)
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    h = self.blocks["output"](h, return_probe_features=return_probe_features)
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    assert list(h.shape) == [
         | 
| 493 | 
            +
                        n_batch,
         | 
| 494 | 
            +
                        self.n_embd if not return_probe_features else self.n_state,
         | 
| 495 | 
            +
                    ]
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    return h
         | 
    	
        glide_text2im/clip/model_creation.py
    ADDED
    
    | @@ -0,0 +1,117 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from functools import lru_cache
         | 
| 3 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Tuple
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import attr
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn as nn
         | 
| 9 | 
            +
            import yaml
         | 
| 10 | 
            +
            from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .encoders import ImageEncoder, TextEncoder
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            @lru_cache()
         | 
| 16 | 
            +
            def default_config_path() -> str:
         | 
| 17 | 
            +
                return os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.yaml")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            @attr.s
         | 
| 21 | 
            +
            class CLIPModel:
         | 
| 22 | 
            +
                config: Dict[str, Any] = attr.ib()
         | 
| 23 | 
            +
                text_encoder: nn.Module = attr.ib()
         | 
| 24 | 
            +
                image_encoder: nn.Module = attr.ib()
         | 
| 25 | 
            +
                logit_scale: torch.Tensor = attr.ib()
         | 
| 26 | 
            +
                device: torch.device = attr.ib()
         | 
| 27 | 
            +
                tokenizer: SimpleTokenizer = attr.ib()
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def encode_prompts(self, prompts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 30 | 
            +
                    tokens = []
         | 
| 31 | 
            +
                    lens = []
         | 
| 32 | 
            +
                    for prompt in prompts:
         | 
| 33 | 
            +
                        sub_tokens, sub_len = self.tokenizer.padded_tokens_and_len(
         | 
| 34 | 
            +
                            self.tokenizer.encode(prompt), self.text_encoder.max_text_len
         | 
| 35 | 
            +
                        )
         | 
| 36 | 
            +
                        tokens.append(sub_tokens)
         | 
| 37 | 
            +
                        lens.append(sub_len)
         | 
| 38 | 
            +
                    return (
         | 
| 39 | 
            +
                        torch.tensor(tokens).to(dtype=torch.long, device=self.device),
         | 
| 40 | 
            +
                        torch.tensor(lens).to(dtype=torch.long, device=self.device),
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def text_embeddings(self, prompts: List[str]) -> torch.Tensor:
         | 
| 44 | 
            +
                    tokens, lens = self.encode_prompts(prompts)
         | 
| 45 | 
            +
                    z_t = self.text_encoder(tokens, lens)
         | 
| 46 | 
            +
                    return z_t / (torch.linalg.norm(z_t, dim=-1, keepdim=True) + 1e-12)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def image_embeddings(self, images: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
         | 
| 49 | 
            +
                    z_i = self.image_encoder((images + 1) * 127.5, t)
         | 
| 50 | 
            +
                    return z_i / (torch.linalg.norm(z_i, dim=-1, keepdim=True) + 1e-12)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def cond_fn(self, prompts: List[str], grad_scale: float) -> Callable[..., torch.Tensor]:
         | 
| 53 | 
            +
                    with torch.no_grad():
         | 
| 54 | 
            +
                        z_t = self.text_embeddings(prompts)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    def cond_fn(x, t, grad_scale=grad_scale, **kwargs):
         | 
| 57 | 
            +
                        with torch.enable_grad():
         | 
| 58 | 
            +
                            x_var = x.detach().requires_grad_(True)
         | 
| 59 | 
            +
                            z_i = self.image_embeddings(x_var, t)
         | 
| 60 | 
            +
                            loss = torch.exp(self.logit_scale) * (z_t * z_i).sum()
         | 
| 61 | 
            +
                            grad = torch.autograd.grad(loss, x_var)[0].detach()
         | 
| 62 | 
            +
                        return grad * grad_scale
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    return cond_fn
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def create_clip_model(
         | 
| 68 | 
            +
                config_path: Optional[str] = None,
         | 
| 69 | 
            +
                device: Optional[torch.device] = None,
         | 
| 70 | 
            +
                tokenizer: Optional[SimpleTokenizer] = None,
         | 
| 71 | 
            +
            ) -> CLIPModel:
         | 
| 72 | 
            +
                if config_path is None:
         | 
| 73 | 
            +
                    config_path = default_config_path()
         | 
| 74 | 
            +
                if device is None:
         | 
| 75 | 
            +
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 76 | 
            +
                if tokenizer is None:
         | 
| 77 | 
            +
                    tokenizer = SimpleTokenizer()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                with open(config_path, "r") as f:
         | 
| 80 | 
            +
                    config = yaml.load(f, Loader=yaml.SafeLoader)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                text_encoder = TextEncoder(
         | 
| 83 | 
            +
                    n_bpe_vocab=config["n_vocab"],
         | 
| 84 | 
            +
                    max_text_len=config["max_text_len"],
         | 
| 85 | 
            +
                    n_embd=config["n_embd"],
         | 
| 86 | 
            +
                    n_head=config["n_head_text"],
         | 
| 87 | 
            +
                    n_xf_blocks=config["n_xf_blocks_text"],
         | 
| 88 | 
            +
                    n_head_state=config["n_head_state_text"],
         | 
| 89 | 
            +
                    device=device,
         | 
| 90 | 
            +
                )
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                image_encoder = ImageEncoder(
         | 
| 93 | 
            +
                    image_size=config["image_size"],
         | 
| 94 | 
            +
                    patch_size=config["patch_size"],
         | 
| 95 | 
            +
                    n_embd=config["n_embd"],
         | 
| 96 | 
            +
                    n_head=config["n_head_image"],
         | 
| 97 | 
            +
                    n_xf_blocks=config["n_xf_blocks_image"],
         | 
| 98 | 
            +
                    n_head_state=config["n_head_state_image"],
         | 
| 99 | 
            +
                    n_timestep=config["n_timesteps"],
         | 
| 100 | 
            +
                    device=device,
         | 
| 101 | 
            +
                )
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                logit_scale = torch.tensor(
         | 
| 104 | 
            +
                    np.log(config["logit_scale"]),
         | 
| 105 | 
            +
                    dtype=torch.float32,
         | 
| 106 | 
            +
                    device=device,
         | 
| 107 | 
            +
                    requires_grad=False,
         | 
| 108 | 
            +
                )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                return CLIPModel(
         | 
| 111 | 
            +
                    config=config,
         | 
| 112 | 
            +
                    text_encoder=text_encoder,
         | 
| 113 | 
            +
                    image_encoder=image_encoder,
         | 
| 114 | 
            +
                    logit_scale=logit_scale,
         | 
| 115 | 
            +
                    device=device,
         | 
| 116 | 
            +
                    tokenizer=tokenizer,
         | 
| 117 | 
            +
                )
         | 
    	
        glide_text2im/clip/utils.py
    ADDED
    
    | @@ -0,0 +1,97 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from typing import Callable, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import attr
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            FilterFn = Callable[[torch.Tensor], torch.Tensor]
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class ZeroKeyBiasGrad(torch.autograd.Function):
         | 
| 13 | 
            +
                @staticmethod
         | 
| 14 | 
            +
                def forward(ctx, x):
         | 
| 15 | 
            +
                    return x
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                @staticmethod
         | 
| 18 | 
            +
                def backward(ctx, output_grad):
         | 
| 19 | 
            +
                    output_grad = output_grad.clone()
         | 
| 20 | 
            +
                    output_grad.chunk(3)[1].zero_()
         | 
| 21 | 
            +
                    return output_grad
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
         | 
| 25 | 
            +
                return ZeroKeyBiasGrad.apply(x)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 29 | 
            +
            class LayerNorm(nn.Module):
         | 
| 30 | 
            +
                n_state: int = attr.ib()
         | 
| 31 | 
            +
                eps: float = attr.ib(default=1e-6)
         | 
| 32 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 35 | 
            +
                    super().__init__()
         | 
| 36 | 
            +
                    self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device))
         | 
| 37 | 
            +
                    self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device))
         | 
| 38 | 
            +
                    self.g.weight_decay_level = "disable"  # type: ignore
         | 
| 39 | 
            +
                    self.b.weight_decay_level = "disable"  # type: ignore
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 42 | 
            +
                    return F.layer_norm(
         | 
| 43 | 
            +
                        x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            @attr.s(eq=False, repr=False)
         | 
| 48 | 
            +
            class Affine(nn.Module):
         | 
| 49 | 
            +
                n_in: int = attr.ib()
         | 
| 50 | 
            +
                n_out: int = attr.ib()
         | 
| 51 | 
            +
                use_bias: bool = attr.ib(default=True)
         | 
| 52 | 
            +
                use_admnet_init: bool = attr.ib(default=False)
         | 
| 53 | 
            +
                std: Optional[float] = attr.ib(default=None)
         | 
| 54 | 
            +
                extra_init_scale: Optional[float] = attr.ib(default=None)
         | 
| 55 | 
            +
                bias_filter_fn: FilterFn = attr.ib(default=lambda x: x)
         | 
| 56 | 
            +
                device: torch.device = attr.ib(default=torch.device("cuda"))
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __attrs_post_init__(self) -> None:
         | 
| 59 | 
            +
                    super().__init__()
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    if not self.use_admnet_init:
         | 
| 62 | 
            +
                        self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
         | 
| 63 | 
            +
                        self.std = (
         | 
| 64 | 
            +
                            self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
         | 
| 65 | 
            +
                        )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                        w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
         | 
| 68 | 
            +
                        self.w = nn.Parameter(w)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        if self.use_bias:
         | 
| 71 | 
            +
                            self.b = nn.Parameter(
         | 
| 72 | 
            +
                                torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
         | 
| 73 | 
            +
                            )
         | 
| 74 | 
            +
                            self.b.weight_decay_level = "disable"  # type: ignore
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        if self.extra_init_scale is not None:
         | 
| 77 | 
            +
                            raise ValueError("extra_init_scale incompatible with admnet init")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                        if self.use_bias:
         | 
| 82 | 
            +
                            b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        self.w = nn.Parameter(w)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                        if self.use_bias:
         | 
| 87 | 
            +
                            self.b = nn.Parameter(b)
         | 
| 88 | 
            +
                            self.b.weight_decay_level = "disable"  # type: ignore
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 91 | 
            +
                    w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype)
         | 
| 92 | 
            +
                    b = (
         | 
| 93 | 
            +
                        self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype))
         | 
| 94 | 
            +
                        if self.use_bias
         | 
| 95 | 
            +
                        else None
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    return F.linear(x, w, b)
         | 
    	
        glide_text2im/download.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from functools import lru_cache
         | 
| 3 | 
            +
            from typing import Dict, Optional
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import requests
         | 
| 6 | 
            +
            import torch as th
         | 
| 7 | 
            +
            from filelock import FileLock
         | 
| 8 | 
            +
            from tqdm.auto import tqdm
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            MODEL_PATHS = {
         | 
| 11 | 
            +
                "base": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt",
         | 
| 12 | 
            +
                "upsample": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt",
         | 
| 13 | 
            +
                "base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt",
         | 
| 14 | 
            +
                "upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt",
         | 
| 15 | 
            +
                "clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt",
         | 
| 16 | 
            +
                "clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt",
         | 
| 17 | 
            +
            }
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            @lru_cache()
         | 
| 21 | 
            +
            def default_cache_dir() -> str:
         | 
| 22 | 
            +
                return os.path.join(os.path.abspath(os.getcwd()), "glide_model_cache")
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def fetch_file_cached(
         | 
| 26 | 
            +
                url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
         | 
| 27 | 
            +
            ) -> str:
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Download the file at the given URL into a local file and return the path.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                If cache_dir is specified, it will be used to download the files.
         | 
| 32 | 
            +
                Otherwise, default_cache_dir() is used.
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                if cache_dir is None:
         | 
| 35 | 
            +
                    cache_dir = default_cache_dir()
         | 
| 36 | 
            +
                os.makedirs(cache_dir, exist_ok=True)
         | 
| 37 | 
            +
                response = requests.get(url, stream=True)
         | 
| 38 | 
            +
                size = int(response.headers.get("content-length", "0"))
         | 
| 39 | 
            +
                local_path = os.path.join(cache_dir, url.split("/")[-1])
         | 
| 40 | 
            +
                with FileLock(local_path + ".lock"):
         | 
| 41 | 
            +
                    if os.path.exists(local_path):
         | 
| 42 | 
            +
                        return local_path
         | 
| 43 | 
            +
                    if progress:
         | 
| 44 | 
            +
                        pbar = tqdm(total=size, unit="iB", unit_scale=True)
         | 
| 45 | 
            +
                    tmp_path = local_path + ".tmp"
         | 
| 46 | 
            +
                    with open(tmp_path, "wb") as f:
         | 
| 47 | 
            +
                        for chunk in response.iter_content(chunk_size):
         | 
| 48 | 
            +
                            if progress:
         | 
| 49 | 
            +
                                pbar.update(len(chunk))
         | 
| 50 | 
            +
                            f.write(chunk)
         | 
| 51 | 
            +
                    os.rename(tmp_path, local_path)
         | 
| 52 | 
            +
                    if progress:
         | 
| 53 | 
            +
                        pbar.close()
         | 
| 54 | 
            +
                    return local_path
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def load_checkpoint(
         | 
| 58 | 
            +
                checkpoint_name: str,
         | 
| 59 | 
            +
                device: th.device,
         | 
| 60 | 
            +
                progress: bool = True,
         | 
| 61 | 
            +
                cache_dir: Optional[str] = None,
         | 
| 62 | 
            +
                chunk_size: int = 4096,
         | 
| 63 | 
            +
            ) -> Dict[str, th.Tensor]:
         | 
| 64 | 
            +
                if checkpoint_name not in MODEL_PATHS:
         | 
| 65 | 
            +
                    raise ValueError(
         | 
| 66 | 
            +
                        f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                path = fetch_file_cached(
         | 
| 69 | 
            +
                    MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
         | 
| 70 | 
            +
                )
         | 
| 71 | 
            +
                return th.load(path, map_location=device)
         | 
    	
        glide_text2im/fp16_util.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Helpers to inference with 16-bit precision.
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def convert_module_to_f16(l):
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                Convert primitive modules to float16.
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
         | 
| 13 | 
            +
                    l.weight.data = l.weight.data.half()
         | 
| 14 | 
            +
                    if l.bias is not None:
         | 
| 15 | 
            +
                        l.bias.data = l.bias.data.half()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def convert_module_to_f32(l):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                Convert primitive modules to float32, undoing convert_module_to_f16().
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
         | 
| 23 | 
            +
                    l.weight.data = l.weight.data.float()
         | 
| 24 | 
            +
                    if l.bias is not None:
         | 
| 25 | 
            +
                        l.bias.data = l.bias.data.float()
         | 
    	
        glide_text2im/gaussian_diffusion.py
    ADDED
    
    | @@ -0,0 +1,639 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Simplified from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py.
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import torch as th
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
         | 
| 12 | 
            +
                betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
         | 
| 13 | 
            +
                warmup_time = int(num_diffusion_timesteps * warmup_frac)
         | 
| 14 | 
            +
                betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
         | 
| 15 | 
            +
                return betas
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                This is the deprecated API for creating beta schedules.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                See get_named_beta_schedule() for the new library of schedules.
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                if beta_schedule == "quad":
         | 
| 25 | 
            +
                    betas = (
         | 
| 26 | 
            +
                        np.linspace(
         | 
| 27 | 
            +
                            beta_start ** 0.5,
         | 
| 28 | 
            +
                            beta_end ** 0.5,
         | 
| 29 | 
            +
                            num_diffusion_timesteps,
         | 
| 30 | 
            +
                            dtype=np.float64,
         | 
| 31 | 
            +
                        )
         | 
| 32 | 
            +
                        ** 2
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                elif beta_schedule == "linear":
         | 
| 35 | 
            +
                    betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
         | 
| 36 | 
            +
                elif beta_schedule == "warmup10":
         | 
| 37 | 
            +
                    betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
         | 
| 38 | 
            +
                elif beta_schedule == "warmup50":
         | 
| 39 | 
            +
                    betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
         | 
| 40 | 
            +
                elif beta_schedule == "const":
         | 
| 41 | 
            +
                    betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
         | 
| 42 | 
            +
                elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
         | 
| 43 | 
            +
                    betas = 1.0 / np.linspace(
         | 
| 44 | 
            +
                        num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
                else:
         | 
| 47 | 
            +
                    raise NotImplementedError(beta_schedule)
         | 
| 48 | 
            +
                assert betas.shape == (num_diffusion_timesteps,)
         | 
| 49 | 
            +
                return betas
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
                Get a pre-defined beta schedule for the given name.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                The beta schedule library consists of beta schedules which remain similar
         | 
| 57 | 
            +
                in the limit of num_diffusion_timesteps.
         | 
| 58 | 
            +
                Beta schedules may be added, but should not be removed or changed once
         | 
| 59 | 
            +
                they are committed to maintain backwards compatibility.
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                if schedule_name == "linear":
         | 
| 62 | 
            +
                    # Linear schedule from Ho et al, extended to work for any number of
         | 
| 63 | 
            +
                    # diffusion steps.
         | 
| 64 | 
            +
                    scale = 1000 / num_diffusion_timesteps
         | 
| 65 | 
            +
                    return get_beta_schedule(
         | 
| 66 | 
            +
                        "linear",
         | 
| 67 | 
            +
                        beta_start=scale * 0.0001,
         | 
| 68 | 
            +
                        beta_end=scale * 0.02,
         | 
| 69 | 
            +
                        num_diffusion_timesteps=num_diffusion_timesteps,
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
                elif schedule_name == "squaredcos_cap_v2":
         | 
| 72 | 
            +
                    return betas_for_alpha_bar(
         | 
| 73 | 
            +
                        num_diffusion_timesteps,
         | 
| 74 | 
            +
                        lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
                else:
         | 
| 77 | 
            +
                    raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function,
         | 
| 83 | 
            +
                which defines the cumulative product of (1-beta) over time from t = [0,1].
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                :param num_diffusion_timesteps: the number of betas to produce.
         | 
| 86 | 
            +
                :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
         | 
| 87 | 
            +
                                  produces the cumulative product of (1-beta) up to that
         | 
| 88 | 
            +
                                  part of the diffusion process.
         | 
| 89 | 
            +
                :param max_beta: the maximum beta to use; use values lower than 1 to
         | 
| 90 | 
            +
                                 prevent singularities.
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                betas = []
         | 
| 93 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 94 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 95 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 96 | 
            +
                    betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
         | 
| 97 | 
            +
                return np.array(betas)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            class GaussianDiffusion:
         | 
| 101 | 
            +
                """
         | 
| 102 | 
            +
                Utilities for training and sampling diffusion models.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Original ported from this codebase:
         | 
| 105 | 
            +
                https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                :param betas: a 1-D numpy array of betas for each diffusion timestep,
         | 
| 108 | 
            +
                              starting at T and going to 1.
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def __init__(
         | 
| 112 | 
            +
                    self,
         | 
| 113 | 
            +
                    *,
         | 
| 114 | 
            +
                    betas,
         | 
| 115 | 
            +
                ):
         | 
| 116 | 
            +
                    # Use float64 for accuracy.
         | 
| 117 | 
            +
                    betas = np.array(betas, dtype=np.float64)
         | 
| 118 | 
            +
                    self.betas = betas
         | 
| 119 | 
            +
                    assert len(betas.shape) == 1, "betas must be 1-D"
         | 
| 120 | 
            +
                    assert (betas > 0).all() and (betas <= 1).all()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.num_timesteps = int(betas.shape[0])
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    alphas = 1.0 - betas
         | 
| 125 | 
            +
                    self.alphas_cumprod = np.cumprod(alphas, axis=0)
         | 
| 126 | 
            +
                    self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
         | 
| 127 | 
            +
                    self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
         | 
| 128 | 
            +
                    assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # calculations for diffusion q(x_t | x_{t-1}) and others
         | 
| 131 | 
            +
                    self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
         | 
| 132 | 
            +
                    self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
         | 
| 133 | 
            +
                    self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
         | 
| 134 | 
            +
                    self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
         | 
| 135 | 
            +
                    self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # calculations for posterior q(x_{t-1} | x_t, x_0)
         | 
| 138 | 
            +
                    self.posterior_variance = (
         | 
| 139 | 
            +
                        betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
                    # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
         | 
| 142 | 
            +
                    self.posterior_log_variance_clipped = np.log(
         | 
| 143 | 
            +
                        np.append(self.posterior_variance[1], self.posterior_variance[1:])
         | 
| 144 | 
            +
                    )
         | 
| 145 | 
            +
                    self.posterior_mean_coef1 = (
         | 
| 146 | 
            +
                        betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
         | 
| 147 | 
            +
                    )
         | 
| 148 | 
            +
                    self.posterior_mean_coef2 = (
         | 
| 149 | 
            +
                        (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
         | 
| 150 | 
            +
                    )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def q_mean_variance(self, x_start, t):
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    Get the distribution q(x_t | x_0).
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    :param x_start: the [N x C x ...] tensor of noiseless inputs.
         | 
| 157 | 
            +
                    :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
         | 
| 158 | 
            +
                    :return: A tuple (mean, variance, log_variance), all of x_start's shape.
         | 
| 159 | 
            +
                    """
         | 
| 160 | 
            +
                    mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
         | 
| 161 | 
            +
                    variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
         | 
| 162 | 
            +
                    log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
         | 
| 163 | 
            +
                    return mean, variance, log_variance
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def q_sample(self, x_start, t, noise=None):
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    Diffuse the data for a given number of diffusion steps.
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    In other words, sample from q(x_t | x_0).
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    :param x_start: the initial data batch.
         | 
| 172 | 
            +
                    :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
         | 
| 173 | 
            +
                    :param noise: if specified, the split-out normal noise.
         | 
| 174 | 
            +
                    :return: A noisy version of x_start.
         | 
| 175 | 
            +
                    """
         | 
| 176 | 
            +
                    if noise is None:
         | 
| 177 | 
            +
                        noise = th.randn_like(x_start)
         | 
| 178 | 
            +
                    assert noise.shape == x_start.shape
         | 
| 179 | 
            +
                    return (
         | 
| 180 | 
            +
                        _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
         | 
| 181 | 
            +
                        + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
         | 
| 182 | 
            +
                    )
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def q_posterior_mean_variance(self, x_start, x_t, t):
         | 
| 185 | 
            +
                    """
         | 
| 186 | 
            +
                    Compute the mean and variance of the diffusion posterior:
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                        q(x_{t-1} | x_t, x_0)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    """
         | 
| 191 | 
            +
                    assert x_start.shape == x_t.shape
         | 
| 192 | 
            +
                    posterior_mean = (
         | 
| 193 | 
            +
                        _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
         | 
| 194 | 
            +
                        + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
         | 
| 195 | 
            +
                    )
         | 
| 196 | 
            +
                    posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
         | 
| 197 | 
            +
                    posterior_log_variance_clipped = _extract_into_tensor(
         | 
| 198 | 
            +
                        self.posterior_log_variance_clipped, t, x_t.shape
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
                    assert (
         | 
| 201 | 
            +
                        posterior_mean.shape[0]
         | 
| 202 | 
            +
                        == posterior_variance.shape[0]
         | 
| 203 | 
            +
                        == posterior_log_variance_clipped.shape[0]
         | 
| 204 | 
            +
                        == x_start.shape[0]
         | 
| 205 | 
            +
                    )
         | 
| 206 | 
            +
                    return posterior_mean, posterior_variance, posterior_log_variance_clipped
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
         | 
| 209 | 
            +
                    """
         | 
| 210 | 
            +
                    Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
         | 
| 211 | 
            +
                    the initial x, x_0.
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    :param model: the model, which takes a signal and a batch of timesteps
         | 
| 214 | 
            +
                                  as input.
         | 
| 215 | 
            +
                    :param x: the [N x C x ...] tensor at time t.
         | 
| 216 | 
            +
                    :param t: a 1-D Tensor of timesteps.
         | 
| 217 | 
            +
                    :param clip_denoised: if True, clip the denoised signal into [-1, 1].
         | 
| 218 | 
            +
                    :param denoised_fn: if not None, a function which applies to the
         | 
| 219 | 
            +
                        x_start prediction before it is used to sample. Applies before
         | 
| 220 | 
            +
                        clip_denoised.
         | 
| 221 | 
            +
                    :param model_kwargs: if not None, a dict of extra keyword arguments to
         | 
| 222 | 
            +
                        pass to the model. This can be used for conditioning.
         | 
| 223 | 
            +
                    :return: a dict with the following keys:
         | 
| 224 | 
            +
                             - 'mean': the model mean output.
         | 
| 225 | 
            +
                             - 'variance': the model variance output.
         | 
| 226 | 
            +
                             - 'log_variance': the log of 'variance'.
         | 
| 227 | 
            +
                             - 'pred_xstart': the prediction for x_0.
         | 
| 228 | 
            +
                    """
         | 
| 229 | 
            +
                    if model_kwargs is None:
         | 
| 230 | 
            +
                        model_kwargs = {}
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    B, C = x.shape[:2]
         | 
| 233 | 
            +
                    assert t.shape == (B,)
         | 
| 234 | 
            +
                    model_output = model(x, t, **model_kwargs)
         | 
| 235 | 
            +
                    if isinstance(model_output, tuple):
         | 
| 236 | 
            +
                        model_output, extra = model_output
         | 
| 237 | 
            +
                    else:
         | 
| 238 | 
            +
                        extra = None
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    assert model_output.shape == (B, C * 2, *x.shape[2:])
         | 
| 241 | 
            +
                    model_output, model_var_values = th.split(model_output, C, dim=1)
         | 
| 242 | 
            +
                    min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
         | 
| 243 | 
            +
                    max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
         | 
| 244 | 
            +
                    # The model_var_values is [-1, 1] for [min_var, max_var].
         | 
| 245 | 
            +
                    frac = (model_var_values + 1) / 2
         | 
| 246 | 
            +
                    model_log_variance = frac * max_log + (1 - frac) * min_log
         | 
| 247 | 
            +
                    model_variance = th.exp(model_log_variance)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    def process_xstart(x):
         | 
| 250 | 
            +
                        if denoised_fn is not None:
         | 
| 251 | 
            +
                            x = denoised_fn(x)
         | 
| 252 | 
            +
                        if clip_denoised:
         | 
| 253 | 
            +
                            return x.clamp(-1, 1)
         | 
| 254 | 
            +
                        return x
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
         | 
| 257 | 
            +
                    model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
         | 
| 260 | 
            +
                    return {
         | 
| 261 | 
            +
                        "mean": model_mean,
         | 
| 262 | 
            +
                        "variance": model_variance,
         | 
| 263 | 
            +
                        "log_variance": model_log_variance,
         | 
| 264 | 
            +
                        "pred_xstart": pred_xstart,
         | 
| 265 | 
            +
                        "extra": extra,
         | 
| 266 | 
            +
                    }
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def _predict_xstart_from_eps(self, x_t, t, eps):
         | 
| 269 | 
            +
                    assert x_t.shape == eps.shape
         | 
| 270 | 
            +
                    return (
         | 
| 271 | 
            +
                        _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
         | 
| 272 | 
            +
                        - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
         | 
| 273 | 
            +
                    )
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
         | 
| 276 | 
            +
                    return (
         | 
| 277 | 
            +
                        _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
         | 
| 278 | 
            +
                    ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
         | 
| 281 | 
            +
                    """
         | 
| 282 | 
            +
                    Compute the mean for the previous step, given a function cond_fn that
         | 
| 283 | 
            +
                    computes the gradient of a conditional log probability with respect to
         | 
| 284 | 
            +
                    x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
         | 
| 285 | 
            +
                    condition on y.
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
         | 
| 288 | 
            +
                    """
         | 
| 289 | 
            +
                    gradient = cond_fn(x, t, **model_kwargs)
         | 
| 290 | 
            +
                    new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
         | 
| 291 | 
            +
                    return new_mean
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
         | 
| 294 | 
            +
                    """
         | 
| 295 | 
            +
                    Compute what the p_mean_variance output would have been, should the
         | 
| 296 | 
            +
                    model's score function be conditioned by cond_fn.
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    See condition_mean() for details on cond_fn.
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    Unlike condition_mean(), this instead uses the conditioning strategy
         | 
| 301 | 
            +
                    from Song et al (2020).
         | 
| 302 | 
            +
                    """
         | 
| 303 | 
            +
                    alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
         | 
| 306 | 
            +
                    eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    out = p_mean_var.copy()
         | 
| 309 | 
            +
                    out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
         | 
| 310 | 
            +
                    out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
         | 
| 311 | 
            +
                    return out
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def p_sample(
         | 
| 314 | 
            +
                    self,
         | 
| 315 | 
            +
                    model,
         | 
| 316 | 
            +
                    x,
         | 
| 317 | 
            +
                    t,
         | 
| 318 | 
            +
                    clip_denoised=True,
         | 
| 319 | 
            +
                    denoised_fn=None,
         | 
| 320 | 
            +
                    cond_fn=None,
         | 
| 321 | 
            +
                    model_kwargs=None,
         | 
| 322 | 
            +
                ):
         | 
| 323 | 
            +
                    """
         | 
| 324 | 
            +
                    Sample x_{t-1} from the model at the given timestep.
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    :param model: the model to sample from.
         | 
| 327 | 
            +
                    :param x: the current tensor at x_{t-1}.
         | 
| 328 | 
            +
                    :param t: the value of t, starting at 0 for the first diffusion step.
         | 
| 329 | 
            +
                    :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
         | 
| 330 | 
            +
                    :param denoised_fn: if not None, a function which applies to the
         | 
| 331 | 
            +
                        x_start prediction before it is used to sample.
         | 
| 332 | 
            +
                    :param cond_fn: if not None, this is a gradient function that acts
         | 
| 333 | 
            +
                                    similarly to the model.
         | 
| 334 | 
            +
                    :param model_kwargs: if not None, a dict of extra keyword arguments to
         | 
| 335 | 
            +
                        pass to the model. This can be used for conditioning.
         | 
| 336 | 
            +
                    :return: a dict containing the following keys:
         | 
| 337 | 
            +
                             - 'sample': a random sample from the model.
         | 
| 338 | 
            +
                             - 'pred_xstart': a prediction of x_0.
         | 
| 339 | 
            +
                    """
         | 
| 340 | 
            +
                    out = self.p_mean_variance(
         | 
| 341 | 
            +
                        model,
         | 
| 342 | 
            +
                        x,
         | 
| 343 | 
            +
                        t,
         | 
| 344 | 
            +
                        clip_denoised=clip_denoised,
         | 
| 345 | 
            +
                        denoised_fn=denoised_fn,
         | 
| 346 | 
            +
                        model_kwargs=model_kwargs,
         | 
| 347 | 
            +
                    )
         | 
| 348 | 
            +
                    noise = th.randn_like(x)
         | 
| 349 | 
            +
                    nonzero_mask = (
         | 
| 350 | 
            +
                        (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
         | 
| 351 | 
            +
                    )  # no noise when t == 0
         | 
| 352 | 
            +
                    if cond_fn is not None:
         | 
| 353 | 
            +
                        out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
         | 
| 354 | 
            +
                    sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
         | 
| 355 | 
            +
                    return {"sample": sample, "pred_xstart": out["pred_xstart"]}
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                def p_sample_loop(
         | 
| 358 | 
            +
                    self,
         | 
| 359 | 
            +
                    model,
         | 
| 360 | 
            +
                    shape,
         | 
| 361 | 
            +
                    noise=None,
         | 
| 362 | 
            +
                    clip_denoised=True,
         | 
| 363 | 
            +
                    denoised_fn=None,
         | 
| 364 | 
            +
                    cond_fn=None,
         | 
| 365 | 
            +
                    model_kwargs=None,
         | 
| 366 | 
            +
                    device=None,
         | 
| 367 | 
            +
                    progress=False,
         | 
| 368 | 
            +
                ):
         | 
| 369 | 
            +
                    """
         | 
| 370 | 
            +
                    Generate samples from the model.
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    :param model: the model module.
         | 
| 373 | 
            +
                    :param shape: the shape of the samples, (N, C, H, W).
         | 
| 374 | 
            +
                    :param noise: if specified, the noise from the encoder to sample.
         | 
| 375 | 
            +
                                  Should be of the same shape as `shape`.
         | 
| 376 | 
            +
                    :param clip_denoised: if True, clip x_start predictions to [-1, 1].
         | 
| 377 | 
            +
                    :param denoised_fn: if not None, a function which applies to the
         | 
| 378 | 
            +
                        x_start prediction before it is used to sample.
         | 
| 379 | 
            +
                    :param cond_fn: if not None, this is a gradient function that acts
         | 
| 380 | 
            +
                                    similarly to the model.
         | 
| 381 | 
            +
                    :param model_kwargs: if not None, a dict of extra keyword arguments to
         | 
| 382 | 
            +
                        pass to the model. This can be used for conditioning.
         | 
| 383 | 
            +
                    :param device: if specified, the device to create the samples on.
         | 
| 384 | 
            +
                                   If not specified, use a model parameter's device.
         | 
| 385 | 
            +
                    :param progress: if True, show a tqdm progress bar.
         | 
| 386 | 
            +
                    :return: a non-differentiable batch of samples.
         | 
| 387 | 
            +
                    """
         | 
| 388 | 
            +
                    final = None
         | 
| 389 | 
            +
                    for sample in self.p_sample_loop_progressive(
         | 
| 390 | 
            +
                        model,
         | 
| 391 | 
            +
                        shape,
         | 
| 392 | 
            +
                        noise=noise,
         | 
| 393 | 
            +
                        clip_denoised=clip_denoised,
         | 
| 394 | 
            +
                        denoised_fn=denoised_fn,
         | 
| 395 | 
            +
                        cond_fn=cond_fn,
         | 
| 396 | 
            +
                        model_kwargs=model_kwargs,
         | 
| 397 | 
            +
                        device=device,
         | 
| 398 | 
            +
                        progress=progress,
         | 
| 399 | 
            +
                    ):
         | 
| 400 | 
            +
                        final = sample
         | 
| 401 | 
            +
                    return final["sample"]
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                def p_sample_loop_progressive(
         | 
| 404 | 
            +
                    self,
         | 
| 405 | 
            +
                    model,
         | 
| 406 | 
            +
                    shape,
         | 
| 407 | 
            +
                    noise=None,
         | 
| 408 | 
            +
                    clip_denoised=True,
         | 
| 409 | 
            +
                    denoised_fn=None,
         | 
| 410 | 
            +
                    cond_fn=None,
         | 
| 411 | 
            +
                    model_kwargs=None,
         | 
| 412 | 
            +
                    device=None,
         | 
| 413 | 
            +
                    progress=False,
         | 
| 414 | 
            +
                ):
         | 
| 415 | 
            +
                    """
         | 
| 416 | 
            +
                    Generate samples from the model and yield intermediate samples from
         | 
| 417 | 
            +
                    each timestep of diffusion.
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    Arguments are the same as p_sample_loop().
         | 
| 420 | 
            +
                    Returns a generator over dicts, where each dict is the return value of
         | 
| 421 | 
            +
                    p_sample().
         | 
| 422 | 
            +
                    """
         | 
| 423 | 
            +
                    if device is None:
         | 
| 424 | 
            +
                        device = next(model.parameters()).device
         | 
| 425 | 
            +
                    assert isinstance(shape, (tuple, list))
         | 
| 426 | 
            +
                    if noise is not None:
         | 
| 427 | 
            +
                        img = noise
         | 
| 428 | 
            +
                    else:
         | 
| 429 | 
            +
                        img = th.randn(*shape, device=device)
         | 
| 430 | 
            +
                    indices = list(range(self.num_timesteps))[::-1]
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    if progress:
         | 
| 433 | 
            +
                        # Lazy import so that we don't depend on tqdm.
         | 
| 434 | 
            +
                        from tqdm.auto import tqdm
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                        indices = tqdm(indices)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    for i in indices:
         | 
| 439 | 
            +
                        t = th.tensor([i] * shape[0], device=device)
         | 
| 440 | 
            +
                        with th.no_grad():
         | 
| 441 | 
            +
                            out = self.p_sample(
         | 
| 442 | 
            +
                                model,
         | 
| 443 | 
            +
                                img,
         | 
| 444 | 
            +
                                t,
         | 
| 445 | 
            +
                                clip_denoised=clip_denoised,
         | 
| 446 | 
            +
                                denoised_fn=denoised_fn,
         | 
| 447 | 
            +
                                cond_fn=cond_fn,
         | 
| 448 | 
            +
                                model_kwargs=model_kwargs,
         | 
| 449 | 
            +
                            )
         | 
| 450 | 
            +
                            yield out
         | 
| 451 | 
            +
                            img = out["sample"]
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                def ddim_sample(
         | 
| 454 | 
            +
                    self,
         | 
| 455 | 
            +
                    model,
         | 
| 456 | 
            +
                    x,
         | 
| 457 | 
            +
                    t,
         | 
| 458 | 
            +
                    clip_denoised=True,
         | 
| 459 | 
            +
                    denoised_fn=None,
         | 
| 460 | 
            +
                    cond_fn=None,
         | 
| 461 | 
            +
                    model_kwargs=None,
         | 
| 462 | 
            +
                    eta=0.0,
         | 
| 463 | 
            +
                ):
         | 
| 464 | 
            +
                    """
         | 
| 465 | 
            +
                    Sample x_{t-1} from the model using DDIM.
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    Same usage as p_sample().
         | 
| 468 | 
            +
                    """
         | 
| 469 | 
            +
                    out = self.p_mean_variance(
         | 
| 470 | 
            +
                        model,
         | 
| 471 | 
            +
                        x,
         | 
| 472 | 
            +
                        t,
         | 
| 473 | 
            +
                        clip_denoised=clip_denoised,
         | 
| 474 | 
            +
                        denoised_fn=denoised_fn,
         | 
| 475 | 
            +
                        model_kwargs=model_kwargs,
         | 
| 476 | 
            +
                    )
         | 
| 477 | 
            +
                    if cond_fn is not None:
         | 
| 478 | 
            +
                        out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    # Usually our model outputs epsilon, but we re-derive it
         | 
| 481 | 
            +
                    # in case we used x_start or x_prev prediction.
         | 
| 482 | 
            +
                    eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
         | 
| 485 | 
            +
                    alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
         | 
| 486 | 
            +
                    sigma = (
         | 
| 487 | 
            +
                        eta
         | 
| 488 | 
            +
                        * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
         | 
| 489 | 
            +
                        * th.sqrt(1 - alpha_bar / alpha_bar_prev)
         | 
| 490 | 
            +
                    )
         | 
| 491 | 
            +
                    # Equation 12.
         | 
| 492 | 
            +
                    noise = th.randn_like(x)
         | 
| 493 | 
            +
                    mean_pred = (
         | 
| 494 | 
            +
                        out["pred_xstart"] * th.sqrt(alpha_bar_prev)
         | 
| 495 | 
            +
                        + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
         | 
| 496 | 
            +
                    )
         | 
| 497 | 
            +
                    nonzero_mask = (
         | 
| 498 | 
            +
                        (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
         | 
| 499 | 
            +
                    )  # no noise when t == 0
         | 
| 500 | 
            +
                    sample = mean_pred + nonzero_mask * sigma * noise
         | 
| 501 | 
            +
                    return {"sample": sample, "pred_xstart": out["pred_xstart"]}
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def ddim_reverse_sample(
         | 
| 504 | 
            +
                    self,
         | 
| 505 | 
            +
                    model,
         | 
| 506 | 
            +
                    x,
         | 
| 507 | 
            +
                    t,
         | 
| 508 | 
            +
                    clip_denoised=True,
         | 
| 509 | 
            +
                    denoised_fn=None,
         | 
| 510 | 
            +
                    cond_fn=None,
         | 
| 511 | 
            +
                    model_kwargs=None,
         | 
| 512 | 
            +
                    eta=0.0,
         | 
| 513 | 
            +
                ):
         | 
| 514 | 
            +
                    """
         | 
| 515 | 
            +
                    Sample x_{t+1} from the model using DDIM reverse ODE.
         | 
| 516 | 
            +
                    """
         | 
| 517 | 
            +
                    assert eta == 0.0, "Reverse ODE only for deterministic path"
         | 
| 518 | 
            +
                    out = self.p_mean_variance(
         | 
| 519 | 
            +
                        model,
         | 
| 520 | 
            +
                        x,
         | 
| 521 | 
            +
                        t,
         | 
| 522 | 
            +
                        clip_denoised=clip_denoised,
         | 
| 523 | 
            +
                        denoised_fn=denoised_fn,
         | 
| 524 | 
            +
                        model_kwargs=model_kwargs,
         | 
| 525 | 
            +
                    )
         | 
| 526 | 
            +
                    if cond_fn is not None:
         | 
| 527 | 
            +
                        out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
         | 
| 528 | 
            +
                    # Usually our model outputs epsilon, but we re-derive it
         | 
| 529 | 
            +
                    # in case we used x_start or x_prev prediction.
         | 
| 530 | 
            +
                    eps = (
         | 
| 531 | 
            +
                        _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
         | 
| 532 | 
            +
                        - out["pred_xstart"]
         | 
| 533 | 
            +
                    ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
         | 
| 534 | 
            +
                    alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                    # Equation 12. reversed
         | 
| 537 | 
            +
                    mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                def ddim_sample_loop(
         | 
| 542 | 
            +
                    self,
         | 
| 543 | 
            +
                    model,
         | 
| 544 | 
            +
                    shape,
         | 
| 545 | 
            +
                    noise=None,
         | 
| 546 | 
            +
                    clip_denoised=True,
         | 
| 547 | 
            +
                    denoised_fn=None,
         | 
| 548 | 
            +
                    cond_fn=None,
         | 
| 549 | 
            +
                    model_kwargs=None,
         | 
| 550 | 
            +
                    device=None,
         | 
| 551 | 
            +
                    progress=False,
         | 
| 552 | 
            +
                    eta=0.0,
         | 
| 553 | 
            +
                ):
         | 
| 554 | 
            +
                    """
         | 
| 555 | 
            +
                    Generate samples from the model using DDIM.
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                    Same usage as p_sample_loop().
         | 
| 558 | 
            +
                    """
         | 
| 559 | 
            +
                    final = None
         | 
| 560 | 
            +
                    for sample in self.ddim_sample_loop_progressive(
         | 
| 561 | 
            +
                        model,
         | 
| 562 | 
            +
                        shape,
         | 
| 563 | 
            +
                        noise=noise,
         | 
| 564 | 
            +
                        clip_denoised=clip_denoised,
         | 
| 565 | 
            +
                        denoised_fn=denoised_fn,
         | 
| 566 | 
            +
                        cond_fn=cond_fn,
         | 
| 567 | 
            +
                        model_kwargs=model_kwargs,
         | 
| 568 | 
            +
                        device=device,
         | 
| 569 | 
            +
                        progress=progress,
         | 
| 570 | 
            +
                        eta=eta,
         | 
| 571 | 
            +
                    ):
         | 
| 572 | 
            +
                        final = sample
         | 
| 573 | 
            +
                    return final["sample"]
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                def ddim_sample_loop_progressive(
         | 
| 576 | 
            +
                    self,
         | 
| 577 | 
            +
                    model,
         | 
| 578 | 
            +
                    shape,
         | 
| 579 | 
            +
                    noise=None,
         | 
| 580 | 
            +
                    clip_denoised=True,
         | 
| 581 | 
            +
                    denoised_fn=None,
         | 
| 582 | 
            +
                    cond_fn=None,
         | 
| 583 | 
            +
                    model_kwargs=None,
         | 
| 584 | 
            +
                    device=None,
         | 
| 585 | 
            +
                    progress=False,
         | 
| 586 | 
            +
                    eta=0.0,
         | 
| 587 | 
            +
                ):
         | 
| 588 | 
            +
                    """
         | 
| 589 | 
            +
                    Use DDIM to sample from the model and yield intermediate samples from
         | 
| 590 | 
            +
                    each timestep of DDIM.
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                    Same usage as p_sample_loop_progressive().
         | 
| 593 | 
            +
                    """
         | 
| 594 | 
            +
                    if device is None:
         | 
| 595 | 
            +
                        device = next(model.parameters()).device
         | 
| 596 | 
            +
                    assert isinstance(shape, (tuple, list))
         | 
| 597 | 
            +
                    if noise is not None:
         | 
| 598 | 
            +
                        img = noise
         | 
| 599 | 
            +
                    else:
         | 
| 600 | 
            +
                        img = th.randn(*shape, device=device)
         | 
| 601 | 
            +
                    indices = list(range(self.num_timesteps))[::-1]
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    if progress:
         | 
| 604 | 
            +
                        # Lazy import so that we don't depend on tqdm.
         | 
| 605 | 
            +
                        from tqdm.auto import tqdm
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                        indices = tqdm(indices)
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    for i in indices:
         | 
| 610 | 
            +
                        t = th.tensor([i] * shape[0], device=device)
         | 
| 611 | 
            +
                        with th.no_grad():
         | 
| 612 | 
            +
                            out = self.ddim_sample(
         | 
| 613 | 
            +
                                model,
         | 
| 614 | 
            +
                                img,
         | 
| 615 | 
            +
                                t,
         | 
| 616 | 
            +
                                clip_denoised=clip_denoised,
         | 
| 617 | 
            +
                                denoised_fn=denoised_fn,
         | 
| 618 | 
            +
                                cond_fn=cond_fn,
         | 
| 619 | 
            +
                                model_kwargs=model_kwargs,
         | 
| 620 | 
            +
                                eta=eta,
         | 
| 621 | 
            +
                            )
         | 
| 622 | 
            +
                            yield out
         | 
| 623 | 
            +
                            img = out["sample"]
         | 
| 624 | 
            +
             | 
| 625 | 
            +
             | 
| 626 | 
            +
            def _extract_into_tensor(arr, timesteps, broadcast_shape):
         | 
| 627 | 
            +
                """
         | 
| 628 | 
            +
                Extract values from a 1-D numpy array for a batch of indices.
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                :param arr: the 1-D numpy array.
         | 
| 631 | 
            +
                :param timesteps: a tensor of indices into the array to extract.
         | 
| 632 | 
            +
                :param broadcast_shape: a larger shape of K dimensions with the batch
         | 
| 633 | 
            +
                                        dimension equal to the length of timesteps.
         | 
| 634 | 
            +
                :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
         | 
| 635 | 
            +
                """
         | 
| 636 | 
            +
                res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
         | 
| 637 | 
            +
                while len(res.shape) < len(broadcast_shape):
         | 
| 638 | 
            +
                    res = res[..., None]
         | 
| 639 | 
            +
                return res + th.zeros(broadcast_shape, device=timesteps.device)
         | 
    	
        glide_text2im/model_creation.py
    ADDED
    
    | @@ -0,0 +1,195 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from glide_text2im.gaussian_diffusion import get_named_beta_schedule
         | 
| 2 | 
            +
            from glide_text2im.respace import SpacedDiffusion, space_timesteps
         | 
| 3 | 
            +
            from glide_text2im.text2im_model import (
         | 
| 4 | 
            +
                InpaintText2ImUNet,
         | 
| 5 | 
            +
                SuperResInpaintText2ImUnet,
         | 
| 6 | 
            +
                SuperResText2ImUNet,
         | 
| 7 | 
            +
                Text2ImUNet,
         | 
| 8 | 
            +
            )
         | 
| 9 | 
            +
            from glide_text2im.tokenizer.bpe import get_encoder
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def model_and_diffusion_defaults():
         | 
| 13 | 
            +
                return dict(
         | 
| 14 | 
            +
                    image_size=64,
         | 
| 15 | 
            +
                    num_channels=192,
         | 
| 16 | 
            +
                    num_res_blocks=3,
         | 
| 17 | 
            +
                    channel_mult="",
         | 
| 18 | 
            +
                    num_heads=1,
         | 
| 19 | 
            +
                    num_head_channels=64,
         | 
| 20 | 
            +
                    num_heads_upsample=-1,
         | 
| 21 | 
            +
                    attention_resolutions="32,16,8",
         | 
| 22 | 
            +
                    dropout=0.1,
         | 
| 23 | 
            +
                    text_ctx=128,
         | 
| 24 | 
            +
                    xf_width=512,
         | 
| 25 | 
            +
                    xf_layers=16,
         | 
| 26 | 
            +
                    xf_heads=8,
         | 
| 27 | 
            +
                    xf_final_ln=True,
         | 
| 28 | 
            +
                    xf_padding=True,
         | 
| 29 | 
            +
                    diffusion_steps=1000,
         | 
| 30 | 
            +
                    noise_schedule="squaredcos_cap_v2",
         | 
| 31 | 
            +
                    timestep_respacing="",
         | 
| 32 | 
            +
                    use_scale_shift_norm=True,
         | 
| 33 | 
            +
                    resblock_updown=True,
         | 
| 34 | 
            +
                    use_fp16=True,
         | 
| 35 | 
            +
                    cache_text_emb=False,
         | 
| 36 | 
            +
                    inpaint=False,
         | 
| 37 | 
            +
                    super_res=False,
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def model_and_diffusion_defaults_upsampler():
         | 
| 42 | 
            +
                result = model_and_diffusion_defaults()
         | 
| 43 | 
            +
                result.update(
         | 
| 44 | 
            +
                    dict(
         | 
| 45 | 
            +
                        image_size=256,
         | 
| 46 | 
            +
                        num_res_blocks=2,
         | 
| 47 | 
            +
                        noise_schedule="linear",
         | 
| 48 | 
            +
                        super_res=True,
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
                )
         | 
| 51 | 
            +
                return result
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def create_model_and_diffusion(
         | 
| 55 | 
            +
                image_size,
         | 
| 56 | 
            +
                num_channels,
         | 
| 57 | 
            +
                num_res_blocks,
         | 
| 58 | 
            +
                channel_mult,
         | 
| 59 | 
            +
                num_heads,
         | 
| 60 | 
            +
                num_head_channels,
         | 
| 61 | 
            +
                num_heads_upsample,
         | 
| 62 | 
            +
                attention_resolutions,
         | 
| 63 | 
            +
                dropout,
         | 
| 64 | 
            +
                text_ctx,
         | 
| 65 | 
            +
                xf_width,
         | 
| 66 | 
            +
                xf_layers,
         | 
| 67 | 
            +
                xf_heads,
         | 
| 68 | 
            +
                xf_final_ln,
         | 
| 69 | 
            +
                xf_padding,
         | 
| 70 | 
            +
                diffusion_steps,
         | 
| 71 | 
            +
                noise_schedule,
         | 
| 72 | 
            +
                timestep_respacing,
         | 
| 73 | 
            +
                use_scale_shift_norm,
         | 
| 74 | 
            +
                resblock_updown,
         | 
| 75 | 
            +
                use_fp16,
         | 
| 76 | 
            +
                cache_text_emb,
         | 
| 77 | 
            +
                inpaint,
         | 
| 78 | 
            +
                super_res,
         | 
| 79 | 
            +
            ):
         | 
| 80 | 
            +
                model = create_model(
         | 
| 81 | 
            +
                    image_size,
         | 
| 82 | 
            +
                    num_channels,
         | 
| 83 | 
            +
                    num_res_blocks,
         | 
| 84 | 
            +
                    channel_mult=channel_mult,
         | 
| 85 | 
            +
                    attention_resolutions=attention_resolutions,
         | 
| 86 | 
            +
                    num_heads=num_heads,
         | 
| 87 | 
            +
                    num_head_channels=num_head_channels,
         | 
| 88 | 
            +
                    num_heads_upsample=num_heads_upsample,
         | 
| 89 | 
            +
                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 90 | 
            +
                    dropout=dropout,
         | 
| 91 | 
            +
                    text_ctx=text_ctx,
         | 
| 92 | 
            +
                    xf_width=xf_width,
         | 
| 93 | 
            +
                    xf_layers=xf_layers,
         | 
| 94 | 
            +
                    xf_heads=xf_heads,
         | 
| 95 | 
            +
                    xf_final_ln=xf_final_ln,
         | 
| 96 | 
            +
                    xf_padding=xf_padding,
         | 
| 97 | 
            +
                    resblock_updown=resblock_updown,
         | 
| 98 | 
            +
                    use_fp16=use_fp16,
         | 
| 99 | 
            +
                    cache_text_emb=cache_text_emb,
         | 
| 100 | 
            +
                    inpaint=inpaint,
         | 
| 101 | 
            +
                    super_res=super_res,
         | 
| 102 | 
            +
                )
         | 
| 103 | 
            +
                diffusion = create_gaussian_diffusion(
         | 
| 104 | 
            +
                    steps=diffusion_steps,
         | 
| 105 | 
            +
                    noise_schedule=noise_schedule,
         | 
| 106 | 
            +
                    timestep_respacing=timestep_respacing,
         | 
| 107 | 
            +
                )
         | 
| 108 | 
            +
                return model, diffusion
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            def create_model(
         | 
| 112 | 
            +
                image_size,
         | 
| 113 | 
            +
                num_channels,
         | 
| 114 | 
            +
                num_res_blocks,
         | 
| 115 | 
            +
                channel_mult,
         | 
| 116 | 
            +
                attention_resolutions,
         | 
| 117 | 
            +
                num_heads,
         | 
| 118 | 
            +
                num_head_channels,
         | 
| 119 | 
            +
                num_heads_upsample,
         | 
| 120 | 
            +
                use_scale_shift_norm,
         | 
| 121 | 
            +
                dropout,
         | 
| 122 | 
            +
                text_ctx,
         | 
| 123 | 
            +
                xf_width,
         | 
| 124 | 
            +
                xf_layers,
         | 
| 125 | 
            +
                xf_heads,
         | 
| 126 | 
            +
                xf_final_ln,
         | 
| 127 | 
            +
                xf_padding,
         | 
| 128 | 
            +
                resblock_updown,
         | 
| 129 | 
            +
                use_fp16,
         | 
| 130 | 
            +
                cache_text_emb,
         | 
| 131 | 
            +
                inpaint,
         | 
| 132 | 
            +
                super_res,
         | 
| 133 | 
            +
            ):
         | 
| 134 | 
            +
                if channel_mult == "":
         | 
| 135 | 
            +
                    if image_size == 256:
         | 
| 136 | 
            +
                        channel_mult = (1, 1, 2, 2, 4, 4)
         | 
| 137 | 
            +
                    elif image_size == 128:
         | 
| 138 | 
            +
                        channel_mult = (1, 1, 2, 3, 4)
         | 
| 139 | 
            +
                    elif image_size == 64:
         | 
| 140 | 
            +
                        channel_mult = (1, 2, 3, 4)
         | 
| 141 | 
            +
                    else:
         | 
| 142 | 
            +
                        raise ValueError(f"unsupported image size: {image_size}")
         | 
| 143 | 
            +
                else:
         | 
| 144 | 
            +
                    channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
         | 
| 145 | 
            +
                    assert 2 ** (len(channel_mult) + 2) == image_size
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                attention_ds = []
         | 
| 148 | 
            +
                for res in attention_resolutions.split(","):
         | 
| 149 | 
            +
                    attention_ds.append(image_size // int(res))
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                if inpaint and super_res:
         | 
| 152 | 
            +
                    model_cls = SuperResInpaintText2ImUnet
         | 
| 153 | 
            +
                elif inpaint:
         | 
| 154 | 
            +
                    model_cls = InpaintText2ImUNet
         | 
| 155 | 
            +
                elif super_res:
         | 
| 156 | 
            +
                    model_cls = SuperResText2ImUNet
         | 
| 157 | 
            +
                else:
         | 
| 158 | 
            +
                    model_cls = Text2ImUNet
         | 
| 159 | 
            +
                return model_cls(
         | 
| 160 | 
            +
                    text_ctx=text_ctx,
         | 
| 161 | 
            +
                    xf_width=xf_width,
         | 
| 162 | 
            +
                    xf_layers=xf_layers,
         | 
| 163 | 
            +
                    xf_heads=xf_heads,
         | 
| 164 | 
            +
                    xf_final_ln=xf_final_ln,
         | 
| 165 | 
            +
                    tokenizer=get_encoder(),
         | 
| 166 | 
            +
                    xf_padding=xf_padding,
         | 
| 167 | 
            +
                    in_channels=3,
         | 
| 168 | 
            +
                    model_channels=num_channels,
         | 
| 169 | 
            +
                    out_channels=6,
         | 
| 170 | 
            +
                    num_res_blocks=num_res_blocks,
         | 
| 171 | 
            +
                    attention_resolutions=tuple(attention_ds),
         | 
| 172 | 
            +
                    dropout=dropout,
         | 
| 173 | 
            +
                    channel_mult=channel_mult,
         | 
| 174 | 
            +
                    use_fp16=use_fp16,
         | 
| 175 | 
            +
                    num_heads=num_heads,
         | 
| 176 | 
            +
                    num_head_channels=num_head_channels,
         | 
| 177 | 
            +
                    num_heads_upsample=num_heads_upsample,
         | 
| 178 | 
            +
                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 179 | 
            +
                    resblock_updown=resblock_updown,
         | 
| 180 | 
            +
                    cache_text_emb=cache_text_emb,
         | 
| 181 | 
            +
                )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            def create_gaussian_diffusion(
         | 
| 185 | 
            +
                steps,
         | 
| 186 | 
            +
                noise_schedule,
         | 
| 187 | 
            +
                timestep_respacing,
         | 
| 188 | 
            +
            ):
         | 
| 189 | 
            +
                betas = get_named_beta_schedule(noise_schedule, steps)
         | 
| 190 | 
            +
                if not timestep_respacing:
         | 
| 191 | 
            +
                    timestep_respacing = [steps]
         | 
| 192 | 
            +
                return SpacedDiffusion(
         | 
| 193 | 
            +
                    use_timesteps=space_timesteps(steps, timestep_respacing),
         | 
| 194 | 
            +
                    betas=betas,
         | 
| 195 | 
            +
                )
         | 
    	
        glide_text2im/nn.py
    ADDED
    
    | @@ -0,0 +1,105 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Various utilities for neural networks.
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch as th
         | 
| 8 | 
            +
            import torch.nn as nn
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class GroupNorm32(nn.GroupNorm):
         | 
| 13 | 
            +
                def __init__(self, num_groups, num_channels, swish, eps=1e-5):
         | 
| 14 | 
            +
                    super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
         | 
| 15 | 
            +
                    self.swish = swish
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def forward(self, x):
         | 
| 18 | 
            +
                    y = super().forward(x.float()).to(x.dtype)
         | 
| 19 | 
            +
                    if self.swish == 1.0:
         | 
| 20 | 
            +
                        y = F.silu(y)
         | 
| 21 | 
            +
                    elif self.swish:
         | 
| 22 | 
            +
                        y = y * F.sigmoid(y * float(self.swish))
         | 
| 23 | 
            +
                    return y
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def conv_nd(dims, *args, **kwargs):
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                Create a 1D, 2D, or 3D convolution module.
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                if dims == 1:
         | 
| 31 | 
            +
                    return nn.Conv1d(*args, **kwargs)
         | 
| 32 | 
            +
                elif dims == 2:
         | 
| 33 | 
            +
                    return nn.Conv2d(*args, **kwargs)
         | 
| 34 | 
            +
                elif dims == 3:
         | 
| 35 | 
            +
                    return nn.Conv3d(*args, **kwargs)
         | 
| 36 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def linear(*args, **kwargs):
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                Create a linear module.
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                return nn.Linear(*args, **kwargs)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def avg_pool_nd(dims, *args, **kwargs):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                Create a 1D, 2D, or 3D average pooling module.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                if dims == 1:
         | 
| 51 | 
            +
                    return nn.AvgPool1d(*args, **kwargs)
         | 
| 52 | 
            +
                elif dims == 2:
         | 
| 53 | 
            +
                    return nn.AvgPool2d(*args, **kwargs)
         | 
| 54 | 
            +
                elif dims == 3:
         | 
| 55 | 
            +
                    return nn.AvgPool3d(*args, **kwargs)
         | 
| 56 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def zero_module(module):
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                Zero out the parameters of a module and return it.
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                for p in module.parameters():
         | 
| 64 | 
            +
                    p.detach().zero_()
         | 
| 65 | 
            +
                return module
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def scale_module(module, scale):
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                Scale the parameters of a module and return it.
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                for p in module.parameters():
         | 
| 73 | 
            +
                    p.detach().mul_(scale)
         | 
| 74 | 
            +
                return module
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def normalization(channels, swish=0.0):
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                Make a standard normalization layer, with an optional swish activation.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                :param channels: number of input channels.
         | 
| 82 | 
            +
                :return: an nn.Module for normalization.
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
                return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def timestep_embedding(timesteps, dim, max_period=10000):
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                Create sinusoidal timestep embeddings.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 92 | 
            +
                                  These may be fractional.
         | 
| 93 | 
            +
                :param dim: the dimension of the output.
         | 
| 94 | 
            +
                :param max_period: controls the minimum frequency of the embeddings.
         | 
| 95 | 
            +
                :return: an [N x dim] Tensor of positional embeddings.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                half = dim // 2
         | 
| 98 | 
            +
                freqs = th.exp(
         | 
| 99 | 
            +
                    -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
         | 
| 100 | 
            +
                ).to(device=timesteps.device)
         | 
| 101 | 
            +
                args = timesteps[:, None].float() * freqs[None]
         | 
| 102 | 
            +
                embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
         | 
| 103 | 
            +
                if dim % 2:
         | 
| 104 | 
            +
                    embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
         | 
| 105 | 
            +
                return embedding
         | 
    	
        glide_text2im/respace.py
    ADDED
    
    | @@ -0,0 +1,117 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Utilities for changing sampling schedules of a trained model.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Simplified from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
         | 
| 5 | 
            +
            """
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import torch as th
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .gaussian_diffusion import GaussianDiffusion
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def space_timesteps(num_timesteps, section_counts):
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                Create a list of timesteps to use from an original diffusion process,
         | 
| 16 | 
            +
                given the number of timesteps we want to take from equally-sized portions
         | 
| 17 | 
            +
                of the original process.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                For example, if there's 300 timesteps and the section counts are [10,15,20]
         | 
| 20 | 
            +
                then the first 100 timesteps are strided to be 10 timesteps, the second 100
         | 
| 21 | 
            +
                are strided to be 15 timesteps, and the final 100 are strided to be 20.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                :param num_timesteps: the number of diffusion steps in the original
         | 
| 24 | 
            +
                                      process to divide up.
         | 
| 25 | 
            +
                :param section_counts: either a list of numbers, or a string containing
         | 
| 26 | 
            +
                                       comma-separated numbers, indicating the step count
         | 
| 27 | 
            +
                                       per section. As a special case, use "ddimN" where N
         | 
| 28 | 
            +
                                       is a number of steps to use the striding from the
         | 
| 29 | 
            +
                                       DDIM paper.
         | 
| 30 | 
            +
                :return: a set of diffusion steps from the original process to use.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                if isinstance(section_counts, str):
         | 
| 33 | 
            +
                    if section_counts.startswith("ddim"):
         | 
| 34 | 
            +
                        desired_count = int(section_counts[len("ddim") :])
         | 
| 35 | 
            +
                        for i in range(1, num_timesteps):
         | 
| 36 | 
            +
                            if len(range(0, num_timesteps, i)) == desired_count:
         | 
| 37 | 
            +
                                return set(range(0, num_timesteps, i))
         | 
| 38 | 
            +
                        raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
         | 
| 39 | 
            +
                    elif section_counts == "fast27":
         | 
| 40 | 
            +
                        steps = space_timesteps(num_timesteps, "10,10,3,2,2")
         | 
| 41 | 
            +
                        # Help reduce DDIM artifacts from noisiest timesteps.
         | 
| 42 | 
            +
                        steps.remove(num_timesteps - 1)
         | 
| 43 | 
            +
                        steps.add(num_timesteps - 3)
         | 
| 44 | 
            +
                        return steps
         | 
| 45 | 
            +
                    section_counts = [int(x) for x in section_counts.split(",")]
         | 
| 46 | 
            +
                size_per = num_timesteps // len(section_counts)
         | 
| 47 | 
            +
                extra = num_timesteps % len(section_counts)
         | 
| 48 | 
            +
                start_idx = 0
         | 
| 49 | 
            +
                all_steps = []
         | 
| 50 | 
            +
                for i, section_count in enumerate(section_counts):
         | 
| 51 | 
            +
                    size = size_per + (1 if i < extra else 0)
         | 
| 52 | 
            +
                    if size < section_count:
         | 
| 53 | 
            +
                        raise ValueError(f"cannot divide section of {size} steps into {section_count}")
         | 
| 54 | 
            +
                    if section_count <= 1:
         | 
| 55 | 
            +
                        frac_stride = 1
         | 
| 56 | 
            +
                    else:
         | 
| 57 | 
            +
                        frac_stride = (size - 1) / (section_count - 1)
         | 
| 58 | 
            +
                    cur_idx = 0.0
         | 
| 59 | 
            +
                    taken_steps = []
         | 
| 60 | 
            +
                    for _ in range(section_count):
         | 
| 61 | 
            +
                        taken_steps.append(start_idx + round(cur_idx))
         | 
| 62 | 
            +
                        cur_idx += frac_stride
         | 
| 63 | 
            +
                    all_steps += taken_steps
         | 
| 64 | 
            +
                    start_idx += size
         | 
| 65 | 
            +
                return set(all_steps)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            class SpacedDiffusion(GaussianDiffusion):
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                A diffusion process which can skip steps in a base diffusion process.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                :param use_timesteps: a collection (sequence or set) of timesteps from the
         | 
| 73 | 
            +
                                      original diffusion process to retain.
         | 
| 74 | 
            +
                :param kwargs: the kwargs to create the base diffusion process.
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def __init__(self, use_timesteps, **kwargs):
         | 
| 78 | 
            +
                    self.use_timesteps = set(use_timesteps)
         | 
| 79 | 
            +
                    self.timestep_map = []
         | 
| 80 | 
            +
                    self.original_num_steps = len(kwargs["betas"])
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    base_diffusion = GaussianDiffusion(**kwargs)  # pylint: disable=missing-kwoa
         | 
| 83 | 
            +
                    last_alpha_cumprod = 1.0
         | 
| 84 | 
            +
                    new_betas = []
         | 
| 85 | 
            +
                    for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
         | 
| 86 | 
            +
                        if i in self.use_timesteps:
         | 
| 87 | 
            +
                            new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
         | 
| 88 | 
            +
                            last_alpha_cumprod = alpha_cumprod
         | 
| 89 | 
            +
                            self.timestep_map.append(i)
         | 
| 90 | 
            +
                    kwargs["betas"] = np.array(new_betas)
         | 
| 91 | 
            +
                    super().__init__(**kwargs)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def p_mean_variance(self, model, *args, **kwargs):
         | 
| 94 | 
            +
                    return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def condition_mean(self, cond_fn, *args, **kwargs):
         | 
| 97 | 
            +
                    return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def condition_score(self, cond_fn, *args, **kwargs):
         | 
| 100 | 
            +
                    return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def _wrap_model(self, model):
         | 
| 103 | 
            +
                    if isinstance(model, _WrappedModel):
         | 
| 104 | 
            +
                        return model
         | 
| 105 | 
            +
                    return _WrappedModel(model, self.timestep_map, self.original_num_steps)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            class _WrappedModel:
         | 
| 109 | 
            +
                def __init__(self, model, timestep_map, original_num_steps):
         | 
| 110 | 
            +
                    self.model = model
         | 
| 111 | 
            +
                    self.timestep_map = timestep_map
         | 
| 112 | 
            +
                    self.original_num_steps = original_num_steps
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def __call__(self, x, ts, **kwargs):
         | 
| 115 | 
            +
                    map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
         | 
| 116 | 
            +
                    new_ts = map_tensor[ts]
         | 
| 117 | 
            +
                    return self.model(x, new_ts, **kwargs)
         | 
    	
        glide_text2im/text2im_model.py
    ADDED
    
    | @@ -0,0 +1,233 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch as th
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from .nn import timestep_embedding
         | 
| 6 | 
            +
            from .unet import UNetModel
         | 
| 7 | 
            +
            from .xf import LayerNorm, Transformer, convert_module_to_f16
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Text2ImUNet(UNetModel):
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                A UNetModel that conditions on text with an encoding transformer.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                Expects an extra kwarg `tokens` of text.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                :param text_ctx: number of text tokens to expect.
         | 
| 17 | 
            +
                :param xf_width: width of the transformer.
         | 
| 18 | 
            +
                :param xf_layers: depth of the transformer.
         | 
| 19 | 
            +
                :param xf_heads: heads in the transformer.
         | 
| 20 | 
            +
                :param xf_final_ln: use a LayerNorm after the output layer.
         | 
| 21 | 
            +
                :param tokenizer: the text tokenizer for sampling/vocab size.
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    text_ctx,
         | 
| 27 | 
            +
                    xf_width,
         | 
| 28 | 
            +
                    xf_layers,
         | 
| 29 | 
            +
                    xf_heads,
         | 
| 30 | 
            +
                    xf_final_ln,
         | 
| 31 | 
            +
                    tokenizer,
         | 
| 32 | 
            +
                    *args,
         | 
| 33 | 
            +
                    cache_text_emb=False,
         | 
| 34 | 
            +
                    xf_ar=0.0,
         | 
| 35 | 
            +
                    xf_padding=False,
         | 
| 36 | 
            +
                    share_unemb=False,
         | 
| 37 | 
            +
                    **kwargs,
         | 
| 38 | 
            +
                ):
         | 
| 39 | 
            +
                    self.text_ctx = text_ctx
         | 
| 40 | 
            +
                    self.xf_width = xf_width
         | 
| 41 | 
            +
                    self.xf_ar = xf_ar
         | 
| 42 | 
            +
                    self.xf_padding = xf_padding
         | 
| 43 | 
            +
                    self.tokenizer = tokenizer
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    if not xf_width:
         | 
| 46 | 
            +
                        super().__init__(*args, **kwargs, encoder_channels=None)
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        super().__init__(*args, **kwargs, encoder_channels=xf_width)
         | 
| 49 | 
            +
                    if self.xf_width:
         | 
| 50 | 
            +
                        self.transformer = Transformer(
         | 
| 51 | 
            +
                            text_ctx,
         | 
| 52 | 
            +
                            xf_width,
         | 
| 53 | 
            +
                            xf_layers,
         | 
| 54 | 
            +
                            xf_heads,
         | 
| 55 | 
            +
                        )
         | 
| 56 | 
            +
                        if xf_final_ln:
         | 
| 57 | 
            +
                            self.final_ln = LayerNorm(xf_width)
         | 
| 58 | 
            +
                        else:
         | 
| 59 | 
            +
                            self.final_ln = None
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                        self.token_embedding = nn.Embedding(self.tokenizer.n_vocab, xf_width)
         | 
| 62 | 
            +
                        self.positional_embedding = nn.Parameter(th.empty(text_ctx, xf_width, dtype=th.float32))
         | 
| 63 | 
            +
                        self.transformer_proj = nn.Linear(xf_width, self.model_channels * 4)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                        if self.xf_padding:
         | 
| 66 | 
            +
                            self.padding_embedding = nn.Parameter(
         | 
| 67 | 
            +
                                th.empty(text_ctx, xf_width, dtype=th.float32)
         | 
| 68 | 
            +
                            )
         | 
| 69 | 
            +
                        if self.xf_ar:
         | 
| 70 | 
            +
                            self.unemb = nn.Linear(xf_width, self.tokenizer.n_vocab)
         | 
| 71 | 
            +
                            if share_unemb:
         | 
| 72 | 
            +
                                self.unemb.weight = self.token_embedding.weight
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.cache_text_emb = cache_text_emb
         | 
| 75 | 
            +
                    self.cache = None
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def convert_to_fp16(self):
         | 
| 78 | 
            +
                    super().convert_to_fp16()
         | 
| 79 | 
            +
                    if self.xf_width:
         | 
| 80 | 
            +
                        self.transformer.apply(convert_module_to_f16)
         | 
| 81 | 
            +
                        self.transformer_proj.to(th.float16)
         | 
| 82 | 
            +
                        self.token_embedding.to(th.float16)
         | 
| 83 | 
            +
                        self.positional_embedding.to(th.float16)
         | 
| 84 | 
            +
                        if self.xf_padding:
         | 
| 85 | 
            +
                            self.padding_embedding.to(th.float16)
         | 
| 86 | 
            +
                        if self.xf_ar:
         | 
| 87 | 
            +
                            self.unemb.to(th.float16)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def get_text_emb(self, tokens, mask):
         | 
| 90 | 
            +
                    assert tokens is not None
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if self.cache_text_emb and self.cache is not None:
         | 
| 93 | 
            +
                        assert (
         | 
| 94 | 
            +
                            tokens == self.cache["tokens"]
         | 
| 95 | 
            +
                        ).all(), f"Tokens {tokens.cpu().numpy().tolist()} do not match cache {self.cache['tokens'].cpu().numpy().tolist()}"
         | 
| 96 | 
            +
                        return self.cache
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    xf_in = self.token_embedding(tokens.long())
         | 
| 99 | 
            +
                    xf_in = xf_in + self.positional_embedding[None]
         | 
| 100 | 
            +
                    if self.xf_padding:
         | 
| 101 | 
            +
                        assert mask is not None
         | 
| 102 | 
            +
                        xf_in = th.where(mask[..., None], xf_in, self.padding_embedding[None])
         | 
| 103 | 
            +
                    xf_out = self.transformer(xf_in.to(self.dtype))
         | 
| 104 | 
            +
                    if self.final_ln is not None:
         | 
| 105 | 
            +
                        xf_out = self.final_ln(xf_out)
         | 
| 106 | 
            +
                    xf_proj = self.transformer_proj(xf_out[:, -1])
         | 
| 107 | 
            +
                    xf_out = xf_out.permute(0, 2, 1)  # NLC -> NCL
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    outputs = dict(xf_proj=xf_proj, xf_out=xf_out)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    if self.cache_text_emb:
         | 
| 112 | 
            +
                        self.cache = dict(
         | 
| 113 | 
            +
                            tokens=tokens,
         | 
| 114 | 
            +
                            xf_proj=xf_proj.detach(),
         | 
| 115 | 
            +
                            xf_out=xf_out.detach() if xf_out is not None else None,
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    return outputs
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def del_cache(self):
         | 
| 121 | 
            +
                    self.cache = None
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def forward(self, x, timesteps, tokens=None, mask=None):
         | 
| 124 | 
            +
                    hs = []
         | 
| 125 | 
            +
                    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
         | 
| 126 | 
            +
                    if self.xf_width:
         | 
| 127 | 
            +
                        text_outputs = self.get_text_emb(tokens, mask)
         | 
| 128 | 
            +
                        xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
         | 
| 129 | 
            +
                        emb = emb + xf_proj.to(emb)
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        xf_out = None
         | 
| 132 | 
            +
                    h = x.type(self.dtype)
         | 
| 133 | 
            +
                    for module in self.input_blocks:
         | 
| 134 | 
            +
                        h = module(h, emb, xf_out)
         | 
| 135 | 
            +
                        hs.append(h)
         | 
| 136 | 
            +
                    h = self.middle_block(h, emb, xf_out)
         | 
| 137 | 
            +
                    for module in self.output_blocks:
         | 
| 138 | 
            +
                        h = th.cat([h, hs.pop()], dim=1)
         | 
| 139 | 
            +
                        h = module(h, emb, xf_out)
         | 
| 140 | 
            +
                    h = h.type(x.dtype)
         | 
| 141 | 
            +
                    h = self.out(h)
         | 
| 142 | 
            +
                    return h
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            class SuperResText2ImUNet(Text2ImUNet):
         | 
| 146 | 
            +
                """
         | 
| 147 | 
            +
                A text2im model that performs super-resolution.
         | 
| 148 | 
            +
                Expects an extra kwarg `low_res` to condition on a low-resolution image.
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 152 | 
            +
                    if "in_channels" in kwargs:
         | 
| 153 | 
            +
                        kwargs = dict(kwargs)
         | 
| 154 | 
            +
                        kwargs["in_channels"] = kwargs["in_channels"] * 2
         | 
| 155 | 
            +
                    else:
         | 
| 156 | 
            +
                        # Curse you, Python. Or really, just curse positional arguments :|.
         | 
| 157 | 
            +
                        args = list(args)
         | 
| 158 | 
            +
                        args[1] = args[1] * 2
         | 
| 159 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def forward(self, x, timesteps, low_res=None, **kwargs):
         | 
| 162 | 
            +
                    _, _, new_height, new_width = x.shape
         | 
| 163 | 
            +
                    upsampled = F.interpolate(
         | 
| 164 | 
            +
                        low_res, (new_height, new_width), mode="bilinear", align_corners=False
         | 
| 165 | 
            +
                    )
         | 
| 166 | 
            +
                    x = th.cat([x, upsampled], dim=1)
         | 
| 167 | 
            +
                    return super().forward(x, timesteps, **kwargs)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            class InpaintText2ImUNet(Text2ImUNet):
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                A text2im model which can perform inpainting.
         | 
| 173 | 
            +
                """
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 176 | 
            +
                    if "in_channels" in kwargs:
         | 
| 177 | 
            +
                        kwargs = dict(kwargs)
         | 
| 178 | 
            +
                        kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1
         | 
| 179 | 
            +
                    else:
         | 
| 180 | 
            +
                        # Curse you, Python. Or really, just curse positional arguments :|.
         | 
| 181 | 
            +
                        args = list(args)
         | 
| 182 | 
            +
                        args[1] = args[1] * 2 + 1
         | 
| 183 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs):
         | 
| 186 | 
            +
                    if inpaint_image is None:
         | 
| 187 | 
            +
                        inpaint_image = th.zeros_like(x)
         | 
| 188 | 
            +
                    if inpaint_mask is None:
         | 
| 189 | 
            +
                        inpaint_mask = th.zeros_like(x[:, :1])
         | 
| 190 | 
            +
                    return super().forward(
         | 
| 191 | 
            +
                        th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1),
         | 
| 192 | 
            +
                        timesteps,
         | 
| 193 | 
            +
                        **kwargs,
         | 
| 194 | 
            +
                    )
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            class SuperResInpaintText2ImUnet(Text2ImUNet):
         | 
| 198 | 
            +
                """
         | 
| 199 | 
            +
                A text2im model which can perform both upsampling and inpainting.
         | 
| 200 | 
            +
                """
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 203 | 
            +
                    if "in_channels" in kwargs:
         | 
| 204 | 
            +
                        kwargs = dict(kwargs)
         | 
| 205 | 
            +
                        kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1
         | 
| 206 | 
            +
                    else:
         | 
| 207 | 
            +
                        # Curse you, Python. Or really, just curse positional arguments :|.
         | 
| 208 | 
            +
                        args = list(args)
         | 
| 209 | 
            +
                        args[1] = args[1] * 3 + 1
         | 
| 210 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def forward(
         | 
| 213 | 
            +
                    self,
         | 
| 214 | 
            +
                    x,
         | 
| 215 | 
            +
                    timesteps,
         | 
| 216 | 
            +
                    inpaint_image=None,
         | 
| 217 | 
            +
                    inpaint_mask=None,
         | 
| 218 | 
            +
                    low_res=None,
         | 
| 219 | 
            +
                    **kwargs,
         | 
| 220 | 
            +
                ):
         | 
| 221 | 
            +
                    if inpaint_image is None:
         | 
| 222 | 
            +
                        inpaint_image = th.zeros_like(x)
         | 
| 223 | 
            +
                    if inpaint_mask is None:
         | 
| 224 | 
            +
                        inpaint_mask = th.zeros_like(x[:, :1])
         | 
| 225 | 
            +
                    _, _, new_height, new_width = x.shape
         | 
| 226 | 
            +
                    upsampled = F.interpolate(
         | 
| 227 | 
            +
                        low_res, (new_height, new_width), mode="bilinear", align_corners=False
         | 
| 228 | 
            +
                    )
         | 
| 229 | 
            +
                    return super().forward(
         | 
| 230 | 
            +
                        th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1),
         | 
| 231 | 
            +
                        timesteps,
         | 
| 232 | 
            +
                        **kwargs,
         | 
| 233 | 
            +
                    )
         | 
    	
        glide_text2im/tokenizer/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        glide_text2im/tokenizer/bpe.py
    ADDED
    
    | @@ -0,0 +1,151 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Byte pair encoding utilities adapted from:
         | 
| 3 | 
            +
            https://github.com/openai/gpt-2/blob/master/src/encoder.py
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import gzip
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            from functools import lru_cache
         | 
| 10 | 
            +
            from typing import List, Tuple
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import regex as re
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            @lru_cache()
         | 
| 16 | 
            +
            def bytes_to_unicode():
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                Returns list of utf-8 byte and a corresponding list of unicode strings.
         | 
| 19 | 
            +
                The reversible bpe codes work on unicode strings.
         | 
| 20 | 
            +
                This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
         | 
| 21 | 
            +
                When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
         | 
| 22 | 
            +
                This is a signficant percentage of your normal, say, 32K bpe vocab.
         | 
| 23 | 
            +
                To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
         | 
| 24 | 
            +
                And avoids mapping to whitespace/control characters the bpe code barfs on.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                bs = (
         | 
| 27 | 
            +
                    list(range(ord("!"), ord("~") + 1))
         | 
| 28 | 
            +
                    + list(range(ord("¡"), ord("¬") + 1))
         | 
| 29 | 
            +
                    + list(range(ord("®"), ord("ÿ") + 1))
         | 
| 30 | 
            +
                )
         | 
| 31 | 
            +
                cs = bs[:]
         | 
| 32 | 
            +
                n = 0
         | 
| 33 | 
            +
                for b in range(2 ** 8):
         | 
| 34 | 
            +
                    if b not in bs:
         | 
| 35 | 
            +
                        bs.append(b)
         | 
| 36 | 
            +
                        cs.append(2 ** 8 + n)
         | 
| 37 | 
            +
                        n += 1
         | 
| 38 | 
            +
                cs = [chr(n) for n in cs]
         | 
| 39 | 
            +
                return dict(zip(bs, cs))
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def get_pairs(word):
         | 
| 43 | 
            +
                """Return set of symbol pairs in a word.
         | 
| 44 | 
            +
                Word is represented as tuple of symbols (symbols being variable-length strings).
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
                pairs = set()
         | 
| 47 | 
            +
                prev_char = word[0]
         | 
| 48 | 
            +
                for char in word[1:]:
         | 
| 49 | 
            +
                    pairs.add((prev_char, char))
         | 
| 50 | 
            +
                    prev_char = char
         | 
| 51 | 
            +
                return pairs
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            class Encoder:
         | 
| 55 | 
            +
                def __init__(self, encoder, bpe_merges, errors="replace"):
         | 
| 56 | 
            +
                    self.encoder = encoder
         | 
| 57 | 
            +
                    self.decoder = {v: k for k, v in self.encoder.items()}
         | 
| 58 | 
            +
                    self.errors = errors  # how to handle errors in decoding
         | 
| 59 | 
            +
                    self.byte_encoder = bytes_to_unicode()
         | 
| 60 | 
            +
                    self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
         | 
| 61 | 
            +
                    self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
         | 
| 62 | 
            +
                    self.cache = {}
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
         | 
| 65 | 
            +
                    self.pat = re.compile(
         | 
| 66 | 
            +
                        r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                @property
         | 
| 70 | 
            +
                def n_vocab(self) -> int:
         | 
| 71 | 
            +
                    return len(self.encoder)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                @property
         | 
| 74 | 
            +
                def end_token(self) -> int:
         | 
| 75 | 
            +
                    return self.n_vocab - 1
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def padded_tokens_and_mask(
         | 
| 78 | 
            +
                    self, tokens: List[int], text_ctx: int
         | 
| 79 | 
            +
                ) -> Tuple[List[int], List[bool]]:
         | 
| 80 | 
            +
                    tokens = tokens[:text_ctx]
         | 
| 81 | 
            +
                    padding = text_ctx - len(tokens)
         | 
| 82 | 
            +
                    padded_tokens = tokens + [self.end_token] * padding
         | 
| 83 | 
            +
                    mask = [True] * len(tokens) + [False] * padding
         | 
| 84 | 
            +
                    return padded_tokens, mask
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def bpe(self, token):
         | 
| 87 | 
            +
                    if token in self.cache:
         | 
| 88 | 
            +
                        return self.cache[token]
         | 
| 89 | 
            +
                    word = tuple(token)
         | 
| 90 | 
            +
                    pairs = get_pairs(word)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if not pairs:
         | 
| 93 | 
            +
                        return token
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    while True:
         | 
| 96 | 
            +
                        bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
         | 
| 97 | 
            +
                        if bigram not in self.bpe_ranks:
         | 
| 98 | 
            +
                            break
         | 
| 99 | 
            +
                        first, second = bigram
         | 
| 100 | 
            +
                        new_word = []
         | 
| 101 | 
            +
                        i = 0
         | 
| 102 | 
            +
                        while i < len(word):
         | 
| 103 | 
            +
                            try:
         | 
| 104 | 
            +
                                j = word.index(first, i)
         | 
| 105 | 
            +
                                new_word.extend(word[i:j])
         | 
| 106 | 
            +
                                i = j
         | 
| 107 | 
            +
                            except:  # pylint: disable=bare-except
         | 
| 108 | 
            +
                                new_word.extend(word[i:])
         | 
| 109 | 
            +
                                break
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                            if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
         | 
| 112 | 
            +
                                new_word.append(first + second)
         | 
| 113 | 
            +
                                i += 2
         | 
| 114 | 
            +
                            else:
         | 
| 115 | 
            +
                                new_word.append(word[i])
         | 
| 116 | 
            +
                                i += 1
         | 
| 117 | 
            +
                        new_word = tuple(new_word)
         | 
| 118 | 
            +
                        word = new_word
         | 
| 119 | 
            +
                        if len(word) == 1:
         | 
| 120 | 
            +
                            break
         | 
| 121 | 
            +
                        else:
         | 
| 122 | 
            +
                            pairs = get_pairs(word)
         | 
| 123 | 
            +
                    word = " ".join(word)
         | 
| 124 | 
            +
                    self.cache[token] = word
         | 
| 125 | 
            +
                    return word
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def encode(self, text):
         | 
| 128 | 
            +
                    text = text.lower()
         | 
| 129 | 
            +
                    bpe_tokens = []
         | 
| 130 | 
            +
                    for token in re.findall(self.pat, text):
         | 
| 131 | 
            +
                        token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
         | 
| 132 | 
            +
                        bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
         | 
| 133 | 
            +
                    return bpe_tokens
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def decode(self, tokens):
         | 
| 136 | 
            +
                    text = "".join([self.decoder[token] for token in tokens])
         | 
| 137 | 
            +
                    text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
         | 
| 138 | 
            +
                    return text
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def get_encoder():
         | 
| 142 | 
            +
                root_dir = os.path.dirname(os.path.abspath(__file__))
         | 
| 143 | 
            +
                with gzip.open(os.path.join(root_dir, "encoder.json.gz"), "r") as f:
         | 
| 144 | 
            +
                    encoder = json.load(f)
         | 
| 145 | 
            +
                with gzip.open(os.path.join(root_dir, "vocab.bpe.gz"), "r") as f:
         | 
| 146 | 
            +
                    bpe_data = str(f.read(), "utf-8")
         | 
| 147 | 
            +
                bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
         | 
| 148 | 
            +
                return Encoder(
         | 
| 149 | 
            +
                    encoder=encoder,
         | 
| 150 | 
            +
                    bpe_merges=bpe_merges,
         | 
| 151 | 
            +
                )
         | 
    	
        glide_text2im/tokenizer/bpe_simple_vocab_16e6.txt.gz
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
         | 
| 3 | 
            +
            size 1356917
         | 
    	
        glide_text2im/tokenizer/encoder.json.gz
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:4debc1cf25180021b07744bc9f4488d53c7bf112c8ce5de8097c6a7518f4ec7c
         | 
| 3 | 
            +
            size 348346
         | 
    	
        glide_text2im/tokenizer/simple_tokenizer.py
    ADDED
    
    | @@ -0,0 +1,163 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Copied from: https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/simple_tokenizer.py
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import gzip
         | 
| 6 | 
            +
            import html
         | 
| 7 | 
            +
            import os
         | 
| 8 | 
            +
            from functools import lru_cache
         | 
| 9 | 
            +
            from typing import List, Tuple
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import ftfy
         | 
| 12 | 
            +
            import regex as re
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            @lru_cache()
         | 
| 16 | 
            +
            def default_bpe():
         | 
| 17 | 
            +
                return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            @lru_cache()
         | 
| 21 | 
            +
            def bytes_to_unicode():
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Returns list of utf-8 byte and a corresponding list of unicode strings.
         | 
| 24 | 
            +
                The reversible bpe codes work on unicode strings.
         | 
| 25 | 
            +
                This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
         | 
| 26 | 
            +
                When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
         | 
| 27 | 
            +
                This is a signficant percentage of your normal, say, 32K bpe vocab.
         | 
| 28 | 
            +
                To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
         | 
| 29 | 
            +
                And avoids mapping to whitespace/control characters the bpe code barfs on.
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
                bs = (
         | 
| 32 | 
            +
                    list(range(ord("!"), ord("~") + 1))
         | 
| 33 | 
            +
                    + list(range(ord("¡"), ord("¬") + 1))
         | 
| 34 | 
            +
                    + list(range(ord("®"), ord("ÿ") + 1))
         | 
| 35 | 
            +
                )
         | 
| 36 | 
            +
                cs = bs[:]
         | 
| 37 | 
            +
                n = 0
         | 
| 38 | 
            +
                for b in range(2 ** 8):
         | 
| 39 | 
            +
                    if b not in bs:
         | 
| 40 | 
            +
                        bs.append(b)
         | 
| 41 | 
            +
                        cs.append(2 ** 8 + n)
         | 
| 42 | 
            +
                        n += 1
         | 
| 43 | 
            +
                cs = [chr(n) for n in cs]
         | 
| 44 | 
            +
                return dict(zip(bs, cs))
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def get_pairs(word):
         | 
| 48 | 
            +
                """Return set of symbol pairs in a word.
         | 
| 49 | 
            +
                Word is represented as tuple of symbols (symbols being variable-length strings).
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                pairs = set()
         | 
| 52 | 
            +
                prev_char = word[0]
         | 
| 53 | 
            +
                for char in word[1:]:
         | 
| 54 | 
            +
                    pairs.add((prev_char, char))
         | 
| 55 | 
            +
                    prev_char = char
         | 
| 56 | 
            +
                return pairs
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def basic_clean(text):
         | 
| 60 | 
            +
                text = ftfy.fix_text(text)
         | 
| 61 | 
            +
                text = html.unescape(html.unescape(text))
         | 
| 62 | 
            +
                return text.strip()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def whitespace_clean(text):
         | 
| 66 | 
            +
                text = re.sub(r"\s+", " ", text)
         | 
| 67 | 
            +
                text = text.strip()
         | 
| 68 | 
            +
                return text
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class SimpleTokenizer(object):
         | 
| 72 | 
            +
                def __init__(self, bpe_path: str = default_bpe()):
         | 
| 73 | 
            +
                    self.byte_encoder = bytes_to_unicode()
         | 
| 74 | 
            +
                    self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
         | 
| 75 | 
            +
                    merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
         | 
| 76 | 
            +
                    merges = merges[1 : 49152 - 256 - 2 + 1]
         | 
| 77 | 
            +
                    merges = [tuple(merge.split()) for merge in merges]
         | 
| 78 | 
            +
                    vocab = list(bytes_to_unicode().values())
         | 
| 79 | 
            +
                    vocab = vocab + [v + "</w>" for v in vocab]
         | 
| 80 | 
            +
                    for merge in merges:
         | 
| 81 | 
            +
                        vocab.append("".join(merge))
         | 
| 82 | 
            +
                    vocab.extend(["<|startoftext|>", "<|endoftext|>"])
         | 
| 83 | 
            +
                    self.encoder = dict(zip(vocab, range(len(vocab))))
         | 
| 84 | 
            +
                    self.decoder = {v: k for k, v in self.encoder.items()}
         | 
| 85 | 
            +
                    self.bpe_ranks = dict(zip(merges, range(len(merges))))
         | 
| 86 | 
            +
                    self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
         | 
| 87 | 
            +
                    self.pat = re.compile(
         | 
| 88 | 
            +
                        r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
         | 
| 89 | 
            +
                        re.IGNORECASE,
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                @property
         | 
| 93 | 
            +
                def start_token(self):
         | 
| 94 | 
            +
                    return self.encoder["<|startoftext|>"]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                @property
         | 
| 97 | 
            +
                def end_token(self):
         | 
| 98 | 
            +
                    return self.encoder["<|endoftext|>"]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def padded_tokens_and_len(self, tokens: List[int], text_ctx: int) -> Tuple[List[int], int]:
         | 
| 101 | 
            +
                    tokens = [self.start_token] + tokens[: text_ctx - 2] + [self.end_token]
         | 
| 102 | 
            +
                    text_len = len(tokens)
         | 
| 103 | 
            +
                    padding = text_ctx - len(tokens)
         | 
| 104 | 
            +
                    padded_tokens = tokens + [0] * padding
         | 
| 105 | 
            +
                    return padded_tokens, text_len
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def bpe(self, token):
         | 
| 108 | 
            +
                    if token in self.cache:
         | 
| 109 | 
            +
                        return self.cache[token]
         | 
| 110 | 
            +
                    word = tuple(token[:-1]) + (token[-1] + "</w>",)
         | 
| 111 | 
            +
                    pairs = get_pairs(word)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if not pairs:
         | 
| 114 | 
            +
                        return token + "</w>"
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    while True:
         | 
| 117 | 
            +
                        bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
         | 
| 118 | 
            +
                        if bigram not in self.bpe_ranks:
         | 
| 119 | 
            +
                            break
         | 
| 120 | 
            +
                        first, second = bigram
         | 
| 121 | 
            +
                        new_word = []
         | 
| 122 | 
            +
                        i = 0
         | 
| 123 | 
            +
                        while i < len(word):
         | 
| 124 | 
            +
                            try:
         | 
| 125 | 
            +
                                j = word.index(first, i)
         | 
| 126 | 
            +
                                new_word.extend(word[i:j])
         | 
| 127 | 
            +
                                i = j
         | 
| 128 | 
            +
                            except:  # pylint: disable=bare-except
         | 
| 129 | 
            +
                                new_word.extend(word[i:])
         | 
| 130 | 
            +
                                break
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                            if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
         | 
| 133 | 
            +
                                new_word.append(first + second)
         | 
| 134 | 
            +
                                i += 2
         | 
| 135 | 
            +
                            else:
         | 
| 136 | 
            +
                                new_word.append(word[i])
         | 
| 137 | 
            +
                                i += 1
         | 
| 138 | 
            +
                        new_word = tuple(new_word)
         | 
| 139 | 
            +
                        word = new_word
         | 
| 140 | 
            +
                        if len(word) == 1:
         | 
| 141 | 
            +
                            break
         | 
| 142 | 
            +
                        else:
         | 
| 143 | 
            +
                            pairs = get_pairs(word)
         | 
| 144 | 
            +
                    word = " ".join(word)
         | 
| 145 | 
            +
                    self.cache[token] = word
         | 
| 146 | 
            +
                    return word
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def encode(self, text):
         | 
| 149 | 
            +
                    bpe_tokens = []
         | 
| 150 | 
            +
                    text = whitespace_clean(basic_clean(text)).lower()
         | 
| 151 | 
            +
                    for token in re.findall(self.pat, text):
         | 
| 152 | 
            +
                        token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
         | 
| 153 | 
            +
                        bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
         | 
| 154 | 
            +
                    return bpe_tokens
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def decode(self, tokens):
         | 
| 157 | 
            +
                    text = "".join([self.decoder[token] for token in tokens])
         | 
| 158 | 
            +
                    text = (
         | 
| 159 | 
            +
                        bytearray([self.byte_decoder[c] for c in text])
         | 
| 160 | 
            +
                        .decode("utf-8", errors="replace")
         | 
| 161 | 
            +
                        .replace("</w>", " ")
         | 
| 162 | 
            +
                    )
         | 
| 163 | 
            +
                    return text
         | 
    	
        glide_text2im/tokenizer/vocab.bpe.gz
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:ce239dd5a898827423fee00e3f7ab37de7900f247f2ba360753d860e8a46524d
         | 
| 3 | 
            +
            size 213544
         | 
    	
        glide_text2im/unet.py
    ADDED
    
    | @@ -0,0 +1,635 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from abc import abstractmethod
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch as th
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .fp16_util import convert_module_to_f16, convert_module_to_f32
         | 
| 9 | 
            +
            from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class TimestepBlock(nn.Module):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Any module where forward() takes timestep embeddings as a second argument.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                @abstractmethod
         | 
| 18 | 
            +
                def forward(self, x, emb):
         | 
| 19 | 
            +
                    """
         | 
| 20 | 
            +
                    Apply the module to `x` given `emb` timestep embeddings.
         | 
| 21 | 
            +
                    """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                A sequential module that passes timestep embeddings to the children that
         | 
| 27 | 
            +
                support it as an extra input.
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, x, emb, encoder_out=None):
         | 
| 31 | 
            +
                    for layer in self:
         | 
| 32 | 
            +
                        if isinstance(layer, TimestepBlock):
         | 
| 33 | 
            +
                            x = layer(x, emb)
         | 
| 34 | 
            +
                        elif isinstance(layer, AttentionBlock):
         | 
| 35 | 
            +
                            x = layer(x, encoder_out)
         | 
| 36 | 
            +
                        else:
         | 
| 37 | 
            +
                            x = layer(x)
         | 
| 38 | 
            +
                    return x
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class Upsample(nn.Module):
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                An upsampling layer with an optional convolution.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 46 | 
            +
                :param use_conv: a bool determining if a convolution is applied.
         | 
| 47 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
         | 
| 48 | 
            +
                             upsampling occurs in the inner-two dimensions.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __init__(self, channels, use_conv, dims=2, out_channels=None):
         | 
| 52 | 
            +
                    super().__init__()
         | 
| 53 | 
            +
                    self.channels = channels
         | 
| 54 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 55 | 
            +
                    self.use_conv = use_conv
         | 
| 56 | 
            +
                    self.dims = dims
         | 
| 57 | 
            +
                    if use_conv:
         | 
| 58 | 
            +
                        self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def forward(self, x):
         | 
| 61 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 62 | 
            +
                    if self.dims == 3:
         | 
| 63 | 
            +
                        x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
         | 
| 64 | 
            +
                    else:
         | 
| 65 | 
            +
                        x = F.interpolate(x, scale_factor=2, mode="nearest")
         | 
| 66 | 
            +
                    if self.use_conv:
         | 
| 67 | 
            +
                        x = self.conv(x)
         | 
| 68 | 
            +
                    return x
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class Downsample(nn.Module):
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                A downsampling layer with an optional convolution.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 76 | 
            +
                :param use_conv: a bool determining if a convolution is applied.
         | 
| 77 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
         | 
| 78 | 
            +
                             downsampling occurs in the inner-two dimensions.
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def __init__(self, channels, use_conv, dims=2, out_channels=None):
         | 
| 82 | 
            +
                    super().__init__()
         | 
| 83 | 
            +
                    self.channels = channels
         | 
| 84 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 85 | 
            +
                    self.use_conv = use_conv
         | 
| 86 | 
            +
                    self.dims = dims
         | 
| 87 | 
            +
                    stride = 2 if dims != 3 else (1, 2, 2)
         | 
| 88 | 
            +
                    if use_conv:
         | 
| 89 | 
            +
                        self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1)
         | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        assert self.channels == self.out_channels
         | 
| 92 | 
            +
                        self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def forward(self, x):
         | 
| 95 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 96 | 
            +
                    return self.op(x)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class ResBlock(TimestepBlock):
         | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
                A residual block that can optionally change the number of channels.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                :param channels: the number of input channels.
         | 
| 104 | 
            +
                :param emb_channels: the number of timestep embedding channels.
         | 
| 105 | 
            +
                :param dropout: the rate of dropout.
         | 
| 106 | 
            +
                :param out_channels: if specified, the number of out channels.
         | 
| 107 | 
            +
                :param use_conv: if True and out_channels is specified, use a spatial
         | 
| 108 | 
            +
                    convolution instead of a smaller 1x1 convolution to change the
         | 
| 109 | 
            +
                    channels in the skip connection.
         | 
| 110 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D.
         | 
| 111 | 
            +
                :param use_checkpoint: if True, use gradient checkpointing on this module.
         | 
| 112 | 
            +
                :param up: if True, use this block for upsampling.
         | 
| 113 | 
            +
                :param down: if True, use this block for downsampling.
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def __init__(
         | 
| 117 | 
            +
                    self,
         | 
| 118 | 
            +
                    channels,
         | 
| 119 | 
            +
                    emb_channels,
         | 
| 120 | 
            +
                    dropout,
         | 
| 121 | 
            +
                    out_channels=None,
         | 
| 122 | 
            +
                    use_conv=False,
         | 
| 123 | 
            +
                    use_scale_shift_norm=False,
         | 
| 124 | 
            +
                    dims=2,
         | 
| 125 | 
            +
                    use_checkpoint=False,
         | 
| 126 | 
            +
                    up=False,
         | 
| 127 | 
            +
                    down=False,
         | 
| 128 | 
            +
                ):
         | 
| 129 | 
            +
                    super().__init__()
         | 
| 130 | 
            +
                    self.channels = channels
         | 
| 131 | 
            +
                    self.emb_channels = emb_channels
         | 
| 132 | 
            +
                    self.dropout = dropout
         | 
| 133 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 134 | 
            +
                    self.use_conv = use_conv
         | 
| 135 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 136 | 
            +
                    self.use_scale_shift_norm = use_scale_shift_norm
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    self.in_layers = nn.Sequential(
         | 
| 139 | 
            +
                        normalization(channels, swish=1.0),
         | 
| 140 | 
            +
                        nn.Identity(),
         | 
| 141 | 
            +
                        conv_nd(dims, channels, self.out_channels, 3, padding=1),
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    self.updown = up or down
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    if up:
         | 
| 147 | 
            +
                        self.h_upd = Upsample(channels, False, dims)
         | 
| 148 | 
            +
                        self.x_upd = Upsample(channels, False, dims)
         | 
| 149 | 
            +
                    elif down:
         | 
| 150 | 
            +
                        self.h_upd = Downsample(channels, False, dims)
         | 
| 151 | 
            +
                        self.x_upd = Downsample(channels, False, dims)
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        self.h_upd = self.x_upd = nn.Identity()
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    self.emb_layers = nn.Sequential(
         | 
| 156 | 
            +
                        nn.SiLU(),
         | 
| 157 | 
            +
                        linear(
         | 
| 158 | 
            +
                            emb_channels,
         | 
| 159 | 
            +
                            2 * self.out_channels if use_scale_shift_norm else self.out_channels,
         | 
| 160 | 
            +
                        ),
         | 
| 161 | 
            +
                    )
         | 
| 162 | 
            +
                    self.out_layers = nn.Sequential(
         | 
| 163 | 
            +
                        normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
         | 
| 164 | 
            +
                        nn.SiLU() if use_scale_shift_norm else nn.Identity(),
         | 
| 165 | 
            +
                        nn.Dropout(p=dropout),
         | 
| 166 | 
            +
                        zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
         | 
| 167 | 
            +
                    )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    if self.out_channels == channels:
         | 
| 170 | 
            +
                        self.skip_connection = nn.Identity()
         | 
| 171 | 
            +
                    elif use_conv:
         | 
| 172 | 
            +
                        self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
         | 
| 173 | 
            +
                    else:
         | 
| 174 | 
            +
                        self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def forward(self, x, emb):
         | 
| 177 | 
            +
                    """
         | 
| 178 | 
            +
                    Apply the block to a Tensor, conditioned on a timestep embedding.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    :param x: an [N x C x ...] Tensor of features.
         | 
| 181 | 
            +
                    :param emb: an [N x emb_channels] Tensor of timestep embeddings.
         | 
| 182 | 
            +
                    :return: an [N x C x ...] Tensor of outputs.
         | 
| 183 | 
            +
                    """
         | 
| 184 | 
            +
                    if self.updown:
         | 
| 185 | 
            +
                        in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
         | 
| 186 | 
            +
                        h = in_rest(x)
         | 
| 187 | 
            +
                        h = self.h_upd(h)
         | 
| 188 | 
            +
                        x = self.x_upd(x)
         | 
| 189 | 
            +
                        h = in_conv(h)
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        h = self.in_layers(x)
         | 
| 192 | 
            +
                    emb_out = self.emb_layers(emb).type(h.dtype)
         | 
| 193 | 
            +
                    while len(emb_out.shape) < len(h.shape):
         | 
| 194 | 
            +
                        emb_out = emb_out[..., None]
         | 
| 195 | 
            +
                    if self.use_scale_shift_norm:
         | 
| 196 | 
            +
                        out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
         | 
| 197 | 
            +
                        scale, shift = th.chunk(emb_out, 2, dim=1)
         | 
| 198 | 
            +
                        h = out_norm(h) * (1 + scale) + shift
         | 
| 199 | 
            +
                        h = out_rest(h)
         | 
| 200 | 
            +
                    else:
         | 
| 201 | 
            +
                        h = h + emb_out
         | 
| 202 | 
            +
                        h = self.out_layers(h)
         | 
| 203 | 
            +
                    return self.skip_connection(x) + h
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 207 | 
            +
                """
         | 
| 208 | 
            +
                An attention block that allows spatial positions to attend to each other.
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                Originally ported from here, but adapted to the N-d case.
         | 
| 211 | 
            +
                https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
         | 
| 212 | 
            +
                """
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def __init__(
         | 
| 215 | 
            +
                    self,
         | 
| 216 | 
            +
                    channels,
         | 
| 217 | 
            +
                    num_heads=1,
         | 
| 218 | 
            +
                    num_head_channels=-1,
         | 
| 219 | 
            +
                    use_checkpoint=False,
         | 
| 220 | 
            +
                    encoder_channels=None,
         | 
| 221 | 
            +
                ):
         | 
| 222 | 
            +
                    super().__init__()
         | 
| 223 | 
            +
                    self.channels = channels
         | 
| 224 | 
            +
                    if num_head_channels == -1:
         | 
| 225 | 
            +
                        self.num_heads = num_heads
         | 
| 226 | 
            +
                    else:
         | 
| 227 | 
            +
                        assert (
         | 
| 228 | 
            +
                            channels % num_head_channels == 0
         | 
| 229 | 
            +
                        ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
         | 
| 230 | 
            +
                        self.num_heads = channels // num_head_channels
         | 
| 231 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 232 | 
            +
                    self.norm = normalization(channels, swish=0.0)
         | 
| 233 | 
            +
                    self.qkv = conv_nd(1, channels, channels * 3, 1)
         | 
| 234 | 
            +
                    self.attention = QKVAttention(self.num_heads)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    if encoder_channels is not None:
         | 
| 237 | 
            +
                        self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
         | 
| 238 | 
            +
                    self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def forward(self, x, encoder_out=None):
         | 
| 241 | 
            +
                    b, c, *spatial = x.shape
         | 
| 242 | 
            +
                    qkv = self.qkv(self.norm(x).view(b, c, -1))
         | 
| 243 | 
            +
                    if encoder_out is not None:
         | 
| 244 | 
            +
                        encoder_out = self.encoder_kv(encoder_out)
         | 
| 245 | 
            +
                        h = self.attention(qkv, encoder_out)
         | 
| 246 | 
            +
                    else:
         | 
| 247 | 
            +
                        h = self.attention(qkv)
         | 
| 248 | 
            +
                    h = self.proj_out(h)
         | 
| 249 | 
            +
                    return x + h.reshape(b, c, *spatial)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            class QKVAttention(nn.Module):
         | 
| 253 | 
            +
                """
         | 
| 254 | 
            +
                A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def __init__(self, n_heads):
         | 
| 258 | 
            +
                    super().__init__()
         | 
| 259 | 
            +
                    self.n_heads = n_heads
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def forward(self, qkv, encoder_kv=None):
         | 
| 262 | 
            +
                    """
         | 
| 263 | 
            +
                    Apply QKV attention.
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
         | 
| 266 | 
            +
                    :return: an [N x (H * C) x T] tensor after attention.
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    bs, width, length = qkv.shape
         | 
| 269 | 
            +
                    assert width % (3 * self.n_heads) == 0
         | 
| 270 | 
            +
                    ch = width // (3 * self.n_heads)
         | 
| 271 | 
            +
                    q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
         | 
| 272 | 
            +
                    if encoder_kv is not None:
         | 
| 273 | 
            +
                        assert encoder_kv.shape[1] == self.n_heads * ch * 2
         | 
| 274 | 
            +
                        ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
         | 
| 275 | 
            +
                        k = th.cat([ek, k], dim=-1)
         | 
| 276 | 
            +
                        v = th.cat([ev, v], dim=-1)
         | 
| 277 | 
            +
                    scale = 1 / math.sqrt(math.sqrt(ch))
         | 
| 278 | 
            +
                    weight = th.einsum(
         | 
| 279 | 
            +
                        "bct,bcs->bts", q * scale, k * scale
         | 
| 280 | 
            +
                    )  # More stable with f16 than dividing afterwards
         | 
| 281 | 
            +
                    weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
         | 
| 282 | 
            +
                    a = th.einsum("bts,bcs->bct", weight, v)
         | 
| 283 | 
            +
                    return a.reshape(bs, -1, length)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
            class UNetModel(nn.Module):
         | 
| 287 | 
            +
                """
         | 
| 288 | 
            +
                The full UNet model with attention and timestep embedding.
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                :param in_channels: channels in the input Tensor.
         | 
| 291 | 
            +
                :param model_channels: base channel count for the model.
         | 
| 292 | 
            +
                :param out_channels: channels in the output Tensor.
         | 
| 293 | 
            +
                :param num_res_blocks: number of residual blocks per downsample.
         | 
| 294 | 
            +
                :param attention_resolutions: a collection of downsample rates at which
         | 
| 295 | 
            +
                    attention will take place. May be a set, list, or tuple.
         | 
| 296 | 
            +
                    For example, if this contains 4, then at 4x downsampling, attention
         | 
| 297 | 
            +
                    will be used.
         | 
| 298 | 
            +
                :param dropout: the dropout probability.
         | 
| 299 | 
            +
                :param channel_mult: channel multiplier for each level of the UNet.
         | 
| 300 | 
            +
                :param conv_resample: if True, use learned convolutions for upsampling and
         | 
| 301 | 
            +
                    downsampling.
         | 
| 302 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D.
         | 
| 303 | 
            +
                :param num_classes: if specified (as an int), then this model will be
         | 
| 304 | 
            +
                    class-conditional with `num_classes` classes.
         | 
| 305 | 
            +
                :param use_checkpoint: use gradient checkpointing to reduce memory usage.
         | 
| 306 | 
            +
                :param num_heads: the number of attention heads in each attention layer.
         | 
| 307 | 
            +
                :param num_heads_channels: if specified, ignore num_heads and instead use
         | 
| 308 | 
            +
                                           a fixed channel width per attention head.
         | 
| 309 | 
            +
                :param num_heads_upsample: works with num_heads to set a different number
         | 
| 310 | 
            +
                                           of heads for upsampling. Deprecated.
         | 
| 311 | 
            +
                :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
         | 
| 312 | 
            +
                :param resblock_updown: use residual blocks for up/downsampling.
         | 
| 313 | 
            +
                """
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def __init__(
         | 
| 316 | 
            +
                    self,
         | 
| 317 | 
            +
                    in_channels,
         | 
| 318 | 
            +
                    model_channels,
         | 
| 319 | 
            +
                    out_channels,
         | 
| 320 | 
            +
                    num_res_blocks,
         | 
| 321 | 
            +
                    attention_resolutions,
         | 
| 322 | 
            +
                    dropout=0,
         | 
| 323 | 
            +
                    channel_mult=(1, 2, 4, 8),
         | 
| 324 | 
            +
                    conv_resample=True,
         | 
| 325 | 
            +
                    dims=2,
         | 
| 326 | 
            +
                    num_classes=None,
         | 
| 327 | 
            +
                    use_checkpoint=False,
         | 
| 328 | 
            +
                    use_fp16=False,
         | 
| 329 | 
            +
                    num_heads=1,
         | 
| 330 | 
            +
                    num_head_channels=-1,
         | 
| 331 | 
            +
                    num_heads_upsample=-1,
         | 
| 332 | 
            +
                    use_scale_shift_norm=False,
         | 
| 333 | 
            +
                    resblock_updown=False,
         | 
| 334 | 
            +
                    encoder_channels=None,
         | 
| 335 | 
            +
                ):
         | 
| 336 | 
            +
                    super().__init__()
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    if num_heads_upsample == -1:
         | 
| 339 | 
            +
                        num_heads_upsample = num_heads
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    self.in_channels = in_channels
         | 
| 342 | 
            +
                    self.model_channels = model_channels
         | 
| 343 | 
            +
                    self.out_channels = out_channels
         | 
| 344 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 345 | 
            +
                    self.attention_resolutions = attention_resolutions
         | 
| 346 | 
            +
                    self.dropout = dropout
         | 
| 347 | 
            +
                    self.channel_mult = channel_mult
         | 
| 348 | 
            +
                    self.conv_resample = conv_resample
         | 
| 349 | 
            +
                    self.num_classes = num_classes
         | 
| 350 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 351 | 
            +
                    self.dtype = th.float16 if use_fp16 else th.float32
         | 
| 352 | 
            +
                    self.num_heads = num_heads
         | 
| 353 | 
            +
                    self.num_head_channels = num_head_channels
         | 
| 354 | 
            +
                    self.num_heads_upsample = num_heads_upsample
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    time_embed_dim = model_channels * 4
         | 
| 357 | 
            +
                    self.time_embed = nn.Sequential(
         | 
| 358 | 
            +
                        linear(model_channels, time_embed_dim),
         | 
| 359 | 
            +
                        nn.SiLU(),
         | 
| 360 | 
            +
                        linear(time_embed_dim, time_embed_dim),
         | 
| 361 | 
            +
                    )
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    if self.num_classes is not None:
         | 
| 364 | 
            +
                        self.label_emb = nn.Embedding(num_classes, time_embed_dim)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    ch = input_ch = int(channel_mult[0] * model_channels)
         | 
| 367 | 
            +
                    self.input_blocks = nn.ModuleList(
         | 
| 368 | 
            +
                        [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
         | 
| 369 | 
            +
                    )
         | 
| 370 | 
            +
                    self._feature_size = ch
         | 
| 371 | 
            +
                    input_block_chans = [ch]
         | 
| 372 | 
            +
                    ds = 1
         | 
| 373 | 
            +
                    for level, mult in enumerate(channel_mult):
         | 
| 374 | 
            +
                        for _ in range(num_res_blocks):
         | 
| 375 | 
            +
                            layers = [
         | 
| 376 | 
            +
                                ResBlock(
         | 
| 377 | 
            +
                                    ch,
         | 
| 378 | 
            +
                                    time_embed_dim,
         | 
| 379 | 
            +
                                    dropout,
         | 
| 380 | 
            +
                                    out_channels=int(mult * model_channels),
         | 
| 381 | 
            +
                                    dims=dims,
         | 
| 382 | 
            +
                                    use_checkpoint=use_checkpoint,
         | 
| 383 | 
            +
                                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 384 | 
            +
                                )
         | 
| 385 | 
            +
                            ]
         | 
| 386 | 
            +
                            ch = int(mult * model_channels)
         | 
| 387 | 
            +
                            if ds in attention_resolutions:
         | 
| 388 | 
            +
                                layers.append(
         | 
| 389 | 
            +
                                    AttentionBlock(
         | 
| 390 | 
            +
                                        ch,
         | 
| 391 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 392 | 
            +
                                        num_heads=num_heads,
         | 
| 393 | 
            +
                                        num_head_channels=num_head_channels,
         | 
| 394 | 
            +
                                        encoder_channels=encoder_channels,
         | 
| 395 | 
            +
                                    )
         | 
| 396 | 
            +
                                )
         | 
| 397 | 
            +
                            self.input_blocks.append(TimestepEmbedSequential(*layers))
         | 
| 398 | 
            +
                            self._feature_size += ch
         | 
| 399 | 
            +
                            input_block_chans.append(ch)
         | 
| 400 | 
            +
                        if level != len(channel_mult) - 1:
         | 
| 401 | 
            +
                            out_ch = ch
         | 
| 402 | 
            +
                            self.input_blocks.append(
         | 
| 403 | 
            +
                                TimestepEmbedSequential(
         | 
| 404 | 
            +
                                    ResBlock(
         | 
| 405 | 
            +
                                        ch,
         | 
| 406 | 
            +
                                        time_embed_dim,
         | 
| 407 | 
            +
                                        dropout,
         | 
| 408 | 
            +
                                        out_channels=out_ch,
         | 
| 409 | 
            +
                                        dims=dims,
         | 
| 410 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 411 | 
            +
                                        use_scale_shift_norm=use_scale_shift_norm,
         | 
| 412 | 
            +
                                        down=True,
         | 
| 413 | 
            +
                                    )
         | 
| 414 | 
            +
                                    if resblock_updown
         | 
| 415 | 
            +
                                    else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
         | 
| 416 | 
            +
                                )
         | 
| 417 | 
            +
                            )
         | 
| 418 | 
            +
                            ch = out_ch
         | 
| 419 | 
            +
                            input_block_chans.append(ch)
         | 
| 420 | 
            +
                            ds *= 2
         | 
| 421 | 
            +
                            self._feature_size += ch
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    self.middle_block = TimestepEmbedSequential(
         | 
| 424 | 
            +
                        ResBlock(
         | 
| 425 | 
            +
                            ch,
         | 
| 426 | 
            +
                            time_embed_dim,
         | 
| 427 | 
            +
                            dropout,
         | 
| 428 | 
            +
                            dims=dims,
         | 
| 429 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 430 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 431 | 
            +
                        ),
         | 
| 432 | 
            +
                        AttentionBlock(
         | 
| 433 | 
            +
                            ch,
         | 
| 434 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 435 | 
            +
                            num_heads=num_heads,
         | 
| 436 | 
            +
                            num_head_channels=num_head_channels,
         | 
| 437 | 
            +
                            encoder_channels=encoder_channels,
         | 
| 438 | 
            +
                        ),
         | 
| 439 | 
            +
                        ResBlock(
         | 
| 440 | 
            +
                            ch,
         | 
| 441 | 
            +
                            time_embed_dim,
         | 
| 442 | 
            +
                            dropout,
         | 
| 443 | 
            +
                            dims=dims,
         | 
| 444 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 445 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 446 | 
            +
                        ),
         | 
| 447 | 
            +
                    )
         | 
| 448 | 
            +
                    self._feature_size += ch
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    self.output_blocks = nn.ModuleList([])
         | 
| 451 | 
            +
                    for level, mult in list(enumerate(channel_mult))[::-1]:
         | 
| 452 | 
            +
                        for i in range(num_res_blocks + 1):
         | 
| 453 | 
            +
                            ich = input_block_chans.pop()
         | 
| 454 | 
            +
                            layers = [
         | 
| 455 | 
            +
                                ResBlock(
         | 
| 456 | 
            +
                                    ch + ich,
         | 
| 457 | 
            +
                                    time_embed_dim,
         | 
| 458 | 
            +
                                    dropout,
         | 
| 459 | 
            +
                                    out_channels=int(model_channels * mult),
         | 
| 460 | 
            +
                                    dims=dims,
         | 
| 461 | 
            +
                                    use_checkpoint=use_checkpoint,
         | 
| 462 | 
            +
                                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 463 | 
            +
                                )
         | 
| 464 | 
            +
                            ]
         | 
| 465 | 
            +
                            ch = int(model_channels * mult)
         | 
| 466 | 
            +
                            if ds in attention_resolutions:
         | 
| 467 | 
            +
                                layers.append(
         | 
| 468 | 
            +
                                    AttentionBlock(
         | 
| 469 | 
            +
                                        ch,
         | 
| 470 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 471 | 
            +
                                        num_heads=num_heads_upsample,
         | 
| 472 | 
            +
                                        num_head_channels=num_head_channels,
         | 
| 473 | 
            +
                                        encoder_channels=encoder_channels,
         | 
| 474 | 
            +
                                    )
         | 
| 475 | 
            +
                                )
         | 
| 476 | 
            +
                            if level and i == num_res_blocks:
         | 
| 477 | 
            +
                                out_ch = ch
         | 
| 478 | 
            +
                                layers.append(
         | 
| 479 | 
            +
                                    ResBlock(
         | 
| 480 | 
            +
                                        ch,
         | 
| 481 | 
            +
                                        time_embed_dim,
         | 
| 482 | 
            +
                                        dropout,
         | 
| 483 | 
            +
                                        out_channels=out_ch,
         | 
| 484 | 
            +
                                        dims=dims,
         | 
| 485 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 486 | 
            +
                                        use_scale_shift_norm=use_scale_shift_norm,
         | 
| 487 | 
            +
                                        up=True,
         | 
| 488 | 
            +
                                    )
         | 
| 489 | 
            +
                                    if resblock_updown
         | 
| 490 | 
            +
                                    else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
         | 
| 491 | 
            +
                                )
         | 
| 492 | 
            +
                                ds //= 2
         | 
| 493 | 
            +
                            self.output_blocks.append(TimestepEmbedSequential(*layers))
         | 
| 494 | 
            +
                            self._feature_size += ch
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    self.out = nn.Sequential(
         | 
| 497 | 
            +
                        normalization(ch, swish=1.0),
         | 
| 498 | 
            +
                        nn.Identity(),
         | 
| 499 | 
            +
                        zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
         | 
| 500 | 
            +
                    )
         | 
| 501 | 
            +
                    self.use_fp16 = use_fp16
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def convert_to_fp16(self):
         | 
| 504 | 
            +
                    """
         | 
| 505 | 
            +
                    Convert the torso of the model to float16.
         | 
| 506 | 
            +
                    """
         | 
| 507 | 
            +
                    self.input_blocks.apply(convert_module_to_f16)
         | 
| 508 | 
            +
                    self.middle_block.apply(convert_module_to_f16)
         | 
| 509 | 
            +
                    self.output_blocks.apply(convert_module_to_f16)
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                def convert_to_fp32(self):
         | 
| 512 | 
            +
                    """
         | 
| 513 | 
            +
                    Convert the torso of the model to float32.
         | 
| 514 | 
            +
                    """
         | 
| 515 | 
            +
                    self.input_blocks.apply(convert_module_to_f32)
         | 
| 516 | 
            +
                    self.middle_block.apply(convert_module_to_f32)
         | 
| 517 | 
            +
                    self.output_blocks.apply(convert_module_to_f32)
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                def forward(self, x, timesteps, y=None):
         | 
| 520 | 
            +
                    """
         | 
| 521 | 
            +
                    Apply the model to an input batch.
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    :param x: an [N x C x ...] Tensor of inputs.
         | 
| 524 | 
            +
                    :param timesteps: a 1-D batch of timesteps.
         | 
| 525 | 
            +
                    :param y: an [N] Tensor of labels, if class-conditional.
         | 
| 526 | 
            +
                    :return: an [N x C x ...] Tensor of outputs.
         | 
| 527 | 
            +
                    """
         | 
| 528 | 
            +
                    assert (y is not None) == (
         | 
| 529 | 
            +
                        self.num_classes is not None
         | 
| 530 | 
            +
                    ), "must specify y if and only if the model is class-conditional"
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    hs = []
         | 
| 533 | 
            +
                    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    if self.num_classes is not None:
         | 
| 536 | 
            +
                        assert y.shape == (x.shape[0],)
         | 
| 537 | 
            +
                        emb = emb + self.label_emb(y)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    h = x.type(self.dtype)
         | 
| 540 | 
            +
                    for module in self.input_blocks:
         | 
| 541 | 
            +
                        h = module(h, emb)
         | 
| 542 | 
            +
                        hs.append(h)
         | 
| 543 | 
            +
                    h = self.middle_block(h, emb)
         | 
| 544 | 
            +
                    for module in self.output_blocks:
         | 
| 545 | 
            +
                        h = th.cat([h, hs.pop()], dim=1)
         | 
| 546 | 
            +
                        h = module(h, emb)
         | 
| 547 | 
            +
                    h = h.type(x.dtype)
         | 
| 548 | 
            +
                    return self.out(h)
         | 
| 549 | 
            +
             | 
| 550 | 
            +
            class SuperResUNetModel(UNetModel):
         | 
| 551 | 
            +
                """
         | 
| 552 | 
            +
                A UNetModel that performs super-resolution.
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                Expects an extra kwarg `low_res` to condition on a low-resolution image.
         | 
| 555 | 
            +
                """
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 558 | 
            +
                    if "in_channels" in kwargs:
         | 
| 559 | 
            +
                        kwargs = dict(kwargs)
         | 
| 560 | 
            +
                        kwargs["in_channels"] = kwargs["in_channels"] * 2
         | 
| 561 | 
            +
                    else:
         | 
| 562 | 
            +
                        # Curse you, Python. Or really, just curse positional arguments :|.
         | 
| 563 | 
            +
                        args = list(args)
         | 
| 564 | 
            +
                        args[1] = args[1] * 2
         | 
| 565 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                def forward(self, x, timesteps, low_res=None, **kwargs):
         | 
| 568 | 
            +
                    _, _, new_height, new_width = x.shape
         | 
| 569 | 
            +
                    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
         | 
| 570 | 
            +
                    x = th.cat([x, upsampled], dim=1)
         | 
| 571 | 
            +
                    return super().forward(x, timesteps, **kwargs)
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                
         | 
| 574 | 
            +
            class InpaintUNetModel(UNetModel):
         | 
| 575 | 
            +
                """
         | 
| 576 | 
            +
                A UNetModel which can perform inpainting.
         | 
| 577 | 
            +
                """
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 580 | 
            +
                    if "in_channels" in kwargs:
         | 
| 581 | 
            +
                        kwargs = dict(kwargs)
         | 
| 582 | 
            +
                        kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1
         | 
| 583 | 
            +
                    else:
         | 
| 584 | 
            +
                        # Curse you, Python. Or really, just curse positional arguments :|.
         | 
| 585 | 
            +
                        args = list(args)
         | 
| 586 | 
            +
                        args[1] = args[1] * 2 + 1
         | 
| 587 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs):
         | 
| 590 | 
            +
                    if inpaint_image is None:
         | 
| 591 | 
            +
                        inpaint_image = th.zeros_like(x)
         | 
| 592 | 
            +
                    if inpaint_mask is None:
         | 
| 593 | 
            +
                        inpaint_mask = th.zeros_like(x[:, :1])
         | 
| 594 | 
            +
                    return super().forward(
         | 
| 595 | 
            +
                        th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1),
         | 
| 596 | 
            +
                        timesteps,
         | 
| 597 | 
            +
                        **kwargs,
         | 
| 598 | 
            +
                    )
         | 
| 599 | 
            +
             | 
| 600 | 
            +
             | 
| 601 | 
            +
            class SuperResInpaintUNetModel(UNetModel):
         | 
| 602 | 
            +
                """
         | 
| 603 | 
            +
                A UNetModel which can perform both upsampling and inpainting.
         | 
| 604 | 
            +
                """
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 607 | 
            +
                    if "in_channels" in kwargs:
         | 
| 608 | 
            +
                        kwargs = dict(kwargs)
         | 
| 609 | 
            +
                        kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1
         | 
| 610 | 
            +
                    else:
         | 
| 611 | 
            +
                        # Curse you, Python. Or really, just curse positional arguments :|.
         | 
| 612 | 
            +
                        args = list(args)
         | 
| 613 | 
            +
                        args[1] = args[1] * 3 + 1
         | 
| 614 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                def forward(
         | 
| 617 | 
            +
                    self,
         | 
| 618 | 
            +
                    x,
         | 
| 619 | 
            +
                    timesteps,
         | 
| 620 | 
            +
                    inpaint_image=None,
         | 
| 621 | 
            +
                    inpaint_mask=None,
         | 
| 622 | 
            +
                    low_res=None,
         | 
| 623 | 
            +
                    **kwargs,
         | 
| 624 | 
            +
                ):
         | 
| 625 | 
            +
                    if inpaint_image is None:
         | 
| 626 | 
            +
                        inpaint_image = th.zeros_like(x)
         | 
| 627 | 
            +
                    if inpaint_mask is None:
         | 
| 628 | 
            +
                        inpaint_mask = th.zeros_like(x[:, :1])
         | 
| 629 | 
            +
                    _, _, new_height, new_width = x.shape
         | 
| 630 | 
            +
                    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
         | 
| 631 | 
            +
                    return super().forward(
         | 
| 632 | 
            +
                        th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1),
         | 
| 633 | 
            +
                        timesteps,
         | 
| 634 | 
            +
                        **kwargs,
         | 
| 635 | 
            +
                    )
         | 
    	
        glide_text2im/xf.py
    ADDED
    
    | @@ -0,0 +1,130 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Transformer implementation adapted from CLIP ViT:
         | 
| 3 | 
            +
            https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch as th
         | 
| 9 | 
            +
            import torch.nn as nn
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def convert_module_to_f16(l):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Convert primitive modules to float16.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
         | 
| 17 | 
            +
                    l.weight.data = l.weight.data.half()
         | 
| 18 | 
            +
                    if l.bias is not None:
         | 
| 19 | 
            +
                        l.bias.data = l.bias.data.half()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class LayerNorm(nn.LayerNorm):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Implementation that supports fp16 inputs but fp32 gains/biases.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, x: th.Tensor):
         | 
| 28 | 
            +
                    return super().forward(x.float()).to(x.dtype)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class MultiheadAttention(nn.Module):
         | 
| 32 | 
            +
                def __init__(self, n_ctx, width, heads):
         | 
| 33 | 
            +
                    super().__init__()
         | 
| 34 | 
            +
                    self.n_ctx = n_ctx
         | 
| 35 | 
            +
                    self.width = width
         | 
| 36 | 
            +
                    self.heads = heads
         | 
| 37 | 
            +
                    self.c_qkv = nn.Linear(width, width * 3)
         | 
| 38 | 
            +
                    self.c_proj = nn.Linear(width, width)
         | 
| 39 | 
            +
                    self.attention = QKVMultiheadAttention(heads, n_ctx)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def forward(self, x):
         | 
| 42 | 
            +
                    x = self.c_qkv(x)
         | 
| 43 | 
            +
                    x = self.attention(x)
         | 
| 44 | 
            +
                    x = self.c_proj(x)
         | 
| 45 | 
            +
                    return x
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            class MLP(nn.Module):
         | 
| 49 | 
            +
                def __init__(self, width):
         | 
| 50 | 
            +
                    super().__init__()
         | 
| 51 | 
            +
                    self.width = width
         | 
| 52 | 
            +
                    self.c_fc = nn.Linear(width, width * 4)
         | 
| 53 | 
            +
                    self.c_proj = nn.Linear(width * 4, width)
         | 
| 54 | 
            +
                    self.gelu = nn.GELU()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, x):
         | 
| 57 | 
            +
                    return self.c_proj(self.gelu(self.c_fc(x)))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class QKVMultiheadAttention(nn.Module):
         | 
| 61 | 
            +
                def __init__(self, n_heads: int, n_ctx: int):
         | 
| 62 | 
            +
                    super().__init__()
         | 
| 63 | 
            +
                    self.n_heads = n_heads
         | 
| 64 | 
            +
                    self.n_ctx = n_ctx
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def forward(self, qkv):
         | 
| 67 | 
            +
                    bs, n_ctx, width = qkv.shape
         | 
| 68 | 
            +
                    attn_ch = width // self.n_heads // 3
         | 
| 69 | 
            +
                    scale = 1 / math.sqrt(math.sqrt(attn_ch))
         | 
| 70 | 
            +
                    qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
         | 
| 71 | 
            +
                    q, k, v = th.split(qkv, attn_ch, dim=-1)
         | 
| 72 | 
            +
                    weight = th.einsum(
         | 
| 73 | 
            +
                        "bthc,bshc->bhts", q * scale, k * scale
         | 
| 74 | 
            +
                    )  # More stable with f16 than dividing afterwards
         | 
| 75 | 
            +
                    wdtype = weight.dtype
         | 
| 76 | 
            +
                    weight = th.softmax(weight.float(), dim=-1).type(wdtype)
         | 
| 77 | 
            +
                    return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            class ResidualAttentionBlock(nn.Module):
         | 
| 81 | 
            +
                def __init__(
         | 
| 82 | 
            +
                    self,
         | 
| 83 | 
            +
                    n_ctx: int,
         | 
| 84 | 
            +
                    width: int,
         | 
| 85 | 
            +
                    heads: int,
         | 
| 86 | 
            +
                ):
         | 
| 87 | 
            +
                    super().__init__()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    self.attn = MultiheadAttention(
         | 
| 90 | 
            +
                        n_ctx,
         | 
| 91 | 
            +
                        width,
         | 
| 92 | 
            +
                        heads,
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                    self.ln_1 = LayerNorm(width)
         | 
| 95 | 
            +
                    self.mlp = MLP(width)
         | 
| 96 | 
            +
                    self.ln_2 = LayerNorm(width)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def forward(self, x: th.Tensor):
         | 
| 99 | 
            +
                    x = x + self.attn(self.ln_1(x))
         | 
| 100 | 
            +
                    x = x + self.mlp(self.ln_2(x))
         | 
| 101 | 
            +
                    return x
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class Transformer(nn.Module):
         | 
| 105 | 
            +
                def __init__(
         | 
| 106 | 
            +
                    self,
         | 
| 107 | 
            +
                    n_ctx: int,
         | 
| 108 | 
            +
                    width: int,
         | 
| 109 | 
            +
                    layers: int,
         | 
| 110 | 
            +
                    heads: int,
         | 
| 111 | 
            +
                ):
         | 
| 112 | 
            +
                    super().__init__()
         | 
| 113 | 
            +
                    self.n_ctx = n_ctx
         | 
| 114 | 
            +
                    self.width = width
         | 
| 115 | 
            +
                    self.layers = layers
         | 
| 116 | 
            +
                    self.resblocks = nn.ModuleList(
         | 
| 117 | 
            +
                        [
         | 
| 118 | 
            +
                            ResidualAttentionBlock(
         | 
| 119 | 
            +
                                n_ctx,
         | 
| 120 | 
            +
                                width,
         | 
| 121 | 
            +
                                heads,
         | 
| 122 | 
            +
                            )
         | 
| 123 | 
            +
                            for _ in range(layers)
         | 
| 124 | 
            +
                        ]
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def forward(self, x: th.Tensor):
         | 
| 128 | 
            +
                    for block in self.resblocks:
         | 
| 129 | 
            +
                        x = block(x)
         | 
| 130 | 
            +
                    return x
         | 
    	
        model-card.md
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Overview
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            This card describes the diffusion model GLIDE (filtered) and noised CLIP model described in the paper [GLIDE: Towards
         | 
| 4 | 
            +
            Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/abs/2112.10741)
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Datasets
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            GLIDE (filtered) was trained on a filtered version of a dataset comprised of several hundred million text-image pairs
         | 
| 9 | 
            +
            collected from the internet. We constructed a set of filters intended to remove all images of people, violent objects, and some
         | 
| 10 | 
            +
            and hate symbols (see Appendix F of the paper for details). The size of the dataset after filtering was approximately
         | 
| 11 | 
            +
            67M text-image pairs.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Our noised CLIP model which was trained on the dataset described above, augmented with a filtered version of the dataset used
         | 
| 14 | 
            +
            to train the [original CLIP models](https://github.com/openai/clip). The total size of this augmented dataset is approximately 137M pairs.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Performance
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            Qualitatively, we find that the generated images from GLIDE (filtered) often look semi-realistic, but the small size of the model hinders
         | 
| 19 | 
            +
            its ability to bind attributes to objects and perform compositional tasks. Because the dataset used to train GLIDE
         | 
| 20 | 
            +
            (filtered) has been preprocessed to remove images of people, this also limits its world knowledge, especially in regard
         | 
| 21 | 
            +
            to concepts that involve people.
         | 
| 22 | 
            +
            Finally, due to the dataset used to train GLIDE (filtered), the model has reduced capabilities to compose multiple objects in complex ways compared to models of a similar size trained on our internal dataset.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            We do not directly measure quantitative metrics for GLIDE (filtered). In particular, most of the evaluations we report for our other models are biased against GLIDE (filtered), since they use prompts that often require generations of people. Evaluating people-free models remains an open area of research.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # Intended Use
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            We release these models to help advance research in generative modeling. Due to the limitations and biases of GLIDE (filtered), we do not currently recommend it for commercial use.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            Functionally, these models are intended to be able to perform the following tasks for research purposes:
         | 
| 31 | 
            +
             * Generate images from natural language prompts
         | 
| 32 | 
            +
             * Iteratively edit and refine images using inpainting
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            These models are explicitly not intended to generate images of people or other subjects we filtered for (see Appendix F of the paper for details).
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Limitations
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            Despite the dataset filtering applied before training, GLIDE (filtered) continues to exhibit biases that extend beyond those found in images of people.
         | 
| 39 | 
            +
            We explore some of these biases in our paper. For example:
         | 
| 40 | 
            +
             | 
| 41 | 
            +
              * It produces different outputs when asked to generate toys for boys and toys for girls.
         | 
| 42 | 
            +
              * It gravitates toward generating images of churches when asked to generate "a religious place",
         | 
| 43 | 
            +
                and this bias is amplified by classifier-free guidance.
         | 
| 44 | 
            +
              * It may have a greater propensity for generating hate symbols other than swastikas and confederate flags. Our filter
         | 
| 45 | 
            +
                for hate symbols focused specifically on these two cases, as we found few relevant images of hate symbols in our
         | 
| 46 | 
            +
                dataset. However, we also found that the model has diminished capabilities across a wider set of symbols.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            GLIDE (filtered) can fail to produce realistic outputs for complex prompts or for prompts that involve concepts that are
         | 
| 49 | 
            +
            not well-represented in its training data. While the data for the model was filtered to remove certain types of images,
         | 
| 50 | 
            +
            the data still exhibits biases toward Western-centric concepts.
         | 
    	
        notebooks/clip_guided.ipynb
    ADDED
    
    | @@ -0,0 +1,246 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": null,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [],
         | 
| 8 | 
            +
               "source": [
         | 
| 9 | 
            +
                "# Run this line in Colab to install the package if it is\n",
         | 
| 10 | 
            +
                "# not already installed.\n",
         | 
| 11 | 
            +
                "!pip install git+https://github.com/openai/glide-text2im"
         | 
| 12 | 
            +
               ]
         | 
| 13 | 
            +
              },
         | 
| 14 | 
            +
              {
         | 
| 15 | 
            +
               "cell_type": "code",
         | 
| 16 | 
            +
               "execution_count": null,
         | 
| 17 | 
            +
               "metadata": {},
         | 
| 18 | 
            +
               "outputs": [],
         | 
| 19 | 
            +
               "source": [
         | 
| 20 | 
            +
                "from PIL import Image\n",
         | 
| 21 | 
            +
                "from IPython.display import display\n",
         | 
| 22 | 
            +
                "import torch as th\n",
         | 
| 23 | 
            +
                "import torch.nn as nn\n",
         | 
| 24 | 
            +
                "\n",
         | 
| 25 | 
            +
                "from glide_text2im.clip.model_creation import create_clip_model\n",
         | 
| 26 | 
            +
                "from glide_text2im.download import load_checkpoint\n",
         | 
| 27 | 
            +
                "from glide_text2im.model_creation import (\n",
         | 
| 28 | 
            +
                "    create_model_and_diffusion,\n",
         | 
| 29 | 
            +
                "    model_and_diffusion_defaults,\n",
         | 
| 30 | 
            +
                "    model_and_diffusion_defaults_upsampler,\n",
         | 
| 31 | 
            +
                ")\n",
         | 
| 32 | 
            +
                "from glide_text2im.tokenizer.simple_tokenizer import SimpleTokenizer"
         | 
| 33 | 
            +
               ]
         | 
| 34 | 
            +
              },
         | 
| 35 | 
            +
              {
         | 
| 36 | 
            +
               "cell_type": "code",
         | 
| 37 | 
            +
               "execution_count": null,
         | 
| 38 | 
            +
               "metadata": {},
         | 
| 39 | 
            +
               "outputs": [],
         | 
| 40 | 
            +
               "source": [
         | 
| 41 | 
            +
                "# This notebook supports both CPU and GPU.\n",
         | 
| 42 | 
            +
                "# On CPU, generating one sample may take on the order of 20 minutes.\n",
         | 
| 43 | 
            +
                "# On a GPU, it should be under a minute.\n",
         | 
| 44 | 
            +
                "\n",
         | 
| 45 | 
            +
                "has_cuda = th.cuda.is_available()\n",
         | 
| 46 | 
            +
                "device = th.device('cpu' if not has_cuda else 'cuda')"
         | 
| 47 | 
            +
               ]
         | 
| 48 | 
            +
              },
         | 
| 49 | 
            +
              {
         | 
| 50 | 
            +
               "cell_type": "code",
         | 
| 51 | 
            +
               "execution_count": null,
         | 
| 52 | 
            +
               "metadata": {},
         | 
| 53 | 
            +
               "outputs": [],
         | 
| 54 | 
            +
               "source": [
         | 
| 55 | 
            +
                "# Create base model.\n",
         | 
| 56 | 
            +
                "options = model_and_diffusion_defaults()\n",
         | 
| 57 | 
            +
                "options['use_fp16'] = has_cuda\n",
         | 
| 58 | 
            +
                "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
         | 
| 59 | 
            +
                "model, diffusion = create_model_and_diffusion(**options)\n",
         | 
| 60 | 
            +
                "model.eval()\n",
         | 
| 61 | 
            +
                "if has_cuda:\n",
         | 
| 62 | 
            +
                "    model.convert_to_fp16()\n",
         | 
| 63 | 
            +
                "model.to(device)\n",
         | 
| 64 | 
            +
                "model.load_state_dict(load_checkpoint('base', device))\n",
         | 
| 65 | 
            +
                "print('total base parameters', sum(x.numel() for x in model.parameters()))"
         | 
| 66 | 
            +
               ]
         | 
| 67 | 
            +
              },
         | 
| 68 | 
            +
              {
         | 
| 69 | 
            +
               "cell_type": "code",
         | 
| 70 | 
            +
               "execution_count": null,
         | 
| 71 | 
            +
               "metadata": {},
         | 
| 72 | 
            +
               "outputs": [],
         | 
| 73 | 
            +
               "source": [
         | 
| 74 | 
            +
                "# Create upsampler model.\n",
         | 
| 75 | 
            +
                "options_up = model_and_diffusion_defaults_upsampler()\n",
         | 
| 76 | 
            +
                "options_up['use_fp16'] = has_cuda\n",
         | 
| 77 | 
            +
                "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
         | 
| 78 | 
            +
                "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
         | 
| 79 | 
            +
                "model_up.eval()\n",
         | 
| 80 | 
            +
                "if has_cuda:\n",
         | 
| 81 | 
            +
                "    model_up.convert_to_fp16()\n",
         | 
| 82 | 
            +
                "model_up.to(device)\n",
         | 
| 83 | 
            +
                "model_up.load_state_dict(load_checkpoint('upsample', device))\n",
         | 
| 84 | 
            +
                "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
         | 
| 85 | 
            +
               ]
         | 
| 86 | 
            +
              },
         | 
| 87 | 
            +
              {
         | 
| 88 | 
            +
               "cell_type": "code",
         | 
| 89 | 
            +
               "execution_count": null,
         | 
| 90 | 
            +
               "metadata": {},
         | 
| 91 | 
            +
               "outputs": [],
         | 
| 92 | 
            +
               "source": [
         | 
| 93 | 
            +
                "# Create CLIP model.\n",
         | 
| 94 | 
            +
                "clip_model = create_clip_model(device=device)\n",
         | 
| 95 | 
            +
                "clip_model.image_encoder.load_state_dict(load_checkpoint('clip/image-enc', device))\n",
         | 
| 96 | 
            +
                "clip_model.text_encoder.load_state_dict(load_checkpoint('clip/text-enc', device))"
         | 
| 97 | 
            +
               ]
         | 
| 98 | 
            +
              },
         | 
| 99 | 
            +
              {
         | 
| 100 | 
            +
               "cell_type": "code",
         | 
| 101 | 
            +
               "execution_count": null,
         | 
| 102 | 
            +
               "metadata": {},
         | 
| 103 | 
            +
               "outputs": [],
         | 
| 104 | 
            +
               "source": [
         | 
| 105 | 
            +
                "def show_images(batch: th.Tensor):\n",
         | 
| 106 | 
            +
                "    \"\"\" Display a batch of images inline. \"\"\"\n",
         | 
| 107 | 
            +
                "    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
         | 
| 108 | 
            +
                "    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
         | 
| 109 | 
            +
                "    display(Image.fromarray(reshaped.numpy()))"
         | 
| 110 | 
            +
               ]
         | 
| 111 | 
            +
              },
         | 
| 112 | 
            +
              {
         | 
| 113 | 
            +
               "cell_type": "code",
         | 
| 114 | 
            +
               "execution_count": null,
         | 
| 115 | 
            +
               "metadata": {},
         | 
| 116 | 
            +
               "outputs": [],
         | 
| 117 | 
            +
               "source": [
         | 
| 118 | 
            +
                "# Sampling parameters\n",
         | 
| 119 | 
            +
                "prompt = \"an oil painting of a corgi\"\n",
         | 
| 120 | 
            +
                "batch_size = 1\n",
         | 
| 121 | 
            +
                "guidance_scale = 3.0\n",
         | 
| 122 | 
            +
                "\n",
         | 
| 123 | 
            +
                "# Tune this parameter to control the sharpness of 256x256 images.\n",
         | 
| 124 | 
            +
                "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
         | 
| 125 | 
            +
                "upsample_temp = 0.997"
         | 
| 126 | 
            +
               ]
         | 
| 127 | 
            +
              },
         | 
| 128 | 
            +
              {
         | 
| 129 | 
            +
               "cell_type": "code",
         | 
| 130 | 
            +
               "execution_count": null,
         | 
| 131 | 
            +
               "metadata": {},
         | 
| 132 | 
            +
               "outputs": [],
         | 
| 133 | 
            +
               "source": [
         | 
| 134 | 
            +
                "##############################\n",
         | 
| 135 | 
            +
                "# Sample from the base model #\n",
         | 
| 136 | 
            +
                "##############################\n",
         | 
| 137 | 
            +
                "\n",
         | 
| 138 | 
            +
                "# Create the text tokens to feed to the model.\n",
         | 
| 139 | 
            +
                "tokens = model.tokenizer.encode(prompt)\n",
         | 
| 140 | 
            +
                "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
         | 
| 141 | 
            +
                "    tokens, options['text_ctx']\n",
         | 
| 142 | 
            +
                ")\n",
         | 
| 143 | 
            +
                "\n",
         | 
| 144 | 
            +
                "# Pack the tokens together into model kwargs.\n",
         | 
| 145 | 
            +
                "model_kwargs = dict(\n",
         | 
| 146 | 
            +
                "    tokens=th.tensor([tokens] * batch_size, device=device),\n",
         | 
| 147 | 
            +
                "    mask=th.tensor([mask] * batch_size, dtype=th.bool, device=device),\n",
         | 
| 148 | 
            +
                ")\n",
         | 
| 149 | 
            +
                "\n",
         | 
| 150 | 
            +
                "# Setup guidance function for CLIP model.\n",
         | 
| 151 | 
            +
                "cond_fn = clip_model.cond_fn([prompt] * batch_size, guidance_scale)\n",
         | 
| 152 | 
            +
                "\n",
         | 
| 153 | 
            +
                "# Sample from the base model.\n",
         | 
| 154 | 
            +
                "model.del_cache()\n",
         | 
| 155 | 
            +
                "samples = diffusion.p_sample_loop(\n",
         | 
| 156 | 
            +
                "    model,\n",
         | 
| 157 | 
            +
                "    (batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
         | 
| 158 | 
            +
                "    device=device,\n",
         | 
| 159 | 
            +
                "    clip_denoised=True,\n",
         | 
| 160 | 
            +
                "    progress=True,\n",
         | 
| 161 | 
            +
                "    model_kwargs=model_kwargs,\n",
         | 
| 162 | 
            +
                "    cond_fn=cond_fn,\n",
         | 
| 163 | 
            +
                ")\n",
         | 
| 164 | 
            +
                "model.del_cache()\n",
         | 
| 165 | 
            +
                "\n",
         | 
| 166 | 
            +
                "# Show the output\n",
         | 
| 167 | 
            +
                "show_images(samples)"
         | 
| 168 | 
            +
               ]
         | 
| 169 | 
            +
              },
         | 
| 170 | 
            +
              {
         | 
| 171 | 
            +
               "cell_type": "code",
         | 
| 172 | 
            +
               "execution_count": null,
         | 
| 173 | 
            +
               "metadata": {},
         | 
| 174 | 
            +
               "outputs": [],
         | 
| 175 | 
            +
               "source": [
         | 
| 176 | 
            +
                "##############################\n",
         | 
| 177 | 
            +
                "# Upsample the 64x64 samples #\n",
         | 
| 178 | 
            +
                "##############################\n",
         | 
| 179 | 
            +
                "\n",
         | 
| 180 | 
            +
                "tokens = model_up.tokenizer.encode(prompt)\n",
         | 
| 181 | 
            +
                "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
         | 
| 182 | 
            +
                "    tokens, options_up['text_ctx']\n",
         | 
| 183 | 
            +
                ")\n",
         | 
| 184 | 
            +
                "\n",
         | 
| 185 | 
            +
                "# Create the model conditioning dict.\n",
         | 
| 186 | 
            +
                "model_kwargs = dict(\n",
         | 
| 187 | 
            +
                "    # Low-res image to upsample.\n",
         | 
| 188 | 
            +
                "    low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
         | 
| 189 | 
            +
                "\n",
         | 
| 190 | 
            +
                "    # Text tokens\n",
         | 
| 191 | 
            +
                "    tokens=th.tensor(\n",
         | 
| 192 | 
            +
                "        [tokens] * batch_size, device=device\n",
         | 
| 193 | 
            +
                "    ),\n",
         | 
| 194 | 
            +
                "    mask=th.tensor(\n",
         | 
| 195 | 
            +
                "        [mask] * batch_size,\n",
         | 
| 196 | 
            +
                "        dtype=th.bool,\n",
         | 
| 197 | 
            +
                "        device=device,\n",
         | 
| 198 | 
            +
                "    ),\n",
         | 
| 199 | 
            +
                ")\n",
         | 
| 200 | 
            +
                "\n",
         | 
| 201 | 
            +
                "# Sample from the base model.\n",
         | 
| 202 | 
            +
                "model_up.del_cache()\n",
         | 
| 203 | 
            +
                "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
         | 
| 204 | 
            +
                "up_samples = diffusion_up.ddim_sample_loop(\n",
         | 
| 205 | 
            +
                "    model_up,\n",
         | 
| 206 | 
            +
                "    up_shape,\n",
         | 
| 207 | 
            +
                "    noise=th.randn(up_shape, device=device) * upsample_temp,\n",
         | 
| 208 | 
            +
                "    device=device,\n",
         | 
| 209 | 
            +
                "    clip_denoised=True,\n",
         | 
| 210 | 
            +
                "    progress=True,\n",
         | 
| 211 | 
            +
                "    model_kwargs=model_kwargs,\n",
         | 
| 212 | 
            +
                "    cond_fn=None,\n",
         | 
| 213 | 
            +
                ")[:batch_size]\n",
         | 
| 214 | 
            +
                "model_up.del_cache()\n",
         | 
| 215 | 
            +
                "\n",
         | 
| 216 | 
            +
                "# Show the output\n",
         | 
| 217 | 
            +
                "show_images(up_samples)"
         | 
| 218 | 
            +
               ]
         | 
| 219 | 
            +
              }
         | 
| 220 | 
            +
             ],
         | 
| 221 | 
            +
             "metadata": {
         | 
| 222 | 
            +
              "interpreter": {
         | 
| 223 | 
            +
               "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
         | 
| 224 | 
            +
              },
         | 
| 225 | 
            +
              "kernelspec": {
         | 
| 226 | 
            +
               "display_name": "Python 3",
         | 
| 227 | 
            +
               "language": "python",
         | 
| 228 | 
            +
               "name": "python3"
         | 
| 229 | 
            +
              },
         | 
| 230 | 
            +
              "language_info": {
         | 
| 231 | 
            +
               "codemirror_mode": {
         | 
| 232 | 
            +
                "name": "ipython",
         | 
| 233 | 
            +
                "version": 3
         | 
| 234 | 
            +
               },
         | 
| 235 | 
            +
               "file_extension": ".py",
         | 
| 236 | 
            +
               "mimetype": "text/x-python",
         | 
| 237 | 
            +
               "name": "python",
         | 
| 238 | 
            +
               "nbconvert_exporter": "python",
         | 
| 239 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 240 | 
            +
               "version": "3.7.3"
         | 
| 241 | 
            +
              },
         | 
| 242 | 
            +
              "accelerator": "GPU"
         | 
| 243 | 
            +
             },
         | 
| 244 | 
            +
             "nbformat": 4,
         | 
| 245 | 
            +
             "nbformat_minor": 2
         | 
| 246 | 
            +
            }
         | 
    	
        notebooks/grass.png
    ADDED
    
    |   | 
    	
        notebooks/inpaint.ipynb
    ADDED
    
    | @@ -0,0 +1,302 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": null,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [],
         | 
| 8 | 
            +
               "source": [
         | 
| 9 | 
            +
                "# Run this line in Colab to install the package if it is\n",
         | 
| 10 | 
            +
                "# not already installed.\n",
         | 
| 11 | 
            +
                "!pip install git+https://github.com/openai/glide-text2im"
         | 
| 12 | 
            +
               ]
         | 
| 13 | 
            +
              },
         | 
| 14 | 
            +
              {
         | 
| 15 | 
            +
               "cell_type": "code",
         | 
| 16 | 
            +
               "execution_count": null,
         | 
| 17 | 
            +
               "metadata": {},
         | 
| 18 | 
            +
               "outputs": [],
         | 
| 19 | 
            +
               "source": [
         | 
| 20 | 
            +
                "from typing import Tuple\n",
         | 
| 21 | 
            +
                "\n",
         | 
| 22 | 
            +
                "from IPython.display import display\n",
         | 
| 23 | 
            +
                "from PIL import Image\n",
         | 
| 24 | 
            +
                "import numpy as np\n",
         | 
| 25 | 
            +
                "import torch as th\n",
         | 
| 26 | 
            +
                "import torch.nn.functional as F\n",
         | 
| 27 | 
            +
                "\n",
         | 
| 28 | 
            +
                "from glide_text2im.download import load_checkpoint\n",
         | 
| 29 | 
            +
                "from glide_text2im.model_creation import (\n",
         | 
| 30 | 
            +
                "    create_model_and_diffusion,\n",
         | 
| 31 | 
            +
                "    model_and_diffusion_defaults,\n",
         | 
| 32 | 
            +
                "    model_and_diffusion_defaults_upsampler\n",
         | 
| 33 | 
            +
                ")"
         | 
| 34 | 
            +
               ]
         | 
| 35 | 
            +
              },
         | 
| 36 | 
            +
              {
         | 
| 37 | 
            +
               "cell_type": "code",
         | 
| 38 | 
            +
               "execution_count": null,
         | 
| 39 | 
            +
               "metadata": {},
         | 
| 40 | 
            +
               "outputs": [],
         | 
| 41 | 
            +
               "source": [
         | 
| 42 | 
            +
                "# This notebook supports both CPU and GPU.\n",
         | 
| 43 | 
            +
                "# On CPU, generating one sample may take on the order of 20 minutes.\n",
         | 
| 44 | 
            +
                "# On a GPU, it should be under a minute.\n",
         | 
| 45 | 
            +
                "\n",
         | 
| 46 | 
            +
                "has_cuda = th.cuda.is_available()\n",
         | 
| 47 | 
            +
                "device = th.device('cpu' if not has_cuda else 'cuda')"
         | 
| 48 | 
            +
               ]
         | 
| 49 | 
            +
              },
         | 
| 50 | 
            +
              {
         | 
| 51 | 
            +
               "cell_type": "code",
         | 
| 52 | 
            +
               "execution_count": null,
         | 
| 53 | 
            +
               "metadata": {},
         | 
| 54 | 
            +
               "outputs": [],
         | 
| 55 | 
            +
               "source": [
         | 
| 56 | 
            +
                "# Create base model.\n",
         | 
| 57 | 
            +
                "options = model_and_diffusion_defaults()\n",
         | 
| 58 | 
            +
                "options['inpaint'] = True\n",
         | 
| 59 | 
            +
                "options['use_fp16'] = has_cuda\n",
         | 
| 60 | 
            +
                "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
         | 
| 61 | 
            +
                "model, diffusion = create_model_and_diffusion(**options)\n",
         | 
| 62 | 
            +
                "model.eval()\n",
         | 
| 63 | 
            +
                "if has_cuda:\n",
         | 
| 64 | 
            +
                "    model.convert_to_fp16()\n",
         | 
| 65 | 
            +
                "model.to(device)\n",
         | 
| 66 | 
            +
                "model.load_state_dict(load_checkpoint('base-inpaint', device))\n",
         | 
| 67 | 
            +
                "print('total base parameters', sum(x.numel() for x in model.parameters()))"
         | 
| 68 | 
            +
               ]
         | 
| 69 | 
            +
              },
         | 
| 70 | 
            +
              {
         | 
| 71 | 
            +
               "cell_type": "code",
         | 
| 72 | 
            +
               "execution_count": null,
         | 
| 73 | 
            +
               "metadata": {},
         | 
| 74 | 
            +
               "outputs": [],
         | 
| 75 | 
            +
               "source": [
         | 
| 76 | 
            +
                "# Create upsampler model.\n",
         | 
| 77 | 
            +
                "options_up = model_and_diffusion_defaults_upsampler()\n",
         | 
| 78 | 
            +
                "options_up['inpaint'] = True\n",
         | 
| 79 | 
            +
                "options_up['use_fp16'] = has_cuda\n",
         | 
| 80 | 
            +
                "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
         | 
| 81 | 
            +
                "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
         | 
| 82 | 
            +
                "model_up.eval()\n",
         | 
| 83 | 
            +
                "if has_cuda:\n",
         | 
| 84 | 
            +
                "    model_up.convert_to_fp16()\n",
         | 
| 85 | 
            +
                "model_up.to(device)\n",
         | 
| 86 | 
            +
                "model_up.load_state_dict(load_checkpoint('upsample-inpaint', device))\n",
         | 
| 87 | 
            +
                "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
         | 
| 88 | 
            +
               ]
         | 
| 89 | 
            +
              },
         | 
| 90 | 
            +
              {
         | 
| 91 | 
            +
               "cell_type": "code",
         | 
| 92 | 
            +
               "execution_count": null,
         | 
| 93 | 
            +
               "metadata": {},
         | 
| 94 | 
            +
               "outputs": [],
         | 
| 95 | 
            +
               "source": [
         | 
| 96 | 
            +
                "def show_images(batch: th.Tensor):\n",
         | 
| 97 | 
            +
                "    \"\"\" Display a batch of images inline. \"\"\"\n",
         | 
| 98 | 
            +
                "    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
         | 
| 99 | 
            +
                "    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
         | 
| 100 | 
            +
                "    display(Image.fromarray(reshaped.numpy()))\n",
         | 
| 101 | 
            +
                "\n",
         | 
| 102 | 
            +
                "def read_image(path: str, size: int = 256) -> Tuple[th.Tensor, th.Tensor]:\n",
         | 
| 103 | 
            +
                "    pil_img = Image.open(path).convert('RGB')\n",
         | 
| 104 | 
            +
                "    pil_img = pil_img.resize((size, size), resample=Image.BICUBIC)\n",
         | 
| 105 | 
            +
                "    img = np.array(pil_img)\n",
         | 
| 106 | 
            +
                "    return th.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1"
         | 
| 107 | 
            +
               ]
         | 
| 108 | 
            +
              },
         | 
| 109 | 
            +
              {
         | 
| 110 | 
            +
               "cell_type": "code",
         | 
| 111 | 
            +
               "execution_count": null,
         | 
| 112 | 
            +
               "metadata": {},
         | 
| 113 | 
            +
               "outputs": [],
         | 
| 114 | 
            +
               "source": [
         | 
| 115 | 
            +
                "# Sampling parameters\n",
         | 
| 116 | 
            +
                "prompt = \"a corgi in a field\"\n",
         | 
| 117 | 
            +
                "batch_size = 1\n",
         | 
| 118 | 
            +
                "guidance_scale = 5.0\n",
         | 
| 119 | 
            +
                "\n",
         | 
| 120 | 
            +
                "# Tune this parameter to control the sharpness of 256x256 images.\n",
         | 
| 121 | 
            +
                "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
         | 
| 122 | 
            +
                "upsample_temp = 0.997\n",
         | 
| 123 | 
            +
                "\n",
         | 
| 124 | 
            +
                "# Source image we are inpainting\n",
         | 
| 125 | 
            +
                "source_image_256 = read_image('grass.png', size=256)\n",
         | 
| 126 | 
            +
                "source_image_64 = read_image('grass.png', size=64)\n",
         | 
| 127 | 
            +
                "\n",
         | 
| 128 | 
            +
                "# The mask should always be a boolean 64x64 mask, and then we\n",
         | 
| 129 | 
            +
                "# can upsample it for the second stage.\n",
         | 
| 130 | 
            +
                "source_mask_64 = th.ones_like(source_image_64)[:, :1]\n",
         | 
| 131 | 
            +
                "source_mask_64[:, :, 20:] = 0\n",
         | 
| 132 | 
            +
                "source_mask_256 = F.interpolate(source_mask_64, (256, 256), mode='nearest')\n",
         | 
| 133 | 
            +
                "\n",
         | 
| 134 | 
            +
                "# Visualize the image we are inpainting\n",
         | 
| 135 | 
            +
                "show_images(source_image_256 * source_mask_256)"
         | 
| 136 | 
            +
               ]
         | 
| 137 | 
            +
              },
         | 
| 138 | 
            +
              {
         | 
| 139 | 
            +
               "cell_type": "code",
         | 
| 140 | 
            +
               "execution_count": null,
         | 
| 141 | 
            +
               "metadata": {},
         | 
| 142 | 
            +
               "outputs": [],
         | 
| 143 | 
            +
               "source": [
         | 
| 144 | 
            +
                "##############################\n",
         | 
| 145 | 
            +
                "# Sample from the base model #\n",
         | 
| 146 | 
            +
                "##############################\n",
         | 
| 147 | 
            +
                "\n",
         | 
| 148 | 
            +
                "# Create the text tokens to feed to the model.\n",
         | 
| 149 | 
            +
                "tokens = model.tokenizer.encode(prompt)\n",
         | 
| 150 | 
            +
                "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
         | 
| 151 | 
            +
                "    tokens, options['text_ctx']\n",
         | 
| 152 | 
            +
                ")\n",
         | 
| 153 | 
            +
                "\n",
         | 
| 154 | 
            +
                "# Create the classifier-free guidance tokens (empty)\n",
         | 
| 155 | 
            +
                "full_batch_size = batch_size * 2\n",
         | 
| 156 | 
            +
                "uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n",
         | 
| 157 | 
            +
                "    [], options['text_ctx']\n",
         | 
| 158 | 
            +
                ")\n",
         | 
| 159 | 
            +
                "\n",
         | 
| 160 | 
            +
                "# Pack the tokens together into model kwargs.\n",
         | 
| 161 | 
            +
                "model_kwargs = dict(\n",
         | 
| 162 | 
            +
                "    tokens=th.tensor(\n",
         | 
| 163 | 
            +
                "        [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n",
         | 
| 164 | 
            +
                "    ),\n",
         | 
| 165 | 
            +
                "    mask=th.tensor(\n",
         | 
| 166 | 
            +
                "        [mask] * batch_size + [uncond_mask] * batch_size,\n",
         | 
| 167 | 
            +
                "        dtype=th.bool,\n",
         | 
| 168 | 
            +
                "        device=device,\n",
         | 
| 169 | 
            +
                "    ),\n",
         | 
| 170 | 
            +
                "\n",
         | 
| 171 | 
            +
                "    # Masked inpainting image\n",
         | 
| 172 | 
            +
                "    inpaint_image=(source_image_64 * source_mask_64).repeat(full_batch_size, 1, 1, 1).to(device),\n",
         | 
| 173 | 
            +
                "    inpaint_mask=source_mask_64.repeat(full_batch_size, 1, 1, 1).to(device),\n",
         | 
| 174 | 
            +
                ")\n",
         | 
| 175 | 
            +
                "\n",
         | 
| 176 | 
            +
                "# Create an classifier-free guidance sampling function\n",
         | 
| 177 | 
            +
                "def model_fn(x_t, ts, **kwargs):\n",
         | 
| 178 | 
            +
                "    half = x_t[: len(x_t) // 2]\n",
         | 
| 179 | 
            +
                "    combined = th.cat([half, half], dim=0)\n",
         | 
| 180 | 
            +
                "    model_out = model(combined, ts, **kwargs)\n",
         | 
| 181 | 
            +
                "    eps, rest = model_out[:, :3], model_out[:, 3:]\n",
         | 
| 182 | 
            +
                "    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n",
         | 
| 183 | 
            +
                "    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n",
         | 
| 184 | 
            +
                "    eps = th.cat([half_eps, half_eps], dim=0)\n",
         | 
| 185 | 
            +
                "    return th.cat([eps, rest], dim=1)\n",
         | 
| 186 | 
            +
                "\n",
         | 
| 187 | 
            +
                "def denoised_fn(x_start):\n",
         | 
| 188 | 
            +
                "    # Force the model to have the exact right x_start predictions\n",
         | 
| 189 | 
            +
                "    # for the part of the image which is known.\n",
         | 
| 190 | 
            +
                "    return (\n",
         | 
| 191 | 
            +
                "        x_start * (1 - model_kwargs['inpaint_mask'])\n",
         | 
| 192 | 
            +
                "        + model_kwargs['inpaint_image'] * model_kwargs['inpaint_mask']\n",
         | 
| 193 | 
            +
                "    )\n",
         | 
| 194 | 
            +
                "\n",
         | 
| 195 | 
            +
                "# Sample from the base model.\n",
         | 
| 196 | 
            +
                "model.del_cache()\n",
         | 
| 197 | 
            +
                "samples = diffusion.p_sample_loop(\n",
         | 
| 198 | 
            +
                "    model_fn,\n",
         | 
| 199 | 
            +
                "    (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
         | 
| 200 | 
            +
                "    device=device,\n",
         | 
| 201 | 
            +
                "    clip_denoised=True,\n",
         | 
| 202 | 
            +
                "    progress=True,\n",
         | 
| 203 | 
            +
                "    model_kwargs=model_kwargs,\n",
         | 
| 204 | 
            +
                "    cond_fn=None,\n",
         | 
| 205 | 
            +
                "    denoised_fn=denoised_fn,\n",
         | 
| 206 | 
            +
                ")[:batch_size]\n",
         | 
| 207 | 
            +
                "model.del_cache()\n",
         | 
| 208 | 
            +
                "\n",
         | 
| 209 | 
            +
                "# Show the output\n",
         | 
| 210 | 
            +
                "show_images(samples)"
         | 
| 211 | 
            +
               ]
         | 
| 212 | 
            +
              },
         | 
| 213 | 
            +
              {
         | 
| 214 | 
            +
               "cell_type": "code",
         | 
| 215 | 
            +
               "execution_count": null,
         | 
| 216 | 
            +
               "metadata": {},
         | 
| 217 | 
            +
               "outputs": [],
         | 
| 218 | 
            +
               "source": [
         | 
| 219 | 
            +
                "##############################\n",
         | 
| 220 | 
            +
                "# Upsample the 64x64 samples #\n",
         | 
| 221 | 
            +
                "##############################\n",
         | 
| 222 | 
            +
                "\n",
         | 
| 223 | 
            +
                "tokens = model_up.tokenizer.encode(prompt)\n",
         | 
| 224 | 
            +
                "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
         | 
| 225 | 
            +
                "    tokens, options_up['text_ctx']\n",
         | 
| 226 | 
            +
                ")\n",
         | 
| 227 | 
            +
                "\n",
         | 
| 228 | 
            +
                "# Create the model conditioning dict.\n",
         | 
| 229 | 
            +
                "model_kwargs = dict(\n",
         | 
| 230 | 
            +
                "    # Low-res image to upsample.\n",
         | 
| 231 | 
            +
                "    low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
         | 
| 232 | 
            +
                "\n",
         | 
| 233 | 
            +
                "    # Text tokens\n",
         | 
| 234 | 
            +
                "    tokens=th.tensor(\n",
         | 
| 235 | 
            +
                "        [tokens] * batch_size, device=device\n",
         | 
| 236 | 
            +
                "    ),\n",
         | 
| 237 | 
            +
                "    mask=th.tensor(\n",
         | 
| 238 | 
            +
                "        [mask] * batch_size,\n",
         | 
| 239 | 
            +
                "        dtype=th.bool,\n",
         | 
| 240 | 
            +
                "        device=device,\n",
         | 
| 241 | 
            +
                "    ),\n",
         | 
| 242 | 
            +
                "\n",
         | 
| 243 | 
            +
                "    # Masked inpainting image.\n",
         | 
| 244 | 
            +
                "    inpaint_image=(source_image_256 * source_mask_256).repeat(batch_size, 1, 1, 1).to(device),\n",
         | 
| 245 | 
            +
                "    inpaint_mask=source_mask_256.repeat(batch_size, 1, 1, 1).to(device),\n",
         | 
| 246 | 
            +
                ")\n",
         | 
| 247 | 
            +
                "\n",
         | 
| 248 | 
            +
                "def denoised_fn(x_start):\n",
         | 
| 249 | 
            +
                "    # Force the model to have the exact right x_start predictions\n",
         | 
| 250 | 
            +
                "    # for the part of the image which is known.\n",
         | 
| 251 | 
            +
                "    return (\n",
         | 
| 252 | 
            +
                "        x_start * (1 - model_kwargs['inpaint_mask'])\n",
         | 
| 253 | 
            +
                "        + model_kwargs['inpaint_image'] * model_kwargs['inpaint_mask']\n",
         | 
| 254 | 
            +
                "    )\n",
         | 
| 255 | 
            +
                "\n",
         | 
| 256 | 
            +
                "# Sample from the base model.\n",
         | 
| 257 | 
            +
                "model_up.del_cache()\n",
         | 
| 258 | 
            +
                "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
         | 
| 259 | 
            +
                "up_samples = diffusion_up.p_sample_loop(\n",
         | 
| 260 | 
            +
                "    model_up,\n",
         | 
| 261 | 
            +
                "    up_shape,\n",
         | 
| 262 | 
            +
                "    noise=th.randn(up_shape, device=device) * upsample_temp,\n",
         | 
| 263 | 
            +
                "    device=device,\n",
         | 
| 264 | 
            +
                "    clip_denoised=True,\n",
         | 
| 265 | 
            +
                "    progress=True,\n",
         | 
| 266 | 
            +
                "    model_kwargs=model_kwargs,\n",
         | 
| 267 | 
            +
                "    cond_fn=None,\n",
         | 
| 268 | 
            +
                "    denoised_fn=denoised_fn,\n",
         | 
| 269 | 
            +
                ")[:batch_size]\n",
         | 
| 270 | 
            +
                "model_up.del_cache()\n",
         | 
| 271 | 
            +
                "\n",
         | 
| 272 | 
            +
                "# Show the output\n",
         | 
| 273 | 
            +
                "show_images(up_samples)"
         | 
| 274 | 
            +
               ]
         | 
| 275 | 
            +
              }
         | 
| 276 | 
            +
             ],
         | 
| 277 | 
            +
             "metadata": {
         | 
| 278 | 
            +
              "interpreter": {
         | 
| 279 | 
            +
               "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
         | 
| 280 | 
            +
              },
         | 
| 281 | 
            +
              "kernelspec": {
         | 
| 282 | 
            +
               "display_name": "Python 3",
         | 
| 283 | 
            +
               "language": "python",
         | 
| 284 | 
            +
               "name": "python3"
         | 
| 285 | 
            +
              },
         | 
| 286 | 
            +
              "language_info": {
         | 
| 287 | 
            +
               "codemirror_mode": {
         | 
| 288 | 
            +
                "name": "ipython",
         | 
| 289 | 
            +
                "version": 3
         | 
| 290 | 
            +
               },
         | 
| 291 | 
            +
               "file_extension": ".py",
         | 
| 292 | 
            +
               "mimetype": "text/x-python",
         | 
| 293 | 
            +
               "name": "python",
         | 
| 294 | 
            +
               "nbconvert_exporter": "python",
         | 
| 295 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 296 | 
            +
               "version": "3.7.3"
         | 
| 297 | 
            +
              },
         | 
| 298 | 
            +
              "accelerator": "GPU"
         | 
| 299 | 
            +
             },
         | 
| 300 | 
            +
             "nbformat": 4,
         | 
| 301 | 
            +
             "nbformat_minor": 2
         | 
| 302 | 
            +
            }
         | 
    	
        notebooks/text2im.ipynb
    ADDED
    
    | @@ -0,0 +1,251 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": null,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [],
         | 
| 8 | 
            +
               "source": [
         | 
| 9 | 
            +
                "# Run this line in Colab to install the package if it is\n",
         | 
| 10 | 
            +
                "# not already installed.\n",
         | 
| 11 | 
            +
                "!pip install git+https://github.com/openai/glide-text2im"
         | 
| 12 | 
            +
               ]
         | 
| 13 | 
            +
              },
         | 
| 14 | 
            +
              {
         | 
| 15 | 
            +
               "cell_type": "code",
         | 
| 16 | 
            +
               "execution_count": null,
         | 
| 17 | 
            +
               "metadata": {},
         | 
| 18 | 
            +
               "outputs": [],
         | 
| 19 | 
            +
               "source": [
         | 
| 20 | 
            +
                "from PIL import Image\n",
         | 
| 21 | 
            +
                "from IPython.display import display\n",
         | 
| 22 | 
            +
                "import torch as th\n",
         | 
| 23 | 
            +
                "\n",
         | 
| 24 | 
            +
                "from glide_text2im.download import load_checkpoint\n",
         | 
| 25 | 
            +
                "from glide_text2im.model_creation import (\n",
         | 
| 26 | 
            +
                "    create_model_and_diffusion,\n",
         | 
| 27 | 
            +
                "    model_and_diffusion_defaults,\n",
         | 
| 28 | 
            +
                "    model_and_diffusion_defaults_upsampler\n",
         | 
| 29 | 
            +
                ")"
         | 
| 30 | 
            +
               ]
         | 
| 31 | 
            +
              },
         | 
| 32 | 
            +
              {
         | 
| 33 | 
            +
               "cell_type": "code",
         | 
| 34 | 
            +
               "execution_count": null,
         | 
| 35 | 
            +
               "metadata": {},
         | 
| 36 | 
            +
               "outputs": [],
         | 
| 37 | 
            +
               "source": [
         | 
| 38 | 
            +
                "# This notebook supports both CPU and GPU.\n",
         | 
| 39 | 
            +
                "# On CPU, generating one sample may take on the order of 20 minutes.\n",
         | 
| 40 | 
            +
                "# On a GPU, it should be under a minute.\n",
         | 
| 41 | 
            +
                "\n",
         | 
| 42 | 
            +
                "has_cuda = th.cuda.is_available()\n",
         | 
| 43 | 
            +
                "device = th.device('cpu' if not has_cuda else 'cuda')"
         | 
| 44 | 
            +
               ]
         | 
| 45 | 
            +
              },
         | 
| 46 | 
            +
              {
         | 
| 47 | 
            +
               "cell_type": "code",
         | 
| 48 | 
            +
               "execution_count": null,
         | 
| 49 | 
            +
               "metadata": {},
         | 
| 50 | 
            +
               "outputs": [],
         | 
| 51 | 
            +
               "source": [
         | 
| 52 | 
            +
                "# Create base model.\n",
         | 
| 53 | 
            +
                "options = model_and_diffusion_defaults()\n",
         | 
| 54 | 
            +
                "options['use_fp16'] = has_cuda\n",
         | 
| 55 | 
            +
                "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
         | 
| 56 | 
            +
                "model, diffusion = create_model_and_diffusion(**options)\n",
         | 
| 57 | 
            +
                "model.eval()\n",
         | 
| 58 | 
            +
                "if has_cuda:\n",
         | 
| 59 | 
            +
                "    model.convert_to_fp16()\n",
         | 
| 60 | 
            +
                "model.to(device)\n",
         | 
| 61 | 
            +
                "model.load_state_dict(load_checkpoint('base', device))\n",
         | 
| 62 | 
            +
                "print('total base parameters', sum(x.numel() for x in model.parameters()))"
         | 
| 63 | 
            +
               ]
         | 
| 64 | 
            +
              },
         | 
| 65 | 
            +
              {
         | 
| 66 | 
            +
               "cell_type": "code",
         | 
| 67 | 
            +
               "execution_count": null,
         | 
| 68 | 
            +
               "metadata": {},
         | 
| 69 | 
            +
               "outputs": [],
         | 
| 70 | 
            +
               "source": [
         | 
| 71 | 
            +
                "# Create upsampler model.\n",
         | 
| 72 | 
            +
                "options_up = model_and_diffusion_defaults_upsampler()\n",
         | 
| 73 | 
            +
                "options_up['use_fp16'] = has_cuda\n",
         | 
| 74 | 
            +
                "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
         | 
| 75 | 
            +
                "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
         | 
| 76 | 
            +
                "model_up.eval()\n",
         | 
| 77 | 
            +
                "if has_cuda:\n",
         | 
| 78 | 
            +
                "    model_up.convert_to_fp16()\n",
         | 
| 79 | 
            +
                "model_up.to(device)\n",
         | 
| 80 | 
            +
                "model_up.load_state_dict(load_checkpoint('upsample', device))\n",
         | 
| 81 | 
            +
                "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
         | 
| 82 | 
            +
               ]
         | 
| 83 | 
            +
              },
         | 
| 84 | 
            +
              {
         | 
| 85 | 
            +
               "cell_type": "code",
         | 
| 86 | 
            +
               "execution_count": null,
         | 
| 87 | 
            +
               "metadata": {},
         | 
| 88 | 
            +
               "outputs": [],
         | 
| 89 | 
            +
               "source": [
         | 
| 90 | 
            +
                "def show_images(batch: th.Tensor):\n",
         | 
| 91 | 
            +
                "    \"\"\" Display a batch of images inline. \"\"\"\n",
         | 
| 92 | 
            +
                "    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
         | 
| 93 | 
            +
                "    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
         | 
| 94 | 
            +
                "    display(Image.fromarray(reshaped.numpy()))"
         | 
| 95 | 
            +
               ]
         | 
| 96 | 
            +
              },
         | 
| 97 | 
            +
              {
         | 
| 98 | 
            +
               "cell_type": "code",
         | 
| 99 | 
            +
               "execution_count": null,
         | 
| 100 | 
            +
               "metadata": {},
         | 
| 101 | 
            +
               "outputs": [],
         | 
| 102 | 
            +
               "source": [
         | 
| 103 | 
            +
                "# Sampling parameters\n",
         | 
| 104 | 
            +
                "prompt = \"an oil painting of a corgi\"\n",
         | 
| 105 | 
            +
                "batch_size = 1\n",
         | 
| 106 | 
            +
                "guidance_scale = 3.0\n",
         | 
| 107 | 
            +
                "\n",
         | 
| 108 | 
            +
                "# Tune this parameter to control the sharpness of 256x256 images.\n",
         | 
| 109 | 
            +
                "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
         | 
| 110 | 
            +
                "upsample_temp = 0.997"
         | 
| 111 | 
            +
               ]
         | 
| 112 | 
            +
              },
         | 
| 113 | 
            +
              {
         | 
| 114 | 
            +
               "cell_type": "code",
         | 
| 115 | 
            +
               "execution_count": null,
         | 
| 116 | 
            +
               "metadata": {},
         | 
| 117 | 
            +
               "outputs": [],
         | 
| 118 | 
            +
               "source": [
         | 
| 119 | 
            +
                "##############################\n",
         | 
| 120 | 
            +
                "# Sample from the base model #\n",
         | 
| 121 | 
            +
                "##############################\n",
         | 
| 122 | 
            +
                "\n",
         | 
| 123 | 
            +
                "# Create the text tokens to feed to the model.\n",
         | 
| 124 | 
            +
                "tokens = model.tokenizer.encode(prompt)\n",
         | 
| 125 | 
            +
                "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
         | 
| 126 | 
            +
                "    tokens, options['text_ctx']\n",
         | 
| 127 | 
            +
                ")\n",
         | 
| 128 | 
            +
                "\n",
         | 
| 129 | 
            +
                "# Create the classifier-free guidance tokens (empty)\n",
         | 
| 130 | 
            +
                "full_batch_size = batch_size * 2\n",
         | 
| 131 | 
            +
                "uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n",
         | 
| 132 | 
            +
                "    [], options['text_ctx']\n",
         | 
| 133 | 
            +
                ")\n",
         | 
| 134 | 
            +
                "\n",
         | 
| 135 | 
            +
                "# Pack the tokens together into model kwargs.\n",
         | 
| 136 | 
            +
                "model_kwargs = dict(\n",
         | 
| 137 | 
            +
                "    tokens=th.tensor(\n",
         | 
| 138 | 
            +
                "        [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n",
         | 
| 139 | 
            +
                "    ),\n",
         | 
| 140 | 
            +
                "    mask=th.tensor(\n",
         | 
| 141 | 
            +
                "        [mask] * batch_size + [uncond_mask] * batch_size,\n",
         | 
| 142 | 
            +
                "        dtype=th.bool,\n",
         | 
| 143 | 
            +
                "        device=device,\n",
         | 
| 144 | 
            +
                "    ),\n",
         | 
| 145 | 
            +
                ")\n",
         | 
| 146 | 
            +
                "\n",
         | 
| 147 | 
            +
                "# Create a classifier-free guidance sampling function\n",
         | 
| 148 | 
            +
                "def model_fn(x_t, ts, **kwargs):\n",
         | 
| 149 | 
            +
                "    half = x_t[: len(x_t) // 2]\n",
         | 
| 150 | 
            +
                "    combined = th.cat([half, half], dim=0)\n",
         | 
| 151 | 
            +
                "    model_out = model(combined, ts, **kwargs)\n",
         | 
| 152 | 
            +
                "    eps, rest = model_out[:, :3], model_out[:, 3:]\n",
         | 
| 153 | 
            +
                "    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n",
         | 
| 154 | 
            +
                "    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n",
         | 
| 155 | 
            +
                "    eps = th.cat([half_eps, half_eps], dim=0)\n",
         | 
| 156 | 
            +
                "    return th.cat([eps, rest], dim=1)\n",
         | 
| 157 | 
            +
                "\n",
         | 
| 158 | 
            +
                "# Sample from the base model.\n",
         | 
| 159 | 
            +
                "model.del_cache()\n",
         | 
| 160 | 
            +
                "samples = diffusion.p_sample_loop(\n",
         | 
| 161 | 
            +
                "    model_fn,\n",
         | 
| 162 | 
            +
                "    (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
         | 
| 163 | 
            +
                "    device=device,\n",
         | 
| 164 | 
            +
                "    clip_denoised=True,\n",
         | 
| 165 | 
            +
                "    progress=True,\n",
         | 
| 166 | 
            +
                "    model_kwargs=model_kwargs,\n",
         | 
| 167 | 
            +
                "    cond_fn=None,\n",
         | 
| 168 | 
            +
                ")[:batch_size]\n",
         | 
| 169 | 
            +
                "model.del_cache()\n",
         | 
| 170 | 
            +
                "\n",
         | 
| 171 | 
            +
                "# Show the output\n",
         | 
| 172 | 
            +
                "show_images(samples)"
         | 
| 173 | 
            +
               ]
         | 
| 174 | 
            +
              },
         | 
| 175 | 
            +
              {
         | 
| 176 | 
            +
               "cell_type": "code",
         | 
| 177 | 
            +
               "execution_count": null,
         | 
| 178 | 
            +
               "metadata": {},
         | 
| 179 | 
            +
               "outputs": [],
         | 
| 180 | 
            +
               "source": [
         | 
| 181 | 
            +
                "##############################\n",
         | 
| 182 | 
            +
                "# Upsample the 64x64 samples #\n",
         | 
| 183 | 
            +
                "##############################\n",
         | 
| 184 | 
            +
                "\n",
         | 
| 185 | 
            +
                "tokens = model_up.tokenizer.encode(prompt)\n",
         | 
| 186 | 
            +
                "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
         | 
| 187 | 
            +
                "    tokens, options_up['text_ctx']\n",
         | 
| 188 | 
            +
                ")\n",
         | 
| 189 | 
            +
                "\n",
         | 
| 190 | 
            +
                "# Create the model conditioning dict.\n",
         | 
| 191 | 
            +
                "model_kwargs = dict(\n",
         | 
| 192 | 
            +
                "    # Low-res image to upsample.\n",
         | 
| 193 | 
            +
                "    low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
         | 
| 194 | 
            +
                "\n",
         | 
| 195 | 
            +
                "    # Text tokens\n",
         | 
| 196 | 
            +
                "    tokens=th.tensor(\n",
         | 
| 197 | 
            +
                "        [tokens] * batch_size, device=device\n",
         | 
| 198 | 
            +
                "    ),\n",
         | 
| 199 | 
            +
                "    mask=th.tensor(\n",
         | 
| 200 | 
            +
                "        [mask] * batch_size,\n",
         | 
| 201 | 
            +
                "        dtype=th.bool,\n",
         | 
| 202 | 
            +
                "        device=device,\n",
         | 
| 203 | 
            +
                "    ),\n",
         | 
| 204 | 
            +
                ")\n",
         | 
| 205 | 
            +
                "\n",
         | 
| 206 | 
            +
                "# Sample from the base model.\n",
         | 
| 207 | 
            +
                "model_up.del_cache()\n",
         | 
| 208 | 
            +
                "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
         | 
| 209 | 
            +
                "up_samples = diffusion_up.ddim_sample_loop(\n",
         | 
| 210 | 
            +
                "    model_up,\n",
         | 
| 211 | 
            +
                "    up_shape,\n",
         | 
| 212 | 
            +
                "    noise=th.randn(up_shape, device=device) * upsample_temp,\n",
         | 
| 213 | 
            +
                "    device=device,\n",
         | 
| 214 | 
            +
                "    clip_denoised=True,\n",
         | 
| 215 | 
            +
                "    progress=True,\n",
         | 
| 216 | 
            +
                "    model_kwargs=model_kwargs,\n",
         | 
| 217 | 
            +
                "    cond_fn=None,\n",
         | 
| 218 | 
            +
                ")[:batch_size]\n",
         | 
| 219 | 
            +
                "model_up.del_cache()\n",
         | 
| 220 | 
            +
                "\n",
         | 
| 221 | 
            +
                "# Show the output\n",
         | 
| 222 | 
            +
                "show_images(up_samples)"
         | 
| 223 | 
            +
               ]
         | 
| 224 | 
            +
              }
         | 
| 225 | 
            +
             ],
         | 
| 226 | 
            +
             "metadata": {
         | 
| 227 | 
            +
              "interpreter": {
         | 
| 228 | 
            +
               "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
         | 
| 229 | 
            +
              },
         | 
| 230 | 
            +
              "kernelspec": {
         | 
| 231 | 
            +
               "display_name": "Python 3",
         | 
| 232 | 
            +
               "language": "python",
         | 
| 233 | 
            +
               "name": "python3"
         | 
| 234 | 
            +
              },
         | 
| 235 | 
            +
              "language_info": {
         | 
| 236 | 
            +
               "codemirror_mode": {
         | 
| 237 | 
            +
                "name": "ipython",
         | 
| 238 | 
            +
                "version": 3
         | 
| 239 | 
            +
               },
         | 
| 240 | 
            +
               "file_extension": ".py",
         | 
| 241 | 
            +
               "mimetype": "text/x-python",
         | 
| 242 | 
            +
               "name": "python",
         | 
| 243 | 
            +
               "nbconvert_exporter": "python",
         | 
| 244 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 245 | 
            +
               "version": "3.7.3"
         | 
| 246 | 
            +
              },
         | 
| 247 | 
            +
              "accelerator": "GPU"
         | 
| 248 | 
            +
             },
         | 
| 249 | 
            +
             "nbformat": 4,
         | 
| 250 | 
            +
             "nbformat_minor": 2
         | 
| 251 | 
            +
            }
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            git+https://github.com/openai/glide-text2im.git
         | 
| 2 | 
            +
            fastapi
         | 
| 3 | 
            +
            uvicorn
         | 
| 4 | 
            +
            regex
         | 
    	
        setup.py
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from setuptools import setup
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            setup(
         | 
| 4 | 
            +
                name="glide-text2im",
         | 
| 5 | 
            +
                packages=[
         | 
| 6 | 
            +
                    "glide_text2im",
         | 
| 7 | 
            +
                    "glide_text2im.clip",
         | 
| 8 | 
            +
                    "glide_text2im.tokenizer",
         | 
| 9 | 
            +
                ],
         | 
| 10 | 
            +
                package_data={
         | 
| 11 | 
            +
                    "glide_text2im.tokenizer": [
         | 
| 12 | 
            +
                        "bpe_simple_vocab_16e6.txt.gz",
         | 
| 13 | 
            +
                        "encoder.json.gz",
         | 
| 14 | 
            +
                        "vocab.bpe.gz",
         | 
| 15 | 
            +
                    ],
         | 
| 16 | 
            +
                    "glide_text2im.clip": ["config.yaml"],
         | 
| 17 | 
            +
                },
         | 
| 18 | 
            +
                install_requires=[
         | 
| 19 | 
            +
                    "Pillow",
         | 
| 20 | 
            +
                    "attrs",
         | 
| 21 | 
            +
                    "torch",
         | 
| 22 | 
            +
                    "filelock",
         | 
| 23 | 
            +
                    "requests",
         | 
| 24 | 
            +
                    "tqdm",
         | 
| 25 | 
            +
                    "ftfy",
         | 
| 26 | 
            +
                    "regex",
         | 
| 27 | 
            +
                ],
         | 
| 28 | 
            +
                author="OpenAI",
         | 
| 29 | 
            +
            )
         | 
