Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Upload 57 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +3 -0
- README.md +1 -1
- app/hydit_app.py +170 -0
- app/lang/en.csv +22 -0
- app/lang/zh.csv +22 -0
- asset/Hunyuan_DiT_Tech_Report_05140553.pdf +3 -0
- asset/chinese elements understanding.png +3 -0
- asset/cover.png +0 -0
- asset/framework.png +0 -0
- asset/logo.png +0 -0
- asset/long text understanding.png +3 -0
- asset/mllm.png +0 -0
- asset/radar.png +0 -0
- dialoggen/dialoggen_demo.py +172 -0
- dialoggen/images/demo1.jpeg +0 -0
- dialoggen/images/demo2.jpeg +0 -0
- dialoggen/llava/__init__.py +1 -0
- dialoggen/llava/constants.py +13 -0
- dialoggen/llava/conversation.py +396 -0
- dialoggen/llava/mm_utils.py +247 -0
- dialoggen/llava/model/__init__.py +6 -0
- dialoggen/llava/model/apply_delta.py +48 -0
- dialoggen/llava/model/builder.py +167 -0
- dialoggen/llava/model/consolidate.py +29 -0
- dialoggen/llava/model/language_model/llava_llama.py +158 -0
- dialoggen/llava/model/language_model/llava_mistral.py +158 -0
- dialoggen/llava/model/language_model/llava_mpt.py +97 -0
- dialoggen/llava/model/llava_arch.py +368 -0
- dialoggen/llava/model/make_delta.py +52 -0
- dialoggen/llava/model/multimodal_encoder/builder.py +11 -0
- dialoggen/llava/model/multimodal_encoder/clip_encoder.py +88 -0
- dialoggen/llava/model/multimodal_projector/builder.py +51 -0
- dialoggen/llava/model/utils.py +20 -0
- dialoggen/llava/utils.py +126 -0
- en.csv +22 -0
- environment.yml +8 -0
- example_prompts.txt +28 -0
- hydit/__init__.py +0 -0
- hydit/config.py +67 -0
- hydit/constants.py +62 -0
- hydit/diffusion/__init__.py +0 -0
- hydit/diffusion/pipeline.py +830 -0
- hydit/inference.py +389 -0
- hydit/modules/__init__.py +0 -0
- hydit/modules/attn_layers.py +377 -0
- hydit/modules/embedders.py +111 -0
- hydit/modules/models.py +409 -0
- hydit/modules/norm_layers.py +68 -0
- hydit/modules/poolers.py +39 -0
- hydit/modules/posemb_layers.py +225 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            asset/chinese[[:space:]]elements[[:space:]]understanding.png filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            asset/Hunyuan_DiT_Tech_Report_05140553.pdf filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            asset/long[[:space:]]text[[:space:]]understanding.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    CHANGED
    
    | @@ -5,7 +5,7 @@ colorFrom: indigo | |
| 5 | 
             
            colorTo: red
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 4.31.1
         | 
| 8 | 
            -
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
| 11 |  | 
|  | |
| 5 | 
             
            colorTo: red
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 4.31.1
         | 
| 8 | 
            +
            app_file: app/hydit_app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
| 11 |  | 
    	
        app/hydit_app.py
    ADDED
    
    | @@ -0,0 +1,170 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import sys
         | 
| 6 | 
            +
            sys.path.insert(0, str(Path(__file__).parent.parent))
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from hydit.constants import SAMPLER_FACTORY
         | 
| 9 | 
            +
            from sample_t2i import inferencer
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            ROOT = Path(__file__).parent.parent
         | 
| 12 | 
            +
            SAMPLERS = list(SAMPLER_FACTORY.keys())
         | 
| 13 | 
            +
            SIZES = {
         | 
| 14 | 
            +
                "square": (1024, 1024),
         | 
| 15 | 
            +
                "landscape": (768, 1280),
         | 
| 16 | 
            +
                "portrait": (1280, 768),
         | 
| 17 | 
            +
            }
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def get_strings(lang):
         | 
| 20 | 
            +
                lang_file = Path(f"app/lang/{lang}.csv")
         | 
| 21 | 
            +
                strings = pd.read_csv(lang_file, header=0)
         | 
| 22 | 
            +
                strings = strings.set_index("key")['value'].to_dict()
         | 
| 23 | 
            +
                return strings
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            args, gen, enhancer = inferencer()
         | 
| 27 | 
            +
            strings = get_strings("en")
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def infer(
         | 
| 31 | 
            +
                prompt,
         | 
| 32 | 
            +
                negative_prompt,
         | 
| 33 | 
            +
                seed,
         | 
| 34 | 
            +
                cfg_scale,
         | 
| 35 | 
            +
                infer_steps,
         | 
| 36 | 
            +
                oriW, oriH,
         | 
| 37 | 
            +
                sampler,
         | 
| 38 | 
            +
                size,
         | 
| 39 | 
            +
                enhance
         | 
| 40 | 
            +
            ):
         | 
| 41 | 
            +
                if enhance and enhancer is not None:
         | 
| 42 | 
            +
                    success, enhanced_prompt = enhancer(prompt)
         | 
| 43 | 
            +
                    if not success:
         | 
| 44 | 
            +
                        fail_image = Image.open(ROOT / 'app/fail.png')
         | 
| 45 | 
            +
                        return fail_image
         | 
| 46 | 
            +
                else:
         | 
| 47 | 
            +
                    enhanced_prompt = None
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                height, width = SIZES[size]
         | 
| 50 | 
            +
                results = gen.predict(prompt,
         | 
| 51 | 
            +
                                      height=height,
         | 
| 52 | 
            +
                                      width=width,
         | 
| 53 | 
            +
                                      seed=seed,
         | 
| 54 | 
            +
                                      enhanced_prompt=enhanced_prompt,
         | 
| 55 | 
            +
                                      negative_prompt=negative_prompt,
         | 
| 56 | 
            +
                                      infer_steps=infer_steps,
         | 
| 57 | 
            +
                                      guidance_scale=cfg_scale,
         | 
| 58 | 
            +
                                      batch_size=1,
         | 
| 59 | 
            +
                                      src_size_cond=(oriW, oriH),
         | 
| 60 | 
            +
                                      sampler=sampler,
         | 
| 61 | 
            +
                                      )
         | 
| 62 | 
            +
                image = results['images'][0]
         | 
| 63 | 
            +
                return image
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def ui():
         | 
| 67 | 
            +
                block = gr.Blocks()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                description = f"""
         | 
| 70 | 
            +
                # {strings['title']}
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                ## {strings['desc']}
         | 
| 73 | 
            +
                
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                with block:
         | 
| 77 | 
            +
                    with gr.Row():
         | 
| 78 | 
            +
                        gr.Markdown(description)
         | 
| 79 | 
            +
                    with gr.Row():
         | 
| 80 | 
            +
                        with gr.Column():
         | 
| 81 | 
            +
                            with gr.Row():
         | 
| 82 | 
            +
                                size = gr.Radio(
         | 
| 83 | 
            +
                                    label=strings['size'], choices=[
         | 
| 84 | 
            +
                                        (strings['square'], 'square'),
         | 
| 85 | 
            +
                                        (strings['landscape'], 'landscape'),
         | 
| 86 | 
            +
                                        (strings['portrait'], 'portrait'),
         | 
| 87 | 
            +
                                    ],
         | 
| 88 | 
            +
                                    value="square"
         | 
| 89 | 
            +
                                )
         | 
| 90 | 
            +
                            prompt = gr.Textbox(label=strings['prompt'], value=strings['default prompt'], lines=3)
         | 
| 91 | 
            +
                            with gr.Row():
         | 
| 92 | 
            +
                                infer_steps = gr.Slider(
         | 
| 93 | 
            +
                                    label=strings['infer steps'], minimum=1, maximum=200, value=100, step=1,
         | 
| 94 | 
            +
                                )
         | 
| 95 | 
            +
                                seed = gr.Number(
         | 
| 96 | 
            +
                                    label=strings['seed'], minimum=-1, maximum=1_000_000_000, value=1, step=1, precision=0,
         | 
| 97 | 
            +
                                )
         | 
| 98 | 
            +
                                enhance = gr.Checkbox(
         | 
| 99 | 
            +
                                    label=strings['enhance'], value=enhancer is not None, interactive=True,
         | 
| 100 | 
            +
                                )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                            with gr.Accordion(
         | 
| 103 | 
            +
                                strings['accordion'], open=False
         | 
| 104 | 
            +
                            ):
         | 
| 105 | 
            +
                                with gr.Row():
         | 
| 106 | 
            +
                                    negative_prompt = gr.Textbox(label=strings['negative_prompt'],
         | 
| 107 | 
            +
                                                                 value=gen.default_negative_prompt,
         | 
| 108 | 
            +
                                                                 lines=2,
         | 
| 109 | 
            +
                                                                 )
         | 
| 110 | 
            +
                                with gr.Row():
         | 
| 111 | 
            +
                                    sampler = gr.Dropdown(SAMPLERS, label=strings['sampler'], value="ddpm")
         | 
| 112 | 
            +
                                    cfg_scale = gr.Slider(
         | 
| 113 | 
            +
                                        label=strings['cfg'], minimum=1.0, maximum=16.0, value=6.0, step=1
         | 
| 114 | 
            +
                                    )
         | 
| 115 | 
            +
                                    oriW = gr.Number(
         | 
| 116 | 
            +
                                        label=strings['width cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
         | 
| 117 | 
            +
                                        min_width=80,
         | 
| 118 | 
            +
                                    )
         | 
| 119 | 
            +
                                    oriH = gr.Number(
         | 
| 120 | 
            +
                                        label=strings['height cond'], minimum=1024, maximum=4096, value=1024, step=64, precision=0,
         | 
| 121 | 
            +
                                        min_width=80,
         | 
| 122 | 
            +
                                    )
         | 
| 123 | 
            +
                            with gr.Row():
         | 
| 124 | 
            +
                                advanced_button = gr.Button(strings['run'])
         | 
| 125 | 
            +
                        with gr.Column():
         | 
| 126 | 
            +
                            default_img = Image.open(ROOT / 'app/default.png')
         | 
| 127 | 
            +
                            output_img = gr.Image(
         | 
| 128 | 
            +
                                label=strings['generated image'],
         | 
| 129 | 
            +
                                interactive=False,
         | 
| 130 | 
            +
                                format='png',
         | 
| 131 | 
            +
                                value=default_img,
         | 
| 132 | 
            +
                            )
         | 
| 133 | 
            +
                        advanced_button.click(
         | 
| 134 | 
            +
                            fn=infer,
         | 
| 135 | 
            +
                            inputs=[
         | 
| 136 | 
            +
                                prompt, negative_prompt, seed, cfg_scale, infer_steps,
         | 
| 137 | 
            +
                                oriW, oriH, sampler, size, enhance,
         | 
| 138 | 
            +
                            ],
         | 
| 139 | 
            +
                            outputs=output_img,
         | 
| 140 | 
            +
                        )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    with gr.Row():
         | 
| 143 | 
            +
                        gr.Examples([
         | 
| 144 | 
            +
                            ['一只小猫'],
         | 
| 145 | 
            +
                            ['现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景'],
         | 
| 146 | 
            +
                            ['一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影'],
         | 
| 147 | 
            +
                            ['飞流直下三千尺,疑是银河落九天'],
         | 
| 148 | 
            +
                            ['一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。'],
         | 
| 149 | 
            +
                            ['麻婆豆腐'],
         | 
| 150 | 
            +
                            ['苏州园林'],
         | 
| 151 | 
            +
                            ['一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子'],
         | 
| 152 | 
            +
                            ['请画出“忽如一夜春风来 千树万树梨花开”'],
         | 
| 153 | 
            +
                            ['请将“杞人忧天”的样子画出来'],
         | 
| 154 | 
            +
                            ['枯藤老树昏鸦,小桥流水人家'],
         | 
| 155 | 
            +
                            ['湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。'],
         | 
| 156 | 
            +
                            ['一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头'],
         | 
| 157 | 
            +
                            ['臭豆腐'],
         | 
| 158 | 
            +
                            ['九寨沟'],
         | 
| 159 | 
            +
                            ['俗语“鲤鱼跃龙门”'],
         | 
| 160 | 
            +
                            ['风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景'],
         | 
| 161 | 
            +
                        ],
         | 
| 162 | 
            +
                        [prompt],
         | 
| 163 | 
            +
                        label=strings['examples']
         | 
| 164 | 
            +
                        )
         | 
| 165 | 
            +
                return block
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            if __name__ == "__main__":
         | 
| 169 | 
            +
                interface = ui()
         | 
| 170 | 
            +
                interface.launch()
         | 
    	
        app/lang/en.csv
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            key,value
         | 
| 2 | 
            +
            size,Size
         | 
| 3 | 
            +
            sampler,Sampler
         | 
| 4 | 
            +
            prompt,Prompt
         | 
| 5 | 
            +
            default prompt,"A cute cat"
         | 
| 6 | 
            +
            negative_prompt,Negative Prompt
         | 
| 7 | 
            +
            seed,Seed
         | 
| 8 | 
            +
            cfg,CFG Scale
         | 
| 9 | 
            +
            infer steps,Sampling Steps
         | 
| 10 | 
            +
            batch size,Batch Size
         | 
| 11 | 
            +
            width cond,Width Cond
         | 
| 12 | 
            +
            height cond,Height Cond
         | 
| 13 | 
            +
            enhance,Prompt Enhancement
         | 
| 14 | 
            +
            run,Submit
         | 
| 15 | 
            +
            square,Square(1024x1024)
         | 
| 16 | 
            +
            landscape,Landscape(1280x768)
         | 
| 17 | 
            +
            portrait,Portrait(768x1280)
         | 
| 18 | 
            +
            accordion,Advanced Options
         | 
| 19 | 
            +
            generated image,HunYuanDiT Generated Image
         | 
| 20 | 
            +
            examples,More Examples
         | 
| 21 | 
            +
            title,Hunyuan-DiT
         | 
| 22 | 
            +
            desc,A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
         | 
    	
        app/lang/zh.csv
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            key,value
         | 
| 2 | 
            +
            size,尺寸
         | 
| 3 | 
            +
            sampler,采样器
         | 
| 4 | 
            +
            prompt,文本描述
         | 
| 5 | 
            +
            default prompt,"一只可爱的猫"
         | 
| 6 | 
            +
            negative_prompt,负向词
         | 
| 7 | 
            +
            seed,种子
         | 
| 8 | 
            +
            cfg,CFG系数
         | 
| 9 | 
            +
            infer steps,采样步数
         | 
| 10 | 
            +
            batch size,批大小
         | 
| 11 | 
            +
            width cond,宽度条件
         | 
| 12 | 
            +
            height cond,高度条件
         | 
| 13 | 
            +
            enhance,文本增强
         | 
| 14 | 
            +
            run,提交生成
         | 
| 15 | 
            +
            square,方形(1024x1024)
         | 
| 16 | 
            +
            portrait,竖屏(1280x768)
         | 
| 17 | 
            +
            landscape,横屏(768x1280)
         | 
| 18 | 
            +
            accordion,高级设置
         | 
| 19 | 
            +
            generated image,HunYuanDiT 生成
         | 
| 20 | 
            +
            examples,更多示例
         | 
| 21 | 
            +
            title,混元-DiT
         | 
| 22 | 
            +
            desc,具有细粒度中文理解的高性能多分辨率 Diffusion Transformer 模型
         | 
    	
        asset/Hunyuan_DiT_Tech_Report_05140553.pdf
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:3f8514b002ba3bb4704575096683f65e09df06693a54bf3004f0b351138ab1e5
         | 
| 3 | 
            +
            size 42132252
         | 
    	
        asset/chinese elements understanding.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        asset/cover.png
    ADDED
    
    |   | 
    	
        asset/framework.png
    ADDED
    
    |   | 
    	
        asset/logo.png
    ADDED
    
    |   | 
    	
        asset/long text understanding.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        asset/mllm.png
    ADDED
    
    |   | 
    	
        asset/radar.png
    ADDED
    
    |   | 
    	
        dialoggen/dialoggen_demo.py
    ADDED
    
    | @@ -0,0 +1,172 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            # 添加当前命令行运行的目录到 sys.path
         | 
| 6 | 
            +
            sys.path.append(os.getcwd()+"/dialoggen")
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            from llava.constants import (
         | 
| 10 | 
            +
                IMAGE_TOKEN_INDEX,
         | 
| 11 | 
            +
                DEFAULT_IMAGE_TOKEN,
         | 
| 12 | 
            +
                DEFAULT_IM_START_TOKEN,
         | 
| 13 | 
            +
                DEFAULT_IM_END_TOKEN,
         | 
| 14 | 
            +
                IMAGE_PLACEHOLDER,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from llava.conversation import conv_templates, SeparatorStyle
         | 
| 17 | 
            +
            from llava.model.builder import load_pretrained_model
         | 
| 18 | 
            +
            from llava.utils import disable_torch_init
         | 
| 19 | 
            +
            from llava.mm_utils import (
         | 
| 20 | 
            +
                process_images,
         | 
| 21 | 
            +
                tokenizer_image_token,
         | 
| 22 | 
            +
                get_model_name_from_path,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            import requests
         | 
| 26 | 
            +
            from PIL import Image
         | 
| 27 | 
            +
            from io import BytesIO
         | 
| 28 | 
            +
            import re
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def image_parser(image_file, sep=','):
         | 
| 32 | 
            +
                out = image_file.split(sep)
         | 
| 33 | 
            +
                return out
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def load_image(image_file):
         | 
| 37 | 
            +
                if image_file.startswith("http") or image_file.startswith("https"):
         | 
| 38 | 
            +
                    response = requests.get(image_file)
         | 
| 39 | 
            +
                    image = Image.open(BytesIO(response.content)).convert("RGB")
         | 
| 40 | 
            +
                else:
         | 
| 41 | 
            +
                    image = Image.open(image_file).convert("RGB")
         | 
| 42 | 
            +
                return image
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def load_images(image_files):
         | 
| 46 | 
            +
                out = []
         | 
| 47 | 
            +
                for image_file in image_files:
         | 
| 48 | 
            +
                    image = load_image(image_file)
         | 
| 49 | 
            +
                    out.append(image)
         | 
| 50 | 
            +
                return out
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def init_dialoggen_model(model_path, model_base=None):
         | 
| 54 | 
            +
                model_name = get_model_name_from_path(model_path)
         | 
| 55 | 
            +
                tokenizer, model, image_processor, context_len = load_pretrained_model(
         | 
| 56 | 
            +
                    model_path, model_base, model_name, llava_type_model=True)
         | 
| 57 | 
            +
                return {"tokenizer": tokenizer,
         | 
| 58 | 
            +
                        "model": model,
         | 
| 59 | 
            +
                        "image_processor": image_processor}
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def eval_model(models,
         | 
| 63 | 
            +
                           query='详细描述一下这张图片',
         | 
| 64 | 
            +
                           image_file=None,
         | 
| 65 | 
            +
                           sep=',',
         | 
| 66 | 
            +
                           temperature=0.2,
         | 
| 67 | 
            +
                           top_p=None,
         | 
| 68 | 
            +
                           num_beams=1,
         | 
| 69 | 
            +
                           max_new_tokens=512,
         | 
| 70 | 
            +
                           ):
         | 
| 71 | 
            +
                # Model
         | 
| 72 | 
            +
                disable_torch_init()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                qs = query
         | 
| 75 | 
            +
                image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
         | 
| 76 | 
            +
                if IMAGE_PLACEHOLDER in qs:
         | 
| 77 | 
            +
                    if models["model"].config.mm_use_im_start_end:
         | 
| 78 | 
            +
                        qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
         | 
| 81 | 
            +
                else:
         | 
| 82 | 
            +
                    if models["model"].config.mm_use_im_start_end:
         | 
| 83 | 
            +
                        qs = image_token_se + "\n" + qs
         | 
| 84 | 
            +
                    else:
         | 
| 85 | 
            +
                        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                conv = conv_templates['llava_v1'].copy()
         | 
| 88 | 
            +
                conv.append_message(conv.roles[0], qs)
         | 
| 89 | 
            +
                conv.append_message(conv.roles[1], None)
         | 
| 90 | 
            +
                prompt = conv.get_prompt()
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                if image_file is not None:
         | 
| 93 | 
            +
                    image_files = image_parser(image_file, sep=sep)
         | 
| 94 | 
            +
                    images = load_images(image_files)
         | 
| 95 | 
            +
                    image_sizes = [x.size for x in images]
         | 
| 96 | 
            +
                    images_tensor = process_images(
         | 
| 97 | 
            +
                        images,
         | 
| 98 | 
            +
                        models["image_processor"],
         | 
| 99 | 
            +
                        models["model"].config
         | 
| 100 | 
            +
                    ).to(models["model"].device, dtype=torch.float16)
         | 
| 101 | 
            +
                else:
         | 
| 102 | 
            +
                    # fomatted input as training data
         | 
| 103 | 
            +
                    image_sizes = [(1024, 1024)]
         | 
| 104 | 
            +
                    images_tensor = torch.zeros(1, 5, 3, models["image_processor"].crop_size["height"], models["image_processor"].crop_size["width"])
         | 
| 105 | 
            +
                    images_tensor = images_tensor.to(models["model"].device, dtype=torch.float16)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                input_ids = (
         | 
| 108 | 
            +
                    tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
         | 
| 109 | 
            +
                    .unsqueeze(0)
         | 
| 110 | 
            +
                    .cuda()
         | 
| 111 | 
            +
                )
         | 
| 112 | 
            +
                with torch.inference_mode():
         | 
| 113 | 
            +
                    output_ids = models["model"].generate(
         | 
| 114 | 
            +
                        input_ids,
         | 
| 115 | 
            +
                        images=images_tensor,
         | 
| 116 | 
            +
                        image_sizes=image_sizes,
         | 
| 117 | 
            +
                        do_sample=True if temperature > 0 else False,
         | 
| 118 | 
            +
                        temperature=temperature,
         | 
| 119 | 
            +
                        top_p=top_p,
         | 
| 120 | 
            +
                        num_beams=num_beams,
         | 
| 121 | 
            +
                        max_new_tokens=max_new_tokens,
         | 
| 122 | 
            +
                        use_cache=True,
         | 
| 123 | 
            +
                    )
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                outputs = models["tokenizer"].batch_decode(output_ids, skip_special_tokens=True)[0].strip()
         | 
| 126 | 
            +
                return outputs
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            def remove_prefix(text):
         | 
| 130 | 
            +
                if text.startswith("<画图>"):
         | 
| 131 | 
            +
                    return text[len("<画图>"):], True
         | 
| 132 | 
            +
                elif text.startswith("对不起"):
         | 
| 133 | 
            +
                    # 拒绝画图
         | 
| 134 | 
            +
                    return "", False
         | 
| 135 | 
            +
                else:
         | 
| 136 | 
            +
                    return text, True
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            class DialogGen(object):
         | 
| 140 | 
            +
                def __init__(self, model_path):
         | 
| 141 | 
            +
                    self.models = init_dialoggen_model(model_path)
         | 
| 142 | 
            +
                    self.query_template = "请先判断用户的意图,若为画图则在输出前加入<画图>:{}"
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def __call__(self, prompt):
         | 
| 145 | 
            +
                    enhanced_prompt = eval_model(
         | 
| 146 | 
            +
                        models=self.models,
         | 
| 147 | 
            +
                        query=self.query_template.format(prompt),
         | 
| 148 | 
            +
                        image_file=None,
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    enhanced_prompt, compliance = remove_prefix(enhanced_prompt)
         | 
| 152 | 
            +
                    if not compliance:
         | 
| 153 | 
            +
                        return False, ""
         | 
| 154 | 
            +
                    return True, enhanced_prompt
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            if __name__ == "__main__":
         | 
| 158 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 159 | 
            +
                parser.add_argument('--model_path', type=str, default='./ckpts/dialoggen')
         | 
| 160 | 
            +
                parser.add_argument('--prompt', type=str, default='画一只小猫')
         | 
| 161 | 
            +
                parser.add_argument('--image_file', type=str, default=None) # 'images/demo1.jpeg'
         | 
| 162 | 
            +
                args = parser.parse_args()
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                query = f"请先判断用户的意图,若为画图则在输出前加入<画图>:{args.prompt}"
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                models = init_dialoggen_model(args.model_path)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                res = eval_model(models,
         | 
| 169 | 
            +
                    query=query,
         | 
| 170 | 
            +
                    image_file=args.image_file,
         | 
| 171 | 
            +
                )
         | 
| 172 | 
            +
                print(res)
         | 
    	
        dialoggen/images/demo1.jpeg
    ADDED
    
    |   | 
    	
        dialoggen/images/demo2.jpeg
    ADDED
    
    |   | 
    	
        dialoggen/llava/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .model import LlavaLlamaForCausalLM
         | 
    	
        dialoggen/llava/constants.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            CONTROLLER_HEART_BEAT_EXPIRATION = 30
         | 
| 2 | 
            +
            WORKER_HEART_BEAT_INTERVAL = 15
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            LOGDIR = "."
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Model Constants
         | 
| 7 | 
            +
            IGNORE_INDEX = -100
         | 
| 8 | 
            +
            IMAGE_TOKEN_INDEX = -200
         | 
| 9 | 
            +
            DEFAULT_IMAGE_TOKEN = "<image>"
         | 
| 10 | 
            +
            DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
         | 
| 11 | 
            +
            DEFAULT_IM_START_TOKEN = "<im_start>"
         | 
| 12 | 
            +
            DEFAULT_IM_END_TOKEN = "<im_end>"
         | 
| 13 | 
            +
            IMAGE_PLACEHOLDER = "<image-placeholder>"
         | 
    	
        dialoggen/llava/conversation.py
    ADDED
    
    | @@ -0,0 +1,396 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import dataclasses
         | 
| 2 | 
            +
            from enum import auto, Enum
         | 
| 3 | 
            +
            from typing import List, Tuple
         | 
| 4 | 
            +
            import base64
         | 
| 5 | 
            +
            from io import BytesIO
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class SeparatorStyle(Enum):
         | 
| 10 | 
            +
                """Different separator style."""
         | 
| 11 | 
            +
                SINGLE = auto()
         | 
| 12 | 
            +
                TWO = auto()
         | 
| 13 | 
            +
                MPT = auto()
         | 
| 14 | 
            +
                PLAIN = auto()
         | 
| 15 | 
            +
                LLAMA_2 = auto()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            @dataclasses.dataclass
         | 
| 19 | 
            +
            class Conversation:
         | 
| 20 | 
            +
                """A class that keeps all conversation history."""
         | 
| 21 | 
            +
                system: str
         | 
| 22 | 
            +
                roles: List[str]
         | 
| 23 | 
            +
                messages: List[List[str]]
         | 
| 24 | 
            +
                offset: int
         | 
| 25 | 
            +
                sep_style: SeparatorStyle = SeparatorStyle.SINGLE
         | 
| 26 | 
            +
                sep: str = "###"
         | 
| 27 | 
            +
                sep2: str = None
         | 
| 28 | 
            +
                version: str = "Unknown"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                skip_next: bool = False
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def get_prompt(self):
         | 
| 33 | 
            +
                    messages = self.messages
         | 
| 34 | 
            +
                    if len(messages) > 0 and type(messages[0][1]) is tuple:
         | 
| 35 | 
            +
                        messages = self.messages.copy()
         | 
| 36 | 
            +
                        init_role, init_msg = messages[0].copy()
         | 
| 37 | 
            +
                        init_msg = init_msg[0].replace("<image>", "").strip()
         | 
| 38 | 
            +
                        if 'mmtag' in self.version:
         | 
| 39 | 
            +
                            messages[0] = (init_role, init_msg)
         | 
| 40 | 
            +
                            messages.insert(0, (self.roles[0], "<Image><image></Image>"))
         | 
| 41 | 
            +
                            messages.insert(1, (self.roles[1], "Received."))
         | 
| 42 | 
            +
                        else:
         | 
| 43 | 
            +
                            messages[0] = (init_role, "<image>\n" + init_msg)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    if self.sep_style == SeparatorStyle.SINGLE:
         | 
| 46 | 
            +
                        ret = self.system + self.sep
         | 
| 47 | 
            +
                        for role, message in messages:
         | 
| 48 | 
            +
                            if message:
         | 
| 49 | 
            +
                                if type(message) is tuple:
         | 
| 50 | 
            +
                                    message, _, _ = message
         | 
| 51 | 
            +
                                ret += role + ": " + message + self.sep
         | 
| 52 | 
            +
                            else:
         | 
| 53 | 
            +
                                ret += role + ":"
         | 
| 54 | 
            +
                    elif self.sep_style == SeparatorStyle.TWO:
         | 
| 55 | 
            +
                        seps = [self.sep, self.sep2]
         | 
| 56 | 
            +
                        ret = self.system + seps[0]
         | 
| 57 | 
            +
                        for i, (role, message) in enumerate(messages):
         | 
| 58 | 
            +
                            if message:
         | 
| 59 | 
            +
                                if type(message) is tuple:
         | 
| 60 | 
            +
                                    message, _, _ = message
         | 
| 61 | 
            +
                                ret += role + ": " + message + seps[i % 2]
         | 
| 62 | 
            +
                            else:
         | 
| 63 | 
            +
                                ret += role + ":"
         | 
| 64 | 
            +
                    elif self.sep_style == SeparatorStyle.MPT:
         | 
| 65 | 
            +
                        ret = self.system + self.sep
         | 
| 66 | 
            +
                        for role, message in messages:
         | 
| 67 | 
            +
                            if message:
         | 
| 68 | 
            +
                                if type(message) is tuple:
         | 
| 69 | 
            +
                                    message, _, _ = message
         | 
| 70 | 
            +
                                ret += role + message + self.sep
         | 
| 71 | 
            +
                            else:
         | 
| 72 | 
            +
                                ret += role
         | 
| 73 | 
            +
                    elif self.sep_style == SeparatorStyle.LLAMA_2:
         | 
| 74 | 
            +
                        wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
         | 
| 75 | 
            +
                        wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
         | 
| 76 | 
            +
                        ret = ""
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                        for i, (role, message) in enumerate(messages):
         | 
| 79 | 
            +
                            if i == 0:
         | 
| 80 | 
            +
                                assert message, "first message should not be none"
         | 
| 81 | 
            +
                                assert role == self.roles[0], "first message should come from user"
         | 
| 82 | 
            +
                            if message:
         | 
| 83 | 
            +
                                if type(message) is tuple:
         | 
| 84 | 
            +
                                    message, _, _ = message
         | 
| 85 | 
            +
                                if i == 0: message = wrap_sys(self.system) + message
         | 
| 86 | 
            +
                                if i % 2 == 0:
         | 
| 87 | 
            +
                                    message = wrap_inst(message)
         | 
| 88 | 
            +
                                    ret += self.sep + message
         | 
| 89 | 
            +
                                else:
         | 
| 90 | 
            +
                                    ret += " " + message + " " + self.sep2
         | 
| 91 | 
            +
                            else:
         | 
| 92 | 
            +
                                ret += ""
         | 
| 93 | 
            +
                        ret = ret.lstrip(self.sep)
         | 
| 94 | 
            +
                    elif self.sep_style == SeparatorStyle.PLAIN:
         | 
| 95 | 
            +
                        seps = [self.sep, self.sep2]
         | 
| 96 | 
            +
                        ret = self.system
         | 
| 97 | 
            +
                        for i, (role, message) in enumerate(messages):
         | 
| 98 | 
            +
                            if message:
         | 
| 99 | 
            +
                                if type(message) is tuple:
         | 
| 100 | 
            +
                                    message, _, _ = message
         | 
| 101 | 
            +
                                ret += message + seps[i % 2]
         | 
| 102 | 
            +
                            else:
         | 
| 103 | 
            +
                                ret += ""
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        raise ValueError(f"Invalid style: {self.sep_style}")
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    return ret
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def append_message(self, role, message):
         | 
| 110 | 
            +
                    self.messages.append([role, message])
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
         | 
| 113 | 
            +
                    if image_process_mode == "Pad":
         | 
| 114 | 
            +
                        def expand2square(pil_img, background_color=(122, 116, 104)):
         | 
| 115 | 
            +
                            width, height = pil_img.size
         | 
| 116 | 
            +
                            if width == height:
         | 
| 117 | 
            +
                                return pil_img
         | 
| 118 | 
            +
                            elif width > height:
         | 
| 119 | 
            +
                                result = Image.new(pil_img.mode, (width, width), background_color)
         | 
| 120 | 
            +
                                result.paste(pil_img, (0, (width - height) // 2))
         | 
| 121 | 
            +
                                return result
         | 
| 122 | 
            +
                            else:
         | 
| 123 | 
            +
                                result = Image.new(pil_img.mode, (height, height), background_color)
         | 
| 124 | 
            +
                                result.paste(pil_img, ((height - width) // 2, 0))
         | 
| 125 | 
            +
                                return result
         | 
| 126 | 
            +
                        image = expand2square(image)
         | 
| 127 | 
            +
                    elif image_process_mode in ["Default", "Crop"]:
         | 
| 128 | 
            +
                        pass
         | 
| 129 | 
            +
                    elif image_process_mode == "Resize":
         | 
| 130 | 
            +
                        image = image.resize((336, 336))
         | 
| 131 | 
            +
                    else:
         | 
| 132 | 
            +
                        raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
         | 
| 133 | 
            +
                    if max(image.size) > max_len:
         | 
| 134 | 
            +
                        max_hw, min_hw = max(image.size), min(image.size)
         | 
| 135 | 
            +
                        aspect_ratio = max_hw / min_hw
         | 
| 136 | 
            +
                        shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
         | 
| 137 | 
            +
                        longest_edge = int(shortest_edge * aspect_ratio)
         | 
| 138 | 
            +
                        W, H = image.size
         | 
| 139 | 
            +
                        if H > W:
         | 
| 140 | 
            +
                            H, W = longest_edge, shortest_edge
         | 
| 141 | 
            +
                        else:
         | 
| 142 | 
            +
                            H, W = shortest_edge, longest_edge
         | 
| 143 | 
            +
                        image = image.resize((W, H))
         | 
| 144 | 
            +
                    if return_pil:
         | 
| 145 | 
            +
                        return image
         | 
| 146 | 
            +
                    else:
         | 
| 147 | 
            +
                        buffered = BytesIO()
         | 
| 148 | 
            +
                        image.save(buffered, format=image_format)
         | 
| 149 | 
            +
                        img_b64_str = base64.b64encode(buffered.getvalue()).decode()
         | 
| 150 | 
            +
                        return img_b64_str
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def get_images(self, return_pil=False):
         | 
| 153 | 
            +
                    images = []
         | 
| 154 | 
            +
                    for i, (role, msg) in enumerate(self.messages[self.offset:]):
         | 
| 155 | 
            +
                        if i % 2 == 0:
         | 
| 156 | 
            +
                            if type(msg) is tuple:
         | 
| 157 | 
            +
                                msg, image, image_process_mode = msg
         | 
| 158 | 
            +
                                image = self.process_image(image, image_process_mode, return_pil=return_pil)
         | 
| 159 | 
            +
                                images.append(image)
         | 
| 160 | 
            +
                    return images
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def to_gradio_chatbot(self):
         | 
| 163 | 
            +
                    ret = []
         | 
| 164 | 
            +
                    for i, (role, msg) in enumerate(self.messages[self.offset:]):
         | 
| 165 | 
            +
                        if i % 2 == 0:
         | 
| 166 | 
            +
                            if type(msg) is tuple:
         | 
| 167 | 
            +
                                msg, image, image_process_mode = msg
         | 
| 168 | 
            +
                                img_b64_str = self.process_image(
         | 
| 169 | 
            +
                                    image, "Default", return_pil=False,
         | 
| 170 | 
            +
                                    image_format='JPEG')
         | 
| 171 | 
            +
                                img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
         | 
| 172 | 
            +
                                msg = img_str + msg.replace('<image>', '').strip()
         | 
| 173 | 
            +
                                ret.append([msg, None])
         | 
| 174 | 
            +
                            else:
         | 
| 175 | 
            +
                                ret.append([msg, None])
         | 
| 176 | 
            +
                        else:
         | 
| 177 | 
            +
                            ret[-1][-1] = msg
         | 
| 178 | 
            +
                    return ret
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def copy(self):
         | 
| 181 | 
            +
                    return Conversation(
         | 
| 182 | 
            +
                        system=self.system,
         | 
| 183 | 
            +
                        roles=self.roles,
         | 
| 184 | 
            +
                        messages=[[x, y] for x, y in self.messages],
         | 
| 185 | 
            +
                        offset=self.offset,
         | 
| 186 | 
            +
                        sep_style=self.sep_style,
         | 
| 187 | 
            +
                        sep=self.sep,
         | 
| 188 | 
            +
                        sep2=self.sep2,
         | 
| 189 | 
            +
                        version=self.version)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def dict(self):
         | 
| 192 | 
            +
                    if len(self.get_images()) > 0:
         | 
| 193 | 
            +
                        return {
         | 
| 194 | 
            +
                            "system": self.system,
         | 
| 195 | 
            +
                            "roles": self.roles,
         | 
| 196 | 
            +
                            "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
         | 
| 197 | 
            +
                            "offset": self.offset,
         | 
| 198 | 
            +
                            "sep": self.sep,
         | 
| 199 | 
            +
                            "sep2": self.sep2,
         | 
| 200 | 
            +
                        }
         | 
| 201 | 
            +
                    return {
         | 
| 202 | 
            +
                        "system": self.system,
         | 
| 203 | 
            +
                        "roles": self.roles,
         | 
| 204 | 
            +
                        "messages": self.messages,
         | 
| 205 | 
            +
                        "offset": self.offset,
         | 
| 206 | 
            +
                        "sep": self.sep,
         | 
| 207 | 
            +
                        "sep2": self.sep2,
         | 
| 208 | 
            +
                    }
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            conv_vicuna_v0 = Conversation(
         | 
| 212 | 
            +
                system="A chat between a curious human and an artificial intelligence assistant. "
         | 
| 213 | 
            +
                       "The assistant gives helpful, detailed, and polite answers to the human's questions.",
         | 
| 214 | 
            +
                roles=("Human", "Assistant"),
         | 
| 215 | 
            +
                messages=(
         | 
| 216 | 
            +
                    ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
         | 
| 217 | 
            +
                    ("Assistant",
         | 
| 218 | 
            +
                        "Renewable energy sources are those that can be replenished naturally in a relatively "
         | 
| 219 | 
            +
                        "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
         | 
| 220 | 
            +
                        "Non-renewable energy sources, on the other hand, are finite and will eventually be "
         | 
| 221 | 
            +
                        "depleted, such as coal, oil, and natural gas. Here are some key differences between "
         | 
| 222 | 
            +
                        "renewable and non-renewable energy sources:\n"
         | 
| 223 | 
            +
                        "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
         | 
| 224 | 
            +
                        "energy sources are finite and will eventually run out.\n"
         | 
| 225 | 
            +
                        "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
         | 
| 226 | 
            +
                        "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
         | 
| 227 | 
            +
                        "and other negative effects.\n"
         | 
| 228 | 
            +
                        "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
         | 
| 229 | 
            +
                        "have lower operational costs than non-renewable sources.\n"
         | 
| 230 | 
            +
                        "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
         | 
| 231 | 
            +
                        "locations than non-renewable sources.\n"
         | 
| 232 | 
            +
                        "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
         | 
| 233 | 
            +
                        "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
         | 
| 234 | 
            +
                        "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
         | 
| 235 | 
            +
                        "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
         | 
| 236 | 
            +
                ),
         | 
| 237 | 
            +
                offset=2,
         | 
| 238 | 
            +
                sep_style=SeparatorStyle.SINGLE,
         | 
| 239 | 
            +
                sep="###",
         | 
| 240 | 
            +
            )
         | 
| 241 | 
            +
             | 
| 242 | 
            +
            conv_vicuna_v1 = Conversation(
         | 
| 243 | 
            +
                system="A chat between a curious user and an artificial intelligence assistant. "
         | 
| 244 | 
            +
                "The assistant gives helpful, detailed, and polite answers to the user's questions.",
         | 
| 245 | 
            +
                roles=("USER", "ASSISTANT"),
         | 
| 246 | 
            +
                version="v1",
         | 
| 247 | 
            +
                messages=(),
         | 
| 248 | 
            +
                offset=0,
         | 
| 249 | 
            +
                sep_style=SeparatorStyle.TWO,
         | 
| 250 | 
            +
                sep=" ",
         | 
| 251 | 
            +
                sep2="</s>",
         | 
| 252 | 
            +
            )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
            conv_llama_2 = Conversation(
         | 
| 255 | 
            +
                system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
         | 
| 256 | 
            +
             | 
| 257 | 
            +
            If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
         | 
| 258 | 
            +
                roles=("USER", "ASSISTANT"),
         | 
| 259 | 
            +
                version="llama_v2",
         | 
| 260 | 
            +
                messages=(),
         | 
| 261 | 
            +
                offset=0,
         | 
| 262 | 
            +
                sep_style=SeparatorStyle.LLAMA_2,
         | 
| 263 | 
            +
                sep="<s>",
         | 
| 264 | 
            +
                sep2="</s>",
         | 
| 265 | 
            +
            )
         | 
| 266 | 
            +
             | 
| 267 | 
            +
            conv_llava_llama_2 = Conversation(
         | 
| 268 | 
            +
                system="You are a helpful language and vision assistant. "
         | 
| 269 | 
            +
                       "You are able to understand the visual content that the user provides, "
         | 
| 270 | 
            +
                       "and assist the user with a variety of tasks using natural language.",
         | 
| 271 | 
            +
                roles=("USER", "ASSISTANT"),
         | 
| 272 | 
            +
                version="llama_v2",
         | 
| 273 | 
            +
                messages=(),
         | 
| 274 | 
            +
                offset=0,
         | 
| 275 | 
            +
                sep_style=SeparatorStyle.LLAMA_2,
         | 
| 276 | 
            +
                sep="<s>",
         | 
| 277 | 
            +
                sep2="</s>",
         | 
| 278 | 
            +
            )
         | 
| 279 | 
            +
             | 
| 280 | 
            +
            conv_mpt = Conversation(
         | 
| 281 | 
            +
                system="""<|im_start|>system
         | 
| 282 | 
            +
            A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
         | 
| 283 | 
            +
                roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
         | 
| 284 | 
            +
                version="mpt",
         | 
| 285 | 
            +
                messages=(),
         | 
| 286 | 
            +
                offset=0,
         | 
| 287 | 
            +
                sep_style=SeparatorStyle.MPT,
         | 
| 288 | 
            +
                sep="<|im_end|>",
         | 
| 289 | 
            +
            )
         | 
| 290 | 
            +
             | 
| 291 | 
            +
            conv_llava_plain = Conversation(
         | 
| 292 | 
            +
                system="",
         | 
| 293 | 
            +
                roles=("", ""),
         | 
| 294 | 
            +
                messages=(
         | 
| 295 | 
            +
                ),
         | 
| 296 | 
            +
                offset=0,
         | 
| 297 | 
            +
                sep_style=SeparatorStyle.PLAIN,
         | 
| 298 | 
            +
                sep="\n",
         | 
| 299 | 
            +
            )
         | 
| 300 | 
            +
             | 
| 301 | 
            +
            conv_llava_v0 = Conversation(
         | 
| 302 | 
            +
                system="A chat between a curious human and an artificial intelligence assistant. "
         | 
| 303 | 
            +
                       "The assistant gives helpful, detailed, and polite answers to the human's questions.",
         | 
| 304 | 
            +
                roles=("Human", "Assistant"),
         | 
| 305 | 
            +
                messages=(
         | 
| 306 | 
            +
                ),
         | 
| 307 | 
            +
                offset=0,
         | 
| 308 | 
            +
                sep_style=SeparatorStyle.SINGLE,
         | 
| 309 | 
            +
                sep="###",
         | 
| 310 | 
            +
            )
         | 
| 311 | 
            +
             | 
| 312 | 
            +
            conv_llava_v0_mmtag = Conversation(
         | 
| 313 | 
            +
                system="A chat between a curious user and an artificial intelligence assistant. "
         | 
| 314 | 
            +
                       "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
         | 
| 315 | 
            +
                       "The visual content will be provided with the following format: <Image>visual content</Image>.",
         | 
| 316 | 
            +
                roles=("Human", "Assistant"),
         | 
| 317 | 
            +
                messages=(
         | 
| 318 | 
            +
                ),
         | 
| 319 | 
            +
                offset=0,
         | 
| 320 | 
            +
                sep_style=SeparatorStyle.SINGLE,
         | 
| 321 | 
            +
                sep="###",
         | 
| 322 | 
            +
                version="v0_mmtag",
         | 
| 323 | 
            +
            )
         | 
| 324 | 
            +
             | 
| 325 | 
            +
            conv_llava_v1 = Conversation(
         | 
| 326 | 
            +
                system="A chat between a curious human and an artificial intelligence assistant. "
         | 
| 327 | 
            +
                       "The assistant gives helpful, detailed, and polite answers to the human's questions.",
         | 
| 328 | 
            +
                roles=("USER", "ASSISTANT"),
         | 
| 329 | 
            +
                version="v1",
         | 
| 330 | 
            +
                messages=(),
         | 
| 331 | 
            +
                offset=0,
         | 
| 332 | 
            +
                sep_style=SeparatorStyle.TWO,
         | 
| 333 | 
            +
                sep=" ",
         | 
| 334 | 
            +
                sep2="</s>",
         | 
| 335 | 
            +
            )
         | 
| 336 | 
            +
             | 
| 337 | 
            +
            conv_llava_v1_mmtag = Conversation(
         | 
| 338 | 
            +
                system="A chat between a curious user and an artificial intelligence assistant. "
         | 
| 339 | 
            +
                       "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
         | 
| 340 | 
            +
                       "The visual content will be provided with the following format: <Image>visual content</Image>.",
         | 
| 341 | 
            +
                roles=("USER", "ASSISTANT"),
         | 
| 342 | 
            +
                messages=(),
         | 
| 343 | 
            +
                offset=0,
         | 
| 344 | 
            +
                sep_style=SeparatorStyle.TWO,
         | 
| 345 | 
            +
                sep=" ",
         | 
| 346 | 
            +
                sep2="</s>",
         | 
| 347 | 
            +
                version="v1_mmtag",
         | 
| 348 | 
            +
            )
         | 
| 349 | 
            +
             | 
| 350 | 
            +
            conv_mistral_instruct = Conversation(
         | 
| 351 | 
            +
                system="",
         | 
| 352 | 
            +
                roles=("USER", "ASSISTANT"),
         | 
| 353 | 
            +
                version="llama_v2",
         | 
| 354 | 
            +
                messages=(),
         | 
| 355 | 
            +
                offset=0,
         | 
| 356 | 
            +
                sep_style=SeparatorStyle.LLAMA_2,
         | 
| 357 | 
            +
                sep="",
         | 
| 358 | 
            +
                sep2="</s>",
         | 
| 359 | 
            +
            )
         | 
| 360 | 
            +
             | 
| 361 | 
            +
            conv_chatml_direct = Conversation(
         | 
| 362 | 
            +
                system="""<|im_start|>system
         | 
| 363 | 
            +
            Answer the questions.""",
         | 
| 364 | 
            +
                roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
         | 
| 365 | 
            +
                version="mpt",
         | 
| 366 | 
            +
                messages=(),
         | 
| 367 | 
            +
                offset=0,
         | 
| 368 | 
            +
                sep_style=SeparatorStyle.MPT,
         | 
| 369 | 
            +
                sep="<|im_end|>",
         | 
| 370 | 
            +
            )
         | 
| 371 | 
            +
             | 
| 372 | 
            +
            default_conversation = conv_vicuna_v1
         | 
| 373 | 
            +
            conv_templates = {
         | 
| 374 | 
            +
                "default": conv_vicuna_v0,
         | 
| 375 | 
            +
                "v0": conv_vicuna_v0,
         | 
| 376 | 
            +
                "v1": conv_vicuna_v1,
         | 
| 377 | 
            +
                "vicuna_v1": conv_vicuna_v1,
         | 
| 378 | 
            +
                "llama_2": conv_llama_2,
         | 
| 379 | 
            +
                "mistral_instruct": conv_mistral_instruct,
         | 
| 380 | 
            +
                "chatml_direct": conv_chatml_direct,
         | 
| 381 | 
            +
                "mistral_direct": conv_chatml_direct,
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                "plain": conv_llava_plain,
         | 
| 384 | 
            +
                "v0_plain": conv_llava_plain,
         | 
| 385 | 
            +
                "llava_v0": conv_llava_v0,
         | 
| 386 | 
            +
                "v0_mmtag": conv_llava_v0_mmtag,
         | 
| 387 | 
            +
                "llava_v1": conv_llava_v1,
         | 
| 388 | 
            +
                "v1_mmtag": conv_llava_v1_mmtag,
         | 
| 389 | 
            +
                "llava_llama_2": conv_llava_llama_2,
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                "mpt": conv_mpt,
         | 
| 392 | 
            +
            }
         | 
| 393 | 
            +
             | 
| 394 | 
            +
             | 
| 395 | 
            +
            if __name__ == "__main__":
         | 
| 396 | 
            +
                print(default_conversation.get_prompt())
         | 
    	
        dialoggen/llava/mm_utils.py
    ADDED
    
    | @@ -0,0 +1,247 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image
         | 
| 2 | 
            +
            from io import BytesIO
         | 
| 3 | 
            +
            import base64
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            import ast
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from transformers import StoppingCriteria
         | 
| 9 | 
            +
            from llava.constants import IMAGE_TOKEN_INDEX
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def select_best_resolution(original_size, possible_resolutions):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Selects the best resolution from a list of possible resolutions based on the original size.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                Args:
         | 
| 17 | 
            +
                    original_size (tuple): The original size of the image in the format (width, height).
         | 
| 18 | 
            +
                    possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Returns:
         | 
| 21 | 
            +
                    tuple: The best fit resolution in the format (width, height).
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                original_width, original_height = original_size
         | 
| 24 | 
            +
                best_fit = None
         | 
| 25 | 
            +
                max_effective_resolution = 0
         | 
| 26 | 
            +
                min_wasted_resolution = float('inf')
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                for width, height in possible_resolutions:
         | 
| 29 | 
            +
                    scale = min(width / original_width, height / original_height)
         | 
| 30 | 
            +
                    downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
         | 
| 31 | 
            +
                    effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
         | 
| 32 | 
            +
                    wasted_resolution = (width * height) - effective_resolution
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
         | 
| 35 | 
            +
                        max_effective_resolution = effective_resolution
         | 
| 36 | 
            +
                        min_wasted_resolution = wasted_resolution
         | 
| 37 | 
            +
                        best_fit = (width, height)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return best_fit
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def resize_and_pad_image(image, target_resolution):
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                Resize and pad an image to a target resolution while maintaining aspect ratio.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                Args:
         | 
| 47 | 
            +
                    image (PIL.Image.Image): The input image.
         | 
| 48 | 
            +
                    target_resolution (tuple): The target resolution (width, height) of the image.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                Returns:
         | 
| 51 | 
            +
                    PIL.Image.Image: The resized and padded image.
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                original_width, original_height = image.size
         | 
| 54 | 
            +
                target_width, target_height = target_resolution
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                scale_w = target_width / original_width
         | 
| 57 | 
            +
                scale_h = target_height / original_height
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                if scale_w < scale_h:
         | 
| 60 | 
            +
                    new_width = target_width
         | 
| 61 | 
            +
                    new_height = min(math.ceil(original_height * scale_w), target_height)
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                    new_height = target_height
         | 
| 64 | 
            +
                    new_width = min(math.ceil(original_width * scale_h), target_width)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                # Resize the image
         | 
| 67 | 
            +
                resized_image = image.resize((new_width, new_height))
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
         | 
| 70 | 
            +
                paste_x = (target_width - new_width) // 2
         | 
| 71 | 
            +
                paste_y = (target_height - new_height) // 2
         | 
| 72 | 
            +
                new_image.paste(resized_image, (paste_x, paste_y))
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                return new_image
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def divide_to_patches(image, patch_size):
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                Divides an image into patches of a specified size.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                Args:
         | 
| 82 | 
            +
                    image (PIL.Image.Image): The input image.
         | 
| 83 | 
            +
                    patch_size (int): The size of each patch.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                Returns:
         | 
| 86 | 
            +
                    list: A list of PIL.Image.Image objects representing the patches.
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                patches = []
         | 
| 89 | 
            +
                width, height = image.size
         | 
| 90 | 
            +
                for i in range(0, height, patch_size):
         | 
| 91 | 
            +
                    for j in range(0, width, patch_size):
         | 
| 92 | 
            +
                        box = (j, i, j + patch_size, i + patch_size)
         | 
| 93 | 
            +
                        patch = image.crop(box)
         | 
| 94 | 
            +
                        patches.append(patch)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                return patches
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
         | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
                Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                Args:
         | 
| 104 | 
            +
                    image_size (tuple): The size of the input image in the format (width, height).
         | 
| 105 | 
            +
                    grid_pinpoints (str): A string representation of a list of possible resolutions.
         | 
| 106 | 
            +
                    patch_size (int): The size of each image patch.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Returns:
         | 
| 109 | 
            +
                    tuple: The shape of the image patch grid in the format (width, height).
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                if type(grid_pinpoints) is list:
         | 
| 112 | 
            +
                    possible_resolutions = grid_pinpoints
         | 
| 113 | 
            +
                else:
         | 
| 114 | 
            +
                    possible_resolutions = ast.literal_eval(grid_pinpoints)
         | 
| 115 | 
            +
                width, height = select_best_resolution(image_size, possible_resolutions)
         | 
| 116 | 
            +
                return width // patch_size, height // patch_size
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            def process_anyres_image(image, processor, grid_pinpoints):
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                Process an image with variable resolutions.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                Args:
         | 
| 124 | 
            +
                    image (PIL.Image.Image): The input image to be processed.
         | 
| 125 | 
            +
                    processor: The image processor object.
         | 
| 126 | 
            +
                    grid_pinpoints (str): A string representation of a list of possible resolutions.
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                Returns:
         | 
| 129 | 
            +
                    torch.Tensor: A tensor containing the processed image patches.
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                if type(grid_pinpoints) is list:
         | 
| 132 | 
            +
                    possible_resolutions = grid_pinpoints
         | 
| 133 | 
            +
                else:
         | 
| 134 | 
            +
                    possible_resolutions = ast.literal_eval(grid_pinpoints)
         | 
| 135 | 
            +
                best_resolution = select_best_resolution(image.size, possible_resolutions)
         | 
| 136 | 
            +
                image_padded = resize_and_pad_image(image, best_resolution)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                patches = divide_to_patches(image_padded, processor.crop_size['height'])
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                image_patches = [image_original_resize] + patches
         | 
| 143 | 
            +
                image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
         | 
| 144 | 
            +
                                 for image_patch in image_patches]
         | 
| 145 | 
            +
                return torch.stack(image_patches, dim=0)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def load_image_from_base64(image):
         | 
| 149 | 
            +
                return Image.open(BytesIO(base64.b64decode(image)))
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            def expand2square(pil_img, background_color):
         | 
| 153 | 
            +
                width, height = pil_img.size
         | 
| 154 | 
            +
                if width == height:
         | 
| 155 | 
            +
                    return pil_img
         | 
| 156 | 
            +
                elif width > height:
         | 
| 157 | 
            +
                    result = Image.new(pil_img.mode, (width, width), background_color)
         | 
| 158 | 
            +
                    result.paste(pil_img, (0, (width - height) // 2))
         | 
| 159 | 
            +
                    return result
         | 
| 160 | 
            +
                else:
         | 
| 161 | 
            +
                    result = Image.new(pil_img.mode, (height, height), background_color)
         | 
| 162 | 
            +
                    result.paste(pil_img, ((height - width) // 2, 0))
         | 
| 163 | 
            +
                    return result
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            def process_images(images, image_processor, model_cfg):
         | 
| 167 | 
            +
                image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
         | 
| 168 | 
            +
                new_images = []
         | 
| 169 | 
            +
                if image_aspect_ratio == 'pad':
         | 
| 170 | 
            +
                    for image in images:
         | 
| 171 | 
            +
                        image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
         | 
| 172 | 
            +
                        image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
         | 
| 173 | 
            +
                        new_images.append(image)
         | 
| 174 | 
            +
                elif image_aspect_ratio == "anyres":
         | 
| 175 | 
            +
                    for image in images:
         | 
| 176 | 
            +
                        image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
         | 
| 177 | 
            +
                        new_images.append(image)
         | 
| 178 | 
            +
                else:
         | 
| 179 | 
            +
                    return image_processor(images, return_tensors='pt')['pixel_values']
         | 
| 180 | 
            +
                if all(x.shape == new_images[0].shape for x in new_images):
         | 
| 181 | 
            +
                    new_images = torch.stack(new_images, dim=0)
         | 
| 182 | 
            +
                return new_images
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
         | 
| 186 | 
            +
                prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def insert_separator(X, sep):
         | 
| 189 | 
            +
                    return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                input_ids = []
         | 
| 192 | 
            +
                offset = 0
         | 
| 193 | 
            +
                if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
         | 
| 194 | 
            +
                    offset = 1
         | 
| 195 | 
            +
                    input_ids.append(prompt_chunks[0][0])
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
         | 
| 198 | 
            +
                    input_ids.extend(x[offset:])
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                if return_tensors is not None:
         | 
| 201 | 
            +
                    if return_tensors == 'pt':
         | 
| 202 | 
            +
                        return torch.tensor(input_ids, dtype=torch.long)
         | 
| 203 | 
            +
                    raise ValueError(f'Unsupported tensor type: {return_tensors}')
         | 
| 204 | 
            +
                return input_ids
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def get_model_name_from_path(model_path):
         | 
| 208 | 
            +
                model_path = model_path.strip("/")
         | 
| 209 | 
            +
                model_paths = model_path.split("/")
         | 
| 210 | 
            +
                if model_paths[-1].startswith('checkpoint-'):
         | 
| 211 | 
            +
                    return model_paths[-2] + "_" + model_paths[-1]
         | 
| 212 | 
            +
                else:
         | 
| 213 | 
            +
                    return model_paths[-1]
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            class KeywordsStoppingCriteria(StoppingCriteria):
         | 
| 216 | 
            +
                def __init__(self, keywords, tokenizer, input_ids):
         | 
| 217 | 
            +
                    self.keywords = keywords
         | 
| 218 | 
            +
                    self.keyword_ids = []
         | 
| 219 | 
            +
                    self.max_keyword_len = 0
         | 
| 220 | 
            +
                    for keyword in keywords:
         | 
| 221 | 
            +
                        cur_keyword_ids = tokenizer(keyword).input_ids
         | 
| 222 | 
            +
                        if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
         | 
| 223 | 
            +
                            cur_keyword_ids = cur_keyword_ids[1:]
         | 
| 224 | 
            +
                        if len(cur_keyword_ids) > self.max_keyword_len:
         | 
| 225 | 
            +
                            self.max_keyword_len = len(cur_keyword_ids)
         | 
| 226 | 
            +
                        self.keyword_ids.append(torch.tensor(cur_keyword_ids))
         | 
| 227 | 
            +
                    self.tokenizer = tokenizer
         | 
| 228 | 
            +
                    self.start_len = input_ids.shape[1]
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         | 
| 231 | 
            +
                    offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
         | 
| 232 | 
            +
                    self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
         | 
| 233 | 
            +
                    for keyword_id in self.keyword_ids:
         | 
| 234 | 
            +
                        truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
         | 
| 235 | 
            +
                        if torch.equal(truncated_output_ids, keyword_id):
         | 
| 236 | 
            +
                            return True
         | 
| 237 | 
            +
                    outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
         | 
| 238 | 
            +
                    for keyword in self.keywords:
         | 
| 239 | 
            +
                        if keyword in outputs:
         | 
| 240 | 
            +
                            return True
         | 
| 241 | 
            +
                    return False
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         | 
| 244 | 
            +
                    outputs = []
         | 
| 245 | 
            +
                    for i in range(output_ids.shape[0]):
         | 
| 246 | 
            +
                        outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
         | 
| 247 | 
            +
                    return all(outputs)
         | 
    	
        dialoggen/llava/model/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            try:
         | 
| 2 | 
            +
                from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
         | 
| 3 | 
            +
                from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
         | 
| 4 | 
            +
                from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
         | 
| 5 | 
            +
            except:
         | 
| 6 | 
            +
                pass
         | 
    	
        dialoggen/llava/model/apply_delta.py
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Usage:
         | 
| 3 | 
            +
            python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 10 | 
            +
            from llava import LlavaLlamaForCausalLM
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def apply_delta(base_model_path, target_model_path, delta_path):
         | 
| 14 | 
            +
                print("Loading base model")
         | 
| 15 | 
            +
                base = AutoModelForCausalLM.from_pretrained(
         | 
| 16 | 
            +
                    base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                print("Loading delta")
         | 
| 19 | 
            +
                delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
         | 
| 20 | 
            +
                delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                print("Applying delta")
         | 
| 23 | 
            +
                for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
         | 
| 24 | 
            +
                    if name not in base.state_dict():
         | 
| 25 | 
            +
                        assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
         | 
| 26 | 
            +
                        continue
         | 
| 27 | 
            +
                    if param.data.shape == base.state_dict()[name].shape:
         | 
| 28 | 
            +
                        param.data += base.state_dict()[name]
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
         | 
| 31 | 
            +
                            f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
         | 
| 32 | 
            +
                        bparam = base.state_dict()[name]
         | 
| 33 | 
            +
                        param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                print("Saving target model")
         | 
| 36 | 
            +
                delta.save_pretrained(target_model_path)
         | 
| 37 | 
            +
                delta_tokenizer.save_pretrained(target_model_path)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            if __name__ == "__main__":
         | 
| 41 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 42 | 
            +
                parser.add_argument("--base-model-path", type=str, required=True)
         | 
| 43 | 
            +
                parser.add_argument("--target-model-path", type=str, required=True)
         | 
| 44 | 
            +
                parser.add_argument("--delta-path", type=str, required=True)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                args = parser.parse_args()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
         | 
    	
        dialoggen/llava/model/builder.py
    ADDED
    
    | @@ -0,0 +1,167 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #    Copyright 2023 Haotian Liu
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            #    Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            #    you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            #    You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #        http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #    Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            #    distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            #    See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            #    limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            import os
         | 
| 17 | 
            +
            import warnings
         | 
| 18 | 
            +
            import shutil
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
         | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
            from llava.model import *
         | 
| 23 | 
            +
            from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, llava_type_model=True, **kwargs):
         | 
| 27 | 
            +
                kwargs = {"device_map": device_map, **kwargs}
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                if device != "cuda":
         | 
| 30 | 
            +
                    kwargs['device_map'] = {"": device}
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if load_8bit:
         | 
| 33 | 
            +
                    kwargs['load_in_8bit'] = True
         | 
| 34 | 
            +
                elif load_4bit:
         | 
| 35 | 
            +
                    kwargs['load_in_4bit'] = True
         | 
| 36 | 
            +
                    kwargs['quantization_config'] = BitsAndBytesConfig(
         | 
| 37 | 
            +
                        load_in_4bit=True,
         | 
| 38 | 
            +
                        bnb_4bit_compute_dtype=torch.float16,
         | 
| 39 | 
            +
                        bnb_4bit_use_double_quant=True,
         | 
| 40 | 
            +
                        bnb_4bit_quant_type='nf4'
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
                else:
         | 
| 43 | 
            +
                    kwargs['torch_dtype'] = torch.float16
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if use_flash_attn:
         | 
| 46 | 
            +
                    kwargs['attn_implementation'] = 'flash_attention_2'
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                if 'llava' in model_name.lower():
         | 
| 49 | 
            +
                    # Load LLaVA model
         | 
| 50 | 
            +
                    if 'lora' in model_name.lower() and model_base is None:
         | 
| 51 | 
            +
                        warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
         | 
| 52 | 
            +
                    if 'lora' in model_name.lower() and model_base is not None:
         | 
| 53 | 
            +
                        from llava.model.language_model.llava_llama import LlavaConfig
         | 
| 54 | 
            +
                        lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
         | 
| 55 | 
            +
                        tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
         | 
| 56 | 
            +
                        print('Loading LLaVA from base model...')
         | 
| 57 | 
            +
                        model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
         | 
| 58 | 
            +
                        token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
         | 
| 59 | 
            +
                        if model.lm_head.weight.shape[0] != token_num:
         | 
| 60 | 
            +
                            model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
         | 
| 61 | 
            +
                            model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                        print('Loading additional LLaVA weights...')
         | 
| 64 | 
            +
                        if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
         | 
| 65 | 
            +
                            non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
         | 
| 66 | 
            +
                        else:
         | 
| 67 | 
            +
                            # this is probably from HF Hub
         | 
| 68 | 
            +
                            from huggingface_hub import hf_hub_download
         | 
| 69 | 
            +
                            def load_from_hf(repo_id, filename, subfolder=None):
         | 
| 70 | 
            +
                                cache_file = hf_hub_download(
         | 
| 71 | 
            +
                                    repo_id=repo_id,
         | 
| 72 | 
            +
                                    filename=filename,
         | 
| 73 | 
            +
                                    subfolder=subfolder)
         | 
| 74 | 
            +
                                return torch.load(cache_file, map_location='cpu')
         | 
| 75 | 
            +
                            non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
         | 
| 76 | 
            +
                        non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
         | 
| 77 | 
            +
                        if any(k.startswith('model.model.') for k in non_lora_trainables):
         | 
| 78 | 
            +
                            non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
         | 
| 79 | 
            +
                        model.load_state_dict(non_lora_trainables, strict=False)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                        from peft import PeftModel
         | 
| 82 | 
            +
                        print('Loading LoRA weights...')
         | 
| 83 | 
            +
                        model = PeftModel.from_pretrained(model, model_path)
         | 
| 84 | 
            +
                        print('Merging LoRA weights...')
         | 
| 85 | 
            +
                        model = model.merge_and_unload()
         | 
| 86 | 
            +
                        print('Model is loaded...')
         | 
| 87 | 
            +
                    elif model_base is not None:
         | 
| 88 | 
            +
                        # this may be mm projector only
         | 
| 89 | 
            +
                        print('Loading LLaVA from base model...')
         | 
| 90 | 
            +
                        if 'mpt' in model_name.lower():
         | 
| 91 | 
            +
                            if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
         | 
| 92 | 
            +
                                shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
         | 
| 93 | 
            +
                            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
         | 
| 94 | 
            +
                            cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
         | 
| 95 | 
            +
                            model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
         | 
| 96 | 
            +
                        else:
         | 
| 97 | 
            +
                            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
         | 
| 98 | 
            +
                            cfg_pretrained = AutoConfig.from_pretrained(model_path)
         | 
| 99 | 
            +
                            model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                        mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
         | 
| 102 | 
            +
                        mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
         | 
| 103 | 
            +
                        model.load_state_dict(mm_projector_weights, strict=False)
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        if 'mpt' in model_name.lower():
         | 
| 106 | 
            +
                            tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
         | 
| 107 | 
            +
                            model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
         | 
| 108 | 
            +
                        elif 'mistral' in model_name.lower():
         | 
| 109 | 
            +
                            tokenizer = AutoTokenizer.from_pretrained(model_path)
         | 
| 110 | 
            +
                            model = LlavaMistralForCausalLM.from_pretrained(
         | 
| 111 | 
            +
                                model_path,
         | 
| 112 | 
            +
                                low_cpu_mem_usage=True,
         | 
| 113 | 
            +
                                **kwargs
         | 
| 114 | 
            +
                            )
         | 
| 115 | 
            +
                        else:
         | 
| 116 | 
            +
                            tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
         | 
| 117 | 
            +
                            model = LlavaLlamaForCausalLM.from_pretrained(
         | 
| 118 | 
            +
                                model_path,
         | 
| 119 | 
            +
                                low_cpu_mem_usage=True,
         | 
| 120 | 
            +
                                **kwargs
         | 
| 121 | 
            +
                            )
         | 
| 122 | 
            +
                else:
         | 
| 123 | 
            +
                    # Load language model
         | 
| 124 | 
            +
                    if model_base is not None:
         | 
| 125 | 
            +
                        # PEFT model
         | 
| 126 | 
            +
                        from peft import PeftModel
         | 
| 127 | 
            +
                        tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
         | 
| 128 | 
            +
                        model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
         | 
| 129 | 
            +
                        print(f"Loading LoRA weights from {model_path}")
         | 
| 130 | 
            +
                        model = PeftModel.from_pretrained(model, model_path)
         | 
| 131 | 
            +
                        print(f"Merging weights")
         | 
| 132 | 
            +
                        model = model.merge_and_unload()
         | 
| 133 | 
            +
                        print('Convert to FP16...')
         | 
| 134 | 
            +
                        model.to(torch.float16)
         | 
| 135 | 
            +
                    else:
         | 
| 136 | 
            +
                        use_fast = False
         | 
| 137 | 
            +
                        if 'mpt' in model_name.lower():
         | 
| 138 | 
            +
                            tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
         | 
| 139 | 
            +
                            model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
         | 
| 140 | 
            +
                        else:
         | 
| 141 | 
            +
                            tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
         | 
| 142 | 
            +
                            model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                image_processor = None
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                if llava_type_model:
         | 
| 147 | 
            +
                    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
         | 
| 148 | 
            +
                    mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
         | 
| 149 | 
            +
                    if mm_use_im_patch_token:
         | 
| 150 | 
            +
                        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
         | 
| 151 | 
            +
                    if mm_use_im_start_end:
         | 
| 152 | 
            +
                        tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
         | 
| 153 | 
            +
                    model.resize_token_embeddings(len(tokenizer))
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    vision_tower = model.get_vision_tower()
         | 
| 156 | 
            +
                    if not vision_tower.is_loaded:
         | 
| 157 | 
            +
                        vision_tower.load_model(device_map=device_map)
         | 
| 158 | 
            +
                    if device_map != 'auto':
         | 
| 159 | 
            +
                        vision_tower.to(device=device_map, dtype=torch.float16)
         | 
| 160 | 
            +
                    image_processor = vision_tower.image_processor
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                if hasattr(model.config, "max_sequence_length"):
         | 
| 163 | 
            +
                    context_len = model.config.max_sequence_length
         | 
| 164 | 
            +
                else:
         | 
| 165 | 
            +
                    context_len = 2048
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                return tokenizer, model, image_processor, context_len
         | 
    	
        dialoggen/llava/model/consolidate.py
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Usage:
         | 
| 3 | 
            +
            python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 9 | 
            +
            from llava.model import *
         | 
| 10 | 
            +
            from llava.model.utils import auto_upgrade
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def consolidate_ckpt(src_path, dst_path):
         | 
| 14 | 
            +
                print("Loading model")
         | 
| 15 | 
            +
                auto_upgrade(src_path)
         | 
| 16 | 
            +
                src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
         | 
| 17 | 
            +
                src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
         | 
| 18 | 
            +
                src_model.save_pretrained(dst_path)
         | 
| 19 | 
            +
                src_tokenizer.save_pretrained(dst_path)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            if __name__ == "__main__":
         | 
| 23 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 24 | 
            +
                parser.add_argument("--src", type=str, required=True)
         | 
| 25 | 
            +
                parser.add_argument("--dst", type=str, required=True)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                args = parser.parse_args()
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                consolidate_ckpt(args.src, args.dst)
         | 
    	
        dialoggen/llava/model/language_model/llava_llama.py
    ADDED
    
    | @@ -0,0 +1,158 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #    Copyright 2023 Haotian Liu
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            #    Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            #    you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            #    You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #        http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #    Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            #    distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            #    See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            #    limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from transformers import AutoConfig, AutoModelForCausalLM, \
         | 
| 22 | 
            +
                                     LlamaConfig, LlamaModel, LlamaForCausalLM
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         | 
| 25 | 
            +
            from transformers.generation.utils import GenerateOutput
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class LlavaConfig(LlamaConfig):
         | 
| 31 | 
            +
                model_type = "llava_llama"
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
         | 
| 35 | 
            +
                config_class = LlavaConfig
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __init__(self, config: LlamaConfig):
         | 
| 38 | 
            +
                    super(LlavaLlamaModel, self).__init__(config)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
         | 
| 42 | 
            +
                config_class = LlavaConfig
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __init__(self, config):
         | 
| 45 | 
            +
                    super(LlamaForCausalLM, self).__init__(config)
         | 
| 46 | 
            +
                    self.model = LlavaLlamaModel(config)
         | 
| 47 | 
            +
                    self.pretraining_tp = config.pretraining_tp
         | 
| 48 | 
            +
                    self.vocab_size = config.vocab_size
         | 
| 49 | 
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Initialize weights and apply final processing
         | 
| 52 | 
            +
                    self.post_init()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def get_model(self):
         | 
| 55 | 
            +
                    return self.model
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def forward(
         | 
| 58 | 
            +
                    self,
         | 
| 59 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 60 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 61 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 62 | 
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         | 
| 63 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 64 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 65 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 66 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 67 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 68 | 
            +
                    images: Optional[torch.FloatTensor] = None,
         | 
| 69 | 
            +
                    image_sizes: Optional[List[List[int]]] = None,
         | 
| 70 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 71 | 
            +
                ) -> Union[Tuple, CausalLMOutputWithPast]:
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if inputs_embeds is None:
         | 
| 74 | 
            +
                        (
         | 
| 75 | 
            +
                            input_ids,
         | 
| 76 | 
            +
                            position_ids,
         | 
| 77 | 
            +
                            attention_mask,
         | 
| 78 | 
            +
                            past_key_values,
         | 
| 79 | 
            +
                            inputs_embeds,
         | 
| 80 | 
            +
                            labels
         | 
| 81 | 
            +
                        ) = self.prepare_inputs_labels_for_multimodal(
         | 
| 82 | 
            +
                            input_ids,
         | 
| 83 | 
            +
                            position_ids,
         | 
| 84 | 
            +
                            attention_mask,
         | 
| 85 | 
            +
                            past_key_values,
         | 
| 86 | 
            +
                            labels,
         | 
| 87 | 
            +
                            images,
         | 
| 88 | 
            +
                            image_sizes
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    return super().forward(
         | 
| 92 | 
            +
                        input_ids=input_ids,
         | 
| 93 | 
            +
                        attention_mask=attention_mask,
         | 
| 94 | 
            +
                        position_ids=position_ids,
         | 
| 95 | 
            +
                        past_key_values=past_key_values,
         | 
| 96 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 97 | 
            +
                        labels=labels,
         | 
| 98 | 
            +
                        use_cache=use_cache,
         | 
| 99 | 
            +
                        output_attentions=output_attentions,
         | 
| 100 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 101 | 
            +
                        return_dict=return_dict
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                @torch.no_grad()
         | 
| 105 | 
            +
                def generate(
         | 
| 106 | 
            +
                    self,
         | 
| 107 | 
            +
                    inputs: Optional[torch.Tensor] = None,
         | 
| 108 | 
            +
                    images: Optional[torch.Tensor] = None,
         | 
| 109 | 
            +
                    image_sizes: Optional[torch.Tensor] = None,
         | 
| 110 | 
            +
                    **kwargs,
         | 
| 111 | 
            +
                ) -> Union[GenerateOutput, torch.LongTensor]:
         | 
| 112 | 
            +
                    position_ids = kwargs.pop("position_ids", None)
         | 
| 113 | 
            +
                    attention_mask = kwargs.pop("attention_mask", None)
         | 
| 114 | 
            +
                    if "inputs_embeds" in kwargs:
         | 
| 115 | 
            +
                        raise NotImplementedError("`inputs_embeds` is not supported")
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if images is not None:
         | 
| 118 | 
            +
                        (
         | 
| 119 | 
            +
                            inputs,
         | 
| 120 | 
            +
                            position_ids,
         | 
| 121 | 
            +
                            attention_mask,
         | 
| 122 | 
            +
                            _,
         | 
| 123 | 
            +
                            inputs_embeds,
         | 
| 124 | 
            +
                            _
         | 
| 125 | 
            +
                        ) = self.prepare_inputs_labels_for_multimodal(
         | 
| 126 | 
            +
                            inputs,
         | 
| 127 | 
            +
                            position_ids,
         | 
| 128 | 
            +
                            attention_mask,
         | 
| 129 | 
            +
                            None,
         | 
| 130 | 
            +
                            None,
         | 
| 131 | 
            +
                            images,
         | 
| 132 | 
            +
                            image_sizes=image_sizes
         | 
| 133 | 
            +
                        )
         | 
| 134 | 
            +
                    else:
         | 
| 135 | 
            +
                        inputs_embeds = self.get_model().embed_tokens(inputs)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    return super().generate(
         | 
| 138 | 
            +
                        position_ids=position_ids,
         | 
| 139 | 
            +
                        attention_mask=attention_mask,
         | 
| 140 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 141 | 
            +
                        **kwargs
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
         | 
| 145 | 
            +
                                                  inputs_embeds=None, **kwargs):
         | 
| 146 | 
            +
                    images = kwargs.pop("images", None)
         | 
| 147 | 
            +
                    image_sizes = kwargs.pop("image_sizes", None)
         | 
| 148 | 
            +
                    inputs = super().prepare_inputs_for_generation(
         | 
| 149 | 
            +
                        input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
         | 
| 150 | 
            +
                    )
         | 
| 151 | 
            +
                    if images is not None:
         | 
| 152 | 
            +
                        inputs['images'] = images
         | 
| 153 | 
            +
                    if image_sizes is not None:
         | 
| 154 | 
            +
                        inputs['image_sizes'] = image_sizes
         | 
| 155 | 
            +
                    return inputs
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            AutoConfig.register("llava_llama", LlavaConfig)
         | 
| 158 | 
            +
            AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
         | 
    	
        dialoggen/llava/model/language_model/llava_mistral.py
    ADDED
    
    | @@ -0,0 +1,158 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #    Copyright 2023 Haotian Liu
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            #    Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            #    you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            #    You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #        http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #    Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            #    distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            #    See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            #    limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
            from torch.nn import CrossEntropyLoss
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from transformers import AutoConfig, AutoModelForCausalLM, \
         | 
| 23 | 
            +
                                     MistralConfig, MistralModel, MistralForCausalLM
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         | 
| 26 | 
            +
            from transformers.generation.utils import GenerateOutput
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class LlavaMistralConfig(MistralConfig):
         | 
| 32 | 
            +
                model_type = "llava_mistral"
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class LlavaMistralModel(LlavaMetaModel, MistralModel):
         | 
| 36 | 
            +
                config_class = LlavaMistralConfig
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def __init__(self, config: MistralConfig):
         | 
| 39 | 
            +
                    super(LlavaMistralModel, self).__init__(config)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
         | 
| 43 | 
            +
                config_class = LlavaMistralConfig
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def __init__(self, config):
         | 
| 46 | 
            +
                    super(MistralForCausalLM, self).__init__(config)
         | 
| 47 | 
            +
                    self.model = LlavaMistralModel(config)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Initialize weights and apply final processing
         | 
| 52 | 
            +
                    self.post_init()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def get_model(self):
         | 
| 55 | 
            +
                    return self.model
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def forward(
         | 
| 58 | 
            +
                    self,
         | 
| 59 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 60 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 61 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 62 | 
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         | 
| 63 | 
            +
                    inputs_embeds: Optional[torch.FloatTensor] = None,
         | 
| 64 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 65 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 66 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 67 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 68 | 
            +
                    images: Optional[torch.FloatTensor] = None,
         | 
| 69 | 
            +
                    image_sizes: Optional[List[List[int]]] = None,
         | 
| 70 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 71 | 
            +
                ) -> Union[Tuple, CausalLMOutputWithPast]:
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if inputs_embeds is None:
         | 
| 74 | 
            +
                        (
         | 
| 75 | 
            +
                            input_ids,
         | 
| 76 | 
            +
                            position_ids,
         | 
| 77 | 
            +
                            attention_mask,
         | 
| 78 | 
            +
                            past_key_values,
         | 
| 79 | 
            +
                            inputs_embeds,
         | 
| 80 | 
            +
                            labels
         | 
| 81 | 
            +
                        ) = self.prepare_inputs_labels_for_multimodal(
         | 
| 82 | 
            +
                            input_ids,
         | 
| 83 | 
            +
                            position_ids,
         | 
| 84 | 
            +
                            attention_mask,
         | 
| 85 | 
            +
                            past_key_values,
         | 
| 86 | 
            +
                            labels,
         | 
| 87 | 
            +
                            images,
         | 
| 88 | 
            +
                            image_sizes
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    return super().forward(
         | 
| 92 | 
            +
                        input_ids=input_ids,
         | 
| 93 | 
            +
                        attention_mask=attention_mask,
         | 
| 94 | 
            +
                        position_ids=position_ids,
         | 
| 95 | 
            +
                        past_key_values=past_key_values,
         | 
| 96 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 97 | 
            +
                        labels=labels,
         | 
| 98 | 
            +
                        use_cache=use_cache,
         | 
| 99 | 
            +
                        output_attentions=output_attentions,
         | 
| 100 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 101 | 
            +
                        return_dict=return_dict
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                @torch.no_grad()
         | 
| 105 | 
            +
                def generate(
         | 
| 106 | 
            +
                    self,
         | 
| 107 | 
            +
                    inputs: Optional[torch.Tensor] = None,
         | 
| 108 | 
            +
                    images: Optional[torch.Tensor] = None,
         | 
| 109 | 
            +
                    image_sizes: Optional[torch.Tensor] = None,
         | 
| 110 | 
            +
                    **kwargs,
         | 
| 111 | 
            +
                ) -> Union[GenerateOutput, torch.LongTensor]:
         | 
| 112 | 
            +
                    position_ids = kwargs.pop("position_ids", None)
         | 
| 113 | 
            +
                    attention_mask = kwargs.pop("attention_mask", None)
         | 
| 114 | 
            +
                    if "inputs_embeds" in kwargs:
         | 
| 115 | 
            +
                        raise NotImplementedError("`inputs_embeds` is not supported")
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if images is not None:
         | 
| 118 | 
            +
                        (
         | 
| 119 | 
            +
                            inputs,
         | 
| 120 | 
            +
                            position_ids,
         | 
| 121 | 
            +
                            attention_mask,
         | 
| 122 | 
            +
                            _,
         | 
| 123 | 
            +
                            inputs_embeds,
         | 
| 124 | 
            +
                            _
         | 
| 125 | 
            +
                        ) = self.prepare_inputs_labels_for_multimodal(
         | 
| 126 | 
            +
                            inputs,
         | 
| 127 | 
            +
                            position_ids,
         | 
| 128 | 
            +
                            attention_mask,
         | 
| 129 | 
            +
                            None,
         | 
| 130 | 
            +
                            None,
         | 
| 131 | 
            +
                            images,
         | 
| 132 | 
            +
                            image_sizes=image_sizes
         | 
| 133 | 
            +
                        )
         | 
| 134 | 
            +
                    else:
         | 
| 135 | 
            +
                        inputs_embeds = self.get_model().embed_tokens(inputs)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    return super().generate(
         | 
| 138 | 
            +
                        position_ids=position_ids,
         | 
| 139 | 
            +
                        attention_mask=attention_mask,
         | 
| 140 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 141 | 
            +
                        **kwargs
         | 
| 142 | 
            +
                    )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
         | 
| 145 | 
            +
                                                  inputs_embeds=None, **kwargs):
         | 
| 146 | 
            +
                    images = kwargs.pop("images", None)
         | 
| 147 | 
            +
                    image_sizes = kwargs.pop("image_sizes", None)
         | 
| 148 | 
            +
                    inputs = super().prepare_inputs_for_generation(
         | 
| 149 | 
            +
                        input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
         | 
| 150 | 
            +
                    )
         | 
| 151 | 
            +
                    if images is not None:
         | 
| 152 | 
            +
                        inputs['images'] = images
         | 
| 153 | 
            +
                    if image_sizes is not None:
         | 
| 154 | 
            +
                        inputs['image_sizes'] = image_sizes
         | 
| 155 | 
            +
                    return inputs
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            AutoConfig.register("llava_mistral", LlavaMistralConfig)
         | 
| 158 | 
            +
            AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
         | 
    	
        dialoggen/llava/model/language_model/llava_mpt.py
    ADDED
    
    | @@ -0,0 +1,97 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #    Copyright 2023 Haotian Liu
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            #    Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            #    you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            #    You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #        http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #    Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            #    distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            #    See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            #    limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from typing import Optional, Tuple
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from transformers import AutoConfig, AutoModelForCausalLM, \
         | 
| 21 | 
            +
                                     MptConfig, MptForCausalLM, MptModel
         | 
| 22 | 
            +
            from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class LlavaMptConfig(MptConfig):
         | 
| 26 | 
            +
                model_type = "llava_mpt"
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class LlavaMptModel(LlavaMetaModel, MptModel):
         | 
| 30 | 
            +
                config_class = LlavaMptConfig
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __init__(self, config: MptConfig):
         | 
| 33 | 
            +
                    config.hidden_size = config.d_model
         | 
| 34 | 
            +
                    super(LlavaMptModel, self).__init__(config)
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
                def embed_tokens(self, x):
         | 
| 37 | 
            +
                    return self.wte(x)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
         | 
| 41 | 
            +
                config_class = LlavaMptConfig
         | 
| 42 | 
            +
                supports_gradient_checkpointing = True
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __init__(self, config):
         | 
| 45 | 
            +
                    super(MptForCausalLM, self).__init__(config)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    self.transformer = LlavaMptModel(config)
         | 
| 48 | 
            +
                    self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # Initialize weights and apply final processing
         | 
| 51 | 
            +
                    self.post_init()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def get_model(self):
         | 
| 54 | 
            +
                    return self.transformer
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 57 | 
            +
                    if isinstance(module, LlavaMptModel):
         | 
| 58 | 
            +
                        module.gradient_checkpointing = value
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def forward(
         | 
| 61 | 
            +
                    self,
         | 
| 62 | 
            +
                    input_ids: Optional[torch.LongTensor] = None,
         | 
| 63 | 
            +
                    past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
         | 
| 64 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 65 | 
            +
                    inputs_embeds: Optional[torch.Tensor] = None,
         | 
| 66 | 
            +
                    labels: Optional[torch.Tensor] = None,
         | 
| 67 | 
            +
                    use_cache: Optional[bool] = None,
         | 
| 68 | 
            +
                    output_attentions: Optional[bool] = None,
         | 
| 69 | 
            +
                    output_hidden_states: Optional[bool] = None,
         | 
| 70 | 
            +
                    return_dict: Optional[bool] = None,
         | 
| 71 | 
            +
                    images=None):
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    return super().forward(
         | 
| 76 | 
            +
                        input_ids,
         | 
| 77 | 
            +
                        past_key_values=past_key_values,
         | 
| 78 | 
            +
                        attention_mask=attention_mask,
         | 
| 79 | 
            +
                        inputs_embeds=inputs_embeds,
         | 
| 80 | 
            +
                        labels=labels,
         | 
| 81 | 
            +
                        use_cache=use_cache,
         | 
| 82 | 
            +
                        output_attentions=output_attentions,
         | 
| 83 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 84 | 
            +
                        return_dict=return_dict,
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
         | 
| 88 | 
            +
                    images = kwargs.pop("images", None)
         | 
| 89 | 
            +
                    _inputs = super().prepare_inputs_for_generation(
         | 
| 90 | 
            +
                        input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                    _inputs['images'] = images
         | 
| 93 | 
            +
                    return _inputs
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            AutoConfig.register("llava_mpt", LlavaMptConfig)
         | 
| 97 | 
            +
            AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
         | 
    	
        dialoggen/llava/model/llava_arch.py
    ADDED
    
    | @@ -0,0 +1,368 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #    Copyright 2023 Haotian Liu
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            #    Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            #    you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            #    You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #        http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #    Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            #    distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            #    See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            #    limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            from abc import ABC, abstractmethod
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torch.nn as nn
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from .multimodal_encoder.builder import build_vision_tower
         | 
| 22 | 
            +
            from .multimodal_projector.builder import build_vision_projector
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from llava.mm_utils import get_anyres_image_grid_shape
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class LlavaMetaModel:
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __init__(self, config):
         | 
| 32 | 
            +
                    super(LlavaMetaModel, self).__init__(config)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if hasattr(config, "mm_vision_tower"):
         | 
| 35 | 
            +
                        self.vision_tower = build_vision_tower(config, delay_load=True)
         | 
| 36 | 
            +
                        self.mm_projector = build_vision_projector(config)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                        if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
         | 
| 39 | 
            +
                            self.image_newline = nn.Parameter(
         | 
| 40 | 
            +
                                torch.empty(config.hidden_size, dtype=self.dtype)
         | 
| 41 | 
            +
                            )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def get_vision_tower(self):
         | 
| 44 | 
            +
                    vision_tower = getattr(self, 'vision_tower', None)
         | 
| 45 | 
            +
                    if type(vision_tower) is list:
         | 
| 46 | 
            +
                        vision_tower = vision_tower[0]
         | 
| 47 | 
            +
                    return vision_tower
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def initialize_vision_modules(self, model_args, fsdp=None):
         | 
| 50 | 
            +
                    vision_tower = model_args.vision_tower
         | 
| 51 | 
            +
                    mm_vision_select_layer = model_args.mm_vision_select_layer
         | 
| 52 | 
            +
                    mm_vision_select_feature = model_args.mm_vision_select_feature
         | 
| 53 | 
            +
                    pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
         | 
| 54 | 
            +
                    mm_patch_merge_type = model_args.mm_patch_merge_type
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    self.config.mm_vision_tower = vision_tower
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    if self.get_vision_tower() is None:
         | 
| 59 | 
            +
                        vision_tower = build_vision_tower(model_args)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                        if fsdp is not None and len(fsdp) > 0:
         | 
| 62 | 
            +
                            self.vision_tower = [vision_tower]
         | 
| 63 | 
            +
                        else:
         | 
| 64 | 
            +
                            self.vision_tower = vision_tower
         | 
| 65 | 
            +
                    else:
         | 
| 66 | 
            +
                        if fsdp is not None and len(fsdp) > 0:
         | 
| 67 | 
            +
                            vision_tower = self.vision_tower[0]
         | 
| 68 | 
            +
                        else:
         | 
| 69 | 
            +
                            vision_tower = self.vision_tower
         | 
| 70 | 
            +
                        vision_tower.load_model()
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    self.config.use_mm_proj = True
         | 
| 73 | 
            +
                    self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
         | 
| 74 | 
            +
                    self.config.mm_hidden_size = vision_tower.hidden_size
         | 
| 75 | 
            +
                    self.config.mm_vision_select_layer = mm_vision_select_layer
         | 
| 76 | 
            +
                    self.config.mm_vision_select_feature = mm_vision_select_feature
         | 
| 77 | 
            +
                    self.config.mm_patch_merge_type = mm_patch_merge_type
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if getattr(self, 'mm_projector', None) is None:
         | 
| 80 | 
            +
                        self.mm_projector = build_vision_projector(self.config)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                        if 'unpad' in mm_patch_merge_type:
         | 
| 83 | 
            +
                            embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
         | 
| 84 | 
            +
                            self.image_newline = nn.Parameter(
         | 
| 85 | 
            +
                                torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
         | 
| 86 | 
            +
                            )
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        # In case it is frozen by LoRA
         | 
| 89 | 
            +
                        for p in self.mm_projector.parameters():
         | 
| 90 | 
            +
                            p.requires_grad = True
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if pretrain_mm_mlp_adapter is not None:
         | 
| 93 | 
            +
                        mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
         | 
| 94 | 
            +
                        def get_w(weights, keyword):
         | 
| 95 | 
            +
                            return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def unpad_image(tensor, original_size):
         | 
| 101 | 
            +
                """
         | 
| 102 | 
            +
                Unpads a PyTorch tensor of a padded and resized image.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Args:
         | 
| 105 | 
            +
                tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
         | 
| 106 | 
            +
                original_size (tuple): The original size of the image (height, width).
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Returns:
         | 
| 109 | 
            +
                torch.Tensor: The unpadded image tensor.
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                original_width, original_height = original_size
         | 
| 112 | 
            +
                current_height, current_width = tensor.shape[1:]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                original_aspect_ratio = original_width / original_height
         | 
| 115 | 
            +
                current_aspect_ratio = current_width / current_height
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                if original_aspect_ratio > current_aspect_ratio:
         | 
| 118 | 
            +
                    scale_factor = current_width / original_width
         | 
| 119 | 
            +
                    new_height = int(original_height * scale_factor)
         | 
| 120 | 
            +
                    padding = (current_height - new_height) // 2
         | 
| 121 | 
            +
                    unpadded_tensor = tensor[:, padding:current_height - padding, :]
         | 
| 122 | 
            +
                else:
         | 
| 123 | 
            +
                    scale_factor = current_height / original_height
         | 
| 124 | 
            +
                    new_width = int(original_width * scale_factor)
         | 
| 125 | 
            +
                    padding = (current_width - new_width) // 2
         | 
| 126 | 
            +
                    unpadded_tensor = tensor[:, :, padding:current_width - padding]
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                return unpadded_tensor
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            class LlavaMetaForCausalLM(ABC):
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                @abstractmethod
         | 
| 134 | 
            +
                def get_model(self):
         | 
| 135 | 
            +
                    pass
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def get_vision_tower(self):
         | 
| 138 | 
            +
                    return self.get_model().get_vision_tower()
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def encode_images(self, images):
         | 
| 141 | 
            +
                    image_features = self.get_model().get_vision_tower()(images)
         | 
| 142 | 
            +
                    image_features = self.get_model().mm_projector(image_features)
         | 
| 143 | 
            +
                    return image_features
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def prepare_inputs_labels_for_multimodal(
         | 
| 146 | 
            +
                    self, input_ids, position_ids, attention_mask, past_key_values, labels,
         | 
| 147 | 
            +
                    images, image_sizes=None
         | 
| 148 | 
            +
                ):
         | 
| 149 | 
            +
                    vision_tower = self.get_vision_tower()
         | 
| 150 | 
            +
                    if vision_tower is None or images is None or input_ids.shape[1] == 1:
         | 
| 151 | 
            +
                        return input_ids, position_ids, attention_mask, past_key_values, None, labels
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    if type(images) is list or images.ndim == 5:
         | 
| 154 | 
            +
                        if type(images) is list:
         | 
| 155 | 
            +
                            images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
         | 
| 156 | 
            +
                        concat_images = torch.cat([image for image in images], dim=0)
         | 
| 157 | 
            +
                        image_features = self.encode_images(concat_images)
         | 
| 158 | 
            +
                        split_sizes = [image.shape[0] for image in images]
         | 
| 159 | 
            +
                        image_features = torch.split(image_features, split_sizes, dim=0)
         | 
| 160 | 
            +
                        mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
         | 
| 161 | 
            +
                        image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
         | 
| 162 | 
            +
                        if mm_patch_merge_type == 'flat':
         | 
| 163 | 
            +
                            image_features = [x.flatten(0, 1) for x in image_features]
         | 
| 164 | 
            +
                        elif mm_patch_merge_type.startswith('spatial'):
         | 
| 165 | 
            +
                            new_image_features = []
         | 
| 166 | 
            +
                            for image_idx, image_feature in enumerate(image_features):
         | 
| 167 | 
            +
                                if image_feature.shape[0] > 1:
         | 
| 168 | 
            +
                                    base_image_feature = image_feature[0]
         | 
| 169 | 
            +
                                    image_feature = image_feature[1:]
         | 
| 170 | 
            +
                                    height = width = self.get_vision_tower().num_patches_per_side
         | 
| 171 | 
            +
                                    assert height * width == base_image_feature.shape[0]
         | 
| 172 | 
            +
                                    if image_aspect_ratio == 'anyres':
         | 
| 173 | 
            +
                                        num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
         | 
| 174 | 
            +
                                        image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
         | 
| 175 | 
            +
                                    else:
         | 
| 176 | 
            +
                                        raise NotImplementedError
         | 
| 177 | 
            +
                                    if 'unpad' in mm_patch_merge_type:
         | 
| 178 | 
            +
                                        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
         | 
| 179 | 
            +
                                        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
         | 
| 180 | 
            +
                                        image_feature = unpad_image(image_feature, image_sizes[image_idx])
         | 
| 181 | 
            +
                                        image_feature = torch.cat((
         | 
| 182 | 
            +
                                            image_feature,
         | 
| 183 | 
            +
                                            self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
         | 
| 184 | 
            +
                                        ), dim=-1)
         | 
| 185 | 
            +
                                        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
         | 
| 186 | 
            +
                                    else:
         | 
| 187 | 
            +
                                        image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
         | 
| 188 | 
            +
                                        image_feature = image_feature.flatten(0, 3)
         | 
| 189 | 
            +
                                    image_feature = torch.cat((base_image_feature, image_feature), dim=0)
         | 
| 190 | 
            +
                                else:
         | 
| 191 | 
            +
                                    image_feature = image_feature[0]
         | 
| 192 | 
            +
                                    if 'unpad' in mm_patch_merge_type:
         | 
| 193 | 
            +
                                        image_feature = torch.cat((
         | 
| 194 | 
            +
                                            image_feature,
         | 
| 195 | 
            +
                                            self.model.image_newline[None].to(image_feature.device)
         | 
| 196 | 
            +
                                        ), dim=0)
         | 
| 197 | 
            +
                                new_image_features.append(image_feature)
         | 
| 198 | 
            +
                            image_features = new_image_features
         | 
| 199 | 
            +
                        else:
         | 
| 200 | 
            +
                            raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
         | 
| 201 | 
            +
                    else:
         | 
| 202 | 
            +
                        image_features = self.encode_images(images)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # TODO: image start / end is not implemented here to support pretraining.
         | 
| 205 | 
            +
                    if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
         | 
| 206 | 
            +
                        raise NotImplementedError
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # Let's just add dummy tensors if they do not exist,
         | 
| 209 | 
            +
                    # it is a headache to deal with None all the time.
         | 
| 210 | 
            +
                    # But it is not ideal, and if you have a better idea,
         | 
| 211 | 
            +
                    # please open an issue / submit a PR, thanks.
         | 
| 212 | 
            +
                    _labels = labels
         | 
| 213 | 
            +
                    _position_ids = position_ids
         | 
| 214 | 
            +
                    _attention_mask = attention_mask
         | 
| 215 | 
            +
                    if attention_mask is None:
         | 
| 216 | 
            +
                        attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
         | 
| 217 | 
            +
                    else:
         | 
| 218 | 
            +
                        attention_mask = attention_mask.bool()
         | 
| 219 | 
            +
                    if position_ids is None:
         | 
| 220 | 
            +
                        position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
         | 
| 221 | 
            +
                    if labels is None:
         | 
| 222 | 
            +
                        labels = torch.full_like(input_ids, IGNORE_INDEX)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # remove the padding using attention_mask -- FIXME
         | 
| 225 | 
            +
                    _input_ids = input_ids
         | 
| 226 | 
            +
                    input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
         | 
| 227 | 
            +
                    labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    new_input_embeds = []
         | 
| 230 | 
            +
                    new_labels = []
         | 
| 231 | 
            +
                    cur_image_idx = 0
         | 
| 232 | 
            +
                    for batch_idx, cur_input_ids in enumerate(input_ids):
         | 
| 233 | 
            +
                        num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
         | 
| 234 | 
            +
                        if num_images == 0:
         | 
| 235 | 
            +
                            cur_image_features = image_features[cur_image_idx]
         | 
| 236 | 
            +
                            cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
         | 
| 237 | 
            +
                            cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
         | 
| 238 | 
            +
                            new_input_embeds.append(cur_input_embeds)
         | 
| 239 | 
            +
                            new_labels.append(labels[batch_idx])
         | 
| 240 | 
            +
                            cur_image_idx += 1
         | 
| 241 | 
            +
                            continue
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                        image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
         | 
| 244 | 
            +
                        cur_input_ids_noim = []
         | 
| 245 | 
            +
                        cur_labels = labels[batch_idx]
         | 
| 246 | 
            +
                        cur_labels_noim = []
         | 
| 247 | 
            +
                        for i in range(len(image_token_indices) - 1):
         | 
| 248 | 
            +
                            cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
         | 
| 249 | 
            +
                            cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
         | 
| 250 | 
            +
                        split_sizes = [x.shape[0] for x in cur_labels_noim]
         | 
| 251 | 
            +
                        cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
         | 
| 252 | 
            +
                        cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
         | 
| 253 | 
            +
                        cur_new_input_embeds = []
         | 
| 254 | 
            +
                        cur_new_labels = []
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                        for i in range(num_images + 1):
         | 
| 257 | 
            +
                            cur_new_input_embeds.append(cur_input_embeds_no_im[i])
         | 
| 258 | 
            +
                            cur_new_labels.append(cur_labels_noim[i])
         | 
| 259 | 
            +
                            if i < num_images:
         | 
| 260 | 
            +
                                cur_image_features = image_features[cur_image_idx]
         | 
| 261 | 
            +
                                cur_image_idx += 1
         | 
| 262 | 
            +
                                cur_new_input_embeds.append(cur_image_features)
         | 
| 263 | 
            +
                                cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                        cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                        cur_new_input_embeds = torch.cat(cur_new_input_embeds)
         | 
| 268 | 
            +
                        cur_new_labels = torch.cat(cur_new_labels)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                        new_input_embeds.append(cur_new_input_embeds)
         | 
| 271 | 
            +
                        new_labels.append(cur_new_labels)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    # Truncate sequences to max length as image embeddings can make the sequence longer
         | 
| 274 | 
            +
                    tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
         | 
| 275 | 
            +
                    if tokenizer_model_max_length is not None:
         | 
| 276 | 
            +
                        new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
         | 
| 277 | 
            +
                        new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # Combine them
         | 
| 280 | 
            +
                    max_len = max(x.shape[0] for x in new_input_embeds)
         | 
| 281 | 
            +
                    batch_size = len(new_input_embeds)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    new_input_embeds_padded = []
         | 
| 284 | 
            +
                    new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
         | 
| 285 | 
            +
                    attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
         | 
| 286 | 
            +
                    position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
         | 
| 289 | 
            +
                        cur_len = cur_new_embed.shape[0]
         | 
| 290 | 
            +
                        if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
         | 
| 291 | 
            +
                            new_input_embeds_padded.append(torch.cat((
         | 
| 292 | 
            +
                                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
         | 
| 293 | 
            +
                                cur_new_embed
         | 
| 294 | 
            +
                            ), dim=0))
         | 
| 295 | 
            +
                            if cur_len > 0:
         | 
| 296 | 
            +
                                new_labels_padded[i, -cur_len:] = cur_new_labels
         | 
| 297 | 
            +
                                attention_mask[i, -cur_len:] = True
         | 
| 298 | 
            +
                                position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
         | 
| 299 | 
            +
                        else:
         | 
| 300 | 
            +
                            new_input_embeds_padded.append(torch.cat((
         | 
| 301 | 
            +
                                cur_new_embed,
         | 
| 302 | 
            +
                                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
         | 
| 303 | 
            +
                            ), dim=0))
         | 
| 304 | 
            +
                            if cur_len > 0:
         | 
| 305 | 
            +
                                new_labels_padded[i, :cur_len] = cur_new_labels
         | 
| 306 | 
            +
                                attention_mask[i, :cur_len] = True
         | 
| 307 | 
            +
                                position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if _labels is None:
         | 
| 312 | 
            +
                        new_labels = None
         | 
| 313 | 
            +
                    else:
         | 
| 314 | 
            +
                        new_labels = new_labels_padded
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    if _attention_mask is None:
         | 
| 317 | 
            +
                        attention_mask = None
         | 
| 318 | 
            +
                    else:
         | 
| 319 | 
            +
                        attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    if _position_ids is None:
         | 
| 322 | 
            +
                        position_ids = None
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                def initialize_vision_tokenizer(self, model_args, tokenizer):
         | 
| 327 | 
            +
                    if model_args.mm_use_im_patch_token:
         | 
| 328 | 
            +
                        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
         | 
| 329 | 
            +
                        self.resize_token_embeddings(len(tokenizer))
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    if model_args.mm_use_im_start_end:
         | 
| 332 | 
            +
                        num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
         | 
| 333 | 
            +
                        self.resize_token_embeddings(len(tokenizer))
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                        if num_new_tokens > 0:
         | 
| 336 | 
            +
                            input_embeddings = self.get_input_embeddings().weight.data
         | 
| 337 | 
            +
                            output_embeddings = self.get_output_embeddings().weight.data
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                            input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
         | 
| 340 | 
            +
                                dim=0, keepdim=True)
         | 
| 341 | 
            +
                            output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
         | 
| 342 | 
            +
                                dim=0, keepdim=True)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                            input_embeddings[-num_new_tokens:] = input_embeddings_avg
         | 
| 345 | 
            +
                            output_embeddings[-num_new_tokens:] = output_embeddings_avg
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                        if model_args.tune_mm_mlp_adapter:
         | 
| 348 | 
            +
                            for p in self.get_input_embeddings().parameters():
         | 
| 349 | 
            +
                                p.requires_grad = True
         | 
| 350 | 
            +
                            for p in self.get_output_embeddings().parameters():
         | 
| 351 | 
            +
                                p.requires_grad = False
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                        if model_args.pretrain_mm_mlp_adapter:
         | 
| 354 | 
            +
                            mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
         | 
| 355 | 
            +
                            embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
         | 
| 356 | 
            +
                            assert num_new_tokens == 2
         | 
| 357 | 
            +
                            if input_embeddings.shape == embed_tokens_weight.shape:
         | 
| 358 | 
            +
                                input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
         | 
| 359 | 
            +
                            elif embed_tokens_weight.shape[0] == num_new_tokens:
         | 
| 360 | 
            +
                                input_embeddings[-num_new_tokens:] = embed_tokens_weight
         | 
| 361 | 
            +
                            else:
         | 
| 362 | 
            +
                                raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
         | 
| 363 | 
            +
                    elif model_args.mm_use_im_patch_token:
         | 
| 364 | 
            +
                        if model_args.tune_mm_mlp_adapter:
         | 
| 365 | 
            +
                            for p in self.get_input_embeddings().parameters():
         | 
| 366 | 
            +
                                p.requires_grad = False
         | 
| 367 | 
            +
                            for p in self.get_output_embeddings().parameters():
         | 
| 368 | 
            +
                                p.requires_grad = False
         | 
    	
        dialoggen/llava/model/make_delta.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Usage:
         | 
| 3 | 
            +
            python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
         | 
| 4 | 
            +
            """
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 10 | 
            +
            from llava.model.utils import auto_upgrade
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
         | 
| 14 | 
            +
                print("Loading base model")
         | 
| 15 | 
            +
                base = AutoModelForCausalLM.from_pretrained(
         | 
| 16 | 
            +
                    base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                print("Loading target model")
         | 
| 19 | 
            +
                auto_upgrade(target_model_path)
         | 
| 20 | 
            +
                target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                print("Calculating delta")
         | 
| 23 | 
            +
                for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
         | 
| 24 | 
            +
                    if name not in base.state_dict():
         | 
| 25 | 
            +
                        assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
         | 
| 26 | 
            +
                        continue
         | 
| 27 | 
            +
                    if param.data.shape == base.state_dict()[name].shape:
         | 
| 28 | 
            +
                        param.data -= base.state_dict()[name]
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
         | 
| 31 | 
            +
                        bparam = base.state_dict()[name]
         | 
| 32 | 
            +
                        param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                print("Saving delta")
         | 
| 35 | 
            +
                if hub_repo_id:
         | 
| 36 | 
            +
                    kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    kwargs = {}
         | 
| 39 | 
            +
                target.save_pretrained(delta_path, **kwargs)
         | 
| 40 | 
            +
                target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
         | 
| 41 | 
            +
                target_tokenizer.save_pretrained(delta_path, **kwargs)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            if __name__ == "__main__":
         | 
| 45 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 46 | 
            +
                parser.add_argument("--base-model-path", type=str, required=True)
         | 
| 47 | 
            +
                parser.add_argument("--target-model-path", type=str, required=True)
         | 
| 48 | 
            +
                parser.add_argument("--delta-path", type=str, required=True)
         | 
| 49 | 
            +
                parser.add_argument("--hub-repo-id", type=str, default=None)
         | 
| 50 | 
            +
                args = parser.parse_args()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
         | 
    	
        dialoggen/llava/model/multimodal_encoder/builder.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from .clip_encoder import CLIPVisionTower
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def build_vision_tower(vision_tower_cfg, **kwargs):
         | 
| 6 | 
            +
                vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
         | 
| 7 | 
            +
                is_absolute_path_exists = os.path.exists(vision_tower)
         | 
| 8 | 
            +
                if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
         | 
| 9 | 
            +
                    return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                raise ValueError(f'Unknown vision tower: {vision_tower}')
         | 
    	
        dialoggen/llava/model/multimodal_encoder/clip_encoder.py
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class CLIPVisionTower(nn.Module):
         | 
| 8 | 
            +
                def __init__(self, vision_tower, args,  delay_load=False):
         | 
| 9 | 
            +
                    super().__init__()
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                    self.is_loaded = False
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                    self.vision_tower_name = vision_tower
         | 
| 14 | 
            +
                    self.select_layer = args.mm_vision_select_layer
         | 
| 15 | 
            +
                    self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    if not delay_load:
         | 
| 18 | 
            +
                        self.load_model()
         | 
| 19 | 
            +
                    elif getattr(args, 'unfreeze_mm_vision_tower', False):
         | 
| 20 | 
            +
                        self.load_model()
         | 
| 21 | 
            +
                    else:
         | 
| 22 | 
            +
                        self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def load_model(self, device_map=None):
         | 
| 25 | 
            +
                    if self.is_loaded:
         | 
| 26 | 
            +
                        print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
         | 
| 27 | 
            +
                        return
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
         | 
| 30 | 
            +
                    self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
         | 
| 31 | 
            +
                    self.vision_tower.requires_grad_(False)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.is_loaded = True
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def feature_select(self, image_forward_outs):
         | 
| 36 | 
            +
                    image_features = image_forward_outs.hidden_states[self.select_layer]
         | 
| 37 | 
            +
                    if self.select_feature == 'patch':
         | 
| 38 | 
            +
                        image_features = image_features[:, 1:]
         | 
| 39 | 
            +
                    elif self.select_feature == 'cls_patch':
         | 
| 40 | 
            +
                        image_features = image_features
         | 
| 41 | 
            +
                    else:
         | 
| 42 | 
            +
                        raise ValueError(f'Unexpected select feature: {self.select_feature}')
         | 
| 43 | 
            +
                    return image_features
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                @torch.no_grad()
         | 
| 46 | 
            +
                def forward(self, images):
         | 
| 47 | 
            +
                    if type(images) is list:
         | 
| 48 | 
            +
                        image_features = []
         | 
| 49 | 
            +
                        for image in images:
         | 
| 50 | 
            +
                            image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
         | 
| 51 | 
            +
                            image_feature = self.feature_select(image_forward_out).to(image.dtype)
         | 
| 52 | 
            +
                            image_features.append(image_feature)
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
         | 
| 55 | 
            +
                        image_features = self.feature_select(image_forward_outs).to(images.dtype)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    return image_features
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                @property
         | 
| 60 | 
            +
                def dummy_feature(self):
         | 
| 61 | 
            +
                    return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                @property
         | 
| 64 | 
            +
                def dtype(self):
         | 
| 65 | 
            +
                    return self.vision_tower.dtype
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                @property
         | 
| 68 | 
            +
                def device(self):
         | 
| 69 | 
            +
                    return self.vision_tower.device
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                @property
         | 
| 72 | 
            +
                def config(self):
         | 
| 73 | 
            +
                    if self.is_loaded:
         | 
| 74 | 
            +
                        return self.vision_tower.config
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        return self.cfg_only
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                @property
         | 
| 79 | 
            +
                def hidden_size(self):
         | 
| 80 | 
            +
                    return self.config.hidden_size
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                @property
         | 
| 83 | 
            +
                def num_patches_per_side(self):
         | 
| 84 | 
            +
                    return self.config.image_size // self.config.patch_size
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                @property
         | 
| 87 | 
            +
                def num_patches(self):
         | 
| 88 | 
            +
                    return (self.config.image_size // self.config.patch_size) ** 2
         | 
    	
        dialoggen/llava/model/multimodal_projector/builder.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class IdentityMap(nn.Module):
         | 
| 7 | 
            +
                def __init__(self):
         | 
| 8 | 
            +
                    super().__init__()
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def forward(self, x, *args, **kwargs):
         | 
| 11 | 
            +
                    return x
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                @property
         | 
| 14 | 
            +
                def config(self):
         | 
| 15 | 
            +
                    return {"mm_projector_type": 'identity'}
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class SimpleResBlock(nn.Module):
         | 
| 19 | 
            +
                def __init__(self, channels):
         | 
| 20 | 
            +
                    super().__init__()
         | 
| 21 | 
            +
                    self.pre_norm = nn.LayerNorm(channels)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.proj = nn.Sequential(
         | 
| 24 | 
            +
                        nn.Linear(channels, channels),
         | 
| 25 | 
            +
                        nn.GELU(),
         | 
| 26 | 
            +
                        nn.Linear(channels, channels)
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
                def forward(self, x):
         | 
| 29 | 
            +
                    x = self.pre_norm(x)
         | 
| 30 | 
            +
                    return x + self.proj(x)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def build_vision_projector(config, delay_load=False, **kwargs):
         | 
| 34 | 
            +
                projector_type = getattr(config, 'mm_projector_type', 'linear')
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                if projector_type == 'linear':
         | 
| 37 | 
            +
                    return nn.Linear(config.mm_hidden_size, config.hidden_size)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
         | 
| 40 | 
            +
                if mlp_gelu_match:
         | 
| 41 | 
            +
                    mlp_depth = int(mlp_gelu_match.group(1))
         | 
| 42 | 
            +
                    modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
         | 
| 43 | 
            +
                    for _ in range(1, mlp_depth):
         | 
| 44 | 
            +
                        modules.append(nn.GELU())
         | 
| 45 | 
            +
                        modules.append(nn.Linear(config.hidden_size, config.hidden_size))
         | 
| 46 | 
            +
                    return nn.Sequential(*modules)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                if projector_type == 'identity':
         | 
| 49 | 
            +
                    return IdentityMap()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                raise ValueError(f'Unknown projector type: {projector_type}')
         | 
    	
        dialoggen/llava/model/utils.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import AutoConfig
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def auto_upgrade(config):
         | 
| 5 | 
            +
                cfg = AutoConfig.from_pretrained(config)
         | 
| 6 | 
            +
                if 'llava' in config and 'llava' not in cfg.model_type:
         | 
| 7 | 
            +
                    assert cfg.model_type == 'llama'
         | 
| 8 | 
            +
                    print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
         | 
| 9 | 
            +
                    print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
         | 
| 10 | 
            +
                    confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
         | 
| 11 | 
            +
                    if confirm.lower() in ["y", "yes"]:
         | 
| 12 | 
            +
                        print("Upgrading checkpoint...")
         | 
| 13 | 
            +
                        assert len(cfg.architectures) == 1
         | 
| 14 | 
            +
                        setattr(cfg.__class__, "model_type", "llava")
         | 
| 15 | 
            +
                        cfg.architectures[0] = 'LlavaLlamaForCausalLM'
         | 
| 16 | 
            +
                        cfg.save_pretrained(config)
         | 
| 17 | 
            +
                        print("Checkpoint upgraded.")
         | 
| 18 | 
            +
                    else:
         | 
| 19 | 
            +
                        print("Checkpoint upgrade aborted.")
         | 
| 20 | 
            +
                        exit(1)
         | 
    	
        dialoggen/llava/utils.py
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import datetime
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            import logging.handlers
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import sys
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import requests
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from llava.constants import LOGDIR
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
         | 
| 12 | 
            +
            moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            handler = None
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def build_logger(logger_name, logger_filename):
         | 
| 18 | 
            +
                global handler
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                formatter = logging.Formatter(
         | 
| 21 | 
            +
                    fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
         | 
| 22 | 
            +
                    datefmt="%Y-%m-%d %H:%M:%S",
         | 
| 23 | 
            +
                )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # Set the format of root handlers
         | 
| 26 | 
            +
                if not logging.getLogger().handlers:
         | 
| 27 | 
            +
                    logging.basicConfig(level=logging.INFO)
         | 
| 28 | 
            +
                logging.getLogger().handlers[0].setFormatter(formatter)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # Redirect stdout and stderr to loggers
         | 
| 31 | 
            +
                stdout_logger = logging.getLogger("stdout")
         | 
| 32 | 
            +
                stdout_logger.setLevel(logging.INFO)
         | 
| 33 | 
            +
                sl = StreamToLogger(stdout_logger, logging.INFO)
         | 
| 34 | 
            +
                sys.stdout = sl
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                stderr_logger = logging.getLogger("stderr")
         | 
| 37 | 
            +
                stderr_logger.setLevel(logging.ERROR)
         | 
| 38 | 
            +
                sl = StreamToLogger(stderr_logger, logging.ERROR)
         | 
| 39 | 
            +
                sys.stderr = sl
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # Get logger
         | 
| 42 | 
            +
                logger = logging.getLogger(logger_name)
         | 
| 43 | 
            +
                logger.setLevel(logging.INFO)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # Add a file handler for all loggers
         | 
| 46 | 
            +
                if handler is None:
         | 
| 47 | 
            +
                    os.makedirs(LOGDIR, exist_ok=True)
         | 
| 48 | 
            +
                    filename = os.path.join(LOGDIR, logger_filename)
         | 
| 49 | 
            +
                    handler = logging.handlers.TimedRotatingFileHandler(
         | 
| 50 | 
            +
                        filename, when='D', utc=True, encoding='UTF-8')
         | 
| 51 | 
            +
                    handler.setFormatter(formatter)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    for name, item in logging.root.manager.loggerDict.items():
         | 
| 54 | 
            +
                        if isinstance(item, logging.Logger):
         | 
| 55 | 
            +
                            item.addHandler(handler)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                return logger
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class StreamToLogger(object):
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                Fake file-like stream object that redirects writes to a logger instance.
         | 
| 63 | 
            +
                """
         | 
| 64 | 
            +
                def __init__(self, logger, log_level=logging.INFO):
         | 
| 65 | 
            +
                    self.terminal = sys.stdout
         | 
| 66 | 
            +
                    self.logger = logger
         | 
| 67 | 
            +
                    self.log_level = log_level
         | 
| 68 | 
            +
                    self.linebuf = ''
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def __getattr__(self, attr):
         | 
| 71 | 
            +
                    return getattr(self.terminal, attr)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def write(self, buf):
         | 
| 74 | 
            +
                    temp_linebuf = self.linebuf + buf
         | 
| 75 | 
            +
                    self.linebuf = ''
         | 
| 76 | 
            +
                    for line in temp_linebuf.splitlines(True):
         | 
| 77 | 
            +
                        # From the io.TextIOWrapper docs:
         | 
| 78 | 
            +
                        #   On output, if newline is None, any '\n' characters written
         | 
| 79 | 
            +
                        #   are translated to the system default line separator.
         | 
| 80 | 
            +
                        # By default sys.stdout.write() expects '\n' newlines and then
         | 
| 81 | 
            +
                        # translates them so this is still cross platform.
         | 
| 82 | 
            +
                        if line[-1] == '\n':
         | 
| 83 | 
            +
                            self.logger.log(self.log_level, line.rstrip())
         | 
| 84 | 
            +
                        else:
         | 
| 85 | 
            +
                            self.linebuf += line
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def flush(self):
         | 
| 88 | 
            +
                    if self.linebuf != '':
         | 
| 89 | 
            +
                        self.logger.log(self.log_level, self.linebuf.rstrip())
         | 
| 90 | 
            +
                    self.linebuf = ''
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def disable_torch_init():
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                Disable the redundant torch default initialization to accelerate model creation.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                import torch
         | 
| 98 | 
            +
                setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
         | 
| 99 | 
            +
                setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def violates_moderation(text):
         | 
| 103 | 
            +
                """
         | 
| 104 | 
            +
                Check whether the text violates OpenAI moderation API.
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                url = "https://api.openai.com/v1/moderations"
         | 
| 107 | 
            +
                headers = {"Content-Type": "application/json",
         | 
| 108 | 
            +
                           "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
         | 
| 109 | 
            +
                text = text.replace("\n", "")
         | 
| 110 | 
            +
                data = "{" + '"input": ' + f'"{text}"' + "}"
         | 
| 111 | 
            +
                data = data.encode("utf-8")
         | 
| 112 | 
            +
                try:
         | 
| 113 | 
            +
                    ret = requests.post(url, headers=headers, data=data, timeout=5)
         | 
| 114 | 
            +
                    flagged = ret.json()["results"][0]["flagged"]
         | 
| 115 | 
            +
                except requests.exceptions.RequestException as e:
         | 
| 116 | 
            +
                    flagged = False
         | 
| 117 | 
            +
                except KeyError as e:
         | 
| 118 | 
            +
                    flagged = False
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                return flagged
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def pretty_print_semaphore(semaphore):
         | 
| 124 | 
            +
                if semaphore is None:
         | 
| 125 | 
            +
                    return "None"
         | 
| 126 | 
            +
                return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
         | 
    	
        en.csv
    ADDED
    
    | @@ -0,0 +1,22 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            key,value
         | 
| 2 | 
            +
            size,Size
         | 
| 3 | 
            +
            sampler,Sampler
         | 
| 4 | 
            +
            prompt,Prompt
         | 
| 5 | 
            +
            default prompt,"A cute cat"
         | 
| 6 | 
            +
            negative_prompt,Negative Prompt
         | 
| 7 | 
            +
            seed,Seed
         | 
| 8 | 
            +
            cfg,CFG Scale
         | 
| 9 | 
            +
            infer steps,Sampling Steps
         | 
| 10 | 
            +
            batch size,Batch Size
         | 
| 11 | 
            +
            width cond,Width Cond
         | 
| 12 | 
            +
            height cond,Height Cond
         | 
| 13 | 
            +
            enhance,Prompt Enhancement
         | 
| 14 | 
            +
            run,Submit
         | 
| 15 | 
            +
            square,Square(1024x1024)
         | 
| 16 | 
            +
            landscape,Landscape(1280x768)
         | 
| 17 | 
            +
            portrait,Portrait(768x1280)
         | 
| 18 | 
            +
            accordion,Advanced Options
         | 
| 19 | 
            +
            generated image,HunYuanDiT Generated Image
         | 
| 20 | 
            +
            examples,More Examples
         | 
| 21 | 
            +
            title,Hunyuan-DiT
         | 
| 22 | 
            +
            desc,A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
         | 
    	
        environment.yml
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: HunyuanDiT
         | 
| 2 | 
            +
            channels:
         | 
| 3 | 
            +
              - pytorch
         | 
| 4 | 
            +
              - nvidia
         | 
| 5 | 
            +
            dependencies:
         | 
| 6 | 
            +
              - python=3.8.12
         | 
| 7 | 
            +
              - pytorch=1.13.1
         | 
| 8 | 
            +
              - pip
         | 
    	
        example_prompts.txt
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影
         | 
| 2 | 
            +
            湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。
         | 
| 3 | 
            +
            太阳微微升起,花园里的玫瑰花瓣上露珠晶莹剔透,一只瓢虫正在爬向露珠,背景是清晨的花园,微距镜头
         | 
| 4 | 
            +
            一位女明星,中国人,头发是黑色,衣服是纯白色短袖,人物风格清新,城市背景
         | 
| 5 | 
            +
            后印象主义风格,一条古老的石板路上面散落着金黄色的树叶。路旁的风车在静谧地转动,后面竖着两个风车。背景是一片向日葵田,蓝天上飘着几朵白云
         | 
| 6 | 
            +
            一幅细致的油画描绘了一只年轻獾轻轻嗅着一朵明亮的黄色玫瑰时错综复杂的皮毛。背景是一棵大树干的粗糙纹理,獾的爪子轻轻地挖进树皮。在柔和的背景中,一个宁静的瀑布倾泻而下,它的水在绿色植物中闪烁着蓝色。
         | 
| 7 | 
            +
            渔舟唱晚
         | 
| 8 | 
            +
            请将杞人忧天的样子画出来
         | 
| 9 | 
            +
            一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。
         | 
| 10 | 
            +
            插画风格,一只狐狸和一只刺猬坐在水边的石头上,刺猬手里拿着一杯茶,狐狸旁边放着一个玻璃杯。周围是茂密的绿色植物和树木,阳光透过树叶洒在水面上,画面宁静温馨。
         | 
| 11 | 
            +
            泥塑风格,一座五彩斑斓的花园在画面中展现,各种各样的花朵,绿色的叶子和一只正在嬉戏的小猫形成了一幅生动的图像,背景是蓝天和白云
         | 
| 12 | 
            +
            枯藤老树昏鸦,小桥流水人家
         | 
| 13 | 
            +
            一张细致的照片捕捉到了一尊雕像的形象,这尊雕像酷似一位古代法老,头上出人意料地戴着一副青铜蒸汽朋克护目镜。这座雕像穿着复古时髦,一件清爽的白色T恤和一件合身的黑色皮夹克,与传统的头饰形成鲜明对比。背景是简单的纯色,突出了雕像的非传统服装和蒸汽朋克眼镜的复杂细节。
         | 
| 14 | 
            +
            一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头,
         | 
| 15 | 
            +
            一只可爱的猫, 细节真实, 摄影
         | 
| 16 | 
            +
            飞流直下三千尺,疑是银河落九天
         | 
| 17 | 
            +
            成语“鲤鱼跃龙门”
         | 
| 18 | 
            +
            一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子
         | 
| 19 | 
            +
            九寨沟
         | 
| 20 | 
            +
            摄影风格,在画面中心是一盘热气腾腾的麻婆豆腐,豆腐呈白色,上面撒着一层红色的辣酱,有些许绿色的葱花点缀,背景是深色木质餐桌,桌子上放有辣椒和葱花作为点缀。
         | 
| 21 | 
            +
            一位年轻女子站在春季的火车站月台上。她身着蓝灰色长风衣,白色衬衫。她的深棕色头发扎成低马尾,几缕碎发随风飘扬。她的眼神充满期待,阳光洒在她温暖的脸庞上。
         | 
| 22 | 
            +
            一只优雅的白鹤在湖边静静地站立,它的身体纯白色,翅膀轻轻展开,背景是湖面和远处的山脉
         | 
| 23 | 
            +
            国画风格,苏州园林中的小桥流水,周围是郁郁葱葱的树,池塘里有几朵绽放的荷花,背景是宁静的江南水乡
         | 
| 24 | 
            +
            现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景
         | 
| 25 | 
            +
            醉后不知天在水,满船清梦压星河
         | 
| 26 | 
            +
            长城
         | 
| 27 | 
            +
            一个亚洲中年男士在夕阳下的公园长椅上静坐。他穿着一件深蓝色的针织毛衣和灰色裤子。他的头发略显花白,手中拿着一本敞开的书。面带微笑,眼神温和,周围是落日余晖和四周的绿树。
         | 
| 28 | 
            +
            风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景
         | 
    	
        hydit/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        hydit/config.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from .constants import *
         | 
| 4 | 
            +
            from .modules.models import HUNYUAN_DIT_CONFIG
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def get_args(default_args=None):
         | 
| 8 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                # Basic
         | 
| 11 | 
            +
                parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.")
         | 
| 12 | 
            +
                parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.")
         | 
| 13 | 
            +
                parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024],
         | 
| 14 | 
            +
                                    help='Image size (h, w). If a single value is provided, the image will be treated to '
         | 
| 15 | 
            +
                                         '(value, value).')
         | 
| 16 | 
            +
                parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch",
         | 
| 17 | 
            +
                                    help="Inference mode")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                # HunYuan-DiT
         | 
| 20 | 
            +
                parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2')
         | 
| 21 | 
            +
                parser.add_argument("--norm", type=str, default="layer", help="Normalization layer type")
         | 
| 22 | 
            +
                parser.add_argument("--load-key", type=str, choices=["ema", "module"], default="ema", help="Load model key for HunYuanDiT checkpoint.")
         | 
| 23 | 
            +
                parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024],
         | 
| 24 | 
            +
                                    help="Size condition used in sampling. 2 values are required for height and width. "
         | 
| 25 | 
            +
                                         "If a single value is provided, the image will be treated to (value, value).")
         | 
| 26 | 
            +
                parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                # Prompt enhancement
         | 
| 29 | 
            +
                parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.")
         | 
| 30 | 
            +
                parser.add_argument("--no-enhance", dest="enhance", action="store_false")
         | 
| 31 | 
            +
                parser.set_defaults(enhance=True)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                # Diffusion
         | 
| 34 | 
            +
                parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
         | 
| 35 | 
            +
                parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
         | 
| 36 | 
            +
                parser.set_defaults(learn_sigma=True)
         | 
| 37 | 
            +
                parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction",
         | 
| 38 | 
            +
                                    help="Diffusion predict type")
         | 
| 39 | 
            +
                parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear",
         | 
| 40 | 
            +
                                    help="Noise schedule")
         | 
| 41 | 
            +
                parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value")
         | 
| 42 | 
            +
                parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value")
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # Text condition
         | 
| 45 | 
            +
                parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
         | 
| 46 | 
            +
                parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
         | 
| 47 | 
            +
                parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.")
         | 
| 48 | 
            +
                parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.")
         | 
| 49 | 
            +
                parser.add_argument("--negative", type=str, default=None, help="Negative prompt.")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # Acceleration
         | 
| 52 | 
            +
                parser.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.")
         | 
| 53 | 
            +
                parser.add_argument("--no-fp16", dest="use_fp16", action="store_false")
         | 
| 54 | 
            +
                parser.set_defaults(use_fp16=True)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # Sampling
         | 
| 57 | 
            +
                parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size")
         | 
| 58 | 
            +
                parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler")
         | 
| 59 | 
            +
                parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps")
         | 
| 60 | 
            +
                parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # App
         | 
| 63 | 
            +
                parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                args = parser.parse_args(default_args)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                return args
         | 
    	
        hydit/constants.py
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # =======================================================
         | 
| 2 | 
            +
            NOISE_SCHEDULES = {
         | 
| 3 | 
            +
                "linear",
         | 
| 4 | 
            +
                "scaled_linear",
         | 
| 5 | 
            +
                "squaredcos_cap_v2",
         | 
| 6 | 
            +
            }
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            PREDICT_TYPE = {
         | 
| 9 | 
            +
                "epsilon",
         | 
| 10 | 
            +
                "sample",
         | 
| 11 | 
            +
                "v_prediction",
         | 
| 12 | 
            +
            }
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            # =======================================================
         | 
| 15 | 
            +
            NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            # =======================================================
         | 
| 19 | 
            +
            # Constants about models
         | 
| 20 | 
            +
            # =======================================================
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            SAMPLER_FACTORY = {
         | 
| 23 | 
            +
                'ddpm': {
         | 
| 24 | 
            +
                    'scheduler': 'DDPMScheduler',
         | 
| 25 | 
            +
                    'name': 'DDPM',
         | 
| 26 | 
            +
                    'kwargs': {
         | 
| 27 | 
            +
                        'steps_offset': 1,
         | 
| 28 | 
            +
                        'clip_sample': False,
         | 
| 29 | 
            +
                        'clip_sample_range': 1.0,
         | 
| 30 | 
            +
                        'beta_schedule': 'scaled_linear',
         | 
| 31 | 
            +
                        'beta_start': 0.00085,
         | 
| 32 | 
            +
                        'beta_end': 0.03,
         | 
| 33 | 
            +
                        'prediction_type': 'v_prediction',
         | 
| 34 | 
            +
                    }
         | 
| 35 | 
            +
                },
         | 
| 36 | 
            +
                'ddim': {
         | 
| 37 | 
            +
                    'scheduler': 'DDIMScheduler',
         | 
| 38 | 
            +
                    'name': 'DDIM',
         | 
| 39 | 
            +
                    'kwargs': {
         | 
| 40 | 
            +
                        'steps_offset': 1,
         | 
| 41 | 
            +
                        'clip_sample': False,
         | 
| 42 | 
            +
                        'clip_sample_range': 1.0,
         | 
| 43 | 
            +
                        'beta_schedule': 'scaled_linear',
         | 
| 44 | 
            +
                        'beta_start': 0.00085,
         | 
| 45 | 
            +
                        'beta_end': 0.03,
         | 
| 46 | 
            +
                        'prediction_type': 'v_prediction',
         | 
| 47 | 
            +
                    }
         | 
| 48 | 
            +
                },
         | 
| 49 | 
            +
                'dpmms': {
         | 
| 50 | 
            +
                    'scheduler': 'DPMSolverMultistepScheduler',
         | 
| 51 | 
            +
                    'name': 'DPMMS',
         | 
| 52 | 
            +
                    'kwargs': {
         | 
| 53 | 
            +
                        'beta_schedule': 'scaled_linear',
         | 
| 54 | 
            +
                        'beta_start': 0.00085,
         | 
| 55 | 
            +
                        'beta_end': 0.03,
         | 
| 56 | 
            +
                        'prediction_type': 'v_prediction',
         | 
| 57 | 
            +
                        'trained_betas': None,
         | 
| 58 | 
            +
                        'solver_order': 2,
         | 
| 59 | 
            +
                        'algorithm_type': 'dpmsolver++',
         | 
| 60 | 
            +
                    }
         | 
| 61 | 
            +
                },
         | 
| 62 | 
            +
            }
         | 
    	
        hydit/diffusion/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        hydit/diffusion/pipeline.py
    ADDED
    
    | @@ -0,0 +1,830 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 2 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 3 | 
            +
            # You may obtain a copy of the License at
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 8 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 9 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 10 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 11 | 
            +
            # limitations under the License.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import inspect
         | 
| 14 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Union
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import PIL
         | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torchvision.transforms as T
         | 
| 20 | 
            +
            from diffusers.configuration_utils import FrozenDict
         | 
| 21 | 
            +
            from diffusers.image_processor import VaeImageProcessor
         | 
| 22 | 
            +
            from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
         | 
| 23 | 
            +
            from diffusers.models import AutoencoderKL, UNet2DConditionModel
         | 
| 24 | 
            +
            from diffusers.models.lora import adjust_lora_scale_text_encoder
         | 
| 25 | 
            +
            from diffusers.pipelines.pipeline_utils import DiffusionPipeline
         | 
| 26 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 27 | 
            +
            from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
         | 
| 28 | 
            +
            from diffusers.schedulers import KarrasDiffusionSchedulers
         | 
| 29 | 
            +
            from diffusers.utils import (
         | 
| 30 | 
            +
                PIL_INTERPOLATION,
         | 
| 31 | 
            +
                deprecate,
         | 
| 32 | 
            +
                logging,
         | 
| 33 | 
            +
                replace_example_docstring,
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 36 | 
            +
            from transformers import BertModel, BertTokenizer
         | 
| 37 | 
            +
            from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            from ..modules.models import HunYuanDiT
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            EXAMPLE_DOC_STRING = """
         | 
| 44 | 
            +
                Examples:
         | 
| 45 | 
            +
                    ```py
         | 
| 46 | 
            +
                    >>> import requests
         | 
| 47 | 
            +
                    >>> import torch
         | 
| 48 | 
            +
                    >>> from PIL import Image
         | 
| 49 | 
            +
                    >>> from io import BytesIO
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    >>> from diffusers import StableDiffusionImg2ImgPipeline
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    >>> device = "cuda"
         | 
| 54 | 
            +
                    >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
         | 
| 55 | 
            +
                    >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
         | 
| 56 | 
            +
                    >>> pipe = pipe.to(device)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    >>> response = requests.get(url)
         | 
| 61 | 
            +
                    >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
         | 
| 62 | 
            +
                    >>> init_image = init_image.resize((768, 512))
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    >>> prompt = "A fantasy landscape, trending on artstation"
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
         | 
| 67 | 
            +
                    >>> images[0].save("fantasy_landscape.png")
         | 
| 68 | 
            +
                    ```
         | 
| 69 | 
            +
            """
         | 
| 70 | 
            +
            def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
         | 
| 71 | 
            +
                """
         | 
| 72 | 
            +
                Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
         | 
| 73 | 
            +
                Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
         | 
| 76 | 
            +
                std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
         | 
| 77 | 
            +
                # rescale the results from guidance (fixes overexposure)
         | 
| 78 | 
            +
                noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
         | 
| 79 | 
            +
                # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
         | 
| 80 | 
            +
                noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
         | 
| 81 | 
            +
                return noise_cfg
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            def preprocess(image):
         | 
| 84 | 
            +
                deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
         | 
| 85 | 
            +
                deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 86 | 
            +
                if isinstance(image, torch.Tensor):
         | 
| 87 | 
            +
                    return image
         | 
| 88 | 
            +
                elif isinstance(image, PIL.Image.Image):
         | 
| 89 | 
            +
                    image = [image]
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                if isinstance(image[0], PIL.Image.Image):
         | 
| 92 | 
            +
                    w, h = image[0].size
         | 
| 93 | 
            +
                    w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
         | 
| 96 | 
            +
                    image = np.concatenate(image, axis=0)
         | 
| 97 | 
            +
                    image = np.array(image).astype(np.float32) / 255.0
         | 
| 98 | 
            +
                    image = image.transpose(0, 3, 1, 2)
         | 
| 99 | 
            +
                    image = 2.0 * image - 1.0
         | 
| 100 | 
            +
                    image = torch.from_numpy(image)
         | 
| 101 | 
            +
                elif isinstance(image[0], torch.Tensor):
         | 
| 102 | 
            +
                    image = torch.cat(image, dim=0)
         | 
| 103 | 
            +
                return image
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            class StableDiffusionPipeline(
         | 
| 107 | 
            +
                DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
         | 
| 108 | 
            +
            ):
         | 
| 109 | 
            +
                r"""
         | 
| 110 | 
            +
                Pipeline for text-guided image-to-image generation using Stable Diffusion.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
         | 
| 113 | 
            +
                implemented for all pipelines (downloading, saving, running on a particular device, etc.).
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                The pipeline also inherits the following loading methods:
         | 
| 116 | 
            +
                    - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
         | 
| 117 | 
            +
                    - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
         | 
| 118 | 
            +
                    - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
         | 
| 119 | 
            +
                    - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                Args:
         | 
| 122 | 
            +
                    vae ([`AutoencoderKL`]):
         | 
| 123 | 
            +
                        Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
         | 
| 124 | 
            +
                    text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
         | 
| 125 | 
            +
                        Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
         | 
| 126 | 
            +
                    tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
         | 
| 127 | 
            +
                        A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
         | 
| 128 | 
            +
                    unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]):
         | 
| 129 | 
            +
                        A `UNet2DConditionModel` to denoise the encoded image latents.
         | 
| 130 | 
            +
                    scheduler ([`SchedulerMixin`]):
         | 
| 131 | 
            +
                        A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
         | 
| 132 | 
            +
                        [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
         | 
| 133 | 
            +
                    safety_checker ([`StableDiffusionSafetyChecker`]):
         | 
| 134 | 
            +
                        Classification module that estimates whether generated images could be considered offensive or harmful.
         | 
| 135 | 
            +
                        Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
         | 
| 136 | 
            +
                        about a model's potential harms.
         | 
| 137 | 
            +
                    feature_extractor ([`~transformers.CLIPImageProcessor`]):
         | 
| 138 | 
            +
                        A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
         | 
| 139 | 
            +
                """
         | 
| 140 | 
            +
                model_cpu_offload_seq = "text_encoder->unet->vae"
         | 
| 141 | 
            +
                _optional_components = ["safety_checker", "feature_extractor"]
         | 
| 142 | 
            +
                _exclude_from_cpu_offload = ["safety_checker"]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def __init__(
         | 
| 145 | 
            +
                        self,
         | 
| 146 | 
            +
                        vae: AutoencoderKL,
         | 
| 147 | 
            +
                        text_encoder: Union[BertModel, CLIPTextModel],
         | 
| 148 | 
            +
                        tokenizer: Union[BertTokenizer, CLIPTokenizer],
         | 
| 149 | 
            +
                        unet: Union[HunYuanDiT, UNet2DConditionModel],
         | 
| 150 | 
            +
                        scheduler: KarrasDiffusionSchedulers,
         | 
| 151 | 
            +
                        safety_checker: StableDiffusionSafetyChecker,
         | 
| 152 | 
            +
                        feature_extractor: CLIPImageProcessor,
         | 
| 153 | 
            +
                        requires_safety_checker: bool = True,
         | 
| 154 | 
            +
                        progress_bar_config: Dict[str, Any] = None,
         | 
| 155 | 
            +
                        embedder_t5=None,
         | 
| 156 | 
            +
                        infer_mode='torch',
         | 
| 157 | 
            +
                ):
         | 
| 158 | 
            +
                    super().__init__()
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    # ========================================================
         | 
| 161 | 
            +
                    self.embedder_t5 = embedder_t5
         | 
| 162 | 
            +
                    self.infer_mode = infer_mode
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # ========================================================
         | 
| 165 | 
            +
                    if progress_bar_config is None:
         | 
| 166 | 
            +
                        progress_bar_config = {}
         | 
| 167 | 
            +
                    if not hasattr(self, '_progress_bar_config'):
         | 
| 168 | 
            +
                        self._progress_bar_config = {}
         | 
| 169 | 
            +
                    self._progress_bar_config.update(progress_bar_config)
         | 
| 170 | 
            +
                    # ========================================================
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
         | 
| 173 | 
            +
                        deprecation_message = (
         | 
| 174 | 
            +
                            f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
         | 
| 175 | 
            +
                            f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
         | 
| 176 | 
            +
                            "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
         | 
| 177 | 
            +
                            " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
         | 
| 178 | 
            +
                            " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
         | 
| 179 | 
            +
                            " file"
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                        deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 182 | 
            +
                        new_config = dict(scheduler.config)
         | 
| 183 | 
            +
                        new_config["steps_offset"] = 1
         | 
| 184 | 
            +
                        scheduler._internal_dict = FrozenDict(new_config)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
         | 
| 187 | 
            +
                        deprecation_message = (
         | 
| 188 | 
            +
                            f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
         | 
| 189 | 
            +
                            " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
         | 
| 190 | 
            +
                            " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
         | 
| 191 | 
            +
                            " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
         | 
| 192 | 
            +
                            " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
         | 
| 193 | 
            +
                        )
         | 
| 194 | 
            +
                        deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 195 | 
            +
                        new_config = dict(scheduler.config)
         | 
| 196 | 
            +
                        new_config["clip_sample"] = False
         | 
| 197 | 
            +
                        scheduler._internal_dict = FrozenDict(new_config)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if safety_checker is None and requires_safety_checker:
         | 
| 200 | 
            +
                        logger.warning(
         | 
| 201 | 
            +
                            f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
         | 
| 202 | 
            +
                            " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
         | 
| 203 | 
            +
                            " results in services or applications open to the public. Both the diffusers team and Hugging Face"
         | 
| 204 | 
            +
                            " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
         | 
| 205 | 
            +
                            " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
         | 
| 206 | 
            +
                            " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
         | 
| 207 | 
            +
                        )
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    if safety_checker is not None and feature_extractor is None:
         | 
| 210 | 
            +
                        raise ValueError(
         | 
| 211 | 
            +
                            "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
         | 
| 212 | 
            +
                            " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
         | 
| 213 | 
            +
                        )
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    self.register_modules(
         | 
| 216 | 
            +
                        vae=vae,
         | 
| 217 | 
            +
                        text_encoder=text_encoder,
         | 
| 218 | 
            +
                        tokenizer=tokenizer,
         | 
| 219 | 
            +
                        unet=unet,
         | 
| 220 | 
            +
                        scheduler=scheduler,
         | 
| 221 | 
            +
                        safety_checker=safety_checker,
         | 
| 222 | 
            +
                        feature_extractor=feature_extractor,
         | 
| 223 | 
            +
                    )
         | 
| 224 | 
            +
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 225 | 
            +
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 226 | 
            +
                    self.register_to_config(requires_safety_checker=requires_safety_checker)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
         | 
| 229 | 
            +
                def _encode_prompt(
         | 
| 230 | 
            +
                        self,
         | 
| 231 | 
            +
                        prompt,
         | 
| 232 | 
            +
                        device,
         | 
| 233 | 
            +
                        num_images_per_prompt,
         | 
| 234 | 
            +
                        do_classifier_free_guidance,
         | 
| 235 | 
            +
                        negative_prompt=None,
         | 
| 236 | 
            +
                        prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 237 | 
            +
                        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 238 | 
            +
                        lora_scale: Optional[float] = None,
         | 
| 239 | 
            +
                ):
         | 
| 240 | 
            +
                    deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
         | 
| 241 | 
            +
                    deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    prompt_embeds_tuple = self.encode_prompt(
         | 
| 244 | 
            +
                        prompt=prompt,
         | 
| 245 | 
            +
                        device=device,
         | 
| 246 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 247 | 
            +
                        do_classifier_free_guidance=do_classifier_free_guidance,
         | 
| 248 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 249 | 
            +
                        prompt_embeds=prompt_embeds,
         | 
| 250 | 
            +
                        negative_prompt_embeds=negative_prompt_embeds,
         | 
| 251 | 
            +
                        lora_scale=lora_scale,
         | 
| 252 | 
            +
                    )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # concatenate for backwards comp
         | 
| 255 | 
            +
                    prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    return prompt_embeds
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
         | 
| 260 | 
            +
                def encode_prompt(
         | 
| 261 | 
            +
                        self,
         | 
| 262 | 
            +
                        prompt,
         | 
| 263 | 
            +
                        device,
         | 
| 264 | 
            +
                        num_images_per_prompt,
         | 
| 265 | 
            +
                        do_classifier_free_guidance,
         | 
| 266 | 
            +
                        negative_prompt=None,
         | 
| 267 | 
            +
                        prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 268 | 
            +
                        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 269 | 
            +
                        lora_scale: Optional[float] = None,
         | 
| 270 | 
            +
                        embedder=None,
         | 
| 271 | 
            +
                ):
         | 
| 272 | 
            +
                    r"""
         | 
| 273 | 
            +
                    Encodes the prompt into text encoder hidden states.
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    Args:
         | 
| 276 | 
            +
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 277 | 
            +
                            prompt to be encoded
         | 
| 278 | 
            +
                        device: (`torch.device`):
         | 
| 279 | 
            +
                            torch device
         | 
| 280 | 
            +
                        num_images_per_prompt (`int`):
         | 
| 281 | 
            +
                            number of images that should be generated per prompt
         | 
| 282 | 
            +
                        do_classifier_free_guidance (`bool`):
         | 
| 283 | 
            +
                            whether to use classifier free guidance or not
         | 
| 284 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 285 | 
            +
                            The prompt or prompts not to guide the image generation. If not defined, one has to pass
         | 
| 286 | 
            +
                            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
         | 
| 287 | 
            +
                            less than `1`).
         | 
| 288 | 
            +
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 289 | 
            +
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 290 | 
            +
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 291 | 
            +
                        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 292 | 
            +
                            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
         | 
| 293 | 
            +
                            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
         | 
| 294 | 
            +
                            argument.
         | 
| 295 | 
            +
                        lora_scale (`float`, *optional*):
         | 
| 296 | 
            +
                            A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
         | 
| 297 | 
            +
                        embedder:
         | 
| 298 | 
            +
                            T5 embedder (including text encoder and tokenizer)
         | 
| 299 | 
            +
                    """
         | 
| 300 | 
            +
                    if embedder is None:
         | 
| 301 | 
            +
                        text_encoder = self.text_encoder
         | 
| 302 | 
            +
                        tokenizer = self.tokenizer
         | 
| 303 | 
            +
                        max_length = self.tokenizer.model_max_length
         | 
| 304 | 
            +
                    else:
         | 
| 305 | 
            +
                        text_encoder = embedder.model
         | 
| 306 | 
            +
                        tokenizer = embedder.tokenizer
         | 
| 307 | 
            +
                        max_length = embedder.max_length
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # set lora scale so that monkey patched LoRA
         | 
| 310 | 
            +
                    # function of text encoder can correctly access it
         | 
| 311 | 
            +
                    if lora_scale is not None and isinstance(self, LoraLoaderMixin):
         | 
| 312 | 
            +
                        self._lora_scale = lora_scale
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                        # dynamically adjust the LoRA scale
         | 
| 315 | 
            +
                        adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    if prompt is not None and isinstance(prompt, str):
         | 
| 318 | 
            +
                        batch_size = 1
         | 
| 319 | 
            +
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 320 | 
            +
                        batch_size = len(prompt)
         | 
| 321 | 
            +
                    else:
         | 
| 322 | 
            +
                        batch_size = prompt_embeds.shape[0]
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    if prompt_embeds is None:
         | 
| 325 | 
            +
                        # textual inversion: procecss multi-vector tokens if necessary
         | 
| 326 | 
            +
                        if isinstance(self, TextualInversionLoaderMixin):
         | 
| 327 | 
            +
                            prompt = self.maybe_convert_prompt(prompt, tokenizer)
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                        text_inputs = tokenizer(
         | 
| 330 | 
            +
                            prompt,
         | 
| 331 | 
            +
                            padding="max_length",
         | 
| 332 | 
            +
                            max_length=max_length,
         | 
| 333 | 
            +
                            truncation=True,
         | 
| 334 | 
            +
                            return_attention_mask=True,
         | 
| 335 | 
            +
                            return_tensors="pt",
         | 
| 336 | 
            +
                        )
         | 
| 337 | 
            +
                        text_input_ids = text_inputs.input_ids
         | 
| 338 | 
            +
                        untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
         | 
| 341 | 
            +
                                text_input_ids, untruncated_ids
         | 
| 342 | 
            +
                        ):
         | 
| 343 | 
            +
                            removed_text = tokenizer.batch_decode(
         | 
| 344 | 
            +
                                untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
         | 
| 345 | 
            +
                            )
         | 
| 346 | 
            +
                            logger.warning(
         | 
| 347 | 
            +
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 348 | 
            +
                                f" {tokenizer.model_max_length} tokens: {removed_text}"
         | 
| 349 | 
            +
                            )
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                        attention_mask = text_inputs.attention_mask.to(device)
         | 
| 352 | 
            +
                        prompt_embeds = text_encoder(
         | 
| 353 | 
            +
                            text_input_ids.to(device),
         | 
| 354 | 
            +
                            attention_mask=attention_mask,
         | 
| 355 | 
            +
                        )
         | 
| 356 | 
            +
                        prompt_embeds = prompt_embeds[0]
         | 
| 357 | 
            +
                        attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 358 | 
            +
                    else:
         | 
| 359 | 
            +
                        attention_mask = None
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    if text_encoder is not None:
         | 
| 362 | 
            +
                        prompt_embeds_dtype = text_encoder.dtype
         | 
| 363 | 
            +
                    elif self.unet is not None:
         | 
| 364 | 
            +
                        prompt_embeds_dtype = self.unet.dtype
         | 
| 365 | 
            +
                    else:
         | 
| 366 | 
            +
                        prompt_embeds_dtype = prompt_embeds.dtype
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 371 | 
            +
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 372 | 
            +
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 373 | 
            +
                    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    # get unconditional embeddings for classifier free guidance
         | 
| 376 | 
            +
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 377 | 
            +
                        uncond_tokens: List[str]
         | 
| 378 | 
            +
                        if negative_prompt is None:
         | 
| 379 | 
            +
                            uncond_tokens = [""] * batch_size
         | 
| 380 | 
            +
                        elif prompt is not None and type(prompt) is not type(negative_prompt):
         | 
| 381 | 
            +
                            raise TypeError(
         | 
| 382 | 
            +
                                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 383 | 
            +
                                f" {type(prompt)}."
         | 
| 384 | 
            +
                            )
         | 
| 385 | 
            +
                        elif isinstance(negative_prompt, str):
         | 
| 386 | 
            +
                            uncond_tokens = [negative_prompt]
         | 
| 387 | 
            +
                        elif batch_size != len(negative_prompt):
         | 
| 388 | 
            +
                            raise ValueError(
         | 
| 389 | 
            +
                                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 390 | 
            +
                                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 391 | 
            +
                                " the batch size of `prompt`."
         | 
| 392 | 
            +
                            )
         | 
| 393 | 
            +
                        else:
         | 
| 394 | 
            +
                            uncond_tokens = negative_prompt
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                        # textual inversion: procecss multi-vector tokens if necessary
         | 
| 397 | 
            +
                        if isinstance(self, TextualInversionLoaderMixin):
         | 
| 398 | 
            +
                            uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                        max_length = prompt_embeds.shape[1]
         | 
| 401 | 
            +
                        uncond_input = tokenizer(
         | 
| 402 | 
            +
                            uncond_tokens,
         | 
| 403 | 
            +
                            padding="max_length",
         | 
| 404 | 
            +
                            max_length=max_length,
         | 
| 405 | 
            +
                            truncation=True,
         | 
| 406 | 
            +
                            return_tensors="pt",
         | 
| 407 | 
            +
                        )
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                        uncond_attention_mask = uncond_input.attention_mask.to(device)
         | 
| 410 | 
            +
                        negative_prompt_embeds = text_encoder(
         | 
| 411 | 
            +
                            uncond_input.input_ids.to(device),
         | 
| 412 | 
            +
                            attention_mask=uncond_attention_mask,
         | 
| 413 | 
            +
                        )
         | 
| 414 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds[0]
         | 
| 415 | 
            +
                        uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 416 | 
            +
                    else:
         | 
| 417 | 
            +
                        uncond_attention_mask = None
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    if do_classifier_free_guidance:
         | 
| 420 | 
            +
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
| 421 | 
            +
                        seq_len = negative_prompt_embeds.shape[1]
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 426 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                def _convert_to_rgb(self, image):
         | 
| 431 | 
            +
                    return image.convert('RGB')
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                def image_transform(self, image_size=224):
         | 
| 434 | 
            +
                    transform = T.Compose([
         | 
| 435 | 
            +
                        T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC),
         | 
| 436 | 
            +
                        self._convert_to_rgb,
         | 
| 437 | 
            +
                        T.ToTensor(),
         | 
| 438 | 
            +
                        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         | 
| 439 | 
            +
                    ])
         | 
| 440 | 
            +
                    return transform
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                def encode_img(self, img, device, do_classifier_free_guidance):
         | 
| 443 | 
            +
                    # print('len', len(img))
         | 
| 444 | 
            +
                    # print('img', img.size)
         | 
| 445 | 
            +
                    img = img[0]    # TODO: support batch processing
         | 
| 446 | 
            +
                    image_preprocess = self.image_transform(224)
         | 
| 447 | 
            +
                    img_for_clip = image_preprocess(img)
         | 
| 448 | 
            +
                    # print('img_for_clip', img_for_clip.shape)
         | 
| 449 | 
            +
                    img_for_clip = img_for_clip.unsqueeze(0)
         | 
| 450 | 
            +
                    img_clip_embedding = self.img_encoder(img_for_clip.to(device)).to(dtype=torch.float16)
         | 
| 451 | 
            +
                    # print('img_clip_embedding_1_type', img_clip_embedding.dtype)
         | 
| 452 | 
            +
                    if do_classifier_free_guidance:
         | 
| 453 | 
            +
                        negative_img_clip_embedding = torch.zeros_like(img_clip_embedding)
         | 
| 454 | 
            +
                    return img_clip_embedding, negative_img_clip_embedding
         | 
| 455 | 
            +
             | 
| 456 | 
            +
             | 
| 457 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
         | 
| 458 | 
            +
                def run_safety_checker(self, image, device, dtype):
         | 
| 459 | 
            +
                    if self.safety_checker is None:
         | 
| 460 | 
            +
                        has_nsfw_concept = None
         | 
| 461 | 
            +
                    else:
         | 
| 462 | 
            +
                        if torch.is_tensor(image):
         | 
| 463 | 
            +
                            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
         | 
| 464 | 
            +
                        else:
         | 
| 465 | 
            +
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         | 
| 466 | 
            +
                        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
         | 
| 467 | 
            +
                        image, has_nsfw_concept = self.safety_checker(
         | 
| 468 | 
            +
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         | 
| 469 | 
            +
                        )
         | 
| 470 | 
            +
                    return image, has_nsfw_concept
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
         | 
| 473 | 
            +
                def decode_latents(self, latents):
         | 
| 474 | 
            +
                    deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
         | 
| 475 | 
            +
                    deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    latents = 1 / self.vae.config.scaling_factor * latents
         | 
| 478 | 
            +
                    image = self.vae.decode(latents, return_dict=False)[0]
         | 
| 479 | 
            +
                    image = (image / 2 + 0.5).clamp(0, 1)
         | 
| 480 | 
            +
                    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
         | 
| 481 | 
            +
                    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
         | 
| 482 | 
            +
                    return image
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 485 | 
            +
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 486 | 
            +
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 487 | 
            +
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 488 | 
            +
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 489 | 
            +
                    # and should be between [0, 1]
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 492 | 
            +
                    extra_step_kwargs = {}
         | 
| 493 | 
            +
                    if accepts_eta:
         | 
| 494 | 
            +
                        extra_step_kwargs["eta"] = eta
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    # check if the scheduler accepts generator
         | 
| 497 | 
            +
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 498 | 
            +
                    if accepts_generator:
         | 
| 499 | 
            +
                        extra_step_kwargs["generator"] = generator
         | 
| 500 | 
            +
                    return extra_step_kwargs
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                def check_inputs(
         | 
| 503 | 
            +
                        self,
         | 
| 504 | 
            +
                        prompt,
         | 
| 505 | 
            +
                        height,
         | 
| 506 | 
            +
                        width,
         | 
| 507 | 
            +
                        callback_steps,
         | 
| 508 | 
            +
                        negative_prompt=None,
         | 
| 509 | 
            +
                        prompt_embeds=None,
         | 
| 510 | 
            +
                        negative_prompt_embeds=None,
         | 
| 511 | 
            +
                ):
         | 
| 512 | 
            +
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 513 | 
            +
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    if (callback_steps is None) or (
         | 
| 516 | 
            +
                            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
         | 
| 517 | 
            +
                    ):
         | 
| 518 | 
            +
                        raise ValueError(
         | 
| 519 | 
            +
                            f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
         | 
| 520 | 
            +
                            f" {type(callback_steps)}."
         | 
| 521 | 
            +
                        )
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    if prompt is not None and prompt_embeds is not None:
         | 
| 524 | 
            +
                        raise ValueError(
         | 
| 525 | 
            +
                            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
         | 
| 526 | 
            +
                            " only forward one of the two."
         | 
| 527 | 
            +
                        )
         | 
| 528 | 
            +
                    elif prompt is None and prompt_embeds is None:
         | 
| 529 | 
            +
                        raise ValueError(
         | 
| 530 | 
            +
                            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
         | 
| 531 | 
            +
                        )
         | 
| 532 | 
            +
                    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
         | 
| 533 | 
            +
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    if negative_prompt is not None and negative_prompt_embeds is not None:
         | 
| 536 | 
            +
                        raise ValueError(
         | 
| 537 | 
            +
                            f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
         | 
| 538 | 
            +
                            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
         | 
| 539 | 
            +
                        )
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                    if prompt_embeds is not None and negative_prompt_embeds is not None:
         | 
| 542 | 
            +
                        if prompt_embeds.shape != negative_prompt_embeds.shape:
         | 
| 543 | 
            +
                            raise ValueError(
         | 
| 544 | 
            +
                                "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
         | 
| 545 | 
            +
                                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
         | 
| 546 | 
            +
                                f" {negative_prompt_embeds.shape}."
         | 
| 547 | 
            +
                            )
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                def get_timesteps(self, num_inference_steps, strength, device):
         | 
| 550 | 
            +
                    # get the original timestep using init_timestep
         | 
| 551 | 
            +
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    t_start = max(num_inference_steps - init_timestep, 0)
         | 
| 554 | 
            +
                    timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    return timesteps, num_inference_steps - t_start
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
         | 
| 559 | 
            +
                    shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 560 | 
            +
                    if isinstance(generator, list) and len(generator) != batch_size:
         | 
| 561 | 
            +
                        raise ValueError(
         | 
| 562 | 
            +
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         | 
| 563 | 
            +
                            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         | 
| 564 | 
            +
                        )
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    if latents is None:
         | 
| 567 | 
            +
                        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 568 | 
            +
                    else:
         | 
| 569 | 
            +
                        latents = latents.to(device)
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 572 | 
            +
                    latents = latents * self.scheduler.init_noise_sigma
         | 
| 573 | 
            +
                    return latents
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                @torch.no_grad()
         | 
| 576 | 
            +
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 577 | 
            +
                def __call__(
         | 
| 578 | 
            +
                        self,
         | 
| 579 | 
            +
                        height: int,
         | 
| 580 | 
            +
                        width: int,
         | 
| 581 | 
            +
                        prompt: Union[str, List[str]] = None,
         | 
| 582 | 
            +
                        num_inference_steps: Optional[int] = 50,
         | 
| 583 | 
            +
                        guidance_scale: Optional[float] = 7.5,
         | 
| 584 | 
            +
                        negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 585 | 
            +
                        num_images_per_prompt: Optional[int] = 1,
         | 
| 586 | 
            +
                        eta: Optional[float] = 0.0,
         | 
| 587 | 
            +
                        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 588 | 
            +
                        latents: Optional[torch.FloatTensor] = None,
         | 
| 589 | 
            +
                        prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 590 | 
            +
                        prompt_embeds_t5: Optional[torch.FloatTensor] = None,
         | 
| 591 | 
            +
                        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 592 | 
            +
                        negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
         | 
| 593 | 
            +
                        output_type: Optional[str] = "pil",
         | 
| 594 | 
            +
                        return_dict: bool = True,
         | 
| 595 | 
            +
                        callback: Optional[Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]] = None,
         | 
| 596 | 
            +
                        callback_steps: int = 1,
         | 
| 597 | 
            +
                        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 598 | 
            +
                        guidance_rescale: float = 0.0,
         | 
| 599 | 
            +
                        image_meta_size: Optional[torch.LongTensor] = None,
         | 
| 600 | 
            +
                        style: Optional[torch.LongTensor] = None,
         | 
| 601 | 
            +
                        progress: bool = True,
         | 
| 602 | 
            +
                        use_fp16: bool = False,
         | 
| 603 | 
            +
                        freqs_cis_img: Optional[tuple] = None,
         | 
| 604 | 
            +
                        learn_sigma: bool = True,
         | 
| 605 | 
            +
                ):
         | 
| 606 | 
            +
                    r"""
         | 
| 607 | 
            +
                    The call function to the pipeline for generation.
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    Args:
         | 
| 610 | 
            +
                        height (`int`):
         | 
| 611 | 
            +
                            The height in pixels of the generated image.
         | 
| 612 | 
            +
                        width (`int`):
         | 
| 613 | 
            +
                            The width in pixels of the generated image.
         | 
| 614 | 
            +
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 615 | 
            +
                            The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
         | 
| 616 | 
            +
                        image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
         | 
| 617 | 
            +
                            `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
         | 
| 618 | 
            +
                            numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
         | 
| 619 | 
            +
                            or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
         | 
| 620 | 
            +
                            list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
         | 
| 621 | 
            +
                            latents as `image`, but if passing latents directly it is not encoded again.
         | 
| 622 | 
            +
                        strength (`float`, *optional*, defaults to 1.0):
         | 
| 623 | 
            +
                            Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
         | 
| 624 | 
            +
                            starting point and more noise is added the higher the `strength`. The number of denoising steps depends
         | 
| 625 | 
            +
                            on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
         | 
| 626 | 
            +
                            process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
         | 
| 627 | 
            +
                            essentially ignores `image`.
         | 
| 628 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 629 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 630 | 
            +
                            expense of slower inference. This parameter is modulated by `strength`.
         | 
| 631 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 7.5):
         | 
| 632 | 
            +
                            A higher guidance scale value encourages the model to generate images closely linked to the text
         | 
| 633 | 
            +
                            `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
         | 
| 634 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 635 | 
            +
                            The prompt or prompts to guide what to not include in image generation. If not defined, you need to
         | 
| 636 | 
            +
                            pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
         | 
| 637 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 638 | 
            +
                            The number of images to generate per prompt.
         | 
| 639 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 640 | 
            +
                            Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
         | 
| 641 | 
            +
                            to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
         | 
| 642 | 
            +
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         | 
| 643 | 
            +
                            A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
         | 
| 644 | 
            +
                            generation deterministic.
         | 
| 645 | 
            +
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 646 | 
            +
                            Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
         | 
| 647 | 
            +
                            provided, text embeddings are generated from the `prompt` input argument.
         | 
| 648 | 
            +
                        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 649 | 
            +
                            Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
         | 
| 650 | 
            +
                            not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
         | 
| 651 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 652 | 
            +
                            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
         | 
| 653 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 654 | 
            +
                            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
         | 
| 655 | 
            +
                            plain tuple.
         | 
| 656 | 
            +
                        callback (`Callable`, *optional*):
         | 
| 657 | 
            +
                            A function that calls every `callback_steps` steps during inference. The function is called with the
         | 
| 658 | 
            +
                            following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor,
         | 
| 659 | 
            +
                            pred_x0: torch.FloatTensor)`.
         | 
| 660 | 
            +
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 661 | 
            +
                            The frequency at which the `callback` function is called. If not specified, the callback is called at
         | 
| 662 | 
            +
                            every step.
         | 
| 663 | 
            +
                        cross_attention_kwargs (`dict`, *optional*):
         | 
| 664 | 
            +
                            A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
         | 
| 665 | 
            +
                            [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                    Examples:
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    Returns:
         | 
| 670 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 671 | 
            +
                            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
         | 
| 672 | 
            +
                            otherwise a `tuple` is returned where the first element is a list with the generated images and the
         | 
| 673 | 
            +
                            second element is a list of `bool`s indicating whether the corresponding generated image contains
         | 
| 674 | 
            +
                            "not-safe-for-work" (nsfw) content.
         | 
| 675 | 
            +
                    """
         | 
| 676 | 
            +
                    # 1. Check inputs. Raise error if not correct
         | 
| 677 | 
            +
                    self.check_inputs(
         | 
| 678 | 
            +
                        prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
         | 
| 679 | 
            +
                    )
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                    # 2. Define call parameters
         | 
| 682 | 
            +
                    if prompt is not None and isinstance(prompt, str):
         | 
| 683 | 
            +
                        batch_size = 1
         | 
| 684 | 
            +
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 685 | 
            +
                        batch_size = len(prompt)
         | 
| 686 | 
            +
                    else:
         | 
| 687 | 
            +
                        batch_size = prompt_embeds.shape[0]
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    device = self._execution_device
         | 
| 690 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 691 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 692 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 693 | 
            +
                    do_classifier_free_guidance = guidance_scale > 1.0
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    # 3. Encode input prompt
         | 
| 696 | 
            +
                    text_encoder_lora_scale = (
         | 
| 697 | 
            +
                        cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
         | 
| 698 | 
            +
                    )
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \
         | 
| 701 | 
            +
                        self.encode_prompt(prompt,
         | 
| 702 | 
            +
                                           device,
         | 
| 703 | 
            +
                                           num_images_per_prompt,
         | 
| 704 | 
            +
                                           do_classifier_free_guidance,
         | 
| 705 | 
            +
                                           negative_prompt,
         | 
| 706 | 
            +
                                           prompt_embeds=prompt_embeds,
         | 
| 707 | 
            +
                                           negative_prompt_embeds=negative_prompt_embeds,
         | 
| 708 | 
            +
                                           lora_scale=text_encoder_lora_scale,
         | 
| 709 | 
            +
                                           )
         | 
| 710 | 
            +
                    prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \
         | 
| 711 | 
            +
                        self.encode_prompt(prompt,
         | 
| 712 | 
            +
                                           device,
         | 
| 713 | 
            +
                                           num_images_per_prompt,
         | 
| 714 | 
            +
                                           do_classifier_free_guidance,
         | 
| 715 | 
            +
                                           negative_prompt,
         | 
| 716 | 
            +
                                           prompt_embeds=prompt_embeds_t5,
         | 
| 717 | 
            +
                                           negative_prompt_embeds=negative_prompt_embeds_t5,
         | 
| 718 | 
            +
                                           lora_scale=text_encoder_lora_scale,
         | 
| 719 | 
            +
                                           embedder=self.embedder_t5,
         | 
| 720 | 
            +
                                           )
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    # For classifier free guidance, we need to do two forward passes.
         | 
| 723 | 
            +
                    # Here we concatenate the unconditional and text embeddings into a single batch
         | 
| 724 | 
            +
                    # to avoid doing two forward passes
         | 
| 725 | 
            +
                    if do_classifier_free_guidance:
         | 
| 726 | 
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
         | 
| 727 | 
            +
                        attention_mask = torch.cat([uncond_attention_mask, attention_mask])
         | 
| 728 | 
            +
                        prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5])
         | 
| 729 | 
            +
                        attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5])
         | 
| 730 | 
            +
             | 
| 731 | 
            +
                    # 4. Prepare timesteps
         | 
| 732 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 733 | 
            +
                    timesteps = self.scheduler.timesteps
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                    # 6. Prepare latent variables
         | 
| 736 | 
            +
                    num_channels_latents = self.unet.config.in_channels
         | 
| 737 | 
            +
                    latents = self.prepare_latents(batch_size * num_images_per_prompt,
         | 
| 738 | 
            +
                                                   num_channels_latents,
         | 
| 739 | 
            +
                                                   height,
         | 
| 740 | 
            +
                                                   width,
         | 
| 741 | 
            +
                                                   prompt_embeds.dtype,
         | 
| 742 | 
            +
                                                   device,
         | 
| 743 | 
            +
                                                   generator,
         | 
| 744 | 
            +
                                                   latents,
         | 
| 745 | 
            +
                                                   )
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 748 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                    # 8. Denoising loop
         | 
| 751 | 
            +
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 752 | 
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 753 | 
            +
                        for i, t in enumerate(timesteps):
         | 
| 754 | 
            +
                            # expand the latents if we are doing classifier free guidance
         | 
| 755 | 
            +
                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         | 
| 756 | 
            +
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 757 | 
            +
                            # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
         | 
| 758 | 
            +
                            t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device)
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                            if use_fp16:
         | 
| 761 | 
            +
                                latent_model_input = latent_model_input.half()
         | 
| 762 | 
            +
                                t_expand = t_expand.half()
         | 
| 763 | 
            +
                                prompt_embeds = prompt_embeds.half()
         | 
| 764 | 
            +
                                ims = image_meta_size.half() if image_meta_size is not None else None
         | 
| 765 | 
            +
                            else:
         | 
| 766 | 
            +
                                ims = image_meta_size if image_meta_size is not None else None
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                            # predict the noise residual
         | 
| 769 | 
            +
                            if self.infer_mode in ["fa", "torch"]:
         | 
| 770 | 
            +
                                noise_pred = self.unet(
         | 
| 771 | 
            +
                                    latent_model_input,
         | 
| 772 | 
            +
                                    t_expand,
         | 
| 773 | 
            +
                                    encoder_hidden_states=prompt_embeds,
         | 
| 774 | 
            +
                                    text_embedding_mask=attention_mask,
         | 
| 775 | 
            +
                                    encoder_hidden_states_t5=prompt_embeds_t5,
         | 
| 776 | 
            +
                                    text_embedding_mask_t5=attention_mask_t5,
         | 
| 777 | 
            +
                                    image_meta_size=ims,
         | 
| 778 | 
            +
                                    style=style,
         | 
| 779 | 
            +
                                    cos_cis_img=freqs_cis_img[0],
         | 
| 780 | 
            +
                                    sin_cis_img=freqs_cis_img[1],
         | 
| 781 | 
            +
                                    return_dict=False,
         | 
| 782 | 
            +
                                )
         | 
| 783 | 
            +
                            elif self.infer_mode == "trt":
         | 
| 784 | 
            +
                                raise NotImplementedError("TensorRT model is not supported yet.")
         | 
| 785 | 
            +
                            else:
         | 
| 786 | 
            +
                                raise ValueError("[ERROR] invalid inference mode! please check your config file")
         | 
| 787 | 
            +
                            if learn_sigma:
         | 
| 788 | 
            +
                                noise_pred, _ = noise_pred.chunk(2, dim=1)
         | 
| 789 | 
            +
             | 
| 790 | 
            +
                            # perform guidance
         | 
| 791 | 
            +
                            if do_classifier_free_guidance:
         | 
| 792 | 
            +
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 793 | 
            +
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                            if do_classifier_free_guidance and guidance_rescale > 0.0:
         | 
| 796 | 
            +
                                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
         | 
| 797 | 
            +
                                noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 800 | 
            +
                            results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
         | 
| 801 | 
            +
                            latents = results.prev_sample
         | 
| 802 | 
            +
                            pred_x0 = results.pred_original_sample if hasattr(results, 'pred_original_sample') else None
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                            # call the callback, if provided
         | 
| 805 | 
            +
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 806 | 
            +
                                progress_bar.update()
         | 
| 807 | 
            +
                                if callback is not None and i % callback_steps == 0:
         | 
| 808 | 
            +
                                    callback(i, t, latents, pred_x0)
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                    if not output_type == "latent":
         | 
| 811 | 
            +
                        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
         | 
| 812 | 
            +
                        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
         | 
| 813 | 
            +
                    else:
         | 
| 814 | 
            +
                        image = latents
         | 
| 815 | 
            +
                        has_nsfw_concept = None
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                    if has_nsfw_concept is None:
         | 
| 818 | 
            +
                        do_denormalize = [True] * image.shape[0]
         | 
| 819 | 
            +
                    else:
         | 
| 820 | 
            +
                        do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                    image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                    # Offload all models
         | 
| 825 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                    if not return_dict:
         | 
| 828 | 
            +
                        return (image, has_nsfw_concept)
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
         | 
    	
        hydit/inference.py
    ADDED
    
    | @@ -0,0 +1,389 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            from pathlib import Path
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # For reproducibility
         | 
| 9 | 
            +
            # torch.backends.cudnn.benchmark = False
         | 
| 10 | 
            +
            # torch.backends.cudnn.deterministic = True
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from diffusers import schedulers
         | 
| 13 | 
            +
            from diffusers.models import AutoencoderKL
         | 
| 14 | 
            +
            from loguru import logger
         | 
| 15 | 
            +
            from transformers import BertModel, BertTokenizer
         | 
| 16 | 
            +
            from transformers.modeling_utils import logger as tf_logger
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT
         | 
| 19 | 
            +
            from .diffusion.pipeline import StableDiffusionPipeline
         | 
| 20 | 
            +
            from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
         | 
| 21 | 
            +
            from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
         | 
| 22 | 
            +
            from .modules.text_encoder import MT5Embedder
         | 
| 23 | 
            +
            from .utils.tools import set_seeds
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class Resolution:
         | 
| 27 | 
            +
                def __init__(self, width, height):
         | 
| 28 | 
            +
                    self.width = width
         | 
| 29 | 
            +
                    self.height = height
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __str__(self):
         | 
| 32 | 
            +
                    return f'{self.height}x{self.width}'
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class ResolutionGroup:
         | 
| 36 | 
            +
                def __init__(self):
         | 
| 37 | 
            +
                    self.data = [
         | 
| 38 | 
            +
                        Resolution(768, 768),   # 1:1
         | 
| 39 | 
            +
                        Resolution(1024, 1024), # 1:1
         | 
| 40 | 
            +
                        Resolution(1280, 1280), # 1:1
         | 
| 41 | 
            +
                        Resolution(1024, 768),  # 4:3
         | 
| 42 | 
            +
                        Resolution(1152, 864),  # 4:3
         | 
| 43 | 
            +
                        Resolution(1280, 960),  # 4:3
         | 
| 44 | 
            +
                        Resolution(768, 1024),  # 3:4
         | 
| 45 | 
            +
                        Resolution(864, 1152),  # 3:4
         | 
| 46 | 
            +
                        Resolution(960, 1280),  # 3:4
         | 
| 47 | 
            +
                        Resolution(1280, 768),  # 16:9
         | 
| 48 | 
            +
                        Resolution(768, 1280),  # 9:16
         | 
| 49 | 
            +
                    ]
         | 
| 50 | 
            +
                    self.supported_sizes = set([(r.width, r.height) for r in self.data])
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def is_valid(self, width, height):
         | 
| 53 | 
            +
                    return (width, height) in self.supported_sizes
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            STANDARD_RATIO = np.array([
         | 
| 57 | 
            +
                1.0,        # 1:1
         | 
| 58 | 
            +
                4.0 / 3.0,  # 4:3
         | 
| 59 | 
            +
                3.0 / 4.0,  # 3:4
         | 
| 60 | 
            +
                16.0 / 9.0, # 16:9
         | 
| 61 | 
            +
                9.0 / 16.0, # 9:16
         | 
| 62 | 
            +
            ])
         | 
| 63 | 
            +
            STANDARD_SHAPE = [
         | 
| 64 | 
            +
                [(768, 768), (1024, 1024), (1280, 1280)],   # 1:1
         | 
| 65 | 
            +
                [(1024, 768), (1152, 864), (1280, 960)],    # 4:3
         | 
| 66 | 
            +
                [(768, 1024), (864, 1152), (960, 1280)],    # 3:4
         | 
| 67 | 
            +
                [(1280, 768)],                              # 16:9
         | 
| 68 | 
            +
                [(768, 1280)],                              # 9:16
         | 
| 69 | 
            +
            ]
         | 
| 70 | 
            +
            STANDARD_AREA = [
         | 
| 71 | 
            +
                np.array([w * h for w, h in shapes])
         | 
| 72 | 
            +
                for shapes in STANDARD_SHAPE
         | 
| 73 | 
            +
            ]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def get_standard_shape(target_width, target_height):
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                Map image size to standard size.
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
                target_ratio = target_width / target_height
         | 
| 81 | 
            +
                closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
         | 
| 82 | 
            +
                closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
         | 
| 83 | 
            +
                width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
         | 
| 84 | 
            +
                return width, height
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def _to_tuple(val):
         | 
| 88 | 
            +
                if isinstance(val, (list, tuple)):
         | 
| 89 | 
            +
                    if len(val) == 1:
         | 
| 90 | 
            +
                        val = [val[0], val[0]]
         | 
| 91 | 
            +
                    elif len(val) == 2:
         | 
| 92 | 
            +
                        val = tuple(val)
         | 
| 93 | 
            +
                    else:
         | 
| 94 | 
            +
                        raise ValueError(f"Invalid value: {val}")
         | 
| 95 | 
            +
                elif isinstance(val, (int, float)):
         | 
| 96 | 
            +
                    val = (val, val)
         | 
| 97 | 
            +
                else:
         | 
| 98 | 
            +
                    raise ValueError(f"Invalid value: {val}")
         | 
| 99 | 
            +
                return val
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
         | 
| 103 | 
            +
                             embedder_t5, infer_mode, sampler=None):
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                Get scheduler and pipeline for sampling. The sampler and pipeline are both
         | 
| 106 | 
            +
                based on diffusers and make some modifications.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Returns
         | 
| 109 | 
            +
                -------
         | 
| 110 | 
            +
                pipeline: StableDiffusionPipeline
         | 
| 111 | 
            +
                sampler_name: str
         | 
| 112 | 
            +
                """
         | 
| 113 | 
            +
                sampler = sampler or args.sampler
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                # Load sampler from factory
         | 
| 116 | 
            +
                kwargs = SAMPLER_FACTORY[sampler]['kwargs']
         | 
| 117 | 
            +
                scheduler = SAMPLER_FACTORY[sampler]['scheduler']
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                # Update sampler according to the arguments
         | 
| 120 | 
            +
                kwargs['beta_schedule'] = args.noise_schedule
         | 
| 121 | 
            +
                kwargs['beta_start'] = args.beta_start
         | 
| 122 | 
            +
                kwargs['beta_end'] = args.beta_end
         | 
| 123 | 
            +
                kwargs['prediction_type'] = args.predict_type
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                # Build scheduler according to the sampler.
         | 
| 126 | 
            +
                scheduler_class = getattr(schedulers, scheduler)
         | 
| 127 | 
            +
                scheduler = scheduler_class(**kwargs)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                # Set timesteps for inference steps.
         | 
| 130 | 
            +
                scheduler.set_timesteps(args.infer_steps, device)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                # Only enable progress bar for rank 0
         | 
| 133 | 
            +
                progress_bar_config = {} if rank == 0 else {'disable': True}
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                pipeline = StableDiffusionPipeline(vae=vae,
         | 
| 136 | 
            +
                                                   text_encoder=text_encoder,
         | 
| 137 | 
            +
                                                   tokenizer=tokenizer,
         | 
| 138 | 
            +
                                                   unet=model,
         | 
| 139 | 
            +
                                                   scheduler=scheduler,
         | 
| 140 | 
            +
                                                   feature_extractor=None,
         | 
| 141 | 
            +
                                                   safety_checker=None,
         | 
| 142 | 
            +
                                                   requires_safety_checker=False,
         | 
| 143 | 
            +
                                                   progress_bar_config=progress_bar_config,
         | 
| 144 | 
            +
                                                   embedder_t5=embedder_t5,
         | 
| 145 | 
            +
                                                   infer_mode=infer_mode,
         | 
| 146 | 
            +
                                                   )
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                pipeline = pipeline.to(device)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                return pipeline, sampler
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            class End2End(object):
         | 
| 154 | 
            +
                def __init__(self, args, models_root_path):
         | 
| 155 | 
            +
                    self.args = args
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # Check arguments
         | 
| 158 | 
            +
                    t2i_root_path = Path(models_root_path) / "t2i"
         | 
| 159 | 
            +
                    self.root = t2i_root_path
         | 
| 160 | 
            +
                    logger.info(f"Got text-to-image model root path: {t2i_root_path}")
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    # Set device and disable gradient
         | 
| 163 | 
            +
                    self.device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 164 | 
            +
                    torch.set_grad_enabled(False)
         | 
| 165 | 
            +
                    # Disable BertModel logging checkpoint info
         | 
| 166 | 
            +
                    tf_logger.setLevel('ERROR')
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # ========================================================================
         | 
| 169 | 
            +
                    model_dir = self.root / "model"
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # ========================================================================
         | 
| 172 | 
            +
                    logger.info(f"Loading CLIP Text Encoder...")
         | 
| 173 | 
            +
                    text_encoder_path = self.root / "clip_text_encoder"
         | 
| 174 | 
            +
                    self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
         | 
| 175 | 
            +
                    logger.info(f"Loading CLIP Text Encoder finished")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # ========================================================================
         | 
| 178 | 
            +
                    logger.info(f"Loading CLIP Tokenizer...")
         | 
| 179 | 
            +
                    tokenizer_path = self.root / "tokenizer"
         | 
| 180 | 
            +
                    self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
         | 
| 181 | 
            +
                    logger.info(f"Loading CLIP Tokenizer finished")
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # ========================================================================
         | 
| 184 | 
            +
                    logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
         | 
| 185 | 
            +
                    t5_text_encoder_path = self.root / 'mt5'
         | 
| 186 | 
            +
                    embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
         | 
| 187 | 
            +
                    self.embedder_t5 = embedder_t5
         | 
| 188 | 
            +
                    logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    # ========================================================================
         | 
| 191 | 
            +
                    logger.info(f"Loading VAE...")
         | 
| 192 | 
            +
                    vae_path = self.root / "sdxl-vae-fp16-fix"
         | 
| 193 | 
            +
                    self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
         | 
| 194 | 
            +
                    logger.info(f"Loading VAE finished")
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # ========================================================================
         | 
| 197 | 
            +
                    # Create model structure and load the checkpoint
         | 
| 198 | 
            +
                    logger.info(f"Building HunYuan-DiT model...")
         | 
| 199 | 
            +
                    model_config = HUNYUAN_DIT_CONFIG[self.args.model]
         | 
| 200 | 
            +
                    self.patch_size = model_config['patch_size']
         | 
| 201 | 
            +
                    self.head_size = model_config['hidden_size'] // model_config['num_heads']
         | 
| 202 | 
            +
                    self.resolutions, self.freqs_cis_img = self.standard_shapes()   # Used for TensorRT models
         | 
| 203 | 
            +
                    self.image_size = _to_tuple(self.args.image_size)
         | 
| 204 | 
            +
                    latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    self.infer_mode = self.args.infer_mode
         | 
| 207 | 
            +
                    if self.infer_mode in ['fa', 'torch']:
         | 
| 208 | 
            +
                        model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
         | 
| 209 | 
            +
                        if not model_path.exists():
         | 
| 210 | 
            +
                            raise ValueError(f"model_path not exists: {model_path}")
         | 
| 211 | 
            +
                        # Build model structure
         | 
| 212 | 
            +
                        self.model = HunYuanDiT(self.args,
         | 
| 213 | 
            +
                                                input_size=latent_size,
         | 
| 214 | 
            +
                                                **model_config,
         | 
| 215 | 
            +
                                                log_fn=logger.info,
         | 
| 216 | 
            +
                                                ).half().to(self.device)    # Force to use fp16
         | 
| 217 | 
            +
                        # Load model checkpoint
         | 
| 218 | 
            +
                        logger.info(f"Loading model checkpoint {model_path}...")
         | 
| 219 | 
            +
                        state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
         | 
| 220 | 
            +
                        self.model.load_state_dict(state_dict)
         | 
| 221 | 
            +
                        self.model.eval()
         | 
| 222 | 
            +
                    elif self.infer_mode == 'trt':
         | 
| 223 | 
            +
                        raise NotImplementedError("TensorRT model is not supported yet.")
         | 
| 224 | 
            +
                    else:
         | 
| 225 | 
            +
                        raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # ========================================================================
         | 
| 228 | 
            +
                    # Build inference pipeline. We use a customized StableDiffusionPipeline.
         | 
| 229 | 
            +
                    logger.info(f"Loading inference pipeline...")
         | 
| 230 | 
            +
                    self.pipeline, self.sampler = self.load_sampler()
         | 
| 231 | 
            +
                    logger.info(f'Loading pipeline finished')
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # ========================================================================
         | 
| 234 | 
            +
                    self.default_negative_prompt = NEGATIVE_PROMPT
         | 
| 235 | 
            +
                    logger.info("==================================================")
         | 
| 236 | 
            +
                    logger.info(f"                Model is ready.                  ")
         | 
| 237 | 
            +
                    logger.info("==================================================")
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def load_sampler(self, sampler=None):
         | 
| 240 | 
            +
                    pipeline, sampler = get_pipeline(self.args,
         | 
| 241 | 
            +
                                                     self.vae,
         | 
| 242 | 
            +
                                                     self.clip_text_encoder,
         | 
| 243 | 
            +
                                                     self.tokenizer,
         | 
| 244 | 
            +
                                                     self.model,
         | 
| 245 | 
            +
                                                     device=self.device,
         | 
| 246 | 
            +
                                                     rank=0,
         | 
| 247 | 
            +
                                                     embedder_t5=self.embedder_t5,
         | 
| 248 | 
            +
                                                     infer_mode=self.infer_mode,
         | 
| 249 | 
            +
                                                     sampler=sampler,
         | 
| 250 | 
            +
                                                     )
         | 
| 251 | 
            +
                    return pipeline, sampler
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def calc_rope(self, height, width):
         | 
| 254 | 
            +
                    th = height // 8 // self.patch_size
         | 
| 255 | 
            +
                    tw = width // 8 // self.patch_size
         | 
| 256 | 
            +
                    base_size = 512 // 8 // self.patch_size
         | 
| 257 | 
            +
                    start, stop = get_fill_resize_and_crop((th, tw), base_size)
         | 
| 258 | 
            +
                    sub_args = [start, stop, (th, tw)]
         | 
| 259 | 
            +
                    rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
         | 
| 260 | 
            +
                    return rope
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                def standard_shapes(self):
         | 
| 263 | 
            +
                    resolutions = ResolutionGroup()
         | 
| 264 | 
            +
                    freqs_cis_img = {}
         | 
| 265 | 
            +
                    for reso in resolutions.data:
         | 
| 266 | 
            +
                        freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
         | 
| 267 | 
            +
                    return resolutions, freqs_cis_img
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def predict(self,
         | 
| 270 | 
            +
                            user_prompt,
         | 
| 271 | 
            +
                            height=1024,
         | 
| 272 | 
            +
                            width=1024,
         | 
| 273 | 
            +
                            seed=None,
         | 
| 274 | 
            +
                            enhanced_prompt=None,
         | 
| 275 | 
            +
                            negative_prompt=None,
         | 
| 276 | 
            +
                            infer_steps=100,
         | 
| 277 | 
            +
                            guidance_scale=6,
         | 
| 278 | 
            +
                            batch_size=1,
         | 
| 279 | 
            +
                            src_size_cond=(1024, 1024),
         | 
| 280 | 
            +
                            sampler=None,
         | 
| 281 | 
            +
                            ):
         | 
| 282 | 
            +
                    # ========================================================================
         | 
| 283 | 
            +
                    # Arguments: seed
         | 
| 284 | 
            +
                    # ========================================================================
         | 
| 285 | 
            +
                    if seed is None:
         | 
| 286 | 
            +
                        seed = random.randint(0, 1_000_000)
         | 
| 287 | 
            +
                    if not isinstance(seed, int):
         | 
| 288 | 
            +
                        raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
         | 
| 289 | 
            +
                    generator = set_seeds(seed)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    # ========================================================================
         | 
| 292 | 
            +
                    # Arguments: target_width, target_height
         | 
| 293 | 
            +
                    # ========================================================================
         | 
| 294 | 
            +
                    if width <= 0 or height <= 0:
         | 
| 295 | 
            +
                        raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
         | 
| 296 | 
            +
                    logger.info(f"Input (height, width) = ({height}, {width})")
         | 
| 297 | 
            +
                    if self.infer_mode in ['fa', 'torch']:
         | 
| 298 | 
            +
                        # We must force height and width to align to 16 and to be an integer.
         | 
| 299 | 
            +
                        target_height = int((height // 16) * 16)
         | 
| 300 | 
            +
                        target_width = int((width // 16) * 16)
         | 
| 301 | 
            +
                        logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
         | 
| 302 | 
            +
                    elif self.infer_mode == 'trt':
         | 
| 303 | 
            +
                        target_width, target_height = get_standard_shape(width, height)
         | 
| 304 | 
            +
                        logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
         | 
| 305 | 
            +
                    else:
         | 
| 306 | 
            +
                        raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # ========================================================================
         | 
| 309 | 
            +
                    # Arguments: prompt, new_prompt, negative_prompt
         | 
| 310 | 
            +
                    # ========================================================================
         | 
| 311 | 
            +
                    if not isinstance(user_prompt, str):
         | 
| 312 | 
            +
                        raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
         | 
| 313 | 
            +
                    user_prompt = user_prompt.strip()
         | 
| 314 | 
            +
                    prompt = user_prompt
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    if enhanced_prompt is not None:
         | 
| 317 | 
            +
                        if not isinstance(enhanced_prompt, str):
         | 
| 318 | 
            +
                            raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
         | 
| 319 | 
            +
                        enhanced_prompt = enhanced_prompt.strip()
         | 
| 320 | 
            +
                        prompt = enhanced_prompt
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # negative prompt
         | 
| 323 | 
            +
                    if negative_prompt is None or negative_prompt == '':
         | 
| 324 | 
            +
                        negative_prompt = self.default_negative_prompt
         | 
| 325 | 
            +
                    if not isinstance(negative_prompt, str):
         | 
| 326 | 
            +
                        raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    # ========================================================================
         | 
| 329 | 
            +
                    # Arguments: style. (A fixed argument. Don't Change it.)
         | 
| 330 | 
            +
                    # ========================================================================
         | 
| 331 | 
            +
                    style = torch.as_tensor([0, 0] * batch_size, device=self.device)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    # ========================================================================
         | 
| 334 | 
            +
                    # Inner arguments: image_meta_size (Please refer to SDXL.)
         | 
| 335 | 
            +
                    # ========================================================================
         | 
| 336 | 
            +
                    if isinstance(src_size_cond, int):
         | 
| 337 | 
            +
                        src_size_cond = [src_size_cond, src_size_cond]
         | 
| 338 | 
            +
                    if not isinstance(src_size_cond, (list, tuple)):
         | 
| 339 | 
            +
                        raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
         | 
| 340 | 
            +
                    if len(src_size_cond) != 2:
         | 
| 341 | 
            +
                        raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
         | 
| 342 | 
            +
                    size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
         | 
| 343 | 
            +
                    image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    # ========================================================================
         | 
| 346 | 
            +
                    start_time = time.time()
         | 
| 347 | 
            +
                    logger.debug(f"""
         | 
| 348 | 
            +
                                   prompt: {user_prompt}
         | 
| 349 | 
            +
                          enhanced prompt: {enhanced_prompt}
         | 
| 350 | 
            +
                                     seed: {seed}
         | 
| 351 | 
            +
                          (height, width): {(target_height, target_width)}
         | 
| 352 | 
            +
                          negative_prompt: {negative_prompt}
         | 
| 353 | 
            +
                               batch_size: {batch_size}
         | 
| 354 | 
            +
                           guidance_scale: {guidance_scale}
         | 
| 355 | 
            +
                              infer_steps: {infer_steps}
         | 
| 356 | 
            +
                          image_meta_size: {size_cond}
         | 
| 357 | 
            +
                    """)
         | 
| 358 | 
            +
                    reso = f'{target_height}x{target_width}'
         | 
| 359 | 
            +
                    if reso in self.freqs_cis_img:
         | 
| 360 | 
            +
                        freqs_cis_img = self.freqs_cis_img[reso]
         | 
| 361 | 
            +
                    else:
         | 
| 362 | 
            +
                        freqs_cis_img = self.calc_rope(target_height, target_width)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    if sampler is not None and sampler != self.sampler:
         | 
| 365 | 
            +
                        self.pipeline, self.sampler = self.load_sampler(sampler)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    samples = self.pipeline(
         | 
| 368 | 
            +
                        height=target_height,
         | 
| 369 | 
            +
                        width=target_width,
         | 
| 370 | 
            +
                        prompt=prompt,
         | 
| 371 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 372 | 
            +
                        num_images_per_prompt=batch_size,
         | 
| 373 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 374 | 
            +
                        num_inference_steps=infer_steps,
         | 
| 375 | 
            +
                        image_meta_size=image_meta_size,
         | 
| 376 | 
            +
                        style=style,
         | 
| 377 | 
            +
                        return_dict=False,
         | 
| 378 | 
            +
                        generator=generator,
         | 
| 379 | 
            +
                        freqs_cis_img=freqs_cis_img,
         | 
| 380 | 
            +
                        use_fp16=self.args.use_fp16,
         | 
| 381 | 
            +
                        learn_sigma=self.args.learn_sigma,
         | 
| 382 | 
            +
                    )[0]
         | 
| 383 | 
            +
                    gen_time = time.time() - start_time
         | 
| 384 | 
            +
                    logger.debug(f"Success, time: {gen_time}")
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    return {
         | 
| 387 | 
            +
                        'images': samples,
         | 
| 388 | 
            +
                        'seed': seed,
         | 
| 389 | 
            +
                    }
         | 
    	
        hydit/modules/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        hydit/modules/attn_layers.py
    ADDED
    
    | @@ -0,0 +1,377 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from typing import Tuple, Union, Optional
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            try:
         | 
| 6 | 
            +
                import flash_attn
         | 
| 7 | 
            +
                if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2:
         | 
| 8 | 
            +
                    from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
         | 
| 9 | 
            +
                    from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
         | 
| 10 | 
            +
                else:
         | 
| 11 | 
            +
                    from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
         | 
| 12 | 
            +
                    from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
         | 
| 13 | 
            +
            except Exception as e:
         | 
| 14 | 
            +
                print(f'flash_attn import failed: {e}')
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                Reshape frequency tensor for broadcasting it with another tensor.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
         | 
| 22 | 
            +
                for the purpose of broadcasting the frequency tensor during element-wise operations.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Args:
         | 
| 25 | 
            +
                    freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
         | 
| 26 | 
            +
                    x (torch.Tensor): Target tensor for broadcasting compatibility.
         | 
| 27 | 
            +
                    head_first (bool): head dimension first (except batch dim) or not.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Returns:
         | 
| 30 | 
            +
                    torch.Tensor: Reshaped frequency tensor.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                Raises:
         | 
| 33 | 
            +
                    AssertionError: If the frequency tensor doesn't match the expected shape.
         | 
| 34 | 
            +
                    AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                ndim = x.ndim
         | 
| 37 | 
            +
                assert 0 <= 1 < ndim
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                if isinstance(freqs_cis, tuple):
         | 
| 40 | 
            +
                    # freqs_cis: (cos, sin) in real space
         | 
| 41 | 
            +
                    if head_first:
         | 
| 42 | 
            +
                        assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
         | 
| 43 | 
            +
                        shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         | 
| 44 | 
            +
                    else:
         | 
| 45 | 
            +
                        assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
         | 
| 46 | 
            +
                        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         | 
| 47 | 
            +
                    return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
         | 
| 48 | 
            +
                else:
         | 
| 49 | 
            +
                    # freqs_cis: values in complex space
         | 
| 50 | 
            +
                    if head_first:
         | 
| 51 | 
            +
                        assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
         | 
| 52 | 
            +
                        shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
         | 
| 55 | 
            +
                        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         | 
| 56 | 
            +
                    return freqs_cis.view(*shape)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def rotate_half(x):
         | 
| 60 | 
            +
                x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
         | 
| 61 | 
            +
                return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def apply_rotary_emb(
         | 
| 65 | 
            +
                    xq: torch.Tensor,
         | 
| 66 | 
            +
                    xk: Optional[torch.Tensor],
         | 
| 67 | 
            +
                    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
         | 
| 68 | 
            +
                    head_first: bool = False,
         | 
| 69 | 
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
                Apply rotary embeddings to input tensors using the given frequency tensor.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
         | 
| 74 | 
            +
                frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
         | 
| 75 | 
            +
                is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
         | 
| 76 | 
            +
                returned as real tensors.
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                Args:
         | 
| 79 | 
            +
                    xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
         | 
| 80 | 
            +
                    xk (torch.Tensor): Key tensor to apply rotary embeddings.   [B, S, H, D]
         | 
| 81 | 
            +
                    freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
         | 
| 82 | 
            +
                    head_first (bool): head dimension first (except batch dim) or not.
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                Returns:
         | 
| 85 | 
            +
                    Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                xk_out = None
         | 
| 89 | 
            +
                if isinstance(freqs_cis, tuple):
         | 
| 90 | 
            +
                    cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)    # [S, D]
         | 
| 91 | 
            +
                    cos, sin = cos.to(xq.device), sin.to(xq.device)
         | 
| 92 | 
            +
                    xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
         | 
| 93 | 
            +
                    if xk is not None:
         | 
| 94 | 
            +
                        xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
         | 
| 95 | 
            +
                else:
         | 
| 96 | 
            +
                    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # [B, S, H, D//2]
         | 
| 97 | 
            +
                    freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device)   # [S, D//2] --> [1, S, 1, D//2]
         | 
| 98 | 
            +
                    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
         | 
| 99 | 
            +
                    if xk is not None:
         | 
| 100 | 
            +
                        xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  # [B, S, H, D//2]
         | 
| 101 | 
            +
                        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                return xq_out, xk_out
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            class FlashSelfMHAModified(nn.Module):
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
                Use QK Normalization.
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                def __init__(self,
         | 
| 111 | 
            +
                             dim,
         | 
| 112 | 
            +
                             num_heads,
         | 
| 113 | 
            +
                             qkv_bias=True,
         | 
| 114 | 
            +
                             qk_norm=False,
         | 
| 115 | 
            +
                             attn_drop=0.0,
         | 
| 116 | 
            +
                             proj_drop=0.0,
         | 
| 117 | 
            +
                             device=None,
         | 
| 118 | 
            +
                             dtype=None,
         | 
| 119 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 120 | 
            +
                             ):
         | 
| 121 | 
            +
                    factory_kwargs = {'device': device, 'dtype': dtype}
         | 
| 122 | 
            +
                    super().__init__()
         | 
| 123 | 
            +
                    self.dim = dim
         | 
| 124 | 
            +
                    self.num_heads = num_heads
         | 
| 125 | 
            +
                    assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads"
         | 
| 126 | 
            +
                    self.head_dim = self.dim // num_heads
         | 
| 127 | 
            +
                    assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs)
         | 
| 130 | 
            +
                    # TODO: eps should be 1 / 65530 if using fp16
         | 
| 131 | 
            +
                    self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 132 | 
            +
                    self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 133 | 
            +
                    self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop)
         | 
| 134 | 
            +
                    self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
         | 
| 135 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def forward(self, x, freqs_cis_img=None):
         | 
| 138 | 
            +
                    """
         | 
| 139 | 
            +
                    Parameters
         | 
| 140 | 
            +
                    ----------
         | 
| 141 | 
            +
                    x: torch.Tensor
         | 
| 142 | 
            +
                        (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
         | 
| 143 | 
            +
                    freqs_cis_img: torch.Tensor
         | 
| 144 | 
            +
                        (batch, hidden_dim // 2), RoPE for image
         | 
| 145 | 
            +
                    """
         | 
| 146 | 
            +
                    b, s, d = x.shape
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    qkv = self.Wqkv(x)
         | 
| 149 | 
            +
                    qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim)  # [b, s, 3, h, d]
         | 
| 150 | 
            +
                    q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
         | 
| 151 | 
            +
                    q = self.q_norm(q).half()   # [b, s, h, d]
         | 
| 152 | 
            +
                    k = self.k_norm(k).half()
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # Apply RoPE if needed
         | 
| 155 | 
            +
                    if freqs_cis_img is not None:
         | 
| 156 | 
            +
                        qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
         | 
| 157 | 
            +
                        assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
         | 
| 158 | 
            +
                        q, k = qq, kk
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    qkv = torch.stack([q, k, v], dim=2)     # [b, s, 3, h, d]
         | 
| 161 | 
            +
                    context = self.inner_attn(qkv)
         | 
| 162 | 
            +
                    out = self.out_proj(context.view(b, s, d))
         | 
| 163 | 
            +
                    out = self.proj_drop(out)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    out_tuple = (out,)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    return out_tuple
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            class FlashCrossMHAModified(nn.Module):
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                Use QK Normalization.
         | 
| 173 | 
            +
                """
         | 
| 174 | 
            +
                def __init__(self,
         | 
| 175 | 
            +
                             qdim,
         | 
| 176 | 
            +
                             kdim,
         | 
| 177 | 
            +
                             num_heads,
         | 
| 178 | 
            +
                             qkv_bias=True,
         | 
| 179 | 
            +
                             qk_norm=False,
         | 
| 180 | 
            +
                             attn_drop=0.0,
         | 
| 181 | 
            +
                             proj_drop=0.0,
         | 
| 182 | 
            +
                             device=None,
         | 
| 183 | 
            +
                             dtype=None,
         | 
| 184 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 185 | 
            +
                             ):
         | 
| 186 | 
            +
                    factory_kwargs = {'device': device, 'dtype': dtype}
         | 
| 187 | 
            +
                    super().__init__()
         | 
| 188 | 
            +
                    self.qdim = qdim
         | 
| 189 | 
            +
                    self.kdim = kdim
         | 
| 190 | 
            +
                    self.num_heads = num_heads
         | 
| 191 | 
            +
                    assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
         | 
| 192 | 
            +
                    self.head_dim = self.qdim // num_heads
         | 
| 193 | 
            +
                    assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    self.scale = self.head_dim ** -0.5
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
         | 
| 198 | 
            +
                    self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # TODO: eps should be 1 / 65530 if using fp16
         | 
| 201 | 
            +
                    self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 202 | 
            +
                    self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop)
         | 
| 205 | 
            +
                    self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
         | 
| 206 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def forward(self, x, y, freqs_cis_img=None):
         | 
| 209 | 
            +
                    """
         | 
| 210 | 
            +
                    Parameters
         | 
| 211 | 
            +
                    ----------
         | 
| 212 | 
            +
                    x: torch.Tensor
         | 
| 213 | 
            +
                        (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
         | 
| 214 | 
            +
                    y: torch.Tensor
         | 
| 215 | 
            +
                        (batch, seqlen2, hidden_dim2)
         | 
| 216 | 
            +
                    freqs_cis_img: torch.Tensor
         | 
| 217 | 
            +
                        (batch, hidden_dim // num_heads), RoPE for image
         | 
| 218 | 
            +
                    """
         | 
| 219 | 
            +
                    b, s1, _ = x.shape     # [b, s1, D]
         | 
| 220 | 
            +
                    _, s2, _ = y.shape     # [b, s2, 1024]
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim)       # [b, s1, h, d]
         | 
| 223 | 
            +
                    kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim)  # [b, s2, 2, h, d]
         | 
| 224 | 
            +
                    k, v = kv.unbind(dim=2)                 # [b, s2, h, d]
         | 
| 225 | 
            +
                    q = self.q_norm(q).half()               # [b, s1, h, d]
         | 
| 226 | 
            +
                    k = self.k_norm(k).half()               # [b, s2, h, d]
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # Apply RoPE if needed
         | 
| 229 | 
            +
                    if freqs_cis_img is not None:
         | 
| 230 | 
            +
                        qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
         | 
| 231 | 
            +
                        assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
         | 
| 232 | 
            +
                        q = qq                              # [b, s1, h, d]
         | 
| 233 | 
            +
                    kv = torch.stack([k, v], dim=2)         # [b, s1, 2, h, d]
         | 
| 234 | 
            +
                    context = self.inner_attn(q, kv)        # [b, s1, h, d]
         | 
| 235 | 
            +
                    context = context.view(b, s1, -1)       # [b, s1, D]
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    out = self.out_proj(context)
         | 
| 238 | 
            +
                    out = self.proj_drop(out)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    out_tuple = (out,)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    return out_tuple
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            class CrossAttention(nn.Module):
         | 
| 246 | 
            +
                """
         | 
| 247 | 
            +
                Use QK Normalization.
         | 
| 248 | 
            +
                """
         | 
| 249 | 
            +
                def __init__(self,
         | 
| 250 | 
            +
                             qdim,
         | 
| 251 | 
            +
                             kdim,
         | 
| 252 | 
            +
                             num_heads,
         | 
| 253 | 
            +
                             qkv_bias=True,
         | 
| 254 | 
            +
                             qk_norm=False,
         | 
| 255 | 
            +
                             attn_drop=0.0,
         | 
| 256 | 
            +
                             proj_drop=0.0,
         | 
| 257 | 
            +
                             device=None,
         | 
| 258 | 
            +
                             dtype=None,
         | 
| 259 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 260 | 
            +
                             ):
         | 
| 261 | 
            +
                    factory_kwargs = {'device': device, 'dtype': dtype}
         | 
| 262 | 
            +
                    super().__init__()
         | 
| 263 | 
            +
                    self.qdim = qdim
         | 
| 264 | 
            +
                    self.kdim = kdim
         | 
| 265 | 
            +
                    self.num_heads = num_heads
         | 
| 266 | 
            +
                    assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
         | 
| 267 | 
            +
                    self.head_dim = self.qdim // num_heads
         | 
| 268 | 
            +
                    assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
         | 
| 269 | 
            +
                    self.scale = self.head_dim ** -0.5
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
         | 
| 272 | 
            +
                    self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    # TODO: eps should be 1 / 65530 if using fp16
         | 
| 275 | 
            +
                    self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 276 | 
            +
                    self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 277 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 278 | 
            +
                    self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
         | 
| 279 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                def forward(self, x, y, freqs_cis_img=None):
         | 
| 282 | 
            +
                    """
         | 
| 283 | 
            +
                    Parameters
         | 
| 284 | 
            +
                    ----------
         | 
| 285 | 
            +
                    x: torch.Tensor
         | 
| 286 | 
            +
                        (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
         | 
| 287 | 
            +
                    y: torch.Tensor
         | 
| 288 | 
            +
                        (batch, seqlen2, hidden_dim2)
         | 
| 289 | 
            +
                    freqs_cis_img: torch.Tensor
         | 
| 290 | 
            +
                        (batch, hidden_dim // 2), RoPE for image
         | 
| 291 | 
            +
                    """
         | 
| 292 | 
            +
                    b, s1, c = x.shape     # [b, s1, D]
         | 
| 293 | 
            +
                    _, s2, c = y.shape     # [b, s2, 1024]
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim)   # [b, s1, h, d]
         | 
| 296 | 
            +
                    kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim)    # [b, s2, 2, h, d]
         | 
| 297 | 
            +
                    k, v = kv.unbind(dim=2) # [b, s, h, d]
         | 
| 298 | 
            +
                    q = self.q_norm(q)
         | 
| 299 | 
            +
                    k = self.k_norm(k)
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    # Apply RoPE if needed
         | 
| 302 | 
            +
                    if freqs_cis_img is not None:
         | 
| 303 | 
            +
                        qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
         | 
| 304 | 
            +
                        assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
         | 
| 305 | 
            +
                        q = qq
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    q = q * self.scale
         | 
| 308 | 
            +
                    q = q.transpose(-2, -3).contiguous()        # q ->  B, L1, H, C - B, H, L1, C
         | 
| 309 | 
            +
                    k = k.permute(0, 2, 3, 1).contiguous()      # k ->  B, L2, H, C - B, H, C, L2
         | 
| 310 | 
            +
                    attn = q @ k                                # attn -> B, H, L1, L2
         | 
| 311 | 
            +
                    attn = attn.softmax(dim=-1)                 # attn -> B, H, L1, L2
         | 
| 312 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 313 | 
            +
                    x = attn @ v.transpose(-2, -3)              # v -> B, L2, H, C - B, H, L2, C    x-> B, H, L1, C
         | 
| 314 | 
            +
                    context = x.transpose(1, 2)                 # context -> B, H, L1, C - B, L1, H, C
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    context = context.contiguous().view(b, s1, -1)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    out = self.out_proj(context)  # context.reshape - B, L1, -1
         | 
| 319 | 
            +
                    out = self.proj_drop(out)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    out_tuple = (out,)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    return out_tuple
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            class Attention(nn.Module):
         | 
| 327 | 
            +
                """
         | 
| 328 | 
            +
                We rename some layer names to align with flash attention
         | 
| 329 | 
            +
                """
         | 
| 330 | 
            +
                def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0.,
         | 
| 331 | 
            +
                             norm_layer=nn.LayerNorm,
         | 
| 332 | 
            +
                             ):
         | 
| 333 | 
            +
                    super().__init__()
         | 
| 334 | 
            +
                    self.dim = dim
         | 
| 335 | 
            +
                    self.num_heads = num_heads
         | 
| 336 | 
            +
                    assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
         | 
| 337 | 
            +
                    self.head_dim = self.dim // num_heads
         | 
| 338 | 
            +
                    # This assertion is aligned with flash attention
         | 
| 339 | 
            +
                    assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
         | 
| 340 | 
            +
                    self.scale = self.head_dim ** -0.5
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    # qkv --> Wqkv
         | 
| 343 | 
            +
                    self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         | 
| 344 | 
            +
                    # TODO: eps should be 1 / 65530 if using fp16
         | 
| 345 | 
            +
                    self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 346 | 
            +
                    self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
         | 
| 347 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 348 | 
            +
                    self.out_proj = nn.Linear(dim, dim)
         | 
| 349 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                def forward(self, x, freqs_cis_img=None):
         | 
| 352 | 
            +
                    B, N, C = x.shape
         | 
| 353 | 
            +
                    qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)   # [3, b, h, s, d]
         | 
| 354 | 
            +
                    q, k, v = qkv.unbind(0)     # [b, h, s, d]
         | 
| 355 | 
            +
                    q = self.q_norm(q)          # [b, h, s, d]
         | 
| 356 | 
            +
                    k = self.k_norm(k)          # [b, h, s, d]
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    # Apply RoPE if needed
         | 
| 359 | 
            +
                    if freqs_cis_img is not None:
         | 
| 360 | 
            +
                        qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
         | 
| 361 | 
            +
                        assert qq.shape == q.shape and kk.shape == k.shape, \
         | 
| 362 | 
            +
                            f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
         | 
| 363 | 
            +
                        q, k = qq, kk
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    q = q * self.scale
         | 
| 366 | 
            +
                    attn = q @ k.transpose(-2, -1)              # [b, h, s, d] @ [b, h, d, s]
         | 
| 367 | 
            +
                    attn = attn.softmax(dim=-1)                 # [b, h, s, s]
         | 
| 368 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 369 | 
            +
                    x = attn @ v                                # [b, h, s, d]
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    x = x.transpose(1, 2).reshape(B, N, C)      # [b, s, h, d]
         | 
| 372 | 
            +
                    x = self.out_proj(x)
         | 
| 373 | 
            +
                    x = self.proj_drop(x)
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    out_tuple = (x,)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    return out_tuple
         | 
    	
        hydit/modules/embedders.py
    ADDED
    
    | @@ -0,0 +1,111 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            from einops import repeat
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from timm.models.layers import to_2tuple
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 10 | 
            +
                """ 2D Image to Patch Embedding
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Image to Patch Embedding using Conv2d
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                A convolution based approach to patchifying a 2D image w/ embedding projection.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                Based on the impl in https://github.com/google-research/vision_transformer
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Hacked together by / Copyright 2020 Ross Wightman
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Remove the _assert function in forward function to be compatible with multi-resolution images.
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                def __init__(
         | 
| 23 | 
            +
                        self,
         | 
| 24 | 
            +
                        img_size=224,
         | 
| 25 | 
            +
                        patch_size=16,
         | 
| 26 | 
            +
                        in_chans=3,
         | 
| 27 | 
            +
                        embed_dim=768,
         | 
| 28 | 
            +
                        norm_layer=None,
         | 
| 29 | 
            +
                        flatten=True,
         | 
| 30 | 
            +
                        bias=True,
         | 
| 31 | 
            +
                ):
         | 
| 32 | 
            +
                    super().__init__()
         | 
| 33 | 
            +
                    if isinstance(img_size, int):
         | 
| 34 | 
            +
                        img_size = to_2tuple(img_size)
         | 
| 35 | 
            +
                    elif isinstance(img_size, (tuple, list)) and len(img_size) == 2:
         | 
| 36 | 
            +
                        img_size = tuple(img_size)
         | 
| 37 | 
            +
                    else:
         | 
| 38 | 
            +
                        raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}")
         | 
| 39 | 
            +
                    patch_size = to_2tuple(patch_size)
         | 
| 40 | 
            +
                    self.img_size = img_size
         | 
| 41 | 
            +
                    self.patch_size = patch_size
         | 
| 42 | 
            +
                    self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
         | 
| 43 | 
            +
                    self.num_patches = self.grid_size[0] * self.grid_size[1]
         | 
| 44 | 
            +
                    self.flatten = flatten
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
         | 
| 47 | 
            +
                    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def update_image_size(self, img_size):
         | 
| 50 | 
            +
                    self.img_size = img_size
         | 
| 51 | 
            +
                    self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
         | 
| 52 | 
            +
                    self.num_patches = self.grid_size[0] * self.grid_size[1]
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def forward(self, x):
         | 
| 55 | 
            +
                    # B, C, H, W = x.shape
         | 
| 56 | 
            +
                    # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
         | 
| 57 | 
            +
                    # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
         | 
| 58 | 
            +
                    x = self.proj(x)
         | 
| 59 | 
            +
                    if self.flatten:
         | 
| 60 | 
            +
                        x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
         | 
| 61 | 
            +
                    x = self.norm(x)
         | 
| 62 | 
            +
                    return x
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                Create sinusoidal timestep embeddings.
         | 
| 68 | 
            +
                :param t: a 1-D Tensor of N indices, one per batch element.
         | 
| 69 | 
            +
                                  These may be fractional.
         | 
| 70 | 
            +
                :param dim: the dimension of the output.
         | 
| 71 | 
            +
                :param max_period: controls the minimum frequency of the embeddings.
         | 
| 72 | 
            +
                :return: an (N, D) Tensor of positional embeddings.
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
         | 
| 75 | 
            +
                if not repeat_only:
         | 
| 76 | 
            +
                    half = dim // 2
         | 
| 77 | 
            +
                    freqs = torch.exp(
         | 
| 78 | 
            +
                        -math.log(max_period)
         | 
| 79 | 
            +
                        * torch.arange(start=0, end=half, dtype=torch.float32)
         | 
| 80 | 
            +
                        / half
         | 
| 81 | 
            +
                    ).to(device=t.device)   # size: [dim/2], 一个指数衰减的曲线
         | 
| 82 | 
            +
                    args = t[:, None].float() * freqs[None]
         | 
| 83 | 
            +
                    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         | 
| 84 | 
            +
                    if dim % 2:
         | 
| 85 | 
            +
                        embedding = torch.cat(
         | 
| 86 | 
            +
                            [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
         | 
| 87 | 
            +
                        )
         | 
| 88 | 
            +
                else:
         | 
| 89 | 
            +
                    embedding = repeat(t, "b -> b d", d=dim)
         | 
| 90 | 
            +
                return embedding
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            class TimestepEmbedder(nn.Module):
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                Embeds scalar timesteps into vector representations.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
         | 
| 98 | 
            +
                    super().__init__()
         | 
| 99 | 
            +
                    if out_size is None:
         | 
| 100 | 
            +
                        out_size = hidden_size
         | 
| 101 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 102 | 
            +
                        nn.Linear(frequency_embedding_size, hidden_size, bias=True),
         | 
| 103 | 
            +
                        nn.SiLU(),
         | 
| 104 | 
            +
                        nn.Linear(hidden_size, out_size, bias=True),
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
                    self.frequency_embedding_size = frequency_embedding_size
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward(self, t):
         | 
| 109 | 
            +
                    t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
         | 
| 110 | 
            +
                    t_emb = self.mlp(t_freq)
         | 
| 111 | 
            +
                    return t_emb
         | 
    	
        hydit/modules/models.py
    ADDED
    
    | @@ -0,0 +1,409 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 5 | 
            +
            from diffusers.models import ModelMixin
         | 
| 6 | 
            +
            from timm.models.vision_transformer import Mlp
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention
         | 
| 9 | 
            +
            from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
         | 
| 10 | 
            +
            from .norm_layers import RMSNorm
         | 
| 11 | 
            +
            from .poolers import AttentionPool
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def modulate(x, shift, scale):
         | 
| 15 | 
            +
                return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class FP32_Layernorm(nn.LayerNorm):
         | 
| 19 | 
            +
                def forward(self, inputs: torch.Tensor) -> torch.Tensor:
         | 
| 20 | 
            +
                    origin_dtype = inputs.dtype
         | 
| 21 | 
            +
                    return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(),
         | 
| 22 | 
            +
                                        self.eps).to(origin_dtype)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class FP32_SiLU(nn.SiLU):
         | 
| 26 | 
            +
                def forward(self, inputs: torch.Tensor) -> torch.Tensor:
         | 
| 27 | 
            +
                    return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class HunYuanDiTBlock(nn.Module):
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                A HunYuanDiT block with `add` conditioning.
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
                def __init__(self,
         | 
| 35 | 
            +
                             hidden_size,
         | 
| 36 | 
            +
                             c_emb_size,
         | 
| 37 | 
            +
                             num_heads,
         | 
| 38 | 
            +
                             mlp_ratio=4.0,
         | 
| 39 | 
            +
                             text_states_dim=1024,
         | 
| 40 | 
            +
                             use_flash_attn=False,
         | 
| 41 | 
            +
                             qk_norm=False,
         | 
| 42 | 
            +
                             norm_type="layer",
         | 
| 43 | 
            +
                             skip=False,
         | 
| 44 | 
            +
                             ):
         | 
| 45 | 
            +
                    super().__init__()
         | 
| 46 | 
            +
                    self.use_flash_attn = use_flash_attn
         | 
| 47 | 
            +
                    use_ele_affine = True
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    if norm_type == "layer":
         | 
| 50 | 
            +
                        norm_layer = FP32_Layernorm
         | 
| 51 | 
            +
                    elif norm_type == "rms":
         | 
| 52 | 
            +
                        norm_layer = RMSNorm
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        raise ValueError(f"Unknown norm_type: {norm_type}")
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # ========================= Self-Attention =========================
         | 
| 57 | 
            +
                    self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
         | 
| 58 | 
            +
                    if use_flash_attn:
         | 
| 59 | 
            +
                        self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    # ========================= FFN =========================
         | 
| 64 | 
            +
                    self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
         | 
| 65 | 
            +
                    mlp_hidden_dim = int(hidden_size * mlp_ratio)
         | 
| 66 | 
            +
                    approx_gelu = lambda: nn.GELU(approximate="tanh")
         | 
| 67 | 
            +
                    self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # ========================= Add =========================
         | 
| 70 | 
            +
                    # Simply use add like SDXL.
         | 
| 71 | 
            +
                    self.default_modulation = nn.Sequential(
         | 
| 72 | 
            +
                        FP32_SiLU(),
         | 
| 73 | 
            +
                        nn.Linear(c_emb_size, hidden_size, bias=True)
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # ========================= Cross-Attention =========================
         | 
| 77 | 
            +
                    if use_flash_attn:
         | 
| 78 | 
            +
                        self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
         | 
| 79 | 
            +
                                                           qk_norm=qk_norm)
         | 
| 80 | 
            +
                    else:
         | 
| 81 | 
            +
                        self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
         | 
| 82 | 
            +
                                                    qk_norm=qk_norm)
         | 
| 83 | 
            +
                    self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    # ========================= Skip Connection =========================
         | 
| 86 | 
            +
                    if skip:
         | 
| 87 | 
            +
                        self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6)
         | 
| 88 | 
            +
                        self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
         | 
| 89 | 
            +
                    else:
         | 
| 90 | 
            +
                        self.skip_linear = None
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
         | 
| 93 | 
            +
                    # Long Skip Connection
         | 
| 94 | 
            +
                    if self.skip_linear is not None:
         | 
| 95 | 
            +
                        cat = torch.cat([x, skip], dim=-1)
         | 
| 96 | 
            +
                        cat = self.skip_norm(cat)
         | 
| 97 | 
            +
                        x = self.skip_linear(cat)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # Self-Attention
         | 
| 100 | 
            +
                    shift_msa = self.default_modulation(c).unsqueeze(dim=1)
         | 
| 101 | 
            +
                    attn_inputs = (
         | 
| 102 | 
            +
                        self.norm1(x) + shift_msa, freq_cis_img,
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    x = x + self.attn1(*attn_inputs)[0]
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    # Cross-Attention
         | 
| 107 | 
            +
                    cross_inputs = (
         | 
| 108 | 
            +
                        self.norm3(x), text_states, freq_cis_img
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
                    x = x + self.attn2(*cross_inputs)[0]
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # FFN Layer
         | 
| 113 | 
            +
                    mlp_inputs = self.norm2(x)
         | 
| 114 | 
            +
                    x = x + self.mlp(mlp_inputs)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    return x
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            class FinalLayer(nn.Module):
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                The final layer of HunYuanDiT.
         | 
| 122 | 
            +
                """
         | 
| 123 | 
            +
                def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
         | 
| 124 | 
            +
                    super().__init__()
         | 
| 125 | 
            +
                    self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
         | 
| 126 | 
            +
                    self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
         | 
| 127 | 
            +
                    self.adaLN_modulation = nn.Sequential(
         | 
| 128 | 
            +
                        FP32_SiLU(),
         | 
| 129 | 
            +
                        nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
         | 
| 130 | 
            +
                    )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def forward(self, x, c):
         | 
| 133 | 
            +
                    shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
         | 
| 134 | 
            +
                    x = modulate(self.norm_final(x), shift, scale)
         | 
| 135 | 
            +
                    x = self.linear(x)
         | 
| 136 | 
            +
                    return x
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            class HunYuanDiT(ModelMixin, ConfigMixin):
         | 
| 140 | 
            +
                """
         | 
| 141 | 
            +
                HunYuanDiT: Diffusion model with a Transformer backbone.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                Parameters
         | 
| 146 | 
            +
                ----------
         | 
| 147 | 
            +
                args: argparse.Namespace
         | 
| 148 | 
            +
                    The arguments parsed by argparse.
         | 
| 149 | 
            +
                input_size: tuple
         | 
| 150 | 
            +
                    The size of the input image.
         | 
| 151 | 
            +
                patch_size: int
         | 
| 152 | 
            +
                    The size of the patch.
         | 
| 153 | 
            +
                in_channels: int
         | 
| 154 | 
            +
                    The number of input channels.
         | 
| 155 | 
            +
                hidden_size: int
         | 
| 156 | 
            +
                    The hidden size of the transformer backbone.
         | 
| 157 | 
            +
                depth: int
         | 
| 158 | 
            +
                    The number of transformer blocks.
         | 
| 159 | 
            +
                num_heads: int
         | 
| 160 | 
            +
                    The number of attention heads.
         | 
| 161 | 
            +
                mlp_ratio: float
         | 
| 162 | 
            +
                    The ratio of the hidden size of the MLP in the transformer block.
         | 
| 163 | 
            +
                log_fn: callable
         | 
| 164 | 
            +
                    The logging function.
         | 
| 165 | 
            +
                """
         | 
| 166 | 
            +
                @register_to_config
         | 
| 167 | 
            +
                def __init__(
         | 
| 168 | 
            +
                        self, args,
         | 
| 169 | 
            +
                        input_size=(32, 32),
         | 
| 170 | 
            +
                        patch_size=2,
         | 
| 171 | 
            +
                        in_channels=4,
         | 
| 172 | 
            +
                        hidden_size=1152,
         | 
| 173 | 
            +
                        depth=28,
         | 
| 174 | 
            +
                        num_heads=16,
         | 
| 175 | 
            +
                        mlp_ratio=4.0,
         | 
| 176 | 
            +
                        log_fn=print,
         | 
| 177 | 
            +
                ):
         | 
| 178 | 
            +
                    super().__init__()
         | 
| 179 | 
            +
                    self.args = args
         | 
| 180 | 
            +
                    self.log_fn = log_fn
         | 
| 181 | 
            +
                    self.depth = depth
         | 
| 182 | 
            +
                    self.learn_sigma = args.learn_sigma
         | 
| 183 | 
            +
                    self.in_channels = in_channels
         | 
| 184 | 
            +
                    self.out_channels = in_channels * 2 if args.learn_sigma else in_channels
         | 
| 185 | 
            +
                    self.patch_size = patch_size
         | 
| 186 | 
            +
                    self.num_heads = num_heads
         | 
| 187 | 
            +
                    self.hidden_size = hidden_size
         | 
| 188 | 
            +
                    self.text_states_dim = args.text_states_dim
         | 
| 189 | 
            +
                    self.text_states_dim_t5 = args.text_states_dim_t5
         | 
| 190 | 
            +
                    self.text_len = args.text_len
         | 
| 191 | 
            +
                    self.text_len_t5 = args.text_len_t5
         | 
| 192 | 
            +
                    self.norm = args.norm
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    use_flash_attn = args.infer_mode == 'fa'
         | 
| 195 | 
            +
                    if use_flash_attn:
         | 
| 196 | 
            +
                        log_fn(f"    Enable Flash Attention.")
         | 
| 197 | 
            +
                    qk_norm = True  # See http://arxiv.org/abs/2302.05442 for details.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    self.mlp_t5 = nn.Sequential(
         | 
| 200 | 
            +
                        nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
         | 
| 201 | 
            +
                        FP32_SiLU(),
         | 
| 202 | 
            +
                        nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                    # learnable replace
         | 
| 205 | 
            +
                    self.text_embedding_padding = nn.Parameter(
         | 
| 206 | 
            +
                        torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32))
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # Attention pooling
         | 
| 209 | 
            +
                    self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # Here we use a default learned embedder layer for future extension.
         | 
| 212 | 
            +
                    self.style_embedder = nn.Embedding(1, hidden_size)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    # Image size and crop size conditions
         | 
| 215 | 
            +
                    self.extra_in_dim = 256 * 6 + hidden_size
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    # Text embedding for `add`
         | 
| 218 | 
            +
                    self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
         | 
| 219 | 
            +
                    self.t_embedder = TimestepEmbedder(hidden_size)
         | 
| 220 | 
            +
                    self.extra_in_dim += 1024
         | 
| 221 | 
            +
                    self.extra_embedder = nn.Sequential(
         | 
| 222 | 
            +
                        nn.Linear(self.extra_in_dim, hidden_size * 4),
         | 
| 223 | 
            +
                        FP32_SiLU(),
         | 
| 224 | 
            +
                        nn.Linear(hidden_size * 4, hidden_size, bias=True),
         | 
| 225 | 
            +
                    )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # Image embedding
         | 
| 228 | 
            +
                    num_patches = self.x_embedder.num_patches
         | 
| 229 | 
            +
                    log_fn(f"    Number of tokens: {num_patches}")
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    # HUnYuanDiT Blocks
         | 
| 232 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 233 | 
            +
                        HunYuanDiTBlock(hidden_size=hidden_size,
         | 
| 234 | 
            +
                                        c_emb_size=hidden_size,
         | 
| 235 | 
            +
                                        num_heads=num_heads,
         | 
| 236 | 
            +
                                        mlp_ratio=mlp_ratio,
         | 
| 237 | 
            +
                                        text_states_dim=self.text_states_dim,
         | 
| 238 | 
            +
                                        use_flash_attn=use_flash_attn,
         | 
| 239 | 
            +
                                        qk_norm=qk_norm,
         | 
| 240 | 
            +
                                        norm_type=self.norm,
         | 
| 241 | 
            +
                                        skip=layer > depth // 2,
         | 
| 242 | 
            +
                                        )
         | 
| 243 | 
            +
                        for layer in range(depth)
         | 
| 244 | 
            +
                    ])
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels)
         | 
| 247 | 
            +
                    self.unpatchify_channels = self.out_channels
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    self.initialize_weights()
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                def forward(self,
         | 
| 252 | 
            +
                            x,
         | 
| 253 | 
            +
                            t,
         | 
| 254 | 
            +
                            encoder_hidden_states=None,
         | 
| 255 | 
            +
                            text_embedding_mask=None,
         | 
| 256 | 
            +
                            encoder_hidden_states_t5=None,
         | 
| 257 | 
            +
                            text_embedding_mask_t5=None,
         | 
| 258 | 
            +
                            image_meta_size=None,
         | 
| 259 | 
            +
                            style=None,
         | 
| 260 | 
            +
                            cos_cis_img=None,
         | 
| 261 | 
            +
                            sin_cis_img=None,
         | 
| 262 | 
            +
                            return_dict=True,
         | 
| 263 | 
            +
                            ):
         | 
| 264 | 
            +
                    """
         | 
| 265 | 
            +
                    Forward pass of the encoder.
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    Parameters
         | 
| 268 | 
            +
                    ----------
         | 
| 269 | 
            +
                    x: torch.Tensor
         | 
| 270 | 
            +
                        (B, D, H, W)
         | 
| 271 | 
            +
                    t: torch.Tensor
         | 
| 272 | 
            +
                        (B)
         | 
| 273 | 
            +
                    encoder_hidden_states: torch.Tensor
         | 
| 274 | 
            +
                        CLIP text embedding, (B, L_clip, D)
         | 
| 275 | 
            +
                    text_embedding_mask: torch.Tensor
         | 
| 276 | 
            +
                        CLIP text embedding mask, (B, L_clip)
         | 
| 277 | 
            +
                    encoder_hidden_states_t5: torch.Tensor
         | 
| 278 | 
            +
                        T5 text embedding, (B, L_t5, D)
         | 
| 279 | 
            +
                    text_embedding_mask_t5: torch.Tensor
         | 
| 280 | 
            +
                        T5 text embedding mask, (B, L_t5)
         | 
| 281 | 
            +
                    image_meta_size: torch.Tensor
         | 
| 282 | 
            +
                        (B, 6)
         | 
| 283 | 
            +
                    style: torch.Tensor
         | 
| 284 | 
            +
                        (B)
         | 
| 285 | 
            +
                    cos_cis_img: torch.Tensor
         | 
| 286 | 
            +
                    sin_cis_img: torch.Tensor
         | 
| 287 | 
            +
                    return_dict: bool
         | 
| 288 | 
            +
                        Whether to return a dictionary.
         | 
| 289 | 
            +
                    """
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    text_states = encoder_hidden_states                     # 2,77,1024
         | 
| 292 | 
            +
                    text_states_t5 = encoder_hidden_states_t5               # 2,256,2048
         | 
| 293 | 
            +
                    text_states_mask = text_embedding_mask.bool()           # 2,77
         | 
| 294 | 
            +
                    text_states_t5_mask = text_embedding_mask_t5.bool()     # 2,256
         | 
| 295 | 
            +
                    b_t5, l_t5, c_t5 = text_states_t5.shape
         | 
| 296 | 
            +
                    text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
         | 
| 297 | 
            +
                    text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1)  # 2,205,1024
         | 
| 298 | 
            +
                    clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    clip_t5_mask = clip_t5_mask
         | 
| 301 | 
            +
                    text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    _, _, oh, ow = x.shape
         | 
| 304 | 
            +
                    th, tw = oh // self.patch_size, ow // self.patch_size
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # ========================= Build time and image embedding =========================
         | 
| 307 | 
            +
                    t = self.t_embedder(t)
         | 
| 308 | 
            +
                    x = self.x_embedder(x)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    # Get image RoPE embedding according to `reso`lution.
         | 
| 311 | 
            +
                    freqs_cis_img = (cos_cis_img, sin_cis_img)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # ========================= Concatenate all extra vectors =========================
         | 
| 314 | 
            +
                    # Build text tokens with pooling
         | 
| 315 | 
            +
                    extra_vec = self.pooler(encoder_hidden_states_t5)
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    # Build image meta size tokens
         | 
| 318 | 
            +
                    image_meta_size = timestep_embedding(image_meta_size.view(-1), 256)   # [B * 6, 256]
         | 
| 319 | 
            +
                    if self.args.use_fp16:
         | 
| 320 | 
            +
                        image_meta_size = image_meta_size.half()
         | 
| 321 | 
            +
                    image_meta_size = image_meta_size.view(-1, 6 * 256)
         | 
| 322 | 
            +
                    extra_vec = torch.cat([extra_vec, image_meta_size], dim=1)  # [B, D + 6 * 256]
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    # Build style tokens
         | 
| 325 | 
            +
                    style_embedding = self.style_embedder(style)
         | 
| 326 | 
            +
                    extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    # Concatenate all extra vectors
         | 
| 329 | 
            +
                    c = t + self.extra_embedder(extra_vec)  # [B, D]
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    # ========================= Forward pass through HunYuanDiT blocks =========================
         | 
| 332 | 
            +
                    skips = []
         | 
| 333 | 
            +
                    for layer, block in enumerate(self.blocks):
         | 
| 334 | 
            +
                        if layer > self.depth // 2:
         | 
| 335 | 
            +
                            skip = skips.pop()
         | 
| 336 | 
            +
                            x = block(x, c, text_states, freqs_cis_img, skip)   # (N, L, D)
         | 
| 337 | 
            +
                        else:
         | 
| 338 | 
            +
                            x = block(x, c, text_states, freqs_cis_img)         # (N, L, D)
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                        if layer < (self.depth // 2 - 1):
         | 
| 341 | 
            +
                            skips.append(x)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    # ========================= Final layer =========================
         | 
| 344 | 
            +
                    x = self.final_layer(x, c)                              # (N, L, patch_size ** 2 * out_channels)
         | 
| 345 | 
            +
                    x = self.unpatchify(x, th, tw)                          # (N, out_channels, H, W)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    if return_dict:
         | 
| 348 | 
            +
                        return {'x': x}
         | 
| 349 | 
            +
                    return x
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                def initialize_weights(self):
         | 
| 352 | 
            +
                    # Initialize transformer layers:
         | 
| 353 | 
            +
                    def _basic_init(module):
         | 
| 354 | 
            +
                        if isinstance(module, nn.Linear):
         | 
| 355 | 
            +
                            torch.nn.init.xavier_uniform_(module.weight)
         | 
| 356 | 
            +
                            if module.bias is not None:
         | 
| 357 | 
            +
                                nn.init.constant_(module.bias, 0)
         | 
| 358 | 
            +
                    self.apply(_basic_init)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
         | 
| 361 | 
            +
                    w = self.x_embedder.proj.weight.data
         | 
| 362 | 
            +
                    nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
         | 
| 363 | 
            +
                    nn.init.constant_(self.x_embedder.proj.bias, 0)
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    # Initialize label embedding table:
         | 
| 366 | 
            +
                    nn.init.normal_(self.extra_embedder[0].weight, std=0.02)
         | 
| 367 | 
            +
                    nn.init.normal_(self.extra_embedder[2].weight, std=0.02)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    # Initialize timestep embedding MLP:
         | 
| 370 | 
            +
                    nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
         | 
| 371 | 
            +
                    nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    # Zero-out adaLN modulation layers in HunYuanDiT blocks:
         | 
| 374 | 
            +
                    for block in self.blocks:
         | 
| 375 | 
            +
                        nn.init.constant_(block.default_modulation[-1].weight, 0)
         | 
| 376 | 
            +
                        nn.init.constant_(block.default_modulation[-1].bias, 0)
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                    # Zero-out output layers:
         | 
| 379 | 
            +
                    nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
         | 
| 380 | 
            +
                    nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
         | 
| 381 | 
            +
                    nn.init.constant_(self.final_layer.linear.weight, 0)
         | 
| 382 | 
            +
                    nn.init.constant_(self.final_layer.linear.bias, 0)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                def unpatchify(self, x, h, w):
         | 
| 385 | 
            +
                    """
         | 
| 386 | 
            +
                    x: (N, T, patch_size**2 * C)
         | 
| 387 | 
            +
                    imgs: (N, H, W, C)
         | 
| 388 | 
            +
                    """
         | 
| 389 | 
            +
                    c = self.unpatchify_channels
         | 
| 390 | 
            +
                    p = self.x_embedder.patch_size[0]
         | 
| 391 | 
            +
                    # h = w = int(x.shape[1] ** 0.5)
         | 
| 392 | 
            +
                    assert h * w == x.shape[1]
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
         | 
| 395 | 
            +
                    x = torch.einsum('nhwpqc->nchpwq', x)
         | 
| 396 | 
            +
                    imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
         | 
| 397 | 
            +
                    return imgs
         | 
| 398 | 
            +
             | 
| 399 | 
            +
             | 
| 400 | 
            +
            #################################################################################
         | 
| 401 | 
            +
            #                            HunYuanDiT Configs                                 #
         | 
| 402 | 
            +
            #################################################################################
         | 
| 403 | 
            +
             | 
| 404 | 
            +
            HUNYUAN_DIT_CONFIG = {
         | 
| 405 | 
            +
                'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637},
         | 
| 406 | 
            +
                'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16},
         | 
| 407 | 
            +
                'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16},
         | 
| 408 | 
            +
                'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12},
         | 
| 409 | 
            +
            }
         | 
    	
        hydit/modules/norm_layers.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class RMSNorm(nn.Module):
         | 
| 6 | 
            +
                def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
         | 
| 7 | 
            +
                    """
         | 
| 8 | 
            +
                    Initialize the RMSNorm normalization layer.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                    Args:
         | 
| 11 | 
            +
                        dim (int): The dimension of the input tensor.
         | 
| 12 | 
            +
                        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    Attributes:
         | 
| 15 | 
            +
                        eps (float): A small value added to the denominator for numerical stability.
         | 
| 16 | 
            +
                        weight (nn.Parameter): Learnable scaling parameter.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                    """
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
                    self.eps = eps
         | 
| 21 | 
            +
                    if elementwise_affine:
         | 
| 22 | 
            +
                        self.weight = nn.Parameter(torch.ones(dim))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def _norm(self, x):
         | 
| 25 | 
            +
                    """
         | 
| 26 | 
            +
                    Apply the RMSNorm normalization to the input tensor.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    Args:
         | 
| 29 | 
            +
                        x (torch.Tensor): The input tensor.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    Returns:
         | 
| 32 | 
            +
                        torch.Tensor: The normalized tensor.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def forward(self, x):
         | 
| 38 | 
            +
                    """
         | 
| 39 | 
            +
                    Forward pass through the RMSNorm layer.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    Args:
         | 
| 42 | 
            +
                        x (torch.Tensor): The input tensor.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    Returns:
         | 
| 45 | 
            +
                        torch.Tensor: The output tensor after applying RMSNorm.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    output = self._norm(x.float()).type_as(x)
         | 
| 49 | 
            +
                    if hasattr(self, "weight"):
         | 
| 50 | 
            +
                        output = output * self.weight
         | 
| 51 | 
            +
                    return output
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            class GroupNorm32(nn.GroupNorm):
         | 
| 55 | 
            +
                def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
         | 
| 56 | 
            +
                    super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(self, x):
         | 
| 59 | 
            +
                    y = super().forward(x).to(x.dtype)
         | 
| 60 | 
            +
                    return y
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            def normalization(channels, dtype=None):
         | 
| 63 | 
            +
                """
         | 
| 64 | 
            +
                Make a standard normalization layer.
         | 
| 65 | 
            +
                :param channels: number of input channels.
         | 
| 66 | 
            +
                :return: an nn.Module for normalization.
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
         | 
    	
        hydit/modules/poolers.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class AttentionPool(nn.Module):
         | 
| 7 | 
            +
                def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
         | 
| 8 | 
            +
                    super().__init__()
         | 
| 9 | 
            +
                    self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
         | 
| 10 | 
            +
                    self.k_proj = nn.Linear(embed_dim, embed_dim)
         | 
| 11 | 
            +
                    self.q_proj = nn.Linear(embed_dim, embed_dim)
         | 
| 12 | 
            +
                    self.v_proj = nn.Linear(embed_dim, embed_dim)
         | 
| 13 | 
            +
                    self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
         | 
| 14 | 
            +
                    self.num_heads = num_heads
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def forward(self, x):
         | 
| 17 | 
            +
                    x = x.permute(1, 0, 2)  # NLC -> LNC
         | 
| 18 | 
            +
                    x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (L+1)NC
         | 
| 19 | 
            +
                    x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (L+1)NC
         | 
| 20 | 
            +
                    x, _ = F.multi_head_attention_forward(
         | 
| 21 | 
            +
                        query=x[:1], key=x, value=x,
         | 
| 22 | 
            +
                        embed_dim_to_check=x.shape[-1],
         | 
| 23 | 
            +
                        num_heads=self.num_heads,
         | 
| 24 | 
            +
                        q_proj_weight=self.q_proj.weight,
         | 
| 25 | 
            +
                        k_proj_weight=self.k_proj.weight,
         | 
| 26 | 
            +
                        v_proj_weight=self.v_proj.weight,
         | 
| 27 | 
            +
                        in_proj_weight=None,
         | 
| 28 | 
            +
                        in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
         | 
| 29 | 
            +
                        bias_k=None,
         | 
| 30 | 
            +
                        bias_v=None,
         | 
| 31 | 
            +
                        add_zero_attn=False,
         | 
| 32 | 
            +
                        dropout_p=0,
         | 
| 33 | 
            +
                        out_proj_weight=self.c_proj.weight,
         | 
| 34 | 
            +
                        out_proj_bias=self.c_proj.bias,
         | 
| 35 | 
            +
                        use_separate_proj_weight=True,
         | 
| 36 | 
            +
                        training=self.training,
         | 
| 37 | 
            +
                        need_weights=False
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    return x.squeeze(0)
         | 
    	
        hydit/modules/posemb_layers.py
    ADDED
    
    | @@ -0,0 +1,225 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from typing import Union
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def _to_tuple(x):
         | 
| 7 | 
            +
                if isinstance(x, int):
         | 
| 8 | 
            +
                    return x, x
         | 
| 9 | 
            +
                else:
         | 
| 10 | 
            +
                    return x
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def get_fill_resize_and_crop(src, tgt):    # src 来源的分辨率   tgt base 分辨率
         | 
| 14 | 
            +
                th, tw = _to_tuple(tgt)
         | 
| 15 | 
            +
                h, w = _to_tuple(src)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                tr = th / tw        # base 分辨率
         | 
| 18 | 
            +
                r = h / w           # 目标分辨率
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                # resize
         | 
| 21 | 
            +
                if r > tr:
         | 
| 22 | 
            +
                    resize_height = th
         | 
| 23 | 
            +
                    resize_width = int(round(th / h * w))
         | 
| 24 | 
            +
                else:
         | 
| 25 | 
            +
                    resize_width = tw
         | 
| 26 | 
            +
                    resize_height = int(round(tw / w * h))    # 根据base分辨率,将目标分辨率resize下来
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                crop_top = int(round((th - resize_height) / 2.0))
         | 
| 29 | 
            +
                crop_left = int(round((tw - resize_width) / 2.0))
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def get_meshgrid(start, *args):
         | 
| 35 | 
            +
                if len(args) == 0:
         | 
| 36 | 
            +
                    # start is grid_size
         | 
| 37 | 
            +
                    num = _to_tuple(start)
         | 
| 38 | 
            +
                    start = (0, 0)
         | 
| 39 | 
            +
                    stop = num
         | 
| 40 | 
            +
                elif len(args) == 1:
         | 
| 41 | 
            +
                    # start is start, args[0] is stop, step is 1
         | 
| 42 | 
            +
                    start = _to_tuple(start)
         | 
| 43 | 
            +
                    stop = _to_tuple(args[0])
         | 
| 44 | 
            +
                    num = (stop[0] - start[0], stop[1] - start[1])
         | 
| 45 | 
            +
                elif len(args) == 2:
         | 
| 46 | 
            +
                    # start is start, args[0] is stop, args[1] is num
         | 
| 47 | 
            +
                    start = _to_tuple(start)       # 左上角   eg: 12,0
         | 
| 48 | 
            +
                    stop = _to_tuple(args[0])      # 右下角   eg: 20,32
         | 
| 49 | 
            +
                    num = _to_tuple(args[1])       # 目标大小  eg: 32,124
         | 
| 50 | 
            +
                else:
         | 
| 51 | 
            +
                    raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份   0-32 中间差值124份
         | 
| 54 | 
            +
                grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
         | 
| 55 | 
            +
                grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         | 
| 56 | 
            +
                grid = np.stack(grid, axis=0)   # [2, W, H]
         | 
| 57 | 
            +
                return grid
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            #################################################################################
         | 
| 60 | 
            +
            #                   Sine/Cosine Positional Embedding Functions                  #
         | 
| 61 | 
            +
            #################################################################################
         | 
| 62 | 
            +
            # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                grid_size: int of the grid height and width
         | 
| 67 | 
            +
                return:
         | 
| 68 | 
            +
                pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                grid = get_meshgrid(start, *args)   # [2, H, w]
         | 
| 71 | 
            +
                # grid_h = np.arange(grid_size, dtype=np.float32)
         | 
| 72 | 
            +
                # grid_w = np.arange(grid_size, dtype=np.float32)
         | 
| 73 | 
            +
                # grid = np.meshgrid(grid_w, grid_h)  # here w goes first
         | 
| 74 | 
            +
                # grid = np.stack(grid, axis=0)   # [2, W, H]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                grid = grid.reshape([2, 1, *grid.shape[1:]])
         | 
| 77 | 
            +
                pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
         | 
| 78 | 
            +
                if cls_token and extra_tokens > 0:
         | 
| 79 | 
            +
                    pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
         | 
| 80 | 
            +
                return pos_embed
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
         | 
| 84 | 
            +
                assert embed_dim % 2 == 0
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 87 | 
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
         | 
| 88 | 
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                emb = np.concatenate([emb_h, emb_w], axis=1)    # (H*W, D)
         | 
| 91 | 
            +
                return emb
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                embed_dim: output dimension for each position
         | 
| 97 | 
            +
                pos: a list of positions to be encoded: size (W,H)
         | 
| 98 | 
            +
                out: (M, D)
         | 
| 99 | 
            +
                """
         | 
| 100 | 
            +
                assert embed_dim % 2 == 0
         | 
| 101 | 
            +
                omega = np.arange(embed_dim // 2, dtype=np.float64)
         | 
| 102 | 
            +
                omega /= embed_dim / 2.
         | 
| 103 | 
            +
                omega = 1. / 10000**omega  # (D/2,)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                pos = pos.reshape(-1)  # (M,)
         | 
| 106 | 
            +
                out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                emb_sin = np.sin(out)   # (M, D/2)
         | 
| 109 | 
            +
                emb_cos = np.cos(out)   # (M, D/2)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
         | 
| 112 | 
            +
                return emb
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            #################################################################################
         | 
| 116 | 
            +
            #                   Rotary Positional Embedding Functions                       #
         | 
| 117 | 
            +
            #################################################################################
         | 
| 118 | 
            +
            # https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
         | 
| 121 | 
            +
                """
         | 
| 122 | 
            +
                This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                Parameters
         | 
| 125 | 
            +
                ----------
         | 
| 126 | 
            +
                embed_dim: int
         | 
| 127 | 
            +
                    embedding dimension size
         | 
| 128 | 
            +
                start: int or tuple of int
         | 
| 129 | 
            +
                    If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
         | 
| 130 | 
            +
                    If len(args) == 2, start is start, args[0] is stop, args[1] is num.
         | 
| 131 | 
            +
                use_real: bool
         | 
| 132 | 
            +
                    If True, return real part and imaginary part separately. Otherwise, return complex numbers.
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                Returns
         | 
| 135 | 
            +
                -------
         | 
| 136 | 
            +
                pos_embed: torch.Tensor
         | 
| 137 | 
            +
                    [HW, D/2]
         | 
| 138 | 
            +
                """
         | 
| 139 | 
            +
                grid = get_meshgrid(start, *args)   # [2, H, w]
         | 
| 140 | 
            +
                grid = grid.reshape([2, 1, *grid.shape[1:]])   # 返回一个采样矩阵  分辨率与目标分辨率一致
         | 
| 141 | 
            +
                pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
         | 
| 142 | 
            +
                return pos_embed
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
         | 
| 146 | 
            +
                assert embed_dim % 4 == 0
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 149 | 
            +
                emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real)  # (H*W, D/4)
         | 
| 150 | 
            +
                emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real)  # (H*W, D/4)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                if use_real:
         | 
| 153 | 
            +
                    cos = torch.cat([emb_h[0], emb_w[0]], dim=1)    # (H*W, D/2)
         | 
| 154 | 
            +
                    sin = torch.cat([emb_h[1], emb_w[1]], dim=1)    # (H*W, D/2)
         | 
| 155 | 
            +
                    return cos, sin
         | 
| 156 | 
            +
                else:
         | 
| 157 | 
            +
                    emb = torch.cat([emb_h, emb_w], dim=1)    # (H*W, D/2)
         | 
| 158 | 
            +
                    return emb
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
         | 
| 166 | 
            +
                and the end index 'end'. The 'theta' parameter scales the frequencies.
         | 
| 167 | 
            +
                The returned tensor contains complex values in complex64 data type.
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                Args:
         | 
| 170 | 
            +
                    dim (int): Dimension of the frequency tensor.
         | 
| 171 | 
            +
                    pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
         | 
| 172 | 
            +
                    theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
         | 
| 173 | 
            +
                    use_real (bool, optional): If True, return real part and imaginary part separately.
         | 
| 174 | 
            +
                                               Otherwise, return complex numbers.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                Returns:
         | 
| 177 | 
            +
                    torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                """
         | 
| 180 | 
            +
                if isinstance(pos, int):
         | 
| 181 | 
            +
                    pos = np.arange(pos)
         | 
| 182 | 
            +
                freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  # [D/2]
         | 
| 183 | 
            +
                t = torch.from_numpy(pos).to(freqs.device)  # type: ignore  # [S]
         | 
| 184 | 
            +
                freqs = torch.outer(t, freqs).float()  # type: ignore   # [S, D/2]
         | 
| 185 | 
            +
                if use_real:
         | 
| 186 | 
            +
                    freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]
         | 
| 187 | 
            +
                    freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]
         | 
| 188 | 
            +
                    return freqs_cos, freqs_sin
         | 
| 189 | 
            +
                else:
         | 
| 190 | 
            +
                    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
         | 
| 191 | 
            +
                    return freqs_cis
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def calc_sizes(rope_img, patch_size, th, tw):
         | 
| 196 | 
            +
                """ 计算 RoPE 的尺寸. """
         | 
| 197 | 
            +
                if rope_img == 'extend':
         | 
| 198 | 
            +
                    # 拓展模式
         | 
| 199 | 
            +
                    sub_args = [(th, tw)]
         | 
| 200 | 
            +
                elif rope_img.startswith('base'):
         | 
| 201 | 
            +
                    # 基于一个尺寸, 其他尺寸插值获得.
         | 
| 202 | 
            +
                    base_size = int(rope_img[4:]) // 8 // patch_size            # 基于512作为base,其他根据512差值得到
         | 
| 203 | 
            +
                    start, stop = get_fill_resize_and_crop((th, tw), base_size)   # 需要在32x32里面 crop的左上角和右下角
         | 
| 204 | 
            +
                    sub_args = [start, stop, (th, tw)]
         | 
| 205 | 
            +
                else:
         | 
| 206 | 
            +
                    raise ValueError(f"Unknown rope_img: {rope_img}")
         | 
| 207 | 
            +
                return sub_args
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def init_image_posemb(rope_img,
         | 
| 211 | 
            +
                                  resolutions,
         | 
| 212 | 
            +
                                  patch_size,
         | 
| 213 | 
            +
                                  hidden_size,
         | 
| 214 | 
            +
                                  num_heads,
         | 
| 215 | 
            +
                                  log_fn,
         | 
| 216 | 
            +
                                  rope_real=True,
         | 
| 217 | 
            +
                                  ):
         | 
| 218 | 
            +
                freqs_cis_img = {}
         | 
| 219 | 
            +
                for reso in resolutions:
         | 
| 220 | 
            +
                    th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
         | 
| 221 | 
            +
                    sub_args = calc_sizes(rope_img, patch_size, th, tw)      #  [左上角, 右下角, 目标高宽]   需要在32x32里面 crop的左上角和右下角
         | 
| 222 | 
            +
                    freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
         | 
| 223 | 
            +
                    log_fn(f"    Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
         | 
| 224 | 
            +
                           f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
         | 
| 225 | 
            +
                return freqs_cis_img
         | 
 
			

