Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	update requirements.
Browse files- .gitattributes +15 -0
- LICENSE +233 -0
- README.md +4 -4
- app.py +1066 -0
- model/__init__.py +0 -0
- model/dera.py +195 -0
- model/dit.py +1090 -0
- model/image_encoder.py +903 -0
- model/prompter.py +107 -0
- model/text_encoder.py +269 -0
- model/vae.py +809 -0
- pipeline/__init__.py +0 -0
- pipeline/i2v_pipeline.py +511 -0
- requirements.txt +148 -0
- samples/1_image1.png +3 -0
- samples/1_out.mp4 +3 -0
- samples/1_prompt.txt +1 -0
- samples/1_sketch1.jpg +3 -0
- samples/1_sketch2.jpg +3 -0
- samples/1_sketch3.jpg +3 -0
- samples/2_image1.jpg +3 -0
- samples/2_out.mp4 +3 -0
- samples/2_prompt.txt +1 -0
- samples/2_sketch1.jpg +3 -0
- samples/2_sketch2.jpg +3 -0
- samples/3_image1.png +3 -0
- samples/3_out.mp4 +3 -0
- samples/3_prompt.txt +1 -0
- samples/3_sketch1.jpg +3 -0
- samples/ToonComposer-Icon.png +3 -0
- samples/ToonComposer-Method.jpg +3 -0
- samples/ToonComposer-TLDR.jpg +3 -0
- scheduler/__init__.py +0 -0
- scheduler/flow_match.py +78 -0
- tooncomposer.py +234 -0
- util/model_util.py +241 -0
- util/optical_flow.py +140 -0
- util/stylesheets.py +0 -0
- util/training_util.py +317 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            samples/1_out.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            samples/2_out.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            samples/3_out.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            samples/1_image1.png filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            samples/3_image1.png filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            samples/ToonComposer-Icon.png filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            samples/1_sketch2.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            +
            samples/1_sketch3.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 44 | 
            +
            samples/2_image1.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 45 | 
            +
            samples/1_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 46 | 
            +
            samples/2_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 47 | 
            +
            samples/2_sketch2.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 48 | 
            +
            samples/3_sketch1.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 49 | 
            +
            samples/ToonComposer-Method.jpg filter=lfs diff=lfs merge=lfs -text
         | 
| 50 | 
            +
            samples/ToonComposer-TLDR.jpg filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,233 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Tencent is pleased to support the open source community by making ToonComposer available.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (C) 2025 Tencent. All rights reserved. 
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            ToonComposer is licensed under the MIT License except for the third-party components listed below, which is licensed under different terms. ToonComposer does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations. 
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            For avoidance of doubts, ToonComposer refers to the inference code, parameters and weights made publicly available by Tencent in accordance with the MIT License in this repository.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            Terms of the MIT License:
         | 
| 10 | 
            +
            --------------------------------------------------------------------
         | 
| 11 | 
            +
            Copyright (C) 2025 Tencent. All rights reserved. 
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the " Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            The above copyright notice and this permission notice (including the next paragraph) shall be included in all copies or substantial portions of the Software.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            The ToonComposer model was developed by Tencent based on the following Open Models.
         | 
| 21 | 
            +
            The ToonComposer inference code was developed by Tencent based on the code of the following Open Models.The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            Open Models Licensed under the Apache-2.0 License:
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            --------------------------------------------------------------------
         | 
| 26 | 
            +
            1.Wan2.1
         | 
| 27 | 
            +
            Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 28 | 
            +
            The code of this model was modified by Tencent.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            --------------------------------------------------------------------
         | 
| 31 | 
            +
            Terms of the Apache-2.0 License:
         | 
| 32 | 
            +
            --------------------------------------------------------------------
         | 
| 33 | 
            +
            Apache License
         | 
| 34 | 
            +
                                       Version 2.0, January 2004
         | 
| 35 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 36 | 
            +
             | 
| 37 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 38 | 
            +
             | 
| 39 | 
            +
               1. Definitions.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 42 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 45 | 
            +
                  the copyright owner that is granting the License.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 48 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 49 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 50 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 51 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 52 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 53 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 56 | 
            +
                  exercising permissions granted by this License.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 59 | 
            +
                  including but not limited to software source code, documentation
         | 
| 60 | 
            +
                  source, and configuration files.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 63 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 64 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 65 | 
            +
                  and conversions to other media types.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 68 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 69 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 70 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 73 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 74 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 75 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 76 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 77 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 78 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 81 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 82 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 83 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 84 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 85 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 86 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 87 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 88 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 89 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 90 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 91 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 92 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 95 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 96 | 
            +
                  subsequently incorporated within the Work.
         | 
| 97 | 
            +
             | 
| 98 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 99 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 100 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 101 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 102 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 103 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 104 | 
            +
             | 
| 105 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 106 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 107 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 108 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 109 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 110 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 111 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 112 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 113 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 114 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 115 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 116 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 117 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 118 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 119 | 
            +
                  as of the date such litigation is filed.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 122 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 123 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 124 | 
            +
                  meet the following conditions:
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 127 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 130 | 
            +
                      stating that You changed the files; and
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 133 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 134 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 135 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 136 | 
            +
                      the Derivative Works; and
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 139 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 140 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 141 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 142 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 143 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 144 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 145 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 146 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 147 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 148 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 149 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 150 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 151 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 152 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 153 | 
            +
                      as modifying the License.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 156 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 157 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 158 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 159 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 160 | 
            +
                  the conditions stated in this License.
         | 
| 161 | 
            +
             | 
| 162 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 163 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 164 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 165 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 166 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 167 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 168 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 169 | 
            +
             | 
| 170 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 171 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 172 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 173 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 174 | 
            +
             | 
| 175 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 176 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 177 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 178 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 179 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 180 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 181 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 182 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 183 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 184 | 
            +
             | 
| 185 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 186 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 187 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 188 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 189 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 190 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 191 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 192 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 193 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 194 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 195 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 198 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 199 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 200 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 201 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 202 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 203 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 204 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 205 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 206 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 207 | 
            +
             | 
| 208 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 209 | 
            +
             | 
| 210 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 213 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 214 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 215 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 216 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 217 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 218 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 219 | 
            +
                  identification within third-party archives.
         | 
| 220 | 
            +
             | 
| 221 | 
            +
               Copyright [yyyy] [name of copyright owner]
         | 
| 222 | 
            +
             | 
| 223 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 224 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 225 | 
            +
               You may obtain a copy of the License at
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 228 | 
            +
             | 
| 229 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 230 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 231 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 232 | 
            +
               See the License for the specific language governing permissions and
         | 
| 233 | 
            +
               limitations under the License.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,12 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            title: ToonComposer
         | 
| 3 | 
            -
            emoji:  | 
| 4 | 
             
            colorFrom: gray
         | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 5. | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
| 11 |  | 
| 12 | 
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
             
            title: ToonComposer
         | 
| 3 | 
            +
            emoji: 🎨
         | 
| 4 | 
             
            colorFrom: gray
         | 
| 5 | 
            +
            colorTo: yellow
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 5.25.2
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
| 11 |  | 
| 12 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,1066 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            from tooncomposer import ToonComposer, get_base_model_paths
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            from util.training_util import extract_img_to_sketch
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import tempfile
         | 
| 10 | 
            +
            import cv2
         | 
| 11 | 
            +
            import gradio as gr
         | 
| 12 | 
            +
            from einops import rearrange
         | 
| 13 | 
            +
            from datetime import datetime
         | 
| 14 | 
            +
            from typing import Optional, List, Dict
         | 
| 15 | 
            +
            from huggingface_hub import snapshot_download
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # -----------------------------------------------------------------------------
         | 
| 20 | 
            +
            # Weights resolution and download helpers
         | 
| 21 | 
            +
            # -----------------------------------------------------------------------------
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            WAN_REPO_ID = "Wan-AI/Wan2.1-I2V-14B-480P"
         | 
| 24 | 
            +
            TOONCOMPOSER_REPO_ID = "TencentARC/ToonComposer"
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            def _path_is_dir_with_files(dir_path: str, required_files: List[str]) -> bool:
         | 
| 27 | 
            +
                if not dir_path or not os.path.isdir(dir_path):
         | 
| 28 | 
            +
                    return False
         | 
| 29 | 
            +
                for f in required_files:
         | 
| 30 | 
            +
                    if not os.path.exists(os.path.join(dir_path, f)):
         | 
| 31 | 
            +
                        return False
         | 
| 32 | 
            +
                return True
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def resolve_wan_model_root(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str:
         | 
| 35 | 
            +
                """Return a directory containing Wan2.1-I2V-14B-480P weights.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                Resolution order:
         | 
| 38 | 
            +
                1) preferred_dir arg (if valid)
         | 
| 39 | 
            +
                2) WAN21_I2V_DIR env var (if valid)
         | 
| 40 | 
            +
                3) HF local cache (no download) via snapshot_download(local_files_only=True)
         | 
| 41 | 
            +
                4) HF download to cache via snapshot_download()
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                # Required filenames relative to the model root
         | 
| 44 | 
            +
                expected = get_base_model_paths("Wan2.1-I2V-14B-480P", format='dict', model_root=".")
         | 
| 45 | 
            +
                required_files = []
         | 
| 46 | 
            +
                required_files.extend([os.path.basename(p) for p in expected["dit"]])
         | 
| 47 | 
            +
                required_files.append(os.path.basename(expected["image_encoder"]))
         | 
| 48 | 
            +
                required_files.append(os.path.basename(expected["text_encoder"]))
         | 
| 49 | 
            +
                required_files.append(os.path.basename(expected["vae"]))
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # 1) preferred_dir arg
         | 
| 52 | 
            +
                if _path_is_dir_with_files(preferred_dir or "", required_files):
         | 
| 53 | 
            +
                    return os.path.abspath(preferred_dir)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                # 2) environment variable
         | 
| 56 | 
            +
                env_dir = os.environ.get("WAN21_I2V_DIR")
         | 
| 57 | 
            +
                if _path_is_dir_with_files(env_dir or "", required_files):
         | 
| 58 | 
            +
                    return os.path.abspath(env_dir)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # 3) try local cache without network
         | 
| 61 | 
            +
                try:
         | 
| 62 | 
            +
                    cached_dir = snapshot_download(repo_id=WAN_REPO_ID, local_files_only=True)
         | 
| 63 | 
            +
                    return cached_dir
         | 
| 64 | 
            +
                except Exception:
         | 
| 65 | 
            +
                    pass
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                # 4) download (may be large)
         | 
| 68 | 
            +
                cached_dir = snapshot_download(repo_id=WAN_REPO_ID, token=hf_token)
         | 
| 69 | 
            +
                return cached_dir
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            def resolve_tooncomposer_repo_dir(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str:
         | 
| 72 | 
            +
                """Return a directory containing ToonComposer repo with 480p/608p subdirs."""
         | 
| 73 | 
            +
                # Quick validity check: ensure either a subdir 480p or 608p exists with required files
         | 
| 74 | 
            +
                def has_resolution_dirs(base_dir: str) -> bool:
         | 
| 75 | 
            +
                    if not base_dir or not os.path.isdir(base_dir):
         | 
| 76 | 
            +
                        return False
         | 
| 77 | 
            +
                    ok = False
         | 
| 78 | 
            +
                    for res in ["480p", "608p"]:
         | 
| 79 | 
            +
                        d = os.path.join(base_dir, res)
         | 
| 80 | 
            +
                        if os.path.isdir(d):
         | 
| 81 | 
            +
                            ckpt = os.path.join(d, "tooncomposer.ckpt")
         | 
| 82 | 
            +
                            cfg = os.path.join(d, "config.json")
         | 
| 83 | 
            +
                            if os.path.exists(ckpt) and os.path.exists(cfg):
         | 
| 84 | 
            +
                                ok = True
         | 
| 85 | 
            +
                    return ok
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                # 1) preferred_dir arg
         | 
| 88 | 
            +
                if has_resolution_dirs(preferred_dir or ""):
         | 
| 89 | 
            +
                    return os.path.abspath(preferred_dir)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                # 2) environment variable
         | 
| 92 | 
            +
                env_dir = os.environ.get("TOONCOMPOSER_DIR")
         | 
| 93 | 
            +
                if has_resolution_dirs(env_dir or ""):
         | 
| 94 | 
            +
                    return os.path.abspath(env_dir)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                # 3) try local cache first
         | 
| 97 | 
            +
                try:
         | 
| 98 | 
            +
                    cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, local_files_only=True)
         | 
| 99 | 
            +
                    return cached_dir
         | 
| 100 | 
            +
                except Exception:
         | 
| 101 | 
            +
                    pass
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # 4) download repo to cache
         | 
| 104 | 
            +
                cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, token=hf_token)
         | 
| 105 | 
            +
                return cached_dir
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            def build_checkpoints_by_resolution(tooncomposer_base_dir: str) -> Dict[str, Dict[str, object]]:
         | 
| 108 | 
            +
                """Construct resolution mapping from a base repo dir that contains 480p/608p.
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                The ToonComposer HF repo stores, inside each resolution dir:
         | 
| 111 | 
            +
                  - tooncomposer.ckpt
         | 
| 112 | 
            +
                  - config.json (model configuration)
         | 
| 113 | 
            +
                """
         | 
| 114 | 
            +
                mapping = {}
         | 
| 115 | 
            +
                # Known target sizes
         | 
| 116 | 
            +
                res_to_hw = {
         | 
| 117 | 
            +
                    "480p": (480, 832),
         | 
| 118 | 
            +
                    "608p": (608, 1088),
         | 
| 119 | 
            +
                }
         | 
| 120 | 
            +
                for res, (h, w) in res_to_hw.items():
         | 
| 121 | 
            +
                    res_dir = os.path.join(tooncomposer_base_dir, res)
         | 
| 122 | 
            +
                    mapping[res] = {
         | 
| 123 | 
            +
                        "target_height": h,
         | 
| 124 | 
            +
                        "target_width": w,
         | 
| 125 | 
            +
                        "snapshot_args_path": os.path.join(res_dir, "config.json"),
         | 
| 126 | 
            +
                        "checkpoint_path": os.path.join(res_dir, "tooncomposer.ckpt"),
         | 
| 127 | 
            +
                    }
         | 
| 128 | 
            +
                return mapping
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            # Will be populated in main() after resolving ToonComposer repo directory
         | 
| 131 | 
            +
            checkpoints_by_resolution = {}
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            def tensor2video(frames):
         | 
| 134 | 
            +
                frames = rearrange(frames, "C T H W -> T H W C")
         | 
| 135 | 
            +
                frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
         | 
| 136 | 
            +
                frames = [Image.fromarray(frame) for frame in frames]
         | 
| 137 | 
            +
                return frames
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            def _load_model_config(config_path: str) -> Dict[str, object]:
         | 
| 140 | 
            +
                with open(config_path, "r") as f:
         | 
| 141 | 
            +
                    data = json.load(f)
         | 
| 142 | 
            +
                return data
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            def _merge_with_defaults(cfg: Dict[str, object]) -> Dict[str, object]:
         | 
| 145 | 
            +
                # Provide safe defaults for optional fields used at inference-time
         | 
| 146 | 
            +
                defaults = {
         | 
| 147 | 
            +
                    "base_model_name": "Wan2.1-I2V-14B-480P",
         | 
| 148 | 
            +
                    "learning_rate": 1e-5,
         | 
| 149 | 
            +
                    "train_architecture": "lora",
         | 
| 150 | 
            +
                    "lora_rank": 4,
         | 
| 151 | 
            +
                    "lora_alpha": 4,
         | 
| 152 | 
            +
                    "lora_target_modules": "q,k,v,o,ffn.0,ffn.2",
         | 
| 153 | 
            +
                    "init_lora_weights": "kaiming",
         | 
| 154 | 
            +
                    "use_gradient_checkpointing": True,
         | 
| 155 | 
            +
                    "tiled": False,
         | 
| 156 | 
            +
                    "tile_size_height": 34,
         | 
| 157 | 
            +
                    "tile_size_width": 34,
         | 
| 158 | 
            +
                    "tile_stride_height": 18,
         | 
| 159 | 
            +
                    "tile_stride_width": 16,
         | 
| 160 | 
            +
                    "output_path": "./",
         | 
| 161 | 
            +
                    "use_local_lora": False,
         | 
| 162 | 
            +
                    "use_dera": False,
         | 
| 163 | 
            +
                    "dera_rank": None,
         | 
| 164 | 
            +
                    "use_dera_spatial": True,
         | 
| 165 | 
            +
                    "use_dera_temporal": True,
         | 
| 166 | 
            +
                    "use_sequence_cond": True,
         | 
| 167 | 
            +
                    "sequence_cond_mode": "sparse",
         | 
| 168 | 
            +
                    "use_channel_cond": False,
         | 
| 169 | 
            +
                    "use_sequence_cond_position_aware_residual": True,
         | 
| 170 | 
            +
                    "use_sequence_cond_loss": False,
         | 
| 171 | 
            +
                    "fast_dev": False,
         | 
| 172 | 
            +
                    "max_num_cond_images": 1,
         | 
| 173 | 
            +
                    "max_num_cond_sketches": 2,
         | 
| 174 | 
            +
                    "visualize_attention": False,
         | 
| 175 | 
            +
                    "random_spaced_cond_frames": False,
         | 
| 176 | 
            +
                    "use_sketch_mask": True,
         | 
| 177 | 
            +
                    "sketch_mask_ratio": 0.2,
         | 
| 178 | 
            +
                    "no_first_sketch": False,
         | 
| 179 | 
            +
                }
         | 
| 180 | 
            +
                merged = defaults.copy()
         | 
| 181 | 
            +
                merged.update(cfg)
         | 
| 182 | 
            +
                return merged
         | 
| 183 | 
            +
             | 
| 184 | 
            +
            def initialize_model(resolution="480p", fast_dev=False, device="cuda:0", dtype=torch.bfloat16,
         | 
| 185 | 
            +
                                 wan_model_dir: Optional[str] = None, tooncomposer_dir: Optional[str] = None,
         | 
| 186 | 
            +
                                 hf_token: Optional[str] = None):
         | 
| 187 | 
            +
                # Initialize model components
         | 
| 188 | 
            +
                if resolution not in checkpoints_by_resolution:
         | 
| 189 | 
            +
                    raise ValueError(f"Resolution '{resolution}' is not available. Found: {list(checkpoints_by_resolution.keys())}")
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                # 1) resolve config and checkpoint from ToonComposer repo (local or HF)
         | 
| 192 | 
            +
                snapshot_args_path = checkpoints_by_resolution[resolution]["snapshot_args_path"]
         | 
| 193 | 
            +
                checkpoint_path = checkpoints_by_resolution[resolution]["checkpoint_path"]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                # 2) load model config
         | 
| 196 | 
            +
                snapshot_args_raw = _load_model_config(snapshot_args_path)
         | 
| 197 | 
            +
                snapshot_args = _merge_with_defaults(snapshot_args_raw)
         | 
| 198 | 
            +
                snapshot_args["checkpoint_path"] = checkpoint_path
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                # 3) resolve Wan2.1 model root
         | 
| 201 | 
            +
                snapshot_args["model_root"] = resolve_wan_model_root(preferred_dir=wan_model_dir, hf_token=hf_token)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                # Backward-compat fields
         | 
| 204 | 
            +
                if "training_max_frame_stride" not in snapshot_args:
         | 
| 205 | 
            +
                    snapshot_args["training_max_frame_stride"] = 4
         | 
| 206 | 
            +
                snapshot_args["random_spaced_cond_frames"] = False
         | 
| 207 | 
            +
                args = argparse.Namespace(**snapshot_args)
         | 
| 208 | 
            +
                if not fast_dev:
         | 
| 209 | 
            +
                    model = ToonComposer(
         | 
| 210 | 
            +
                        base_model_name=args.base_model_name,
         | 
| 211 | 
            +
                        model_root=args.model_root,
         | 
| 212 | 
            +
                        learning_rate=args.learning_rate,
         | 
| 213 | 
            +
                        train_architecture=args.train_architecture,
         | 
| 214 | 
            +
                        lora_rank=args.lora_rank,
         | 
| 215 | 
            +
                        lora_alpha=args.lora_alpha,
         | 
| 216 | 
            +
                        lora_target_modules=args.lora_target_modules,
         | 
| 217 | 
            +
                        init_lora_weights=args.init_lora_weights,
         | 
| 218 | 
            +
                        use_gradient_checkpointing=args.use_gradient_checkpointing,
         | 
| 219 | 
            +
                        checkpoint_path=args.checkpoint_path,
         | 
| 220 | 
            +
                        tiled=args.tiled,
         | 
| 221 | 
            +
                        tile_size=(args.tile_size_height, args.tile_size_width),
         | 
| 222 | 
            +
                        tile_stride=(args.tile_stride_height, args.tile_stride_width),
         | 
| 223 | 
            +
                        output_path=args.output_path,
         | 
| 224 | 
            +
                        use_local_lora=args.use_local_lora,
         | 
| 225 | 
            +
                        use_dera=args.use_dera,
         | 
| 226 | 
            +
                        dera_rank=args.dera_rank,
         | 
| 227 | 
            +
                        use_dera_spatial=args.use_dera_spatial,
         | 
| 228 | 
            +
                        use_dera_temporal=args.use_dera_temporal,
         | 
| 229 | 
            +
                        use_sequence_cond=args.use_sequence_cond,
         | 
| 230 | 
            +
                        sequence_cond_mode=args.sequence_cond_mode,
         | 
| 231 | 
            +
                        use_channel_cond=args.use_channel_cond,
         | 
| 232 | 
            +
                        use_sequence_cond_position_aware_residual=args.use_sequence_cond_position_aware_residual,
         | 
| 233 | 
            +
                        use_sequence_cond_loss=args.use_sequence_cond_loss,
         | 
| 234 | 
            +
                        fast_dev=args.fast_dev,
         | 
| 235 | 
            +
                        max_num_cond_images=args.max_num_cond_images,
         | 
| 236 | 
            +
                        max_num_cond_sketches=args.max_num_cond_sketches,
         | 
| 237 | 
            +
                        visualize_attention=args.visualize_attention,
         | 
| 238 | 
            +
                        random_spaced_cond_frames=args.random_spaced_cond_frames,
         | 
| 239 | 
            +
                        use_sketch_mask=args.use_sketch_mask,
         | 
| 240 | 
            +
                        sketch_mask_ratio=args.sketch_mask_ratio,
         | 
| 241 | 
            +
                        no_first_sketch=args.no_first_sketch,
         | 
| 242 | 
            +
                    )
         | 
| 243 | 
            +
                    model = model.to(device, dtype=dtype).eval()
         | 
| 244 | 
            +
                else:
         | 
| 245 | 
            +
                    print("Fast dev mode. Models will not be loaded.")
         | 
| 246 | 
            +
                    model = None
         | 
| 247 | 
            +
                print("Models initialized.")
         | 
| 248 | 
            +
                return model, device, dtype
         | 
| 249 | 
            +
             | 
| 250 | 
            +
            # -----------------------------------------------------------------------------
         | 
| 251 | 
            +
            # CLI args and global initialization
         | 
| 252 | 
            +
            # -----------------------------------------------------------------------------
         | 
| 253 | 
            +
             | 
| 254 | 
            +
            def _parse_args():
         | 
| 255 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 256 | 
            +
                parser.add_argument("--resolution", type=str, default=os.environ.get("TOONCOMPOSER_RESOLUTION", "480p"), choices=["480p", "608p"], help="Target resolution to load by default.")
         | 
| 257 | 
            +
                parser.add_argument("--device", type=str, default=os.environ.get("DEVICE", "cuda"))
         | 
| 258 | 
            +
                parser.add_argument("--dtype", type=str, default=os.environ.get("DTYPE", "bfloat16"), choices=["bfloat16", "float32"])
         | 
| 259 | 
            +
                parser.add_argument("--wan_model_dir", type=str, default=os.environ.get("WAN21_I2V_DIR"), help="Local directory containing Wan2.1 model files. If not provided, will try HF cache and download if needed.")
         | 
| 260 | 
            +
                parser.add_argument("--tooncomposer_dir", type=str, default=os.environ.get("TOONCOMPOSER_DIR"), help="Local directory containing ToonComposer weights with 480p/608p subdirectories. If not provided, will try HF cache and download if needed.")
         | 
| 261 | 
            +
                parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="Hugging Face token (if needed for gated models).")
         | 
| 262 | 
            +
                parser.add_argument("--fast_dev", action="store_true", help="Run in fast dev mode without loading heavy models.")
         | 
| 263 | 
            +
                return parser.parse_args()
         | 
| 264 | 
            +
             | 
| 265 | 
            +
            _cli_args = _parse_args()
         | 
| 266 | 
            +
             | 
| 267 | 
            +
            # Resolve ToonComposer repo dir and build resolution mapping
         | 
| 268 | 
            +
            _toon_dir = resolve_tooncomposer_repo_dir(preferred_dir=_cli_args.tooncomposer_dir, hf_token=_cli_args.hf_token)
         | 
| 269 | 
            +
            checkpoints_by_resolution = build_checkpoints_by_resolution(_toon_dir)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
            _dtype_map = {
         | 
| 272 | 
            +
                "bfloat16": torch.bfloat16,
         | 
| 273 | 
            +
                "float32": torch.float32,
         | 
| 274 | 
            +
            }
         | 
| 275 | 
            +
            fast_dev = bool(_cli_args.fast_dev)
         | 
| 276 | 
            +
            model, device, dtype = initialize_model(
         | 
| 277 | 
            +
                resolution=_cli_args.resolution,
         | 
| 278 | 
            +
                fast_dev=fast_dev,
         | 
| 279 | 
            +
                device=_cli_args.device,
         | 
| 280 | 
            +
                dtype=_dtype_map[_cli_args.dtype],
         | 
| 281 | 
            +
                wan_model_dir=_cli_args.wan_model_dir,
         | 
| 282 | 
            +
                tooncomposer_dir=_cli_args.tooncomposer_dir,
         | 
| 283 | 
            +
                hf_token=_cli_args.hf_token,
         | 
| 284 | 
            +
            )
         | 
| 285 | 
            +
             | 
| 286 | 
            +
            def process_conditions(num_items, item_inputs, num_frames, is_sketch=False, target_height=480, target_width=832):
         | 
| 287 | 
            +
                """Process condition images/sketches into masked video tensor and mask"""
         | 
| 288 | 
            +
                # Create empty tensors filled with -1
         | 
| 289 | 
            +
                video = torch.zeros((1, 3, num_frames, target_height, target_width), device=device)
         | 
| 290 | 
            +
                mask = torch.zeros((1, num_frames), device=device)
         | 
| 291 | 
            +
                
         | 
| 292 | 
            +
                for i in range(num_items):
         | 
| 293 | 
            +
                    img, frame_idx = item_inputs[i]
         | 
| 294 | 
            +
                    if img is None or frame_idx is None:
         | 
| 295 | 
            +
                        continue
         | 
| 296 | 
            +
                        
         | 
| 297 | 
            +
                    # Convert PIL image to tensor
         | 
| 298 | 
            +
                    img_tensor = torch.from_numpy(np.array(img)).permute(2,0,1).float() / 127.5 - 1.0
         | 
| 299 | 
            +
                    if is_sketch:
         | 
| 300 | 
            +
                        img_tensor = -img_tensor
         | 
| 301 | 
            +
                    img_tensor = img_tensor.unsqueeze(0).to(device)
         | 
| 302 | 
            +
                    
         | 
| 303 | 
            +
                    # Resize to model's expected resolution while preserving aspect ratio
         | 
| 304 | 
            +
                    # Get original dimensions
         | 
| 305 | 
            +
                    _, _, h, w = img_tensor.shape
         | 
| 306 | 
            +
                    
         | 
| 307 | 
            +
                    # Resize based on short edge while maintaining aspect ratio
         | 
| 308 | 
            +
                    if h/w < target_height/target_width:
         | 
| 309 | 
            +
                        new_h = target_height
         | 
| 310 | 
            +
                        new_w = int(w * (new_h / h))
         | 
| 311 | 
            +
                    else:  # Width is the short edge
         | 
| 312 | 
            +
                        new_w = target_width
         | 
| 313 | 
            +
                        new_h = int(h * (new_w / w))
         | 
| 314 | 
            +
                        
         | 
| 315 | 
            +
                    # Resize with the calculated dimensions
         | 
| 316 | 
            +
                    img_tensor = torch.nn.functional.interpolate(img_tensor, size=(new_h, new_w), mode="bilinear")
         | 
| 317 | 
            +
                    
         | 
| 318 | 
            +
                    # Center crop to target resolution if needed
         | 
| 319 | 
            +
                    if new_h > target_height or new_w > target_width:
         | 
| 320 | 
            +
                        # Calculate starting positions for crop
         | 
| 321 | 
            +
                        start_h = max(0, (new_h - target_height) // 2)
         | 
| 322 | 
            +
                        start_w = max(0, (new_w - target_width) // 2)
         | 
| 323 | 
            +
                        # Crop
         | 
| 324 | 
            +
                        img_tensor = img_tensor[:, :, start_h:start_h+target_height, start_w:start_w+target_width]
         | 
| 325 | 
            +
                    
         | 
| 326 | 
            +
                    # Place in video tensor
         | 
| 327 | 
            +
                    frame_idx = min(max(int(frame_idx), 0), num_frames-1)
         | 
| 328 | 
            +
                    if is_sketch:
         | 
| 329 | 
            +
                        video[:, :, frame_idx] = img_tensor[:, :3]  # Handle RGBA sketches
         | 
| 330 | 
            +
                    else:
         | 
| 331 | 
            +
                        video[:, :, frame_idx] = img_tensor
         | 
| 332 | 
            +
                    mask[:, frame_idx] = 1.0
         | 
| 333 | 
            +
                return video, mask
         | 
| 334 | 
            +
             | 
| 335 | 
            +
            def process_sketch_masks(num_sketch_masks, sketch_mask_inputs, num_frames, target_height=480, target_width=832):
         | 
| 336 | 
            +
                """Process sketch masks into a single tensor"""
         | 
| 337 | 
            +
                # Create empty tensor filled with 1s (1 means no mask, keep original)
         | 
| 338 | 
            +
                sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device)
         | 
| 339 | 
            +
                
         | 
| 340 | 
            +
                for i in range(num_sketch_masks):
         | 
| 341 | 
            +
                    editor_value, frame_idx = sketch_mask_inputs[i]
         | 
| 342 | 
            +
                    if editor_value is None or frame_idx is None:
         | 
| 343 | 
            +
                        continue
         | 
| 344 | 
            +
                        
         | 
| 345 | 
            +
                    # For ImageMask, we need to extract the mask from the editor_value dictionary
         | 
| 346 | 
            +
                    # editor_value is a dict with 'background', 'layers', and 'composite' keys from ImageEditor
         | 
| 347 | 
            +
                    if isinstance(editor_value, dict):
         | 
| 348 | 
            +
                        if "composite" in editor_value and editor_value["composite"] is not None:
         | 
| 349 | 
            +
                            # The 'composite' is the image with mask drawn on it
         | 
| 350 | 
            +
                            # Since we're using ImageMask with fixed black brush, the black areas are the mask
         | 
| 351 | 
            +
                            # Convert the composite to a binary mask (0=masked, 1=not masked)
         | 
| 352 | 
            +
                            # sketch = editor_value["background"]  # This is the sketch
         | 
| 353 | 
            +
                            mask = editor_value["layers"][0] if editor_value["layers"] else None  # This is the mask layer
         | 
| 354 | 
            +
                            if mask is not None:
         | 
| 355 | 
            +
                                # Convert mask to tensor and normalize
         | 
| 356 | 
            +
                                mask_array = np.array(mask)
         | 
| 357 | 
            +
                                mask_array = np.max(mask_array, axis=2)
         | 
| 358 | 
            +
                                
         | 
| 359 | 
            +
                                # Convert to tensor, normalize to [0, 1]
         | 
| 360 | 
            +
                                mask_tensor = torch.from_numpy(mask_array).float()
         | 
| 361 | 
            +
                                if mask_tensor.max() > 1.0:
         | 
| 362 | 
            +
                                    mask_tensor = mask_tensor / 255.0
         | 
| 363 | 
            +
                                
         | 
| 364 | 
            +
                                # Resize to model's expected resolution
         | 
| 365 | 
            +
                                mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0)  # [1, 1, h, w]
         | 
| 366 | 
            +
                                mask_tensor = torch.nn.functional.interpolate(mask_tensor, size=(target_height, target_width), mode="nearest")
         | 
| 367 | 
            +
                                
         | 
| 368 | 
            +
                                # Invert the mask: black (0) = masked area, white (1) = keep original
         | 
| 369 | 
            +
                                # We need to invert because in the UI black means "masked"
         | 
| 370 | 
            +
                                mask_tensor = 1.0 - mask_tensor
         | 
| 371 | 
            +
                                
         | 
| 372 | 
            +
                                # Place in sketch_local_mask tensor
         | 
| 373 | 
            +
                                frame_idx = min(max(int(frame_idx), 0), num_frames-1)
         | 
| 374 | 
            +
                                sketch_local_mask[:, :, frame_idx] = mask_tensor
         | 
| 375 | 
            +
                                
         | 
| 376 | 
            +
                sketch_mask_vis = torch.ones((1, 3, num_frames, target_height, target_width), device=device)
         | 
| 377 | 
            +
                for t in range(sketch_local_mask.shape[2]):
         | 
| 378 | 
            +
                    for c in range(3):
         | 
| 379 | 
            +
                        sketch_mask_vis[0, c, t, :, :] = torch.where(
         | 
| 380 | 
            +
                            sketch_local_mask[0, 0, t] > 0.5,
         | 
| 381 | 
            +
                            1.0,  # White for unmasked areas
         | 
| 382 | 
            +
                            -1.0  # Black for masked areas
         | 
| 383 | 
            +
                        )
         | 
| 384 | 
            +
                return sketch_local_mask
         | 
| 385 | 
            +
             | 
| 386 | 
            +
             | 
| 387 | 
            +
            def invert_sketch(image):
         | 
| 388 | 
            +
                """Invert the colors of an image (black to white, white to black)"""
         | 
| 389 | 
            +
                if image is None:
         | 
| 390 | 
            +
                    return None
         | 
| 391 | 
            +
                
         | 
| 392 | 
            +
                # Handle input from ImageMask component (EditorValue dictionary)
         | 
| 393 | 
            +
                if isinstance(image, dict) and "background" in image:
         | 
| 394 | 
            +
                    # Extract the background image
         | 
| 395 | 
            +
                    bg_image = image["background"]
         | 
| 396 | 
            +
                    
         | 
| 397 | 
            +
                    # Invert the background
         | 
| 398 | 
            +
                    inverted_bg = invert_sketch_internal(bg_image)
         | 
| 399 | 
            +
                    
         | 
| 400 | 
            +
                    # Return updated editor value
         | 
| 401 | 
            +
                    return gr.update(value=inverted_bg)
         | 
| 402 | 
            +
                
         | 
| 403 | 
            +
                # Original function for regular images
         | 
| 404 | 
            +
                return invert_sketch_internal(image)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
            def invert_sketch_internal(image):
         | 
| 407 | 
            +
                """Internal function to invert an image"""
         | 
| 408 | 
            +
                if image is None:
         | 
| 409 | 
            +
                    return None
         | 
| 410 | 
            +
                
         | 
| 411 | 
            +
                # Convert to PIL image if needed
         | 
| 412 | 
            +
                if isinstance(image, str):  # If it's a filepath
         | 
| 413 | 
            +
                    image = Image.open(image)
         | 
| 414 | 
            +
                elif isinstance(image, np.ndarray):
         | 
| 415 | 
            +
                    image = Image.fromarray(image)
         | 
| 416 | 
            +
                
         | 
| 417 | 
            +
                # Ensure it's a PIL image now
         | 
| 418 | 
            +
                if not isinstance(image, Image.Image):
         | 
| 419 | 
            +
                    try:
         | 
| 420 | 
            +
                        image = Image.fromarray(np.array(image))
         | 
| 421 | 
            +
                    except:
         | 
| 422 | 
            +
                        print(f"Warning: Could not convert image of type {type(image)} to PIL Image")
         | 
| 423 | 
            +
                        return image
         | 
| 424 | 
            +
                
         | 
| 425 | 
            +
                # Invert the image
         | 
| 426 | 
            +
                inverted = Image.fromarray(255 - np.array(image))
         | 
| 427 | 
            +
                return inverted
         | 
| 428 | 
            +
             | 
| 429 | 
            +
            def create_blank_mask(canvas_width=832, canvas_height=480):
         | 
| 430 | 
            +
                """Create a blank white mask image"""
         | 
| 431 | 
            +
                return Image.new('RGB', (canvas_width, canvas_height), color='white')
         | 
| 432 | 
            +
             | 
| 433 | 
            +
            def create_mask_with_sketch(sketch, canvas_width=832, canvas_height=480):
         | 
| 434 | 
            +
                """Create a mask image with sketch as background"""
         | 
| 435 | 
            +
                if sketch is None:
         | 
| 436 | 
            +
                    return create_blank_mask(canvas_width, canvas_height)
         | 
| 437 | 
            +
                    
         | 
| 438 | 
            +
                # Convert sketch to PIL if needed
         | 
| 439 | 
            +
                if not isinstance(sketch, Image.Image):
         | 
| 440 | 
            +
                    sketch = Image.fromarray(np.array(sketch))
         | 
| 441 | 
            +
                
         | 
| 442 | 
            +
                # Resize sketch to fit the canvas
         | 
| 443 | 
            +
                sketch = sketch.resize((canvas_width, canvas_height))
         | 
| 444 | 
            +
                
         | 
| 445 | 
            +
                # Create a semi-transparent white layer over the sketch
         | 
| 446 | 
            +
                overlay = Image.new('RGBA', (canvas_width, canvas_height), (255, 255, 255, 128))
         | 
| 447 | 
            +
                
         | 
| 448 | 
            +
                # Ensure sketch has alpha channel
         | 
| 449 | 
            +
                if sketch.mode != 'RGBA':
         | 
| 450 | 
            +
                    sketch = sketch.convert('RGBA')
         | 
| 451 | 
            +
                
         | 
| 452 | 
            +
                # Overlay the semi-transparent white layer on the sketch
         | 
| 453 | 
            +
                result = Image.alpha_composite(sketch, overlay)
         | 
| 454 | 
            +
                
         | 
| 455 | 
            +
                # Convert back to RGB for Gradio
         | 
| 456 | 
            +
                return result.convert('RGB')
         | 
| 457 | 
            +
             | 
| 458 | 
            +
            def validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args):
         | 
| 459 | 
            +
                """Validate user inputs and return error messages if any"""
         | 
| 460 | 
            +
                errors = []
         | 
| 461 | 
            +
                
         | 
| 462 | 
            +
                # Check text prompt
         | 
| 463 | 
            +
                if not text_prompt or text_prompt.strip() == "":
         | 
| 464 | 
            +
                    errors.append("❌ Text prompt is required. Please enter a description for your video.")
         | 
| 465 | 
            +
                
         | 
| 466 | 
            +
                # Check condition images
         | 
| 467 | 
            +
                cond_images_count = 0
         | 
| 468 | 
            +
                for i in range(int(num_cond_images)):
         | 
| 469 | 
            +
                    img = args[i*2]
         | 
| 470 | 
            +
                    frame_idx = args[i*2+1]
         | 
| 471 | 
            +
                    
         | 
| 472 | 
            +
                    if img is None:
         | 
| 473 | 
            +
                        errors.append(f"❌ Image #{i+1} is missing. Please upload an image or reduce the number of keyframe images.")
         | 
| 474 | 
            +
                    else:
         | 
| 475 | 
            +
                        cond_images_count += 1
         | 
| 476 | 
            +
                        
         | 
| 477 | 
            +
                    if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames):
         | 
| 478 | 
            +
                        errors.append(f"❌ Frame index for Image #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.")
         | 
| 479 | 
            +
                
         | 
| 480 | 
            +
                # Check condition sketches
         | 
| 481 | 
            +
                num_cond_sketches_index = 8  # Starting index for sketch inputs
         | 
| 482 | 
            +
                cond_sketches_count = 0
         | 
| 483 | 
            +
                sketch_frame_indices = []
         | 
| 484 | 
            +
                
         | 
| 485 | 
            +
                for i in range(int(num_cond_sketches)):
         | 
| 486 | 
            +
                    sketch_idx = num_cond_sketches_index + i*2
         | 
| 487 | 
            +
                    frame_idx_idx = num_cond_sketches_index + 1 + i*2
         | 
| 488 | 
            +
                    
         | 
| 489 | 
            +
                    if sketch_idx < len(args) and frame_idx_idx < len(args):
         | 
| 490 | 
            +
                        sketch = args[sketch_idx]
         | 
| 491 | 
            +
                        frame_idx = args[frame_idx_idx]
         | 
| 492 | 
            +
                        
         | 
| 493 | 
            +
                        # Check if sketch is provided
         | 
| 494 | 
            +
                        if sketch is None:
         | 
| 495 | 
            +
                            errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch or reduce the number of keyframe sketches.")
         | 
| 496 | 
            +
                        else:
         | 
| 497 | 
            +
                            # For ImageMask components, check if background is provided
         | 
| 498 | 
            +
                            if isinstance(sketch, dict):
         | 
| 499 | 
            +
                                if "background" not in sketch or sketch["background"] is None:
         | 
| 500 | 
            +
                                    errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch image.")
         | 
| 501 | 
            +
                                else:
         | 
| 502 | 
            +
                                    cond_sketches_count += 1
         | 
| 503 | 
            +
                            else:
         | 
| 504 | 
            +
                                cond_sketches_count += 1
         | 
| 505 | 
            +
                        
         | 
| 506 | 
            +
                        # Check frame index
         | 
| 507 | 
            +
                        if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames):
         | 
| 508 | 
            +
                            errors.append(f"❌ Frame index for Sketch #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.")
         | 
| 509 | 
            +
                        elif frame_idx is not None:
         | 
| 510 | 
            +
                            sketch_frame_indices.append(frame_idx)
         | 
| 511 | 
            +
                
         | 
| 512 | 
            +
                # Check for duplicate frame indices
         | 
| 513 | 
            +
                image_frame_indices = []
         | 
| 514 | 
            +
                for i in range(int(num_cond_images)):
         | 
| 515 | 
            +
                    frame_idx = args[i*2+1]
         | 
| 516 | 
            +
                    if frame_idx is not None:
         | 
| 517 | 
            +
                        image_frame_indices.append(frame_idx)
         | 
| 518 | 
            +
                
         | 
| 519 | 
            +
                all_frame_indices = image_frame_indices + sketch_frame_indices
         | 
| 520 | 
            +
                if len(all_frame_indices) != len(set(all_frame_indices)):
         | 
| 521 | 
            +
                    errors.append("❌ Duplicate frame indices detected. Each image and sketch must be placed at a different frame.")
         | 
| 522 | 
            +
                
         | 
| 523 | 
            +
                # Check minimum requirements
         | 
| 524 | 
            +
                if cond_images_count == 0:
         | 
| 525 | 
            +
                    errors.append("❌ At least one input image is required.")
         | 
| 526 | 
            +
                
         | 
| 527 | 
            +
                return errors
         | 
| 528 | 
            +
             | 
| 529 | 
            +
            def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
         | 
| 530 | 
            +
                # Validate inputs first
         | 
| 531 | 
            +
                validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
         | 
| 532 | 
            +
                
         | 
| 533 | 
            +
                if validation_errors:
         | 
| 534 | 
            +
                    error_message = "\n".join(validation_errors)
         | 
| 535 | 
            +
                    return gr.update(value=None), error_message
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                try:
         | 
| 538 | 
            +
                    # Parse inputs
         | 
| 539 | 
            +
                    # Get the condition images
         | 
| 540 | 
            +
                    cond_images = []
         | 
| 541 | 
            +
                    for i in range(int(num_cond_images)):
         | 
| 542 | 
            +
                        img = args[i*2]
         | 
| 543 | 
            +
                        frame_idx = args[i*2+1]
         | 
| 544 | 
            +
                        if img is not None and frame_idx is not None:
         | 
| 545 | 
            +
                            cond_images.append((img, frame_idx))
         | 
| 546 | 
            +
                    
         | 
| 547 | 
            +
                    # Get num_cond_sketches
         | 
| 548 | 
            +
                    if num_cond_sketches is None:
         | 
| 549 | 
            +
                        num_cond_sketches = 0
         | 
| 550 | 
            +
                    else:
         | 
| 551 | 
            +
                        num_cond_sketches = int(num_cond_sketches)
         | 
| 552 | 
            +
                    
         | 
| 553 | 
            +
                    # Get condition sketches and masks
         | 
| 554 | 
            +
                    cond_sketches = []
         | 
| 555 | 
            +
                    sketch_masks = []
         | 
| 556 | 
            +
                    num_cond_sketches_index = 8  # Starting index for sketch inputs
         | 
| 557 | 
            +
                    
         | 
| 558 | 
            +
                    for i in range(num_cond_sketches):
         | 
| 559 | 
            +
                        sketch_idx = num_cond_sketches_index + i*2
         | 
| 560 | 
            +
                        frame_idx_idx = num_cond_sketches_index + 1 + i*2
         | 
| 561 | 
            +
                        
         | 
| 562 | 
            +
                        if sketch_idx < len(args) and frame_idx_idx < len(args):
         | 
| 563 | 
            +
                            editor_value = args[sketch_idx]
         | 
| 564 | 
            +
                            frame_idx = args[frame_idx_idx]
         | 
| 565 | 
            +
                            
         | 
| 566 | 
            +
                            if editor_value is not None and frame_idx is not None:
         | 
| 567 | 
            +
                                # Extract the sketch from the background of the editor value
         | 
| 568 | 
            +
                                if isinstance(editor_value, dict) and "background" in editor_value:
         | 
| 569 | 
            +
                                    sketch = editor_value["background"]
         | 
| 570 | 
            +
                                    if sketch is not None:
         | 
| 571 | 
            +
                                        cond_sketches.append((sketch, frame_idx))
         | 
| 572 | 
            +
                                        # Also add to sketch_masks for mask processing
         | 
| 573 | 
            +
                                        sketch_masks.append((editor_value, frame_idx))
         | 
| 574 | 
            +
                                else:
         | 
| 575 | 
            +
                                    # For regular image inputs (first sketch)
         | 
| 576 | 
            +
                                    if editor_value is not None:
         | 
| 577 | 
            +
                                        cond_sketches.append((editor_value, frame_idx))
         | 
| 578 | 
            +
                    
         | 
| 579 | 
            +
                    # Set target resolution based on selection
         | 
| 580 | 
            +
                    target_height, target_width = checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"]
         | 
| 581 | 
            +
                    
         | 
| 582 | 
            +
                    # Update model resolution
         | 
| 583 | 
            +
                    if not fast_dev:
         | 
| 584 | 
            +
                        model.update_height_width(target_height, target_width)
         | 
| 585 | 
            +
                    
         | 
| 586 | 
            +
                    # Process conditions
         | 
| 587 | 
            +
                    with torch.no_grad():
         | 
| 588 | 
            +
                        # Process image conditions
         | 
| 589 | 
            +
                        masked_cond_video, preserved_cond_mask = process_conditions(
         | 
| 590 | 
            +
                            num_cond_images, cond_images, num_frames, target_height=target_height, target_width=target_width
         | 
| 591 | 
            +
                        )
         | 
| 592 | 
            +
                        
         | 
| 593 | 
            +
                        # Process sketch conditions
         | 
| 594 | 
            +
                        masked_cond_sketch, preserved_sketch_mask = process_conditions(
         | 
| 595 | 
            +
                            len(cond_sketches), cond_sketches, num_frames, is_sketch=True, target_height=target_height, target_width=target_width
         | 
| 596 | 
            +
                        )
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                        # Process sketch masks (if any)
         | 
| 599 | 
            +
                        sketch_local_mask = None
         | 
| 600 | 
            +
                        if len(sketch_masks) > 0:
         | 
| 601 | 
            +
                            sketch_local_mask = process_sketch_masks(
         | 
| 602 | 
            +
                                len(sketch_masks), sketch_masks, num_frames, target_height=target_height, target_width=target_width
         | 
| 603 | 
            +
                            )
         | 
| 604 | 
            +
                        else:
         | 
| 605 | 
            +
                            sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device)
         | 
| 606 | 
            +
                         
         | 
| 607 | 
            +
                        if fast_dev:
         | 
| 608 | 
            +
                            print("Fast dev mode, returning dummy video")
         | 
| 609 | 
            +
                            # Create a simple dummy video for testing
         | 
| 610 | 
            +
                            temp_dir = tempfile.mkdtemp()
         | 
| 611 | 
            +
                            video_path = os.path.join(temp_dir, "dummy_video.mp4")
         | 
| 612 | 
            +
                            
         | 
| 613 | 
            +
                            # Create a simple test video
         | 
| 614 | 
            +
                            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
         | 
| 615 | 
            +
                            video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (target_width, target_height))
         | 
| 616 | 
            +
                            
         | 
| 617 | 
            +
                            for i in range(30):  # 30 frames
         | 
| 618 | 
            +
                                # Create a simple colored frame
         | 
| 619 | 
            +
                                frame = np.full((target_height, target_width, 3), (i * 8) % 255, dtype=np.uint8)
         | 
| 620 | 
            +
                                video_writer.write(frame)
         | 
| 621 | 
            +
                            
         | 
| 622 | 
            +
                            video_writer.release()
         | 
| 623 | 
            +
                            return video_path, "✅ Dummy video generated successfully in fast dev mode!"
         | 
| 624 | 
            +
                        
         | 
| 625 | 
            +
                        masked_cond_video = masked_cond_video.to(device=device, dtype=dtype)
         | 
| 626 | 
            +
                        preserved_cond_mask = preserved_cond_mask.to(device=device, dtype=dtype)
         | 
| 627 | 
            +
                        masked_cond_sketch = masked_cond_sketch.to(device=device, dtype=dtype)
         | 
| 628 | 
            +
                        preserved_sketch_mask = preserved_sketch_mask.to(device=device, dtype=dtype)
         | 
| 629 | 
            +
                        
         | 
| 630 | 
            +
                        with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(device).type):
         | 
| 631 | 
            +
                            # Generate video
         | 
| 632 | 
            +
                            model.pipe.device = device
         | 
| 633 | 
            +
                            generated_video = model.pipe(
         | 
| 634 | 
            +
                                prompt=[text_prompt],
         | 
| 635 | 
            +
                                negative_prompt=[model.negative_prompt],
         | 
| 636 | 
            +
                                input_image=None,
         | 
| 637 | 
            +
                                num_inference_steps=15,
         | 
| 638 | 
            +
                                num_frames=num_frames,
         | 
| 639 | 
            +
                                seed=42, tiled=True,
         | 
| 640 | 
            +
                                input_condition_video=masked_cond_video,
         | 
| 641 | 
            +
                                input_condition_preserved_mask=preserved_cond_mask,
         | 
| 642 | 
            +
                                input_condition_video_sketch=masked_cond_sketch,
         | 
| 643 | 
            +
                                input_condition_preserved_mask_sketch=preserved_sketch_mask,
         | 
| 644 | 
            +
                                sketch_local_mask=sketch_local_mask,
         | 
| 645 | 
            +
                                cfg_scale=cfg_scale,
         | 
| 646 | 
            +
                                sequence_cond_residual_scale=sequence_cond_residual_scale,
         | 
| 647 | 
            +
                                height=target_height,
         | 
| 648 | 
            +
                                width=target_width,
         | 
| 649 | 
            +
                            )
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                        # Convert to PIL images
         | 
| 652 | 
            +
                        video_frames = model.pipe.tensor2video(generated_video[0].cpu())
         | 
| 653 | 
            +
                        
         | 
| 654 | 
            +
                        # Convert PIL images to an MP4 video
         | 
| 655 | 
            +
                        temp_dir = tempfile.mkdtemp()
         | 
| 656 | 
            +
                        video_path = os.path.join(temp_dir, "generated_video.mp4")
         | 
| 657 | 
            +
                        
         | 
| 658 | 
            +
                        width, height = video_frames[0].size
         | 
| 659 | 
            +
                        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for MP4 video
         | 
| 660 | 
            +
                        video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (width, height))  # 20 fps
         | 
| 661 | 
            +
                        
         | 
| 662 | 
            +
                        for frame in video_frames:
         | 
| 663 | 
            +
                            # Convert PIL image to OpenCV BGR format
         | 
| 664 | 
            +
                            frame_bgr = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
         | 
| 665 | 
            +
                            video_writer.write(frame_bgr)
         | 
| 666 | 
            +
                        
         | 
| 667 | 
            +
                        video_writer.release()
         | 
| 668 | 
            +
                        print(f"Generated video saved to {video_path}. Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
         | 
| 669 | 
            +
                        
         | 
| 670 | 
            +
                        return video_path, f"✅ Video generated successfully! (with {len(cond_images)} keyframe images, {len(cond_sketches)} keyframe sketches)"
         | 
| 671 | 
            +
                
         | 
| 672 | 
            +
                except Exception as e:
         | 
| 673 | 
            +
                    error_msg = f"❌ Error during generation: {str(e)}"
         | 
| 674 | 
            +
                    print(error_msg)
         | 
| 675 | 
            +
                    return gr.update(value=None), error_msg
         | 
| 676 | 
            +
             | 
| 677 | 
            +
            def create_sample_gallery():
         | 
| 678 | 
            +
                """Create gallery items for samples"""
         | 
| 679 | 
            +
                import os
         | 
| 680 | 
            +
                
         | 
| 681 | 
            +
                gallery_items = []
         | 
| 682 | 
            +
                sample_info = [
         | 
| 683 | 
            +
                    {
         | 
| 684 | 
            +
                        "id": 1,
         | 
| 685 | 
            +
                        "title": "Sample 1",
         | 
| 686 | 
            +
                        "description": "Man playing with blue fish underwater (3 sketches)",
         | 
| 687 | 
            +
                        "preview": "samples/1_image1.png"
         | 
| 688 | 
            +
                    },
         | 
| 689 | 
            +
                    {
         | 
| 690 | 
            +
                        "id": 2, 
         | 
| 691 | 
            +
                        "title": "Sample 2",
         | 
| 692 | 
            +
                        "description": "Girl and boy planting a growing flower (2 sketches)",
         | 
| 693 | 
            +
                        "preview": "samples/2_image1.jpg"
         | 
| 694 | 
            +
                    },
         | 
| 695 | 
            +
                    {
         | 
| 696 | 
            +
                        "id": 3,
         | 
| 697 | 
            +
                        "title": "Sample 3", 
         | 
| 698 | 
            +
                        "description": "Ancient Chinese boy giving apple to elder (1 sketch)",
         | 
| 699 | 
            +
                        "preview": "samples/3_image1.png"
         | 
| 700 | 
            +
                    }
         | 
| 701 | 
            +
                ]
         | 
| 702 | 
            +
                
         | 
| 703 | 
            +
                for sample in sample_info:
         | 
| 704 | 
            +
                    if os.path.exists(sample["preview"]):
         | 
| 705 | 
            +
                        gallery_items.append((sample["preview"], f"{sample['title']}: {sample['description']}"))
         | 
| 706 | 
            +
                
         | 
| 707 | 
            +
                return gallery_items
         | 
| 708 | 
            +
             | 
| 709 | 
            +
            def handle_gallery_select(evt: gr.SelectData):
         | 
| 710 | 
            +
                """Handle gallery selection and load the corresponding sample"""
         | 
| 711 | 
            +
                sample_id = evt.index + 1  # Gallery index starts from 0, sample IDs start from 1
         | 
| 712 | 
            +
                return apply_sample_to_ui(sample_id)
         | 
| 713 | 
            +
             | 
| 714 | 
            +
            def load_sample_data(sample_id):
         | 
| 715 | 
            +
                """Load sample data based on the selected sample"""
         | 
| 716 | 
            +
                import os
         | 
| 717 | 
            +
                
         | 
| 718 | 
            +
                samples_dir = "samples"
         | 
| 719 | 
            +
                
         | 
| 720 | 
            +
                # Sample configurations
         | 
| 721 | 
            +
                sample_configs = {
         | 
| 722 | 
            +
                    1: {
         | 
| 723 | 
            +
                        "prompt": "Underwater scene: A shirtless man plays with a spiraling blue fish. A whale follows a bag in the man's hand, swimming in circles as the man uses the bag to lure the blue fish forward. Anime. High quality.",
         | 
| 724 | 
            +
                        "num_sketches": 3,
         | 
| 725 | 
            +
                        "image_frame": 0,
         | 
| 726 | 
            +
                        "sketch_frames": [20, 40, 60],
         | 
| 727 | 
            +
                        "num_frames": 61
         | 
| 728 | 
            +
                    },
         | 
| 729 | 
            +
                    2: {
         | 
| 730 | 
            +
                        "prompt": "A girl and a silver-haired boy plant a huge flower. As the camera slowly moves up, the huge flower continues to grow and bloom. Anime. High quality.",
         | 
| 731 | 
            +
                        "num_sketches": 2,
         | 
| 732 | 
            +
                        "image_frame": 0,
         | 
| 733 | 
            +
                        "sketch_frames": [30, 60],
         | 
| 734 | 
            +
                        "num_frames": 61
         | 
| 735 | 
            +
                    },
         | 
| 736 | 
            +
                    3: {
         | 
| 737 | 
            +
                        "prompt": "An ancient Chinese boy holds an apple and smiles as he gives it to an elderly man nearby. Anime. High quality.",
         | 
| 738 | 
            +
                        "num_sketches": 1,
         | 
| 739 | 
            +
                        "image_frame": 0,
         | 
| 740 | 
            +
                        "sketch_frames": [30],
         | 
| 741 | 
            +
                        "num_frames": 33
         | 
| 742 | 
            +
                    }
         | 
| 743 | 
            +
                }
         | 
| 744 | 
            +
                
         | 
| 745 | 
            +
                if sample_id not in sample_configs:
         | 
| 746 | 
            +
                    return None
         | 
| 747 | 
            +
                
         | 
| 748 | 
            +
                config = sample_configs[sample_id]
         | 
| 749 | 
            +
                
         | 
| 750 | 
            +
                # Load image
         | 
| 751 | 
            +
                image_path = os.path.join(samples_dir, f"{sample_id}_image1.png")
         | 
| 752 | 
            +
                if not os.path.exists(image_path):
         | 
| 753 | 
            +
                    image_path = os.path.join(samples_dir, f"{sample_id}_image1.jpg")
         | 
| 754 | 
            +
                
         | 
| 755 | 
            +
                # Load sketches
         | 
| 756 | 
            +
                sketches = []
         | 
| 757 | 
            +
                for i in range(config["num_sketches"]):
         | 
| 758 | 
            +
                    sketch_path = os.path.join(samples_dir, f"{sample_id}_sketch{i+1}.jpg")
         | 
| 759 | 
            +
                    if os.path.exists(sketch_path):
         | 
| 760 | 
            +
                        sketches.append(sketch_path)
         | 
| 761 | 
            +
                
         | 
| 762 | 
            +
                # Load output video
         | 
| 763 | 
            +
                output_path = os.path.join(samples_dir, f"{sample_id}_out.mp4")
         | 
| 764 | 
            +
                
         | 
| 765 | 
            +
                return {
         | 
| 766 | 
            +
                    "prompt": config["prompt"],
         | 
| 767 | 
            +
                    "image": image_path if os.path.exists(image_path) else None,
         | 
| 768 | 
            +
                    "sketches": sketches,
         | 
| 769 | 
            +
                    "image_frame": config["image_frame"],
         | 
| 770 | 
            +
                    "sketch_frames": config["sketch_frames"][:len(sketches)],
         | 
| 771 | 
            +
                    "output_video": output_path if os.path.exists(output_path) else None,
         | 
| 772 | 
            +
                    "num_sketches": len(sketches),
         | 
| 773 | 
            +
                    "num_frames": config["num_frames"]
         | 
| 774 | 
            +
                }
         | 
| 775 | 
            +
             | 
| 776 | 
            +
            def apply_sample_to_ui(sample_id):
         | 
| 777 | 
            +
                """Apply sample data to UI components"""
         | 
| 778 | 
            +
                sample_data = load_sample_data(sample_id)
         | 
| 779 | 
            +
                
         | 
| 780 | 
            +
                if not sample_data:
         | 
| 781 | 
            +
                    return [gr.update() for _ in range(20)]  # Return no updates if sample not found
         | 
| 782 | 
            +
                
         | 
| 783 | 
            +
                updates = [gr.update(value=sample_data["num_frames"])]
         | 
| 784 | 
            +
                
         | 
| 785 | 
            +
                # Update prompt
         | 
| 786 | 
            +
                updates.append(gr.update(value=sample_data["prompt"]))
         | 
| 787 | 
            +
                
         | 
| 788 | 
            +
                # Update number of sketches
         | 
| 789 | 
            +
                updates.append(gr.update(value=sample_data["num_sketches"]))
         | 
| 790 | 
            +
                
         | 
| 791 | 
            +
                # Update condition image
         | 
| 792 | 
            +
                updates.append(gr.update(value=sample_data["image"]))
         | 
| 793 | 
            +
                updates.append(gr.update(value=sample_data["image_frame"]))
         | 
| 794 | 
            +
                
         | 
| 795 | 
            +
                # Update sketches (up to 4)
         | 
| 796 | 
            +
                for i in range(4):
         | 
| 797 | 
            +
                    if i < len(sample_data["sketches"]):
         | 
| 798 | 
            +
                        # Load sketch image
         | 
| 799 | 
            +
                        sketch_img = Image.open(sample_data["sketches"][i])
         | 
| 800 | 
            +
                        # Create ImageMask format
         | 
| 801 | 
            +
                        sketch_dict = {
         | 
| 802 | 
            +
                            "background": sketch_img,
         | 
| 803 | 
            +
                            "layers": [],
         | 
| 804 | 
            +
                            "composite": sketch_img
         | 
| 805 | 
            +
                        }
         | 
| 806 | 
            +
                        updates.append(gr.update(value=sketch_dict))
         | 
| 807 | 
            +
                        updates.append(gr.update(value=sample_data["sketch_frames"][i]))
         | 
| 808 | 
            +
                    else:
         | 
| 809 | 
            +
                        updates.append(gr.update(value=None))
         | 
| 810 | 
            +
                        updates.append(gr.update(value=30))
         | 
| 811 | 
            +
                
         | 
| 812 | 
            +
                # Update output video
         | 
| 813 | 
            +
                updates.append(gr.update(value=sample_data["output_video"]))
         | 
| 814 | 
            +
                
         | 
| 815 | 
            +
                # Update status
         | 
| 816 | 
            +
                updates.append(gr.update(value=f"✅ Loaded Sample {sample_id}: {sample_data['prompt'][:50]}..."))
         | 
| 817 | 
            +
                
         | 
| 818 | 
            +
                return updates
         | 
| 819 | 
            +
             | 
| 820 | 
            +
            if __name__ == "__main__":
         | 
| 821 | 
            +
                from util.stylesheets import css, pre_js, banner_image
         | 
| 822 | 
            +
                with gr.Blocks(title="🎨 ToonComposer Demo", css=css, js=pre_js) as iface:
         | 
| 823 | 
            +
                    with gr.Row():
         | 
| 824 | 
            +
                        with gr.Column(scale=1):
         | 
| 825 | 
            +
                            gr.HTML(banner_image)
         | 
| 826 | 
            +
                        with gr.Column(scale=1):
         | 
| 827 | 
            +
                            gr.Markdown("""
         | 
| 828 | 
            +
                            💡 **Quick Guide**
         | 
| 829 | 
            +
                            1. Set the promopt and number of target frames, input keyframe images/sketches, etc.
         | 
| 830 | 
            +
                            2. Upload keyframe image as the first frame (with index set to 0).
         | 
| 831 | 
            +
                            3. Upload sketches with optional motion masks for controlled generation at specified frame indices.
         | 
| 832 | 
            +
                            4. Click the *Generate* button to create your cartoon video.
         | 
| 833 | 
            +
                            """)
         | 
| 834 | 
            +
                    
         | 
| 835 | 
            +
                    max_num_frames = 61
         | 
| 836 | 
            +
                    cond_images_inputs = []
         | 
| 837 | 
            +
                    cond_sketches_inputs = []
         | 
| 838 | 
            +
                    with gr.Row():
         | 
| 839 | 
            +
                        with gr.Column(scale=1):
         | 
| 840 | 
            +
                            with gr.Accordion("Video Settings", open=True):
         | 
| 841 | 
            +
                                num_frames = gr.Slider(
         | 
| 842 | 
            +
                                    minimum=17, maximum=max_num_frames, value=max_num_frames, step=1, label="🎥 Number of Frames",
         | 
| 843 | 
            +
                                    info="Select the total number of frames for the generated video. Should be 4N+"
         | 
| 844 | 
            +
                                )
         | 
| 845 | 
            +
                                
         | 
| 846 | 
            +
                                resolution = gr.Radio(
         | 
| 847 | 
            +
                                    choices=["480p", "608p"],
         | 
| 848 | 
            +
                                    value="480p",
         | 
| 849 | 
            +
                                    label="🎥 Resolution",
         | 
| 850 | 
            +
                                    info="Select the resolution for the generated video."
         | 
| 851 | 
            +
                                )
         | 
| 852 | 
            +
                                
         | 
| 853 | 
            +
                                text_prompt = gr.Textbox(
         | 
| 854 | 
            +
                                    label="📝 Text Prompt",
         | 
| 855 | 
            +
                                    placeholder="Enter a description for the video.",
         | 
| 856 | 
            +
                                    info="Describe what you want to generate in the video.",
         | 
| 857 | 
            +
                                    lines=5
         | 
| 858 | 
            +
                                )
         | 
| 859 | 
            +
                                cfg_scale = gr.Slider(
         | 
| 860 | 
            +
                                    minimum=1.0, maximum=15.0, value=7.5, label="⚙️ CFG Scale",
         | 
| 861 | 
            +
                                    info="Adjust the classifier-free guidance scale for generation."
         | 
| 862 | 
            +
                                )
         | 
| 863 | 
            +
                                sequence_cond_residual_scale = gr.Slider(
         | 
| 864 | 
            +
                                    minimum=0.0, maximum=1.2, value=1.0, label="⚙️ Pos-aware Residual Scale",
         | 
| 865 | 
            +
                                    info="Adjust the residual scale for the position-aware sequence condition."
         | 
| 866 | 
            +
                                )
         | 
| 867 | 
            +
                    
         | 
| 868 | 
            +
                        with gr.Column(scale=3):
         | 
| 869 | 
            +
                            with gr.Accordion("Keyframe Image(s)", open=True):
         | 
| 870 | 
            +
                                num_cond_images = gr.Slider(
         | 
| 871 | 
            +
                                    minimum=1, maximum=4, value=1, step=1, label="🖼️ Number of Keyframe Images",
         | 
| 872 | 
            +
                                    info="Specify how many keyframe color images to use (max 4 images)."
         | 
| 873 | 
            +
                                )
         | 
| 874 | 
            +
                                for i in range(4):  # Max 4 condition images
         | 
| 875 | 
            +
                                    with gr.Tab(label=f"Image {i+1}", interactive=i==0) as tab:
         | 
| 876 | 
            +
                                        gr.Markdown("At least one image is required. \n Each image or sketch will be used to control the cartoon geneartion at the given frame index.")
         | 
| 877 | 
            +
                                        image_input = gr.Image(
         | 
| 878 | 
            +
                                            label=f"Image {i+1}", type="pil",
         | 
| 879 | 
            +
                                            placeholder=f"Upload a keyframe image {i+1}..."
         | 
| 880 | 
            +
                                        )
         | 
| 881 | 
            +
                                        frame_index_input = gr.Slider(
         | 
| 882 | 
            +
                                            label=f"Frame Index for Image #{i+1}", minimum=0, maximum=max_num_frames - 1,
         | 
| 883 | 
            +
                                            value=i * (max_num_frames-1) // 3, step=1, 
         | 
| 884 | 
            +
                                            info=f"Frame position for Image {i+1} (0 to {max_num_frames-1})"
         | 
| 885 | 
            +
                                        )
         | 
| 886 | 
            +
                                        cond_images_inputs.append((image_input, frame_index_input, tab))
         | 
| 887 | 
            +
                                        
         | 
| 888 | 
            +
                    
         | 
| 889 | 
            +
                        with gr.Column(scale=3):
         | 
| 890 | 
            +
                            with gr.Accordion("Keyframe Sketch(es)", open=True): 
         | 
| 891 | 
            +
                                num_cond_sketches = gr.Slider(
         | 
| 892 | 
            +
                                    minimum=1, maximum=4, value=1, step=1, label="✏️ Number of Keyframe Sketch(es)",
         | 
| 893 | 
            +
                                    info="Specify how many keyframe sketches to use (max 4 sketches)."
         | 
| 894 | 
            +
                                )
         | 
| 895 | 
            +
                                for i in range(4):  # Max 4 condition sketches
         | 
| 896 | 
            +
                                    with gr.Tab(label=f"Sketch {i + 1}", interactive=i==0) as tab:
         | 
| 897 | 
            +
                                        
         | 
| 898 | 
            +
                                        gr.Markdown("At least one sketch is required. \n You can optionally draw black areas using the brush tool to mark regions where motion can be generated freely.")
         | 
| 899 | 
            +
                                        
         | 
| 900 | 
            +
                                        # Use ImageMask which allows uploading an image and drawing a mask
         | 
| 901 | 
            +
                                        sketch_input = gr.ImageMask(
         | 
| 902 | 
            +
                                            label=f"Sketch {i + 1} with Motion Mask",
         | 
| 903 | 
            +
                                            type="pil",
         | 
| 904 | 
            +
                                            elem_id=f"sketch_mask_{i + 1}"
         | 
| 905 | 
            +
                                        )
         | 
| 906 | 
            +
                                        
         | 
| 907 | 
            +
                                        # All sketches have a frame index input
         | 
| 908 | 
            +
                                        _frame_index_input = gr.Slider(
         | 
| 909 | 
            +
                                            label=f"Frame Index for Sketch #{i + 1}", minimum=0, maximum=max_num_frames - 1,
         | 
| 910 | 
            +
                                            value=max_num_frames-1, step=1,
         | 
| 911 | 
            +
                                            info=f"Frame position for Sketch {i + 1} (0 to {max_num_frames-1})"
         | 
| 912 | 
            +
                                        )
         | 
| 913 | 
            +
                                        
         | 
| 914 | 
            +
                                        cond_sketches_inputs.append((sketch_input, _frame_index_input, tab))
         | 
| 915 | 
            +
                        
         | 
| 916 | 
            +
                    with gr.Row():
         | 
| 917 | 
            +
                        with gr.Column(scale=1):
         | 
| 918 | 
            +
                            # Sample Gallery Section
         | 
| 919 | 
            +
                            with gr.Accordion("🔍 Sample Gallery", open=True):
         | 
| 920 | 
            +
                                gr.Markdown("Click on any sample image below to load the sample inputs.")
         | 
| 921 | 
            +
                                sample_gallery = gr.Gallery(
         | 
| 922 | 
            +
                                    value=create_sample_gallery(),
         | 
| 923 | 
            +
                                    label="Sample Examples",
         | 
| 924 | 
            +
                                    show_label=False,
         | 
| 925 | 
            +
                                    elem_id="sample-gallery",
         | 
| 926 | 
            +
                                    columns=3,
         | 
| 927 | 
            +
                                    rows=1,
         | 
| 928 | 
            +
                                    height=200,
         | 
| 929 | 
            +
                                    allow_preview=True,
         | 
| 930 | 
            +
                                    object_fit="contain")
         | 
| 931 | 
            +
                                
         | 
| 932 | 
            +
                            with gr.Accordion("🛠️ Tools", open=False):
         | 
| 933 | 
            +
                                tool_input = gr.Image(
         | 
| 934 | 
            +
                                    label=f"Input Image", type="pil",
         | 
| 935 | 
            +
                                    placeholder=f"Upload an image."
         | 
| 936 | 
            +
                                )
         | 
| 937 | 
            +
                                invert_btn = gr.Button(f"Invert Colors")
         | 
| 938 | 
            +
                                invert_btn.click(
         | 
| 939 | 
            +
                                    fn=invert_sketch,
         | 
| 940 | 
            +
                                    inputs=[tool_input],
         | 
| 941 | 
            +
                                    outputs=[tool_input]
         | 
| 942 | 
            +
                                )
         | 
| 943 | 
            +
                                
         | 
| 944 | 
            +
                        with gr.Column(scale=1):
         | 
| 945 | 
            +
                            status_text = gr.Textbox(
         | 
| 946 | 
            +
                                label="📊 Status",
         | 
| 947 | 
            +
                                value="Ready to generate. Please check your inputs and click Run.",
         | 
| 948 | 
            +
                                interactive=False,
         | 
| 949 | 
            +
                                lines=5
         | 
| 950 | 
            +
                            )
         | 
| 951 | 
            +
                            
         | 
| 952 | 
            +
                            with gr.Accordion("🎬 Generated Video", open=True):
         | 
| 953 | 
            +
                                output_video = gr.Video(
         | 
| 954 | 
            +
                                    label="Video Output",
         | 
| 955 | 
            +
                                    show_label=True
         | 
| 956 | 
            +
                                )
         | 
| 957 | 
            +
                                run_button = gr.Button("🚀 Generate Video", variant="primary", size="lg")
         | 
| 958 | 
            +
             | 
| 959 | 
            +
                    def update_visibility(num_items, num_frames):
         | 
| 960 | 
            +
                        # Update visibility for columns
         | 
| 961 | 
            +
                        updates_images = []
         | 
| 962 | 
            +
                        updates_indices = []
         | 
| 963 | 
            +
                        for i in range(4):
         | 
| 964 | 
            +
                            is_visible = i < num_items
         | 
| 965 | 
            +
                            # is_visible = True
         | 
| 966 | 
            +
                            updates_images.append(gr.update(interactive=is_visible))
         | 
| 967 | 
            +
                            updates_indices.append(gr.update(
         | 
| 968 | 
            +
                                value=((num_frames - 1) // max(num_items, 1)) * (i + 1),
         | 
| 969 | 
            +
                                minimum=0, maximum=num_frames-1,
         | 
| 970 | 
            +
                            ))
         | 
| 971 | 
            +
                        return updates_images + updates_indices
         | 
| 972 | 
            +
                    
         | 
| 973 | 
            +
                    def update_visibility_images(num_items, num_frames):
         | 
| 974 | 
            +
                        # Update visibility for columns
         | 
| 975 | 
            +
                        updates_images = []
         | 
| 976 | 
            +
                        updates_indices = []
         | 
| 977 | 
            +
                        for i in range(4):
         | 
| 978 | 
            +
                            is_visible = i < num_items
         | 
| 979 | 
            +
                            updates_images.append(gr.update(interactive=is_visible))
         | 
| 980 | 
            +
                            updates_indices.append(gr.update(
         | 
| 981 | 
            +
                                value=((num_frames - 1) // max(num_items, 1)) * i,
         | 
| 982 | 
            +
                                minimum=0, maximum=num_frames-1,
         | 
| 983 | 
            +
                            ))
         | 
| 984 | 
            +
                        return updates_images + updates_indices
         | 
| 985 | 
            +
                    
         | 
| 986 | 
            +
                    def update_frame_ranges(num_items_images, num_items_sketches, num_frames):
         | 
| 987 | 
            +
                        """Update the maximum values for all frame index sliders"""
         | 
| 988 | 
            +
                        updates = []
         | 
| 989 | 
            +
                        for i in range(4):  # Images
         | 
| 990 | 
            +
                            updates.append(gr.update(
         | 
| 991 | 
            +
                                value=((num_frames - 1) // max(num_items_images, 1)) * i,
         | 
| 992 | 
            +
                                maximum=num_frames-1
         | 
| 993 | 
            +
                                ))
         | 
| 994 | 
            +
                        for i in range(4):  # Sketches
         | 
| 995 | 
            +
                            updates.append(gr.update(
         | 
| 996 | 
            +
                                value=((num_frames - 1) // max(num_items_sketches, 1)) * (i + 1),
         | 
| 997 | 
            +
                                maximum=num_frames-1))
         | 
| 998 | 
            +
                        return updates
         | 
| 999 | 
            +
                    
         | 
| 1000 | 
            +
                    num_cond_images.change(
         | 
| 1001 | 
            +
                        fn=update_visibility_images,
         | 
| 1002 | 
            +
                        inputs=[num_cond_images, num_frames],
         | 
| 1003 | 
            +
                        outputs=[tab for _, _, tab in cond_images_inputs] \
         | 
| 1004 | 
            +
                            + [frame_index_input for _, frame_index_input, _ in cond_images_inputs],
         | 
| 1005 | 
            +
                    )
         | 
| 1006 | 
            +
             | 
| 1007 | 
            +
                    num_cond_sketches.change(
         | 
| 1008 | 
            +
                        fn=update_visibility,
         | 
| 1009 | 
            +
                        inputs=[num_cond_sketches, num_frames],
         | 
| 1010 | 
            +
                        outputs=[tab for _, _, tab in cond_sketches_inputs] \
         | 
| 1011 | 
            +
                            + [frame_index_input for _, frame_index_input, _ in cond_sketches_inputs],
         | 
| 1012 | 
            +
                    )
         | 
| 1013 | 
            +
             | 
| 1014 | 
            +
                    num_frames.change(
         | 
| 1015 | 
            +
                        fn=update_frame_ranges,
         | 
| 1016 | 
            +
                        inputs=[num_cond_images, num_cond_sketches, num_frames],
         | 
| 1017 | 
            +
                        outputs=[frame_index_input for _, frame_index_input, _ in cond_images_inputs] + \
         | 
| 1018 | 
            +
                                [frame_index_input for _, frame_index_input, _ in cond_sketches_inputs]
         | 
| 1019 | 
            +
                    )
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                    def update_resolution(resolution):
         | 
| 1022 | 
            +
                        model.update_height_width(checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"])
         | 
| 1023 | 
            +
                        model.load_tooncomposer_checkpoint(checkpoints_by_resolution[resolution]["checkpoint_path"])
         | 
| 1024 | 
            +
                        return gr.update(), gr.update()
         | 
| 1025 | 
            +
             | 
| 1026 | 
            +
                    resolution.change(
         | 
| 1027 | 
            +
                        fn=update_resolution,
         | 
| 1028 | 
            +
                        inputs=[resolution],
         | 
| 1029 | 
            +
                        outputs=[output_video, run_button]
         | 
| 1030 | 
            +
                    )
         | 
| 1031 | 
            +
                    
         | 
| 1032 | 
            +
                    sample_outputs = [
         | 
| 1033 | 
            +
                        num_frames, text_prompt, num_cond_sketches,
         | 
| 1034 | 
            +
                        cond_images_inputs[0][0], cond_images_inputs[0][1],  # Image 1
         | 
| 1035 | 
            +
                        cond_sketches_inputs[0][0], cond_sketches_inputs[0][1],  # Sketch 1
         | 
| 1036 | 
            +
                        cond_sketches_inputs[1][0], cond_sketches_inputs[1][1],  # Sketch 2
         | 
| 1037 | 
            +
                        cond_sketches_inputs[2][0], cond_sketches_inputs[2][1],  # Sketch 3
         | 
| 1038 | 
            +
                        cond_sketches_inputs[3][0], cond_sketches_inputs[3][1],  # Sketch 4
         | 
| 1039 | 
            +
                        output_video, status_text
         | 
| 1040 | 
            +
                    ]
         | 
| 1041 | 
            +
                    
         | 
| 1042 | 
            +
                    sample_gallery.select(
         | 
| 1043 | 
            +
                        fn=handle_gallery_select,
         | 
| 1044 | 
            +
                        outputs=sample_outputs
         | 
| 1045 | 
            +
                    )
         | 
| 1046 | 
            +
             | 
| 1047 | 
            +
                    inputs = [num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution]
         | 
| 1048 | 
            +
                    run_button.click(
         | 
| 1049 | 
            +
                        fn=tooncomposer_inference,
         | 
| 1050 | 
            +
                        inputs=inputs,
         | 
| 1051 | 
            +
                        outputs=[output_video, status_text]
         | 
| 1052 | 
            +
                    )
         | 
| 1053 | 
            +
                    
         | 
| 1054 | 
            +
                    # Add condition image inputs
         | 
| 1055 | 
            +
                    for image_input, frame_index_input, _ in cond_images_inputs:
         | 
| 1056 | 
            +
                        inputs.append(image_input)
         | 
| 1057 | 
            +
                        inputs.append(frame_index_input)
         | 
| 1058 | 
            +
                        
         | 
| 1059 | 
            +
                    # Add sketch inputs (both regular and ImageMask)
         | 
| 1060 | 
            +
                    for sketch_input, frame_index_input, _ in cond_sketches_inputs:
         | 
| 1061 | 
            +
                        inputs.append(sketch_input)
         | 
| 1062 | 
            +
                        inputs.append(frame_index_input)
         | 
| 1063 | 
            +
                        
         | 
| 1064 | 
            +
                    iface.launch(server_port=7860, server_name="0.0.0.0",
         | 
| 1065 | 
            +
                                 allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")), 
         | 
| 1066 | 
            +
                                               os.path.abspath(os.path.join(os.path.dirname(__file__), "samples"))])
         | 
    	
        model/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        model/dera.py
    ADDED
    
    | @@ -0,0 +1,195 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from einops import rearrange
         | 
| 4 | 
            +
            from .dit import flash_attention
         | 
| 5 | 
            +
            import torch.amp as amp
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class DeRAAttention(nn.Module):
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def __init__(self,
         | 
| 11 | 
            +
                             dim,
         | 
| 12 | 
            +
                             num_heads,
         | 
| 13 | 
            +
                             window_size=(-1, -1),
         | 
| 14 | 
            +
                             mode="spatial"):
         | 
| 15 | 
            +
                    assert dim % num_heads == 0
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    self.dim = dim
         | 
| 18 | 
            +
                    self.num_heads = num_heads
         | 
| 19 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 20 | 
            +
                    self.window_size = window_size
         | 
| 21 | 
            +
                    
         | 
| 22 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 23 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 24 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 25 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 26 | 
            +
                    self.visualize_attention = False
         | 
| 27 | 
            +
                    
         | 
| 28 | 
            +
                    if mode == 'spatial':
         | 
| 29 | 
            +
                        self.rope_apply = self.rope_apply_spatial
         | 
| 30 | 
            +
                    elif mode == 'temporal':
         | 
| 31 | 
            +
                        self.rope_apply = self.rope_apply_temporal
         | 
| 32 | 
            +
                    elif mode == 'spatial_temporal':
         | 
| 33 | 
            +
                        self.rope_apply = self.rope_apply_spatial_temporal
         | 
| 34 | 
            +
                    else:
         | 
| 35 | 
            +
                        raise ValueError("Invalid mode: {}".format(mode))
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                @staticmethod
         | 
| 38 | 
            +
                @amp.autocast(enabled=False, device_type="cuda")
         | 
| 39 | 
            +
                def rope_apply_spatial(x, grid_size, freqs, sequence_cond_compressed_indices=None):
         | 
| 40 | 
            +
                    batch, _, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
         | 
| 41 | 
            +
                    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 42 | 
            +
                    assert len(grid_size) == 2, "grid_size mustbe [h, w]"
         | 
| 43 | 
            +
                    h, w = grid_size[0], grid_size[1]
         | 
| 44 | 
            +
                    seq_len = h * w
         | 
| 45 | 
            +
                    x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(
         | 
| 46 | 
            +
                        batch, seq_len, n, -1, 2))
         | 
| 47 | 
            +
                    freqs_i = torch.cat([
         | 
| 48 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(1, h, w, -1),
         | 
| 49 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(1, h, w, -1)
         | 
| 50 | 
            +
                    ], dim=-1).reshape(seq_len, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
         | 
| 51 | 
            +
                    freqs_i = torch.concat([freqs_i.new_ones(batch, seq_len, 1, c//3), freqs_i], dim=3)
         | 
| 52 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
         | 
| 53 | 
            +
                    return x_i.float()
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                @staticmethod
         | 
| 56 | 
            +
                @amp.autocast(enabled=False, device_type="cuda")
         | 
| 57 | 
            +
                def rope_apply_temporal(x, grid_size, freqs, sequence_cond_compressed_indices=None):
         | 
| 58 | 
            +
                    batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
         | 
| 59 | 
            +
                    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 60 | 
            +
                    assert len(grid_size) == 1, "grid_size must be [t]"
         | 
| 61 | 
            +
                    seq_len = grid_size[0]
         | 
| 62 | 
            +
                    x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(batch, seq_len, n, -1, 2))
         | 
| 63 | 
            +
                    freqs_i = torch.cat([
         | 
| 64 | 
            +
                        freqs[0][:seq_len].view(seq_len, 1, 1, -1)
         | 
| 65 | 
            +
                    ], dim=-1).reshape(seq_len, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
         | 
| 66 | 
            +
                    freqs_i = torch.concat([freqs_i, freqs_i.new_ones(batch, seq_len, 1, 2 * c//3)], dim=3)
         | 
| 67 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
         | 
| 68 | 
            +
                    if seq_len_actual > seq_len:
         | 
| 69 | 
            +
                        sequence_cond_seq_length = seq_len_actual - seq_len
         | 
| 70 | 
            +
                        if sequence_cond_seq_length == seq_len:
         | 
| 71 | 
            +
                            x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, seq_len_actual - seq_len, n, -1, 2))
         | 
| 72 | 
            +
                            x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(3)
         | 
| 73 | 
            +
                        else:
         | 
| 74 | 
            +
                            sequence_cond_compressed_index = sequence_cond_compressed_indices[0]
         | 
| 75 | 
            +
                            sequence_cond_t_length = len(sequence_cond_compressed_index)
         | 
| 76 | 
            +
                            assert sequence_cond_t_length == sequence_cond_seq_length, "`sequence_cond_t_length` must be equal to `sequence_cond_seq_length`"
         | 
| 77 | 
            +
                            x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, sequence_cond_seq_length, n, -1, 2))
         | 
| 78 | 
            +
                            freqs_i_sequence_cond = torch.cat([
         | 
| 79 | 
            +
                                freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1),
         | 
| 80 | 
            +
                            ], dim=-1).reshape(sequence_cond_seq_length, 1, -1).unsqueeze(0).repeat(batch, 1, 1, 1)
         | 
| 81 | 
            +
                            freqs_i_sequence_cond = torch.concat([freqs_i_sequence_cond, freqs_i_sequence_cond.new_ones(batch, sequence_cond_t_length, 1, 2 * c//3)], dim=3)
         | 
| 82 | 
            +
                            x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(3)
         | 
| 83 | 
            +
                        x_i = torch.cat([x_i, x_i_sequence_cond], dim=1)
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    return x_i.float()
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                @staticmethod
         | 
| 88 | 
            +
                @amp.autocast(enabled=False, device_type="cuda")
         | 
| 89 | 
            +
                def rope_apply_spatial_temporal(x, grid_sizes, freqs, sequence_cond_compressed_indices=None):
         | 
| 90 | 
            +
                    batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
         | 
| 91 | 
            +
                    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 92 | 
            +
                    assert len(grid_sizes) == 3, "grid_sizes must be ([f, h, w])"
         | 
| 93 | 
            +
                    f, h, w = grid_sizes[0], grid_sizes[1], grid_sizes[2]
         | 
| 94 | 
            +
                    seq_len = f * h * w
         | 
| 95 | 
            +
                    x_i = torch.view_as_complex(x[:, :seq_len].to(torch.float64).reshape(
         | 
| 96 | 
            +
                        batch, seq_len, n, -1, 2))
         | 
| 97 | 
            +
                    freqs_i = torch.cat([
         | 
| 98 | 
            +
                        freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
         | 
| 99 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
         | 
| 100 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
         | 
| 101 | 
            +
                    ], dim=-1).reshape(seq_len, 1, -1)
         | 
| 102 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(3)
         | 
| 103 | 
            +
                    if seq_len_actual > seq_len:
         | 
| 104 | 
            +
                        sequence_cond_seq_length = seq_len_actual - seq_len
         | 
| 105 | 
            +
                        if sequence_cond_seq_length == seq_len:
         | 
| 106 | 
            +
                            x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, seq_len_actual - seq_len, n, -1, 2))
         | 
| 107 | 
            +
                            x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(3)
         | 
| 108 | 
            +
                        else:
         | 
| 109 | 
            +
                            sequence_cond_compressed_index = sequence_cond_compressed_indices[0]
         | 
| 110 | 
            +
                            sequence_cond_t_length = len(sequence_cond_compressed_index)
         | 
| 111 | 
            +
                            assert sequence_cond_t_length * h * w == sequence_cond_seq_length, "`sequence_cond_t_length * h * w` must be equal to `sequence_cond_seq_length`"
         | 
| 112 | 
            +
                            x_i_sequence_cond = torch.view_as_complex(x[:, seq_len:].to(torch.float64).reshape(batch, sequence_cond_seq_length, n, -1, 2))
         | 
| 113 | 
            +
                            freqs_i_sequence_cond = torch.cat([
         | 
| 114 | 
            +
                                freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1).expand(sequence_cond_t_length, h, w, -1),
         | 
| 115 | 
            +
                                freqs[1][:h].view(1, h, 1, -1).expand(sequence_cond_t_length, h, w, -1),
         | 
| 116 | 
            +
                                freqs[2][:w].view(1, 1, w, -1).expand(sequence_cond_t_length, h, w, -1)
         | 
| 117 | 
            +
                            ], dim=-1).reshape(sequence_cond_seq_length, 1, -1)
         | 
| 118 | 
            +
                            x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(3)
         | 
| 119 | 
            +
                        x_i = torch.cat([x_i, x_i_sequence_cond], dim=1)
         | 
| 120 | 
            +
                    return x_i.float()
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
                def forward(self, x, seq_lens, grid_size, freqs, sequence_cond_compressed_indices):
         | 
| 124 | 
            +
                    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 125 | 
            +
                    def qkv_fn(x):
         | 
| 126 | 
            +
                        q = self.q(x).view(b, s, n, d)
         | 
| 127 | 
            +
                        k = self.k(x).view(b, s, n, d)
         | 
| 128 | 
            +
                        v = self.v(x).view(b, s, n, d)
         | 
| 129 | 
            +
                        return q, k, v
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    q, k, v = qkv_fn(x)
         | 
| 132 | 
            +
                    q_rope = self.rope_apply(q, grid_size, freqs, sequence_cond_compressed_indices)
         | 
| 133 | 
            +
                    k_rope = self.rope_apply(k, grid_size, freqs, sequence_cond_compressed_indices)
         | 
| 134 | 
            +
                    if self.visualize_attention:
         | 
| 135 | 
            +
                        with torch.no_grad():
         | 
| 136 | 
            +
                            self._last_attn_maps = self._compute_attention_for_visualization(q_rope, k_rope) # CPU tesnor of [S, S]
         | 
| 137 | 
            +
                            self._last_grid_sizes = grid_size
         | 
| 138 | 
            +
                            self._last_seq_lens = seq_lens
         | 
| 139 | 
            +
                    x = flash_attention(
         | 
| 140 | 
            +
                        q=q_rope,
         | 
| 141 | 
            +
                        k=k_rope,
         | 
| 142 | 
            +
                        v=v,
         | 
| 143 | 
            +
                        k_lens=None,
         | 
| 144 | 
            +
                        window_size=self.window_size)
         | 
| 145 | 
            +
                    x = x.flatten(2)
         | 
| 146 | 
            +
                    x = self.o(x)
         | 
| 147 | 
            +
                    return x
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            class DeRA(nn.Module):
         | 
| 151 | 
            +
                def __init__(self, dim, rank, use_spatial=True, use_temporal=True):
         | 
| 152 | 
            +
                    super(DeRA, self).__init__()
         | 
| 153 | 
            +
                    self.dim = dim
         | 
| 154 | 
            +
                    self.rank = rank
         | 
| 155 | 
            +
                    self.use_spatial = use_spatial
         | 
| 156 | 
            +
                    self.use_temporal = use_temporal
         | 
| 157 | 
            +
                    
         | 
| 158 | 
            +
                    if not use_spatial and not use_temporal:
         | 
| 159 | 
            +
                        self.attention_mode = "none"
         | 
| 160 | 
            +
                    else:
         | 
| 161 | 
            +
                        self.attention_mode = "spatial_temporal" if use_spatial and use_temporal else "spatial" if use_spatial else "temporal"
         | 
| 162 | 
            +
                    
         | 
| 163 | 
            +
                    self.spatial_down_proj = nn.Linear(self.dim, rank, bias=False)
         | 
| 164 | 
            +
                    self.spatial_up_proj = nn.Linear(rank, self.dim, bias=False)
         | 
| 165 | 
            +
                    self.spatial_up_proj.weight.data.zero_()
         | 
| 166 | 
            +
                    if self.attention_mode != "none":
         | 
| 167 | 
            +
                        self.spatial_attn = DeRAAttention(dim=rank, num_heads=4, window_size=(-1, -1),
         | 
| 168 | 
            +
                                                          mode=self.attention_mode)
         | 
| 169 | 
            +
                    else:
         | 
| 170 | 
            +
                        self.spatial_attn = None
         | 
| 171 | 
            +
                            
         | 
| 172 | 
            +
                def forward(self, x, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices):
         | 
| 173 | 
            +
                    _, actual_seq, _ = x.shape
         | 
| 174 | 
            +
                    if isinstance(grid_sizes, torch.Tensor):
         | 
| 175 | 
            +
                        grid_sizes = tuple(grid_sizes[0].tolist())
         | 
| 176 | 
            +
                        
         | 
| 177 | 
            +
                    if len(grid_sizes) != 3:
         | 
| 178 | 
            +
                        raise ValueError("`grid_sizes` should contain time, spatial height, and width dimensions")
         | 
| 179 | 
            +
                    _, orig_h, orig_w = grid_sizes
         | 
| 180 | 
            +
                    actual_t = actual_seq // (orig_h * orig_w)
         | 
| 181 | 
            +
                    
         | 
| 182 | 
            +
                    x_low = self.spatial_down_proj(x)
         | 
| 183 | 
            +
                    if self.attention_mode == "spatial":
         | 
| 184 | 
            +
                        x_low_spatial = rearrange(x_low, 'b (t h w) r -> (b t) (h w) r', t=actual_t, h=orig_h, w=orig_w)
         | 
| 185 | 
            +
                        x_low_spatial = self.spatial_attn(x_low_spatial, seq_lens, grid_sizes[1:], freqs, sequence_cond_compressed_indices)
         | 
| 186 | 
            +
                        x_low = rearrange(x_low_spatial, '(b t) (h w) r -> b (t h w) r', t=actual_t, h=orig_h, w=orig_w)
         | 
| 187 | 
            +
                    elif self.attention_mode == "temporal":
         | 
| 188 | 
            +
                        x_low_temporal = rearrange(x_low, 'b (t h w) r -> (b h w) t r', t=actual_t, h=orig_h, w=orig_w)
         | 
| 189 | 
            +
                        x_low_temporal = self.spatial_attn(x_low_temporal, seq_lens, grid_sizes[:1], freqs, sequence_cond_compressed_indices)
         | 
| 190 | 
            +
                        x_low = rearrange(x_low_temporal, '(b h w) t r -> b (t h w) r', t=actual_t, h=orig_h, w=orig_w)
         | 
| 191 | 
            +
                    elif self.attention_mode == "spatial_temporal":
         | 
| 192 | 
            +
                        x_low = self.spatial_attn(x_low, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices)
         | 
| 193 | 
            +
                    x_out = self.spatial_up_proj(x_low)
         | 
| 194 | 
            +
                    return x_out
         | 
| 195 | 
            +
                
         | 
    	
        model/dit.py
    ADDED
    
    | @@ -0,0 +1,1090 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.amp as amp
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            from util.model_util import hash_state_dict_keys
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            try:
         | 
| 10 | 
            +
                import flash_attn_interface
         | 
| 11 | 
            +
                FLASH_ATTN_3_AVAILABLE = True
         | 
| 12 | 
            +
            except ModuleNotFoundError:
         | 
| 13 | 
            +
                FLASH_ATTN_3_AVAILABLE = False
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            try:
         | 
| 16 | 
            +
                import flash_attn
         | 
| 17 | 
            +
                FLASH_ATTN_2_AVAILABLE = True
         | 
| 18 | 
            +
            except ModuleNotFoundError:
         | 
| 19 | 
            +
                FLASH_ATTN_2_AVAILABLE = False
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            try:
         | 
| 22 | 
            +
                from sageattention import sageattn
         | 
| 23 | 
            +
                SAGE_ATTN_AVAILABLE = True
         | 
| 24 | 
            +
            except ModuleNotFoundError:
         | 
| 25 | 
            +
                SAGE_ATTN_AVAILABLE = False
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            import warnings
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            __all__ = ['WanModel']
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def flash_attention(
         | 
| 34 | 
            +
                q,
         | 
| 35 | 
            +
                k,
         | 
| 36 | 
            +
                v,
         | 
| 37 | 
            +
                q_lens=None,
         | 
| 38 | 
            +
                k_lens=None,
         | 
| 39 | 
            +
                dropout_p=0.,
         | 
| 40 | 
            +
                softmax_scale=None,
         | 
| 41 | 
            +
                q_scale=None,
         | 
| 42 | 
            +
                causal=False,
         | 
| 43 | 
            +
                window_size=(-1, -1),
         | 
| 44 | 
            +
                deterministic=False,
         | 
| 45 | 
            +
                dtype=torch.bfloat16,
         | 
| 46 | 
            +
                version=None,
         | 
| 47 | 
            +
            ):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                q:              [B, Lq, Nq, C1].
         | 
| 50 | 
            +
                k:              [B, Lk, Nk, C1].
         | 
| 51 | 
            +
                v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.
         | 
| 52 | 
            +
                q_lens:         [B].
         | 
| 53 | 
            +
                k_lens:         [B].
         | 
| 54 | 
            +
                dropout_p:      float. Dropout probability.
         | 
| 55 | 
            +
                softmax_scale:  float. The scaling of QK^T before applying softmax.
         | 
| 56 | 
            +
                causal:         bool. Whether to apply causal attention mask.
         | 
| 57 | 
            +
                window_size:    (left right). If not (-1, -1), apply sliding window local attention.
         | 
| 58 | 
            +
                deterministic:  bool. If True, slightly slower and uses more memory.
         | 
| 59 | 
            +
                dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                half_dtypes = (torch.float16, torch.bfloat16)
         | 
| 62 | 
            +
                assert dtype in half_dtypes
         | 
| 63 | 
            +
                assert q.device.type == 'cuda' and q.size(-1) <= 256
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                # params
         | 
| 66 | 
            +
                b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def half(x):
         | 
| 69 | 
            +
                    return x if x.dtype in half_dtypes else x.to(dtype)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                # preprocess query
         | 
| 72 | 
            +
                if q_lens is None:
         | 
| 73 | 
            +
                    q = half(q.flatten(0, 1))
         | 
| 74 | 
            +
                    q_lens = torch.tensor(
         | 
| 75 | 
            +
                        [lq] * b, dtype=torch.int32).to(
         | 
| 76 | 
            +
                            device=q.device, non_blocking=True)
         | 
| 77 | 
            +
                else:
         | 
| 78 | 
            +
                    q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                # preprocess key, value
         | 
| 81 | 
            +
                if k_lens is None:
         | 
| 82 | 
            +
                    k = half(k.flatten(0, 1))
         | 
| 83 | 
            +
                    v = half(v.flatten(0, 1))
         | 
| 84 | 
            +
                    k_lens = torch.tensor(
         | 
| 85 | 
            +
                        [lk] * b, dtype=torch.int32).to(
         | 
| 86 | 
            +
                            device=k.device, non_blocking=True)
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                    k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
         | 
| 89 | 
            +
                    v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                q = q.to(v.dtype)
         | 
| 92 | 
            +
                k = k.to(v.dtype)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                if q_scale is not None:
         | 
| 95 | 
            +
                    q = q * q_scale
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
         | 
| 98 | 
            +
                    warnings.warn(
         | 
| 99 | 
            +
                        'Flash attention 3 is not available, use flash attention 2 instead.'
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                # apply attention
         | 
| 103 | 
            +
                if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
         | 
| 104 | 
            +
                    # Note: dropout_p, window_size are not supported in FA3 now.
         | 
| 105 | 
            +
                    x = flash_attn_interface.flash_attn_varlen_func(
         | 
| 106 | 
            +
                        q=q,
         | 
| 107 | 
            +
                        k=k,
         | 
| 108 | 
            +
                        v=v,
         | 
| 109 | 
            +
                        cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
         | 
| 110 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 111 | 
            +
                        cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
         | 
| 112 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 113 | 
            +
                        seqused_q=None,
         | 
| 114 | 
            +
                        seqused_k=None,
         | 
| 115 | 
            +
                        max_seqlen_q=lq,
         | 
| 116 | 
            +
                        max_seqlen_k=lk,
         | 
| 117 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 118 | 
            +
                        causal=causal,
         | 
| 119 | 
            +
                        deterministic=deterministic)[0].unflatten(0, (b, lq))
         | 
| 120 | 
            +
                elif FLASH_ATTN_2_AVAILABLE:
         | 
| 121 | 
            +
                    x = flash_attn.flash_attn_varlen_func(
         | 
| 122 | 
            +
                        q=q,
         | 
| 123 | 
            +
                        k=k,
         | 
| 124 | 
            +
                        v=v,
         | 
| 125 | 
            +
                        cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
         | 
| 126 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 127 | 
            +
                        cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
         | 
| 128 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 129 | 
            +
                        max_seqlen_q=lq,
         | 
| 130 | 
            +
                        max_seqlen_k=lk,
         | 
| 131 | 
            +
                        dropout_p=dropout_p,
         | 
| 132 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 133 | 
            +
                        causal=causal,
         | 
| 134 | 
            +
                        window_size=window_size,
         | 
| 135 | 
            +
                        deterministic=deterministic).unflatten(0, (b, lq))
         | 
| 136 | 
            +
                elif SAGE_ATTN_AVAILABLE:
         | 
| 137 | 
            +
                    q = q.unsqueeze(0).transpose(1, 2).to(dtype)
         | 
| 138 | 
            +
                    k = k.unsqueeze(0).transpose(1, 2).to(dtype)
         | 
| 139 | 
            +
                    v = v.unsqueeze(0).transpose(1, 2).to(dtype)
         | 
| 140 | 
            +
                    x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
         | 
| 141 | 
            +
                    x = x.transpose(1, 2).contiguous()
         | 
| 142 | 
            +
                else:
         | 
| 143 | 
            +
                    q = q.unsqueeze(0).transpose(1, 2).to(dtype)
         | 
| 144 | 
            +
                    k = k.unsqueeze(0).transpose(1, 2).to(dtype)
         | 
| 145 | 
            +
                    v = v.unsqueeze(0).transpose(1, 2).to(dtype)
         | 
| 146 | 
            +
                    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
         | 
| 147 | 
            +
                    x = x.transpose(1, 2).contiguous()
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                # output
         | 
| 150 | 
            +
                return x.type(out_dtype)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def create_sdpa_mask(q, k, q_lens, k_lens, causal=False):
         | 
| 154 | 
            +
                b, lq, lk = q.size(0), q.size(1), k.size(1)
         | 
| 155 | 
            +
                if q_lens is None:
         | 
| 156 | 
            +
                    q_lens = torch.tensor([lq] * b, dtype=torch.int32)
         | 
| 157 | 
            +
                if k_lens is None:
         | 
| 158 | 
            +
                    k_lens = torch.tensor([lk] * b, dtype=torch.int32)
         | 
| 159 | 
            +
                attn_mask = torch.zeros((b, lq, lk), dtype=torch.bool)
         | 
| 160 | 
            +
                for i in range(b):
         | 
| 161 | 
            +
                    q_len, k_len = q_lens[i], k_lens[i]
         | 
| 162 | 
            +
                    attn_mask[i, q_len:, :] = True
         | 
| 163 | 
            +
                    attn_mask[i, :, k_len:] = True
         | 
| 164 | 
            +
                    
         | 
| 165 | 
            +
                    if causal:
         | 
| 166 | 
            +
                        causal_mask = torch.triu(torch.ones((lq, lk), dtype=torch.bool), diagonal=1)
         | 
| 167 | 
            +
                        attn_mask[i, :, :] = torch.logical_or(attn_mask[i, :, :], causal_mask)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                attn_mask = attn_mask.logical_not().to(q.device, non_blocking=True)
         | 
| 170 | 
            +
                return attn_mask
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def attention(
         | 
| 174 | 
            +
                q,
         | 
| 175 | 
            +
                k,
         | 
| 176 | 
            +
                v,
         | 
| 177 | 
            +
                q_lens=None,
         | 
| 178 | 
            +
                k_lens=None,
         | 
| 179 | 
            +
                dropout_p=0.,
         | 
| 180 | 
            +
                softmax_scale=None,
         | 
| 181 | 
            +
                q_scale=None,
         | 
| 182 | 
            +
                causal=False,
         | 
| 183 | 
            +
                window_size=(-1, -1),
         | 
| 184 | 
            +
                deterministic=False,
         | 
| 185 | 
            +
                dtype=torch.bfloat16,
         | 
| 186 | 
            +
                fa_version=None,
         | 
| 187 | 
            +
            ):
         | 
| 188 | 
            +
                if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
         | 
| 189 | 
            +
                    return flash_attention(
         | 
| 190 | 
            +
                        q=q,
         | 
| 191 | 
            +
                        k=k,
         | 
| 192 | 
            +
                        v=v,
         | 
| 193 | 
            +
                        q_lens=q_lens,
         | 
| 194 | 
            +
                        k_lens=k_lens,
         | 
| 195 | 
            +
                        dropout_p=dropout_p,
         | 
| 196 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 197 | 
            +
                        q_scale=q_scale,
         | 
| 198 | 
            +
                        causal=causal,
         | 
| 199 | 
            +
                        window_size=window_size,
         | 
| 200 | 
            +
                        deterministic=deterministic,
         | 
| 201 | 
            +
                        dtype=dtype,
         | 
| 202 | 
            +
                        version=fa_version,
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                else:
         | 
| 205 | 
            +
                    if q_lens is not None or k_lens is not None:
         | 
| 206 | 
            +
                        warnings.warn('Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.')
         | 
| 207 | 
            +
                    attn_mask = None
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    q = q.transpose(1, 2).to(dtype)
         | 
| 210 | 
            +
                    k = k.transpose(1, 2).to(dtype)
         | 
| 211 | 
            +
                    v = v.transpose(1, 2).to(dtype)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    out = out.transpose(1, 2).contiguous()
         | 
| 216 | 
            +
                    return out
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            def sinusoidal_embedding_1d(dim, position):
         | 
| 221 | 
            +
                # preprocess
         | 
| 222 | 
            +
                assert dim % 2 == 0
         | 
| 223 | 
            +
                half = dim // 2
         | 
| 224 | 
            +
                position = position.type(torch.float64)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                # calculation
         | 
| 227 | 
            +
                sinusoid = torch.outer(
         | 
| 228 | 
            +
                    position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
         | 
| 229 | 
            +
                x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
         | 
| 230 | 
            +
                return x
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            @amp.autocast(enabled=False, device_type="cuda")
         | 
| 234 | 
            +
            def rope_params(max_seq_len, dim, theta=10000):
         | 
| 235 | 
            +
                assert dim % 2 == 0
         | 
| 236 | 
            +
                freqs = torch.outer(
         | 
| 237 | 
            +
                    torch.arange(max_seq_len),
         | 
| 238 | 
            +
                    1.0 / torch.pow(theta,
         | 
| 239 | 
            +
                                    torch.arange(0, dim, 2).to(torch.float64).div(dim)))
         | 
| 240 | 
            +
                freqs = torch.polar(torch.ones_like(freqs), freqs)
         | 
| 241 | 
            +
                return freqs
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            @amp.autocast(enabled=False, device_type="cuda")
         | 
| 245 | 
            +
            def rope_apply(x, grid_sizes, freqs, sequence_cond_compressed_indices=None):
         | 
| 246 | 
            +
                batch, seq_len_actual, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
         | 
| 247 | 
            +
                freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 248 | 
            +
                output = []
         | 
| 249 | 
            +
                assert len(grid_sizes) == batch, "grid_sizes must have the same length as the batch size ([b, 3=[f, h, w])"
         | 
| 250 | 
            +
                for i, (f, h, w) in enumerate(grid_sizes.tolist()):
         | 
| 251 | 
            +
                    seq_len = f * h * w
         | 
| 252 | 
            +
                    x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
         | 
| 253 | 
            +
                        seq_len, n, -1, 2))
         | 
| 254 | 
            +
                    freqs_i = torch.cat([
         | 
| 255 | 
            +
                        freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
         | 
| 256 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
         | 
| 257 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
         | 
| 258 | 
            +
                    ], dim=-1).reshape(seq_len, 1, -1)
         | 
| 259 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
         | 
| 260 | 
            +
                    
         | 
| 261 | 
            +
                    if seq_len_actual > seq_len:
         | 
| 262 | 
            +
                        sequence_cond_seq_length = seq_len_actual - seq_len
         | 
| 263 | 
            +
                        if sequence_cond_seq_length == seq_len:
         | 
| 264 | 
            +
                            x_i_sequence_cond = torch.view_as_complex(x[i, seq_len:].to(torch.float64).reshape(seq_len_actual - seq_len, n, -1, 2))
         | 
| 265 | 
            +
                            x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i).flatten(2)
         | 
| 266 | 
            +
                        else:
         | 
| 267 | 
            +
                            sequence_cond_compressed_index = sequence_cond_compressed_indices[i]
         | 
| 268 | 
            +
                            sequence_cond_t_length = len(sequence_cond_compressed_index)
         | 
| 269 | 
            +
                            assert sequence_cond_t_length * h * w == sequence_cond_seq_length, "`sequence_cond_t_length * h * w` must be equal to `sequence_cond_seq_length`"
         | 
| 270 | 
            +
                            x_i_sequence_cond = torch.view_as_complex(x[i, seq_len:].to(torch.float64).reshape(sequence_cond_seq_length, n, -1, 2))
         | 
| 271 | 
            +
                            freqs_i_sequence_cond = torch.cat([
         | 
| 272 | 
            +
                                freqs[0][sequence_cond_compressed_index].view(sequence_cond_t_length, 1, 1, -1).expand(sequence_cond_t_length, h, w, -1),
         | 
| 273 | 
            +
                                freqs[1][:h].view(1, h, 1, -1).expand(sequence_cond_t_length, h, w, -1),
         | 
| 274 | 
            +
                                freqs[2][:w].view(1, 1, w, -1).expand(sequence_cond_t_length, h, w, -1)
         | 
| 275 | 
            +
                            ], dim=-1).reshape(sequence_cond_seq_length, 1, -1)
         | 
| 276 | 
            +
                            x_i_sequence_cond = torch.view_as_real(x_i_sequence_cond * freqs_i_sequence_cond).flatten(2)
         | 
| 277 | 
            +
                        x_i = torch.cat([x_i, x_i_sequence_cond])
         | 
| 278 | 
            +
                    
         | 
| 279 | 
            +
                    output.append(x_i)
         | 
| 280 | 
            +
                return torch.stack(output).float()
         | 
| 281 | 
            +
             | 
| 282 | 
            +
             | 
| 283 | 
            +
            class WanRMSNorm(nn.Module):
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def __init__(self, dim, eps=1e-5):
         | 
| 286 | 
            +
                    super().__init__()
         | 
| 287 | 
            +
                    self.dim = dim
         | 
| 288 | 
            +
                    self.eps = eps
         | 
| 289 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                def forward(self, x):
         | 
| 292 | 
            +
                    return self._norm(x.float()).type_as(x) * self.weight
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def _norm(self, x):
         | 
| 295 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
| 298 | 
            +
            class WanLayerNorm(nn.LayerNorm):
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def __init__(self, dim, eps=1e-6, elementwise_affine=False):
         | 
| 301 | 
            +
                    super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def forward(self, x):
         | 
| 304 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
             | 
| 307 | 
            +
            class WanSelfAttention(nn.Module):
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def __init__(self,
         | 
| 310 | 
            +
                             dim,
         | 
| 311 | 
            +
                             num_heads,
         | 
| 312 | 
            +
                             window_size=(-1, -1),
         | 
| 313 | 
            +
                             qk_norm=True,
         | 
| 314 | 
            +
                             eps=1e-6):
         | 
| 315 | 
            +
                    assert dim % num_heads == 0
         | 
| 316 | 
            +
                    super().__init__()
         | 
| 317 | 
            +
                    self.dim = dim
         | 
| 318 | 
            +
                    self.num_heads = num_heads
         | 
| 319 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 320 | 
            +
                    self.window_size = window_size
         | 
| 321 | 
            +
                    self.qk_norm = qk_norm
         | 
| 322 | 
            +
                    self.eps = eps
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 325 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 326 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 327 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 328 | 
            +
                    self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 329 | 
            +
                    self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 330 | 
            +
                    self.visualize_attention = False
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                def forward(self, x, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices):
         | 
| 333 | 
            +
                    """
         | 
| 334 | 
            +
                    Args:
         | 
| 335 | 
            +
                        x:              [B, L, C].
         | 
| 336 | 
            +
                        seq_lens:       [B].
         | 
| 337 | 
            +
                        grid_sizes:     [B, 3=[f, h, w]].
         | 
| 338 | 
            +
                        freqs:          [L, 2].
         | 
| 339 | 
            +
                        sequence_cond_compressed_indices: [B, T_sequence_condITION].
         | 
| 340 | 
            +
                        
         | 
| 341 | 
            +
                    `f` in `grid_sizes` can less than the actual seq_lens (L), 
         | 
| 342 | 
            +
                    which indicates full in-context condition (when L=2*f) or 
         | 
| 343 | 
            +
                    sparse in-context condition (when `f` < L < 2*f and `sequence_cond_compressed_indices` is not None) is used.
         | 
| 344 | 
            +
                    """
         | 
| 345 | 
            +
                    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    def qkv_fn(x):
         | 
| 348 | 
            +
                        q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 349 | 
            +
                        k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 350 | 
            +
                        v = self.v(x).view(b, s, n, d)
         | 
| 351 | 
            +
                        return q, k, v
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    q, k, v = qkv_fn(x)
         | 
| 354 | 
            +
                    
         | 
| 355 | 
            +
                    q_rope = rope_apply(q, grid_sizes, freqs, sequence_cond_compressed_indices)
         | 
| 356 | 
            +
                    k_rope = rope_apply(k, grid_sizes, freqs, sequence_cond_compressed_indices)
         | 
| 357 | 
            +
                    
         | 
| 358 | 
            +
                    if self.visualize_attention:
         | 
| 359 | 
            +
                        with torch.no_grad():
         | 
| 360 | 
            +
                            self._last_attn_maps = self._compute_attention_for_visualization(q_rope, k_rope) # CPU tesnor of [S, S]
         | 
| 361 | 
            +
                            self._last_grid_sizes = grid_sizes
         | 
| 362 | 
            +
                            self._last_seq_lens = seq_lens
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    x = flash_attention(
         | 
| 365 | 
            +
                        q=q_rope,
         | 
| 366 | 
            +
                        k=k_rope,
         | 
| 367 | 
            +
                        v=v,
         | 
| 368 | 
            +
                        k_lens=None,
         | 
| 369 | 
            +
                        window_size=self.window_size)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    # output
         | 
| 372 | 
            +
                    x = x.flatten(2)
         | 
| 373 | 
            +
                    x = self.o(x)
         | 
| 374 | 
            +
                    return x
         | 
| 375 | 
            +
                
         | 
| 376 | 
            +
                def _compute_attention_for_visualization(self, q, k):
         | 
| 377 | 
            +
                    """Compute attention maps for visualization purposes"""
         | 
| 378 | 
            +
                    # b, _, n, d = q.shape
         | 
| 379 | 
            +
                    print("Computing attention maps for visualization")
         | 
| 380 | 
            +
                    # Reshape for attention computation
         | 
| 381 | 
            +
                    q = q.permute(0, 2, 1, 3)  # [b, n, s, d]
         | 
| 382 | 
            +
                    k = k.permute(0, 2, 1, 3)  # [b, n, s, d]
         | 
| 383 | 
            +
                    # query: b, n, s, d
         | 
| 384 | 
            +
                    print("q.shape=", q.shape)
         | 
| 385 | 
            +
                    print("k.shape=", k.shape)
         | 
| 386 | 
            +
                    attention_probs_list = []
         | 
| 387 | 
            +
                    for i in range(0, q.shape[1], 20):
         | 
| 388 | 
            +
                        print(f"Computing attention for head {i} to {i+20}")
         | 
| 389 | 
            +
                        query_attention = q[-1][i : i + 20]
         | 
| 390 | 
            +
                        key_attention = k[-1][i : i + 20]
         | 
| 391 | 
            +
                        identity_matrix = torch.eye(
         | 
| 392 | 
            +
                            query_attention.shape[-2],
         | 
| 393 | 
            +
                            device=query_attention.device,
         | 
| 394 | 
            +
                            dtype=query_attention.dtype,
         | 
| 395 | 
            +
                        ) # shape=[s]
         | 
| 396 | 
            +
                        attention_probs_temp = torch.nn.functional.scaled_dot_product_attention(
         | 
| 397 | 
            +
                            query_attention,
         | 
| 398 | 
            +
                            key_attention,
         | 
| 399 | 
            +
                            identity_matrix,
         | 
| 400 | 
            +
                            attn_mask=None,
         | 
| 401 | 
            +
                            dropout_p=0.0,
         | 
| 402 | 
            +
                            is_causal=False,
         | 
| 403 | 
            +
                        )
         | 
| 404 | 
            +
                        attention_probs_list.append(attention_probs_temp.detach().cpu())
         | 
| 405 | 
            +
                        del (
         | 
| 406 | 
            +
                            query_attention,
         | 
| 407 | 
            +
                            key_attention,
         | 
| 408 | 
            +
                            identity_matrix,
         | 
| 409 | 
            +
                            attention_probs_temp,
         | 
| 410 | 
            +
                        )
         | 
| 411 | 
            +
                    attention_probs = torch.mean(torch.cat(attention_probs_list), dim=0).float().numpy()
         | 
| 412 | 
            +
                    print("Attention maps computed. Shape=", attention_probs.shape)
         | 
| 413 | 
            +
                    # Only keep attention maps, don't compute the output
         | 
| 414 | 
            +
                    return attention_probs # [s, s]
         | 
| 415 | 
            +
             | 
| 416 | 
            +
             | 
| 417 | 
            +
            class WanT2VCrossAttention(WanSelfAttention):
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 420 | 
            +
                    """
         | 
| 421 | 
            +
                    x:              [B, L1, C].
         | 
| 422 | 
            +
                    context:        [B, L2, C].
         | 
| 423 | 
            +
                    context_lens:   [B].
         | 
| 424 | 
            +
                    """
         | 
| 425 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    # compute query, key, value
         | 
| 428 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 429 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 430 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    # compute attention
         | 
| 433 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    # output
         | 
| 436 | 
            +
                    x = x.flatten(2)
         | 
| 437 | 
            +
                    x = self.o(x)
         | 
| 438 | 
            +
                    return x
         | 
| 439 | 
            +
             | 
| 440 | 
            +
             | 
| 441 | 
            +
            class WanI2VCrossAttention(WanSelfAttention):
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                def __init__(self,
         | 
| 444 | 
            +
                             dim,
         | 
| 445 | 
            +
                             num_heads,
         | 
| 446 | 
            +
                             window_size=(-1, -1),
         | 
| 447 | 
            +
                             qk_norm=True,
         | 
| 448 | 
            +
                             eps=1e-6):
         | 
| 449 | 
            +
                    super().__init__(dim, num_heads, window_size, qk_norm, eps)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    self.k_img = nn.Linear(dim, dim)
         | 
| 452 | 
            +
                    self.v_img = nn.Linear(dim, dim)
         | 
| 453 | 
            +
                    # self.alpha = nn.Parameter(torch.zeros((1, )))
         | 
| 454 | 
            +
                    self.norm_k_img = WanRMSNorm(
         | 
| 455 | 
            +
                        dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 458 | 
            +
                    """
         | 
| 459 | 
            +
                    x:              [B, L1, C].
         | 
| 460 | 
            +
                    context:        [B, L2, C].
         | 
| 461 | 
            +
                    context_lens:   [B].
         | 
| 462 | 
            +
                    """
         | 
| 463 | 
            +
                    context_img = context[:, :257]
         | 
| 464 | 
            +
                    context = context[:, 257:]
         | 
| 465 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    # compute query, key, value
         | 
| 468 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 469 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 470 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 471 | 
            +
                    k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
         | 
| 472 | 
            +
                    v_img = self.v_img(context_img).view(b, -1, n, d)
         | 
| 473 | 
            +
                    img_x = flash_attention(q, k_img, v_img, k_lens=None)
         | 
| 474 | 
            +
                    # compute attention
         | 
| 475 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    # output
         | 
| 478 | 
            +
                    x = x.flatten(2)
         | 
| 479 | 
            +
                    img_x = img_x.flatten(2)
         | 
| 480 | 
            +
                    x = x + img_x
         | 
| 481 | 
            +
                    x = self.o(x)
         | 
| 482 | 
            +
                    return x
         | 
| 483 | 
            +
             | 
| 484 | 
            +
             | 
| 485 | 
            +
            WANX_CROSSATTENTION_CLASSES = {
         | 
| 486 | 
            +
                't2v_cross_attn': WanT2VCrossAttention,
         | 
| 487 | 
            +
                'i2v_cross_attn': WanI2VCrossAttention,
         | 
| 488 | 
            +
            }
         | 
| 489 | 
            +
             | 
| 490 | 
            +
             | 
| 491 | 
            +
            class WanAttentionBlock(nn.Module):
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                def __init__(self,
         | 
| 494 | 
            +
                             cross_attn_type,
         | 
| 495 | 
            +
                             dim,
         | 
| 496 | 
            +
                             ffn_dim,
         | 
| 497 | 
            +
                             num_heads,
         | 
| 498 | 
            +
                             window_size=(-1, -1),
         | 
| 499 | 
            +
                             qk_norm=True,
         | 
| 500 | 
            +
                             cross_attn_norm=False,
         | 
| 501 | 
            +
                             eps=1e-6,
         | 
| 502 | 
            +
                             use_local_lora=False,
         | 
| 503 | 
            +
                             use_dera=False,
         | 
| 504 | 
            +
                             dera_rank=None,
         | 
| 505 | 
            +
                             use_dera_spatial=True,
         | 
| 506 | 
            +
                             use_dera_temporal=True):
         | 
| 507 | 
            +
                    super().__init__()
         | 
| 508 | 
            +
                    self.dim = dim
         | 
| 509 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 510 | 
            +
                    self.num_heads = num_heads
         | 
| 511 | 
            +
                    self.window_size = window_size
         | 
| 512 | 
            +
                    self.qk_norm = qk_norm
         | 
| 513 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 514 | 
            +
                    self.eps = eps
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                    # layers
         | 
| 517 | 
            +
                    self.norm1 = WanLayerNorm(dim, eps)
         | 
| 518 | 
            +
                    self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
         | 
| 519 | 
            +
                    self.norm3 = WanLayerNorm(
         | 
| 520 | 
            +
                        dim, eps,
         | 
| 521 | 
            +
                        elementwise_affine=True) if cross_attn_norm else nn.Identity()
         | 
| 522 | 
            +
                    self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
         | 
| 523 | 
            +
                        dim, num_heads, (-1, -1), qk_norm, eps)
         | 
| 524 | 
            +
                    self.norm2 = WanLayerNorm(dim, eps)
         | 
| 525 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 526 | 
            +
                        nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
         | 
| 527 | 
            +
                        nn.Linear(ffn_dim, dim))
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    # modulation
         | 
| 530 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
         | 
| 531 | 
            +
                    
         | 
| 532 | 
            +
                    self.use_local_lora = use_local_lora
         | 
| 533 | 
            +
                    if use_local_lora:
         | 
| 534 | 
            +
                        from .local_lora import LocalLoRA
         | 
| 535 | 
            +
                        self.local_lora = LocalLoRA(dim=dim, rank=64, kernel_size=(3, 3), stride=(1, 1))
         | 
| 536 | 
            +
                        
         | 
| 537 | 
            +
                    self.use_dera = use_dera
         | 
| 538 | 
            +
                    if use_dera:
         | 
| 539 | 
            +
                        from .dera import DeRA
         | 
| 540 | 
            +
                        self.dera = DeRA(dim, rank=dera_rank, use_spatial=use_dera_spatial, use_temporal=use_dera_temporal)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                def forward(
         | 
| 543 | 
            +
                    self,
         | 
| 544 | 
            +
                    x,
         | 
| 545 | 
            +
                    e,
         | 
| 546 | 
            +
                    seq_lens,
         | 
| 547 | 
            +
                    grid_sizes,
         | 
| 548 | 
            +
                    freqs,
         | 
| 549 | 
            +
                    context,
         | 
| 550 | 
            +
                    context_lens,
         | 
| 551 | 
            +
                    sequence_cond_compressed_indices,
         | 
| 552 | 
            +
                    dera_freqs=None
         | 
| 553 | 
            +
                ):
         | 
| 554 | 
            +
                    assert e.dtype == torch.float32
         | 
| 555 | 
            +
                    with amp.autocast(dtype=torch.float32, device_type="cuda"):
         | 
| 556 | 
            +
                        e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
         | 
| 557 | 
            +
                    assert e[0].dtype == torch.float32
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                    # self-attention
         | 
| 560 | 
            +
                    x_self_attn_input = self.norm1(x).float() * (1 + e[1]) + e[0]
         | 
| 561 | 
            +
                    y = self.self_attn(x_self_attn_input, seq_lens, grid_sizes, freqs, sequence_cond_compressed_indices)
         | 
| 562 | 
            +
                    if self.use_local_lora:
         | 
| 563 | 
            +
                        y = y + self.local_lora(x_self_attn_input, grid_sizes)
         | 
| 564 | 
            +
                    
         | 
| 565 | 
            +
                    if self.use_dera:
         | 
| 566 | 
            +
                        y = y + self.dera(x_self_attn_input, seq_lens, grid_sizes, dera_freqs, sequence_cond_compressed_indices)
         | 
| 567 | 
            +
                        
         | 
| 568 | 
            +
                    with amp.autocast(dtype=torch.float32, device_type="cuda"):
         | 
| 569 | 
            +
                        x = x + y * e[2]
         | 
| 570 | 
            +
                        
         | 
| 571 | 
            +
                    def cross_attn_ffn(x, context, context_lens, e):
         | 
| 572 | 
            +
                        x = x + self.cross_attn(self.norm3(x), context, context_lens)
         | 
| 573 | 
            +
                        y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
         | 
| 574 | 
            +
                        with amp.autocast(dtype=torch.float32, device_type="cuda"):
         | 
| 575 | 
            +
                            x = x + y * e[5]
         | 
| 576 | 
            +
                        return x
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    x = cross_attn_ffn(x, context, context_lens, e)
         | 
| 579 | 
            +
                    return x
         | 
| 580 | 
            +
             | 
| 581 | 
            +
             | 
| 582 | 
            +
            class Head(nn.Module):
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                def __init__(self, dim, out_dim, patch_size, eps=1e-6):
         | 
| 585 | 
            +
                    super().__init__()
         | 
| 586 | 
            +
                    self.dim = dim
         | 
| 587 | 
            +
                    self.out_dim = out_dim
         | 
| 588 | 
            +
                    self.patch_size = patch_size
         | 
| 589 | 
            +
                    self.eps = eps
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    # layers
         | 
| 592 | 
            +
                    out_dim = math.prod(patch_size) * out_dim
         | 
| 593 | 
            +
                    self.norm = WanLayerNorm(dim, eps)
         | 
| 594 | 
            +
                    self.head = nn.Linear(dim, out_dim)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    # modulation
         | 
| 597 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                def forward(self, x, e):
         | 
| 600 | 
            +
                    assert e.dtype == torch.float32
         | 
| 601 | 
            +
                    with amp.autocast(dtype=torch.float32, device_type="cuda"):
         | 
| 602 | 
            +
                        e = (self.modulation.to(dtype=e.dtype, device=e.device) + e.unsqueeze(1)).chunk(2, dim=1)
         | 
| 603 | 
            +
                        x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
         | 
| 604 | 
            +
                    return x
         | 
| 605 | 
            +
             | 
| 606 | 
            +
             | 
| 607 | 
            +
            class MLPProj(torch.nn.Module):
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                def __init__(self, in_dim, out_dim):
         | 
| 610 | 
            +
                    super().__init__()
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                    self.proj = torch.nn.Sequential(
         | 
| 613 | 
            +
                        torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
         | 
| 614 | 
            +
                        torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
         | 
| 615 | 
            +
                        torch.nn.LayerNorm(out_dim))
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                def forward(self, image_embeds):
         | 
| 618 | 
            +
                    clip_extra_context_tokens = self.proj(image_embeds)
         | 
| 619 | 
            +
                    return clip_extra_context_tokens
         | 
| 620 | 
            +
             | 
| 621 | 
            +
             | 
| 622 | 
            +
            class WanModel(nn.Module):
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                def __init__(self,
         | 
| 625 | 
            +
                             model_type='t2v',
         | 
| 626 | 
            +
                             patch_size=(1, 2, 2),
         | 
| 627 | 
            +
                             text_len=512,
         | 
| 628 | 
            +
                             in_dim=16,
         | 
| 629 | 
            +
                             dim=2048,
         | 
| 630 | 
            +
                             ffn_dim=8192,
         | 
| 631 | 
            +
                             freq_dim=256,
         | 
| 632 | 
            +
                             text_dim=4096,
         | 
| 633 | 
            +
                             out_dim=16,
         | 
| 634 | 
            +
                             num_heads=16,
         | 
| 635 | 
            +
                             num_layers=32,
         | 
| 636 | 
            +
                             window_size=(-1, -1),
         | 
| 637 | 
            +
                             qk_norm=True,
         | 
| 638 | 
            +
                             cross_attn_norm=False,
         | 
| 639 | 
            +
                             eps=1e-6,
         | 
| 640 | 
            +
                             use_local_lora=False,
         | 
| 641 | 
            +
                             use_dera=False,
         | 
| 642 | 
            +
                             dera_rank=None,
         | 
| 643 | 
            +
                             use_dera_spatial=True,
         | 
| 644 | 
            +
                             use_dera_temporal=True,
         | 
| 645 | 
            +
                             use_sequence_cond=False,
         | 
| 646 | 
            +
                             sequence_cond_in_dim=None,
         | 
| 647 | 
            +
                             sequence_cond_mode=None,
         | 
| 648 | 
            +
                             use_channel_cond=False,
         | 
| 649 | 
            +
                             channel_cond_in_dim=None,
         | 
| 650 | 
            +
                             use_sequence_cond_position_aware_residual=False,
         | 
| 651 | 
            +
                             use_sequence_cond_loss=False
         | 
| 652 | 
            +
                             ):
         | 
| 653 | 
            +
                    super().__init__()
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    assert model_type in ['t2v', 'i2v']
         | 
| 656 | 
            +
                    self.model_type = model_type
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                    self.patch_size = patch_size
         | 
| 659 | 
            +
                    self.text_len = text_len
         | 
| 660 | 
            +
                    self.in_dim = in_dim
         | 
| 661 | 
            +
                    self.dim = dim
         | 
| 662 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 663 | 
            +
                    self.freq_dim = freq_dim
         | 
| 664 | 
            +
                    self.text_dim = text_dim
         | 
| 665 | 
            +
                    self.out_dim = out_dim
         | 
| 666 | 
            +
                    self.num_heads = num_heads
         | 
| 667 | 
            +
                    self.num_layers = num_layers
         | 
| 668 | 
            +
                    self.window_size = window_size
         | 
| 669 | 
            +
                    self.qk_norm = qk_norm
         | 
| 670 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 671 | 
            +
                    self.eps = eps
         | 
| 672 | 
            +
                    
         | 
| 673 | 
            +
                    self.use_local_lora = use_local_lora
         | 
| 674 | 
            +
                    self.use_dera = use_dera
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                    # embeddings
         | 
| 677 | 
            +
                    self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
         | 
| 678 | 
            +
                    self.text_embedding = nn.Sequential(
         | 
| 679 | 
            +
                        nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
         | 
| 680 | 
            +
                        nn.Linear(dim, dim))
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                    self.time_embedding = nn.Sequential(
         | 
| 683 | 
            +
                        nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
         | 
| 684 | 
            +
                    self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                    # blocks
         | 
| 687 | 
            +
                    cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
         | 
| 688 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 689 | 
            +
                        WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
         | 
| 690 | 
            +
                                           window_size, qk_norm, cross_attn_norm, eps, use_local_lora=use_local_lora, 
         | 
| 691 | 
            +
                                           use_dera=use_dera, dera_rank=dera_rank, use_dera_spatial=use_dera_spatial, use_dera_temporal=use_dera_temporal)
         | 
| 692 | 
            +
                        for _ in range(num_layers)
         | 
| 693 | 
            +
                    ])
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    # head
         | 
| 696 | 
            +
                    self.head = Head(dim, out_dim, patch_size, eps)
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    # buffers (don't use register_buffer otherwise dtype will be changed in to())
         | 
| 699 | 
            +
                    assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
         | 
| 700 | 
            +
                    d = dim // num_heads
         | 
| 701 | 
            +
                    self.freqs = torch.cat([
         | 
| 702 | 
            +
                        rope_params(1024, d - 4 * (d // 6)),
         | 
| 703 | 
            +
                        rope_params(1024, 2 * (d // 6)),
         | 
| 704 | 
            +
                        rope_params(1024, 2 * (d // 6))
         | 
| 705 | 
            +
                    ], dim=1)
         | 
| 706 | 
            +
                    
         | 
| 707 | 
            +
                    if self.use_dera:
         | 
| 708 | 
            +
                        dera_d = dera_rank // 4  # (18)
         | 
| 709 | 
            +
                        self.dera_freqs = torch.cat([
         | 
| 710 | 
            +
                            rope_params(1024, dera_d - 4 * (dera_d // 6)),
         | 
| 711 | 
            +
                            rope_params(1024, 2 * (dera_d // 6)),
         | 
| 712 | 
            +
                            rope_params(1024, 2 * (dera_d // 6))
         | 
| 713 | 
            +
                        ], dim=1)
         | 
| 714 | 
            +
                    else:
         | 
| 715 | 
            +
                        self.dera_freqs = None
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                    if model_type == 'i2v':
         | 
| 718 | 
            +
                        self.img_emb = MLPProj(1280, dim)
         | 
| 719 | 
            +
                        
         | 
| 720 | 
            +
                    self.init_weights()
         | 
| 721 | 
            +
                    
         | 
| 722 | 
            +
                    self.use_sequence_cond = use_sequence_cond
         | 
| 723 | 
            +
                    self.sequence_cond_in_dim = sequence_cond_in_dim
         | 
| 724 | 
            +
                    self.sequence_cond_mode = sequence_cond_mode
         | 
| 725 | 
            +
                    if use_sequence_cond:
         | 
| 726 | 
            +
                        assert sequence_cond_in_dim is not None, "`sequence_cond_in_dim` must be provided when `use_sequence_cond` is True"
         | 
| 727 | 
            +
                        self.sequence_cond_patch_embedding = nn.Conv3d(sequence_cond_in_dim, dim, kernel_size=patch_size, stride=patch_size)
         | 
| 728 | 
            +
                        self.sequence_cond_identifier = nn.Parameter(torch.randn(1, 1, dim) / dim**0.5)
         | 
| 729 | 
            +
                    
         | 
| 730 | 
            +
                    self.use_channel_cond = use_channel_cond
         | 
| 731 | 
            +
                    self.channel_cond_in_dim = channel_cond_in_dim
         | 
| 732 | 
            +
                    if use_channel_cond:
         | 
| 733 | 
            +
                        assert channel_cond_in_dim is not None, "`channel_cond_in_dim` must be provided when `use_channel_cond` is True"
         | 
| 734 | 
            +
                    self.use_sequence_cond_position_aware_residual = use_sequence_cond_position_aware_residual
         | 
| 735 | 
            +
                    if use_sequence_cond_position_aware_residual:
         | 
| 736 | 
            +
                        self.sequence_cond_residual_proj = nn.Linear(dim, dim, bias=False)
         | 
| 737 | 
            +
                        self.sequence_cond_residual_proj.weight.data.zero_()
         | 
| 738 | 
            +
                    
         | 
| 739 | 
            +
                    self.use_sequence_cond_loss = use_sequence_cond_loss
         | 
| 740 | 
            +
                    if self.use_sequence_cond_loss:
         | 
| 741 | 
            +
                        self.sequence_latent_to_cond_proj = nn.Linear(dim, dim, bias=False)
         | 
| 742 | 
            +
                        self.sequence_latent_to_cond_proj.weight.data.zero_()
         | 
| 743 | 
            +
                        self.head_sequence_cond_out = nn.Linear(dim, math.prod(patch_size) * out_dim)
         | 
| 744 | 
            +
                
         | 
| 745 | 
            +
                def copy_sequence_cond_patch_embedding_weights(self):
         | 
| 746 | 
            +
                    size_patch_embedding = self.patch_embedding.weight.size(1)
         | 
| 747 | 
            +
                    size_sequence_cond_patch_embedding = self.sequence_cond_patch_embedding.weight.size(1)
         | 
| 748 | 
            +
                    self.sequence_cond_patch_embedding.weight.data = self.patch_embedding.weight.data[:, size_patch_embedding - size_sequence_cond_patch_embedding:, :, :, :].clone()
         | 
| 749 | 
            +
                    if self.patch_embedding.bias is not None:
         | 
| 750 | 
            +
                        self.sequence_cond_patch_embedding.bias.data = self.patch_embedding.bias.data.clone()
         | 
| 751 | 
            +
                
         | 
| 752 | 
            +
                def copy_patch_embedding_weights_for_channel_cond(self):
         | 
| 753 | 
            +
                    original_patch_in_channels = self.patch_embedding.in_channels
         | 
| 754 | 
            +
                    new_patch_embedding = nn.Conv3d(in_channels=original_patch_in_channels + self.channel_cond_in_dim, 
         | 
| 755 | 
            +
                                                    out_channels=self.dim, kernel_size=self.patch_size, stride=self.patch_size)
         | 
| 756 | 
            +
                    new_patch_embedding.weight.data[:, :original_patch_in_channels, :, :, :] = self.patch_embedding.weight.data.clone()
         | 
| 757 | 
            +
                    if self.patch_embedding.bias is not None:
         | 
| 758 | 
            +
                        new_patch_embedding.bias.data = self.patch_embedding.bias.data.clone()
         | 
| 759 | 
            +
                    del self.patch_embedding
         | 
| 760 | 
            +
                    self.patch_embedding = new_patch_embedding
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                def forward(
         | 
| 763 | 
            +
                    self,
         | 
| 764 | 
            +
                    x,
         | 
| 765 | 
            +
                    timestep,
         | 
| 766 | 
            +
                    context,
         | 
| 767 | 
            +
                    seq_len,
         | 
| 768 | 
            +
                    clip_fea=None,
         | 
| 769 | 
            +
                    y=None,
         | 
| 770 | 
            +
                    use_gradient_checkpointing=False,
         | 
| 771 | 
            +
                    sequence_cond=None,
         | 
| 772 | 
            +
                    sequence_cond_compressed_indices=None,
         | 
| 773 | 
            +
                    channel_cond=None,
         | 
| 774 | 
            +
                    sequence_cond_residual_scale=1.0,
         | 
| 775 | 
            +
                    **kwargs,
         | 
| 776 | 
            +
                ):
         | 
| 777 | 
            +
                    """
         | 
| 778 | 
            +
                    x:              A list of videos each with shape [C, T, H, W].
         | 
| 779 | 
            +
                    t:              [B].
         | 
| 780 | 
            +
                    context:        A list of text embeddings each with shape [L, C].
         | 
| 781 | 
            +
                    sequence_cond: A list of conditional frames each with shape [C, T_sequence_cond, H, W].
         | 
| 782 | 
            +
                    sequence_cond_compressed_indices: [B, T_sequence_cond] Indices for any additional conditional information, where T_sequence_cond < T. For sparse mode only.
         | 
| 783 | 
            +
                    
         | 
| 784 | 
            +
                    
         | 
| 785 | 
            +
                        Note:
         | 
| 786 | 
            +
                            sequence_cond will be injected into the model as an additional input sequence, i.e., sequence dimension.
         | 
| 787 | 
            +
                            channel_cond will be injected into the model in the input' channel dimension.
         | 
| 788 | 
            +
                            
         | 
| 789 | 
            +
                        Examples:
         | 
| 790 | 
            +
                            1) for extra cond case:
         | 
| 791 | 
            +
                                # given x: [B, C, T, H, W] ----> [B, L=T*H*W, C] --patch_embedding--> [B, L, D]
         | 
| 792 | 
            +
                                # sequence_cond: [B, C_sequence_cond, T_sequence_cond, H, W] ----> [B, L_sequence_cond=T_sequence_cond*H*W, C_sequence_cond] --sequence_cond_embedding--> [B, L_sequence_cond, D]
         | 
| 793 | 
            +
                                x = torch.concat([x, sequence_cond], dim=2) # Concat on sequence dimension after patch/extra cond embedding
         | 
| 794 | 
            +
                                # after concat, x: [B, L+L_sequence_cond, D]
         | 
| 795 | 
            +
                            2) for channel cond case:
         | 
| 796 | 
            +
                                given x: [B, C, T, H, W]
         | 
| 797 | 
            +
                                channel_cond: [B, C_CHANNEL_COND, T, H, W]
         | 
| 798 | 
            +
                                x = torch.concat([x, channel_cond], dim=1) # Concat on channel dimension before patch/extra cond embedding    
         | 
| 799 | 
            +
                                # x: [B, C + C_CHANNEL_COND, T, H, W] --patch_embedding(requires param copy and tuning)--> [B, L=T*H*W, D]
         | 
| 800 | 
            +
                    """
         | 
| 801 | 
            +
                    if self.model_type == 'i2v':
         | 
| 802 | 
            +
                        assert clip_fea is not None and y is not None
         | 
| 803 | 
            +
                    # params
         | 
| 804 | 
            +
                    device = x[0].device
         | 
| 805 | 
            +
                    if self.freqs.device != device:
         | 
| 806 | 
            +
                        self.freqs = self.freqs.to(device)
         | 
| 807 | 
            +
                    if self.dera_freqs is not None and self.dera_freqs.device != device:
         | 
| 808 | 
            +
                        self.dera_freqs = self.dera_freqs.to(device)
         | 
| 809 | 
            +
                        
         | 
| 810 | 
            +
                    if y is not None:
         | 
| 811 | 
            +
                        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 812 | 
            +
                    
         | 
| 813 | 
            +
                    if channel_cond is not None:
         | 
| 814 | 
            +
                        assert self.use_channel_cond, "forward argument `channel_cond` is provided but model property `self.use_channel_cond` is False"
         | 
| 815 | 
            +
                        x = [torch.cat([u, v], dim=0) for u, v in zip(x, channel_cond)]
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                    # embeddings
         | 
| 818 | 
            +
                    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 819 | 
            +
                    grid_sizes = torch.stack(
         | 
| 820 | 
            +
                        [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 821 | 
            +
                    x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 822 | 
            +
                    x = torch.cat(x, dim=0)
         | 
| 823 | 
            +
                    
         | 
| 824 | 
            +
                    if sequence_cond is not None:
         | 
| 825 | 
            +
                        assert self.use_sequence_cond, "forward argument `sequence_cond` is provided but model property `self.use_sequence_cond` is False"
         | 
| 826 | 
            +
                        sequence_cond = [self.sequence_cond_patch_embedding(u.unsqueeze(0)) for u in sequence_cond]
         | 
| 827 | 
            +
                        sequence_cond = [u.flatten(2).transpose(1, 2) + self.sequence_cond_identifier for u in sequence_cond]
         | 
| 828 | 
            +
                        sequence_cond = torch.concat(sequence_cond, dim=0)
         | 
| 829 | 
            +
                        
         | 
| 830 | 
            +
                        x = torch.concat([x, sequence_cond], dim=1)
         | 
| 831 | 
            +
                    
         | 
| 832 | 
            +
                    actual_seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 833 | 
            +
                    
         | 
| 834 | 
            +
                    # time embeddings
         | 
| 835 | 
            +
                    with amp.autocast(dtype=torch.float32, device_type="cuda"):
         | 
| 836 | 
            +
                        e = self.time_embedding(
         | 
| 837 | 
            +
                            sinusoidal_embedding_1d(self.freq_dim, timestep).float())
         | 
| 838 | 
            +
                        e0 = self.time_projection(e).unflatten(1, (6, self.dim))
         | 
| 839 | 
            +
                        assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                    # context
         | 
| 842 | 
            +
                    context_lens = None
         | 
| 843 | 
            +
                    context = self.text_embedding(
         | 
| 844 | 
            +
                        torch.stack([
         | 
| 845 | 
            +
                            torch.cat(
         | 
| 846 | 
            +
                                [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 847 | 
            +
                            for u in context
         | 
| 848 | 
            +
                        ]))
         | 
| 849 | 
            +
             | 
| 850 | 
            +
                    if clip_fea is not None:
         | 
| 851 | 
            +
                        context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
         | 
| 852 | 
            +
                        context = torch.concat([context_clip, context], dim=1)
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                    # arguments
         | 
| 855 | 
            +
                    kwargs = dict(e=e0, seq_lens=actual_seq_lens, grid_sizes=grid_sizes,
         | 
| 856 | 
            +
                        freqs=self.freqs, context=context, context_lens=context_lens,
         | 
| 857 | 
            +
                        sequence_cond_compressed_indices=sequence_cond_compressed_indices, dera_freqs=self.dera_freqs)
         | 
| 858 | 
            +
                    
         | 
| 859 | 
            +
                    def create_custom_forward(module):
         | 
| 860 | 
            +
                        def custom_forward(*inputs, **kwargs):
         | 
| 861 | 
            +
                            return module(*inputs, **kwargs)
         | 
| 862 | 
            +
                        return custom_forward
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                    for block_idx, block in enumerate(self.blocks):
         | 
| 865 | 
            +
                        if self.training and use_gradient_checkpointing:
         | 
| 866 | 
            +
                            x = torch.utils.checkpoint.checkpoint(
         | 
| 867 | 
            +
                                create_custom_forward(block),
         | 
| 868 | 
            +
                                x, **kwargs,
         | 
| 869 | 
            +
                                use_reentrant=False,
         | 
| 870 | 
            +
                            )
         | 
| 871 | 
            +
                        else:
         | 
| 872 | 
            +
                            x = block(x, **kwargs)
         | 
| 873 | 
            +
                            
         | 
| 874 | 
            +
                        if self.use_sequence_cond_loss and block_idx == len(self.blocks) - 3:
         | 
| 875 | 
            +
                            # This this function, the context length will be extended from (N+C) to 2N, where C is the length of the sparse sequence cond.
         | 
| 876 | 
            +
                            x_ori = x[:, :seq_len, :]
         | 
| 877 | 
            +
                            x_ori_projected = self.sequence_latent_to_cond_proj(x_ori)
         | 
| 878 | 
            +
                            x_seq_cond = x[:, seq_len:, :]
         | 
| 879 | 
            +
                            seq_cond_length = len(sequence_cond_compressed_indices[0])
         | 
| 880 | 
            +
                            x_ori_projected = rearrange(x_ori_projected, 'b (t h w) c -> b c t h w', t=grid_sizes[0, 0], h=grid_sizes[0, 1], w=grid_sizes[0, 2])
         | 
| 881 | 
            +
                            x_seq_cond = rearrange(x_seq_cond, 'b (t h w) c -> b c t h w', t=seq_cond_length, h=grid_sizes[0, 1], w=grid_sizes[0, 2])
         | 
| 882 | 
            +
                            x_ori_projected[:, :, sequence_cond_compressed_indices[0], :, :] += x_seq_cond
         | 
| 883 | 
            +
                            x_ori_projected = rearrange(x_ori_projected, 'b c t h w -> b (t h w) c')
         | 
| 884 | 
            +
                            x = torch.concat([x_ori, x_ori_projected], dim=1)
         | 
| 885 | 
            +
                            # Let the later blocks generate sketches at the full seqeuence length
         | 
| 886 | 
            +
                            
         | 
| 887 | 
            +
                        if self.use_sequence_cond_position_aware_residual and block_idx < len(self.blocks) - 1:
         | 
| 888 | 
            +
                            # Apply the sequence condition position-aware residual for all blocks except the last one
         | 
| 889 | 
            +
                            x_ori = x[:, :seq_len, :]
         | 
| 890 | 
            +
                            x_seq_cond = x[:, seq_len:, :]
         | 
| 891 | 
            +
                            x_seq_cond_porjected = self.sequence_cond_residual_proj(x_seq_cond)
         | 
| 892 | 
            +
                            assert x_ori.shape[0] == 1, "Only support batch size 1 for `sequence_cond_position_aware_residual`."
         | 
| 893 | 
            +
                            seq_cond_length = len(sequence_cond_compressed_indices[0])
         | 
| 894 | 
            +
                            x_ori = rearrange(x_ori, 'b (t h w) c -> b c t h w', t=grid_sizes[0, 0], h=grid_sizes[0, 1], w=grid_sizes[0, 2])
         | 
| 895 | 
            +
                            x_seq_cond_porjected = rearrange(x_seq_cond_porjected, 'b (t h w) c -> b c t h w', t=seq_cond_length, h=grid_sizes[0, 1], w=grid_sizes[0, 2])
         | 
| 896 | 
            +
                            
         | 
| 897 | 
            +
                            x_ori[:, :, sequence_cond_compressed_indices[0], :, :] = x_ori[:, :, sequence_cond_compressed_indices[0], :, :] + x_seq_cond_porjected * sequence_cond_residual_scale
         | 
| 898 | 
            +
                            x_ori = rearrange(x_ori, 'b c t h w -> b (t h w) c')
         | 
| 899 | 
            +
                            x = torch.concat([x_ori, x_seq_cond], dim=1)
         | 
| 900 | 
            +
                            
         | 
| 901 | 
            +
                    if sequence_cond is not None:
         | 
| 902 | 
            +
                        if self.use_sequence_cond_loss:
         | 
| 903 | 
            +
                            sequence_cond_out = x[:, seq_len:, :]
         | 
| 904 | 
            +
                            sequence_cond_out = self.unpatchify(sequence_cond_out, grid_sizes)  # sequence_cond_grid_sizes
         | 
| 905 | 
            +
                            sequence_cond_out = torch.stack(sequence_cond_out).float()  # b, c, t, h, w
         | 
| 906 | 
            +
                        else:
         | 
| 907 | 
            +
                            sequence_cond_out = None
         | 
| 908 | 
            +
                        x = x[:, :seq_len, :]
         | 
| 909 | 
            +
                    # head
         | 
| 910 | 
            +
                    x = self.head(x, e)
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                    # unpatchify
         | 
| 913 | 
            +
                    x = self.unpatchify(x, grid_sizes)
         | 
| 914 | 
            +
                    x = torch.stack(x).float()
         | 
| 915 | 
            +
                    if sequence_cond is not None and self.use_sequence_cond_loss:
         | 
| 916 | 
            +
                        return x, sequence_cond_out
         | 
| 917 | 
            +
                    return x
         | 
| 918 | 
            +
             | 
| 919 | 
            +
                def unpatchify(self, x, grid_sizes):
         | 
| 920 | 
            +
                    c = self.out_dim
         | 
| 921 | 
            +
                    out = []
         | 
| 922 | 
            +
                    for u, v in zip(x, grid_sizes.tolist()):
         | 
| 923 | 
            +
                        u = u[:math.prod(v)].view(*v, *self.patch_size, c)
         | 
| 924 | 
            +
                        u = torch.einsum('fhwpqrc->cfphqwr', u)
         | 
| 925 | 
            +
                        u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
         | 
| 926 | 
            +
                        out.append(u)
         | 
| 927 | 
            +
                    return out
         | 
| 928 | 
            +
             | 
| 929 | 
            +
                def init_weights(self):
         | 
| 930 | 
            +
                    for m in self.modules():
         | 
| 931 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 932 | 
            +
                            nn.init.xavier_uniform_(m.weight)
         | 
| 933 | 
            +
                            if m.bias is not None:
         | 
| 934 | 
            +
                                nn.init.zeros_(m.bias)
         | 
| 935 | 
            +
             | 
| 936 | 
            +
                    # init embeddings
         | 
| 937 | 
            +
                    nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
         | 
| 938 | 
            +
                    for m in self.text_embedding.modules():
         | 
| 939 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 940 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 941 | 
            +
                    for m in self.time_embedding.modules():
         | 
| 942 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 943 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 944 | 
            +
             | 
| 945 | 
            +
                    # init output layer
         | 
| 946 | 
            +
                    nn.init.zeros_(self.head.head.weight)
         | 
| 947 | 
            +
             | 
| 948 | 
            +
                @staticmethod
         | 
| 949 | 
            +
                def state_dict_converter():
         | 
| 950 | 
            +
                    return WanModelStateDictConverter()
         | 
| 951 | 
            +
                
         | 
| 952 | 
            +
                
         | 
| 953 | 
            +
            class WanModelStateDictConverter:
         | 
| 954 | 
            +
                def __init__(self):
         | 
| 955 | 
            +
                    pass
         | 
| 956 | 
            +
             | 
| 957 | 
            +
                def from_diffusers(self, state_dict):
         | 
| 958 | 
            +
                    rename_dict = {"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
         | 
| 959 | 
            +
                        "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
         | 
| 960 | 
            +
                        "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
         | 
| 961 | 
            +
                        "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
         | 
| 962 | 
            +
                        "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
         | 
| 963 | 
            +
                        "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
         | 
| 964 | 
            +
                        "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
         | 
| 965 | 
            +
                        "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
         | 
| 966 | 
            +
                        "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
         | 
| 967 | 
            +
                        "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
         | 
| 968 | 
            +
                        "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
         | 
| 969 | 
            +
                        "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
         | 
| 970 | 
            +
                        "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
         | 
| 971 | 
            +
                        "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
         | 
| 972 | 
            +
                        "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
         | 
| 973 | 
            +
                        "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
         | 
| 974 | 
            +
                        "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
         | 
| 975 | 
            +
                        "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
         | 
| 976 | 
            +
                        "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
         | 
| 977 | 
            +
                        "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
         | 
| 978 | 
            +
                        "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
         | 
| 979 | 
            +
                        "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
         | 
| 980 | 
            +
                        "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
         | 
| 981 | 
            +
                        "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
         | 
| 982 | 
            +
                        "blocks.0.norm2.bias": "blocks.0.norm3.bias",
         | 
| 983 | 
            +
                        "blocks.0.norm2.weight": "blocks.0.norm3.weight",
         | 
| 984 | 
            +
                        "blocks.0.scale_shift_table": "blocks.0.modulation",
         | 
| 985 | 
            +
                        "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
         | 
| 986 | 
            +
                        "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
         | 
| 987 | 
            +
                        "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
         | 
| 988 | 
            +
                        "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
         | 
| 989 | 
            +
                        "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
         | 
| 990 | 
            +
                        "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
         | 
| 991 | 
            +
                        "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
         | 
| 992 | 
            +
                        "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
         | 
| 993 | 
            +
                        "condition_embedder.time_proj.bias": "time_projection.1.bias",
         | 
| 994 | 
            +
                        "condition_embedder.time_proj.weight": "time_projection.1.weight",
         | 
| 995 | 
            +
                        "patch_embedding.bias": "patch_embedding.bias",
         | 
| 996 | 
            +
                        "patch_embedding.weight": "patch_embedding.weight",
         | 
| 997 | 
            +
                        "scale_shift_table": "head.modulation",
         | 
| 998 | 
            +
                        "proj_out.bias": "head.head.bias",
         | 
| 999 | 
            +
                        "proj_out.weight": "head.head.weight",
         | 
| 1000 | 
            +
                    }
         | 
| 1001 | 
            +
                    state_dict_ = {}
         | 
| 1002 | 
            +
                    for name, param in state_dict.items():
         | 
| 1003 | 
            +
                        if name in rename_dict:
         | 
| 1004 | 
            +
                            state_dict_[rename_dict[name]] = param
         | 
| 1005 | 
            +
                        else:
         | 
| 1006 | 
            +
                            name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
         | 
| 1007 | 
            +
                            if name_ in rename_dict:
         | 
| 1008 | 
            +
                                name_ = rename_dict[name_]
         | 
| 1009 | 
            +
                                name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
         | 
| 1010 | 
            +
                                state_dict_[name_] = param
         | 
| 1011 | 
            +
                    if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
         | 
| 1012 | 
            +
                        config = {
         | 
| 1013 | 
            +
                            "model_type": "t2v",
         | 
| 1014 | 
            +
                            "patch_size": (1, 2, 2),
         | 
| 1015 | 
            +
                            "text_len": 512,
         | 
| 1016 | 
            +
                            "in_dim": 16,
         | 
| 1017 | 
            +
                            "dim": 5120,
         | 
| 1018 | 
            +
                            "ffn_dim": 13824,
         | 
| 1019 | 
            +
                            "freq_dim": 256,
         | 
| 1020 | 
            +
                            "text_dim": 4096,
         | 
| 1021 | 
            +
                            "out_dim": 16,
         | 
| 1022 | 
            +
                            "num_heads": 40,
         | 
| 1023 | 
            +
                            "num_layers": 40,
         | 
| 1024 | 
            +
                            "window_size": (-1, -1),
         | 
| 1025 | 
            +
                            "qk_norm": True,
         | 
| 1026 | 
            +
                            "cross_attn_norm": True,
         | 
| 1027 | 
            +
                            "eps": 1e-6,
         | 
| 1028 | 
            +
                        }
         | 
| 1029 | 
            +
                    else:
         | 
| 1030 | 
            +
                        config = {}
         | 
| 1031 | 
            +
                    return state_dict_, config
         | 
| 1032 | 
            +
                
         | 
| 1033 | 
            +
                def from_civitai(self, state_dict):
         | 
| 1034 | 
            +
                    if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
         | 
| 1035 | 
            +
                        config = {
         | 
| 1036 | 
            +
                            "model_type": "t2v",
         | 
| 1037 | 
            +
                            "patch_size": (1, 2, 2),
         | 
| 1038 | 
            +
                            "text_len": 512,
         | 
| 1039 | 
            +
                            "in_dim": 16,
         | 
| 1040 | 
            +
                            "dim": 1536,
         | 
| 1041 | 
            +
                            "ffn_dim": 8960,
         | 
| 1042 | 
            +
                            "freq_dim": 256,
         | 
| 1043 | 
            +
                            "text_dim": 4096,
         | 
| 1044 | 
            +
                            "out_dim": 16,
         | 
| 1045 | 
            +
                            "num_heads": 12,
         | 
| 1046 | 
            +
                            "num_layers": 30,
         | 
| 1047 | 
            +
                            "window_size": (-1, -1),
         | 
| 1048 | 
            +
                            "qk_norm": True,
         | 
| 1049 | 
            +
                            "cross_attn_norm": True,
         | 
| 1050 | 
            +
                            "eps": 1e-6,
         | 
| 1051 | 
            +
                        }
         | 
| 1052 | 
            +
                    elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
         | 
| 1053 | 
            +
                        config = {
         | 
| 1054 | 
            +
                            "model_type": "t2v",
         | 
| 1055 | 
            +
                            "patch_size": (1, 2, 2),
         | 
| 1056 | 
            +
                            "text_len": 512,
         | 
| 1057 | 
            +
                            "in_dim": 16,
         | 
| 1058 | 
            +
                            "dim": 5120,
         | 
| 1059 | 
            +
                            "ffn_dim": 13824,
         | 
| 1060 | 
            +
                            "freq_dim": 256,
         | 
| 1061 | 
            +
                            "text_dim": 4096,
         | 
| 1062 | 
            +
                            "out_dim": 16,
         | 
| 1063 | 
            +
                            "num_heads": 40,
         | 
| 1064 | 
            +
                            "num_layers": 40,
         | 
| 1065 | 
            +
                            "window_size": (-1, -1),
         | 
| 1066 | 
            +
                            "qk_norm": True,
         | 
| 1067 | 
            +
                            "cross_attn_norm": True,
         | 
| 1068 | 
            +
                            "eps": 1e-6,
         | 
| 1069 | 
            +
                        }
         | 
| 1070 | 
            +
                    elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
         | 
| 1071 | 
            +
                        config = {
         | 
| 1072 | 
            +
                            "model_type": "i2v",
         | 
| 1073 | 
            +
                            "patch_size": (1, 2, 2),
         | 
| 1074 | 
            +
                            "text_len": 512,
         | 
| 1075 | 
            +
                            "in_dim": 36,
         | 
| 1076 | 
            +
                            "dim": 5120,
         | 
| 1077 | 
            +
                            "ffn_dim": 13824,
         | 
| 1078 | 
            +
                            "freq_dim": 256,
         | 
| 1079 | 
            +
                            "text_dim": 4096,
         | 
| 1080 | 
            +
                            "out_dim": 16,
         | 
| 1081 | 
            +
                            "num_heads": 40,
         | 
| 1082 | 
            +
                            "num_layers": 40,
         | 
| 1083 | 
            +
                            "window_size": (-1, -1),
         | 
| 1084 | 
            +
                            "qk_norm": True,
         | 
| 1085 | 
            +
                            "cross_attn_norm": True,
         | 
| 1086 | 
            +
                            "eps": 1e-6,
         | 
| 1087 | 
            +
                        }
         | 
| 1088 | 
            +
                    else:
         | 
| 1089 | 
            +
                        config = {}
         | 
| 1090 | 
            +
                    return state_dict, config
         | 
    	
        model/image_encoder.py
    ADDED
    
    | @@ -0,0 +1,903 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Concise re-implementation of
         | 
| 3 | 
            +
            ``https://github.com/openai/CLIP'' and
         | 
| 4 | 
            +
            ``https://github.com/mlfoundations/open_clip''.
         | 
| 5 | 
            +
            """
         | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn as nn
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
            import torchvision.transforms as T
         | 
| 11 | 
            +
            from .dit import flash_attention
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class SelfAttention(nn.Module):
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
         | 
| 17 | 
            +
                    assert dim % num_heads == 0
         | 
| 18 | 
            +
                    super().__init__()
         | 
| 19 | 
            +
                    self.dim = dim
         | 
| 20 | 
            +
                    self.num_heads = num_heads
         | 
| 21 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 22 | 
            +
                    self.eps = eps
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    # layers
         | 
| 25 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 26 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 27 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 28 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 29 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def forward(self, x, mask):
         | 
| 32 | 
            +
                    """
         | 
| 33 | 
            +
                    x:   [B, L, C].
         | 
| 34 | 
            +
                    """
         | 
| 35 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    # compute query, key, value
         | 
| 38 | 
            +
                    q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 39 | 
            +
                    k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 40 | 
            +
                    v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    # compute attention
         | 
| 43 | 
            +
                    p = self.dropout.p if self.training else 0.0
         | 
| 44 | 
            +
                    x = F.scaled_dot_product_attention(q, k, v, mask, p)
         | 
| 45 | 
            +
                    x = x.permute(0, 2, 1, 3).reshape(b, s, c)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    # output
         | 
| 48 | 
            +
                    x = self.o(x)
         | 
| 49 | 
            +
                    x = self.dropout(x)
         | 
| 50 | 
            +
                    return x
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
         | 
| 56 | 
            +
                    super().__init__()
         | 
| 57 | 
            +
                    self.dim = dim
         | 
| 58 | 
            +
                    self.num_heads = num_heads
         | 
| 59 | 
            +
                    self.post_norm = post_norm
         | 
| 60 | 
            +
                    self.eps = eps
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # layers
         | 
| 63 | 
            +
                    self.attn = SelfAttention(dim, num_heads, dropout, eps)
         | 
| 64 | 
            +
                    self.norm1 = nn.LayerNorm(dim, eps=eps)
         | 
| 65 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 66 | 
            +
                        nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
         | 
| 67 | 
            +
                        nn.Dropout(dropout))
         | 
| 68 | 
            +
                    self.norm2 = nn.LayerNorm(dim, eps=eps)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def forward(self, x, mask):
         | 
| 71 | 
            +
                    if self.post_norm:
         | 
| 72 | 
            +
                        x = self.norm1(x + self.attn(x, mask))
         | 
| 73 | 
            +
                        x = self.norm2(x + self.ffn(x))
         | 
| 74 | 
            +
                    else:
         | 
| 75 | 
            +
                        x = x + self.attn(self.norm1(x), mask)
         | 
| 76 | 
            +
                        x = x + self.ffn(self.norm2(x))
         | 
| 77 | 
            +
                    return x
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            class XLMRoberta(nn.Module):
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                XLMRobertaModel with no pooler and no LM head.
         | 
| 83 | 
            +
                """
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def __init__(self,
         | 
| 86 | 
            +
                             vocab_size=250002,
         | 
| 87 | 
            +
                             max_seq_len=514,
         | 
| 88 | 
            +
                             type_size=1,
         | 
| 89 | 
            +
                             pad_id=1,
         | 
| 90 | 
            +
                             dim=1024,
         | 
| 91 | 
            +
                             num_heads=16,
         | 
| 92 | 
            +
                             num_layers=24,
         | 
| 93 | 
            +
                             post_norm=True,
         | 
| 94 | 
            +
                             dropout=0.1,
         | 
| 95 | 
            +
                             eps=1e-5):
         | 
| 96 | 
            +
                    super().__init__()
         | 
| 97 | 
            +
                    self.vocab_size = vocab_size
         | 
| 98 | 
            +
                    self.max_seq_len = max_seq_len
         | 
| 99 | 
            +
                    self.type_size = type_size
         | 
| 100 | 
            +
                    self.pad_id = pad_id
         | 
| 101 | 
            +
                    self.dim = dim
         | 
| 102 | 
            +
                    self.num_heads = num_heads
         | 
| 103 | 
            +
                    self.num_layers = num_layers
         | 
| 104 | 
            +
                    self.post_norm = post_norm
         | 
| 105 | 
            +
                    self.eps = eps
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # embeddings
         | 
| 108 | 
            +
                    self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
         | 
| 109 | 
            +
                    self.type_embedding = nn.Embedding(type_size, dim)
         | 
| 110 | 
            +
                    self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
         | 
| 111 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # blocks
         | 
| 114 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 115 | 
            +
                        AttentionBlock(dim, num_heads, post_norm, dropout, eps)
         | 
| 116 | 
            +
                        for _ in range(num_layers)
         | 
| 117 | 
            +
                    ])
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # norm layer
         | 
| 120 | 
            +
                    self.norm = nn.LayerNorm(dim, eps=eps)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def forward(self, ids):
         | 
| 123 | 
            +
                    """
         | 
| 124 | 
            +
                    ids: [B, L] of torch.LongTensor.
         | 
| 125 | 
            +
                    """
         | 
| 126 | 
            +
                    b, s = ids.shape
         | 
| 127 | 
            +
                    mask = ids.ne(self.pad_id).long()
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # embeddings
         | 
| 130 | 
            +
                    x = self.token_embedding(ids) + \
         | 
| 131 | 
            +
                        self.type_embedding(torch.zeros_like(ids)) + \
         | 
| 132 | 
            +
                        self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
         | 
| 133 | 
            +
                    if self.post_norm:
         | 
| 134 | 
            +
                        x = self.norm(x)
         | 
| 135 | 
            +
                    x = self.dropout(x)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # blocks
         | 
| 138 | 
            +
                    mask = torch.where(
         | 
| 139 | 
            +
                        mask.view(b, 1, 1, s).gt(0), 0.0,
         | 
| 140 | 
            +
                        torch.finfo(x.dtype).min)
         | 
| 141 | 
            +
                    for block in self.blocks:
         | 
| 142 | 
            +
                        x = block(x, mask)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # output
         | 
| 145 | 
            +
                    if not self.post_norm:
         | 
| 146 | 
            +
                        x = self.norm(x)
         | 
| 147 | 
            +
                    return x
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            def xlm_roberta_large(pretrained=False,
         | 
| 151 | 
            +
                                  return_tokenizer=False,
         | 
| 152 | 
            +
                                  device='cpu',
         | 
| 153 | 
            +
                                  **kwargs):
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                XLMRobertaLarge adapted from Huggingface.
         | 
| 156 | 
            +
                """
         | 
| 157 | 
            +
                # params
         | 
| 158 | 
            +
                cfg = dict(
         | 
| 159 | 
            +
                    vocab_size=250002,
         | 
| 160 | 
            +
                    max_seq_len=514,
         | 
| 161 | 
            +
                    type_size=1,
         | 
| 162 | 
            +
                    pad_id=1,
         | 
| 163 | 
            +
                    dim=1024,
         | 
| 164 | 
            +
                    num_heads=16,
         | 
| 165 | 
            +
                    num_layers=24,
         | 
| 166 | 
            +
                    post_norm=True,
         | 
| 167 | 
            +
                    dropout=0.1,
         | 
| 168 | 
            +
                    eps=1e-5)
         | 
| 169 | 
            +
                cfg.update(**kwargs)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                # init model
         | 
| 172 | 
            +
                if pretrained:
         | 
| 173 | 
            +
                    from sora import DOWNLOAD_TO_CACHE
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # init a meta model
         | 
| 176 | 
            +
                    with torch.device('meta'):
         | 
| 177 | 
            +
                        model = XLMRoberta(**cfg)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # load checkpoint
         | 
| 180 | 
            +
                    model.load_state_dict(
         | 
| 181 | 
            +
                        torch.load(
         | 
| 182 | 
            +
                            DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
         | 
| 183 | 
            +
                            map_location=device),
         | 
| 184 | 
            +
                        assign=True)
         | 
| 185 | 
            +
                else:
         | 
| 186 | 
            +
                    # init a model on device
         | 
| 187 | 
            +
                    with torch.device(device):
         | 
| 188 | 
            +
                        model = XLMRoberta(**cfg)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                # init tokenizer
         | 
| 191 | 
            +
                if return_tokenizer:
         | 
| 192 | 
            +
                    from sora.data import HuggingfaceTokenizer
         | 
| 193 | 
            +
                    tokenizer = HuggingfaceTokenizer(
         | 
| 194 | 
            +
                        name='xlm-roberta-large',
         | 
| 195 | 
            +
                        seq_len=model.text_len,
         | 
| 196 | 
            +
                        clean='whitespace')
         | 
| 197 | 
            +
                    return model, tokenizer
         | 
| 198 | 
            +
                else:
         | 
| 199 | 
            +
                    return model
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            def pos_interpolate(pos, seq_len):
         | 
| 204 | 
            +
                if pos.size(1) == seq_len:
         | 
| 205 | 
            +
                    return pos
         | 
| 206 | 
            +
                else:
         | 
| 207 | 
            +
                    src_grid = int(math.sqrt(pos.size(1)))
         | 
| 208 | 
            +
                    tar_grid = int(math.sqrt(seq_len))
         | 
| 209 | 
            +
                    n = pos.size(1) - src_grid * src_grid
         | 
| 210 | 
            +
                    return torch.cat([
         | 
| 211 | 
            +
                        pos[:, :n],
         | 
| 212 | 
            +
                        F.interpolate(
         | 
| 213 | 
            +
                            pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
         | 
| 214 | 
            +
                                0, 3, 1, 2),
         | 
| 215 | 
            +
                            size=(tar_grid, tar_grid),
         | 
| 216 | 
            +
                            mode='bicubic',
         | 
| 217 | 
            +
                            align_corners=False).flatten(2).transpose(1, 2)
         | 
| 218 | 
            +
                    ],
         | 
| 219 | 
            +
                                     dim=1)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
             | 
| 222 | 
            +
            class QuickGELU(nn.Module):
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def forward(self, x):
         | 
| 225 | 
            +
                    return x * torch.sigmoid(1.702 * x)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            class LayerNorm(nn.LayerNorm):
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def forward(self, x):
         | 
| 231 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            class SelfAttention(nn.Module):
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def __init__(self,
         | 
| 237 | 
            +
                             dim,
         | 
| 238 | 
            +
                             num_heads,
         | 
| 239 | 
            +
                             causal=False,
         | 
| 240 | 
            +
                             attn_dropout=0.0,
         | 
| 241 | 
            +
                             proj_dropout=0.0):
         | 
| 242 | 
            +
                    assert dim % num_heads == 0
         | 
| 243 | 
            +
                    super().__init__()
         | 
| 244 | 
            +
                    self.dim = dim
         | 
| 245 | 
            +
                    self.num_heads = num_heads
         | 
| 246 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 247 | 
            +
                    self.causal = causal
         | 
| 248 | 
            +
                    self.attn_dropout = attn_dropout
         | 
| 249 | 
            +
                    self.proj_dropout = proj_dropout
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    # layers
         | 
| 252 | 
            +
                    self.to_qkv = nn.Linear(dim, dim * 3)
         | 
| 253 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def forward(self, x):
         | 
| 256 | 
            +
                    """
         | 
| 257 | 
            +
                    x:   [B, L, C].
         | 
| 258 | 
            +
                    """
         | 
| 259 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    # compute query, key, value
         | 
| 262 | 
            +
                    q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    # compute attention
         | 
| 265 | 
            +
                    p = self.attn_dropout if self.training else 0.0
         | 
| 266 | 
            +
                    x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
         | 
| 267 | 
            +
                    x = x.reshape(b, s, c)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    # output
         | 
| 270 | 
            +
                    x = self.proj(x)
         | 
| 271 | 
            +
                    x = F.dropout(x, self.proj_dropout, self.training)
         | 
| 272 | 
            +
                    return x
         | 
| 273 | 
            +
             | 
| 274 | 
            +
             | 
| 275 | 
            +
            class SwiGLU(nn.Module):
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                def __init__(self, dim, mid_dim):
         | 
| 278 | 
            +
                    super().__init__()
         | 
| 279 | 
            +
                    self.dim = dim
         | 
| 280 | 
            +
                    self.mid_dim = mid_dim
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    # layers
         | 
| 283 | 
            +
                    self.fc1 = nn.Linear(dim, mid_dim)
         | 
| 284 | 
            +
                    self.fc2 = nn.Linear(dim, mid_dim)
         | 
| 285 | 
            +
                    self.fc3 = nn.Linear(mid_dim, dim)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                def forward(self, x):
         | 
| 288 | 
            +
                    x = F.silu(self.fc1(x)) * self.fc2(x)
         | 
| 289 | 
            +
                    x = self.fc3(x)
         | 
| 290 | 
            +
                    return x
         | 
| 291 | 
            +
             | 
| 292 | 
            +
             | 
| 293 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def __init__(self,
         | 
| 296 | 
            +
                             dim,
         | 
| 297 | 
            +
                             mlp_ratio,
         | 
| 298 | 
            +
                             num_heads,
         | 
| 299 | 
            +
                             post_norm=False,
         | 
| 300 | 
            +
                             causal=False,
         | 
| 301 | 
            +
                             activation='quick_gelu',
         | 
| 302 | 
            +
                             attn_dropout=0.0,
         | 
| 303 | 
            +
                             proj_dropout=0.0,
         | 
| 304 | 
            +
                             norm_eps=1e-5):
         | 
| 305 | 
            +
                    assert activation in ['quick_gelu', 'gelu', 'swi_glu']
         | 
| 306 | 
            +
                    super().__init__()
         | 
| 307 | 
            +
                    self.dim = dim
         | 
| 308 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 309 | 
            +
                    self.num_heads = num_heads
         | 
| 310 | 
            +
                    self.post_norm = post_norm
         | 
| 311 | 
            +
                    self.causal = causal
         | 
| 312 | 
            +
                    self.norm_eps = norm_eps
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    # layers
         | 
| 315 | 
            +
                    self.norm1 = LayerNorm(dim, eps=norm_eps)
         | 
| 316 | 
            +
                    self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
         | 
| 317 | 
            +
                                              proj_dropout)
         | 
| 318 | 
            +
                    self.norm2 = LayerNorm(dim, eps=norm_eps)
         | 
| 319 | 
            +
                    if activation == 'swi_glu':
         | 
| 320 | 
            +
                        self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
         | 
| 321 | 
            +
                    else:
         | 
| 322 | 
            +
                        self.mlp = nn.Sequential(
         | 
| 323 | 
            +
                            nn.Linear(dim, int(dim * mlp_ratio)),
         | 
| 324 | 
            +
                            QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
         | 
| 325 | 
            +
                            nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                def forward(self, x):
         | 
| 328 | 
            +
                    if self.post_norm:
         | 
| 329 | 
            +
                        x = x + self.norm1(self.attn(x))
         | 
| 330 | 
            +
                        x = x + self.norm2(self.mlp(x))
         | 
| 331 | 
            +
                    else:
         | 
| 332 | 
            +
                        x = x + self.attn(self.norm1(x))
         | 
| 333 | 
            +
                        x = x + self.mlp(self.norm2(x))
         | 
| 334 | 
            +
                    return x
         | 
| 335 | 
            +
             | 
| 336 | 
            +
             | 
| 337 | 
            +
            class AttentionPool(nn.Module):
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                def __init__(self,
         | 
| 340 | 
            +
                             dim,
         | 
| 341 | 
            +
                             mlp_ratio,
         | 
| 342 | 
            +
                             num_heads,
         | 
| 343 | 
            +
                             activation='gelu',
         | 
| 344 | 
            +
                             proj_dropout=0.0,
         | 
| 345 | 
            +
                             norm_eps=1e-5):
         | 
| 346 | 
            +
                    assert dim % num_heads == 0
         | 
| 347 | 
            +
                    super().__init__()
         | 
| 348 | 
            +
                    self.dim = dim
         | 
| 349 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 350 | 
            +
                    self.num_heads = num_heads
         | 
| 351 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 352 | 
            +
                    self.proj_dropout = proj_dropout
         | 
| 353 | 
            +
                    self.norm_eps = norm_eps
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    # layers
         | 
| 356 | 
            +
                    gain = 1.0 / math.sqrt(dim)
         | 
| 357 | 
            +
                    self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
         | 
| 358 | 
            +
                    self.to_q = nn.Linear(dim, dim)
         | 
| 359 | 
            +
                    self.to_kv = nn.Linear(dim, dim * 2)
         | 
| 360 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 361 | 
            +
                    self.norm = LayerNorm(dim, eps=norm_eps)
         | 
| 362 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 363 | 
            +
                        nn.Linear(dim, int(dim * mlp_ratio)),
         | 
| 364 | 
            +
                        QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
         | 
| 365 | 
            +
                        nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                def forward(self, x):
         | 
| 368 | 
            +
                    """
         | 
| 369 | 
            +
                    x:  [B, L, C].
         | 
| 370 | 
            +
                    """
         | 
| 371 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    # compute query, key, value
         | 
| 374 | 
            +
                    q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
         | 
| 375 | 
            +
                    k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    # compute attention
         | 
| 378 | 
            +
                    x = flash_attention(q, k, v, version=2)
         | 
| 379 | 
            +
                    x = x.reshape(b, 1, c)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    # output
         | 
| 382 | 
            +
                    x = self.proj(x)
         | 
| 383 | 
            +
                    x = F.dropout(x, self.proj_dropout, self.training)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    # mlp
         | 
| 386 | 
            +
                    x = x + self.mlp(self.norm(x))
         | 
| 387 | 
            +
                    return x[:, 0]
         | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
            +
            class VisionTransformer(nn.Module):
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                def __init__(self,
         | 
| 393 | 
            +
                             image_size=224,
         | 
| 394 | 
            +
                             patch_size=16,
         | 
| 395 | 
            +
                             dim=768,
         | 
| 396 | 
            +
                             mlp_ratio=4,
         | 
| 397 | 
            +
                             out_dim=512,
         | 
| 398 | 
            +
                             num_heads=12,
         | 
| 399 | 
            +
                             num_layers=12,
         | 
| 400 | 
            +
                             pool_type='token',
         | 
| 401 | 
            +
                             pre_norm=True,
         | 
| 402 | 
            +
                             post_norm=False,
         | 
| 403 | 
            +
                             activation='quick_gelu',
         | 
| 404 | 
            +
                             attn_dropout=0.0,
         | 
| 405 | 
            +
                             proj_dropout=0.0,
         | 
| 406 | 
            +
                             embedding_dropout=0.0,
         | 
| 407 | 
            +
                             norm_eps=1e-5):
         | 
| 408 | 
            +
                    if image_size % patch_size != 0:
         | 
| 409 | 
            +
                        print(
         | 
| 410 | 
            +
                            '[WARNING] image_size is not divisible by patch_size',
         | 
| 411 | 
            +
                            flush=True)
         | 
| 412 | 
            +
                    assert pool_type in ('token', 'token_fc', 'attn_pool')
         | 
| 413 | 
            +
                    out_dim = out_dim or dim
         | 
| 414 | 
            +
                    super().__init__()
         | 
| 415 | 
            +
                    self.image_size = image_size
         | 
| 416 | 
            +
                    self.patch_size = patch_size
         | 
| 417 | 
            +
                    self.num_patches = (image_size // patch_size)**2
         | 
| 418 | 
            +
                    self.dim = dim
         | 
| 419 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 420 | 
            +
                    self.out_dim = out_dim
         | 
| 421 | 
            +
                    self.num_heads = num_heads
         | 
| 422 | 
            +
                    self.num_layers = num_layers
         | 
| 423 | 
            +
                    self.pool_type = pool_type
         | 
| 424 | 
            +
                    self.post_norm = post_norm
         | 
| 425 | 
            +
                    self.norm_eps = norm_eps
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    # embeddings
         | 
| 428 | 
            +
                    gain = 1.0 / math.sqrt(dim)
         | 
| 429 | 
            +
                    self.patch_embedding = nn.Conv2d(
         | 
| 430 | 
            +
                        3,
         | 
| 431 | 
            +
                        dim,
         | 
| 432 | 
            +
                        kernel_size=patch_size,
         | 
| 433 | 
            +
                        stride=patch_size,
         | 
| 434 | 
            +
                        bias=not pre_norm)
         | 
| 435 | 
            +
                    if pool_type in ('token', 'token_fc'):
         | 
| 436 | 
            +
                        self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
         | 
| 437 | 
            +
                    self.pos_embedding = nn.Parameter(gain * torch.randn(
         | 
| 438 | 
            +
                        1, self.num_patches +
         | 
| 439 | 
            +
                        (1 if pool_type in ('token', 'token_fc') else 0), dim))
         | 
| 440 | 
            +
                    self.dropout = nn.Dropout(embedding_dropout)
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    # transformer
         | 
| 443 | 
            +
                    self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
         | 
| 444 | 
            +
                    self.transformer = nn.Sequential(*[
         | 
| 445 | 
            +
                        AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
         | 
| 446 | 
            +
                                       activation, attn_dropout, proj_dropout, norm_eps)
         | 
| 447 | 
            +
                        for _ in range(num_layers)
         | 
| 448 | 
            +
                    ])
         | 
| 449 | 
            +
                    self.post_norm = LayerNorm(dim, eps=norm_eps)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    # head
         | 
| 452 | 
            +
                    if pool_type == 'token':
         | 
| 453 | 
            +
                        self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
         | 
| 454 | 
            +
                    elif pool_type == 'token_fc':
         | 
| 455 | 
            +
                        self.head = nn.Linear(dim, out_dim)
         | 
| 456 | 
            +
                    elif pool_type == 'attn_pool':
         | 
| 457 | 
            +
                        self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
         | 
| 458 | 
            +
                                                  proj_dropout, norm_eps)
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                def forward(self, x, interpolation=False, use_31_block=False):
         | 
| 461 | 
            +
                    b = x.size(0)
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    # embeddings
         | 
| 464 | 
            +
                    x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
         | 
| 465 | 
            +
                    if self.pool_type in ('token', 'token_fc'):
         | 
| 466 | 
            +
                        x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
         | 
| 467 | 
            +
                    if interpolation:
         | 
| 468 | 
            +
                        e = pos_interpolate(self.pos_embedding, x.size(1))
         | 
| 469 | 
            +
                    else:
         | 
| 470 | 
            +
                        e = self.pos_embedding
         | 
| 471 | 
            +
                    e = e.to(dtype=x.dtype, device=x.device)
         | 
| 472 | 
            +
                    x = self.dropout(x + e)
         | 
| 473 | 
            +
                    if self.pre_norm is not None:
         | 
| 474 | 
            +
                        x = self.pre_norm(x)
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    # transformer
         | 
| 477 | 
            +
                    if use_31_block:
         | 
| 478 | 
            +
                        x = self.transformer[:-1](x)
         | 
| 479 | 
            +
                        return x
         | 
| 480 | 
            +
                    else:
         | 
| 481 | 
            +
                        x = self.transformer(x)
         | 
| 482 | 
            +
                        return x
         | 
| 483 | 
            +
             | 
| 484 | 
            +
             | 
| 485 | 
            +
            class CLIP(nn.Module):
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                def __init__(self,
         | 
| 488 | 
            +
                             embed_dim=512,
         | 
| 489 | 
            +
                             image_size=224,
         | 
| 490 | 
            +
                             patch_size=16,
         | 
| 491 | 
            +
                             vision_dim=768,
         | 
| 492 | 
            +
                             vision_mlp_ratio=4,
         | 
| 493 | 
            +
                             vision_heads=12,
         | 
| 494 | 
            +
                             vision_layers=12,
         | 
| 495 | 
            +
                             vision_pool='token',
         | 
| 496 | 
            +
                             vision_pre_norm=True,
         | 
| 497 | 
            +
                             vision_post_norm=False,
         | 
| 498 | 
            +
                             vocab_size=49408,
         | 
| 499 | 
            +
                             text_len=77,
         | 
| 500 | 
            +
                             text_dim=512,
         | 
| 501 | 
            +
                             text_mlp_ratio=4,
         | 
| 502 | 
            +
                             text_heads=8,
         | 
| 503 | 
            +
                             text_layers=12,
         | 
| 504 | 
            +
                             text_causal=True,
         | 
| 505 | 
            +
                             text_pool='argmax',
         | 
| 506 | 
            +
                             text_head_bias=False,
         | 
| 507 | 
            +
                             logit_bias=None,
         | 
| 508 | 
            +
                             activation='quick_gelu',
         | 
| 509 | 
            +
                             attn_dropout=0.0,
         | 
| 510 | 
            +
                             proj_dropout=0.0,
         | 
| 511 | 
            +
                             embedding_dropout=0.0,
         | 
| 512 | 
            +
                             norm_eps=1e-5):
         | 
| 513 | 
            +
                    super().__init__()
         | 
| 514 | 
            +
                    self.embed_dim = embed_dim
         | 
| 515 | 
            +
                    self.image_size = image_size
         | 
| 516 | 
            +
                    self.patch_size = patch_size
         | 
| 517 | 
            +
                    self.vision_dim = vision_dim
         | 
| 518 | 
            +
                    self.vision_mlp_ratio = vision_mlp_ratio
         | 
| 519 | 
            +
                    self.vision_heads = vision_heads
         | 
| 520 | 
            +
                    self.vision_layers = vision_layers
         | 
| 521 | 
            +
                    self.vision_pool = vision_pool
         | 
| 522 | 
            +
                    self.vision_pre_norm = vision_pre_norm
         | 
| 523 | 
            +
                    self.vision_post_norm = vision_post_norm
         | 
| 524 | 
            +
                    self.vocab_size = vocab_size
         | 
| 525 | 
            +
                    self.text_len = text_len
         | 
| 526 | 
            +
                    self.text_dim = text_dim
         | 
| 527 | 
            +
                    self.text_mlp_ratio = text_mlp_ratio
         | 
| 528 | 
            +
                    self.text_heads = text_heads
         | 
| 529 | 
            +
                    self.text_layers = text_layers
         | 
| 530 | 
            +
                    self.text_causal = text_causal
         | 
| 531 | 
            +
                    self.text_pool = text_pool
         | 
| 532 | 
            +
                    self.text_head_bias = text_head_bias
         | 
| 533 | 
            +
                    self.norm_eps = norm_eps
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    # models
         | 
| 536 | 
            +
                    self.visual = VisionTransformer(
         | 
| 537 | 
            +
                        image_size=image_size,
         | 
| 538 | 
            +
                        patch_size=patch_size,
         | 
| 539 | 
            +
                        dim=vision_dim,
         | 
| 540 | 
            +
                        mlp_ratio=vision_mlp_ratio,
         | 
| 541 | 
            +
                        out_dim=embed_dim,
         | 
| 542 | 
            +
                        num_heads=vision_heads,
         | 
| 543 | 
            +
                        num_layers=vision_layers,
         | 
| 544 | 
            +
                        pool_type=vision_pool,
         | 
| 545 | 
            +
                        pre_norm=vision_pre_norm,
         | 
| 546 | 
            +
                        post_norm=vision_post_norm,
         | 
| 547 | 
            +
                        activation=activation,
         | 
| 548 | 
            +
                        attn_dropout=attn_dropout,
         | 
| 549 | 
            +
                        proj_dropout=proj_dropout,
         | 
| 550 | 
            +
                        embedding_dropout=embedding_dropout,
         | 
| 551 | 
            +
                        norm_eps=norm_eps)
         | 
| 552 | 
            +
                    self.textual = TextTransformer(
         | 
| 553 | 
            +
                        vocab_size=vocab_size,
         | 
| 554 | 
            +
                        text_len=text_len,
         | 
| 555 | 
            +
                        dim=text_dim,
         | 
| 556 | 
            +
                        mlp_ratio=text_mlp_ratio,
         | 
| 557 | 
            +
                        out_dim=embed_dim,
         | 
| 558 | 
            +
                        num_heads=text_heads,
         | 
| 559 | 
            +
                        num_layers=text_layers,
         | 
| 560 | 
            +
                        causal=text_causal,
         | 
| 561 | 
            +
                        pool_type=text_pool,
         | 
| 562 | 
            +
                        head_bias=text_head_bias,
         | 
| 563 | 
            +
                        activation=activation,
         | 
| 564 | 
            +
                        attn_dropout=attn_dropout,
         | 
| 565 | 
            +
                        proj_dropout=proj_dropout,
         | 
| 566 | 
            +
                        embedding_dropout=embedding_dropout,
         | 
| 567 | 
            +
                        norm_eps=norm_eps)
         | 
| 568 | 
            +
                    self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
         | 
| 569 | 
            +
                    if logit_bias is not None:
         | 
| 570 | 
            +
                        self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    # initialize weights
         | 
| 573 | 
            +
                    self.init_weights()
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                def forward(self, imgs, txt_ids):
         | 
| 576 | 
            +
                    """
         | 
| 577 | 
            +
                    imgs:       [B, 3, H, W] of torch.float32.
         | 
| 578 | 
            +
                    - mean:     [0.48145466, 0.4578275, 0.40821073]
         | 
| 579 | 
            +
                    - std:      [0.26862954, 0.26130258, 0.27577711]
         | 
| 580 | 
            +
                    txt_ids:    [B, L] of torch.long. Encoded by data.CLIPTokenizer.
         | 
| 581 | 
            +
                    """
         | 
| 582 | 
            +
                    xi = self.visual(imgs)
         | 
| 583 | 
            +
                    xt = self.textual(txt_ids)
         | 
| 584 | 
            +
                    return xi, xt
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                def init_weights(self):
         | 
| 587 | 
            +
                    # embeddings
         | 
| 588 | 
            +
                    nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
         | 
| 589 | 
            +
                    nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    # attentions
         | 
| 592 | 
            +
                    for modality in ['visual', 'textual']:
         | 
| 593 | 
            +
                        dim = self.vision_dim if modality == 'visual' else self.text_dim
         | 
| 594 | 
            +
                        transformer = getattr(self, modality).transformer
         | 
| 595 | 
            +
                        proj_gain = (1.0 / math.sqrt(dim)) * (
         | 
| 596 | 
            +
                            1.0 / math.sqrt(2 * len(transformer)))
         | 
| 597 | 
            +
                        attn_gain = 1.0 / math.sqrt(dim)
         | 
| 598 | 
            +
                        mlp_gain = 1.0 / math.sqrt(2.0 * dim)
         | 
| 599 | 
            +
                        for block in transformer:
         | 
| 600 | 
            +
                            nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
         | 
| 601 | 
            +
                            nn.init.normal_(block.attn.proj.weight, std=proj_gain)
         | 
| 602 | 
            +
                            nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
         | 
| 603 | 
            +
                            nn.init.normal_(block.mlp[2].weight, std=proj_gain)
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                def param_groups(self):
         | 
| 606 | 
            +
                    groups = [{
         | 
| 607 | 
            +
                        'params': [
         | 
| 608 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 609 | 
            +
                            if 'norm' in n or n.endswith('bias')
         | 
| 610 | 
            +
                        ],
         | 
| 611 | 
            +
                        'weight_decay': 0.0
         | 
| 612 | 
            +
                    }, {
         | 
| 613 | 
            +
                        'params': [
         | 
| 614 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 615 | 
            +
                            if not ('norm' in n or n.endswith('bias'))
         | 
| 616 | 
            +
                        ]
         | 
| 617 | 
            +
                    }]
         | 
| 618 | 
            +
                    return groups
         | 
| 619 | 
            +
             | 
| 620 | 
            +
             | 
| 621 | 
            +
            class XLMRobertaWithHead(XLMRoberta):
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                def __init__(self, **kwargs):
         | 
| 624 | 
            +
                    self.out_dim = kwargs.pop('out_dim')
         | 
| 625 | 
            +
                    super().__init__(**kwargs)
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    # head
         | 
| 628 | 
            +
                    mid_dim = (self.dim + self.out_dim) // 2
         | 
| 629 | 
            +
                    self.head = nn.Sequential(
         | 
| 630 | 
            +
                        nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
         | 
| 631 | 
            +
                        nn.Linear(mid_dim, self.out_dim, bias=False))
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                def forward(self, ids):
         | 
| 634 | 
            +
                    # xlm-roberta
         | 
| 635 | 
            +
                    x = super().forward(ids)
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                    # average pooling
         | 
| 638 | 
            +
                    mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
         | 
| 639 | 
            +
                    x = (x * mask).sum(dim=1) / mask.sum(dim=1)
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    # head
         | 
| 642 | 
            +
                    x = self.head(x)
         | 
| 643 | 
            +
                    return x
         | 
| 644 | 
            +
             | 
| 645 | 
            +
             | 
| 646 | 
            +
            class XLMRobertaCLIP(nn.Module):
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                def __init__(self,
         | 
| 649 | 
            +
                             embed_dim=1024,
         | 
| 650 | 
            +
                             image_size=224,
         | 
| 651 | 
            +
                             patch_size=14,
         | 
| 652 | 
            +
                             vision_dim=1280,
         | 
| 653 | 
            +
                             vision_mlp_ratio=4,
         | 
| 654 | 
            +
                             vision_heads=16,
         | 
| 655 | 
            +
                             vision_layers=32,
         | 
| 656 | 
            +
                             vision_pool='token',
         | 
| 657 | 
            +
                             vision_pre_norm=True,
         | 
| 658 | 
            +
                             vision_post_norm=False,
         | 
| 659 | 
            +
                             activation='gelu',
         | 
| 660 | 
            +
                             vocab_size=250002,
         | 
| 661 | 
            +
                             max_text_len=514,
         | 
| 662 | 
            +
                             type_size=1,
         | 
| 663 | 
            +
                             pad_id=1,
         | 
| 664 | 
            +
                             text_dim=1024,
         | 
| 665 | 
            +
                             text_heads=16,
         | 
| 666 | 
            +
                             text_layers=24,
         | 
| 667 | 
            +
                             text_post_norm=True,
         | 
| 668 | 
            +
                             text_dropout=0.1,
         | 
| 669 | 
            +
                             attn_dropout=0.0,
         | 
| 670 | 
            +
                             proj_dropout=0.0,
         | 
| 671 | 
            +
                             embedding_dropout=0.0,
         | 
| 672 | 
            +
                             norm_eps=1e-5):
         | 
| 673 | 
            +
                    super().__init__()
         | 
| 674 | 
            +
                    self.embed_dim = embed_dim
         | 
| 675 | 
            +
                    self.image_size = image_size
         | 
| 676 | 
            +
                    self.patch_size = patch_size
         | 
| 677 | 
            +
                    self.vision_dim = vision_dim
         | 
| 678 | 
            +
                    self.vision_mlp_ratio = vision_mlp_ratio
         | 
| 679 | 
            +
                    self.vision_heads = vision_heads
         | 
| 680 | 
            +
                    self.vision_layers = vision_layers
         | 
| 681 | 
            +
                    self.vision_pre_norm = vision_pre_norm
         | 
| 682 | 
            +
                    self.vision_post_norm = vision_post_norm
         | 
| 683 | 
            +
                    self.activation = activation
         | 
| 684 | 
            +
                    self.vocab_size = vocab_size
         | 
| 685 | 
            +
                    self.max_text_len = max_text_len
         | 
| 686 | 
            +
                    self.type_size = type_size
         | 
| 687 | 
            +
                    self.pad_id = pad_id
         | 
| 688 | 
            +
                    self.text_dim = text_dim
         | 
| 689 | 
            +
                    self.text_heads = text_heads
         | 
| 690 | 
            +
                    self.text_layers = text_layers
         | 
| 691 | 
            +
                    self.text_post_norm = text_post_norm
         | 
| 692 | 
            +
                    self.norm_eps = norm_eps
         | 
| 693 | 
            +
             | 
| 694 | 
            +
                    # models
         | 
| 695 | 
            +
                    self.visual = VisionTransformer(
         | 
| 696 | 
            +
                        image_size=image_size,
         | 
| 697 | 
            +
                        patch_size=patch_size,
         | 
| 698 | 
            +
                        dim=vision_dim,
         | 
| 699 | 
            +
                        mlp_ratio=vision_mlp_ratio,
         | 
| 700 | 
            +
                        out_dim=embed_dim,
         | 
| 701 | 
            +
                        num_heads=vision_heads,
         | 
| 702 | 
            +
                        num_layers=vision_layers,
         | 
| 703 | 
            +
                        pool_type=vision_pool,
         | 
| 704 | 
            +
                        pre_norm=vision_pre_norm,
         | 
| 705 | 
            +
                        post_norm=vision_post_norm,
         | 
| 706 | 
            +
                        activation=activation,
         | 
| 707 | 
            +
                        attn_dropout=attn_dropout,
         | 
| 708 | 
            +
                        proj_dropout=proj_dropout,
         | 
| 709 | 
            +
                        embedding_dropout=embedding_dropout,
         | 
| 710 | 
            +
                        norm_eps=norm_eps)
         | 
| 711 | 
            +
                    self.textual = None
         | 
| 712 | 
            +
                    self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                def forward(self, imgs, txt_ids):
         | 
| 715 | 
            +
                    """
         | 
| 716 | 
            +
                    imgs:       [B, 3, H, W] of torch.float32.
         | 
| 717 | 
            +
                    - mean:     [0.48145466, 0.4578275, 0.40821073]
         | 
| 718 | 
            +
                    - std:      [0.26862954, 0.26130258, 0.27577711]
         | 
| 719 | 
            +
                    txt_ids:    [B, L] of torch.long.
         | 
| 720 | 
            +
                                Encoded by data.CLIPTokenizer.
         | 
| 721 | 
            +
                    """
         | 
| 722 | 
            +
                    xi = self.visual(imgs)
         | 
| 723 | 
            +
                    xt = self.textual(txt_ids)
         | 
| 724 | 
            +
                    return xi, xt
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                def param_groups(self):
         | 
| 727 | 
            +
                    groups = [{
         | 
| 728 | 
            +
                        'params': [
         | 
| 729 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 730 | 
            +
                            if 'norm' in n or n.endswith('bias')
         | 
| 731 | 
            +
                        ],
         | 
| 732 | 
            +
                        'weight_decay': 0.0
         | 
| 733 | 
            +
                    }, {
         | 
| 734 | 
            +
                        'params': [
         | 
| 735 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 736 | 
            +
                            if not ('norm' in n or n.endswith('bias'))
         | 
| 737 | 
            +
                        ]
         | 
| 738 | 
            +
                    }]
         | 
| 739 | 
            +
                    return groups
         | 
| 740 | 
            +
             | 
| 741 | 
            +
             | 
| 742 | 
            +
            def _clip(pretrained=False,
         | 
| 743 | 
            +
                      pretrained_name=None,
         | 
| 744 | 
            +
                      model_cls=CLIP,
         | 
| 745 | 
            +
                      return_transforms=False,
         | 
| 746 | 
            +
                      return_tokenizer=False,
         | 
| 747 | 
            +
                      tokenizer_padding='eos',
         | 
| 748 | 
            +
                      dtype=torch.float32,
         | 
| 749 | 
            +
                      device='cpu',
         | 
| 750 | 
            +
                      **kwargs):
         | 
| 751 | 
            +
                # init model
         | 
| 752 | 
            +
                if pretrained and pretrained_name:
         | 
| 753 | 
            +
                    from sora import BUCKET, DOWNLOAD_TO_CACHE
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                    # init a meta model
         | 
| 756 | 
            +
                    with torch.device('meta'):
         | 
| 757 | 
            +
                        model = model_cls(**kwargs)
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                    # checkpoint path
         | 
| 760 | 
            +
                    checkpoint = f'models/clip/{pretrained_name}'
         | 
| 761 | 
            +
                    if dtype in (torch.float16, torch.bfloat16):
         | 
| 762 | 
            +
                        suffix = '-' + {
         | 
| 763 | 
            +
                            torch.float16: 'fp16',
         | 
| 764 | 
            +
                            torch.bfloat16: 'bf16'
         | 
| 765 | 
            +
                        }[dtype]
         | 
| 766 | 
            +
                        if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
         | 
| 767 | 
            +
                            checkpoint = f'{checkpoint}{suffix}'
         | 
| 768 | 
            +
                    checkpoint += '.pth'
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    # load
         | 
| 771 | 
            +
                    model.load_state_dict(
         | 
| 772 | 
            +
                        torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
         | 
| 773 | 
            +
                        assign=True,
         | 
| 774 | 
            +
                        strict=False)
         | 
| 775 | 
            +
                else:
         | 
| 776 | 
            +
                    # init a model on device
         | 
| 777 | 
            +
                    with torch.device(device):
         | 
| 778 | 
            +
                        model = model_cls(**kwargs)
         | 
| 779 | 
            +
             | 
| 780 | 
            +
                # set device
         | 
| 781 | 
            +
                output = (model,)
         | 
| 782 | 
            +
             | 
| 783 | 
            +
                # init transforms
         | 
| 784 | 
            +
                if return_transforms:
         | 
| 785 | 
            +
                    # mean and std
         | 
| 786 | 
            +
                    if 'siglip' in pretrained_name.lower():
         | 
| 787 | 
            +
                        mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
         | 
| 788 | 
            +
                    else:
         | 
| 789 | 
            +
                        mean = [0.48145466, 0.4578275, 0.40821073]
         | 
| 790 | 
            +
                        std = [0.26862954, 0.26130258, 0.27577711]
         | 
| 791 | 
            +
             | 
| 792 | 
            +
                    # transforms
         | 
| 793 | 
            +
                    transforms = T.Compose([
         | 
| 794 | 
            +
                        T.Resize((model.image_size, model.image_size),
         | 
| 795 | 
            +
                                 interpolation=T.InterpolationMode.BICUBIC),
         | 
| 796 | 
            +
                        T.ToTensor(),
         | 
| 797 | 
            +
                        T.Normalize(mean=mean, std=std)
         | 
| 798 | 
            +
                    ])
         | 
| 799 | 
            +
                    output += (transforms,)
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                # init tokenizer
         | 
| 802 | 
            +
                if return_tokenizer:
         | 
| 803 | 
            +
                    from sora import data
         | 
| 804 | 
            +
                    if 'siglip' in pretrained_name.lower():
         | 
| 805 | 
            +
                        tokenizer = data.HuggingfaceTokenizer(
         | 
| 806 | 
            +
                            name=f'timm/{pretrained_name}',
         | 
| 807 | 
            +
                            seq_len=model.text_len,
         | 
| 808 | 
            +
                            clean='canonicalize')
         | 
| 809 | 
            +
                    elif 'xlm' in pretrained_name.lower():
         | 
| 810 | 
            +
                        tokenizer = data.HuggingfaceTokenizer(
         | 
| 811 | 
            +
                            name='xlm-roberta-large',
         | 
| 812 | 
            +
                            seq_len=model.max_text_len - 2,
         | 
| 813 | 
            +
                            clean='whitespace')
         | 
| 814 | 
            +
                    elif 'mba' in pretrained_name.lower():
         | 
| 815 | 
            +
                        tokenizer = data.HuggingfaceTokenizer(
         | 
| 816 | 
            +
                            name='facebook/xlm-roberta-xl',
         | 
| 817 | 
            +
                            seq_len=model.max_text_len - 2,
         | 
| 818 | 
            +
                            clean='whitespace')
         | 
| 819 | 
            +
                    else:
         | 
| 820 | 
            +
                        tokenizer = data.CLIPTokenizer(
         | 
| 821 | 
            +
                            seq_len=model.text_len, padding=tokenizer_padding)
         | 
| 822 | 
            +
                    output += (tokenizer,)
         | 
| 823 | 
            +
                return output[0] if len(output) == 1 else output
         | 
| 824 | 
            +
             | 
| 825 | 
            +
             | 
| 826 | 
            +
            def clip_xlm_roberta_vit_h_14(
         | 
| 827 | 
            +
                    pretrained=False,
         | 
| 828 | 
            +
                    pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
         | 
| 829 | 
            +
                    **kwargs):
         | 
| 830 | 
            +
                cfg = dict(
         | 
| 831 | 
            +
                    embed_dim=1024,
         | 
| 832 | 
            +
                    image_size=224,
         | 
| 833 | 
            +
                    patch_size=14,
         | 
| 834 | 
            +
                    vision_dim=1280,
         | 
| 835 | 
            +
                    vision_mlp_ratio=4,
         | 
| 836 | 
            +
                    vision_heads=16,
         | 
| 837 | 
            +
                    vision_layers=32,
         | 
| 838 | 
            +
                    vision_pool='token',
         | 
| 839 | 
            +
                    activation='gelu',
         | 
| 840 | 
            +
                    vocab_size=250002,
         | 
| 841 | 
            +
                    max_text_len=514,
         | 
| 842 | 
            +
                    type_size=1,
         | 
| 843 | 
            +
                    pad_id=1,
         | 
| 844 | 
            +
                    text_dim=1024,
         | 
| 845 | 
            +
                    text_heads=16,
         | 
| 846 | 
            +
                    text_layers=24,
         | 
| 847 | 
            +
                    text_post_norm=True,
         | 
| 848 | 
            +
                    text_dropout=0.1,
         | 
| 849 | 
            +
                    attn_dropout=0.0,
         | 
| 850 | 
            +
                    proj_dropout=0.0,
         | 
| 851 | 
            +
                    embedding_dropout=0.0)
         | 
| 852 | 
            +
                cfg.update(**kwargs)
         | 
| 853 | 
            +
                return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
         | 
| 854 | 
            +
             | 
| 855 | 
            +
             | 
| 856 | 
            +
            class WanImageEncoder(torch.nn.Module):
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                def __init__(self):
         | 
| 859 | 
            +
                    super().__init__()
         | 
| 860 | 
            +
                    # init model
         | 
| 861 | 
            +
                    self.model, self.transforms = clip_xlm_roberta_vit_h_14(
         | 
| 862 | 
            +
                        pretrained=False,
         | 
| 863 | 
            +
                        return_transforms=True,
         | 
| 864 | 
            +
                        return_tokenizer=False,
         | 
| 865 | 
            +
                        dtype=torch.float32,
         | 
| 866 | 
            +
                        device="cpu")
         | 
| 867 | 
            +
             | 
| 868 | 
            +
                def encode_image(self, videos):
         | 
| 869 | 
            +
                    # preprocess
         | 
| 870 | 
            +
                    size = (self.model.image_size,) * 2
         | 
| 871 | 
            +
                    videos = torch.cat([
         | 
| 872 | 
            +
                        F.interpolate(
         | 
| 873 | 
            +
                            u,
         | 
| 874 | 
            +
                            size=size,
         | 
| 875 | 
            +
                            mode='bicubic',
         | 
| 876 | 
            +
                            align_corners=False) for u in videos
         | 
| 877 | 
            +
                    ])
         | 
| 878 | 
            +
                    videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
         | 
| 879 | 
            +
             | 
| 880 | 
            +
                    # forward
         | 
| 881 | 
            +
                    out = self.model.visual(videos, use_31_block=True)
         | 
| 882 | 
            +
                    return out
         | 
| 883 | 
            +
                    
         | 
| 884 | 
            +
                @staticmethod
         | 
| 885 | 
            +
                def state_dict_converter():
         | 
| 886 | 
            +
                    return WanImageEncoderStateDictConverter()
         | 
| 887 | 
            +
                
         | 
| 888 | 
            +
                
         | 
| 889 | 
            +
            class WanImageEncoderStateDictConverter:
         | 
| 890 | 
            +
                def __init__(self):
         | 
| 891 | 
            +
                    pass
         | 
| 892 | 
            +
             | 
| 893 | 
            +
                def from_diffusers(self, state_dict):
         | 
| 894 | 
            +
                    return state_dict
         | 
| 895 | 
            +
                
         | 
| 896 | 
            +
                def from_civitai(self, state_dict):
         | 
| 897 | 
            +
                    state_dict_ = {}
         | 
| 898 | 
            +
                    for name, param in state_dict.items():
         | 
| 899 | 
            +
                        if name.startswith("textual."):
         | 
| 900 | 
            +
                            continue
         | 
| 901 | 
            +
                        name = "model." + name
         | 
| 902 | 
            +
                        state_dict_[name] = param
         | 
| 903 | 
            +
                    return state_dict_
         | 
    	
        model/prompter.py
    ADDED
    
    | @@ -0,0 +1,107 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from diffsynth.prompters.base_prompter import BasePrompter
         | 
| 2 | 
            +
            from model.text_encoder import WanTextEncoder
         | 
| 3 | 
            +
            from transformers import AutoTokenizer
         | 
| 4 | 
            +
            import ftfy
         | 
| 5 | 
            +
            import html
         | 
| 6 | 
            +
            import string
         | 
| 7 | 
            +
            import regex as re
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def basic_clean(text):
         | 
| 11 | 
            +
                text = ftfy.fix_text(text)
         | 
| 12 | 
            +
                text = html.unescape(html.unescape(text))
         | 
| 13 | 
            +
                return text.strip()
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def whitespace_clean(text):
         | 
| 17 | 
            +
                text = re.sub(r'\s+', ' ', text)
         | 
| 18 | 
            +
                text = text.strip()
         | 
| 19 | 
            +
                return text
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def canonicalize(text, keep_punctuation_exact_string=None):
         | 
| 23 | 
            +
                text = text.replace('_', ' ')
         | 
| 24 | 
            +
                if keep_punctuation_exact_string:
         | 
| 25 | 
            +
                    text = keep_punctuation_exact_string.join(
         | 
| 26 | 
            +
                        part.translate(str.maketrans('', '', string.punctuation))
         | 
| 27 | 
            +
                        for part in text.split(keep_punctuation_exact_string))
         | 
| 28 | 
            +
                else:
         | 
| 29 | 
            +
                    text = text.translate(str.maketrans('', '', string.punctuation))
         | 
| 30 | 
            +
                text = text.lower()
         | 
| 31 | 
            +
                text = re.sub(r'\s+', ' ', text)
         | 
| 32 | 
            +
                return text.strip()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class HuggingfaceTokenizer:
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __init__(self, name, seq_len=None, clean=None, **kwargs):
         | 
| 38 | 
            +
                    assert clean in (None, 'whitespace', 'lower', 'canonicalize')
         | 
| 39 | 
            +
                    self.name = name
         | 
| 40 | 
            +
                    self.seq_len = seq_len
         | 
| 41 | 
            +
                    self.clean = clean
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # init tokenizer
         | 
| 44 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
         | 
| 45 | 
            +
                    self.vocab_size = self.tokenizer.vocab_size
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __call__(self, sequence, **kwargs):
         | 
| 48 | 
            +
                    return_mask = kwargs.pop('return_mask', False)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # arguments
         | 
| 51 | 
            +
                    _kwargs = {'return_tensors': 'pt'}
         | 
| 52 | 
            +
                    if self.seq_len is not None:
         | 
| 53 | 
            +
                        _kwargs.update({
         | 
| 54 | 
            +
                            'padding': 'max_length',
         | 
| 55 | 
            +
                            'truncation': True,
         | 
| 56 | 
            +
                            'max_length': self.seq_len
         | 
| 57 | 
            +
                        })
         | 
| 58 | 
            +
                    _kwargs.update(**kwargs)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # tokenization
         | 
| 61 | 
            +
                    if isinstance(sequence, str):
         | 
| 62 | 
            +
                        sequence = [sequence]
         | 
| 63 | 
            +
                    if self.clean:
         | 
| 64 | 
            +
                        sequence = [self._clean(u) for u in sequence]
         | 
| 65 | 
            +
                    ids = self.tokenizer(sequence, **_kwargs)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    # output
         | 
| 68 | 
            +
                    if return_mask:
         | 
| 69 | 
            +
                        return ids.input_ids, ids.attention_mask
         | 
| 70 | 
            +
                    else:
         | 
| 71 | 
            +
                        return ids.input_ids
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def _clean(self, text):
         | 
| 74 | 
            +
                    if self.clean == 'whitespace':
         | 
| 75 | 
            +
                        text = whitespace_clean(basic_clean(text))
         | 
| 76 | 
            +
                    elif self.clean == 'lower':
         | 
| 77 | 
            +
                        text = whitespace_clean(basic_clean(text)).lower()
         | 
| 78 | 
            +
                    elif self.clean == 'canonicalize':
         | 
| 79 | 
            +
                        text = canonicalize(basic_clean(text))
         | 
| 80 | 
            +
                    return text
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            class WanPrompter(BasePrompter):
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def __init__(self, tokenizer_path=None, text_len=512):
         | 
| 86 | 
            +
                    super().__init__()
         | 
| 87 | 
            +
                    self.text_len = text_len
         | 
| 88 | 
            +
                    self.text_encoder = None
         | 
| 89 | 
            +
                    self.fetch_tokenizer(tokenizer_path)
         | 
| 90 | 
            +
                    
         | 
| 91 | 
            +
                def fetch_tokenizer(self, tokenizer_path=None):
         | 
| 92 | 
            +
                    if tokenizer_path is not None:
         | 
| 93 | 
            +
                        self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def fetch_models(self, text_encoder: WanTextEncoder = None):
         | 
| 96 | 
            +
                    self.text_encoder = text_encoder
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def encode_prompt(self, prompt, positive=True, device="cuda"):
         | 
| 99 | 
            +
                    prompt = self.process_prompt(prompt, positive=positive)
         | 
| 100 | 
            +
                    
         | 
| 101 | 
            +
                    ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
         | 
| 102 | 
            +
                    ids = ids.to(device)
         | 
| 103 | 
            +
                    mask = mask.to(device)
         | 
| 104 | 
            +
                    seq_lens = mask.gt(0).sum(dim=1).long()
         | 
| 105 | 
            +
                    prompt_emb = self.text_encoder(ids, mask)
         | 
| 106 | 
            +
                    prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
         | 
| 107 | 
            +
                    return prompt_emb
         | 
    	
        model/text_encoder.py
    ADDED
    
    | @@ -0,0 +1,269 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def fp16_clamp(x):
         | 
| 9 | 
            +
                if x.dtype == torch.float16 and torch.isinf(x).any():
         | 
| 10 | 
            +
                    clamp = torch.finfo(x.dtype).max - 1000
         | 
| 11 | 
            +
                    x = torch.clamp(x, min=-clamp, max=clamp)
         | 
| 12 | 
            +
                return x
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class GELU(nn.Module):
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def forward(self, x):
         | 
| 18 | 
            +
                    return 0.5 * x * (1.0 + torch.tanh(
         | 
| 19 | 
            +
                        math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class T5LayerNorm(nn.Module):
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(self, dim, eps=1e-6):
         | 
| 25 | 
            +
                    super(T5LayerNorm, self).__init__()
         | 
| 26 | 
            +
                    self.dim = dim
         | 
| 27 | 
            +
                    self.eps = eps
         | 
| 28 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, x):
         | 
| 31 | 
            +
                    x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
         | 
| 32 | 
            +
                                        self.eps)
         | 
| 33 | 
            +
                    if self.weight.dtype in [torch.float16, torch.bfloat16]:
         | 
| 34 | 
            +
                        x = x.type_as(self.weight)
         | 
| 35 | 
            +
                    return self.weight * x
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            class T5Attention(nn.Module):
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
         | 
| 41 | 
            +
                    assert dim_attn % num_heads == 0
         | 
| 42 | 
            +
                    super(T5Attention, self).__init__()
         | 
| 43 | 
            +
                    self.dim = dim
         | 
| 44 | 
            +
                    self.dim_attn = dim_attn
         | 
| 45 | 
            +
                    self.num_heads = num_heads
         | 
| 46 | 
            +
                    self.head_dim = dim_attn // num_heads
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    # layers
         | 
| 49 | 
            +
                    self.q = nn.Linear(dim, dim_attn, bias=False)
         | 
| 50 | 
            +
                    self.k = nn.Linear(dim, dim_attn, bias=False)
         | 
| 51 | 
            +
                    self.v = nn.Linear(dim, dim_attn, bias=False)
         | 
| 52 | 
            +
                    self.o = nn.Linear(dim_attn, dim, bias=False)
         | 
| 53 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def forward(self, x, context=None, mask=None, pos_bias=None):
         | 
| 56 | 
            +
                    """
         | 
| 57 | 
            +
                    x:          [B, L1, C].
         | 
| 58 | 
            +
                    context:    [B, L2, C] or None.
         | 
| 59 | 
            +
                    mask:       [B, L2] or [B, L1, L2] or None.
         | 
| 60 | 
            +
                    """
         | 
| 61 | 
            +
                    # check inputs
         | 
| 62 | 
            +
                    context = x if context is None else context
         | 
| 63 | 
            +
                    b, n, c = x.size(0), self.num_heads, self.head_dim
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # compute query, key, value
         | 
| 66 | 
            +
                    q = self.q(x).view(b, -1, n, c)
         | 
| 67 | 
            +
                    k = self.k(context).view(b, -1, n, c)
         | 
| 68 | 
            +
                    v = self.v(context).view(b, -1, n, c)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # attention bias
         | 
| 71 | 
            +
                    attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
         | 
| 72 | 
            +
                    if pos_bias is not None:
         | 
| 73 | 
            +
                        attn_bias += pos_bias
         | 
| 74 | 
            +
                    if mask is not None:
         | 
| 75 | 
            +
                        assert mask.ndim in [2, 3]
         | 
| 76 | 
            +
                        mask = mask.view(b, 1, 1,
         | 
| 77 | 
            +
                                         -1) if mask.ndim == 2 else mask.unsqueeze(1)
         | 
| 78 | 
            +
                        attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # compute attention (T5 does not use scaling)
         | 
| 81 | 
            +
                    attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
         | 
| 82 | 
            +
                    attn = F.softmax(attn.float(), dim=-1).type_as(attn)
         | 
| 83 | 
            +
                    x = torch.einsum('bnij,bjnc->binc', attn, v)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    # output
         | 
| 86 | 
            +
                    x = x.reshape(b, -1, n * c)
         | 
| 87 | 
            +
                    x = self.o(x)
         | 
| 88 | 
            +
                    x = self.dropout(x)
         | 
| 89 | 
            +
                    return x
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class T5FeedForward(nn.Module):
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __init__(self, dim, dim_ffn, dropout=0.1):
         | 
| 95 | 
            +
                    super(T5FeedForward, self).__init__()
         | 
| 96 | 
            +
                    self.dim = dim
         | 
| 97 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # layers
         | 
| 100 | 
            +
                    self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
         | 
| 101 | 
            +
                    self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
         | 
| 102 | 
            +
                    self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
         | 
| 103 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def forward(self, x):
         | 
| 106 | 
            +
                    x = self.fc1(x) * self.gate(x)
         | 
| 107 | 
            +
                    x = self.dropout(x)
         | 
| 108 | 
            +
                    x = self.fc2(x)
         | 
| 109 | 
            +
                    x = self.dropout(x)
         | 
| 110 | 
            +
                    return x
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            class T5SelfAttention(nn.Module):
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def __init__(self,
         | 
| 116 | 
            +
                             dim,
         | 
| 117 | 
            +
                             dim_attn,
         | 
| 118 | 
            +
                             dim_ffn,
         | 
| 119 | 
            +
                             num_heads,
         | 
| 120 | 
            +
                             num_buckets,
         | 
| 121 | 
            +
                             shared_pos=True,
         | 
| 122 | 
            +
                             dropout=0.1):
         | 
| 123 | 
            +
                    super(T5SelfAttention, self).__init__()
         | 
| 124 | 
            +
                    self.dim = dim
         | 
| 125 | 
            +
                    self.dim_attn = dim_attn
         | 
| 126 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 127 | 
            +
                    self.num_heads = num_heads
         | 
| 128 | 
            +
                    self.num_buckets = num_buckets
         | 
| 129 | 
            +
                    self.shared_pos = shared_pos
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    # layers
         | 
| 132 | 
            +
                    self.norm1 = T5LayerNorm(dim)
         | 
| 133 | 
            +
                    self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
         | 
| 134 | 
            +
                    self.norm2 = T5LayerNorm(dim)
         | 
| 135 | 
            +
                    self.ffn = T5FeedForward(dim, dim_ffn, dropout)
         | 
| 136 | 
            +
                    self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
         | 
| 137 | 
            +
                        num_buckets, num_heads, bidirectional=True)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def forward(self, x, mask=None, pos_bias=None):
         | 
| 140 | 
            +
                    e = pos_bias if self.shared_pos else self.pos_embedding(
         | 
| 141 | 
            +
                        x.size(1), x.size(1))
         | 
| 142 | 
            +
                    x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
         | 
| 143 | 
            +
                    x = fp16_clamp(x + self.ffn(self.norm2(x)))
         | 
| 144 | 
            +
                    return x
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            class T5RelativeEmbedding(nn.Module):
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
         | 
| 150 | 
            +
                    super(T5RelativeEmbedding, self).__init__()
         | 
| 151 | 
            +
                    self.num_buckets = num_buckets
         | 
| 152 | 
            +
                    self.num_heads = num_heads
         | 
| 153 | 
            +
                    self.bidirectional = bidirectional
         | 
| 154 | 
            +
                    self.max_dist = max_dist
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # layers
         | 
| 157 | 
            +
                    self.embedding = nn.Embedding(num_buckets, num_heads)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def forward(self, lq, lk):
         | 
| 160 | 
            +
                    device = self.embedding.weight.device
         | 
| 161 | 
            +
                    # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
         | 
| 162 | 
            +
                    #     torch.arange(lq).unsqueeze(1).to(device)
         | 
| 163 | 
            +
                    rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
         | 
| 164 | 
            +
                        torch.arange(lq, device=device).unsqueeze(1)
         | 
| 165 | 
            +
                    rel_pos = self._relative_position_bucket(rel_pos)
         | 
| 166 | 
            +
                    rel_pos_embeds = self.embedding(rel_pos)
         | 
| 167 | 
            +
                    rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
         | 
| 168 | 
            +
                        0)  # [1, N, Lq, Lk]
         | 
| 169 | 
            +
                    return rel_pos_embeds.contiguous()
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def _relative_position_bucket(self, rel_pos):
         | 
| 172 | 
            +
                    # preprocess
         | 
| 173 | 
            +
                    if self.bidirectional:
         | 
| 174 | 
            +
                        num_buckets = self.num_buckets // 2
         | 
| 175 | 
            +
                        rel_buckets = (rel_pos > 0).long() * num_buckets
         | 
| 176 | 
            +
                        rel_pos = torch.abs(rel_pos)
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        num_buckets = self.num_buckets
         | 
| 179 | 
            +
                        rel_buckets = 0
         | 
| 180 | 
            +
                        rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # embeddings for small and large positions
         | 
| 183 | 
            +
                    max_exact = num_buckets // 2
         | 
| 184 | 
            +
                    rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
         | 
| 185 | 
            +
                                                 math.log(self.max_dist / max_exact) *
         | 
| 186 | 
            +
                                                 (num_buckets - max_exact)).long()
         | 
| 187 | 
            +
                    rel_pos_large = torch.min(
         | 
| 188 | 
            +
                        rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
         | 
| 189 | 
            +
                    rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
         | 
| 190 | 
            +
                    return rel_buckets
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            def init_weights(m):
         | 
| 193 | 
            +
                if isinstance(m, T5LayerNorm):
         | 
| 194 | 
            +
                    nn.init.ones_(m.weight)
         | 
| 195 | 
            +
                elif isinstance(m, T5FeedForward):
         | 
| 196 | 
            +
                    nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
         | 
| 197 | 
            +
                    nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
         | 
| 198 | 
            +
                    nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
         | 
| 199 | 
            +
                elif isinstance(m, T5Attention):
         | 
| 200 | 
            +
                    nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
         | 
| 201 | 
            +
                    nn.init.normal_(m.k.weight, std=m.dim**-0.5)
         | 
| 202 | 
            +
                    nn.init.normal_(m.v.weight, std=m.dim**-0.5)
         | 
| 203 | 
            +
                    nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
         | 
| 204 | 
            +
                elif isinstance(m, T5RelativeEmbedding):
         | 
| 205 | 
            +
                    nn.init.normal_(
         | 
| 206 | 
            +
                        m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            class WanTextEncoder(torch.nn.Module):
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __init__(self,
         | 
| 212 | 
            +
                             vocab=256384,
         | 
| 213 | 
            +
                             dim=4096,
         | 
| 214 | 
            +
                             dim_attn=4096,
         | 
| 215 | 
            +
                             dim_ffn=10240,
         | 
| 216 | 
            +
                             num_heads=64,
         | 
| 217 | 
            +
                             num_layers=24,
         | 
| 218 | 
            +
                             num_buckets=32,
         | 
| 219 | 
            +
                             shared_pos=False,
         | 
| 220 | 
            +
                             dropout=0.1):
         | 
| 221 | 
            +
                    super(WanTextEncoder, self).__init__()
         | 
| 222 | 
            +
                    self.dim = dim
         | 
| 223 | 
            +
                    self.dim_attn = dim_attn
         | 
| 224 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 225 | 
            +
                    self.num_heads = num_heads
         | 
| 226 | 
            +
                    self.num_layers = num_layers
         | 
| 227 | 
            +
                    self.num_buckets = num_buckets
         | 
| 228 | 
            +
                    self.shared_pos = shared_pos
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # layers
         | 
| 231 | 
            +
                    self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
         | 
| 232 | 
            +
                        else nn.Embedding(vocab, dim)
         | 
| 233 | 
            +
                    self.pos_embedding = T5RelativeEmbedding(
         | 
| 234 | 
            +
                        num_buckets, num_heads, bidirectional=True) if shared_pos else None
         | 
| 235 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 236 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 237 | 
            +
                        T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
         | 
| 238 | 
            +
                                        shared_pos, dropout) for _ in range(num_layers)
         | 
| 239 | 
            +
                    ])
         | 
| 240 | 
            +
                    self.norm = T5LayerNorm(dim)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # initialize weights
         | 
| 243 | 
            +
                    self.apply(init_weights)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def forward(self, ids, mask=None):
         | 
| 246 | 
            +
                    x = self.token_embedding(ids)
         | 
| 247 | 
            +
                    x = self.dropout(x)
         | 
| 248 | 
            +
                    e = self.pos_embedding(x.size(1),
         | 
| 249 | 
            +
                                           x.size(1)) if self.shared_pos else None
         | 
| 250 | 
            +
                    for block in self.blocks:
         | 
| 251 | 
            +
                        x = block(x, mask, pos_bias=e)
         | 
| 252 | 
            +
                    x = self.norm(x)
         | 
| 253 | 
            +
                    x = self.dropout(x)
         | 
| 254 | 
            +
                    return x
         | 
| 255 | 
            +
                
         | 
| 256 | 
            +
                @staticmethod
         | 
| 257 | 
            +
                def state_dict_converter():
         | 
| 258 | 
            +
                    return WanTextEncoderStateDictConverter()
         | 
| 259 | 
            +
                
         | 
| 260 | 
            +
                
         | 
| 261 | 
            +
            class WanTextEncoderStateDictConverter:
         | 
| 262 | 
            +
                def __init__(self):
         | 
| 263 | 
            +
                    pass
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def from_diffusers(self, state_dict):
         | 
| 266 | 
            +
                    return state_dict
         | 
| 267 | 
            +
                
         | 
| 268 | 
            +
                def from_civitai(self, state_dict):
         | 
| 269 | 
            +
                    return state_dict
         | 
    	
        model/vae.py
    ADDED
    
    | @@ -0,0 +1,809 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from einops import rearrange, repeat
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            CACHE_T = 2
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def check_is_instance(model, module_class):
         | 
| 12 | 
            +
                if isinstance(model, module_class):
         | 
| 13 | 
            +
                    return True
         | 
| 14 | 
            +
                if hasattr(model, "module") and isinstance(model.module, module_class):
         | 
| 15 | 
            +
                    return True
         | 
| 16 | 
            +
                return False
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def block_causal_mask(x, block_size):
         | 
| 20 | 
            +
                # params
         | 
| 21 | 
            +
                b, n, s, _, device = *x.size(), x.device
         | 
| 22 | 
            +
                assert s % block_size == 0
         | 
| 23 | 
            +
                num_blocks = s // block_size
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # build mask
         | 
| 26 | 
            +
                mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
         | 
| 27 | 
            +
                for i in range(num_blocks):
         | 
| 28 | 
            +
                    mask[:, :,
         | 
| 29 | 
            +
                         i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
         | 
| 30 | 
            +
                return mask
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class CausalConv3d(nn.Conv3d):
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                Causal 3d convolusion.
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 39 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 40 | 
            +
                    self._padding = (self.padding[2], self.padding[2], self.padding[1],
         | 
| 41 | 
            +
                                     self.padding[1], 2 * self.padding[0], 0)
         | 
| 42 | 
            +
                    self.padding = (0, 0, 0)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def forward(self, x, cache_x=None):
         | 
| 45 | 
            +
                    padding = list(self._padding)
         | 
| 46 | 
            +
                    if cache_x is not None and self._padding[4] > 0:
         | 
| 47 | 
            +
                        cache_x = cache_x.to(x.device)
         | 
| 48 | 
            +
                        x = torch.cat([cache_x, x], dim=2)
         | 
| 49 | 
            +
                        padding[4] -= cache_x.shape[2]
         | 
| 50 | 
            +
                    x = F.pad(x, padding)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    return super().forward(x)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            class RMS_norm(nn.Module):
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __init__(self, dim, channel_first=True, images=True, bias=False):
         | 
| 58 | 
            +
                    super().__init__()
         | 
| 59 | 
            +
                    broadcastable_dims = (1, 1, 1) if not images else (1, 1)
         | 
| 60 | 
            +
                    shape = (dim, *broadcastable_dims) if channel_first else (dim,)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    self.channel_first = channel_first
         | 
| 63 | 
            +
                    self.scale = dim**0.5
         | 
| 64 | 
            +
                    self.gamma = nn.Parameter(torch.ones(shape))
         | 
| 65 | 
            +
                    self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def forward(self, x):
         | 
| 68 | 
            +
                    return F.normalize(
         | 
| 69 | 
            +
                        x, dim=(1 if self.channel_first else
         | 
| 70 | 
            +
                                -1)) * self.scale * self.gamma + self.bias
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class Upsample(nn.Upsample):
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def forward(self, x):
         | 
| 76 | 
            +
                    """
         | 
| 77 | 
            +
                    Fix bfloat16 support for nearest neighbor interpolation.
         | 
| 78 | 
            +
                    """
         | 
| 79 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            class Resample(nn.Module):
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def __init__(self, dim, mode):
         | 
| 85 | 
            +
                    assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
         | 
| 86 | 
            +
                                    'downsample3d')
         | 
| 87 | 
            +
                    super().__init__()
         | 
| 88 | 
            +
                    self.dim = dim
         | 
| 89 | 
            +
                    self.mode = mode
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    # layers
         | 
| 92 | 
            +
                    if mode == 'upsample2d':
         | 
| 93 | 
            +
                        self.resample = nn.Sequential(
         | 
| 94 | 
            +
                            Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
         | 
| 95 | 
            +
                            nn.Conv2d(dim, dim // 2, 3, padding=1))
         | 
| 96 | 
            +
                    elif mode == 'upsample3d':
         | 
| 97 | 
            +
                        self.resample = nn.Sequential(
         | 
| 98 | 
            +
                            Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
         | 
| 99 | 
            +
                            nn.Conv2d(dim, dim // 2, 3, padding=1))
         | 
| 100 | 
            +
                        self.time_conv = CausalConv3d(dim,
         | 
| 101 | 
            +
                                                      dim * 2, (3, 1, 1),
         | 
| 102 | 
            +
                                                      padding=(1, 0, 0))
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    elif mode == 'downsample2d':
         | 
| 105 | 
            +
                        self.resample = nn.Sequential(
         | 
| 106 | 
            +
                            nn.ZeroPad2d((0, 1, 0, 1)),
         | 
| 107 | 
            +
                            nn.Conv2d(dim, dim, 3, stride=(2, 2)))
         | 
| 108 | 
            +
                    elif mode == 'downsample3d':
         | 
| 109 | 
            +
                        self.resample = nn.Sequential(
         | 
| 110 | 
            +
                            nn.ZeroPad2d((0, 1, 0, 1)),
         | 
| 111 | 
            +
                            nn.Conv2d(dim, dim, 3, stride=(2, 2)))
         | 
| 112 | 
            +
                        self.time_conv = CausalConv3d(dim,
         | 
| 113 | 
            +
                                                      dim, (3, 1, 1),
         | 
| 114 | 
            +
                                                      stride=(2, 1, 1),
         | 
| 115 | 
            +
                                                      padding=(0, 0, 0))
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        self.resample = nn.Identity()
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 121 | 
            +
                    b, c, t, h, w = x.size()
         | 
| 122 | 
            +
                    if self.mode == 'upsample3d':
         | 
| 123 | 
            +
                        if feat_cache is not None:
         | 
| 124 | 
            +
                            idx = feat_idx[0]
         | 
| 125 | 
            +
                            if feat_cache[idx] is None:
         | 
| 126 | 
            +
                                feat_cache[idx] = 'Rep'
         | 
| 127 | 
            +
                                feat_idx[0] += 1
         | 
| 128 | 
            +
                            else:
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                                cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 131 | 
            +
                                if cache_x.shape[2] < 2 and feat_cache[
         | 
| 132 | 
            +
                                        idx] is not None and feat_cache[idx] != 'Rep':
         | 
| 133 | 
            +
                                    # cache last frame of last two chunk
         | 
| 134 | 
            +
                                    cache_x = torch.cat([
         | 
| 135 | 
            +
                                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 136 | 
            +
                                            cache_x.device), cache_x
         | 
| 137 | 
            +
                                    ],
         | 
| 138 | 
            +
                                                        dim=2)
         | 
| 139 | 
            +
                                if cache_x.shape[2] < 2 and feat_cache[
         | 
| 140 | 
            +
                                        idx] is not None and feat_cache[idx] == 'Rep':
         | 
| 141 | 
            +
                                    cache_x = torch.cat([
         | 
| 142 | 
            +
                                        torch.zeros_like(cache_x).to(cache_x.device),
         | 
| 143 | 
            +
                                        cache_x
         | 
| 144 | 
            +
                                    ],
         | 
| 145 | 
            +
                                                        dim=2)
         | 
| 146 | 
            +
                                if feat_cache[idx] == 'Rep':
         | 
| 147 | 
            +
                                    x = self.time_conv(x)
         | 
| 148 | 
            +
                                else:
         | 
| 149 | 
            +
                                    x = self.time_conv(x, feat_cache[idx])
         | 
| 150 | 
            +
                                feat_cache[idx] = cache_x
         | 
| 151 | 
            +
                                feat_idx[0] += 1
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                                x = x.reshape(b, 2, c, t, h, w)
         | 
| 154 | 
            +
                                x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
         | 
| 155 | 
            +
                                                3)
         | 
| 156 | 
            +
                                x = x.reshape(b, c, t * 2, h, w)
         | 
| 157 | 
            +
                    t = x.shape[2]
         | 
| 158 | 
            +
                    x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 159 | 
            +
                    x = self.resample(x)
         | 
| 160 | 
            +
                    x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    if self.mode == 'downsample3d':
         | 
| 163 | 
            +
                        if feat_cache is not None:
         | 
| 164 | 
            +
                            idx = feat_idx[0]
         | 
| 165 | 
            +
                            if feat_cache[idx] is None:
         | 
| 166 | 
            +
                                feat_cache[idx] = x.clone()
         | 
| 167 | 
            +
                                feat_idx[0] += 1
         | 
| 168 | 
            +
                            else:
         | 
| 169 | 
            +
                                cache_x = x[:, :, -1:, :, :].clone()
         | 
| 170 | 
            +
                                x = self.time_conv(
         | 
| 171 | 
            +
                                    torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
         | 
| 172 | 
            +
                                feat_cache[idx] = cache_x
         | 
| 173 | 
            +
                                feat_idx[0] += 1
         | 
| 174 | 
            +
                    return x
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def init_weight(self, conv):
         | 
| 177 | 
            +
                    conv_weight = conv.weight
         | 
| 178 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 179 | 
            +
                    c1, c2, t, h, w = conv_weight.size()
         | 
| 180 | 
            +
                    one_matrix = torch.eye(c1, c2)
         | 
| 181 | 
            +
                    init_matrix = one_matrix
         | 
| 182 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 183 | 
            +
                    conv_weight.data[:, :, 1, 0, 0] = init_matrix
         | 
| 184 | 
            +
                    conv.weight.data.copy_(conv_weight)
         | 
| 185 | 
            +
                    nn.init.zeros_(conv.bias.data)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def init_weight2(self, conv):
         | 
| 188 | 
            +
                    conv_weight = conv.weight.data
         | 
| 189 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 190 | 
            +
                    c1, c2, t, h, w = conv_weight.size()
         | 
| 191 | 
            +
                    init_matrix = torch.eye(c1 // 2, c2)
         | 
| 192 | 
            +
                    conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
         | 
| 193 | 
            +
                    conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
         | 
| 194 | 
            +
                    conv.weight.data.copy_(conv_weight)
         | 
| 195 | 
            +
                    nn.init.zeros_(conv.bias.data)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            class ResidualBlock(nn.Module):
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                def __init__(self, in_dim, out_dim, dropout=0.0):
         | 
| 201 | 
            +
                    super().__init__()
         | 
| 202 | 
            +
                    self.in_dim = in_dim
         | 
| 203 | 
            +
                    self.out_dim = out_dim
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # layers
         | 
| 206 | 
            +
                    self.residual = nn.Sequential(
         | 
| 207 | 
            +
                        RMS_norm(in_dim, images=False), nn.SiLU(),
         | 
| 208 | 
            +
                        CausalConv3d(in_dim, out_dim, 3, padding=1),
         | 
| 209 | 
            +
                        RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
         | 
| 210 | 
            +
                        CausalConv3d(out_dim, out_dim, 3, padding=1))
         | 
| 211 | 
            +
                    self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
         | 
| 212 | 
            +
                        if in_dim != out_dim else nn.Identity()
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 215 | 
            +
                    h = self.shortcut(x)
         | 
| 216 | 
            +
                    for layer in self.residual:
         | 
| 217 | 
            +
                        if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 218 | 
            +
                            idx = feat_idx[0]
         | 
| 219 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 220 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 221 | 
            +
                                # cache last frame of last two chunk
         | 
| 222 | 
            +
                                cache_x = torch.cat([
         | 
| 223 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 224 | 
            +
                                        cache_x.device), cache_x
         | 
| 225 | 
            +
                                ],
         | 
| 226 | 
            +
                                                    dim=2)
         | 
| 227 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 228 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 229 | 
            +
                            feat_idx[0] += 1
         | 
| 230 | 
            +
                        else:
         | 
| 231 | 
            +
                            x = layer(x)
         | 
| 232 | 
            +
                    return x + h
         | 
| 233 | 
            +
             | 
| 234 | 
            +
             | 
| 235 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 236 | 
            +
                """
         | 
| 237 | 
            +
                Causal self-attention with a single head.
         | 
| 238 | 
            +
                """
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def __init__(self, dim):
         | 
| 241 | 
            +
                    super().__init__()
         | 
| 242 | 
            +
                    self.dim = dim
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    # layers
         | 
| 245 | 
            +
                    self.norm = RMS_norm(dim)
         | 
| 246 | 
            +
                    self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
         | 
| 247 | 
            +
                    self.proj = nn.Conv2d(dim, dim, 1)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    # zero out the last layer params
         | 
| 250 | 
            +
                    nn.init.zeros_(self.proj.weight)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def forward(self, x):
         | 
| 253 | 
            +
                    identity = x
         | 
| 254 | 
            +
                    b, c, t, h, w = x.size()
         | 
| 255 | 
            +
                    x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 256 | 
            +
                    x = self.norm(x)
         | 
| 257 | 
            +
                    # compute query, key, value
         | 
| 258 | 
            +
                    q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
         | 
| 259 | 
            +
                        0, 1, 3, 2).contiguous().chunk(3, dim=-1)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    # apply attention
         | 
| 262 | 
            +
                    x = F.scaled_dot_product_attention(
         | 
| 263 | 
            +
                        q,
         | 
| 264 | 
            +
                        k,
         | 
| 265 | 
            +
                        v,
         | 
| 266 | 
            +
                        #attn_mask=block_causal_mask(q, block_size=h * w)
         | 
| 267 | 
            +
                    )
         | 
| 268 | 
            +
                    x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # output
         | 
| 271 | 
            +
                    x = self.proj(x)
         | 
| 272 | 
            +
                    x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
         | 
| 273 | 
            +
                    return x + identity
         | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            class Encoder3d(nn.Module):
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                def __init__(self,
         | 
| 279 | 
            +
                             dim=128,
         | 
| 280 | 
            +
                             z_dim=4,
         | 
| 281 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 282 | 
            +
                             num_res_blocks=2,
         | 
| 283 | 
            +
                             attn_scales=[],
         | 
| 284 | 
            +
                             temperal_downsample=[True, True, False],
         | 
| 285 | 
            +
                             dropout=0.0):
         | 
| 286 | 
            +
                    super().__init__()
         | 
| 287 | 
            +
                    self.dim = dim
         | 
| 288 | 
            +
                    self.z_dim = z_dim
         | 
| 289 | 
            +
                    self.dim_mult = dim_mult
         | 
| 290 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 291 | 
            +
                    self.attn_scales = attn_scales
         | 
| 292 | 
            +
                    self.temperal_downsample = temperal_downsample
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    # dimensions
         | 
| 295 | 
            +
                    dims = [dim * u for u in [1] + dim_mult]
         | 
| 296 | 
            +
                    scale = 1.0
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # init block
         | 
| 299 | 
            +
                    self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    # downsample blocks
         | 
| 302 | 
            +
                    downsamples = []
         | 
| 303 | 
            +
                    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
         | 
| 304 | 
            +
                        # residual (+attention) blocks
         | 
| 305 | 
            +
                        for _ in range(num_res_blocks):
         | 
| 306 | 
            +
                            downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
         | 
| 307 | 
            +
                            if scale in attn_scales:
         | 
| 308 | 
            +
                                downsamples.append(AttentionBlock(out_dim))
         | 
| 309 | 
            +
                            in_dim = out_dim
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                        # downsample block
         | 
| 312 | 
            +
                        if i != len(dim_mult) - 1:
         | 
| 313 | 
            +
                            mode = 'downsample3d' if temperal_downsample[
         | 
| 314 | 
            +
                                i] else 'downsample2d'
         | 
| 315 | 
            +
                            downsamples.append(Resample(out_dim, mode=mode))
         | 
| 316 | 
            +
                            scale /= 2.0
         | 
| 317 | 
            +
                    self.downsamples = nn.Sequential(*downsamples)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    # middle blocks
         | 
| 320 | 
            +
                    self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
         | 
| 321 | 
            +
                                                AttentionBlock(out_dim),
         | 
| 322 | 
            +
                                                ResidualBlock(out_dim, out_dim, dropout))
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    # output blocks
         | 
| 325 | 
            +
                    self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
         | 
| 326 | 
            +
                                              CausalConv3d(out_dim, z_dim, 3, padding=1))
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 329 | 
            +
                    if feat_cache is not None:
         | 
| 330 | 
            +
                        idx = feat_idx[0]
         | 
| 331 | 
            +
                        cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 332 | 
            +
                        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 333 | 
            +
                            # cache last frame of last two chunk
         | 
| 334 | 
            +
                            cache_x = torch.cat([
         | 
| 335 | 
            +
                                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 336 | 
            +
                                    cache_x.device), cache_x
         | 
| 337 | 
            +
                            ],
         | 
| 338 | 
            +
                                                dim=2)
         | 
| 339 | 
            +
                        x = self.conv1(x, feat_cache[idx])
         | 
| 340 | 
            +
                        feat_cache[idx] = cache_x
         | 
| 341 | 
            +
                        feat_idx[0] += 1
         | 
| 342 | 
            +
                    else:
         | 
| 343 | 
            +
                        x = self.conv1(x)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    ## downsamples
         | 
| 346 | 
            +
                    for layer in self.downsamples:
         | 
| 347 | 
            +
                        if feat_cache is not None:
         | 
| 348 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 349 | 
            +
                        else:
         | 
| 350 | 
            +
                            x = layer(x)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    ## middle
         | 
| 353 | 
            +
                    for layer in self.middle:
         | 
| 354 | 
            +
                        if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
         | 
| 355 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 356 | 
            +
                        else:
         | 
| 357 | 
            +
                            x = layer(x)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    ## head
         | 
| 360 | 
            +
                    for layer in self.head:
         | 
| 361 | 
            +
                        if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 362 | 
            +
                            idx = feat_idx[0]
         | 
| 363 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 364 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 365 | 
            +
                                # cache last frame of last two chunk
         | 
| 366 | 
            +
                                cache_x = torch.cat([
         | 
| 367 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 368 | 
            +
                                        cache_x.device), cache_x
         | 
| 369 | 
            +
                                ],
         | 
| 370 | 
            +
                                                    dim=2)
         | 
| 371 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 372 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 373 | 
            +
                            feat_idx[0] += 1
         | 
| 374 | 
            +
                        else:
         | 
| 375 | 
            +
                            x = layer(x)
         | 
| 376 | 
            +
                    return x
         | 
| 377 | 
            +
             | 
| 378 | 
            +
             | 
| 379 | 
            +
            class Decoder3d(nn.Module):
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                def __init__(self,
         | 
| 382 | 
            +
                             dim=128,
         | 
| 383 | 
            +
                             z_dim=4,
         | 
| 384 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 385 | 
            +
                             num_res_blocks=2,
         | 
| 386 | 
            +
                             attn_scales=[],
         | 
| 387 | 
            +
                             temperal_upsample=[False, True, True],
         | 
| 388 | 
            +
                             dropout=0.0):
         | 
| 389 | 
            +
                    super().__init__()
         | 
| 390 | 
            +
                    self.dim = dim
         | 
| 391 | 
            +
                    self.z_dim = z_dim
         | 
| 392 | 
            +
                    self.dim_mult = dim_mult
         | 
| 393 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 394 | 
            +
                    self.attn_scales = attn_scales
         | 
| 395 | 
            +
                    self.temperal_upsample = temperal_upsample
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    # dimensions
         | 
| 398 | 
            +
                    dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
         | 
| 399 | 
            +
                    scale = 1.0 / 2**(len(dim_mult) - 2)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    # init block
         | 
| 402 | 
            +
                    self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    # middle blocks
         | 
| 405 | 
            +
                    self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
         | 
| 406 | 
            +
                                                AttentionBlock(dims[0]),
         | 
| 407 | 
            +
                                                ResidualBlock(dims[0], dims[0], dropout))
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    # upsample blocks
         | 
| 410 | 
            +
                    upsamples = []
         | 
| 411 | 
            +
                    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
         | 
| 412 | 
            +
                        # residual (+attention) blocks
         | 
| 413 | 
            +
                        if i == 1 or i == 2 or i == 3:
         | 
| 414 | 
            +
                            in_dim = in_dim // 2
         | 
| 415 | 
            +
                        for _ in range(num_res_blocks + 1):
         | 
| 416 | 
            +
                            upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
         | 
| 417 | 
            +
                            if scale in attn_scales:
         | 
| 418 | 
            +
                                upsamples.append(AttentionBlock(out_dim))
         | 
| 419 | 
            +
                            in_dim = out_dim
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                        # upsample block
         | 
| 422 | 
            +
                        if i != len(dim_mult) - 1:
         | 
| 423 | 
            +
                            mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
         | 
| 424 | 
            +
                            upsamples.append(Resample(out_dim, mode=mode))
         | 
| 425 | 
            +
                            scale *= 2.0
         | 
| 426 | 
            +
                    self.upsamples = nn.Sequential(*upsamples)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    # output blocks
         | 
| 429 | 
            +
                    self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
         | 
| 430 | 
            +
                                              CausalConv3d(out_dim, 3, 3, padding=1))
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 433 | 
            +
                    ## conv1
         | 
| 434 | 
            +
                    if feat_cache is not None:
         | 
| 435 | 
            +
                        idx = feat_idx[0]
         | 
| 436 | 
            +
                        cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 437 | 
            +
                        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 438 | 
            +
                            # cache last frame of last two chunk
         | 
| 439 | 
            +
                            cache_x = torch.cat([
         | 
| 440 | 
            +
                                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 441 | 
            +
                                    cache_x.device), cache_x
         | 
| 442 | 
            +
                            ],
         | 
| 443 | 
            +
                                                dim=2)
         | 
| 444 | 
            +
                        x = self.conv1(x, feat_cache[idx])
         | 
| 445 | 
            +
                        feat_cache[idx] = cache_x
         | 
| 446 | 
            +
                        feat_idx[0] += 1
         | 
| 447 | 
            +
                    else:
         | 
| 448 | 
            +
                        x = self.conv1(x)
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    ## middle
         | 
| 451 | 
            +
                    for layer in self.middle:
         | 
| 452 | 
            +
                        if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
         | 
| 453 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 454 | 
            +
                        else:
         | 
| 455 | 
            +
                            x = layer(x)
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    ## upsamples
         | 
| 458 | 
            +
                    for layer in self.upsamples:
         | 
| 459 | 
            +
                        if feat_cache is not None:
         | 
| 460 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 461 | 
            +
                        else:
         | 
| 462 | 
            +
                            x = layer(x)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    ## head
         | 
| 465 | 
            +
                    for layer in self.head:
         | 
| 466 | 
            +
                        if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 467 | 
            +
                            idx = feat_idx[0]
         | 
| 468 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 469 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 470 | 
            +
                                # cache last frame of last two chunk
         | 
| 471 | 
            +
                                cache_x = torch.cat([
         | 
| 472 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 473 | 
            +
                                        cache_x.device), cache_x
         | 
| 474 | 
            +
                                ],
         | 
| 475 | 
            +
                                                    dim=2)
         | 
| 476 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 477 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 478 | 
            +
                            feat_idx[0] += 1
         | 
| 479 | 
            +
                        else:
         | 
| 480 | 
            +
                            x = layer(x)
         | 
| 481 | 
            +
                    return x
         | 
| 482 | 
            +
             | 
| 483 | 
            +
             | 
| 484 | 
            +
            def count_conv3d(model):
         | 
| 485 | 
            +
                count = 0
         | 
| 486 | 
            +
                for m in model.modules():
         | 
| 487 | 
            +
                    if check_is_instance(m, CausalConv3d):
         | 
| 488 | 
            +
                        count += 1
         | 
| 489 | 
            +
                return count
         | 
| 490 | 
            +
             | 
| 491 | 
            +
             | 
| 492 | 
            +
            class VideoVAE_(nn.Module):
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                def __init__(self,
         | 
| 495 | 
            +
                             dim=96,
         | 
| 496 | 
            +
                             z_dim=16,
         | 
| 497 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 498 | 
            +
                             num_res_blocks=2,
         | 
| 499 | 
            +
                             attn_scales=[],
         | 
| 500 | 
            +
                             temperal_downsample=[False, True, True],
         | 
| 501 | 
            +
                             dropout=0.0):
         | 
| 502 | 
            +
                    super().__init__()
         | 
| 503 | 
            +
                    self.dim = dim
         | 
| 504 | 
            +
                    self.z_dim = z_dim
         | 
| 505 | 
            +
                    self.dim_mult = dim_mult
         | 
| 506 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 507 | 
            +
                    self.attn_scales = attn_scales
         | 
| 508 | 
            +
                    self.temperal_downsample = temperal_downsample
         | 
| 509 | 
            +
                    self.temperal_upsample = temperal_downsample[::-1]
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    # modules
         | 
| 512 | 
            +
                    self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
         | 
| 513 | 
            +
                                             attn_scales, self.temperal_downsample, dropout)
         | 
| 514 | 
            +
                    self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
         | 
| 515 | 
            +
                    self.conv2 = CausalConv3d(z_dim, z_dim, 1)
         | 
| 516 | 
            +
                    self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
         | 
| 517 | 
            +
                                             attn_scales, self.temperal_upsample, dropout)
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                def forward(self, x):
         | 
| 520 | 
            +
                    mu, log_var = self.encode(x)
         | 
| 521 | 
            +
                    z = self.reparameterize(mu, log_var)
         | 
| 522 | 
            +
                    x_recon = self.decode(z)
         | 
| 523 | 
            +
                    return x_recon, mu, log_var
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                def encode(self, x, scale):
         | 
| 526 | 
            +
                    self.clear_cache()
         | 
| 527 | 
            +
                    ## cache
         | 
| 528 | 
            +
                    t = x.shape[2]
         | 
| 529 | 
            +
                    iter_ = 1 + (t - 1) // 4
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                    for i in range(iter_):
         | 
| 532 | 
            +
                        self._enc_conv_idx = [0]
         | 
| 533 | 
            +
                        if i == 0:
         | 
| 534 | 
            +
                            out = self.encoder(x[:, :, :1, :, :],
         | 
| 535 | 
            +
                                               feat_cache=self._enc_feat_map,
         | 
| 536 | 
            +
                                               feat_idx=self._enc_conv_idx)
         | 
| 537 | 
            +
                        else:
         | 
| 538 | 
            +
                            out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
         | 
| 539 | 
            +
                                                feat_cache=self._enc_feat_map,
         | 
| 540 | 
            +
                                                feat_idx=self._enc_conv_idx)
         | 
| 541 | 
            +
                            out = torch.cat([out, out_], 2)
         | 
| 542 | 
            +
                    mu, log_var = self.conv1(out).chunk(2, dim=1)
         | 
| 543 | 
            +
                    if isinstance(scale[0], torch.Tensor):
         | 
| 544 | 
            +
                        scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
         | 
| 545 | 
            +
                        mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
         | 
| 546 | 
            +
                            1, self.z_dim, 1, 1, 1)
         | 
| 547 | 
            +
                    else:
         | 
| 548 | 
            +
                        scale = scale.to(dtype=mu.dtype, device=mu.device)
         | 
| 549 | 
            +
                        mu = (mu - scale[0]) * scale[1]
         | 
| 550 | 
            +
                    return mu
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                def decode(self, z, scale):
         | 
| 553 | 
            +
                    self.clear_cache()
         | 
| 554 | 
            +
                    # z: [b,c,t,h,w]
         | 
| 555 | 
            +
                    if isinstance(scale[0], torch.Tensor):
         | 
| 556 | 
            +
                        scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
         | 
| 557 | 
            +
                        z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
         | 
| 558 | 
            +
                            1, self.z_dim, 1, 1, 1)
         | 
| 559 | 
            +
                    else:
         | 
| 560 | 
            +
                        scale = scale.to(dtype=z.dtype, device=z.device)
         | 
| 561 | 
            +
                        z = z / scale[1] + scale[0]
         | 
| 562 | 
            +
                    iter_ = z.shape[2]
         | 
| 563 | 
            +
                    x = self.conv2(z)
         | 
| 564 | 
            +
                    for i in range(iter_):
         | 
| 565 | 
            +
                        self._conv_idx = [0]
         | 
| 566 | 
            +
                        if i == 0:
         | 
| 567 | 
            +
                            out = self.decoder(x[:, :, i:i + 1, :, :],
         | 
| 568 | 
            +
                                               feat_cache=self._feat_map,
         | 
| 569 | 
            +
                                               feat_idx=self._conv_idx)
         | 
| 570 | 
            +
                        else:
         | 
| 571 | 
            +
                            out_ = self.decoder(x[:, :, i:i + 1, :, :],
         | 
| 572 | 
            +
                                                feat_cache=self._feat_map,
         | 
| 573 | 
            +
                                                feat_idx=self._conv_idx)
         | 
| 574 | 
            +
                            out = torch.cat([out, out_], 2) # may add tensor offload
         | 
| 575 | 
            +
                    return out
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                def reparameterize(self, mu, log_var):
         | 
| 578 | 
            +
                    std = torch.exp(0.5 * log_var)
         | 
| 579 | 
            +
                    eps = torch.randn_like(std)
         | 
| 580 | 
            +
                    return eps * std + mu
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                def sample(self, imgs, deterministic=False):
         | 
| 583 | 
            +
                    mu, log_var = self.encode(imgs)
         | 
| 584 | 
            +
                    if deterministic:
         | 
| 585 | 
            +
                        return mu
         | 
| 586 | 
            +
                    std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
         | 
| 587 | 
            +
                    return mu + std * torch.randn_like(std)
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                def clear_cache(self):
         | 
| 590 | 
            +
                    self._conv_num = count_conv3d(self.decoder)
         | 
| 591 | 
            +
                    self._conv_idx = [0]
         | 
| 592 | 
            +
                    self._feat_map = [None] * self._conv_num
         | 
| 593 | 
            +
                    # cache encode
         | 
| 594 | 
            +
                    self._enc_conv_num = count_conv3d(self.encoder)
         | 
| 595 | 
            +
                    self._enc_conv_idx = [0]
         | 
| 596 | 
            +
                    self._enc_feat_map = [None] * self._enc_conv_num
         | 
| 597 | 
            +
             | 
| 598 | 
            +
             | 
| 599 | 
            +
            class WanVideoVAE(nn.Module):
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                def __init__(self, z_dim=16):
         | 
| 602 | 
            +
                    super().__init__()
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                    mean = [
         | 
| 605 | 
            +
                        -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
         | 
| 606 | 
            +
                        0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
         | 
| 607 | 
            +
                    ]
         | 
| 608 | 
            +
                    std = [
         | 
| 609 | 
            +
                        2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
         | 
| 610 | 
            +
                        3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
         | 
| 611 | 
            +
                    ]
         | 
| 612 | 
            +
                    self.mean = torch.tensor(mean)
         | 
| 613 | 
            +
                    self.std = torch.tensor(std)
         | 
| 614 | 
            +
                    self.scale = [self.mean, 1.0 / self.std]
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                    # init model
         | 
| 617 | 
            +
                    self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
         | 
| 618 | 
            +
                    self.upsampling_factor = 8
         | 
| 619 | 
            +
             | 
| 620 | 
            +
             | 
| 621 | 
            +
                def build_1d_mask(self, length, left_bound, right_bound, border_width):
         | 
| 622 | 
            +
                    x = torch.ones((length,))
         | 
| 623 | 
            +
                    if not left_bound:
         | 
| 624 | 
            +
                        x[:border_width] = (torch.arange(border_width) + 1) / border_width
         | 
| 625 | 
            +
                    if not right_bound:
         | 
| 626 | 
            +
                        x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
         | 
| 627 | 
            +
                    return x
         | 
| 628 | 
            +
             | 
| 629 | 
            +
             | 
| 630 | 
            +
                def build_mask(self, data, is_bound, border_width):
         | 
| 631 | 
            +
                    _, _, _, H, W = data.shape
         | 
| 632 | 
            +
                    h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
         | 
| 633 | 
            +
                    w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                    h = repeat(h, "H -> H W", H=H, W=W)
         | 
| 636 | 
            +
                    w = repeat(w, "W -> H W", H=H, W=W)
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                    mask = torch.stack([h, w]).min(dim=0).values
         | 
| 639 | 
            +
                    mask = rearrange(mask, "H W -> 1 1 1 H W")
         | 
| 640 | 
            +
                    return mask
         | 
| 641 | 
            +
             | 
| 642 | 
            +
             | 
| 643 | 
            +
                def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
         | 
| 644 | 
            +
                    _, _, T, H, W = hidden_states.shape
         | 
| 645 | 
            +
                    size_h, size_w = tile_size
         | 
| 646 | 
            +
                    stride_h, stride_w = tile_stride
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    # Split tasks
         | 
| 649 | 
            +
                    tasks = []
         | 
| 650 | 
            +
                    for h in range(0, H, stride_h):
         | 
| 651 | 
            +
                        if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
         | 
| 652 | 
            +
                        for w in range(0, W, stride_w):
         | 
| 653 | 
            +
                            if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
         | 
| 654 | 
            +
                            h_, w_ = h + size_h, w + size_w
         | 
| 655 | 
            +
                            tasks.append((h, h_, w, w_))
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    data_device = "cpu"
         | 
| 658 | 
            +
                    computation_device = device
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    out_T = T * 4 - 3
         | 
| 661 | 
            +
                    weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
         | 
| 662 | 
            +
                    values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                    for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
         | 
| 665 | 
            +
                        hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
         | 
| 666 | 
            +
                        hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                        mask = self.build_mask(
         | 
| 669 | 
            +
                            hidden_states_batch,
         | 
| 670 | 
            +
                            is_bound=(h==0, h_>=H, w==0, w_>=W),
         | 
| 671 | 
            +
                            border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
         | 
| 672 | 
            +
                        ).to(dtype=hidden_states.dtype, device=data_device)
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                        target_h = h * self.upsampling_factor
         | 
| 675 | 
            +
                        target_w = w * self.upsampling_factor
         | 
| 676 | 
            +
                        values[
         | 
| 677 | 
            +
                            :,
         | 
| 678 | 
            +
                            :,
         | 
| 679 | 
            +
                            :,
         | 
| 680 | 
            +
                            target_h:target_h + hidden_states_batch.shape[3],
         | 
| 681 | 
            +
                            target_w:target_w + hidden_states_batch.shape[4],
         | 
| 682 | 
            +
                        ] += hidden_states_batch * mask
         | 
| 683 | 
            +
                        weight[
         | 
| 684 | 
            +
                            :,
         | 
| 685 | 
            +
                            :,
         | 
| 686 | 
            +
                            :,
         | 
| 687 | 
            +
                            target_h: target_h + hidden_states_batch.shape[3],
         | 
| 688 | 
            +
                            target_w: target_w + hidden_states_batch.shape[4],
         | 
| 689 | 
            +
                        ] += mask
         | 
| 690 | 
            +
                    values = values / weight
         | 
| 691 | 
            +
                    values = values.float().clamp_(-1, 1)
         | 
| 692 | 
            +
                    return values
         | 
| 693 | 
            +
             | 
| 694 | 
            +
             | 
| 695 | 
            +
                def tiled_encode(self, video, device, tile_size, tile_stride):
         | 
| 696 | 
            +
                    _, _, T, H, W = video.shape
         | 
| 697 | 
            +
                    size_h, size_w = tile_size
         | 
| 698 | 
            +
                    stride_h, stride_w = tile_stride
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    # Split tasks
         | 
| 701 | 
            +
                    tasks = []
         | 
| 702 | 
            +
                    for h in range(0, H, stride_h):
         | 
| 703 | 
            +
                        if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
         | 
| 704 | 
            +
                        for w in range(0, W, stride_w):
         | 
| 705 | 
            +
                            if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
         | 
| 706 | 
            +
                            h_, w_ = h + size_h, w + size_w
         | 
| 707 | 
            +
                            tasks.append((h, h_, w, w_))
         | 
| 708 | 
            +
             | 
| 709 | 
            +
                    data_device = "cpu"
         | 
| 710 | 
            +
                    computation_device = device
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    out_T = (T + 3) // 4
         | 
| 713 | 
            +
                    weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
         | 
| 714 | 
            +
                    values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
         | 
| 717 | 
            +
                        hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
         | 
| 718 | 
            +
                        hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                        mask = self.build_mask(
         | 
| 721 | 
            +
                            hidden_states_batch,
         | 
| 722 | 
            +
                            is_bound=(h==0, h_>=H, w==0, w_>=W),
         | 
| 723 | 
            +
                            border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
         | 
| 724 | 
            +
                        ).to(dtype=video.dtype, device=data_device)
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                        target_h = h // self.upsampling_factor
         | 
| 727 | 
            +
                        target_w = w // self.upsampling_factor
         | 
| 728 | 
            +
                        values[
         | 
| 729 | 
            +
                            :,
         | 
| 730 | 
            +
                            :,
         | 
| 731 | 
            +
                            :,
         | 
| 732 | 
            +
                            target_h:target_h + hidden_states_batch.shape[3],
         | 
| 733 | 
            +
                            target_w:target_w + hidden_states_batch.shape[4],
         | 
| 734 | 
            +
                        ] += hidden_states_batch * mask
         | 
| 735 | 
            +
                        weight[
         | 
| 736 | 
            +
                            :,
         | 
| 737 | 
            +
                            :,
         | 
| 738 | 
            +
                            :,
         | 
| 739 | 
            +
                            target_h: target_h + hidden_states_batch.shape[3],
         | 
| 740 | 
            +
                            target_w: target_w + hidden_states_batch.shape[4],
         | 
| 741 | 
            +
                        ] += mask
         | 
| 742 | 
            +
                    values = values / weight
         | 
| 743 | 
            +
                    values = values.float()
         | 
| 744 | 
            +
                    return values
         | 
| 745 | 
            +
             | 
| 746 | 
            +
             | 
| 747 | 
            +
                def single_encode(self, video, device):
         | 
| 748 | 
            +
                    video = video.to(device)
         | 
| 749 | 
            +
                    x = self.model.encode(video, self.scale)
         | 
| 750 | 
            +
                    return x.float()
         | 
| 751 | 
            +
             | 
| 752 | 
            +
             | 
| 753 | 
            +
                def single_decode(self, hidden_state, device):
         | 
| 754 | 
            +
                    hidden_state = hidden_state.to(device)
         | 
| 755 | 
            +
                    video = self.model.decode(hidden_state, self.scale)
         | 
| 756 | 
            +
                    return video.float().clamp_(-1, 1)
         | 
| 757 | 
            +
             | 
| 758 | 
            +
             | 
| 759 | 
            +
                def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                    videos = [video.to("cpu") for video in videos]
         | 
| 762 | 
            +
                    hidden_states = []
         | 
| 763 | 
            +
                    for video in videos:
         | 
| 764 | 
            +
                        video = video.unsqueeze(0)
         | 
| 765 | 
            +
                        if tiled:
         | 
| 766 | 
            +
                            tile_size = (tile_size[0] * 8, tile_size[1] * 8)
         | 
| 767 | 
            +
                            tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
         | 
| 768 | 
            +
                            hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
         | 
| 769 | 
            +
                        else:
         | 
| 770 | 
            +
                            hidden_state = self.single_encode(video, device)
         | 
| 771 | 
            +
                        hidden_state = hidden_state.squeeze(0)
         | 
| 772 | 
            +
                        hidden_states.append(hidden_state)
         | 
| 773 | 
            +
                    hidden_states = torch.stack(hidden_states)
         | 
| 774 | 
            +
                    return hidden_states
         | 
| 775 | 
            +
             | 
| 776 | 
            +
             | 
| 777 | 
            +
                def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
         | 
| 778 | 
            +
                    hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
         | 
| 779 | 
            +
                    videos = []
         | 
| 780 | 
            +
                    for hidden_state in hidden_states:
         | 
| 781 | 
            +
                        hidden_state = hidden_state.unsqueeze(0)
         | 
| 782 | 
            +
                        if tiled:
         | 
| 783 | 
            +
                            video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
         | 
| 784 | 
            +
                        else:
         | 
| 785 | 
            +
                            video = self.single_decode(hidden_state, device)
         | 
| 786 | 
            +
                        video = video.squeeze(0)
         | 
| 787 | 
            +
                        videos.append(video)
         | 
| 788 | 
            +
                    videos = torch.stack(videos)
         | 
| 789 | 
            +
                    return videos
         | 
| 790 | 
            +
             | 
| 791 | 
            +
             | 
| 792 | 
            +
                @staticmethod
         | 
| 793 | 
            +
                def state_dict_converter():
         | 
| 794 | 
            +
                    return WanVideoVAEStateDictConverter()
         | 
| 795 | 
            +
             | 
| 796 | 
            +
             | 
| 797 | 
            +
            class WanVideoVAEStateDictConverter:
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                def __init__(self):
         | 
| 800 | 
            +
                    pass
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                def from_civitai(self, state_dict):
         | 
| 803 | 
            +
                    state_dict_ = {}
         | 
| 804 | 
            +
                    if 'model_state' in state_dict:
         | 
| 805 | 
            +
                        state_dict = state_dict['model_state']
         | 
| 806 | 
            +
                    for name in state_dict:
         | 
| 807 | 
            +
                        state_dict_['model.' + name] = state_dict[name]
         | 
| 808 | 
            +
                    return state_dict_
         | 
| 809 | 
            +
                
         | 
    	
        pipeline/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        pipeline/i2v_pipeline.py
    ADDED
    
    | @@ -0,0 +1,511 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from diffsynth import ModelManager
         | 
| 2 | 
            +
            from diffsynth.pipelines.base import BasePipeline
         | 
| 3 | 
            +
            from diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from model.dit import WanModel
         | 
| 6 | 
            +
            from model.text_encoder import WanTextEncoder
         | 
| 7 | 
            +
            from model.vae import WanVideoVAE
         | 
| 8 | 
            +
            from model.image_encoder import WanImageEncoder
         | 
| 9 | 
            +
            from model.prompter import WanPrompter
         | 
| 10 | 
            +
            from scheduler.flow_match import FlowMatchScheduler
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch, os
         | 
| 13 | 
            +
            from einops import rearrange, repeat
         | 
| 14 | 
            +
            import numpy as np
         | 
| 15 | 
            +
            import PIL.Image
         | 
| 16 | 
            +
            from tqdm import tqdm
         | 
| 17 | 
            +
            from safetensors import safe_open
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from model.text_encoder import T5RelativeEmbedding, T5LayerNorm
         | 
| 20 | 
            +
            from model.dit import WanLayerNorm, WanRMSNorm, WanSelfAttention
         | 
| 21 | 
            +
            from model.vae import RMS_norm, CausalConv3d, Upsample
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def binary_tensor_to_indices(tensor):
         | 
| 25 | 
            +
                assert tensor.dim() == 2, "Input tensor must be in [b, t]"
         | 
| 26 | 
            +
                indices = [(row == 1).nonzero(as_tuple=True)[0] for row in tensor]
         | 
| 27 | 
            +
                return indices
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            def propagate_visualize_attention_arg(model, visualize_attention=False):
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    Recursively set the visualize_attention parameter to True for all WanSelfAttention modules
         | 
| 32 | 
            +
                    Only for inference/test mode
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    for name, module in model.named_modules():
         | 
| 35 | 
            +
                        if isinstance(module, WanSelfAttention):
         | 
| 36 | 
            +
                            if "blocks.0.self_attn" in name or "blocks.19.self_attn" in name or "blocks.39.self_attn" in name:
         | 
| 37 | 
            +
                                print(f"Set `visualize_attention` to {visualize_attention} for {name}")
         | 
| 38 | 
            +
                                module.visualize_attention = visualize_attention
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            class WanVideoPipeline(BasePipeline):
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
         | 
| 43 | 
            +
                    super().__init__(device=device, torch_dtype=torch_dtype)
         | 
| 44 | 
            +
                    self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
         | 
| 45 | 
            +
                    self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
         | 
| 46 | 
            +
                    self.text_encoder: WanTextEncoder = None
         | 
| 47 | 
            +
                    self.image_encoder: WanImageEncoder = None
         | 
| 48 | 
            +
                    self.dit: WanModel = None
         | 
| 49 | 
            +
                    self.vae: WanVideoVAE = None
         | 
| 50 | 
            +
                    self.model_names = ['text_encoder', 'dit', 'vae']
         | 
| 51 | 
            +
                    self.height_division_factor = 16
         | 
| 52 | 
            +
                    self.width_division_factor = 16
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
                def enable_vram_management(self, num_persistent_param_in_dit=None):
         | 
| 56 | 
            +
                    dtype = next(iter(self.text_encoder.parameters())).dtype
         | 
| 57 | 
            +
                    enable_vram_management(
         | 
| 58 | 
            +
                        self.text_encoder,
         | 
| 59 | 
            +
                        module_map = {
         | 
| 60 | 
            +
                            torch.nn.Linear: AutoWrappedLinear,
         | 
| 61 | 
            +
                            torch.nn.Embedding: AutoWrappedModule,
         | 
| 62 | 
            +
                            T5RelativeEmbedding: AutoWrappedModule,
         | 
| 63 | 
            +
                            T5LayerNorm: AutoWrappedModule,
         | 
| 64 | 
            +
                        },
         | 
| 65 | 
            +
                        module_config = dict(
         | 
| 66 | 
            +
                            offload_dtype=dtype,
         | 
| 67 | 
            +
                            offload_device="cpu",
         | 
| 68 | 
            +
                            onload_dtype=dtype,
         | 
| 69 | 
            +
                            onload_device="cpu",
         | 
| 70 | 
            +
                            computation_dtype=self.torch_dtype,
         | 
| 71 | 
            +
                            computation_device=self.device,
         | 
| 72 | 
            +
                        ),
         | 
| 73 | 
            +
                    )
         | 
| 74 | 
            +
                    dtype = next(iter(self.dit.parameters())).dtype
         | 
| 75 | 
            +
                    enable_vram_management(
         | 
| 76 | 
            +
                        self.dit,
         | 
| 77 | 
            +
                        module_map = {
         | 
| 78 | 
            +
                            torch.nn.Linear: AutoWrappedLinear,
         | 
| 79 | 
            +
                            torch.nn.Conv3d: AutoWrappedModule,
         | 
| 80 | 
            +
                            torch.nn.LayerNorm: AutoWrappedModule,
         | 
| 81 | 
            +
                            WanLayerNorm: AutoWrappedModule,
         | 
| 82 | 
            +
                            WanRMSNorm: AutoWrappedModule,
         | 
| 83 | 
            +
                        },
         | 
| 84 | 
            +
                        module_config = dict(
         | 
| 85 | 
            +
                            offload_dtype=dtype,
         | 
| 86 | 
            +
                            offload_device="cpu",
         | 
| 87 | 
            +
                            onload_dtype=dtype,
         | 
| 88 | 
            +
                            onload_device=self.device,
         | 
| 89 | 
            +
                            computation_dtype=self.torch_dtype,
         | 
| 90 | 
            +
                            computation_device=self.device,
         | 
| 91 | 
            +
                        ),
         | 
| 92 | 
            +
                        max_num_param=num_persistent_param_in_dit,
         | 
| 93 | 
            +
                        overflow_module_config = dict(
         | 
| 94 | 
            +
                            offload_dtype=dtype,
         | 
| 95 | 
            +
                            offload_device="cpu",
         | 
| 96 | 
            +
                            onload_dtype=dtype,
         | 
| 97 | 
            +
                            onload_device="cpu",
         | 
| 98 | 
            +
                            computation_dtype=self.torch_dtype,
         | 
| 99 | 
            +
                            computation_device=self.device,
         | 
| 100 | 
            +
                        ),
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                    dtype = next(iter(self.vae.parameters())).dtype
         | 
| 103 | 
            +
                    enable_vram_management(
         | 
| 104 | 
            +
                        self.vae,
         | 
| 105 | 
            +
                        module_map = {
         | 
| 106 | 
            +
                            torch.nn.Linear: AutoWrappedLinear,
         | 
| 107 | 
            +
                            torch.nn.Conv2d: AutoWrappedModule,
         | 
| 108 | 
            +
                            RMS_norm: AutoWrappedModule,
         | 
| 109 | 
            +
                            CausalConv3d: AutoWrappedModule,
         | 
| 110 | 
            +
                            Upsample: AutoWrappedModule,
         | 
| 111 | 
            +
                            torch.nn.SiLU: AutoWrappedModule,
         | 
| 112 | 
            +
                            torch.nn.Dropout: AutoWrappedModule,
         | 
| 113 | 
            +
                        },
         | 
| 114 | 
            +
                        module_config = dict(
         | 
| 115 | 
            +
                            offload_dtype=dtype,
         | 
| 116 | 
            +
                            offload_device="cpu",
         | 
| 117 | 
            +
                            onload_dtype=dtype,
         | 
| 118 | 
            +
                            onload_device=self.device,
         | 
| 119 | 
            +
                            computation_dtype=self.torch_dtype,
         | 
| 120 | 
            +
                            computation_device=self.device,
         | 
| 121 | 
            +
                        ),
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
                    if self.image_encoder is not None:
         | 
| 124 | 
            +
                        dtype = next(iter(self.image_encoder.parameters())).dtype
         | 
| 125 | 
            +
                        enable_vram_management(
         | 
| 126 | 
            +
                            self.image_encoder,
         | 
| 127 | 
            +
                            module_map = {
         | 
| 128 | 
            +
                                torch.nn.Linear: AutoWrappedLinear,
         | 
| 129 | 
            +
                                torch.nn.Conv2d: AutoWrappedModule,
         | 
| 130 | 
            +
                                torch.nn.LayerNorm: AutoWrappedModule,
         | 
| 131 | 
            +
                            },
         | 
| 132 | 
            +
                            module_config = dict(
         | 
| 133 | 
            +
                                offload_dtype=dtype,
         | 
| 134 | 
            +
                                offload_device="cpu",
         | 
| 135 | 
            +
                                onload_dtype=dtype,
         | 
| 136 | 
            +
                                onload_device="cpu",
         | 
| 137 | 
            +
                                computation_dtype=self.torch_dtype,
         | 
| 138 | 
            +
                                computation_device=self.device,
         | 
| 139 | 
            +
                            ),
         | 
| 140 | 
            +
                        )
         | 
| 141 | 
            +
                    self.enable_cpu_offload()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                def fetch_models_from_model_manager(self, model_manager: ModelManager):
         | 
| 144 | 
            +
                    text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
         | 
| 145 | 
            +
                    if text_encoder_model_and_path is not None:
         | 
| 146 | 
            +
                        self.text_encoder, tokenizer_path = text_encoder_model_and_path
         | 
| 147 | 
            +
                        self.prompter.fetch_models(self.text_encoder)
         | 
| 148 | 
            +
                        self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
         | 
| 149 | 
            +
                    self.dit = model_manager.fetch_model("wan_video_dit")
         | 
| 150 | 
            +
                    self.vae = model_manager.fetch_model("wan_video_vae")
         | 
| 151 | 
            +
                    self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
         | 
| 152 | 
            +
                
         | 
| 153 | 
            +
                def _init_component_from_checkpoint_path(self, model_cls, state_dict_path, strict=True, config_dict=None):
         | 
| 154 | 
            +
                    config = {}
         | 
| 155 | 
            +
                    state_dict = self._load_state_dict(state_dict_path)
         | 
| 156 | 
            +
                    if hasattr(model_cls, "state_dict_converter"):
         | 
| 157 | 
            +
                        state_dict_converter = model_cls.state_dict_converter()
         | 
| 158 | 
            +
                        state_dict = state_dict_converter.from_civitai(state_dict)
         | 
| 159 | 
            +
                        if isinstance(state_dict, tuple):
         | 
| 160 | 
            +
                            state_dict, config = state_dict
         | 
| 161 | 
            +
                    config.update(config_dict or {})
         | 
| 162 | 
            +
                    model = model_cls(**config)
         | 
| 163 | 
            +
                    if "use_local_lora" in config_dict or "use_dera" in config_dict:
         | 
| 164 | 
            +
                        strict = False
         | 
| 165 | 
            +
                    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
         | 
| 166 | 
            +
                    print(f"Missing keys: {missing_keys}")
         | 
| 167 | 
            +
                    print(f"Unexpected keys: {unexpected_keys}")
         | 
| 168 | 
            +
                    return model
         | 
| 169 | 
            +
                
         | 
| 170 | 
            +
                def _load_state_dict(self, state_dict_paths):
         | 
| 171 | 
            +
                    if isinstance(state_dict_paths, str):
         | 
| 172 | 
            +
                        state_dict_paths = [state_dict_paths]
         | 
| 173 | 
            +
                    state_dict = {}
         | 
| 174 | 
            +
                    for state_dict_path in tqdm(state_dict_paths, desc="Reading file(s) from disk"):
         | 
| 175 | 
            +
                        state_dict.update(self._load_single_file(state_dict_path))
         | 
| 176 | 
            +
                    return state_dict
         | 
| 177 | 
            +
                
         | 
| 178 | 
            +
                def _load_single_file(self, file_path):
         | 
| 179 | 
            +
                    if file_path.endswith(".safetensors"):
         | 
| 180 | 
            +
                        return self._load_state_dict_from_safetensors(file_path)
         | 
| 181 | 
            +
                    else:
         | 
| 182 | 
            +
                        return torch.load(file_path, map_location='cpu')
         | 
| 183 | 
            +
                
         | 
| 184 | 
            +
                def _load_state_dict_from_safetensors(self, file_path, torch_dtype=None):
         | 
| 185 | 
            +
                    state_dict = {}
         | 
| 186 | 
            +
                    with safe_open(file_path, framework="pt", device="cpu") as f:
         | 
| 187 | 
            +
                        for k in f.keys():
         | 
| 188 | 
            +
                            state_dict[k] = f.get_tensor(k)
         | 
| 189 | 
            +
                            if torch_dtype is not None:
         | 
| 190 | 
            +
                                state_dict[k] = state_dict[k].to(torch_dtype)
         | 
| 191 | 
            +
                    return state_dict
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                def initialize_dummy_dit(self, config):
         | 
| 194 | 
            +
                    print("Initializing a dummy DIT model.")
         | 
| 195 | 
            +
                    self.dit = WanModel(**config)
         | 
| 196 | 
            +
                    print("Dummy DIT model is initialized.")
         | 
| 197 | 
            +
                
         | 
| 198 | 
            +
                def fetch_models_from_checkpoints(self, path_dict, config_dict=None):
         | 
| 199 | 
            +
                    default_config = {"text_encoder": {}, "dit": {}, "vae": {}, "image_encoder": {}}
         | 
| 200 | 
            +
                    config_dict = {**default_config, **(config_dict or {})}
         | 
| 201 | 
            +
                    components = {
         | 
| 202 | 
            +
                        "text_encoder": WanTextEncoder,
         | 
| 203 | 
            +
                        "dit": WanModel,
         | 
| 204 | 
            +
                        "vae": WanVideoVAE,
         | 
| 205 | 
            +
                        "image_encoder": WanImageEncoder
         | 
| 206 | 
            +
                    }
         | 
| 207 | 
            +
                    for name, model_cls in components.items():
         | 
| 208 | 
            +
                        if name not in path_dict:
         | 
| 209 | 
            +
                            print(f"Component {name} is not found in the checkpoint path dict. Skipping.")
         | 
| 210 | 
            +
                            continue
         | 
| 211 | 
            +
                        path = path_dict[name]
         | 
| 212 | 
            +
                        config = config_dict.get(name, {})
         | 
| 213 | 
            +
                        print(f"Loading {name} from {path} with config {config}.")
         | 
| 214 | 
            +
                        setattr(self, name, self._init_component_from_checkpoint_path(model_cls, path, config_dict=config))
         | 
| 215 | 
            +
                        print(f"Initialized {name} from checkpoint.")
         | 
| 216 | 
            +
                    if "text_encoder" in path_dict:
         | 
| 217 | 
            +
                        self.prompter.fetch_models(self.text_encoder)
         | 
| 218 | 
            +
                        self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(path_dict["text_encoder"]), "google/umt5-xxl"))
         | 
| 219 | 
            +
                    print("Initialized prompter from checkpoint.")
         | 
| 220 | 
            +
                    print("All components are initialized from checkpoints.")
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                @staticmethod
         | 
| 223 | 
            +
                def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
         | 
| 224 | 
            +
                    if device is None: device = model_manager.device
         | 
| 225 | 
            +
                    if torch_dtype is None: torch_dtype = model_manager.torch_dtype
         | 
| 226 | 
            +
                    pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
         | 
| 227 | 
            +
                    pipe.fetch_models_from_model_manager(model_manager)
         | 
| 228 | 
            +
                    return pipe
         | 
| 229 | 
            +
                
         | 
| 230 | 
            +
                def denoising_model(self):
         | 
| 231 | 
            +
                    return self.dit
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def encode_prompt(self, prompt, positive=True):
         | 
| 234 | 
            +
                    prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
         | 
| 235 | 
            +
                    return {"context": prompt_emb}
         | 
| 236 | 
            +
                
         | 
| 237 | 
            +
                def encode_image(self, image, num_frames, height, width):
         | 
| 238 | 
            +
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
         | 
| 239 | 
            +
                        image = self.preprocess_image(image.resize((width, height))).to(self.device)
         | 
| 240 | 
            +
                        clip_context = self.image_encoder.encode_image([image])
         | 
| 241 | 
            +
                        msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
         | 
| 242 | 
            +
                        msk[:, 1:] = 0
         | 
| 243 | 
            +
                        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
         | 
| 244 | 
            +
                        msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
         | 
| 245 | 
            +
                        msk = msk.transpose(1, 2)[0]
         | 
| 246 | 
            +
                        y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
         | 
| 247 | 
            +
                        y = torch.concat([msk, y])
         | 
| 248 | 
            +
                    return {"clip_fea": clip_context, "y": [y]}
         | 
| 249 | 
            +
                
         | 
| 250 | 
            +
                def check_and_fix_image_or_video_tensor_input(self, _tensor):
         | 
| 251 | 
            +
                    assert isinstance(_tensor, torch.Tensor), "Input must be a tensor."
         | 
| 252 | 
            +
                    if _tensor.max() <= 255 and _tensor.max() > 1.0:
         | 
| 253 | 
            +
                        _tensor = _tensor.to(self.device) / 127.5 - 1
         | 
| 254 | 
            +
                        print("Input tensor is converted from [0, 255] to [-1, 1].")
         | 
| 255 | 
            +
                    elif _tensor.min() >= 0 and _tensor.max() <= 1:
         | 
| 256 | 
            +
                        _tensor = _tensor.to(self.device) * 2 - 1
         | 
| 257 | 
            +
                        print("Input tensor is converted from [0, 1] to [-1, 1].")
         | 
| 258 | 
            +
                    return _tensor
         | 
| 259 | 
            +
                
         | 
| 260 | 
            +
                def encode_video_with_mask(self, video, num_frames, height, width, condition_preserved_mask):
         | 
| 261 | 
            +
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
         | 
| 262 | 
            +
                        video = video.to(self.device)
         | 
| 263 | 
            +
                        y = self.vae.encode(video, device=self.device)
         | 
| 264 | 
            +
                        msk = condition_preserved_mask
         | 
| 265 | 
            +
                        assert msk is not None, "The mask must be provided for the masked video input."
         | 
| 266 | 
            +
                        assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]."
         | 
| 267 | 
            +
                        assert msk.shape[0] == video.shape[0], "The batch size of the mask must be the same as the input video."
         | 
| 268 | 
            +
                        assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video."
         | 
| 269 | 
            +
                        msk = msk.to(self.device)
         | 
| 270 | 
            +
                        msk = msk.unsqueeze(-1).unsqueeze(-1)
         | 
| 271 | 
            +
                        msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
         | 
| 272 | 
            +
                        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
         | 
| 273 | 
            +
                        msk = msk.view(video.shape[0], msk.shape[1] // 4, 4, height//8, width//8)  # b, t, c, h, w
         | 
| 274 | 
            +
                        msk = msk.transpose(1, 2)  # b, c, t, h, w
         | 
| 275 | 
            +
                        y = torch.concat([msk, y], dim=1)
         | 
| 276 | 
            +
                    return y
         | 
| 277 | 
            +
                
         | 
| 278 | 
            +
                def encode_video_with_mask_sparse(self, video, height, width, condition_preserved_mask, sketch_local_mask=None):
         | 
| 279 | 
            +
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
         | 
| 280 | 
            +
                        batch_size = video.shape[0]
         | 
| 281 | 
            +
                        cond_indices = binary_tensor_to_indices(condition_preserved_mask)
         | 
| 282 | 
            +
                        sequence_cond_compressed_indices = [(cond_index + 3) // 4 for cond_index in cond_indices]
         | 
| 283 | 
            +
                        video = video.to(self.device)
         | 
| 284 | 
            +
                        video_latent = self.vae.encode(video, device=self.device)
         | 
| 285 | 
            +
                        video_latent = video_latent[:, :, sequence_cond_compressed_indices[0], :, :]
         | 
| 286 | 
            +
                        msk = condition_preserved_mask.to(self.device)
         | 
| 287 | 
            +
                        msk = msk.unsqueeze(-1).unsqueeze(-1)  # b, t, 1, 1
         | 
| 288 | 
            +
                        msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
         | 
| 289 | 
            +
                        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
         | 
| 290 | 
            +
                        msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8)  # b, t, 4, h//8, w//8
         | 
| 291 | 
            +
                        msk = msk.transpose(1, 2)  # b, 4, t, h//8, w//8
         | 
| 292 | 
            +
                        msk = msk[:, :, sequence_cond_compressed_indices[0], :, :]
         | 
| 293 | 
            +
                        
         | 
| 294 | 
            +
                        if sketch_local_mask is not None:
         | 
| 295 | 
            +
                            sketch_local_mask = sketch_local_mask.to(self.device)
         | 
| 296 | 
            +
                            if sketch_local_mask.shape[-2:] != (height//8, width//8):
         | 
| 297 | 
            +
                                sk_batch_t = sketch_local_mask.shape[0] * sketch_local_mask.shape[2]
         | 
| 298 | 
            +
                                sketch_local_mask_reshaped = sketch_local_mask.reshape(sk_batch_t, 1, sketch_local_mask.shape[3], sketch_local_mask.shape[4])
         | 
| 299 | 
            +
                                sketch_local_mask_resized = torch.nn.functional.interpolate(
         | 
| 300 | 
            +
                                    sketch_local_mask_reshaped,
         | 
| 301 | 
            +
                                    size=(height//8, width//8), 
         | 
| 302 | 
            +
                                    mode='nearest'
         | 
| 303 | 
            +
                                )
         | 
| 304 | 
            +
                                sketch_local_mask_resized = sketch_local_mask_resized.reshape(
         | 
| 305 | 
            +
                                    sketch_local_mask.shape[0], 
         | 
| 306 | 
            +
                                    sketch_local_mask.shape[1], 
         | 
| 307 | 
            +
                                    sketch_local_mask.shape[2], 
         | 
| 308 | 
            +
                                    height//8, width//8
         | 
| 309 | 
            +
                                )
         | 
| 310 | 
            +
                            else:
         | 
| 311 | 
            +
                                sketch_local_mask_resized = sketch_local_mask
         | 
| 312 | 
            +
                                
         | 
| 313 | 
            +
                            sketch_mask = sketch_local_mask_resized
         | 
| 314 | 
            +
                            sketch_mask = torch.concat([torch.repeat_interleave(sketch_mask[:, :, 0:1], repeats=4, dim=2), sketch_mask[:, :, 1:]], dim=2)
         | 
| 315 | 
            +
                            sketch_mask = sketch_mask.view(batch_size, sketch_mask.shape[1], sketch_mask.shape[2] // 4, 4, height//8, width//8)
         | 
| 316 | 
            +
                            sketch_mask = sketch_mask.permute(0, 1, 3, 2, 4, 5)  # [b, 1, 4, t//4, h//8, w//8]
         | 
| 317 | 
            +
                            sketch_mask = sketch_mask.view(batch_size, 4, sketch_mask.shape[3], height//8, width//8)  # [b, 4, t//4, h//8, w//8]
         | 
| 318 | 
            +
                            sketch_mask = sketch_mask[:, :, sequence_cond_compressed_indices[0], :, :]  # [b, 4, len(indices), h//8, w//8]
         | 
| 319 | 
            +
                            
         | 
| 320 | 
            +
                            combined_latent = torch.cat([msk, video_latent, sketch_mask], dim=1)
         | 
| 321 | 
            +
                        else:
         | 
| 322 | 
            +
                            combined_latent = torch.concat([msk, video_latent], dim=1)
         | 
| 323 | 
            +
                        
         | 
| 324 | 
            +
                    return combined_latent, sequence_cond_compressed_indices  # b, c=(4+16+4=24), t, h, w when sketch_local_mask is provided
         | 
| 325 | 
            +
                
         | 
| 326 | 
            +
                def encode_image_or_masked_video(self, image_or_masked_video, num_frames, height, width, condition_preserved_mask=None):
         | 
| 327 | 
            +
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
         | 
| 328 | 
            +
                        batch_size = image_or_masked_video.shape[0]
         | 
| 329 | 
            +
                        if isinstance(image_or_masked_video, PIL.Image.Image) or (isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() <= 4):
         | 
| 330 | 
            +
                            if isinstance(image_or_masked_video, PIL.Image.Image):
         | 
| 331 | 
            +
                                image_or_masked_video = self.preprocess_image(image_or_masked_video.resize((width, height))).to(self.device)
         | 
| 332 | 
            +
                            else:
         | 
| 333 | 
            +
                                if image_or_masked_video.dim() == 3:
         | 
| 334 | 
            +
                                    image_or_masked_video = image_or_masked_video.unsqueeze(0)  # b=1, c, h, w
         | 
| 335 | 
            +
                                image_or_masked_video = image_or_masked_video.to(self.device)
         | 
| 336 | 
            +
                            y = self.vae.encode([torch.concat([image_or_masked_video.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_or_masked_video.device)], dim=1)], device=self.device)
         | 
| 337 | 
            +
                            msk_idx_to_be_zero = range(1, num_frames)
         | 
| 338 | 
            +
                            clip_context = self.image_encoder.encode_image(image_or_masked_video.unsqueeze(1))  # need to be [b, 1, c, h, w]
         | 
| 339 | 
            +
                            msk = torch.ones(batch_size, num_frames, height//8, width//8, device=self.device)
         | 
| 340 | 
            +
                            msk[:, msk_idx_to_be_zero] = 0
         | 
| 341 | 
            +
                            msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
         | 
| 342 | 
            +
                            msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8)
         | 
| 343 | 
            +
                            msk = msk.transpose(1, 2)
         | 
| 344 | 
            +
                        elif isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() == 5:
         | 
| 345 | 
            +
                            image_or_masked_video = image_or_masked_video.to(self.device)
         | 
| 346 | 
            +
                            first_image = image_or_masked_video[:, :, 0, :, :].unsqueeze(1)
         | 
| 347 | 
            +
                            clip_context = self.image_encoder.encode_image(first_image)
         | 
| 348 | 
            +
                            y = self.vae.encode(image_or_masked_video, device=self.device)
         | 
| 349 | 
            +
                            msk = condition_preserved_mask  # b, t
         | 
| 350 | 
            +
                            assert msk is not None, "The mask must be provided for the masked video input."
         | 
| 351 | 
            +
                            assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]."
         | 
| 352 | 
            +
                            assert msk.shape[0] == batch_size, "The batch size of the mask must be the same as the input video."
         | 
| 353 | 
            +
                            assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video."
         | 
| 354 | 
            +
                            msk = msk.to(self.device)
         | 
| 355 | 
            +
                            msk = msk.unsqueeze(-1).unsqueeze(-1)  # b, t, 1, 1
         | 
| 356 | 
            +
                            msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8)
         | 
| 357 | 
            +
                            msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
         | 
| 358 | 
            +
                            msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8)  # b, t, 4, h//8, w//8
         | 
| 359 | 
            +
                            msk = msk.transpose(1, 2)  # b, 4, t, h//8, w//8
         | 
| 360 | 
            +
                        else:
         | 
| 361 | 
            +
                            raise ValueError("Input must be an image (PIL/Tensor in [b, c, h, w]) or a masked video (Tensor in [b, c, t, h, w]).")
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    y = torch.concat([msk, y], dim=1)
         | 
| 364 | 
            +
                    return {"clip_fea": clip_context, "y": y}
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def tensor2video(self, frames):
         | 
| 367 | 
            +
                    frames = rearrange(frames, "C T H W -> T H W C")
         | 
| 368 | 
            +
                    frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
         | 
| 369 | 
            +
                    frames = [PIL.Image.fromarray(frame) for frame in frames]
         | 
| 370 | 
            +
                    return frames
         | 
| 371 | 
            +
                
         | 
| 372 | 
            +
                def prepare_extra_input(self, latents=None):
         | 
| 373 | 
            +
                    return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
         | 
| 374 | 
            +
                
         | 
| 375 | 
            +
                def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
         | 
| 376 | 
            +
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
         | 
| 377 | 
            +
                        latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
         | 
| 378 | 
            +
                    return latents
         | 
| 379 | 
            +
                
         | 
| 380 | 
            +
                def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
         | 
| 381 | 
            +
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
         | 
| 382 | 
            +
                        frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
         | 
| 383 | 
            +
                    return frames
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                @torch.no_grad()
         | 
| 386 | 
            +
                def __call__(
         | 
| 387 | 
            +
                    self,
         | 
| 388 | 
            +
                    prompt,
         | 
| 389 | 
            +
                    negative_prompt="",
         | 
| 390 | 
            +
                    input_image=None,
         | 
| 391 | 
            +
                    input_video=None,
         | 
| 392 | 
            +
                    denoising_strength=1.0,
         | 
| 393 | 
            +
                    seed=None,
         | 
| 394 | 
            +
                    rand_device="cpu",
         | 
| 395 | 
            +
                    height=480,
         | 
| 396 | 
            +
                    width=832,
         | 
| 397 | 
            +
                    num_frames=81,
         | 
| 398 | 
            +
                    cfg_scale=5.0,
         | 
| 399 | 
            +
                    num_inference_steps=50,
         | 
| 400 | 
            +
                    sigma_shift=5.0,
         | 
| 401 | 
            +
                    tiled=True,
         | 
| 402 | 
            +
                    tile_size=(30, 52),
         | 
| 403 | 
            +
                    tile_stride=(15, 26),
         | 
| 404 | 
            +
                    progress_bar_cmd=tqdm,
         | 
| 405 | 
            +
                    # progress_bar_st=None,
         | 
| 406 | 
            +
                    input_condition_video=None,
         | 
| 407 | 
            +
                    input_condition_preserved_mask=None,
         | 
| 408 | 
            +
                    input_condition_video_sketch=None,
         | 
| 409 | 
            +
                    input_condition_preserved_mask_sketch=None,
         | 
| 410 | 
            +
                    sketch_local_mask=None,
         | 
| 411 | 
            +
                    visualize_attention=False,
         | 
| 412 | 
            +
                    output_path=None,
         | 
| 413 | 
            +
                    batch_idx=None,
         | 
| 414 | 
            +
                    sequence_cond_residual_scale=1.0,
         | 
| 415 | 
            +
                ):
         | 
| 416 | 
            +
                    height, width = self.check_resize_height_width(height, width)
         | 
| 417 | 
            +
                    if num_frames % 4 != 1:
         | 
| 418 | 
            +
                        num_frames = (num_frames + 2) // 4 * 4 + 1
         | 
| 419 | 
            +
                        print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
         | 
| 420 | 
            +
                    
         | 
| 421 | 
            +
                    tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
         | 
| 426 | 
            +
                    if input_video is not None:
         | 
| 427 | 
            +
                        self.load_models_to_device(['vae'])
         | 
| 428 | 
            +
                        input_video = self.preprocess_images(input_video)
         | 
| 429 | 
            +
                        input_video = torch.stack(input_video, dim=2)
         | 
| 430 | 
            +
                        latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
         | 
| 431 | 
            +
                        latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
         | 
| 432 | 
            +
                    else:
         | 
| 433 | 
            +
                        latents = noise
         | 
| 434 | 
            +
                    
         | 
| 435 | 
            +
                    self.load_models_to_device(["text_encoder"])
         | 
| 436 | 
            +
                    prompt_emb_posi = self.encode_prompt(prompt, positive=True)
         | 
| 437 | 
            +
                    if cfg_scale != 1.0:
         | 
| 438 | 
            +
                        prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
         | 
| 439 | 
            +
                        
         | 
| 440 | 
            +
                    self.load_models_to_device(["image_encoder", "vae"])
         | 
| 441 | 
            +
                    if input_image is not None and self.image_encoder is not None:
         | 
| 442 | 
            +
                        image_emb = self.encode_image(input_image, num_frames, height, width)
         | 
| 443 | 
            +
                    elif input_condition_video is not None and self.image_encoder is not None:
         | 
| 444 | 
            +
                        assert input_condition_preserved_mask is not None, "`input_condition_preserved_mask` must not be None when `input_condition_video` is given."
         | 
| 445 | 
            +
                        image_emb = self.encode_image_or_masked_video(input_condition_video, num_frames, height, width, input_condition_preserved_mask)
         | 
| 446 | 
            +
                    else:
         | 
| 447 | 
            +
                        image_emb = {}
         | 
| 448 | 
            +
                        
         | 
| 449 | 
            +
                    # Extra input
         | 
| 450 | 
            +
                    extra_input = self.prepare_extra_input(latents)
         | 
| 451 | 
            +
                    if self.dit.use_sequence_cond:
         | 
| 452 | 
            +
                        assert input_condition_video_sketch is not None, "`input_condition_video_sketch` must not be None when `use_sequence_cond` is True."
         | 
| 453 | 
            +
                        assert input_condition_preserved_mask_sketch is not None, "`input_condition_preserved_mask_sketch` must not be None when `input_condition_video_sketch` is given."
         | 
| 454 | 
            +
                        
         | 
| 455 | 
            +
                        if self.dit.sequence_cond_mode == "sparse":
         | 
| 456 | 
            +
                            sequence_cond, sequence_cond_compressed_indices = self.encode_video_with_mask_sparse(input_condition_video_sketch, height, width, input_condition_preserved_mask_sketch, sketch_local_mask)
         | 
| 457 | 
            +
                            extra_input.update({"sequence_cond": sequence_cond,
         | 
| 458 | 
            +
                                                "sequence_cond_compressed_indices": sequence_cond_compressed_indices})
         | 
| 459 | 
            +
                        elif self.dit.sequence_cond_mode == "full":
         | 
| 460 | 
            +
                            sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch)
         | 
| 461 | 
            +
                            extra_input.update({"sequence_cond": sequence_cond})
         | 
| 462 | 
            +
                        else:
         | 
| 463 | 
            +
                            raise ValueError(f"Invalid `sequence_cond_model`={self.dit.sequence_cond_mode} in the DIT model.")
         | 
| 464 | 
            +
                        
         | 
| 465 | 
            +
                    elif self.dit.use_channel_cond:
         | 
| 466 | 
            +
                        sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch)
         | 
| 467 | 
            +
                        extra_input.update({"channel_cond": sequence_cond})
         | 
| 468 | 
            +
                        
         | 
| 469 | 
            +
                    self.load_models_to_device([])
         | 
| 470 | 
            +
                    
         | 
| 471 | 
            +
                    if sequence_cond_residual_scale != 1.0:
         | 
| 472 | 
            +
                        extra_input.update({"sequence_cond_residual_scale": sequence_cond_residual_scale})
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    # Denoise
         | 
| 475 | 
            +
                    self.load_models_to_device(["dit"])
         | 
| 476 | 
            +
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
         | 
| 477 | 
            +
                        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
         | 
| 478 | 
            +
                            timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
         | 
| 479 | 
            +
                            _should_visualize_attention = visualize_attention and (progress_id == len(self.scheduler.timesteps) - 1)
         | 
| 480 | 
            +
                            if _should_visualize_attention:
         | 
| 481 | 
            +
                                print(f"Visualizing attention maps (Step {progress_id + 1}/{len(self.scheduler.timesteps)}).")
         | 
| 482 | 
            +
                                propagate_visualize_attention_arg(self.dit, True)
         | 
| 483 | 
            +
                    
         | 
| 484 | 
            +
                            # Inference
         | 
| 485 | 
            +
                            noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
         | 
| 486 | 
            +
                            if isinstance(noise_pred_posi, tuple):
         | 
| 487 | 
            +
                                noise_pred_posi = noise_pred_posi[0]
         | 
| 488 | 
            +
                            if cfg_scale != 1.0:
         | 
| 489 | 
            +
                                noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
         | 
| 490 | 
            +
                                if isinstance(noise_pred_nega, tuple):
         | 
| 491 | 
            +
                                    noise_pred_nega = noise_pred_nega[0]
         | 
| 492 | 
            +
                                noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
         | 
| 493 | 
            +
                            else:
         | 
| 494 | 
            +
                                noise_pred = noise_pred_posi
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                            # Scheduler
         | 
| 497 | 
            +
                            latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
         | 
| 498 | 
            +
                            
         | 
| 499 | 
            +
                            # If visualization is enabled, save the attention maps
         | 
| 500 | 
            +
                            if _should_visualize_attention:
         | 
| 501 | 
            +
                                print("Saving attention maps...")
         | 
| 502 | 
            +
                                from util.model_util import save_attention_maps
         | 
| 503 | 
            +
                                save_attention_maps(self.dit, output_path, batch_idx, timestep.squeeze().cpu().numpy().item())
         | 
| 504 | 
            +
                                propagate_visualize_attention_arg(self.dit, False)
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    # Decode
         | 
| 507 | 
            +
                    self.load_models_to_device(['vae'])
         | 
| 508 | 
            +
                    frames = self.decode_video(latents, **tiler_kwargs)
         | 
| 509 | 
            +
                    self.load_models_to_device([])
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    return frames
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,148 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            absl-py==2.2.2
         | 
| 2 | 
            +
            accelerate==1.6.0
         | 
| 3 | 
            +
            beartype==0.20.2
         | 
| 4 | 
            +
            beautifulsoup4==4.13.4
         | 
| 5 | 
            +
            braceexpand==0.1.7
         | 
| 6 | 
            +
            cached-property==2.0.1
         | 
| 7 | 
            +
            certifi==2025.1.31
         | 
| 8 | 
            +
            charset-normalizer==3.4.1
         | 
| 9 | 
            +
            click==8.1.8
         | 
| 10 | 
            +
            clip==0.2.0
         | 
| 11 | 
            +
            comm==0.2.3
         | 
| 12 | 
            +
            contourpy==1.3.2
         | 
| 13 | 
            +
            controlnet_aux==0.0.7
         | 
| 14 | 
            +
            crcmod==1.7
         | 
| 15 | 
            +
            cycler==0.12.1
         | 
| 16 | 
            +
            datasets==3.5.0
         | 
| 17 | 
            +
            debugpy==1.8.15
         | 
| 18 | 
            +
            decorator==5.2.1
         | 
| 19 | 
            +
            decord==0.6.0
         | 
| 20 | 
            +
            deepspeed==0.16.7
         | 
| 21 | 
            +
            diffsynth==1.1.7
         | 
| 22 | 
            +
            diffusers==0.33.1
         | 
| 23 | 
            +
            dill==0.3.8
         | 
| 24 | 
            +
            docker-pycreds==0.4.0
         | 
| 25 | 
            +
            dulwich==0.22.8
         | 
| 26 | 
            +
            easydict==1.13
         | 
| 27 | 
            +
            einops==0.8.1
         | 
| 28 | 
            +
            exceptiongroup==1.2.2
         | 
| 29 | 
            +
            executing==2.2.0
         | 
| 30 | 
            +
            fairscale==0.4.13
         | 
| 31 | 
            +
            fastapi==0.115.12
         | 
| 32 | 
            +
            fastrlock==0.8.3
         | 
| 33 | 
            +
            ffmpy==0.5.0
         | 
| 34 | 
            +
            filelock==3.13.1
         | 
| 35 | 
            +
            flash_attn==2.8.0.post2 --global-option="--no-build-isolation"
         | 
| 36 | 
            +
            fonttools==4.57.0
         | 
| 37 | 
            +
            frozenlist==1.6.0
         | 
| 38 | 
            +
            fsspec==2024.12.0
         | 
| 39 | 
            +
            ftfy==6.3.1
         | 
| 40 | 
            +
            func_timeout==4.3.5
         | 
| 41 | 
            +
            fuzzywuzzy==0.18.0
         | 
| 42 | 
            +
            gitdb==4.0.12
         | 
| 43 | 
            +
            GitPython==3.1.44
         | 
| 44 | 
            +
            gradio==5.25.2
         | 
| 45 | 
            +
            gradio_client==1.8.0
         | 
| 46 | 
            +
            groovy==0.1.2
         | 
| 47 | 
            +
            grpcio==1.71.0
         | 
| 48 | 
            +
            h11==0.14.0
         | 
| 49 | 
            +
            hjson==3.1.0
         | 
| 50 | 
            +
            httpcore==1.0.8
         | 
| 51 | 
            +
            httpx==0.28.1
         | 
| 52 | 
            +
            huggingface-hub==0.30.2
         | 
| 53 | 
            +
            idna==3.10
         | 
| 54 | 
            +
            imageio==2.37.0
         | 
| 55 | 
            +
            imageio-ffmpeg==0.6.0
         | 
| 56 | 
            +
            importlib_metadata==8.6.1
         | 
| 57 | 
            +
            ipykernel==6.30.0
         | 
| 58 | 
            +
            ipython==8.37.0
         | 
| 59 | 
            +
            jedi==0.19.2
         | 
| 60 | 
            +
            Jinja2==3.1.4
         | 
| 61 | 
            +
            joblib==1.4.2
         | 
| 62 | 
            +
            kiwisolver==1.4.8
         | 
| 63 | 
            +
            kornia==0.8.0
         | 
| 64 | 
            +
            kornia_rs==0.1.8
         | 
| 65 | 
            +
            lazy_loader==0.4
         | 
| 66 | 
            +
            lightning==2.5.1
         | 
| 67 | 
            +
            lightning-utilities==0.14.3
         | 
| 68 | 
            +
            lpips==0.1.4
         | 
| 69 | 
            +
            matplotlib==3.10.1
         | 
| 70 | 
            +
            matplotlib-inline==0.1.7
         | 
| 71 | 
            +
            mdurl==0.1.2
         | 
| 72 | 
            +
            modelscope==1.25.0
         | 
| 73 | 
            +
            moviepy==2.1.2
         | 
| 74 | 
            +
            mpmath==1.3.0
         | 
| 75 | 
            +
            msgpack==1.1.0
         | 
| 76 | 
            +
            multidict==6.4.3
         | 
| 77 | 
            +
            multiprocess==0.70.16
         | 
| 78 | 
            +
            ninja==1.11.1.4
         | 
| 79 | 
            +
            numpy==2.2.5
         | 
| 80 | 
            +
            omegaconf==2.3.0
         | 
| 81 | 
            +
            opencv-python==4.11.0.86
         | 
| 82 | 
            +
            orjson==3.10.16
         | 
| 83 | 
            +
            packaging==24.2
         | 
| 84 | 
            +
            pandas==2.2.3
         | 
| 85 | 
            +
            parso==0.8.4
         | 
| 86 | 
            +
            peft==0.15.2
         | 
| 87 | 
            +
            pexpect==4.9.0
         | 
| 88 | 
            +
            pillow==10.4.0
         | 
| 89 | 
            +
            platformdirs==4.3.7
         | 
| 90 | 
            +
            proglog==0.1.11
         | 
| 91 | 
            +
            prompt_toolkit==3.0.51
         | 
| 92 | 
            +
            propcache==0.3.1
         | 
| 93 | 
            +
            protobuf==5.29.4
         | 
| 94 | 
            +
            psutil==7.0.0
         | 
| 95 | 
            +
            ptyprocess==0.7.0
         | 
| 96 | 
            +
            pure_eval==0.2.3
         | 
| 97 | 
            +
            py-cpuinfo==9.0.0
         | 
| 98 | 
            +
            pyarrow==19.0.1
         | 
| 99 | 
            +
            pycryptodome==3.22.0
         | 
| 100 | 
            +
            pydantic==2.11.3
         | 
| 101 | 
            +
            pydantic_core==2.33.1
         | 
| 102 | 
            +
            pydub==0.25.1
         | 
| 103 | 
            +
            Pygments==2.19.1
         | 
| 104 | 
            +
            pynvml==12.0.0
         | 
| 105 | 
            +
            pyparsing==3.2.3
         | 
| 106 | 
            +
            python-dateutil==2.9.0.post0
         | 
| 107 | 
            +
            python-dotenv==1.1.0
         | 
| 108 | 
            +
            python-multipart==0.0.20
         | 
| 109 | 
            +
            pytorch-fid==0.3.0
         | 
| 110 | 
            +
            pytorch-lightning==2.5.1
         | 
| 111 | 
            +
            pytz==2025.2
         | 
| 112 | 
            +
            PyYAML==6.0.2
         | 
| 113 | 
            +
            pyzmq==27.0.0
         | 
| 114 | 
            +
            regex==2024.11.6
         | 
| 115 | 
            +
            requests==2.32.3
         | 
| 116 | 
            +
            rich==14.0.0
         | 
| 117 | 
            +
            ruff==0.11.6
         | 
| 118 | 
            +
            safehttpx==0.1.6
         | 
| 119 | 
            +
            safetensors==0.5.3
         | 
| 120 | 
            +
            scikit-image==0.25.2
         | 
| 121 | 
            +
            scikit-learn==1.6.1
         | 
| 122 | 
            +
            scipy==1.15.2
         | 
| 123 | 
            +
            semantic-version==2.10.0
         | 
| 124 | 
            +
            sentencepiece==0.2.0
         | 
| 125 | 
            +
            sentry-sdk==2.26.1
         | 
| 126 | 
            +
            setproctitle==1.3.5
         | 
| 127 | 
            +
            shellingham==1.5.4
         | 
| 128 | 
            +
            simplejson==3.20.1
         | 
| 129 | 
            +
            six==1.17.0
         | 
| 130 | 
            +
            smmap==5.0.2
         | 
| 131 | 
            +
            sniffio==1.3.1
         | 
| 132 | 
            +
            soupsieve==2.7
         | 
| 133 | 
            +
            stack-data==0.6.3
         | 
| 134 | 
            +
            starlette==0.46.2
         | 
| 135 | 
            +
            sympy==1.13.1
         | 
| 136 | 
            +
            taming-transformers==0.0.1
         | 
| 137 | 
            +
            tensorboard==2.19.0
         | 
| 138 | 
            +
            tokenizers==0.20.3
         | 
| 139 | 
            +
            torch==2.6.0
         | 
| 140 | 
            +
            torchaudio==2.6.0
         | 
| 141 | 
            +
            torchdiffeq==0.2.5
         | 
| 142 | 
            +
            torchmetrics==1.7.1
         | 
| 143 | 
            +
            torchsde==0.2.6
         | 
| 144 | 
            +
            torchvision==0.21.0
         | 
| 145 | 
            +
            tqdm==4.67.1
         | 
| 146 | 
            +
            transformers==4.46.2
         | 
| 147 | 
            +
            triton==3.2.0
         | 
| 148 | 
            +
            xformers==0.0.29.post2
         | 
    	
        samples/1_image1.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/1_out.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:fa51ac0653a18dc20b9a6946aaa1a7923d58fe291e926908703c300a4d13c4a2
         | 
| 3 | 
            +
            size 356550
         | 
    	
        samples/1_prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            ['在海底,一个上身赤裸的的男子和一个螺旋游动的蓝鱼嬉戏。鲸鱼跟着男人手里拿的袋子绕圈,男子拿着袋子引诱着蓝鱼向前游动。Anime. High quality.']
         | 
    	
        samples/1_sketch1.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/1_sketch2.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/1_sketch3.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/2_image1.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/2_out.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c9f28ab63b4fc5b07c0ed01f715ec671f8d839b8783dc8a432c7764bd35605f5
         | 
| 3 | 
            +
            size 151565
         | 
    	
        samples/2_prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            ['一个女孩和一个银发男孩种下了一颗巨大的花,随着镜头缓慢向上移动,这个巨大的花不断生长变大并开放。Anime. High quality.']
         | 
    	
        samples/2_sketch1.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/2_sketch2.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/3_image1.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/3_out.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:fdb131043289d831f7c4e0d3dd4a21ecc3c4eecca1bf3ae539bb14414c439cde
         | 
| 3 | 
            +
            size 87909
         | 
    	
        samples/3_prompt.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            ['一个古代中国男孩拿着苹果,笑眯眯地送给旁边的老人。Anime. High quality.']
         | 
    	
        samples/3_sketch1.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/ToonComposer-Icon.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/ToonComposer-Method.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        samples/ToonComposer-TLDR.jpg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        scheduler/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        scheduler/flow_match.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class FlowMatchScheduler():
         | 
| 5 | 
            +
             | 
| 6 | 
            +
                def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
         | 
| 7 | 
            +
                    self.num_train_timesteps = num_train_timesteps
         | 
| 8 | 
            +
                    self.shift = shift
         | 
| 9 | 
            +
                    self.sigma_max = sigma_max
         | 
| 10 | 
            +
                    self.sigma_min = sigma_min
         | 
| 11 | 
            +
                    self.inverse_timesteps = inverse_timesteps
         | 
| 12 | 
            +
                    self.extra_one_step = extra_one_step
         | 
| 13 | 
            +
                    self.reverse_sigmas = reverse_sigmas
         | 
| 14 | 
            +
                    self.set_timesteps(num_inference_steps)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
                def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
         | 
| 18 | 
            +
                    if shift is not None:
         | 
| 19 | 
            +
                        self.shift = shift
         | 
| 20 | 
            +
                    sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
         | 
| 21 | 
            +
                    if self.extra_one_step:
         | 
| 22 | 
            +
                        self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
         | 
| 23 | 
            +
                    else:
         | 
| 24 | 
            +
                        self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
         | 
| 25 | 
            +
                    if self.inverse_timesteps:
         | 
| 26 | 
            +
                        self.sigmas = torch.flip(self.sigmas, dims=[0])
         | 
| 27 | 
            +
                    self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
         | 
| 28 | 
            +
                    if self.reverse_sigmas:
         | 
| 29 | 
            +
                        self.sigmas = 1 - self.sigmas
         | 
| 30 | 
            +
                    self.timesteps = self.sigmas * self.num_train_timesteps
         | 
| 31 | 
            +
                    if training:
         | 
| 32 | 
            +
                        x = self.timesteps
         | 
| 33 | 
            +
                        y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
         | 
| 34 | 
            +
                        y_shifted = y - y.min()
         | 
| 35 | 
            +
                        bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
         | 
| 36 | 
            +
                        self.linear_timesteps_weights = bsmntw_weighing
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
                def step(self, model_output, timestep, sample, to_final=False):
         | 
| 40 | 
            +
                    if isinstance(timestep, torch.Tensor):
         | 
| 41 | 
            +
                        timestep = timestep.cpu()
         | 
| 42 | 
            +
                    timestep_id = torch.argmin((self.timesteps - timestep).abs())
         | 
| 43 | 
            +
                    sigma = self.sigmas[timestep_id]
         | 
| 44 | 
            +
                    if to_final or timestep_id + 1 >= len(self.timesteps):
         | 
| 45 | 
            +
                        sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
         | 
| 46 | 
            +
                    else:
         | 
| 47 | 
            +
                        sigma_ = self.sigmas[timestep_id + 1]
         | 
| 48 | 
            +
                    prev_sample = sample + model_output * (sigma_ - sigma)
         | 
| 49 | 
            +
                    return prev_sample
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def return_to_timestep(self, timestep, sample, sample_stablized):
         | 
| 53 | 
            +
                    if isinstance(timestep, torch.Tensor):
         | 
| 54 | 
            +
                        timestep = timestep.cpu()
         | 
| 55 | 
            +
                    timestep_id = torch.argmin((self.timesteps - timestep).abs())
         | 
| 56 | 
            +
                    sigma = self.sigmas[timestep_id]
         | 
| 57 | 
            +
                    model_output = (sample - sample_stablized) / sigma
         | 
| 58 | 
            +
                    return model_output
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                def add_noise(self, original_samples, noise, timestep):
         | 
| 62 | 
            +
                    if isinstance(timestep, torch.Tensor):
         | 
| 63 | 
            +
                        timestep = timestep.cpu()
         | 
| 64 | 
            +
                    timestep_id = torch.argmin((self.timesteps - timestep).abs())
         | 
| 65 | 
            +
                    sigma = self.sigmas[timestep_id]
         | 
| 66 | 
            +
                    sample = (1 - sigma) * original_samples + sigma * noise
         | 
| 67 | 
            +
                    return sample
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def training_target(self, sample, noise, timestep):
         | 
| 71 | 
            +
                    target = noise - sample
         | 
| 72 | 
            +
                    return target
         | 
| 73 | 
            +
                
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def training_weight(self, timestep):
         | 
| 76 | 
            +
                    timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
         | 
| 77 | 
            +
                    weights = self.linear_timesteps_weights[timestep_id]
         | 
| 78 | 
            +
                    return weights
         | 
    	
        tooncomposer.py
    ADDED
    
    | @@ -0,0 +1,234 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os, torch, lightning, imageio
         | 
| 2 | 
            +
            from peft import LoraConfig, inject_adapter_in_model
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from pipeline.i2v_pipeline import WanVideoPipeline
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
| 9 | 
            +
            torch.set_float32_matmul_precision('medium')
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
         | 
| 13 | 
            +
                writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
         | 
| 14 | 
            +
                for frame in frames:
         | 
| 15 | 
            +
                    frame = np.array(frame)
         | 
| 16 | 
            +
                    writer.append_data(frame)
         | 
| 17 | 
            +
                writer.close()
         | 
| 18 | 
            +
                
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            def get_base_model_paths(base_model_name, format='dict', model_root="./weights"):
         | 
| 21 | 
            +
                    if base_model_name == "Wan2.1-I2V-14B-480P":
         | 
| 22 | 
            +
                        if format == 'list':
         | 
| 23 | 
            +
                            return [
         | 
| 24 | 
            +
                                [os.path.join(model_root, f"diffusion_pytorch_model-0000{_idx}-of-00007.safetensors") for _idx in range(1, 8)],
         | 
| 25 | 
            +
                                os.path.join(model_root, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
         | 
| 26 | 
            +
                                os.path.join(model_root, "models_t5_umt5-xxl-enc-bf16.pth"),
         | 
| 27 | 
            +
                                os.path.join(model_root, "Wan2.1_VAE.pth")
         | 
| 28 | 
            +
                            ]
         | 
| 29 | 
            +
                        elif format == 'dict':
         | 
| 30 | 
            +
                            return {
         | 
| 31 | 
            +
                                "dit": [os.path.join(model_root, f"diffusion_pytorch_model-0000{_idx}-of-00007.safetensors") for _idx in range(1, 8)],
         | 
| 32 | 
            +
                                "image_encoder": os.path.join(model_root, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
         | 
| 33 | 
            +
                                "text_encoder": os.path.join(model_root, "models_t5_umt5-xxl-enc-bf16.pth"),
         | 
| 34 | 
            +
                                "vae": os.path.join(model_root, "Wan2.1_VAE.pth")
         | 
| 35 | 
            +
                            }
         | 
| 36 | 
            +
                        else:
         | 
| 37 | 
            +
                            raise ValueError(f"Unsupported format: {format}")
         | 
| 38 | 
            +
                    else:
         | 
| 39 | 
            +
                        raise ValueError(f"Unsupported base model name: {base_model_name}")
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class ToonComposer(lightning.LightningModule):
         | 
| 43 | 
            +
                def __init__(self, base_model_name="Wan2.1-I2V-14B-480P", model_root=None, learning_rate=1e-5, lora_rank=4, lora_alpha=4, 
         | 
| 44 | 
            +
                             train_architecture=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", 
         | 
| 45 | 
            +
                             init_lora_weights="kaiming", use_gradient_checkpointing=True, 
         | 
| 46 | 
            +
                             checkpoint_path=None, video_condition_preservation_mode="first_and_last", 
         | 
| 47 | 
            +
                             tiled=False, tile_size=(34, 34), tile_stride=(18, 16), output_path=None,
         | 
| 48 | 
            +
                             use_local_lora=False, use_dera=False, dera_rank=None, use_dera_spatial=True, use_dera_temporal=True, use_sequence_cond=False, sequence_cond_mode="sparse",
         | 
| 49 | 
            +
                             use_channel_cond=False,
         | 
| 50 | 
            +
                             use_sequence_cond_position_aware_residual=False,
         | 
| 51 | 
            +
                             use_sequence_cond_loss=False, fast_dev=False,
         | 
| 52 | 
            +
                             max_num_cond_images=1, max_num_cond_sketches=2, visualize_attention=False,
         | 
| 53 | 
            +
                             random_spaced_cond_frames=False, use_sketch_mask=False, sketch_mask_ratio=0.2, no_first_sketch=False,
         | 
| 54 | 
            +
                             test_sampling_steps=15, test_sequence_cond_residual_scale=0.5, height=480, width=832):
         | 
| 55 | 
            +
                    super().__init__()
         | 
| 56 | 
            +
                    
         | 
| 57 | 
            +
                    self.pipe = WanVideoPipeline(device="cpu", torch_dtype=torch.bfloat16)
         | 
| 58 | 
            +
                    self.use_local_lora = use_local_lora
         | 
| 59 | 
            +
                    self.use_dera = use_dera
         | 
| 60 | 
            +
                    self.use_dera_spatial = use_dera_spatial
         | 
| 61 | 
            +
                    self.use_dera_temporal = use_dera_temporal
         | 
| 62 | 
            +
                    self.use_sequence_cond = use_sequence_cond
         | 
| 63 | 
            +
                    self.sequence_cond_mode = sequence_cond_mode
         | 
| 64 | 
            +
                    self.use_channel_cond = use_channel_cond
         | 
| 65 | 
            +
                    self.use_sequence_cond_position_aware_residual = use_sequence_cond_position_aware_residual
         | 
| 66 | 
            +
                    assert not (use_sequence_cond and use_channel_cond), "Cannot use both sequence condition and channel condition."
         | 
| 67 | 
            +
                    self.use_sequence_cond_loss = use_sequence_cond_loss
         | 
| 68 | 
            +
                    
         | 
| 69 | 
            +
                    self.max_num_cond_images = max_num_cond_images
         | 
| 70 | 
            +
                    self.max_num_cond_sketches = max_num_cond_sketches
         | 
| 71 | 
            +
                    
         | 
| 72 | 
            +
                    self.visualize_attention = visualize_attention
         | 
| 73 | 
            +
                    self.random_spaced_cond_frames = random_spaced_cond_frames
         | 
| 74 | 
            +
                    self.use_sketch_mask = use_sketch_mask
         | 
| 75 | 
            +
                    self.sketch_mask_ratio = sketch_mask_ratio
         | 
| 76 | 
            +
                    self.no_first_sketch = no_first_sketch
         | 
| 77 | 
            +
                    self.test_sampling_steps = test_sampling_steps
         | 
| 78 | 
            +
                    self.test_sequence_cond_residual_scale = test_sequence_cond_residual_scale
         | 
| 79 | 
            +
                    
         | 
| 80 | 
            +
                    self.height = height
         | 
| 81 | 
            +
                    self.width = width
         | 
| 82 | 
            +
                    
         | 
| 83 | 
            +
                    self.current_checkpoint_path = None
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    paths = get_base_model_paths(base_model_name, format='dict', model_root=model_root)
         | 
| 86 | 
            +
                    if use_sequence_cond:
         | 
| 87 | 
            +
                        assert sequence_cond_mode in ["sparse", "full"], f"Unsupported sequence condition model: {sequence_cond_mode}"
         | 
| 88 | 
            +
                        if sequence_cond_mode == "sparse":
         | 
| 89 | 
            +
                            if use_sketch_mask:
         | 
| 90 | 
            +
                                sequence_cond_in_dim = 24
         | 
| 91 | 
            +
                            else:
         | 
| 92 | 
            +
                                sequence_cond_in_dim = 20
         | 
| 93 | 
            +
                        else:
         | 
| 94 | 
            +
                            sequence_cond_in_dim = 20
         | 
| 95 | 
            +
                        use_channel_cond = False
         | 
| 96 | 
            +
                        channel_cond_in_dim = None
         | 
| 97 | 
            +
                    elif use_channel_cond:
         | 
| 98 | 
            +
                        channel_cond_in_dim = 20
         | 
| 99 | 
            +
                        sequence_cond_in_dim = None
         | 
| 100 | 
            +
                        use_sequence_cond = False
         | 
| 101 | 
            +
                    
         | 
| 102 | 
            +
                    dit_config = {
         | 
| 103 | 
            +
                        "use_local_lora": use_local_lora,
         | 
| 104 | 
            +
                        "use_dera": use_dera,
         | 
| 105 | 
            +
                        "dera_rank": dera_rank,
         | 
| 106 | 
            +
                        "use_dera_spatial": use_dera_spatial,
         | 
| 107 | 
            +
                        "use_dera_temporal": use_dera_temporal,
         | 
| 108 | 
            +
                        "use_sequence_cond": use_sequence_cond,
         | 
| 109 | 
            +
                        "sequence_cond_mode": sequence_cond_mode,
         | 
| 110 | 
            +
                        "sequence_cond_in_dim": sequence_cond_in_dim,
         | 
| 111 | 
            +
                        "use_channel_cond": use_channel_cond,
         | 
| 112 | 
            +
                        "channel_cond_in_dim": channel_cond_in_dim,
         | 
| 113 | 
            +
                        "use_sequence_cond_position_aware_residual": use_sequence_cond_position_aware_residual,
         | 
| 114 | 
            +
                        "use_sequence_cond_loss": use_sequence_cond_loss
         | 
| 115 | 
            +
                    }
         | 
| 116 | 
            +
                    if fast_dev:
         | 
| 117 | 
            +
                        del paths["dit"]
         | 
| 118 | 
            +
                        dit_config.update({
         | 
| 119 | 
            +
                            "model_type": "i2v",
         | 
| 120 | 
            +
                            "patch_size": (1, 2, 2),
         | 
| 121 | 
            +
                            "text_len": 512,
         | 
| 122 | 
            +
                            "in_dim": 36,
         | 
| 123 | 
            +
                            "dim": 512,
         | 
| 124 | 
            +
                            "ffn_dim": 512,
         | 
| 125 | 
            +
                            "freq_dim": 256,
         | 
| 126 | 
            +
                            "text_dim": 4096,
         | 
| 127 | 
            +
                            "out_dim": 16,
         | 
| 128 | 
            +
                            "num_heads": 2,  # 40
         | 
| 129 | 
            +
                            "num_layers": 40,
         | 
| 130 | 
            +
                            "window_size": (-1, -1),
         | 
| 131 | 
            +
                            "qk_norm": True,
         | 
| 132 | 
            +
                            "cross_attn_norm": True,
         | 
| 133 | 
            +
                            "eps": 1e-6,
         | 
| 134 | 
            +
                        })
         | 
| 135 | 
            +
                        self.pipe.initialize_dummy_dit(dit_config)
         | 
| 136 | 
            +
                        
         | 
| 137 | 
            +
                    self.pipe.fetch_models_from_checkpoints(
         | 
| 138 | 
            +
                        paths,
         | 
| 139 | 
            +
                        config_dict={
         | 
| 140 | 
            +
                            "dit": dit_config
         | 
| 141 | 
            +
                        })
         | 
| 142 | 
            +
                    
         | 
| 143 | 
            +
                    if use_sequence_cond:
         | 
| 144 | 
            +
                        self.pipe.denoising_model().copy_sequence_cond_patch_embedding_weights()
         | 
| 145 | 
            +
                    elif use_channel_cond:
         | 
| 146 | 
            +
                        self.pipe.denoising_model().copy_patch_embedding_weights_for_channel_cond()
         | 
| 147 | 
            +
                    
         | 
| 148 | 
            +
                    self.freeze_parameters()
         | 
| 149 | 
            +
                    if train_architecture == "lora":
         | 
| 150 | 
            +
                        self.add_lora_to_model(
         | 
| 151 | 
            +
                            self.pipe.denoising_model(),
         | 
| 152 | 
            +
                            lora_rank=lora_rank,
         | 
| 153 | 
            +
                            lora_alpha=lora_alpha,
         | 
| 154 | 
            +
                            lora_target_modules=lora_target_modules,
         | 
| 155 | 
            +
                            init_lora_weights=init_lora_weights
         | 
| 156 | 
            +
                        )
         | 
| 157 | 
            +
                    elif train_architecture == "full":
         | 
| 158 | 
            +
                        self.pipe.denoising_model().requires_grad_(True)
         | 
| 159 | 
            +
                        
         | 
| 160 | 
            +
                    if checkpoint_path is not None:
         | 
| 161 | 
            +
                        self.load_tooncomposer_checkpoint(checkpoint_path)
         | 
| 162 | 
            +
                    
         | 
| 163 | 
            +
                    self.learning_rate = learning_rate
         | 
| 164 | 
            +
                    self.use_gradient_checkpointing = use_gradient_checkpointing
         | 
| 165 | 
            +
                    
         | 
| 166 | 
            +
                    self.pipe.scheduler.set_timesteps(1000, training=True)
         | 
| 167 | 
            +
                    self.vae_tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
         | 
| 168 | 
            +
                    self.video_condition_preservation_mode = video_condition_preservation_mode
         | 
| 169 | 
            +
                    self.negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"  
         | 
| 170 | 
            +
                    
         | 
| 171 | 
            +
                    if output_path is None:
         | 
| 172 | 
            +
                        output_path = "./"
         | 
| 173 | 
            +
                    self.output_path = output_path
         | 
| 174 | 
            +
                
         | 
| 175 | 
            +
                def load_tooncomposer_checkpoint(self, checkpoint_path):
         | 
| 176 | 
            +
                    if checkpoint_path == self.current_checkpoint_path:
         | 
| 177 | 
            +
                        print(f"Skipping loading checkpoint {checkpoint_path} because it is the same as the current checkpoint.")
         | 
| 178 | 
            +
                        return
         | 
| 179 | 
            +
                    self.current_checkpoint_path = checkpoint_path
         | 
| 180 | 
            +
                    self.load_patch_to_model(
         | 
| 181 | 
            +
                        self.pipe.denoising_model(),
         | 
| 182 | 
            +
                        checkpoint_path
         | 
| 183 | 
            +
                    )
         | 
| 184 | 
            +
                    
         | 
| 185 | 
            +
                def update_height_width(self, height, width):
         | 
| 186 | 
            +
                    self.height = height
         | 
| 187 | 
            +
                    self.width = width
         | 
| 188 | 
            +
                    
         | 
| 189 | 
            +
                def freeze_parameters(self):
         | 
| 190 | 
            +
                    self.pipe.requires_grad_(False)
         | 
| 191 | 
            +
                    self.pipe.eval()
         | 
| 192 | 
            +
                    self.pipe.denoising_model().train()
         | 
| 193 | 
            +
                    
         | 
| 194 | 
            +
                def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
         | 
| 195 | 
            +
                    self.lora_alpha = lora_alpha
         | 
| 196 | 
            +
                    if init_lora_weights == "kaiming":
         | 
| 197 | 
            +
                        init_lora_weights = True
         | 
| 198 | 
            +
                        
         | 
| 199 | 
            +
                    lora_config = LoraConfig(
         | 
| 200 | 
            +
                        r=lora_rank,
         | 
| 201 | 
            +
                        lora_alpha=lora_alpha,
         | 
| 202 | 
            +
                        init_lora_weights=init_lora_weights,
         | 
| 203 | 
            +
                        target_modules=lora_target_modules.split(","),
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
                    model = inject_adapter_in_model(lora_config, model)
         | 
| 206 | 
            +
                    for param in model.parameters():
         | 
| 207 | 
            +
                        if param.requires_grad:
         | 
| 208 | 
            +
                            param.data = param.to(torch.float32)
         | 
| 209 | 
            +
                
         | 
| 210 | 
            +
                def load_patch_to_model(self, model, pretrained_path, state_dict_converter=None):
         | 
| 211 | 
            +
                    if pretrained_path is not None:
         | 
| 212 | 
            +
                        state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)
         | 
| 213 | 
            +
                        self.loaded_global_step = 0
         | 
| 214 | 
            +
                        self.loaded_current_epoch = 0
         | 
| 215 | 
            +
                        if self.use_sketch_mask:
         | 
| 216 | 
            +
                            seq_cond_embed_weight = state_dict['sequence_cond_patch_embedding.weight']
         | 
| 217 | 
            +
                            current_in_channels = self.pipe.denoising_model().sequence_cond_patch_embedding.in_channels
         | 
| 218 | 
            +
                            if current_in_channels == 24 and seq_cond_embed_weight.shape[1] == 20:
         | 
| 219 | 
            +
                                new_weight = torch.zeros(
         | 
| 220 | 
            +
                                    seq_cond_embed_weight.shape[0],
         | 
| 221 | 
            +
                                    4,
         | 
| 222 | 
            +
                                    *seq_cond_embed_weight.shape[2:],
         | 
| 223 | 
            +
                                    dtype=seq_cond_embed_weight.dtype
         | 
| 224 | 
            +
                                )
         | 
| 225 | 
            +
                                state_dict['sequence_cond_patch_embedding.weight'] = torch.cat([
         | 
| 226 | 
            +
                                    seq_cond_embed_weight, new_weight], dim=1)
         | 
| 227 | 
            +
                        
         | 
| 228 | 
            +
                        if state_dict_converter is not None:
         | 
| 229 | 
            +
                            state_dict = state_dict_converter(state_dict)
         | 
| 230 | 
            +
                        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
         | 
| 231 | 
            +
                        all_keys = [i for i, _ in model.named_parameters()]
         | 
| 232 | 
            +
                        num_updated_keys = len(all_keys) - len(missing_keys)
         | 
| 233 | 
            +
                        num_unexpected_keys = len(unexpected_keys)
         | 
| 234 | 
            +
                        print(f"[Checkpoint] {num_updated_keys} parameters are loaded from {pretrained_path}. {num_unexpected_keys} parameters are unexpected.")
         | 
    	
        util/model_util.py
    ADDED
    
    | @@ -0,0 +1,241 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch, os
         | 
| 2 | 
            +
            from safetensors import safe_open
         | 
| 3 | 
            +
            from contextlib import contextmanager
         | 
| 4 | 
            +
            import hashlib
         | 
| 5 | 
            +
            import matplotlib.pyplot as plt
         | 
| 6 | 
            +
            from matplotlib.colors import LinearSegmentedColormap
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            @contextmanager
         | 
| 10 | 
            +
            def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
                old_register_parameter = torch.nn.Module.register_parameter
         | 
| 13 | 
            +
                if include_buffers:
         | 
| 14 | 
            +
                    old_register_buffer = torch.nn.Module.register_buffer
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                def register_empty_parameter(module, name, param):
         | 
| 17 | 
            +
                    old_register_parameter(module, name, param)
         | 
| 18 | 
            +
                    if param is not None:
         | 
| 19 | 
            +
                        param_cls = type(module._parameters[name])
         | 
| 20 | 
            +
                        kwargs = module._parameters[name].__dict__
         | 
| 21 | 
            +
                        kwargs["requires_grad"] = param.requires_grad
         | 
| 22 | 
            +
                        module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def register_empty_buffer(module, name, buffer, persistent=True):
         | 
| 25 | 
            +
                    old_register_buffer(module, name, buffer, persistent=persistent)
         | 
| 26 | 
            +
                    if buffer is not None:
         | 
| 27 | 
            +
                        module._buffers[name] = module._buffers[name].to(device)
         | 
| 28 | 
            +
                        
         | 
| 29 | 
            +
                def patch_tensor_constructor(fn):
         | 
| 30 | 
            +
                    def wrapper(*args, **kwargs):
         | 
| 31 | 
            +
                        kwargs["device"] = device
         | 
| 32 | 
            +
                        return fn(*args, **kwargs)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    return wrapper
         | 
| 35 | 
            +
                
         | 
| 36 | 
            +
                if include_buffers:
         | 
| 37 | 
            +
                    tensor_constructors_to_patch = {
         | 
| 38 | 
            +
                        torch_function_name: getattr(torch, torch_function_name)
         | 
| 39 | 
            +
                        for torch_function_name in ["empty", "zeros", "ones", "full"]
         | 
| 40 | 
            +
                    }
         | 
| 41 | 
            +
                else:
         | 
| 42 | 
            +
                    tensor_constructors_to_patch = {}
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                try:
         | 
| 45 | 
            +
                    torch.nn.Module.register_parameter = register_empty_parameter
         | 
| 46 | 
            +
                    if include_buffers:
         | 
| 47 | 
            +
                        torch.nn.Module.register_buffer = register_empty_buffer
         | 
| 48 | 
            +
                    for torch_function_name in tensor_constructors_to_patch.keys():
         | 
| 49 | 
            +
                        setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
         | 
| 50 | 
            +
                    yield
         | 
| 51 | 
            +
                finally:
         | 
| 52 | 
            +
                    torch.nn.Module.register_parameter = old_register_parameter
         | 
| 53 | 
            +
                    if include_buffers:
         | 
| 54 | 
            +
                        torch.nn.Module.register_buffer = old_register_buffer
         | 
| 55 | 
            +
                    for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
         | 
| 56 | 
            +
                        setattr(torch, torch_function_name, old_torch_function)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            def load_state_dict_from_folder(file_path, torch_dtype=None):
         | 
| 59 | 
            +
                state_dict = {}
         | 
| 60 | 
            +
                for file_name in os.listdir(file_path):
         | 
| 61 | 
            +
                    if "." in file_name and file_name.split(".")[-1] in [
         | 
| 62 | 
            +
                        "safetensors", "bin", "ckpt", "pth", "pt"
         | 
| 63 | 
            +
                    ]:
         | 
| 64 | 
            +
                        state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
         | 
| 65 | 
            +
                return state_dict
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def load_state_dict(file_path, torch_dtype=None):
         | 
| 69 | 
            +
                if file_path.endswith(".safetensors"):
         | 
| 70 | 
            +
                    return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
         | 
| 71 | 
            +
                else:
         | 
| 72 | 
            +
                    return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def load_state_dict_from_safetensors(file_path, torch_dtype=None):
         | 
| 76 | 
            +
                state_dict = {}
         | 
| 77 | 
            +
                with safe_open(file_path, framework="pt", device="cpu") as f:
         | 
| 78 | 
            +
                    for k in f.keys():
         | 
| 79 | 
            +
                        state_dict[k] = f.get_tensor(k)
         | 
| 80 | 
            +
                        if torch_dtype is not None:
         | 
| 81 | 
            +
                            state_dict[k] = state_dict[k].to(torch_dtype)
         | 
| 82 | 
            +
                return state_dict
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def load_state_dict_from_bin(file_path, torch_dtype=None):
         | 
| 86 | 
            +
                state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
         | 
| 87 | 
            +
                if torch_dtype is not None:
         | 
| 88 | 
            +
                    for i in state_dict:
         | 
| 89 | 
            +
                        if isinstance(state_dict[i], torch.Tensor):
         | 
| 90 | 
            +
                            state_dict[i] = state_dict[i].to(torch_dtype)
         | 
| 91 | 
            +
                return state_dict
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def search_for_embeddings(state_dict):
         | 
| 95 | 
            +
                embeddings = []
         | 
| 96 | 
            +
                for k in state_dict:
         | 
| 97 | 
            +
                    if isinstance(state_dict[k], torch.Tensor):
         | 
| 98 | 
            +
                        embeddings.append(state_dict[k])
         | 
| 99 | 
            +
                    elif isinstance(state_dict[k], dict):
         | 
| 100 | 
            +
                        embeddings += search_for_embeddings(state_dict[k])
         | 
| 101 | 
            +
                return embeddings
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def search_parameter(param, state_dict):
         | 
| 105 | 
            +
                for name, param_ in state_dict.items():
         | 
| 106 | 
            +
                    if param.numel() == param_.numel():
         | 
| 107 | 
            +
                        if param.shape == param_.shape:
         | 
| 108 | 
            +
                            if torch.dist(param, param_) < 1e-3:
         | 
| 109 | 
            +
                                return name
         | 
| 110 | 
            +
                        else:
         | 
| 111 | 
            +
                            if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
         | 
| 112 | 
            +
                                return name
         | 
| 113 | 
            +
                return None
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
         | 
| 117 | 
            +
                matched_keys = set()
         | 
| 118 | 
            +
                with torch.no_grad():
         | 
| 119 | 
            +
                    for name in source_state_dict:
         | 
| 120 | 
            +
                        rename = search_parameter(source_state_dict[name], target_state_dict)
         | 
| 121 | 
            +
                        if rename is not None:
         | 
| 122 | 
            +
                            print(f'"{name}": "{rename}",')
         | 
| 123 | 
            +
                            matched_keys.add(rename)
         | 
| 124 | 
            +
                        elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
         | 
| 125 | 
            +
                            length = source_state_dict[name].shape[0] // 3
         | 
| 126 | 
            +
                            rename = []
         | 
| 127 | 
            +
                            for i in range(3):
         | 
| 128 | 
            +
                                rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
         | 
| 129 | 
            +
                            if None not in rename:
         | 
| 130 | 
            +
                                print(f'"{name}": {rename},')
         | 
| 131 | 
            +
                                for rename_ in rename:
         | 
| 132 | 
            +
                                    matched_keys.add(rename_)
         | 
| 133 | 
            +
                for name in target_state_dict:
         | 
| 134 | 
            +
                    if name not in matched_keys:
         | 
| 135 | 
            +
                        print("Cannot find", name, target_state_dict[name].shape)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
             | 
| 138 | 
            +
            def search_for_files(folder, extensions):
         | 
| 139 | 
            +
                files = []
         | 
| 140 | 
            +
                if os.path.isdir(folder):
         | 
| 141 | 
            +
                    for file in sorted(os.listdir(folder)):
         | 
| 142 | 
            +
                        files += search_for_files(os.path.join(folder, file), extensions)
         | 
| 143 | 
            +
                elif os.path.isfile(folder):
         | 
| 144 | 
            +
                    for extension in extensions:
         | 
| 145 | 
            +
                        if folder.endswith(extension):
         | 
| 146 | 
            +
                            files.append(folder)
         | 
| 147 | 
            +
                            break
         | 
| 148 | 
            +
                return files
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
         | 
| 152 | 
            +
                keys = []
         | 
| 153 | 
            +
                for key, value in state_dict.items():
         | 
| 154 | 
            +
                    if isinstance(key, str):
         | 
| 155 | 
            +
                        if isinstance(value, torch.Tensor):
         | 
| 156 | 
            +
                            if with_shape:
         | 
| 157 | 
            +
                                shape = "_".join(map(str, list(value.shape)))
         | 
| 158 | 
            +
                                keys.append(key + ":" + shape)
         | 
| 159 | 
            +
                            keys.append(key)
         | 
| 160 | 
            +
                        elif isinstance(value, dict):
         | 
| 161 | 
            +
                            keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
         | 
| 162 | 
            +
                keys.sort()
         | 
| 163 | 
            +
                keys_str = ",".join(keys)
         | 
| 164 | 
            +
                return keys_str
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            def split_state_dict_with_prefix(state_dict):
         | 
| 168 | 
            +
                keys = sorted([key for key in state_dict if isinstance(key, str)])
         | 
| 169 | 
            +
                prefix_dict = {}
         | 
| 170 | 
            +
                for key in  keys:
         | 
| 171 | 
            +
                    prefix = key if "." not in key else key.split(".")[0]
         | 
| 172 | 
            +
                    if prefix not in prefix_dict:
         | 
| 173 | 
            +
                        prefix_dict[prefix] = []
         | 
| 174 | 
            +
                    prefix_dict[prefix].append(key)
         | 
| 175 | 
            +
                state_dicts = []
         | 
| 176 | 
            +
                for prefix, keys in prefix_dict.items():
         | 
| 177 | 
            +
                    sub_state_dict = {key: state_dict[key] for key in keys}
         | 
| 178 | 
            +
                    state_dicts.append(sub_state_dict)
         | 
| 179 | 
            +
                return state_dicts
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def hash_state_dict_keys(state_dict, with_shape=True):
         | 
| 183 | 
            +
                keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
         | 
| 184 | 
            +
                keys_str = keys_str.encode(encoding="UTF-8")
         | 
| 185 | 
            +
                return hashlib.md5(keys_str).hexdigest()
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def save_attention_maps(model, output_path, batch_idx, timestep, layer_indices=None):
         | 
| 189 | 
            +
                """
         | 
| 190 | 
            +
                Visualize and save the attention maps from selected layers of the model
         | 
| 191 | 
            +
                
         | 
| 192 | 
            +
                Args:
         | 
| 193 | 
            +
                    model: The DiT model with attention maps stored
         | 
| 194 | 
            +
                    output_path: Directory to save visualizations
         | 
| 195 | 
            +
                    batch_idx: Current batch index for file naming
         | 
| 196 | 
            +
                    layer_indices: List of layer indices to visualize (if None, visualize all)
         | 
| 197 | 
            +
                """
         | 
| 198 | 
            +
                timestep = int(float(str(timestep)))
         | 
| 199 | 
            +
                os.makedirs(os.path.join(output_path, "attention_maps"), exist_ok=True)
         | 
| 200 | 
            +
                
         | 
| 201 | 
            +
                # If layer indices not specified, visualize all layers
         | 
| 202 | 
            +
                if layer_indices is None:
         | 
| 203 | 
            +
                    layer_indices = range(len(model.blocks))
         | 
| 204 | 
            +
                
         | 
| 205 | 
            +
                # Create a custom colormap (similar to the ones used in attention visualization papers)
         | 
| 206 | 
            +
                colors = [(0, 0, 0.5), (0, 0, 1), (0, 0.5, 1), (0, 1, 1), 
         | 
| 207 | 
            +
                          (0.5, 1, 0.5), (1, 1, 0), (1, 0.5, 0), (1, 0, 0), (0.5, 0, 0)]
         | 
| 208 | 
            +
                attention_cmap = LinearSegmentedColormap.from_list('attention_cmap', colors)
         | 
| 209 | 
            +
                
         | 
| 210 | 
            +
                for i in layer_indices:
         | 
| 211 | 
            +
                    if not hasattr(model.blocks[i].self_attn, '_last_attn_maps'):
         | 
| 212 | 
            +
                        continue
         | 
| 213 | 
            +
                        
         | 
| 214 | 
            +
                    attn_map = model.blocks[i].self_attn._last_attn_maps
         | 
| 215 | 
            +
                    grid_size = model.blocks[i].self_attn._last_grid_sizes
         | 
| 216 | 
            +
                    seq_len = model.blocks[i].self_attn._last_seq_lens
         | 
| 217 | 
            +
                    # attn_maps.shape=[s, s]
         | 
| 218 | 
            +
                    np.savez_compressed(os.path.join(output_path,
         | 
| 219 | 
            +
                            "attention_maps",
         | 
| 220 | 
            +
                            f"attn_maps_layer{i}_batch{batch_idx}_t{timestep}.npz"),
         | 
| 221 | 
            +
                                        attn_map=attn_map, grid_size=grid_size, seq_len=seq_len)
         | 
| 222 | 
            +
                    
         | 
| 223 | 
            +
                    print(f"Saving Layer {i}, Batch {batch_idx} attention maps")
         | 
| 224 | 
            +
                    attn_map -= attn_map.min()
         | 
| 225 | 
            +
                    attn_map /= attn_map.max()
         | 
| 226 | 
            +
                    plt.figure(figsize=(10, 8))
         | 
| 227 | 
            +
                    plt.imshow(attn_map ** 0.25, cmap=attention_cmap)
         | 
| 228 | 
            +
                    plt.colorbar(label='Attention Weight')
         | 
| 229 | 
            +
                    plt.title(f'Layer {i}, Batch {batch_idx} (Average)')
         | 
| 230 | 
            +
                    save_path = os.path.join(
         | 
| 231 | 
            +
                        output_path, 
         | 
| 232 | 
            +
                        "attention_maps", 
         | 
| 233 | 
            +
                        f"attn_map_layer{i}_average_batch{batch_idx}_t{timestep}.png"
         | 
| 234 | 
            +
                    )
         | 
| 235 | 
            +
                    plt.savefig(save_path, dpi=300, bbox_inches='tight')
         | 
| 236 | 
            +
                    plt.close()
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                # Clean up the stored attention maps to free memory
         | 
| 239 | 
            +
                for i in layer_indices:
         | 
| 240 | 
            +
                    if hasattr(model.blocks[i].self_attn, '_last_attn_maps'):
         | 
| 241 | 
            +
                        del model.blocks[i].self_attn._last_attn_maps
         | 
    	
        util/optical_flow.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import cv2
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from torchvision.models.optical_flow import Raft_Large_Weights, raft_large
         | 
| 6 | 
            +
            from typing import List, Tuple, Dict
         | 
| 7 | 
            +
            import argparse
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
            from sklearn.cluster import KMeans
         | 
| 10 | 
            +
            from tqdm import tqdm
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            os.environ['OPENBLAS_NUM_THREADS'] = '64'
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            class OpticalFlowAnalyzer:
         | 
| 16 | 
            +
                def __init__(self, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
         | 
| 17 | 
            +
                    self.device = device
         | 
| 18 | 
            +
                    self.model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
         | 
| 19 | 
            +
                    self.model.eval()
         | 
| 20 | 
            +
                    
         | 
| 21 | 
            +
                def preprocess_frame(self, frame: np.ndarray) -> torch.Tensor:
         | 
| 22 | 
            +
                    """Preprocess a frame for RAFT model."""
         | 
| 23 | 
            +
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         | 
| 24 | 
            +
                    frame = torch.from_numpy(frame).permute(2, 0, 1).float()
         | 
| 25 | 
            +
                    frame = frame.unsqueeze(0) / 255.0
         | 
| 26 | 
            +
                    return frame.to(self.device)
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                def compute_optical_flow(self, frame1: np.ndarray, frame2: np.ndarray) -> np.ndarray:
         | 
| 29 | 
            +
                    """Compute optical flow between two consecutive frames."""
         | 
| 30 | 
            +
                    with torch.no_grad():
         | 
| 31 | 
            +
                        frame1_tensor = self.preprocess_frame(frame1)
         | 
| 32 | 
            +
                        frame2_tensor = self.preprocess_frame(frame2)
         | 
| 33 | 
            +
                        
         | 
| 34 | 
            +
                        flow = self.model(frame1_tensor, frame2_tensor)[-1]
         | 
| 35 | 
            +
                        flow = flow[0].permute(1, 2, 0).cpu().numpy()
         | 
| 36 | 
            +
                        
         | 
| 37 | 
            +
                    return flow
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                def analyze_motion_regions(self, flow: np.ndarray, num_clusters: int = 3) -> Tuple[np.ndarray, Dict]:
         | 
| 40 | 
            +
                    """Cluster motion regions based on optical flow magnitude and direction."""
         | 
| 41 | 
            +
                    h, w = flow.shape[:2]
         | 
| 42 | 
            +
                    magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)
         | 
| 43 | 
            +
                    direction = np.arctan2(flow[..., 1], flow[..., 0])
         | 
| 44 | 
            +
                    
         | 
| 45 | 
            +
                    # Create feature matrix for clustering
         | 
| 46 | 
            +
                    features = np.zeros((h * w, 3))
         | 
| 47 | 
            +
                    features[:, 0] = magnitude.ravel()
         | 
| 48 | 
            +
                    features[:, 1] = np.cos(direction).ravel()
         | 
| 49 | 
            +
                    features[:, 2] = np.sin(direction).ravel()
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    # Normalize features
         | 
| 52 | 
            +
                    features = (features - features.mean(axis=0)) / features.std(axis=0)
         | 
| 53 | 
            +
                    
         | 
| 54 | 
            +
                    # Perform clustering
         | 
| 55 | 
            +
                    kmeans = KMeans(n_clusters=num_clusters, random_state=42,)
         | 
| 56 | 
            +
                    labels = kmeans.fit_predict(features)
         | 
| 57 | 
            +
                    labels = labels.reshape(h, w)
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    # Analyze clusters
         | 
| 60 | 
            +
                    cluster_stats = {}
         | 
| 61 | 
            +
                    for i in range(num_clusters):
         | 
| 62 | 
            +
                        cluster_mask = (labels == i)
         | 
| 63 | 
            +
                        cluster_magnitude = magnitude[cluster_mask]
         | 
| 64 | 
            +
                        cluster_stats[i] = {
         | 
| 65 | 
            +
                            'mean_magnitude': np.mean(cluster_magnitude),
         | 
| 66 | 
            +
                            'std_magnitude': np.std(cluster_magnitude),
         | 
| 67 | 
            +
                            'pixel_count': np.sum(cluster_mask),
         | 
| 68 | 
            +
                            'is_static': np.mean(cluster_magnitude) < 0.1  # Threshold for static regions
         | 
| 69 | 
            +
                        }
         | 
| 70 | 
            +
                    
         | 
| 71 | 
            +
                    return labels, cluster_stats
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                def process_video(self, video_path: str, output_path: str = None) -> List[Tuple[np.ndarray, Dict]]:
         | 
| 74 | 
            +
                    """Process a video and return motion analysis results for each frame pair."""
         | 
| 75 | 
            +
                    cap = cv2.VideoCapture(video_path)
         | 
| 76 | 
            +
                    if not cap.isOpened():
         | 
| 77 | 
            +
                        raise ValueError(f"Could not open video: {video_path}")
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    results = []
         | 
| 80 | 
            +
                    ret, prev_frame = cap.read()
         | 
| 81 | 
            +
                    if not ret:
         | 
| 82 | 
            +
                        raise ValueError("Could not read first frame")
         | 
| 83 | 
            +
                    
         | 
| 84 | 
            +
                    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
         | 
| 85 | 
            +
                    pbar = tqdm(total=total_frames-1, desc="Processing video")
         | 
| 86 | 
            +
                    
         | 
| 87 | 
            +
                    while True:
         | 
| 88 | 
            +
                        ret, curr_frame = cap.read()
         | 
| 89 | 
            +
                        if not ret:
         | 
| 90 | 
            +
                            break
         | 
| 91 | 
            +
                            
         | 
| 92 | 
            +
                        flow = self.compute_optical_flow(prev_frame, curr_frame)
         | 
| 93 | 
            +
                        labels, stats = self.analyze_motion_regions(flow)
         | 
| 94 | 
            +
                        
         | 
| 95 | 
            +
                        if output_path:
         | 
| 96 | 
            +
                            # Visualize results
         | 
| 97 | 
            +
                            vis_frame = curr_frame.copy()
         | 
| 98 | 
            +
                            for i, stat in stats.items():
         | 
| 99 | 
            +
                                if not stat['is_static']:
         | 
| 100 | 
            +
                                    mask = (labels == i).astype(np.uint8) * 255
         | 
| 101 | 
            +
                                    print("mask:",mask.shape)
         | 
| 102 | 
            +
                                    print("vis_frame:",vis_frame.shape)
         | 
| 103 | 
            +
                                    mask = np.expand_dims(mask, axis=-1).repeat(3, axis=-1)
         | 
| 104 | 
            +
                                    print("mask:",mask.shape)
         | 
| 105 | 
            +
                                    
         | 
| 106 | 
            +
                                    vis_frame[mask > 0] = cv2.addWeighted(vis_frame[mask > 0], 0.7, 255, 0.3, 0)
         | 
| 107 | 
            +
                            
         | 
| 108 | 
            +
                            cv2.imwrite(f"{output_path}/frame_{len(results):04d}.jpg", vis_frame)
         | 
| 109 | 
            +
                        
         | 
| 110 | 
            +
                        results.append((labels, stats))
         | 
| 111 | 
            +
                        prev_frame = curr_frame
         | 
| 112 | 
            +
                        pbar.update(1)
         | 
| 113 | 
            +
                    
         | 
| 114 | 
            +
                    cap.release()
         | 
| 115 | 
            +
                    pbar.close()
         | 
| 116 | 
            +
                    return results
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            def main():
         | 
| 119 | 
            +
                parser = argparse.ArgumentParser(description='Analyze motion regions in a video using RAFT optical flow')
         | 
| 120 | 
            +
                parser.add_argument('--video', type=str, required=True, help='Path to input video')
         | 
| 121 | 
            +
                parser.add_argument('--output', type=str, help='Path to output directory for visualization')
         | 
| 122 | 
            +
                parser.add_argument('--clusters', type=int, default=3, help='Number of motion clusters')
         | 
| 123 | 
            +
                args = parser.parse_args()
         | 
| 124 | 
            +
                
         | 
| 125 | 
            +
                analyzer = OpticalFlowAnalyzer()
         | 
| 126 | 
            +
                results = analyzer.process_video(args.video, args.output)
         | 
| 127 | 
            +
                
         | 
| 128 | 
            +
                # Print summary statistics
         | 
| 129 | 
            +
                print("\nMotion Analysis Summary:")
         | 
| 130 | 
            +
                for i, (_, stats) in enumerate(results):
         | 
| 131 | 
            +
                    print(f"\nFrame {i+1}:")
         | 
| 132 | 
            +
                    for cluster_id, stat in stats.items():
         | 
| 133 | 
            +
                        motion_type = "Static" if stat['is_static'] else "Moving"
         | 
| 134 | 
            +
                        print(f"  Cluster {cluster_id} ({motion_type}):")
         | 
| 135 | 
            +
                        print(f"    Mean magnitude: {stat['mean_magnitude']:.4f}")
         | 
| 136 | 
            +
                        print(f"    Pixel count: {stat['pixel_count']}")
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            if __name__ == "__main__":
         | 
| 139 | 
            +
                main()
         | 
| 140 | 
            +
             | 
    	
        util/stylesheets.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        util/training_util.py
    ADDED
    
    | @@ -0,0 +1,317 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Union
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def create_random_mask(batch_size, num_frames, height, width, device, dtype, shape_type=None):
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                Create random masks for sketch frames.
         | 
| 12 | 
            +
                
         | 
| 13 | 
            +
                Args:
         | 
| 14 | 
            +
                    batch_size: Batch size
         | 
| 15 | 
            +
                    num_frames: Number of frames to mask
         | 
| 16 | 
            +
                    height, width: Image dimensions
         | 
| 17 | 
            +
                    device: Device for tensor
         | 
| 18 | 
            +
                    dtype: Data type for tensor
         | 
| 19 | 
            +
                    mask_area_ratio: Ratio of area to mask (0-1)
         | 
| 20 | 
            +
                    shape_type: Type of shape for masking ('square', 'circle', 'random'). If None, one is randomly selected.
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
                Returns:
         | 
| 23 | 
            +
                    Mask tensor in [b, 1, num_frames, height, width] where 0 indicates areas to mask (inverse of previous implementation)
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                # Initialize with ones (unmasked)
         | 
| 26 | 
            +
                masks = torch.ones(batch_size, 1, num_frames, height, width, device=device, dtype=dtype)
         | 
| 27 | 
            +
                
         | 
| 28 | 
            +
                for b in range(batch_size):
         | 
| 29 | 
            +
                    for f in range(num_frames):
         | 
| 30 | 
            +
                        # Randomly select shape type if not specified
         | 
| 31 | 
            +
                        if shape_type is None:
         | 
| 32 | 
            +
                            shape_type = random.choice(['square', 'circle', 'random'])
         | 
| 33 | 
            +
                        
         | 
| 34 | 
            +
                        # Create numpy mask for easier shape drawing
         | 
| 35 | 
            +
                        mask = np.zeros((height, width), dtype=np.float32)
         | 
| 36 | 
            +
                        
         | 
| 37 | 
            +
                        if shape_type == 'square':
         | 
| 38 | 
            +
                            # Random squares
         | 
| 39 | 
            +
                            num_squares = random.randint(1, 5)
         | 
| 40 | 
            +
                            for _ in range(num_squares):
         | 
| 41 | 
            +
                                # Random square size (proportional to image dimensions)
         | 
| 42 | 
            +
                                max_size = min(height, width)
         | 
| 43 | 
            +
                                size = random.randint(max_size // 4, max_size)
         | 
| 44 | 
            +
                                
         | 
| 45 | 
            +
                                # Random position
         | 
| 46 | 
            +
                                x = random.randint(0, width - size)
         | 
| 47 | 
            +
                                y = random.randint(0, height - size)
         | 
| 48 | 
            +
                                
         | 
| 49 | 
            +
                                # Draw square
         | 
| 50 | 
            +
                                mask[y:y+size, x:x+size] = 1.0
         | 
| 51 | 
            +
                                
         | 
| 52 | 
            +
                        elif shape_type == 'circle':
         | 
| 53 | 
            +
                            # Random circles
         | 
| 54 | 
            +
                            num_circles = random.randint(1, 5)
         | 
| 55 | 
            +
                            for _ in range(num_circles):
         | 
| 56 | 
            +
                                # Random radius (proportional to image dimensions)
         | 
| 57 | 
            +
                                max_radius = min(height, width) // 2
         | 
| 58 | 
            +
                                radius = random.randint(max_radius // 4, max_radius)
         | 
| 59 | 
            +
                                
         | 
| 60 | 
            +
                                # Random center
         | 
| 61 | 
            +
                                center_x = random.randint(radius, width - radius)
         | 
| 62 | 
            +
                                center_y = random.randint(radius, height - radius)
         | 
| 63 | 
            +
                                
         | 
| 64 | 
            +
                                # Draw circle
         | 
| 65 | 
            +
                                cv2.circle(mask, (center_x, center_y), radius, 1.0, -1)
         | 
| 66 | 
            +
                                
         | 
| 67 | 
            +
                        elif shape_type == 'random':
         | 
| 68 | 
            +
                            # Create connected random shape with cv2
         | 
| 69 | 
            +
                            num_points = random.randint(5, 16)
         | 
| 70 | 
            +
                            points = []
         | 
| 71 | 
            +
                            
         | 
| 72 | 
            +
                            # Generate random points
         | 
| 73 | 
            +
                            for _ in range(num_points):
         | 
| 74 | 
            +
                                x = random.randint(0, width - 1)
         | 
| 75 | 
            +
                                y = random.randint(0, height - 1)
         | 
| 76 | 
            +
                                points.append([x, y])
         | 
| 77 | 
            +
                            
         | 
| 78 | 
            +
                            # Convert to numpy array for cv2
         | 
| 79 | 
            +
                            points = np.array(points, dtype=np.int32)
         | 
| 80 | 
            +
                            
         | 
| 81 | 
            +
                            # Draw filled polygon
         | 
| 82 | 
            +
                            cv2.fillPoly(mask, [points], 1.0)
         | 
| 83 | 
            +
                        
         | 
| 84 | 
            +
                        # Convert numpy mask to tensor and subtract from ones (inverse the mask)
         | 
| 85 | 
            +
                        masks[b, 0, f] = 1.0 - torch.from_numpy(mask).to(device=device, dtype=dtype)
         | 
| 86 | 
            +
                
         | 
| 87 | 
            +
                return masks
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            @torch.no_grad()
         | 
| 91 | 
            +
            def extract_img_to_sketch(_sketch_model, _img, model_name="random"):
         | 
| 92 | 
            +
                """
         | 
| 93 | 
            +
                Return sketch: [-1, 1]
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
                orig_shape = (_img.shape[-2], _img.shape[-1])
         | 
| 96 | 
            +
                with torch.amp.autocast(dtype=torch.float32, device_type="cuda"):
         | 
| 97 | 
            +
                    reshaped_img = torch.nn.functional.interpolate(_img, (2048, 2048))
         | 
| 98 | 
            +
                    sketch = _sketch_model(reshaped_img, model_name=model_name)
         | 
| 99 | 
            +
                    sketch = torch.nn.functional.interpolate(sketch, orig_shape)
         | 
| 100 | 
            +
                if sketch.shape[1] == 1:
         | 
| 101 | 
            +
                    sketch = sketch.repeat(1, 3, 1, 1)
         | 
| 102 | 
            +
                return sketch
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def video_to_frame_and_sketch(
         | 
| 106 | 
            +
                sketch_model,
         | 
| 107 | 
            +
                original_video,
         | 
| 108 | 
            +
                max_num_preserved_sketch_frames=2,
         | 
| 109 | 
            +
                max_num_preserved_image_frames=1,
         | 
| 110 | 
            +
                min_num_preserved_sketch_frames=2,
         | 
| 111 | 
            +
                min_num_preserved_image_frames=1,
         | 
| 112 | 
            +
                model_name=None,
         | 
| 113 | 
            +
                detach_image_and_sketch=False,
         | 
| 114 | 
            +
                equally_spaced_preserve_sketch=False,
         | 
| 115 | 
            +
                apply_sketch_mask=False,
         | 
| 116 | 
            +
                sketch_mask_ratio=0.2,
         | 
| 117 | 
            +
                sketch_mask_shape=None,
         | 
| 118 | 
            +
                no_first_sketch: Union[bool, float] = False,
         | 
| 119 | 
            +
                video_clip_names=None,
         | 
| 120 | 
            +
                is_flux_sketch_available=None,
         | 
| 121 | 
            +
                is_evaluation=False,
         | 
| 122 | 
            +
            ):
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                Args:
         | 
| 125 | 
            +
                    sketch_model: torch.nn.Module, a sketch pool for extracting sketches from images
         | 
| 126 | 
            +
                    original_video: torch.Tensor, shape=(batch_size, num_channels, num_frames, height, width)
         | 
| 127 | 
            +
                    max_num_preserved_sketch_frames: int, maximum number of preserved sketch frames
         | 
| 128 | 
            +
                    max_num_preserved_image_frames: int, maximum number of preserved image frames
         | 
| 129 | 
            +
                    min_num_preserved_sketch_frames: int, minimum number of preserved sketch frames
         | 
| 130 | 
            +
                    min_num_preserved_image_frames: int, minimum number of preserved image frames
         | 
| 131 | 
            +
                    model_name: str, name of the sketch model. If None, randomly select from ["lineart", "lineart_anime", "anime2sketch"]. Default: None.
         | 
| 132 | 
            +
                    equally_spaced_preserve_sketch: bool, whether to preserve sketches at equally spaced intervals. Default: False.
         | 
| 133 | 
            +
                    apply_sketch_mask: bool, whether to apply random masking to sketch frames. Default: False.
         | 
| 134 | 
            +
                    sketch_mask_ratio: float, ratio of frames to mask (0-1). Default: 0.2.
         | 
| 135 | 
            +
                    sketch_mask_shape: str, shape type for masking ('square', 'circle', 'random'). If None, randomly selected. Default: None.
         | 
| 136 | 
            +
                Returns:
         | 
| 137 | 
            +
                    conditional_image: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
         | 
| 138 | 
            +
                    preserving_image_mask: torch.Tensor, shape=(batch_size, num_frames, 1, height, width)
         | 
| 139 | 
            +
                    full_sketch_frames: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width)
         | 
| 140 | 
            +
                    sketch_local_mask: torch.Tensor, shape=(batch_size, 1, num_frames, height, width) or None if apply_sketch_mask=False
         | 
| 141 | 
            +
                """
         | 
| 142 | 
            +
                video_shape = original_video.shape
         | 
| 143 | 
            +
                video_dtype = original_video.dtype
         | 
| 144 | 
            +
                video_device = original_video.device
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                if min_num_preserved_sketch_frames is None or min_num_preserved_sketch_frames < 2:
         | 
| 147 | 
            +
                    min_num_preserved_sketch_frames = 2  # Minimum num: 2 (the first and the last)
         | 
| 148 | 
            +
                num_preserved_sketch_frames = random.randint(min_num_preserved_sketch_frames, max_num_preserved_sketch_frames)
         | 
| 149 | 
            +
                num_preserved_sketch_frames = min(num_preserved_sketch_frames, video_shape[2])
         | 
| 150 | 
            +
                
         | 
| 151 | 
            +
                # Always include first and last frames
         | 
| 152 | 
            +
                if video_clip_names is not None and is_flux_sketch_available is not None:
         | 
| 153 | 
            +
                    if is_flux_sketch_available[0]:
         | 
| 154 | 
            +
                        num_preserved_sketch_frames = 2
         | 
| 155 | 
            +
                
         | 
| 156 | 
            +
                if isinstance(no_first_sketch, float):
         | 
| 157 | 
            +
                    no_first_sketch = random.random() < no_first_sketch
         | 
| 158 | 
            +
                
         | 
| 159 | 
            +
                if equally_spaced_preserve_sketch:
         | 
| 160 | 
            +
                    preserved_sketch_indices = torch.linspace(0, video_shape[2] - 1, num_preserved_sketch_frames).long().tolist()
         | 
| 161 | 
            +
                    if no_first_sketch:
         | 
| 162 | 
            +
                        preserved_sketch_indices = preserved_sketch_indices[1:]
         | 
| 163 | 
            +
                else:
         | 
| 164 | 
            +
                    if no_first_sketch:
         | 
| 165 | 
            +
                        preserved_sketch_indices = [video_shape[2] - 1] 
         | 
| 166 | 
            +
                    else:   
         | 
| 167 | 
            +
                        preserved_sketch_indices = [0, video_shape[2] - 1] 
         | 
| 168 | 
            +
                    # If we need more frames than just first and last
         | 
| 169 | 
            +
                    if num_preserved_sketch_frames > 2 and video_shape[2] > 4:
         | 
| 170 | 
            +
                        # Create set of all valid candidates (excluding first, last and their adjacent frames)
         | 
| 171 | 
            +
                        # Exclude indices adjacent to first and last
         | 
| 172 | 
            +
                        candidates = set(range(2, video_shape[2] - 2))
         | 
| 173 | 
            +
                        
         | 
| 174 | 
            +
                        # Determine how many additional frames to select
         | 
| 175 | 
            +
                        additional_frames_needed = min(num_preserved_sketch_frames - 2, len(candidates))
         | 
| 176 | 
            +
                        
         | 
| 177 | 
            +
                        # Keep selecting frames until we have enough or run out of candidates
         | 
| 178 | 
            +
                        additional_indices = []
         | 
| 179 | 
            +
                        while len(additional_indices) < additional_frames_needed and candidates:
         | 
| 180 | 
            +
                            # Convert set to list for random selection
         | 
| 181 | 
            +
                            candidate_list = list(candidates)
         | 
| 182 | 
            +
                            # Select a random candidate
         | 
| 183 | 
            +
                            idx = random.choice(candidate_list)
         | 
| 184 | 
            +
                            additional_indices.append(idx)
         | 
| 185 | 
            +
                            
         | 
| 186 | 
            +
                            # Remove selected index and adjacent indices from candidates
         | 
| 187 | 
            +
                            candidates.remove(idx)
         | 
| 188 | 
            +
                            if idx - 1 in candidates:
         | 
| 189 | 
            +
                                candidates.remove(idx - 1)
         | 
| 190 | 
            +
                            if idx + 1 in candidates:
         | 
| 191 | 
            +
                                candidates.remove(idx + 1)
         | 
| 192 | 
            +
                        
         | 
| 193 | 
            +
                        preserved_sketch_indices.extend(additional_indices)
         | 
| 194 | 
            +
                        preserved_sketch_indices.sort()
         | 
| 195 | 
            +
                        
         | 
| 196 | 
            +
                # Indices to preserve has been determined. 
         | 
| 197 | 
            +
                # Later code will not care the number of preserved frames but rely on the indices only.
         | 
| 198 | 
            +
                preserved_image_indices = [0]
         | 
| 199 | 
            +
                if max_num_preserved_image_frames is not None and max_num_preserved_image_frames > 1:
         | 
| 200 | 
            +
                    max_num_preserved_image_frames -= 1
         | 
| 201 | 
            +
                    if min_num_preserved_image_frames is None or min_num_preserved_image_frames < 1:
         | 
| 202 | 
            +
                        min_num_preserved_image_frames = 1
         | 
| 203 | 
            +
                    min_num_preserved_image_frames -= 1
         | 
| 204 | 
            +
                    other_indices = torch.tensor([i for i in range(video_shape[2]) if i not in preserved_sketch_indices])
         | 
| 205 | 
            +
                    max_num_preserved_image_frames = min(max_num_preserved_image_frames, len(other_indices))
         | 
| 206 | 
            +
                    min_num_preserved_image_frames = min(min_num_preserved_image_frames, max_num_preserved_image_frames)
         | 
| 207 | 
            +
                    num_preserved_image_frames = random.randint(min_num_preserved_image_frames, max_num_preserved_image_frames)
         | 
| 208 | 
            +
                    other_indices = other_indices[torch.randperm(len(other_indices))]
         | 
| 209 | 
            +
                    if num_preserved_image_frames > 0:
         | 
| 210 | 
            +
                        preserved_image_indices.extend(other_indices[:num_preserved_image_frames])
         | 
| 211 | 
            +
                
         | 
| 212 | 
            +
                preserved_condition_mask = torch.zeros(size=(video_shape[0], video_shape[2]), dtype=video_dtype, device=video_device)  # [b, t]
         | 
| 213 | 
            +
                masked_condition_video = torch.zeros_like(original_video)   # [b, c, t, h, w]
         | 
| 214 | 
            +
                full_sketch_frames = torch.zeros_like(original_video)  # [b, c, t, h, w]
         | 
| 215 | 
            +
                
         | 
| 216 | 
            +
                if detach_image_and_sketch:
         | 
| 217 | 
            +
                    preserved_condition_mask_sketch = torch.zeros_like(preserved_condition_mask)
         | 
| 218 | 
            +
                    masked_condition_video_sketch = torch.zeros_like(masked_condition_video)
         | 
| 219 | 
            +
                    if 0 not in preserved_sketch_indices and not no_first_sketch:
         | 
| 220 | 
            +
                        preserved_sketch_indices.append(0)
         | 
| 221 | 
            +
                else:
         | 
| 222 | 
            +
                    preserved_condition_mask_sketch = None
         | 
| 223 | 
            +
                    masked_condition_video_sketch = None
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                for _idx in preserved_image_indices:
         | 
| 226 | 
            +
                    preserved_condition_mask[:, _idx] = 1.0
         | 
| 227 | 
            +
                    masked_condition_video[:, :, _idx, :, :] = original_video[:, :, _idx, :, :]
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                # Set up sketch_local_mask if masking is applied
         | 
| 230 | 
            +
                sketch_local_mask = None
         | 
| 231 | 
            +
                    
         | 
| 232 | 
            +
                if apply_sketch_mask:
         | 
| 233 | 
            +
                    # Create a full-sized mask initialized to all ones (unmasked)
         | 
| 234 | 
            +
                    sketch_local_mask = torch.ones(
         | 
| 235 | 
            +
                        video_shape[0], video_shape[2], video_shape[3], video_shape[4],
         | 
| 236 | 
            +
                        device=video_device,
         | 
| 237 | 
            +
                        dtype=video_dtype
         | 
| 238 | 
            +
                    ).unsqueeze(1)  # Add channel dimension to get [b, 1, t, h, w]
         | 
| 239 | 
            +
                    
         | 
| 240 | 
            +
                    if not is_evaluation and random.random() < sketch_mask_ratio:
         | 
| 241 | 
            +
                        # For preserved frames, apply random masking
         | 
| 242 | 
            +
                        for i, frame_idx in enumerate(preserved_sketch_indices):
         | 
| 243 | 
            +
                            if i == 0:
         | 
| 244 | 
            +
                                # First frame is not masked
         | 
| 245 | 
            +
                                continue
         | 
| 246 | 
            +
                            # Create masks only for preserved frames
         | 
| 247 | 
            +
                            frame_masks = create_random_mask(
         | 
| 248 | 
            +
                                batch_size=video_shape[0],
         | 
| 249 | 
            +
                                num_frames=1,  # Just one frame at a time
         | 
| 250 | 
            +
                                height=video_shape[3],
         | 
| 251 | 
            +
                                width=video_shape[4],
         | 
| 252 | 
            +
                                device=video_device,
         | 
| 253 | 
            +
                                dtype=video_dtype,
         | 
| 254 | 
            +
                                # mask_area_ratio=0.4 * random.random() + 0.1,
         | 
| 255 | 
            +
                                shape_type=sketch_mask_shape
         | 
| 256 | 
            +
                            )
         | 
| 257 | 
            +
                            
         | 
| 258 | 
            +
                            # Set the mask for this preserved frame
         | 
| 259 | 
            +
                            sketch_local_mask[:, :, frame_idx:frame_idx+1, :, :] = frame_masks
         | 
| 260 | 
            +
                
         | 
| 261 | 
            +
                # Produce sketches for preserved frames
         | 
| 262 | 
            +
                # Sketches can either be 1) calculated from sketch pool or 2) loaded from the flux sketch directory
         | 
| 263 | 
            +
                if is_flux_sketch_available is not None and is_flux_sketch_available[0]:
         | 
| 264 | 
            +
                    should_use_flux_sketch = random.random() < 0.75 if not is_evaluation else True
         | 
| 265 | 
            +
                else:
         | 
| 266 | 
            +
                    should_use_flux_sketch = False
         | 
| 267 | 
            +
                    
         | 
| 268 | 
            +
                cur_model_name = "flux" if should_use_flux_sketch else random.choice(["lineart", "lineart_anime", "anime2sketch"]) if model_name is None else model_name # "anime2sketch"
         | 
| 269 | 
            +
                # cur_model_name = "anyline"
         | 
| 270 | 
            +
                for _idx in preserved_sketch_indices:
         | 
| 271 | 
            +
                    sketch_frame = None
         | 
| 272 | 
            +
                    if should_use_flux_sketch:
         | 
| 273 | 
            +
                        # Load flux sketch
         | 
| 274 | 
            +
                        sketech_path = f"/group/40005/gzhiwang/iclora/linearts/{video_clip_names[0]}/{_idx}.lineart.png"
         | 
| 275 | 
            +
                        print(f"Loading flux sketch from {sketech_path}...")
         | 
| 276 | 
            +
                        if os.path.exists(sketech_path):
         | 
| 277 | 
            +
                            sketch_frame = cv2.imread(sketech_path)
         | 
| 278 | 
            +
                            sketch_frame = cv2.cvtColor(sketch_frame, cv2.COLOR_BGR2RGB)
         | 
| 279 | 
            +
                            # resize to 480p
         | 
| 280 | 
            +
                            sketch_frame = cv2.resize(sketch_frame, (video_shape[4], video_shape[3]))
         | 
| 281 | 
            +
                            sketch_frame = torch.from_numpy(sketch_frame).to(video_device, dtype=video_dtype)
         | 
| 282 | 
            +
                            # Normalize to [-1, 1]
         | 
| 283 | 
            +
                            sketch_frame = sketch_frame / 255.0 * 2.0 - 1.0
         | 
| 284 | 
            +
                            sketch_frame = sketch_frame.permute(2, 0, 1)
         | 
| 285 | 
            +
                            sketch_frame = sketch_frame.unsqueeze(0)
         | 
| 286 | 
            +
                        else:
         | 
| 287 | 
            +
                            print(f"FLUX Sketch path {sketech_path} does not exist. Falling back to sketch pool.")
         | 
| 288 | 
            +
                        #     raise ValueError(f"FLUX Sketch path {sketech_path} does not exist.")
         | 
| 289 | 
            +
                    if sketch_frame is None:
         | 
| 290 | 
            +
                        # Calculate sketch from sketch pool
         | 
| 291 | 
            +
                        sketch_frame = extract_img_to_sketch(
         | 
| 292 | 
            +
                                sketch_model, original_video[:, :, _idx, :, :].float(),
         | 
| 293 | 
            +
                                model_name=cur_model_name).to(video_device, dtype=video_dtype)
         | 
| 294 | 
            +
                    # Convert white BG (from sketch pool or loaded from flux sketch files) to black BG (for training)
         | 
| 295 | 
            +
                    sketch_frame = -torch.clip(sketch_frame, -1, 1)
         | 
| 296 | 
            +
                    full_sketch_frames[:, :, _idx, :, :] = sketch_frame
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                if len(preserved_sketch_indices) > 0:
         | 
| 299 | 
            +
                    _mask_to_add = preserved_condition_mask_sketch if detach_image_and_sketch else preserved_condition_mask
         | 
| 300 | 
            +
                    _video_to_add = masked_condition_video_sketch if detach_image_and_sketch else masked_condition_video
         | 
| 301 | 
            +
                    if not detach_image_and_sketch:
         | 
| 302 | 
            +
                        preserved_sketch_indices = preserved_sketch_indices[1:]
         | 
| 303 | 
            +
                    
         | 
| 304 | 
            +
                    # Apply masking to sketch frames if required
         | 
| 305 | 
            +
                    if apply_sketch_mask and sketch_local_mask is not None:
         | 
| 306 | 
            +
                        # sketch_local_mask: [b, 1, t, h, w]
         | 
| 307 | 
            +
                        for _idx in preserved_sketch_indices:
         | 
| 308 | 
            +
                            _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
         | 
| 309 | 
            +
                            _video_to_add[:, :, _idx, :, :] = torch.where(sketch_local_mask[:, 0:1, _idx, :, :] == 0, -1.0, full_sketch_frames[:, :, _idx, :, :])
         | 
| 310 | 
            +
                    else:
         | 
| 311 | 
            +
                        for _idx in preserved_sketch_indices:
         | 
| 312 | 
            +
                            _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0
         | 
| 313 | 
            +
                            _video_to_add[:, :, _idx, :, :] = full_sketch_frames[:, :, _idx, :, :]
         | 
| 314 | 
            +
                                 
         | 
| 315 | 
            +
                return masked_condition_video, preserved_condition_mask, masked_condition_video_sketch, preserved_condition_mask_sketch, full_sketch_frames, sketch_local_mask, cur_model_name
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
