First commit
Browse files- LICENSE.txt +97 -0
 - README.md +63 -0
 - demo.ipynb +610 -0
 - dnnlib/__init__.py +24 -0
 - dnnlib/submission/__init__.py +8 -0
 - dnnlib/submission/internal/__init__.py +7 -0
 - dnnlib/submission/internal/local.py +22 -0
 - dnnlib/submission/run_context.py +110 -0
 - dnnlib/submission/submit.py +369 -0
 - dnnlib/tflib/__init__.py +20 -0
 - dnnlib/tflib/autosummary.py +193 -0
 - dnnlib/tflib/custom_ops.py +181 -0
 - dnnlib/tflib/network.py +825 -0
 - dnnlib/tflib/ops/__init__.py +9 -0
 - dnnlib/tflib/ops/fused_bias_act.cu +220 -0
 - dnnlib/tflib/ops/fused_bias_act.py +211 -0
 - dnnlib/tflib/ops/upfirdn_2d.cu +359 -0
 - dnnlib/tflib/ops/upfirdn_2d.py +418 -0
 - dnnlib/tflib/optimizer.py +372 -0
 - dnnlib/tflib/tfutil.py +264 -0
 - dnnlib/util.py +472 -0
 - gallery/gallery.md +15 -0
 - gallery/gl-mosaics1.png +3 -0
 - gallery/gl-mosaics10.png +3 -0
 - gallery/gl-mosaics2.png +3 -0
 - gallery/gl-mosaics3.png +3 -0
 - gallery/gl-mosaics4.png +3 -0
 - gallery/gl-mosaics5.png +3 -0
 - gallery/gl-mosaics6.png +3 -0
 - gallery/gl-mosaics7.png +3 -0
 - gallery/gl-mosaics8.png +3 -0
 - gallery/gl-mosaics9.png +3 -0
 - generate.py +700 -0
 - imgs/calligraphyv2.PNG +3 -0
 - imgs/calligraphyv3.png +3 -0
 - imgs/calligraphyv4.png +3 -0
 - imgs/calligraphyv5.png +3 -0
 - imgs/mosaic.png +3 -0
 - imgs/mosaicsv2.png +3 -0
 - imgs/mosaicsv3.png +3 -0
 - imgs/mosaicsv4.png +3 -0
 - models.py +142 -0
 - rasm.py +146 -0
 - requirements.txt +32 -0
 - utils.py +165 -0
 - video.gif +3 -0
 
    	
        LICENSE.txt
    ADDED
    
    | 
         @@ -0,0 +1,97 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            =======================================================================
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            1. Definitions
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            "Licensor" means any person or entity that distributes its Work.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            "Software" means the original work of authorship made available under
         
     | 
| 14 | 
         
            +
            this License.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            "Work" means the Software and any additions to or derivative works of
         
     | 
| 17 | 
         
            +
            the Software that are made available under this License.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            The terms "reproduce," "reproduction," "derivative works," and
         
     | 
| 20 | 
         
            +
            "distribution" have the meaning as provided under U.S. copyright law;
         
     | 
| 21 | 
         
            +
            provided, however, that for the purposes of this License, derivative
         
     | 
| 22 | 
         
            +
            works shall not include works that remain separable from, or merely
         
     | 
| 23 | 
         
            +
            link (or bind by name) to the interfaces of, the Work.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            Works, including the Software, are "made available" under this License
         
     | 
| 26 | 
         
            +
            by including in or with the Work either (a) a copyright notice
         
     | 
| 27 | 
         
            +
            referencing the applicability of this License to the Work, or (b) a
         
     | 
| 28 | 
         
            +
            copy of this License.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            2. License Grants
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                2.1 Copyright Grant. Subject to the terms and conditions of this
         
     | 
| 33 | 
         
            +
                License, each Licensor grants to you a perpetual, worldwide,
         
     | 
| 34 | 
         
            +
                non-exclusive, royalty-free, copyright license to reproduce,
         
     | 
| 35 | 
         
            +
                prepare derivative works of, publicly display, publicly perform,
         
     | 
| 36 | 
         
            +
                sublicense and distribute its Work and any resulting derivative
         
     | 
| 37 | 
         
            +
                works in any form.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            3. Limitations
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                3.1 Redistribution. You may reproduce or distribute the Work only
         
     | 
| 42 | 
         
            +
                if (a) you do so under this License, (b) you include a complete
         
     | 
| 43 | 
         
            +
                copy of this License with your distribution, and (c) you retain
         
     | 
| 44 | 
         
            +
                without modification any copyright, patent, trademark, or
         
     | 
| 45 | 
         
            +
                attribution notices that are present in the Work.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                3.2 Derivative Works. You may specify that additional or different
         
     | 
| 48 | 
         
            +
                terms apply to the use, reproduction, and distribution of your
         
     | 
| 49 | 
         
            +
                derivative works of the Work ("Your Terms") only if (a) Your Terms
         
     | 
| 50 | 
         
            +
                provide that the use limitation in Section 3.3 applies to your
         
     | 
| 51 | 
         
            +
                derivative works, and (b) you identify the specific derivative
         
     | 
| 52 | 
         
            +
                works that are subject to Your Terms. Notwithstanding Your Terms,
         
     | 
| 53 | 
         
            +
                this License (including the redistribution requirements in Section
         
     | 
| 54 | 
         
            +
                3.1) will continue to apply to the Work itself.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                3.3 Use Limitation. The Work and any derivative works thereof only
         
     | 
| 57 | 
         
            +
                may be used or intended for use non-commercially. Notwithstanding
         
     | 
| 58 | 
         
            +
                the foregoing, NVIDIA and its affiliates may use the Work and any
         
     | 
| 59 | 
         
            +
                derivative works commercially. As used herein, "non-commercially"
         
     | 
| 60 | 
         
            +
                means for research or evaluation purposes only.
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                3.4 Patent Claims. If you bring or threaten to bring a patent claim
         
     | 
| 63 | 
         
            +
                against any Licensor (including any claim, cross-claim or
         
     | 
| 64 | 
         
            +
                counterclaim in a lawsuit) to enforce any patents that you allege
         
     | 
| 65 | 
         
            +
                are infringed by any Work, then your rights under this License from
         
     | 
| 66 | 
         
            +
                such Licensor (including the grant in Section 2.1) will terminate
         
     | 
| 67 | 
         
            +
                immediately.
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                3.5 Trademarks. This License does not grant any rights to use any
         
     | 
| 70 | 
         
            +
                Licensor’s or its affiliates’ names, logos, or trademarks, except
         
     | 
| 71 | 
         
            +
                as necessary to reproduce the notices described in this License.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                3.6 Termination. If you violate any term of this License, then your
         
     | 
| 74 | 
         
            +
                rights under this License (including the grant in Section 2.1) will
         
     | 
| 75 | 
         
            +
                terminate immediately.
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            4. Disclaimer of Warranty.
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
         
     | 
| 80 | 
         
            +
            KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
         
     | 
| 81 | 
         
            +
            MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
         
     | 
| 82 | 
         
            +
            NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
         
     | 
| 83 | 
         
            +
            THIS LICENSE.
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            5. Limitation of Liability.
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
         
     | 
| 88 | 
         
            +
            THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
         
     | 
| 89 | 
         
            +
            SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
         
     | 
| 90 | 
         
            +
            INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
         
     | 
| 91 | 
         
            +
            OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
         
     | 
| 92 | 
         
            +
            (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
         
     | 
| 93 | 
         
            +
            LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
         
     | 
| 94 | 
         
            +
            COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
         
     | 
| 95 | 
         
            +
            THE POSSIBILITY OF SUCH DAMAGES.
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            =======================================================================
         
     | 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,63 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ## rasm
         
     | 
| 2 | 
         
            +
            Arabic art using GANs. We currently have two models for generating calligraphy and mosaics.  
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            ## Notebooks 
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            <table class="tg">
         
     | 
| 7 | 
         
            +
              <tr>
         
     | 
| 8 | 
         
            +
                <th class="tg-yw4l"><b>Name</b></th>
         
     | 
| 9 | 
         
            +
                <th class="tg-yw4l"><b>Notebook</b></th>
         
     | 
| 10 | 
         
            +
              </tr>
         
     | 
| 11 | 
         
            +
              <tr>
         
     | 
| 12 | 
         
            +
                <td class="tg-yw4l">Visualization</td>
         
     | 
| 13 | 
         
            +
                <td class="tg-yw4l"><a href="https://colab.research.google.com/github/ARBML/rasm/blob/master/demo.ipynb">
         
     | 
| 14 | 
         
            +
              <img src="https://colab.research.google.com/assets/colab-badge.svg" width = '100px' >
         
     | 
| 15 | 
         
            +
            </a></td>
         
     | 
| 16 | 
         
            +
              </tr>
         
     | 
| 17 | 
         
            +
            </table>
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            ## Visualization 
         
     | 
| 20 | 
         
            +
            A set of functions for vis, interpolation and animation. Mostly tested in colab notebooks. 
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            ### Load Model 
         
     | 
| 23 | 
         
            +
            ```python 
         
     | 
| 24 | 
         
            +
            from rasm import Rasm
         
     | 
| 25 | 
         
            +
            model = Rasm(mode = 'calligraphy')
         
     | 
| 26 | 
         
            +
            model = Rasm(mode = 'mosaics')
         
     | 
| 27 | 
         
            +
            ```
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            ### Generate random 
         
     | 
| 30 | 
         
            +
            ```python 
         
     | 
| 31 | 
         
            +
            model.generate_randomly()
         
     | 
| 32 | 
         
            +
            ```
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            ### Generate grid 
         
     | 
| 35 | 
         
            +
            ```python 
         
     | 
| 36 | 
         
            +
            model.generate_grid()
         
     | 
| 37 | 
         
            +
            ```
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            ### Generate animation 
         
     | 
| 40 | 
         
            +
            ```python
         
     | 
| 41 | 
         
            +
            model.generate_animation(size = 2, steps = 20)
         
     | 
| 42 | 
         
            +
            ```
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            ## Sample Models 
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            ### Mosaics 
         
     | 
| 49 | 
         
            +
            
         
     | 
| 50 | 
         
            +
            
         
     | 
| 51 | 
         
            +
            
         
     | 
| 52 | 
         
            +
            
         
     | 
| 53 | 
         
            +
            ### Calligraphy 
         
     | 
| 54 | 
         
            +
            
         
     | 
| 55 | 
         
            +
            
         
     | 
| 56 | 
         
            +
            
         
     | 
| 57 | 
         
            +
            
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            ## References 
         
     | 
| 61 | 
         
            +
            - Gan-surgery: https://github.com/aydao/stylegan2-surgery
         
     | 
| 62 | 
         
            +
            - WikiArt model: https://github.com/pbaylies/stylegan2 
         
     | 
| 63 | 
         
            +
            - Starter-Notebook: https://github.com/Hephyrius/Stylegan2-Ada-Google-Colab-Starter-Notebook/
         
     | 
    	
        demo.ipynb
    ADDED
    
    | 
         @@ -0,0 +1,610 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "nbformat": 4,
         
     | 
| 3 | 
         
            +
              "nbformat_minor": 0,
         
     | 
| 4 | 
         
            +
              "metadata": {
         
     | 
| 5 | 
         
            +
                "colab": {
         
     | 
| 6 | 
         
            +
                  "name": "SGAN Vis.ipynb",
         
     | 
| 7 | 
         
            +
                  "provenance": [],
         
     | 
| 8 | 
         
            +
                  "machine_shape": "hm"
         
     | 
| 9 | 
         
            +
                },
         
     | 
| 10 | 
         
            +
                "kernelspec": {
         
     | 
| 11 | 
         
            +
                  "name": "python3",
         
     | 
| 12 | 
         
            +
                  "display_name": "Python 3"
         
     | 
| 13 | 
         
            +
                },
         
     | 
| 14 | 
         
            +
                "accelerator": "GPU",
         
     | 
| 15 | 
         
            +
                "widgets": {
         
     | 
| 16 | 
         
            +
                  "application/vnd.jupyter.widget-state+json": {
         
     | 
| 17 | 
         
            +
                    "edfc5ae9a3924ee6a04811ab3dec1656": {
         
     | 
| 18 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 19 | 
         
            +
                      "model_name": "VBoxModel",
         
     | 
| 20 | 
         
            +
                      "state": {
         
     | 
| 21 | 
         
            +
                        "_view_name": "VBoxView",
         
     | 
| 22 | 
         
            +
                        "_dom_classes": [],
         
     | 
| 23 | 
         
            +
                        "_model_name": "VBoxModel",
         
     | 
| 24 | 
         
            +
                        "_view_module": "@jupyter-widgets/controls",
         
     | 
| 25 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 26 | 
         
            +
                        "_view_count": null,
         
     | 
| 27 | 
         
            +
                        "_view_module_version": "1.5.0",
         
     | 
| 28 | 
         
            +
                        "box_style": "",
         
     | 
| 29 | 
         
            +
                        "layout": "IPY_MODEL_86ea189295a54dd0aa2a771a95b12f00",
         
     | 
| 30 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls",
         
     | 
| 31 | 
         
            +
                        "children": [
         
     | 
| 32 | 
         
            +
                          "IPY_MODEL_61e1782fb995412fbef8e59c570e8823",
         
     | 
| 33 | 
         
            +
                          "IPY_MODEL_a03932d2929c4225b7735d0831a18566"
         
     | 
| 34 | 
         
            +
                        ]
         
     | 
| 35 | 
         
            +
                      }
         
     | 
| 36 | 
         
            +
                    },
         
     | 
| 37 | 
         
            +
                    "86ea189295a54dd0aa2a771a95b12f00": {
         
     | 
| 38 | 
         
            +
                      "model_module": "@jupyter-widgets/base",
         
     | 
| 39 | 
         
            +
                      "model_name": "LayoutModel",
         
     | 
| 40 | 
         
            +
                      "state": {
         
     | 
| 41 | 
         
            +
                        "_view_name": "LayoutView",
         
     | 
| 42 | 
         
            +
                        "grid_template_rows": null,
         
     | 
| 43 | 
         
            +
                        "right": null,
         
     | 
| 44 | 
         
            +
                        "justify_content": null,
         
     | 
| 45 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 46 | 
         
            +
                        "overflow": null,
         
     | 
| 47 | 
         
            +
                        "_model_module_version": "1.2.0",
         
     | 
| 48 | 
         
            +
                        "_view_count": null,
         
     | 
| 49 | 
         
            +
                        "flex_flow": null,
         
     | 
| 50 | 
         
            +
                        "width": null,
         
     | 
| 51 | 
         
            +
                        "min_width": null,
         
     | 
| 52 | 
         
            +
                        "border": null,
         
     | 
| 53 | 
         
            +
                        "align_items": null,
         
     | 
| 54 | 
         
            +
                        "bottom": null,
         
     | 
| 55 | 
         
            +
                        "_model_module": "@jupyter-widgets/base",
         
     | 
| 56 | 
         
            +
                        "top": null,
         
     | 
| 57 | 
         
            +
                        "grid_column": null,
         
     | 
| 58 | 
         
            +
                        "overflow_y": null,
         
     | 
| 59 | 
         
            +
                        "overflow_x": null,
         
     | 
| 60 | 
         
            +
                        "grid_auto_flow": null,
         
     | 
| 61 | 
         
            +
                        "grid_area": null,
         
     | 
| 62 | 
         
            +
                        "grid_template_columns": null,
         
     | 
| 63 | 
         
            +
                        "flex": null,
         
     | 
| 64 | 
         
            +
                        "_model_name": "LayoutModel",
         
     | 
| 65 | 
         
            +
                        "justify_items": null,
         
     | 
| 66 | 
         
            +
                        "grid_row": null,
         
     | 
| 67 | 
         
            +
                        "max_height": null,
         
     | 
| 68 | 
         
            +
                        "align_content": null,
         
     | 
| 69 | 
         
            +
                        "visibility": null,
         
     | 
| 70 | 
         
            +
                        "align_self": null,
         
     | 
| 71 | 
         
            +
                        "height": null,
         
     | 
| 72 | 
         
            +
                        "min_height": null,
         
     | 
| 73 | 
         
            +
                        "padding": null,
         
     | 
| 74 | 
         
            +
                        "grid_auto_rows": null,
         
     | 
| 75 | 
         
            +
                        "grid_gap": null,
         
     | 
| 76 | 
         
            +
                        "max_width": null,
         
     | 
| 77 | 
         
            +
                        "order": null,
         
     | 
| 78 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 79 | 
         
            +
                        "grid_template_areas": null,
         
     | 
| 80 | 
         
            +
                        "object_position": null,
         
     | 
| 81 | 
         
            +
                        "object_fit": null,
         
     | 
| 82 | 
         
            +
                        "grid_auto_columns": null,
         
     | 
| 83 | 
         
            +
                        "margin": null,
         
     | 
| 84 | 
         
            +
                        "display": null,
         
     | 
| 85 | 
         
            +
                        "left": null
         
     | 
| 86 | 
         
            +
                      }
         
     | 
| 87 | 
         
            +
                    },
         
     | 
| 88 | 
         
            +
                    "61e1782fb995412fbef8e59c570e8823": {
         
     | 
| 89 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 90 | 
         
            +
                      "model_name": "HTMLModel",
         
     | 
| 91 | 
         
            +
                      "state": {
         
     | 
| 92 | 
         
            +
                        "_view_name": "HTMLView",
         
     | 
| 93 | 
         
            +
                        "style": "IPY_MODEL_c5b6fad0fd114c90ac0f23d91546846f",
         
     | 
| 94 | 
         
            +
                        "_dom_classes": [],
         
     | 
| 95 | 
         
            +
                        "description": "",
         
     | 
| 96 | 
         
            +
                        "_model_name": "HTMLModel",
         
     | 
| 97 | 
         
            +
                        "placeholder": "",
         
     | 
| 98 | 
         
            +
                        "_view_module": "@jupyter-widgets/controls",
         
     | 
| 99 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 100 | 
         
            +
                        "value": "Generating images: 1",
         
     | 
| 101 | 
         
            +
                        "_view_count": null,
         
     | 
| 102 | 
         
            +
                        "_view_module_version": "1.5.0",
         
     | 
| 103 | 
         
            +
                        "description_tooltip": null,
         
     | 
| 104 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls",
         
     | 
| 105 | 
         
            +
                        "layout": "IPY_MODEL_7dea244035a34d0f8658e75a12b5cd89"
         
     | 
| 106 | 
         
            +
                      }
         
     | 
| 107 | 
         
            +
                    },
         
     | 
| 108 | 
         
            +
                    "a03932d2929c4225b7735d0831a18566": {
         
     | 
| 109 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 110 | 
         
            +
                      "model_name": "IntProgressModel",
         
     | 
| 111 | 
         
            +
                      "state": {
         
     | 
| 112 | 
         
            +
                        "_view_name": "ProgressView",
         
     | 
| 113 | 
         
            +
                        "style": "IPY_MODEL_13d6a5cf5103499ebfb40e2e9c520d27",
         
     | 
| 114 | 
         
            +
                        "_dom_classes": [],
         
     | 
| 115 | 
         
            +
                        "description": "",
         
     | 
| 116 | 
         
            +
                        "_model_name": "IntProgressModel",
         
     | 
| 117 | 
         
            +
                        "bar_style": "success",
         
     | 
| 118 | 
         
            +
                        "max": 1,
         
     | 
| 119 | 
         
            +
                        "_view_module": "@jupyter-widgets/controls",
         
     | 
| 120 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 121 | 
         
            +
                        "value": 1,
         
     | 
| 122 | 
         
            +
                        "_view_count": null,
         
     | 
| 123 | 
         
            +
                        "_view_module_version": "1.5.0",
         
     | 
| 124 | 
         
            +
                        "orientation": "horizontal",
         
     | 
| 125 | 
         
            +
                        "min": 0,
         
     | 
| 126 | 
         
            +
                        "description_tooltip": null,
         
     | 
| 127 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls",
         
     | 
| 128 | 
         
            +
                        "layout": "IPY_MODEL_653de804a191499ca46a8eb6b91e7a6b"
         
     | 
| 129 | 
         
            +
                      }
         
     | 
| 130 | 
         
            +
                    },
         
     | 
| 131 | 
         
            +
                    "c5b6fad0fd114c90ac0f23d91546846f": {
         
     | 
| 132 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 133 | 
         
            +
                      "model_name": "DescriptionStyleModel",
         
     | 
| 134 | 
         
            +
                      "state": {
         
     | 
| 135 | 
         
            +
                        "_view_name": "StyleView",
         
     | 
| 136 | 
         
            +
                        "_model_name": "DescriptionStyleModel",
         
     | 
| 137 | 
         
            +
                        "description_width": "",
         
     | 
| 138 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 139 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 140 | 
         
            +
                        "_view_count": null,
         
     | 
| 141 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 142 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls"
         
     | 
| 143 | 
         
            +
                      }
         
     | 
| 144 | 
         
            +
                    },
         
     | 
| 145 | 
         
            +
                    "7dea244035a34d0f8658e75a12b5cd89": {
         
     | 
| 146 | 
         
            +
                      "model_module": "@jupyter-widgets/base",
         
     | 
| 147 | 
         
            +
                      "model_name": "LayoutModel",
         
     | 
| 148 | 
         
            +
                      "state": {
         
     | 
| 149 | 
         
            +
                        "_view_name": "LayoutView",
         
     | 
| 150 | 
         
            +
                        "grid_template_rows": null,
         
     | 
| 151 | 
         
            +
                        "right": null,
         
     | 
| 152 | 
         
            +
                        "justify_content": null,
         
     | 
| 153 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 154 | 
         
            +
                        "overflow": null,
         
     | 
| 155 | 
         
            +
                        "_model_module_version": "1.2.0",
         
     | 
| 156 | 
         
            +
                        "_view_count": null,
         
     | 
| 157 | 
         
            +
                        "flex_flow": null,
         
     | 
| 158 | 
         
            +
                        "width": null,
         
     | 
| 159 | 
         
            +
                        "min_width": null,
         
     | 
| 160 | 
         
            +
                        "border": null,
         
     | 
| 161 | 
         
            +
                        "align_items": null,
         
     | 
| 162 | 
         
            +
                        "bottom": null,
         
     | 
| 163 | 
         
            +
                        "_model_module": "@jupyter-widgets/base",
         
     | 
| 164 | 
         
            +
                        "top": null,
         
     | 
| 165 | 
         
            +
                        "grid_column": null,
         
     | 
| 166 | 
         
            +
                        "overflow_y": null,
         
     | 
| 167 | 
         
            +
                        "overflow_x": null,
         
     | 
| 168 | 
         
            +
                        "grid_auto_flow": null,
         
     | 
| 169 | 
         
            +
                        "grid_area": null,
         
     | 
| 170 | 
         
            +
                        "grid_template_columns": null,
         
     | 
| 171 | 
         
            +
                        "flex": null,
         
     | 
| 172 | 
         
            +
                        "_model_name": "LayoutModel",
         
     | 
| 173 | 
         
            +
                        "justify_items": null,
         
     | 
| 174 | 
         
            +
                        "grid_row": null,
         
     | 
| 175 | 
         
            +
                        "max_height": null,
         
     | 
| 176 | 
         
            +
                        "align_content": null,
         
     | 
| 177 | 
         
            +
                        "visibility": null,
         
     | 
| 178 | 
         
            +
                        "align_self": null,
         
     | 
| 179 | 
         
            +
                        "height": null,
         
     | 
| 180 | 
         
            +
                        "min_height": null,
         
     | 
| 181 | 
         
            +
                        "padding": null,
         
     | 
| 182 | 
         
            +
                        "grid_auto_rows": null,
         
     | 
| 183 | 
         
            +
                        "grid_gap": null,
         
     | 
| 184 | 
         
            +
                        "max_width": null,
         
     | 
| 185 | 
         
            +
                        "order": null,
         
     | 
| 186 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 187 | 
         
            +
                        "grid_template_areas": null,
         
     | 
| 188 | 
         
            +
                        "object_position": null,
         
     | 
| 189 | 
         
            +
                        "object_fit": null,
         
     | 
| 190 | 
         
            +
                        "grid_auto_columns": null,
         
     | 
| 191 | 
         
            +
                        "margin": null,
         
     | 
| 192 | 
         
            +
                        "display": null,
         
     | 
| 193 | 
         
            +
                        "left": null
         
     | 
| 194 | 
         
            +
                      }
         
     | 
| 195 | 
         
            +
                    },
         
     | 
| 196 | 
         
            +
                    "13d6a5cf5103499ebfb40e2e9c520d27": {
         
     | 
| 197 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 198 | 
         
            +
                      "model_name": "ProgressStyleModel",
         
     | 
| 199 | 
         
            +
                      "state": {
         
     | 
| 200 | 
         
            +
                        "_view_name": "StyleView",
         
     | 
| 201 | 
         
            +
                        "_model_name": "ProgressStyleModel",
         
     | 
| 202 | 
         
            +
                        "description_width": "",
         
     | 
| 203 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 204 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 205 | 
         
            +
                        "_view_count": null,
         
     | 
| 206 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 207 | 
         
            +
                        "bar_color": null,
         
     | 
| 208 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls"
         
     | 
| 209 | 
         
            +
                      }
         
     | 
| 210 | 
         
            +
                    },
         
     | 
| 211 | 
         
            +
                    "653de804a191499ca46a8eb6b91e7a6b": {
         
     | 
| 212 | 
         
            +
                      "model_module": "@jupyter-widgets/base",
         
     | 
| 213 | 
         
            +
                      "model_name": "LayoutModel",
         
     | 
| 214 | 
         
            +
                      "state": {
         
     | 
| 215 | 
         
            +
                        "_view_name": "LayoutView",
         
     | 
| 216 | 
         
            +
                        "grid_template_rows": null,
         
     | 
| 217 | 
         
            +
                        "right": null,
         
     | 
| 218 | 
         
            +
                        "justify_content": null,
         
     | 
| 219 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 220 | 
         
            +
                        "overflow": null,
         
     | 
| 221 | 
         
            +
                        "_model_module_version": "1.2.0",
         
     | 
| 222 | 
         
            +
                        "_view_count": null,
         
     | 
| 223 | 
         
            +
                        "flex_flow": null,
         
     | 
| 224 | 
         
            +
                        "width": null,
         
     | 
| 225 | 
         
            +
                        "min_width": null,
         
     | 
| 226 | 
         
            +
                        "border": null,
         
     | 
| 227 | 
         
            +
                        "align_items": null,
         
     | 
| 228 | 
         
            +
                        "bottom": null,
         
     | 
| 229 | 
         
            +
                        "_model_module": "@jupyter-widgets/base",
         
     | 
| 230 | 
         
            +
                        "top": null,
         
     | 
| 231 | 
         
            +
                        "grid_column": null,
         
     | 
| 232 | 
         
            +
                        "overflow_y": null,
         
     | 
| 233 | 
         
            +
                        "overflow_x": null,
         
     | 
| 234 | 
         
            +
                        "grid_auto_flow": null,
         
     | 
| 235 | 
         
            +
                        "grid_area": null,
         
     | 
| 236 | 
         
            +
                        "grid_template_columns": null,
         
     | 
| 237 | 
         
            +
                        "flex": null,
         
     | 
| 238 | 
         
            +
                        "_model_name": "LayoutModel",
         
     | 
| 239 | 
         
            +
                        "justify_items": null,
         
     | 
| 240 | 
         
            +
                        "grid_row": null,
         
     | 
| 241 | 
         
            +
                        "max_height": null,
         
     | 
| 242 | 
         
            +
                        "align_content": null,
         
     | 
| 243 | 
         
            +
                        "visibility": null,
         
     | 
| 244 | 
         
            +
                        "align_self": null,
         
     | 
| 245 | 
         
            +
                        "height": null,
         
     | 
| 246 | 
         
            +
                        "min_height": null,
         
     | 
| 247 | 
         
            +
                        "padding": null,
         
     | 
| 248 | 
         
            +
                        "grid_auto_rows": null,
         
     | 
| 249 | 
         
            +
                        "grid_gap": null,
         
     | 
| 250 | 
         
            +
                        "max_width": null,
         
     | 
| 251 | 
         
            +
                        "order": null,
         
     | 
| 252 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 253 | 
         
            +
                        "grid_template_areas": null,
         
     | 
| 254 | 
         
            +
                        "object_position": null,
         
     | 
| 255 | 
         
            +
                        "object_fit": null,
         
     | 
| 256 | 
         
            +
                        "grid_auto_columns": null,
         
     | 
| 257 | 
         
            +
                        "margin": null,
         
     | 
| 258 | 
         
            +
                        "display": null,
         
     | 
| 259 | 
         
            +
                        "left": null
         
     | 
| 260 | 
         
            +
                      }
         
     | 
| 261 | 
         
            +
                    },
         
     | 
| 262 | 
         
            +
                    "ee14e110bec24331816d2243cdde5e37": {
         
     | 
| 263 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 264 | 
         
            +
                      "model_name": "VBoxModel",
         
     | 
| 265 | 
         
            +
                      "state": {
         
     | 
| 266 | 
         
            +
                        "_view_name": "VBoxView",
         
     | 
| 267 | 
         
            +
                        "_dom_classes": [],
         
     | 
| 268 | 
         
            +
                        "_model_name": "VBoxModel",
         
     | 
| 269 | 
         
            +
                        "_view_module": "@jupyter-widgets/controls",
         
     | 
| 270 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 271 | 
         
            +
                        "_view_count": null,
         
     | 
| 272 | 
         
            +
                        "_view_module_version": "1.5.0",
         
     | 
| 273 | 
         
            +
                        "box_style": "",
         
     | 
| 274 | 
         
            +
                        "layout": "IPY_MODEL_c3fbec3f54b54293a667dfe43c5aed56",
         
     | 
| 275 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls",
         
     | 
| 276 | 
         
            +
                        "children": [
         
     | 
| 277 | 
         
            +
                          "IPY_MODEL_93630fd11e7f47fa9ef23b1057ec1dd9",
         
     | 
| 278 | 
         
            +
                          "IPY_MODEL_aa9410600ce04b5592f4d2a7b4531656"
         
     | 
| 279 | 
         
            +
                        ]
         
     | 
| 280 | 
         
            +
                      }
         
     | 
| 281 | 
         
            +
                    },
         
     | 
| 282 | 
         
            +
                    "c3fbec3f54b54293a667dfe43c5aed56": {
         
     | 
| 283 | 
         
            +
                      "model_module": "@jupyter-widgets/base",
         
     | 
| 284 | 
         
            +
                      "model_name": "LayoutModel",
         
     | 
| 285 | 
         
            +
                      "state": {
         
     | 
| 286 | 
         
            +
                        "_view_name": "LayoutView",
         
     | 
| 287 | 
         
            +
                        "grid_template_rows": null,
         
     | 
| 288 | 
         
            +
                        "right": null,
         
     | 
| 289 | 
         
            +
                        "justify_content": null,
         
     | 
| 290 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 291 | 
         
            +
                        "overflow": null,
         
     | 
| 292 | 
         
            +
                        "_model_module_version": "1.2.0",
         
     | 
| 293 | 
         
            +
                        "_view_count": null,
         
     | 
| 294 | 
         
            +
                        "flex_flow": null,
         
     | 
| 295 | 
         
            +
                        "width": null,
         
     | 
| 296 | 
         
            +
                        "min_width": null,
         
     | 
| 297 | 
         
            +
                        "border": null,
         
     | 
| 298 | 
         
            +
                        "align_items": null,
         
     | 
| 299 | 
         
            +
                        "bottom": null,
         
     | 
| 300 | 
         
            +
                        "_model_module": "@jupyter-widgets/base",
         
     | 
| 301 | 
         
            +
                        "top": null,
         
     | 
| 302 | 
         
            +
                        "grid_column": null,
         
     | 
| 303 | 
         
            +
                        "overflow_y": null,
         
     | 
| 304 | 
         
            +
                        "overflow_x": null,
         
     | 
| 305 | 
         
            +
                        "grid_auto_flow": null,
         
     | 
| 306 | 
         
            +
                        "grid_area": null,
         
     | 
| 307 | 
         
            +
                        "grid_template_columns": null,
         
     | 
| 308 | 
         
            +
                        "flex": null,
         
     | 
| 309 | 
         
            +
                        "_model_name": "LayoutModel",
         
     | 
| 310 | 
         
            +
                        "justify_items": null,
         
     | 
| 311 | 
         
            +
                        "grid_row": null,
         
     | 
| 312 | 
         
            +
                        "max_height": null,
         
     | 
| 313 | 
         
            +
                        "align_content": null,
         
     | 
| 314 | 
         
            +
                        "visibility": null,
         
     | 
| 315 | 
         
            +
                        "align_self": null,
         
     | 
| 316 | 
         
            +
                        "height": null,
         
     | 
| 317 | 
         
            +
                        "min_height": null,
         
     | 
| 318 | 
         
            +
                        "padding": null,
         
     | 
| 319 | 
         
            +
                        "grid_auto_rows": null,
         
     | 
| 320 | 
         
            +
                        "grid_gap": null,
         
     | 
| 321 | 
         
            +
                        "max_width": null,
         
     | 
| 322 | 
         
            +
                        "order": null,
         
     | 
| 323 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 324 | 
         
            +
                        "grid_template_areas": null,
         
     | 
| 325 | 
         
            +
                        "object_position": null,
         
     | 
| 326 | 
         
            +
                        "object_fit": null,
         
     | 
| 327 | 
         
            +
                        "grid_auto_columns": null,
         
     | 
| 328 | 
         
            +
                        "margin": null,
         
     | 
| 329 | 
         
            +
                        "display": null,
         
     | 
| 330 | 
         
            +
                        "left": null
         
     | 
| 331 | 
         
            +
                      }
         
     | 
| 332 | 
         
            +
                    },
         
     | 
| 333 | 
         
            +
                    "93630fd11e7f47fa9ef23b1057ec1dd9": {
         
     | 
| 334 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 335 | 
         
            +
                      "model_name": "HTMLModel",
         
     | 
| 336 | 
         
            +
                      "state": {
         
     | 
| 337 | 
         
            +
                        "_view_name": "HTMLView",
         
     | 
| 338 | 
         
            +
                        "style": "IPY_MODEL_af65853e39c04163a6173f438cc0ae71",
         
     | 
| 339 | 
         
            +
                        "_dom_classes": [],
         
     | 
| 340 | 
         
            +
                        "description": "",
         
     | 
| 341 | 
         
            +
                        "_model_name": "HTMLModel",
         
     | 
| 342 | 
         
            +
                        "placeholder": "",
         
     | 
| 343 | 
         
            +
                        "_view_module": "@jupyter-widgets/controls",
         
     | 
| 344 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 345 | 
         
            +
                        "value": "Generating images: 9",
         
     | 
| 346 | 
         
            +
                        "_view_count": null,
         
     | 
| 347 | 
         
            +
                        "_view_module_version": "1.5.0",
         
     | 
| 348 | 
         
            +
                        "description_tooltip": null,
         
     | 
| 349 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls",
         
     | 
| 350 | 
         
            +
                        "layout": "IPY_MODEL_2bb901e39be5406bb1f495030180732b"
         
     | 
| 351 | 
         
            +
                      }
         
     | 
| 352 | 
         
            +
                    },
         
     | 
| 353 | 
         
            +
                    "aa9410600ce04b5592f4d2a7b4531656": {
         
     | 
| 354 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 355 | 
         
            +
                      "model_name": "IntProgressModel",
         
     | 
| 356 | 
         
            +
                      "state": {
         
     | 
| 357 | 
         
            +
                        "_view_name": "ProgressView",
         
     | 
| 358 | 
         
            +
                        "style": "IPY_MODEL_f2b4a35ed75343aab26390a8c26b6e0f",
         
     | 
| 359 | 
         
            +
                        "_dom_classes": [],
         
     | 
| 360 | 
         
            +
                        "description": "",
         
     | 
| 361 | 
         
            +
                        "_model_name": "IntProgressModel",
         
     | 
| 362 | 
         
            +
                        "bar_style": "success",
         
     | 
| 363 | 
         
            +
                        "max": 9,
         
     | 
| 364 | 
         
            +
                        "_view_module": "@jupyter-widgets/controls",
         
     | 
| 365 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 366 | 
         
            +
                        "value": 9,
         
     | 
| 367 | 
         
            +
                        "_view_count": null,
         
     | 
| 368 | 
         
            +
                        "_view_module_version": "1.5.0",
         
     | 
| 369 | 
         
            +
                        "orientation": "horizontal",
         
     | 
| 370 | 
         
            +
                        "min": 0,
         
     | 
| 371 | 
         
            +
                        "description_tooltip": null,
         
     | 
| 372 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls",
         
     | 
| 373 | 
         
            +
                        "layout": "IPY_MODEL_df1e1c500f9c42e09fd96d7cb6cca316"
         
     | 
| 374 | 
         
            +
                      }
         
     | 
| 375 | 
         
            +
                    },
         
     | 
| 376 | 
         
            +
                    "af65853e39c04163a6173f438cc0ae71": {
         
     | 
| 377 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 378 | 
         
            +
                      "model_name": "DescriptionStyleModel",
         
     | 
| 379 | 
         
            +
                      "state": {
         
     | 
| 380 | 
         
            +
                        "_view_name": "StyleView",
         
     | 
| 381 | 
         
            +
                        "_model_name": "DescriptionStyleModel",
         
     | 
| 382 | 
         
            +
                        "description_width": "",
         
     | 
| 383 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 384 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 385 | 
         
            +
                        "_view_count": null,
         
     | 
| 386 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 387 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls"
         
     | 
| 388 | 
         
            +
                      }
         
     | 
| 389 | 
         
            +
                    },
         
     | 
| 390 | 
         
            +
                    "2bb901e39be5406bb1f495030180732b": {
         
     | 
| 391 | 
         
            +
                      "model_module": "@jupyter-widgets/base",
         
     | 
| 392 | 
         
            +
                      "model_name": "LayoutModel",
         
     | 
| 393 | 
         
            +
                      "state": {
         
     | 
| 394 | 
         
            +
                        "_view_name": "LayoutView",
         
     | 
| 395 | 
         
            +
                        "grid_template_rows": null,
         
     | 
| 396 | 
         
            +
                        "right": null,
         
     | 
| 397 | 
         
            +
                        "justify_content": null,
         
     | 
| 398 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 399 | 
         
            +
                        "overflow": null,
         
     | 
| 400 | 
         
            +
                        "_model_module_version": "1.2.0",
         
     | 
| 401 | 
         
            +
                        "_view_count": null,
         
     | 
| 402 | 
         
            +
                        "flex_flow": null,
         
     | 
| 403 | 
         
            +
                        "width": null,
         
     | 
| 404 | 
         
            +
                        "min_width": null,
         
     | 
| 405 | 
         
            +
                        "border": null,
         
     | 
| 406 | 
         
            +
                        "align_items": null,
         
     | 
| 407 | 
         
            +
                        "bottom": null,
         
     | 
| 408 | 
         
            +
                        "_model_module": "@jupyter-widgets/base",
         
     | 
| 409 | 
         
            +
                        "top": null,
         
     | 
| 410 | 
         
            +
                        "grid_column": null,
         
     | 
| 411 | 
         
            +
                        "overflow_y": null,
         
     | 
| 412 | 
         
            +
                        "overflow_x": null,
         
     | 
| 413 | 
         
            +
                        "grid_auto_flow": null,
         
     | 
| 414 | 
         
            +
                        "grid_area": null,
         
     | 
| 415 | 
         
            +
                        "grid_template_columns": null,
         
     | 
| 416 | 
         
            +
                        "flex": null,
         
     | 
| 417 | 
         
            +
                        "_model_name": "LayoutModel",
         
     | 
| 418 | 
         
            +
                        "justify_items": null,
         
     | 
| 419 | 
         
            +
                        "grid_row": null,
         
     | 
| 420 | 
         
            +
                        "max_height": null,
         
     | 
| 421 | 
         
            +
                        "align_content": null,
         
     | 
| 422 | 
         
            +
                        "visibility": null,
         
     | 
| 423 | 
         
            +
                        "align_self": null,
         
     | 
| 424 | 
         
            +
                        "height": null,
         
     | 
| 425 | 
         
            +
                        "min_height": null,
         
     | 
| 426 | 
         
            +
                        "padding": null,
         
     | 
| 427 | 
         
            +
                        "grid_auto_rows": null,
         
     | 
| 428 | 
         
            +
                        "grid_gap": null,
         
     | 
| 429 | 
         
            +
                        "max_width": null,
         
     | 
| 430 | 
         
            +
                        "order": null,
         
     | 
| 431 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 432 | 
         
            +
                        "grid_template_areas": null,
         
     | 
| 433 | 
         
            +
                        "object_position": null,
         
     | 
| 434 | 
         
            +
                        "object_fit": null,
         
     | 
| 435 | 
         
            +
                        "grid_auto_columns": null,
         
     | 
| 436 | 
         
            +
                        "margin": null,
         
     | 
| 437 | 
         
            +
                        "display": null,
         
     | 
| 438 | 
         
            +
                        "left": null
         
     | 
| 439 | 
         
            +
                      }
         
     | 
| 440 | 
         
            +
                    },
         
     | 
| 441 | 
         
            +
                    "f2b4a35ed75343aab26390a8c26b6e0f": {
         
     | 
| 442 | 
         
            +
                      "model_module": "@jupyter-widgets/controls",
         
     | 
| 443 | 
         
            +
                      "model_name": "ProgressStyleModel",
         
     | 
| 444 | 
         
            +
                      "state": {
         
     | 
| 445 | 
         
            +
                        "_view_name": "StyleView",
         
     | 
| 446 | 
         
            +
                        "_model_name": "ProgressStyleModel",
         
     | 
| 447 | 
         
            +
                        "description_width": "",
         
     | 
| 448 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 449 | 
         
            +
                        "_model_module_version": "1.5.0",
         
     | 
| 450 | 
         
            +
                        "_view_count": null,
         
     | 
| 451 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 452 | 
         
            +
                        "bar_color": null,
         
     | 
| 453 | 
         
            +
                        "_model_module": "@jupyter-widgets/controls"
         
     | 
| 454 | 
         
            +
                      }
         
     | 
| 455 | 
         
            +
                    },
         
     | 
| 456 | 
         
            +
                    "df1e1c500f9c42e09fd96d7cb6cca316": {
         
     | 
| 457 | 
         
            +
                      "model_module": "@jupyter-widgets/base",
         
     | 
| 458 | 
         
            +
                      "model_name": "LayoutModel",
         
     | 
| 459 | 
         
            +
                      "state": {
         
     | 
| 460 | 
         
            +
                        "_view_name": "LayoutView",
         
     | 
| 461 | 
         
            +
                        "grid_template_rows": null,
         
     | 
| 462 | 
         
            +
                        "right": null,
         
     | 
| 463 | 
         
            +
                        "justify_content": null,
         
     | 
| 464 | 
         
            +
                        "_view_module": "@jupyter-widgets/base",
         
     | 
| 465 | 
         
            +
                        "overflow": null,
         
     | 
| 466 | 
         
            +
                        "_model_module_version": "1.2.0",
         
     | 
| 467 | 
         
            +
                        "_view_count": null,
         
     | 
| 468 | 
         
            +
                        "flex_flow": null,
         
     | 
| 469 | 
         
            +
                        "width": null,
         
     | 
| 470 | 
         
            +
                        "min_width": null,
         
     | 
| 471 | 
         
            +
                        "border": null,
         
     | 
| 472 | 
         
            +
                        "align_items": null,
         
     | 
| 473 | 
         
            +
                        "bottom": null,
         
     | 
| 474 | 
         
            +
                        "_model_module": "@jupyter-widgets/base",
         
     | 
| 475 | 
         
            +
                        "top": null,
         
     | 
| 476 | 
         
            +
                        "grid_column": null,
         
     | 
| 477 | 
         
            +
                        "overflow_y": null,
         
     | 
| 478 | 
         
            +
                        "overflow_x": null,
         
     | 
| 479 | 
         
            +
                        "grid_auto_flow": null,
         
     | 
| 480 | 
         
            +
                        "grid_area": null,
         
     | 
| 481 | 
         
            +
                        "grid_template_columns": null,
         
     | 
| 482 | 
         
            +
                        "flex": null,
         
     | 
| 483 | 
         
            +
                        "_model_name": "LayoutModel",
         
     | 
| 484 | 
         
            +
                        "justify_items": null,
         
     | 
| 485 | 
         
            +
                        "grid_row": null,
         
     | 
| 486 | 
         
            +
                        "max_height": null,
         
     | 
| 487 | 
         
            +
                        "align_content": null,
         
     | 
| 488 | 
         
            +
                        "visibility": null,
         
     | 
| 489 | 
         
            +
                        "align_self": null,
         
     | 
| 490 | 
         
            +
                        "height": null,
         
     | 
| 491 | 
         
            +
                        "min_height": null,
         
     | 
| 492 | 
         
            +
                        "padding": null,
         
     | 
| 493 | 
         
            +
                        "grid_auto_rows": null,
         
     | 
| 494 | 
         
            +
                        "grid_gap": null,
         
     | 
| 495 | 
         
            +
                        "max_width": null,
         
     | 
| 496 | 
         
            +
                        "order": null,
         
     | 
| 497 | 
         
            +
                        "_view_module_version": "1.2.0",
         
     | 
| 498 | 
         
            +
                        "grid_template_areas": null,
         
     | 
| 499 | 
         
            +
                        "object_position": null,
         
     | 
| 500 | 
         
            +
                        "object_fit": null,
         
     | 
| 501 | 
         
            +
                        "grid_auto_columns": null,
         
     | 
| 502 | 
         
            +
                        "margin": null,
         
     | 
| 503 | 
         
            +
                        "display": null,
         
     | 
| 504 | 
         
            +
                        "left": null
         
     | 
| 505 | 
         
            +
                      }
         
     | 
| 506 | 
         
            +
                    }
         
     | 
| 507 | 
         
            +
                  }
         
     | 
| 508 | 
         
            +
                }
         
     | 
| 509 | 
         
            +
              },
         
     | 
| 510 | 
         
            +
              "cells": [
         
     | 
| 511 | 
         
            +
                {
         
     | 
| 512 | 
         
            +
                  "cell_type": "code",
         
     | 
| 513 | 
         
            +
                  "metadata": {
         
     | 
| 514 | 
         
            +
                    "id": "QvYDzQccgMg_"
         
     | 
| 515 | 
         
            +
                  },
         
     | 
| 516 | 
         
            +
                  "source": [
         
     | 
| 517 | 
         
            +
                    "%tensorflow_version 1.x\r\n",
         
     | 
| 518 | 
         
            +
                    "import tensorflow as tf"
         
     | 
| 519 | 
         
            +
                  ],
         
     | 
| 520 | 
         
            +
                  "execution_count": null,
         
     | 
| 521 | 
         
            +
                  "outputs": []
         
     | 
| 522 | 
         
            +
                },
         
     | 
| 523 | 
         
            +
                {
         
     | 
| 524 | 
         
            +
                  "cell_type": "code",
         
     | 
| 525 | 
         
            +
                  "metadata": {
         
     | 
| 526 | 
         
            +
                    "colab": {
         
     | 
| 527 | 
         
            +
                      "base_uri": "https://localhost:8080/"
         
     | 
| 528 | 
         
            +
                    },
         
     | 
| 529 | 
         
            +
                    "id": "y8VaukPJgclY",
         
     | 
| 530 | 
         
            +
                    "outputId": "56ac601b-2cba-427e-c9bb-860d583c1cf6"
         
     | 
| 531 | 
         
            +
                  },
         
     | 
| 532 | 
         
            +
                  "source": [
         
     | 
| 533 | 
         
            +
                    "%cd /content\n",
         
     | 
| 534 | 
         
            +
                    "!rm -rf /content/rasm\n",
         
     | 
| 535 | 
         
            +
                    "!git clone https://github.com/ARBML/rasm\n",
         
     | 
| 536 | 
         
            +
                    "%cd rasm"
         
     | 
| 537 | 
         
            +
                  ],
         
     | 
| 538 | 
         
            +
                  "execution_count": null,
         
     | 
| 539 | 
         
            +
                  "outputs": []
         
     | 
| 540 | 
         
            +
                },
         
     | 
| 541 | 
         
            +
                {
         
     | 
| 542 | 
         
            +
                  "cell_type": "code",
         
     | 
| 543 | 
         
            +
                  "metadata": {
         
     | 
| 544 | 
         
            +
                    "colab": {
         
     | 
| 545 | 
         
            +
                      "base_uri": "https://localhost:8080/"
         
     | 
| 546 | 
         
            +
                    },
         
     | 
| 547 | 
         
            +
                    "id": "BPyug4mhnEEz",
         
     | 
| 548 | 
         
            +
                    "outputId": "ad0b269c-23f8-4376-8830-9f9d0541b6c8"
         
     | 
| 549 | 
         
            +
                  },
         
     | 
| 550 | 
         
            +
                  "source": [
         
     | 
| 551 | 
         
            +
                    "from rasm import Rasm\n",
         
     | 
| 552 | 
         
            +
                    "model = Rasm(mode = 'calligraphy')"
         
     | 
| 553 | 
         
            +
                  ],
         
     | 
| 554 | 
         
            +
                  "execution_count": null,
         
     | 
| 555 | 
         
            +
                  "outputs": []
         
     | 
| 556 | 
         
            +
                },
         
     | 
| 557 | 
         
            +
                {
         
     | 
| 558 | 
         
            +
                  "cell_type": "code",
         
     | 
| 559 | 
         
            +
                  "metadata": {
         
     | 
| 560 | 
         
            +
                    "colab": {
         
     | 
| 561 | 
         
            +
                      "base_uri": "https://localhost:8080/",
         
     | 
| 562 | 
         
            +
                      "height": 919,
         
     | 
| 563 | 
         
            +
                      "referenced_widgets": [
         
     | 
| 564 | 
         
            +
                        "edfc5ae9a3924ee6a04811ab3dec1656",
         
     | 
| 565 | 
         
            +
                        "86ea189295a54dd0aa2a771a95b12f00",
         
     | 
| 566 | 
         
            +
                        "61e1782fb995412fbef8e59c570e8823",
         
     | 
| 567 | 
         
            +
                        "a03932d2929c4225b7735d0831a18566",
         
     | 
| 568 | 
         
            +
                        "c5b6fad0fd114c90ac0f23d91546846f",
         
     | 
| 569 | 
         
            +
                        "7dea244035a34d0f8658e75a12b5cd89",
         
     | 
| 570 | 
         
            +
                        "13d6a5cf5103499ebfb40e2e9c520d27",
         
     | 
| 571 | 
         
            +
                        "653de804a191499ca46a8eb6b91e7a6b"
         
     | 
| 572 | 
         
            +
                      ]
         
     | 
| 573 | 
         
            +
                    },
         
     | 
| 574 | 
         
            +
                    "id": "e5mniebmwJiy",
         
     | 
| 575 | 
         
            +
                    "outputId": "cf409647-73e5-479e-d711-c3f914721ee5"
         
     | 
| 576 | 
         
            +
                  },
         
     | 
| 577 | 
         
            +
                  "source": [
         
     | 
| 578 | 
         
            +
                    "model.generate_randomly()"
         
     | 
| 579 | 
         
            +
                  ],
         
     | 
| 580 | 
         
            +
                  "execution_count": null,
         
     | 
| 581 | 
         
            +
                  "outputs": []
         
     | 
| 582 | 
         
            +
                },
         
     | 
| 583 | 
         
            +
                {
         
     | 
| 584 | 
         
            +
                  "cell_type": "code",
         
     | 
| 585 | 
         
            +
                  "metadata": {
         
     | 
| 586 | 
         
            +
                    "colab": {
         
     | 
| 587 | 
         
            +
                      "base_uri": "https://localhost:8080/",
         
     | 
| 588 | 
         
            +
                      "height": 919,
         
     | 
| 589 | 
         
            +
                      "referenced_widgets": [
         
     | 
| 590 | 
         
            +
                        "ee14e110bec24331816d2243cdde5e37",
         
     | 
| 591 | 
         
            +
                        "c3fbec3f54b54293a667dfe43c5aed56",
         
     | 
| 592 | 
         
            +
                        "93630fd11e7f47fa9ef23b1057ec1dd9",
         
     | 
| 593 | 
         
            +
                        "aa9410600ce04b5592f4d2a7b4531656",
         
     | 
| 594 | 
         
            +
                        "af65853e39c04163a6173f438cc0ae71",
         
     | 
| 595 | 
         
            +
                        "2bb901e39be5406bb1f495030180732b",
         
     | 
| 596 | 
         
            +
                        "f2b4a35ed75343aab26390a8c26b6e0f",
         
     | 
| 597 | 
         
            +
                        "df1e1c500f9c42e09fd96d7cb6cca316"
         
     | 
| 598 | 
         
            +
                      ]
         
     | 
| 599 | 
         
            +
                    },
         
     | 
| 600 | 
         
            +
                    "id": "F2RGfy_9wRFS",
         
     | 
| 601 | 
         
            +
                    "outputId": "2b5fa5af-401f-44b0-f2a6-13da08507396"
         
     | 
| 602 | 
         
            +
                  },
         
     | 
| 603 | 
         
            +
                  "source": [
         
     | 
| 604 | 
         
            +
                    "model.generate_grid()"
         
     | 
| 605 | 
         
            +
                  ],
         
     | 
| 606 | 
         
            +
                  "execution_count": null,
         
     | 
| 607 | 
         
            +
                  "outputs": []
         
     | 
| 608 | 
         
            +
                }
         
     | 
| 609 | 
         
            +
              ]
         
     | 
| 610 | 
         
            +
            }
         
     | 
    	
        dnnlib/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,24 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from . import submission
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from .submission.run_context import RunContext
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from .submission.submit import SubmitTarget
         
     | 
| 14 | 
         
            +
            from .submission.submit import PathType
         
     | 
| 15 | 
         
            +
            from .submission.submit import SubmitConfig
         
     | 
| 16 | 
         
            +
            from .submission.submit import submit_run
         
     | 
| 17 | 
         
            +
            from .submission.submit import submit_diagnostic
         
     | 
| 18 | 
         
            +
            from .submission.submit import get_path_from_template
         
     | 
| 19 | 
         
            +
            from .submission.submit import convert_path
         
     | 
| 20 | 
         
            +
            from .submission.submit import make_run_dir_path
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from .util import EasyDict
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
         
     | 
    	
        dnnlib/submission/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This work is made available under the Nvidia Source Code License-NC.
         
     | 
| 4 | 
         
            +
            # To view a copy of this license, visit
         
     | 
| 5 | 
         
            +
            # https://nvlabs.github.io/stylegan2/license.html
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from . import run_context
         
     | 
| 8 | 
         
            +
            from . import submit
         
     | 
    	
        dnnlib/submission/internal/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This work is made available under the Nvidia Source Code License-NC.
         
     | 
| 4 | 
         
            +
            # To view a copy of this license, visit
         
     | 
| 5 | 
         
            +
            # https://nvlabs.github.io/stylegan2/license.html
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from . import local
         
     | 
    	
        dnnlib/submission/internal/local.py
    ADDED
    
    | 
         @@ -0,0 +1,22 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This work is made available under the Nvidia Source Code License-NC.
         
     | 
| 4 | 
         
            +
            # To view a copy of this license, visit
         
     | 
| 5 | 
         
            +
            # https://nvlabs.github.io/stylegan2/license.html
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class TargetOptions():
         
     | 
| 8 | 
         
            +
                def __init__(self):
         
     | 
| 9 | 
         
            +
                    self.do_not_copy_source_files = False
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class Target():
         
     | 
| 12 | 
         
            +
                def __init__(self):
         
     | 
| 13 | 
         
            +
                    pass
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def finalize_submit_config(self, submit_config, host_run_dir):
         
     | 
| 16 | 
         
            +
                    # print ('Local submit ', end='', flush=True)
         
     | 
| 17 | 
         
            +
                    submit_config.run_dir = host_run_dir
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                def submit(self, submit_config, host_run_dir):
         
     | 
| 20 | 
         
            +
                    from ..submit import run_wrapper, convert_path
         
     | 
| 21 | 
         
            +
                    # print('- run_dir: %s' % convert_path(submit_config.run_dir), flush=True)
         
     | 
| 22 | 
         
            +
                    return run_wrapper(submit_config)
         
     | 
    	
        dnnlib/submission/run_context.py
    ADDED
    
    | 
         @@ -0,0 +1,110 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This work is made available under the Nvidia Source Code License-NC.
         
     | 
| 4 | 
         
            +
            # To view a copy of this license, visit
         
     | 
| 5 | 
         
            +
            # https://nvlabs.github.io/stylegan2/license.html
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            """Helpers for managing the run/training loop."""
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import datetime
         
     | 
| 10 | 
         
            +
            import json
         
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            import pprint
         
     | 
| 13 | 
         
            +
            import time
         
     | 
| 14 | 
         
            +
            import types
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from typing import Any
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from . import submit
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            # Singleton RunContext
         
     | 
| 21 | 
         
            +
            _run_context = None
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class RunContext(object):
         
     | 
| 24 | 
         
            +
                """Helper class for managing the run/training loop.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                The context will hide the implementation details of a basic run/training loop.
         
     | 
| 27 | 
         
            +
                It will set things up properly, tell if run should be stopped, and then cleans up.
         
     | 
| 28 | 
         
            +
                User should call update periodically and use should_stop to determine if run should be stopped.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                Args:
         
     | 
| 31 | 
         
            +
                    submit_config: The SubmitConfig that is used for the current run.
         
     | 
| 32 | 
         
            +
                    config_module: (deprecated) The whole config module that is used for the current run.
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None):
         
     | 
| 36 | 
         
            +
                    global _run_context
         
     | 
| 37 | 
         
            +
                    # Only a single RunContext can be alive
         
     | 
| 38 | 
         
            +
                    assert _run_context is None
         
     | 
| 39 | 
         
            +
                    _run_context = self
         
     | 
| 40 | 
         
            +
                    self.submit_config = submit_config
         
     | 
| 41 | 
         
            +
                    self.should_stop_flag = False
         
     | 
| 42 | 
         
            +
                    self.has_closed = False
         
     | 
| 43 | 
         
            +
                    self.start_time = time.time()
         
     | 
| 44 | 
         
            +
                    self.last_update_time = time.time()
         
     | 
| 45 | 
         
            +
                    self.last_update_interval = 0.0
         
     | 
| 46 | 
         
            +
                    self.progress_monitor_file_path = None
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    # vestigial config_module support just prints a warning
         
     | 
| 49 | 
         
            +
                    if config_module is not None:
         
     | 
| 50 | 
         
            +
                        print("RunContext.config_module parameter support has been removed.")
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    # write out details about the run to a text file
         
     | 
| 53 | 
         
            +
                    self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
         
     | 
| 54 | 
         
            +
                    with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
         
     | 
| 55 | 
         
            +
                        pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def __enter__(self) -> "RunContext":
         
     | 
| 58 | 
         
            +
                    return self
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
         
     | 
| 61 | 
         
            +
                    self.close()
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
         
     | 
| 64 | 
         
            +
                    """Do general housekeeping and keep the state of the context up-to-date.
         
     | 
| 65 | 
         
            +
                    Should be called often enough but not in a tight loop."""
         
     | 
| 66 | 
         
            +
                    assert not self.has_closed
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    self.last_update_interval = time.time() - self.last_update_time
         
     | 
| 69 | 
         
            +
                    self.last_update_time = time.time()
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
         
     | 
| 72 | 
         
            +
                        self.should_stop_flag = True
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def should_stop(self) -> bool:
         
     | 
| 75 | 
         
            +
                    """Tell whether a stopping condition has been triggered one way or another."""
         
     | 
| 76 | 
         
            +
                    return self.should_stop_flag
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def get_time_since_start(self) -> float:
         
     | 
| 79 | 
         
            +
                    """How much time has passed since the creation of the context."""
         
     | 
| 80 | 
         
            +
                    return time.time() - self.start_time
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def get_time_since_last_update(self) -> float:
         
     | 
| 83 | 
         
            +
                    """How much time has passed since the last call to update."""
         
     | 
| 84 | 
         
            +
                    return time.time() - self.last_update_time
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                def get_last_update_interval(self) -> float:
         
     | 
| 87 | 
         
            +
                    """How much time passed between the previous two calls to update."""
         
     | 
| 88 | 
         
            +
                    return self.last_update_interval
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def close(self) -> None:
         
     | 
| 91 | 
         
            +
                    """Close the context and clean up.
         
     | 
| 92 | 
         
            +
                    Should only be called once."""
         
     | 
| 93 | 
         
            +
                    if not self.has_closed:
         
     | 
| 94 | 
         
            +
                        # update the run.txt with stopping time
         
     | 
| 95 | 
         
            +
                        self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
         
     | 
| 96 | 
         
            +
                        with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
         
     | 
| 97 | 
         
            +
                            pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
         
     | 
| 98 | 
         
            +
                        self.has_closed = True
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                        # detach the global singleton
         
     | 
| 101 | 
         
            +
                        global _run_context
         
     | 
| 102 | 
         
            +
                        if _run_context is self:
         
     | 
| 103 | 
         
            +
                            _run_context = None
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                @staticmethod
         
     | 
| 106 | 
         
            +
                def get():
         
     | 
| 107 | 
         
            +
                    import dnnlib
         
     | 
| 108 | 
         
            +
                    if _run_context is not None:
         
     | 
| 109 | 
         
            +
                        return _run_context
         
     | 
| 110 | 
         
            +
                    return RunContext(dnnlib.submit_config)
         
     | 
    	
        dnnlib/submission/submit.py
    ADDED
    
    | 
         @@ -0,0 +1,369 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This work is made available under the Nvidia Source Code License-NC.
         
     | 
| 4 | 
         
            +
            # To view a copy of this license, visit
         
     | 
| 5 | 
         
            +
            # https://nvlabs.github.io/stylegan2/license.html
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            """Submit a function to be run either locally or in a computing cluster."""
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import copy
         
     | 
| 10 | 
         
            +
            import inspect
         
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            import pathlib
         
     | 
| 13 | 
         
            +
            import pickle
         
     | 
| 14 | 
         
            +
            import platform
         
     | 
| 15 | 
         
            +
            import pprint
         
     | 
| 16 | 
         
            +
            import re
         
     | 
| 17 | 
         
            +
            import shutil
         
     | 
| 18 | 
         
            +
            import sys
         
     | 
| 19 | 
         
            +
            import time
         
     | 
| 20 | 
         
            +
            import traceback
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from enum import Enum
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from .. import util
         
     | 
| 25 | 
         
            +
            from ..util import EasyDict
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            from . import internal
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class SubmitTarget(Enum):
         
     | 
| 30 | 
         
            +
                """The target where the function should be run.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                LOCAL: Run it locally.
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                LOCAL = 1
         
     | 
| 35 | 
         
            +
                DIAGNOSTIC = 17
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            class PathType(Enum):
         
     | 
| 39 | 
         
            +
                """Determines in which format should a path be formatted.
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                WINDOWS: Format with Windows style.
         
     | 
| 42 | 
         
            +
                LINUX: Format with Linux/Posix style.
         
     | 
| 43 | 
         
            +
                AUTO: Use current OS type to select either WINDOWS or LINUX.
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                WINDOWS = 1
         
     | 
| 46 | 
         
            +
                LINUX = 2
         
     | 
| 47 | 
         
            +
                AUTO = 3
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            class PlatformExtras:
         
     | 
| 51 | 
         
            +
                """A mixed bag of values used by dnnlib heuristics.
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                Attributes:
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    data_reader_buffer_size: Used by DataReader to size internal shared memory buffers.
         
     | 
| 56 | 
         
            +
                    data_reader_process_count: Number of worker processes to spawn (zero for single thread operation)
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
                def __init__(self):
         
     | 
| 59 | 
         
            +
                    self.data_reader_buffer_size = 1<<30    # 1 GB
         
     | 
| 60 | 
         
            +
                    self.data_reader_process_count = 0      # single threaded default
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            _user_name_override = None
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            class SubmitConfig(util.EasyDict):
         
     | 
| 66 | 
         
            +
                """Strongly typed config dict needed to submit runs.
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                Attributes:
         
     | 
| 69 | 
         
            +
                    run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
         
     | 
| 70 | 
         
            +
                    run_desc: Description of the run. Will be used in the run dir and task name.
         
     | 
| 71 | 
         
            +
                    run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
         
     | 
| 72 | 
         
            +
                    run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
         
     | 
| 73 | 
         
            +
                    submit_target: Submit target enum value. Used to select where the run is actually launched.
         
     | 
| 74 | 
         
            +
                    num_gpus: Number of GPUs used/requested for the run.
         
     | 
| 75 | 
         
            +
                    print_info: Whether to print debug information when submitting.
         
     | 
| 76 | 
         
            +
                    local.do_not_copy_source_files: Do not copy source files from the working directory to the run dir.
         
     | 
| 77 | 
         
            +
                    run_id: Automatically populated value during submit.
         
     | 
| 78 | 
         
            +
                    run_name: Automatically populated value during submit.
         
     | 
| 79 | 
         
            +
                    run_dir: Automatically populated value during submit.
         
     | 
| 80 | 
         
            +
                    run_func_name: Automatically populated value during submit.
         
     | 
| 81 | 
         
            +
                    run_func_kwargs: Automatically populated value during submit.
         
     | 
| 82 | 
         
            +
                    user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
         
     | 
| 83 | 
         
            +
                    task_name: Automatically populated value during submit.
         
     | 
| 84 | 
         
            +
                    host_name: Automatically populated value during submit.
         
     | 
| 85 | 
         
            +
                    platform_extras: Automatically populated values during submit.  Used by various dnnlib libraries such as the DataReader class.
         
     | 
| 86 | 
         
            +
                """
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def __init__(self):
         
     | 
| 89 | 
         
            +
                    super().__init__()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # run (set these)
         
     | 
| 92 | 
         
            +
                    self.run_dir_root = ""  # should always be passed through get_path_from_template
         
     | 
| 93 | 
         
            +
                    self.run_desc = ""
         
     | 
| 94 | 
         
            +
                    self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode", "_cudacache"]
         
     | 
| 95 | 
         
            +
                    self.run_dir_extra_files = []
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    # submit (set these)
         
     | 
| 98 | 
         
            +
                    self.submit_target = SubmitTarget.LOCAL
         
     | 
| 99 | 
         
            +
                    self.num_gpus = 1
         
     | 
| 100 | 
         
            +
                    self.print_info = False
         
     | 
| 101 | 
         
            +
                    self.nvprof = False
         
     | 
| 102 | 
         
            +
                    self.local = internal.local.TargetOptions()
         
     | 
| 103 | 
         
            +
                    self.datasets = []
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    # (automatically populated)
         
     | 
| 106 | 
         
            +
                    self.run_id = None
         
     | 
| 107 | 
         
            +
                    self.run_name = None
         
     | 
| 108 | 
         
            +
                    self.run_dir = None
         
     | 
| 109 | 
         
            +
                    self.run_func_name = None
         
     | 
| 110 | 
         
            +
                    self.run_func_kwargs = None
         
     | 
| 111 | 
         
            +
                    self.user_name = None
         
     | 
| 112 | 
         
            +
                    self.task_name = None
         
     | 
| 113 | 
         
            +
                    self.host_name = "localhost"
         
     | 
| 114 | 
         
            +
                    self.platform_extras = PlatformExtras()
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
         
     | 
| 118 | 
         
            +
                """Replace tags in the given path template and return either Windows or Linux formatted path."""
         
     | 
| 119 | 
         
            +
                # automatically select path type depending on running OS
         
     | 
| 120 | 
         
            +
                if path_type == PathType.AUTO:
         
     | 
| 121 | 
         
            +
                    if platform.system() == "Windows":
         
     | 
| 122 | 
         
            +
                        path_type = PathType.WINDOWS
         
     | 
| 123 | 
         
            +
                    elif platform.system() == "Linux":
         
     | 
| 124 | 
         
            +
                        path_type = PathType.LINUX
         
     | 
| 125 | 
         
            +
                    else:
         
     | 
| 126 | 
         
            +
                        raise RuntimeError("Unknown platform")
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                path_template = path_template.replace("<USERNAME>", get_user_name())
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                # return correctly formatted path
         
     | 
| 131 | 
         
            +
                if path_type == PathType.WINDOWS:
         
     | 
| 132 | 
         
            +
                    return str(pathlib.PureWindowsPath(path_template))
         
     | 
| 133 | 
         
            +
                elif path_type == PathType.LINUX:
         
     | 
| 134 | 
         
            +
                    return str(pathlib.PurePosixPath(path_template))
         
     | 
| 135 | 
         
            +
                else:
         
     | 
| 136 | 
         
            +
                    raise RuntimeError("Unknown platform")
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            def get_template_from_path(path: str) -> str:
         
     | 
| 140 | 
         
            +
                """Convert a normal path back to its template representation."""
         
     | 
| 141 | 
         
            +
                path = path.replace("\\", "/")
         
     | 
| 142 | 
         
            +
                return path
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
         
     | 
| 146 | 
         
            +
                """Convert a normal path to template and the convert it back to a normal path with given path type."""
         
     | 
| 147 | 
         
            +
                path_template = get_template_from_path(path)
         
     | 
| 148 | 
         
            +
                path = get_path_from_template(path_template, path_type)
         
     | 
| 149 | 
         
            +
                return path
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            def set_user_name_override(name: str) -> None:
         
     | 
| 153 | 
         
            +
                """Set the global username override value."""
         
     | 
| 154 | 
         
            +
                global _user_name_override
         
     | 
| 155 | 
         
            +
                _user_name_override = name
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            def get_user_name():
         
     | 
| 159 | 
         
            +
                """Get the current user name."""
         
     | 
| 160 | 
         
            +
                if _user_name_override is not None:
         
     | 
| 161 | 
         
            +
                    return _user_name_override
         
     | 
| 162 | 
         
            +
                elif platform.system() == "Windows":
         
     | 
| 163 | 
         
            +
                    return os.getlogin()
         
     | 
| 164 | 
         
            +
                elif platform.system() == "Linux":
         
     | 
| 165 | 
         
            +
                    try:
         
     | 
| 166 | 
         
            +
                        import pwd
         
     | 
| 167 | 
         
            +
                        return pwd.getpwuid(os.geteuid()).pw_name
         
     | 
| 168 | 
         
            +
                    except:
         
     | 
| 169 | 
         
            +
                        return "unknown"
         
     | 
| 170 | 
         
            +
                else:
         
     | 
| 171 | 
         
            +
                    raise RuntimeError("Unknown platform")
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
            def make_run_dir_path(*paths):
         
     | 
| 175 | 
         
            +
                """Make a path/filename that resides under the current submit run_dir.
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                Args:
         
     | 
| 178 | 
         
            +
                    *paths: Path components to be passed to os.path.join
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                Returns:
         
     | 
| 181 | 
         
            +
                    A file/dirname rooted at submit_config.run_dir.  If there's no
         
     | 
| 182 | 
         
            +
                    submit_config or run_dir, the base directory is the current
         
     | 
| 183 | 
         
            +
                    working directory.
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                E.g., `os.path.join(dnnlib.submit_config.run_dir, "output.txt"))`
         
     | 
| 186 | 
         
            +
                """
         
     | 
| 187 | 
         
            +
                import dnnlib
         
     | 
| 188 | 
         
            +
                if (dnnlib.submit_config is None) or (dnnlib.submit_config.run_dir is None):
         
     | 
| 189 | 
         
            +
                    return os.path.join(os.getcwd(), *paths)
         
     | 
| 190 | 
         
            +
                return os.path.join(dnnlib.submit_config.run_dir, *paths)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            def _create_run_dir_local(submit_config: SubmitConfig) -> str:
         
     | 
| 194 | 
         
            +
                """Create a new run dir with increasing ID number at the start."""
         
     | 
| 195 | 
         
            +
                run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                if not os.path.exists(run_dir_root):
         
     | 
| 198 | 
         
            +
                    os.makedirs(run_dir_root)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                submit_config.run_id = _get_next_run_id_local(run_dir_root)
         
     | 
| 201 | 
         
            +
                submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
         
     | 
| 202 | 
         
            +
                run_dir = os.path.join(run_dir_root, submit_config.run_name)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                if os.path.exists(run_dir):
         
     | 
| 205 | 
         
            +
                    raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                os.makedirs(run_dir)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                return run_dir
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
            def _get_next_run_id_local(run_dir_root: str) -> int:
         
     | 
| 213 | 
         
            +
                """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
         
     | 
| 214 | 
         
            +
                dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
         
     | 
| 215 | 
         
            +
                r = re.compile("^\\d+")  # match one or more digits at the start of the string
         
     | 
| 216 | 
         
            +
                run_id = 0
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                for dir_name in dir_names:
         
     | 
| 219 | 
         
            +
                    m = r.match(dir_name)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    if m is not None:
         
     | 
| 222 | 
         
            +
                        i = int(m.group())
         
     | 
| 223 | 
         
            +
                        run_id = max(run_id, i + 1)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                return run_id
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
            def _populate_run_dir(submit_config: SubmitConfig, run_dir: str) -> None:
         
     | 
| 229 | 
         
            +
                """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
         
     | 
| 230 | 
         
            +
                pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
         
     | 
| 231 | 
         
            +
                with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
         
     | 
| 232 | 
         
            +
                    pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                if (submit_config.submit_target == SubmitTarget.LOCAL) and submit_config.local.do_not_copy_source_files:
         
     | 
| 235 | 
         
            +
                    return
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                files = []
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
         
     | 
| 240 | 
         
            +
                assert '.' in submit_config.run_func_name
         
     | 
| 241 | 
         
            +
                for _idx in range(submit_config.run_func_name.count('.') - 1):
         
     | 
| 242 | 
         
            +
                    run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
         
     | 
| 243 | 
         
            +
                files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
         
     | 
| 246 | 
         
            +
                files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                files += submit_config.run_dir_extra_files
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
         
     | 
| 251 | 
         
            +
                files += [(os.path.join(dnnlib_module_dir_path, "submission", "internal", "run.py"), os.path.join(run_dir, "run.py"))]
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                util.copy_files_and_create_dirs(files)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
            def run_wrapper(submit_config: SubmitConfig) -> None:
         
     | 
| 258 | 
         
            +
                """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
         
     | 
| 259 | 
         
            +
                is_local = submit_config.submit_target == SubmitTarget.LOCAL
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
         
     | 
| 262 | 
         
            +
                if is_local:
         
     | 
| 263 | 
         
            +
                    logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
         
     | 
| 264 | 
         
            +
                else:  # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
         
     | 
| 265 | 
         
            +
                    logger = util.Logger(file_name=None, should_flush=True)
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                import dnnlib
         
     | 
| 268 | 
         
            +
                dnnlib.submit_config = submit_config
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                exit_with_errcode = False
         
     | 
| 271 | 
         
            +
                try:
         
     | 
| 272 | 
         
            +
                    # print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
         
     | 
| 273 | 
         
            +
                    start_time = time.time()
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
         
     | 
| 276 | 
         
            +
                    assert callable(run_func_obj)
         
     | 
| 277 | 
         
            +
                    sig = inspect.signature(run_func_obj)
         
     | 
| 278 | 
         
            +
                    if 'submit_config' in sig.parameters:
         
     | 
| 279 | 
         
            +
                        run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
         
     | 
| 280 | 
         
            +
                    else:
         
     | 
| 281 | 
         
            +
                        run_func_obj(**submit_config.run_func_kwargs)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    # print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
         
     | 
| 284 | 
         
            +
                except:
         
     | 
| 285 | 
         
            +
                    if is_local:
         
     | 
| 286 | 
         
            +
                        raise
         
     | 
| 287 | 
         
            +
                    else:
         
     | 
| 288 | 
         
            +
                        traceback.print_exc()
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                        try:
         
     | 
| 291 | 
         
            +
                            log_src = os.path.join(submit_config.run_dir, "log.txt")
         
     | 
| 292 | 
         
            +
                            log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
         
     | 
| 293 | 
         
            +
                            shutil.copyfile(log_src, log_dst)
         
     | 
| 294 | 
         
            +
                        except:
         
     | 
| 295 | 
         
            +
                            print("Failing hard, check stack trace")
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                        # Defer sys.exit(1) to happen after we close the logs and create a _finished.txt
         
     | 
| 298 | 
         
            +
                        exit_with_errcode = True
         
     | 
| 299 | 
         
            +
                finally:
         
     | 
| 300 | 
         
            +
                    if submit_config.submit_target != SubmitTarget.DIAGNOSTIC:
         
     | 
| 301 | 
         
            +
                        open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                dnnlib.RunContext.get().close()
         
     | 
| 304 | 
         
            +
                dnnlib.submit_config = None
         
     | 
| 305 | 
         
            +
                logger.close()
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                # If we hit an error, get out of the script now and signal the error
         
     | 
| 308 | 
         
            +
                # to whatever process that started this script.
         
     | 
| 309 | 
         
            +
                if exit_with_errcode:
         
     | 
| 310 | 
         
            +
                    sys.exit(1)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                return submit_config
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
            def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
         
     | 
| 316 | 
         
            +
                """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
         
     | 
| 317 | 
         
            +
                submit_config = copy.deepcopy(submit_config)
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                submit_target = submit_config.submit_target
         
     | 
| 320 | 
         
            +
                farm = None
         
     | 
| 321 | 
         
            +
                if submit_target == SubmitTarget.LOCAL:
         
     | 
| 322 | 
         
            +
                    farm = internal.local.Target()
         
     | 
| 323 | 
         
            +
                assert farm is not None # unknown target
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                # Disallow submitting jobs with zero num_gpus.
         
     | 
| 326 | 
         
            +
                if (submit_config.num_gpus is None) or (submit_config.num_gpus == 0):
         
     | 
| 327 | 
         
            +
                    raise RuntimeError("submit_config.num_gpus must be set to a non-zero value")
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                if submit_config.user_name is None:
         
     | 
| 330 | 
         
            +
                    submit_config.user_name = get_user_name()
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                submit_config.run_func_name = run_func_name
         
     | 
| 333 | 
         
            +
                submit_config.run_func_kwargs = run_func_kwargs
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                #--------------------------------------------------------------------
         
     | 
| 336 | 
         
            +
                # Prepare submission by populating the run dir
         
     | 
| 337 | 
         
            +
                #--------------------------------------------------------------------
         
     | 
| 338 | 
         
            +
                host_run_dir = _create_run_dir_local(submit_config)
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
         
     | 
| 341 | 
         
            +
                docker_valid_name_regex = "^[a-zA-Z0-9][a-zA-Z0-9_.-]+$"
         
     | 
| 342 | 
         
            +
                if not re.match(docker_valid_name_regex, submit_config.task_name):
         
     | 
| 343 | 
         
            +
                    raise RuntimeError("Invalid task name.  Probable reason: unacceptable characters in your submit_config.run_desc.  Task name must be accepted by the following regex: " + docker_valid_name_regex + ", got " + submit_config.task_name)
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                # Farm specific preparations for a submit
         
     | 
| 346 | 
         
            +
                farm.finalize_submit_config(submit_config, host_run_dir)
         
     | 
| 347 | 
         
            +
                _populate_run_dir(submit_config, host_run_dir)
         
     | 
| 348 | 
         
            +
                return farm.submit(submit_config, host_run_dir)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
            def submit_diagnostic(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
         
     | 
| 351 | 
         
            +
                """Launch a run without creating a run directory."""
         
     | 
| 352 | 
         
            +
                submit_config = copy.deepcopy(submit_config)
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                submit_target = submit_config.submit_target
         
     | 
| 355 | 
         
            +
                farm = None
         
     | 
| 356 | 
         
            +
                if submit_target == SubmitTarget.LOCAL or submit_target == SubmitTarget.DIAGNOSTIC:
         
     | 
| 357 | 
         
            +
                    farm = internal.local.Target()
         
     | 
| 358 | 
         
            +
                assert farm is not None # unknown target
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                if submit_config.user_name is None:
         
     | 
| 361 | 
         
            +
                    submit_config.user_name = get_user_name()
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                submit_config.run_func_name = run_func_name
         
     | 
| 364 | 
         
            +
                submit_config.run_func_kwargs = run_func_kwargs
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                host_run_dir = ""
         
     | 
| 367 | 
         
            +
                # Farm specific preparations for a submit
         
     | 
| 368 | 
         
            +
                farm.finalize_submit_config(submit_config, host_run_dir)
         
     | 
| 369 | 
         
            +
                return farm.submit(submit_config, host_run_dir)
         
     | 
    	
        dnnlib/tflib/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,20 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from . import autosummary
         
     | 
| 10 | 
         
            +
            from . import network
         
     | 
| 11 | 
         
            +
            from . import optimizer
         
     | 
| 12 | 
         
            +
            from . import tfutil
         
     | 
| 13 | 
         
            +
            from . import custom_ops
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from .tfutil import *
         
     | 
| 16 | 
         
            +
            from .network import Network
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from .optimizer import Optimizer
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from .custom_ops import get_plugin
         
     | 
    	
        dnnlib/tflib/autosummary.py
    ADDED
    
    | 
         @@ -0,0 +1,193 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """Helper for adding automatically tracked values to Tensorboard.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            Autosummary creates an identity op that internally keeps track of the input
         
     | 
| 12 | 
         
            +
            values and automatically shows up in TensorBoard. The reported value
         
     | 
| 13 | 
         
            +
            represents an average over input components. The average is accumulated
         
     | 
| 14 | 
         
            +
            constantly over time and flushed when save_summaries() is called.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            Notes:
         
     | 
| 17 | 
         
            +
            - The output tensor must be used as an input for something else in the
         
     | 
| 18 | 
         
            +
              graph. Otherwise, the autosummary op will not get executed, and the average
         
     | 
| 19 | 
         
            +
              value will not get accumulated.
         
     | 
| 20 | 
         
            +
            - It is perfectly fine to include autosummaries with the same name in
         
     | 
| 21 | 
         
            +
              several places throughout the graph, even if they are executed concurrently.
         
     | 
| 22 | 
         
            +
            - It is ok to also pass in a python scalar or numpy array. In this case, it
         
     | 
| 23 | 
         
            +
              is added to the average immediately.
         
     | 
| 24 | 
         
            +
            """
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            from collections import OrderedDict
         
     | 
| 27 | 
         
            +
            import numpy as np
         
     | 
| 28 | 
         
            +
            import tensorflow as tf
         
     | 
| 29 | 
         
            +
            from tensorboard import summary as summary_lib
         
     | 
| 30 | 
         
            +
            from tensorboard.plugins.custom_scalar import layout_pb2
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            from . import tfutil
         
     | 
| 33 | 
         
            +
            from .tfutil import TfExpression
         
     | 
| 34 | 
         
            +
            from .tfutil import TfExpressionEx
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            # Enable "Custom scalars" tab in TensorBoard for advanced formatting.
         
     | 
| 37 | 
         
            +
            # Disabled by default to reduce tfevents file size.
         
     | 
| 38 | 
         
            +
            enable_custom_scalars = False
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            _dtype = tf.float64
         
     | 
| 41 | 
         
            +
            _vars = OrderedDict()  # name => [var, ...]
         
     | 
| 42 | 
         
            +
            _immediate = OrderedDict()  # name => update_op, update_value
         
     | 
| 43 | 
         
            +
            _finalized = False
         
     | 
| 44 | 
         
            +
            _merge_op = None
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
         
     | 
| 48 | 
         
            +
                """Internal helper for creating autosummary accumulators."""
         
     | 
| 49 | 
         
            +
                assert not _finalized
         
     | 
| 50 | 
         
            +
                name_id = name.replace("/", "_")
         
     | 
| 51 | 
         
            +
                v = tf.cast(value_expr, _dtype)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                if v.shape.is_fully_defined():
         
     | 
| 54 | 
         
            +
                    size = np.prod(v.shape.as_list())
         
     | 
| 55 | 
         
            +
                    size_expr = tf.constant(size, dtype=_dtype)
         
     | 
| 56 | 
         
            +
                else:
         
     | 
| 57 | 
         
            +
                    size = None
         
     | 
| 58 | 
         
            +
                    size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                if size == 1:
         
     | 
| 61 | 
         
            +
                    if v.shape.ndims != 0:
         
     | 
| 62 | 
         
            +
                        v = tf.reshape(v, [])
         
     | 
| 63 | 
         
            +
                    v = [size_expr, v, tf.square(v)]
         
     | 
| 64 | 
         
            +
                else:
         
     | 
| 65 | 
         
            +
                    v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
         
     | 
| 66 | 
         
            +
                v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
         
     | 
| 69 | 
         
            +
                    var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False)  # [sum(1), sum(x), sum(x**2)]
         
     | 
| 70 | 
         
            +
                update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                if name in _vars:
         
     | 
| 73 | 
         
            +
                    _vars[name].append(var)
         
     | 
| 74 | 
         
            +
                else:
         
     | 
| 75 | 
         
            +
                    _vars[name] = [var]
         
     | 
| 76 | 
         
            +
                return update_op
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
         
     | 
| 80 | 
         
            +
                """Create a new autosummary.
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                Args:
         
     | 
| 83 | 
         
            +
                    name:     Name to use in TensorBoard
         
     | 
| 84 | 
         
            +
                    value:    TensorFlow expression or python value to track
         
     | 
| 85 | 
         
            +
                    passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                Example use of the passthru mechanism:
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                n = autosummary('l2loss', loss, passthru=n)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                This is a shorthand for the following code:
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                with tf.control_dependencies([autosummary('l2loss', loss)]):
         
     | 
| 94 | 
         
            +
                    n = tf.identity(n)
         
     | 
| 95 | 
         
            +
                """
         
     | 
| 96 | 
         
            +
                tfutil.assert_tf_initialized()
         
     | 
| 97 | 
         
            +
                name_id = name.replace("/", "_")
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                if tfutil.is_tf_expression(value):
         
     | 
| 100 | 
         
            +
                    with tf.name_scope("summary_" + name_id), tf.device(value.device):
         
     | 
| 101 | 
         
            +
                        condition = tf.convert_to_tensor(condition, name='condition')
         
     | 
| 102 | 
         
            +
                        update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
         
     | 
| 103 | 
         
            +
                        with tf.control_dependencies([update_op]):
         
     | 
| 104 | 
         
            +
                            return tf.identity(value if passthru is None else passthru)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                else:  # python scalar or numpy array
         
     | 
| 107 | 
         
            +
                    assert not tfutil.is_tf_expression(passthru)
         
     | 
| 108 | 
         
            +
                    assert not tfutil.is_tf_expression(condition)
         
     | 
| 109 | 
         
            +
                    if condition:
         
     | 
| 110 | 
         
            +
                        if name not in _immediate:
         
     | 
| 111 | 
         
            +
                            with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
         
     | 
| 112 | 
         
            +
                                update_value = tf.placeholder(_dtype)
         
     | 
| 113 | 
         
            +
                                update_op = _create_var(name, update_value)
         
     | 
| 114 | 
         
            +
                                _immediate[name] = update_op, update_value
         
     | 
| 115 | 
         
            +
                        update_op, update_value = _immediate[name]
         
     | 
| 116 | 
         
            +
                        tfutil.run(update_op, {update_value: value})
         
     | 
| 117 | 
         
            +
                    return value if passthru is None else passthru
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            def finalize_autosummaries() -> None:
         
     | 
| 121 | 
         
            +
                """Create the necessary ops to include autosummaries in TensorBoard report.
         
     | 
| 122 | 
         
            +
                Note: This should be done only once per graph.
         
     | 
| 123 | 
         
            +
                """
         
     | 
| 124 | 
         
            +
                global _finalized
         
     | 
| 125 | 
         
            +
                tfutil.assert_tf_initialized()
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                if _finalized:
         
     | 
| 128 | 
         
            +
                    return None
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                _finalized = True
         
     | 
| 131 | 
         
            +
                tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                # Create summary ops.
         
     | 
| 134 | 
         
            +
                with tf.device(None), tf.control_dependencies(None):
         
     | 
| 135 | 
         
            +
                    for name, vars_list in _vars.items():
         
     | 
| 136 | 
         
            +
                        name_id = name.replace("/", "_")
         
     | 
| 137 | 
         
            +
                        with tfutil.absolute_name_scope("Autosummary/" + name_id):
         
     | 
| 138 | 
         
            +
                            moments = tf.add_n(vars_list)
         
     | 
| 139 | 
         
            +
                            moments /= moments[0]
         
     | 
| 140 | 
         
            +
                            with tf.control_dependencies([moments]):  # read before resetting
         
     | 
| 141 | 
         
            +
                                reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
         
     | 
| 142 | 
         
            +
                                with tf.name_scope(None), tf.control_dependencies(reset_ops):  # reset before reporting
         
     | 
| 143 | 
         
            +
                                    mean = moments[1]
         
     | 
| 144 | 
         
            +
                                    std = tf.sqrt(moments[2] - tf.square(moments[1]))
         
     | 
| 145 | 
         
            +
                                    tf.summary.scalar(name, mean)
         
     | 
| 146 | 
         
            +
                                    if enable_custom_scalars:
         
     | 
| 147 | 
         
            +
                                        tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
         
     | 
| 148 | 
         
            +
                                        tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                # Setup layout for custom scalars.
         
     | 
| 151 | 
         
            +
                layout = None
         
     | 
| 152 | 
         
            +
                if enable_custom_scalars:
         
     | 
| 153 | 
         
            +
                    cat_dict = OrderedDict()
         
     | 
| 154 | 
         
            +
                    for series_name in sorted(_vars.keys()):
         
     | 
| 155 | 
         
            +
                        p = series_name.split("/")
         
     | 
| 156 | 
         
            +
                        cat = p[0] if len(p) >= 2 else ""
         
     | 
| 157 | 
         
            +
                        chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
         
     | 
| 158 | 
         
            +
                        if cat not in cat_dict:
         
     | 
| 159 | 
         
            +
                            cat_dict[cat] = OrderedDict()
         
     | 
| 160 | 
         
            +
                        if chart not in cat_dict[cat]:
         
     | 
| 161 | 
         
            +
                            cat_dict[cat][chart] = []
         
     | 
| 162 | 
         
            +
                        cat_dict[cat][chart].append(series_name)
         
     | 
| 163 | 
         
            +
                    categories = []
         
     | 
| 164 | 
         
            +
                    for cat_name, chart_dict in cat_dict.items():
         
     | 
| 165 | 
         
            +
                        charts = []
         
     | 
| 166 | 
         
            +
                        for chart_name, series_names in chart_dict.items():
         
     | 
| 167 | 
         
            +
                            series = []
         
     | 
| 168 | 
         
            +
                            for series_name in series_names:
         
     | 
| 169 | 
         
            +
                                series.append(layout_pb2.MarginChartContent.Series(
         
     | 
| 170 | 
         
            +
                                    value=series_name,
         
     | 
| 171 | 
         
            +
                                    lower="xCustomScalars/" + series_name + "/margin_lo",
         
     | 
| 172 | 
         
            +
                                    upper="xCustomScalars/" + series_name + "/margin_hi"))
         
     | 
| 173 | 
         
            +
                            margin = layout_pb2.MarginChartContent(series=series)
         
     | 
| 174 | 
         
            +
                            charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
         
     | 
| 175 | 
         
            +
                        categories.append(layout_pb2.Category(title=cat_name, chart=charts))
         
     | 
| 176 | 
         
            +
                    layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
         
     | 
| 177 | 
         
            +
                return layout
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
            def save_summaries(file_writer, global_step=None):
         
     | 
| 180 | 
         
            +
                """Call FileWriter.add_summary() with all summaries in the default graph,
         
     | 
| 181 | 
         
            +
                automatically finalizing and merging them on the first call.
         
     | 
| 182 | 
         
            +
                """
         
     | 
| 183 | 
         
            +
                global _merge_op
         
     | 
| 184 | 
         
            +
                tfutil.assert_tf_initialized()
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                if _merge_op is None:
         
     | 
| 187 | 
         
            +
                    layout = finalize_autosummaries()
         
     | 
| 188 | 
         
            +
                    if layout is not None:
         
     | 
| 189 | 
         
            +
                        file_writer.add_summary(layout)
         
     | 
| 190 | 
         
            +
                    with tf.device(None), tf.control_dependencies(None):
         
     | 
| 191 | 
         
            +
                        _merge_op = tf.summary.merge_all()
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                file_writer.add_summary(_merge_op.eval(), global_step)
         
     | 
    	
        dnnlib/tflib/custom_ops.py
    ADDED
    
    | 
         @@ -0,0 +1,181 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """TensorFlow custom ops builder.
         
     | 
| 10 | 
         
            +
            """
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import glob
         
     | 
| 13 | 
         
            +
            import os
         
     | 
| 14 | 
         
            +
            import re
         
     | 
| 15 | 
         
            +
            import uuid
         
     | 
| 16 | 
         
            +
            import hashlib
         
     | 
| 17 | 
         
            +
            import tempfile
         
     | 
| 18 | 
         
            +
            import shutil
         
     | 
| 19 | 
         
            +
            import tensorflow as tf
         
     | 
| 20 | 
         
            +
            from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from .. import util
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 25 | 
         
            +
            # Global options.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            cuda_cache_path = None
         
     | 
| 28 | 
         
            +
            cuda_cache_version_tag = 'v1'
         
     | 
| 29 | 
         
            +
            do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change.
         
     | 
| 30 | 
         
            +
            verbose = False # Print status messages to stdout.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 33 | 
         
            +
            # Internal helper funcs.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def _find_compiler_bindir():
         
     | 
| 36 | 
         
            +
                hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
         
     | 
| 37 | 
         
            +
                if hostx64_paths != []:
         
     | 
| 38 | 
         
            +
                    return hostx64_paths[0]
         
     | 
| 39 | 
         
            +
                hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
         
     | 
| 40 | 
         
            +
                if hostx64_paths != []:
         
     | 
| 41 | 
         
            +
                    return hostx64_paths[0]
         
     | 
| 42 | 
         
            +
                hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True)
         
     | 
| 43 | 
         
            +
                if hostx64_paths != []:
         
     | 
| 44 | 
         
            +
                    return hostx64_paths[0]
         
     | 
| 45 | 
         
            +
                vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin'
         
     | 
| 46 | 
         
            +
                if os.path.isdir(vc_bin_dir):
         
     | 
| 47 | 
         
            +
                    return vc_bin_dir
         
     | 
| 48 | 
         
            +
                return None
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def _get_compute_cap(device):
         
     | 
| 51 | 
         
            +
                caps_str = device.physical_device_desc
         
     | 
| 52 | 
         
            +
                m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
         
     | 
| 53 | 
         
            +
                major = m.group(1)
         
     | 
| 54 | 
         
            +
                minor = m.group(2)
         
     | 
| 55 | 
         
            +
                return (major, minor)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def _get_cuda_gpu_arch_string():
         
     | 
| 58 | 
         
            +
                gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
         
     | 
| 59 | 
         
            +
                if len(gpus) == 0:
         
     | 
| 60 | 
         
            +
                    raise RuntimeError('No GPU devices found')
         
     | 
| 61 | 
         
            +
                (major, minor) = _get_compute_cap(gpus[0])
         
     | 
| 62 | 
         
            +
                return 'sm_%s%s' % (major, minor)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def _run_cmd(cmd):
         
     | 
| 65 | 
         
            +
                with os.popen(cmd) as pipe:
         
     | 
| 66 | 
         
            +
                    output = pipe.read()
         
     | 
| 67 | 
         
            +
                    status = pipe.close()
         
     | 
| 68 | 
         
            +
                if status is not None:
         
     | 
| 69 | 
         
            +
                    raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def _prepare_nvcc_cli(opts):
         
     | 
| 72 | 
         
            +
                cmd = 'nvcc  --std=c++11 -DNDEBUG ' + opts.strip()
         
     | 
| 73 | 
         
            +
                cmd += ' --disable-warnings'
         
     | 
| 74 | 
         
            +
                cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
         
     | 
| 75 | 
         
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
         
     | 
| 76 | 
         
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
         
     | 
| 77 | 
         
            +
                cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                compiler_bindir = _find_compiler_bindir()
         
     | 
| 80 | 
         
            +
                if compiler_bindir is None:
         
     | 
| 81 | 
         
            +
                    # Require that _find_compiler_bindir succeeds on Windows.  Allow
         
     | 
| 82 | 
         
            +
                    # nvcc to use whatever is the default on Linux.
         
     | 
| 83 | 
         
            +
                    if os.name == 'nt':
         
     | 
| 84 | 
         
            +
                        raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
         
     | 
| 85 | 
         
            +
                else:
         
     | 
| 86 | 
         
            +
                    cmd += ' --compiler-bindir "%s"' % compiler_bindir
         
     | 
| 87 | 
         
            +
                cmd += ' 2>&1'
         
     | 
| 88 | 
         
            +
                return cmd
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 91 | 
         
            +
            # Main entry point.
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            _plugin_cache = dict()
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            def get_plugin(cuda_file, extra_nvcc_options=[]):
         
     | 
| 96 | 
         
            +
                cuda_file_base = os.path.basename(cuda_file)
         
     | 
| 97 | 
         
            +
                cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                # Already in cache?
         
     | 
| 100 | 
         
            +
                if cuda_file in _plugin_cache:
         
     | 
| 101 | 
         
            +
                    return _plugin_cache[cuda_file]
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                # Setup plugin.
         
     | 
| 104 | 
         
            +
                if verbose:
         
     | 
| 105 | 
         
            +
                    print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
         
     | 
| 106 | 
         
            +
                try:
         
     | 
| 107 | 
         
            +
                    # Hash CUDA source.
         
     | 
| 108 | 
         
            +
                    md5 = hashlib.md5()
         
     | 
| 109 | 
         
            +
                    with open(cuda_file, 'rb') as f:
         
     | 
| 110 | 
         
            +
                        md5.update(f.read())
         
     | 
| 111 | 
         
            +
                    md5.update(b'\n')
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    # Hash headers included by the CUDA code by running it through the preprocessor.
         
     | 
| 114 | 
         
            +
                    if not do_not_hash_included_headers:
         
     | 
| 115 | 
         
            +
                        if verbose:
         
     | 
| 116 | 
         
            +
                            print('Preprocessing... ', end='', flush=True)
         
     | 
| 117 | 
         
            +
                        with tempfile.TemporaryDirectory() as tmp_dir:
         
     | 
| 118 | 
         
            +
                            tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
         
     | 
| 119 | 
         
            +
                            _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
         
     | 
| 120 | 
         
            +
                            with open(tmp_file, 'rb') as f:
         
     | 
| 121 | 
         
            +
                                bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
         
     | 
| 122 | 
         
            +
                                good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
         
     | 
| 123 | 
         
            +
                                for ln in f:
         
     | 
| 124 | 
         
            +
                                    if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
         
     | 
| 125 | 
         
            +
                                        ln = ln.replace(bad_file_str, good_file_str)
         
     | 
| 126 | 
         
            +
                                        md5.update(ln)
         
     | 
| 127 | 
         
            +
                                md5.update(b'\n')
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    # Select compiler options.
         
     | 
| 130 | 
         
            +
                    compile_opts = ''
         
     | 
| 131 | 
         
            +
                    if os.name == 'nt':
         
     | 
| 132 | 
         
            +
                        compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
         
     | 
| 133 | 
         
            +
                    elif os.name == 'posix':
         
     | 
| 134 | 
         
            +
                        compile_opts += f' --compiler-options \'-fPIC\''
         
     | 
| 135 | 
         
            +
                        compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\''
         
     | 
| 136 | 
         
            +
                        compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\''
         
     | 
| 137 | 
         
            +
                    else:
         
     | 
| 138 | 
         
            +
                        assert False # not Windows or Linux, w00t?
         
     | 
| 139 | 
         
            +
                    compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}'
         
     | 
| 140 | 
         
            +
                    compile_opts += ' --use_fast_math'
         
     | 
| 141 | 
         
            +
                    for opt in extra_nvcc_options:
         
     | 
| 142 | 
         
            +
                        compile_opts += ' ' + opt
         
     | 
| 143 | 
         
            +
                    nvcc_cmd = _prepare_nvcc_cli(compile_opts)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    # Hash build configuration.
         
     | 
| 146 | 
         
            +
                    md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
         
     | 
| 147 | 
         
            +
                    md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
         
     | 
| 148 | 
         
            +
                    md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # Compile if not already compiled.
         
     | 
| 151 | 
         
            +
                    cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path
         
     | 
| 152 | 
         
            +
                    bin_file_ext = '.dll' if os.name == 'nt' else '.so'
         
     | 
| 153 | 
         
            +
                    bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
         
     | 
| 154 | 
         
            +
                    if not os.path.isfile(bin_file):
         
     | 
| 155 | 
         
            +
                        if verbose:
         
     | 
| 156 | 
         
            +
                            print('Compiling... ', end='', flush=True)
         
     | 
| 157 | 
         
            +
                        with tempfile.TemporaryDirectory() as tmp_dir:
         
     | 
| 158 | 
         
            +
                            tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
         
     | 
| 159 | 
         
            +
                            _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
         
     | 
| 160 | 
         
            +
                            os.makedirs(cache_dir, exist_ok=True)
         
     | 
| 161 | 
         
            +
                            intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
         
     | 
| 162 | 
         
            +
                            shutil.copyfile(tmp_file, intermediate_file)
         
     | 
| 163 | 
         
            +
                            os.rename(intermediate_file, bin_file) # atomic
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    # Load.
         
     | 
| 166 | 
         
            +
                    if verbose:
         
     | 
| 167 | 
         
            +
                        print('Loading... ', end='', flush=True)
         
     | 
| 168 | 
         
            +
                    plugin = tf.load_op_library(bin_file)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    # Add to cache.
         
     | 
| 171 | 
         
            +
                    _plugin_cache[cuda_file] = plugin
         
     | 
| 172 | 
         
            +
                    if verbose:
         
     | 
| 173 | 
         
            +
                        print('Done.', flush=True)
         
     | 
| 174 | 
         
            +
                    return plugin
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                except:
         
     | 
| 177 | 
         
            +
                    if verbose:
         
     | 
| 178 | 
         
            +
                        print('Failed!', flush=True)
         
     | 
| 179 | 
         
            +
                    raise
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
    	
        dnnlib/tflib/network.py
    ADDED
    
    | 
         @@ -0,0 +1,825 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """Helper for managing networks."""
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import types
         
     | 
| 12 | 
         
            +
            import inspect
         
     | 
| 13 | 
         
            +
            import re
         
     | 
| 14 | 
         
            +
            import uuid
         
     | 
| 15 | 
         
            +
            import sys
         
     | 
| 16 | 
         
            +
            import copy
         
     | 
| 17 | 
         
            +
            import numpy as np
         
     | 
| 18 | 
         
            +
            import tensorflow as tf
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from collections import OrderedDict
         
     | 
| 21 | 
         
            +
            from typing import Any, List, Tuple, Union, Callable
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            from . import tfutil
         
     | 
| 24 | 
         
            +
            from .. import util
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            from .tfutil import TfExpression, TfExpressionEx
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            # pylint: disable=protected-access
         
     | 
| 29 | 
         
            +
            # pylint: disable=attribute-defined-outside-init
         
     | 
| 30 | 
         
            +
            # pylint: disable=too-many-public-methods
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            _import_handlers = []  # Custom import handlers for dealing with legacy data in pickle import.
         
     | 
| 33 | 
         
            +
            _import_module_src = dict()  # Source code for temporary modules created during pickle import.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def import_handler(handler_func):
         
     | 
| 37 | 
         
            +
                """Function decorator for declaring custom import handlers."""
         
     | 
| 38 | 
         
            +
                _import_handlers.append(handler_func)
         
     | 
| 39 | 
         
            +
                return handler_func
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            class Network:
         
     | 
| 43 | 
         
            +
                """Generic network abstraction.
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                Acts as a convenience wrapper for a parameterized network construction
         
     | 
| 46 | 
         
            +
                function, providing several utility methods and convenient access to
         
     | 
| 47 | 
         
            +
                the inputs/outputs/weights.
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                Network objects can be safely pickled and unpickled for long-term
         
     | 
| 50 | 
         
            +
                archival purposes. The pickling works reliably as long as the underlying
         
     | 
| 51 | 
         
            +
                network construction function is defined in a standalone Python module
         
     | 
| 52 | 
         
            +
                that has no side effects or application-specific imports.
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                Args:
         
     | 
| 55 | 
         
            +
                    name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None.
         
     | 
| 56 | 
         
            +
                    func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
         
     | 
| 57 | 
         
            +
                    static_kwargs: Keyword arguments to be passed in to the network construction function.
         
     | 
| 58 | 
         
            +
                """
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
         
     | 
| 61 | 
         
            +
                    # Locate the user-specified build function.
         
     | 
| 62 | 
         
            +
                    assert isinstance(func_name, str) or util.is_top_level_function(func_name)
         
     | 
| 63 | 
         
            +
                    if util.is_top_level_function(func_name):
         
     | 
| 64 | 
         
            +
                        func_name = util.get_top_level_function_name(func_name)
         
     | 
| 65 | 
         
            +
                    module, func_name = util.get_module_from_obj_name(func_name)
         
     | 
| 66 | 
         
            +
                    func = util.get_obj_from_module(module, func_name)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    # Dig up source code for the module containing the build function.
         
     | 
| 69 | 
         
            +
                    module_src = _import_module_src.get(module, None)
         
     | 
| 70 | 
         
            +
                    if module_src is None:
         
     | 
| 71 | 
         
            +
                        module_src = inspect.getsource(module)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    # Initialize fields.
         
     | 
| 74 | 
         
            +
                    self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None:
         
     | 
| 77 | 
         
            +
                    tfutil.assert_tf_initialized()
         
     | 
| 78 | 
         
            +
                    assert isinstance(name, str)
         
     | 
| 79 | 
         
            +
                    assert len(name) >= 1
         
     | 
| 80 | 
         
            +
                    assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name)
         
     | 
| 81 | 
         
            +
                    assert isinstance(static_kwargs, dict)
         
     | 
| 82 | 
         
            +
                    assert util.is_pickleable(static_kwargs)
         
     | 
| 83 | 
         
            +
                    assert callable(build_func)
         
     | 
| 84 | 
         
            +
                    assert isinstance(build_func_name, str)
         
     | 
| 85 | 
         
            +
                    assert isinstance(build_module_src, str)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    # Choose TensorFlow name scope.
         
     | 
| 88 | 
         
            +
                    with tf.name_scope(None):
         
     | 
| 89 | 
         
            +
                        scope = tf.get_default_graph().unique_name(name, mark_as_used=True)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # Query current TensorFlow device.
         
     | 
| 92 | 
         
            +
                    with tfutil.absolute_name_scope(scope), tf.control_dependencies(None):
         
     | 
| 93 | 
         
            +
                        device = tf.no_op(name="_QueryDevice").device
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    # Immutable state.
         
     | 
| 96 | 
         
            +
                    self._name                  = name
         
     | 
| 97 | 
         
            +
                    self._scope                 = scope
         
     | 
| 98 | 
         
            +
                    self._device                = device
         
     | 
| 99 | 
         
            +
                    self._static_kwargs         = util.EasyDict(copy.deepcopy(static_kwargs))
         
     | 
| 100 | 
         
            +
                    self._build_func            = build_func
         
     | 
| 101 | 
         
            +
                    self._build_func_name       = build_func_name
         
     | 
| 102 | 
         
            +
                    self._build_module_src      = build_module_src
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    # State before _init_graph().
         
     | 
| 105 | 
         
            +
                    self._var_inits             = dict()    # var_name => initial_value, set to None by _init_graph()
         
     | 
| 106 | 
         
            +
                    self._all_inits_known       = False     # Do we know for sure that _var_inits covers all the variables?
         
     | 
| 107 | 
         
            +
                    self._components            = None      # subnet_name => Network, None if the components are not known yet
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    # Initialized by _init_graph().
         
     | 
| 110 | 
         
            +
                    self._input_templates       = None
         
     | 
| 111 | 
         
            +
                    self._output_templates      = None
         
     | 
| 112 | 
         
            +
                    self._own_vars              = None
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    # Cached values initialized the respective methods.
         
     | 
| 115 | 
         
            +
                    self._input_shapes          = None
         
     | 
| 116 | 
         
            +
                    self._output_shapes         = None
         
     | 
| 117 | 
         
            +
                    self._input_names           = None
         
     | 
| 118 | 
         
            +
                    self._output_names          = None
         
     | 
| 119 | 
         
            +
                    self._vars                  = None
         
     | 
| 120 | 
         
            +
                    self._trainables            = None
         
     | 
| 121 | 
         
            +
                    self._var_global_to_local   = None
         
     | 
| 122 | 
         
            +
                    self._run_cache             = dict()
         
     | 
| 123 | 
         
            +
                    self.epochs = tf.Variable(0., dtype=tf.float32, name='epochs')
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                def _init_graph(self) -> None:
         
     | 
| 126 | 
         
            +
                    assert self._var_inits is not None
         
     | 
| 127 | 
         
            +
                    assert self._input_templates is None
         
     | 
| 128 | 
         
            +
                    assert self._output_templates is None
         
     | 
| 129 | 
         
            +
                    assert self._own_vars is None
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    # Initialize components.
         
     | 
| 132 | 
         
            +
                    if self._components is None:
         
     | 
| 133 | 
         
            +
                        self._components = util.EasyDict()
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    # Choose build func kwargs.
         
     | 
| 136 | 
         
            +
                    build_kwargs = dict(self.static_kwargs)
         
     | 
| 137 | 
         
            +
                    build_kwargs["is_template_graph"] = True
         
     | 
| 138 | 
         
            +
                    build_kwargs["components"] = self._components
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # Override scope and device, and ignore surrounding control dependencies.
         
     | 
| 141 | 
         
            +
                    with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None):
         
     | 
| 142 | 
         
            +
                        assert tf.get_variable_scope().name == self.scope
         
     | 
| 143 | 
         
            +
                        assert tf.get_default_graph().get_name_scope() == self.scope
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                        # Create input templates.
         
     | 
| 146 | 
         
            +
                        self._input_templates = []
         
     | 
| 147 | 
         
            +
                        for param in inspect.signature(self._build_func).parameters.values():
         
     | 
| 148 | 
         
            +
                            if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
         
     | 
| 149 | 
         
            +
                                self._input_templates.append(tf.placeholder(tf.float32, name=param.name))
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                        # Call build func.
         
     | 
| 152 | 
         
            +
                        out_expr = self._build_func(*self._input_templates, **build_kwargs)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    # Collect output templates and variables.
         
     | 
| 155 | 
         
            +
                    assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
         
     | 
| 156 | 
         
            +
                    self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
         
     | 
| 157 | 
         
            +
                    self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    # Check for errors.
         
     | 
| 160 | 
         
            +
                    if len(self._input_templates) == 0:
         
     | 
| 161 | 
         
            +
                        raise ValueError("Network build func did not list any inputs.")
         
     | 
| 162 | 
         
            +
                    if len(self._output_templates) == 0:
         
     | 
| 163 | 
         
            +
                        raise ValueError("Network build func did not return any outputs.")
         
     | 
| 164 | 
         
            +
                    if any(not tfutil.is_tf_expression(t) for t in self._output_templates):
         
     | 
| 165 | 
         
            +
                        raise ValueError("Network outputs must be TensorFlow expressions.")
         
     | 
| 166 | 
         
            +
                    if any(t.shape.ndims is None for t in self._input_templates):
         
     | 
| 167 | 
         
            +
                        raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
         
     | 
| 168 | 
         
            +
                    if any(t.shape.ndims is None for t in self._output_templates):
         
     | 
| 169 | 
         
            +
                        raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
         
     | 
| 170 | 
         
            +
                    if any(not isinstance(comp, Network) for comp in self._components.values()):
         
     | 
| 171 | 
         
            +
                        raise ValueError("Components of a Network must be Networks themselves.")
         
     | 
| 172 | 
         
            +
                    if len(self._components) != len(set(comp.name for comp in self._components.values())):
         
     | 
| 173 | 
         
            +
                        raise ValueError("Components of a Network must have unique names.")
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    # Initialize variables.
         
     | 
| 176 | 
         
            +
                    if len(self._var_inits):
         
     | 
| 177 | 
         
            +
                        tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()})
         
     | 
| 178 | 
         
            +
                    remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits]
         
     | 
| 179 | 
         
            +
                    if self._all_inits_known:
         
     | 
| 180 | 
         
            +
                        assert len(remaining_inits) == 0
         
     | 
| 181 | 
         
            +
                    else:
         
     | 
| 182 | 
         
            +
                        tfutil.run(remaining_inits)
         
     | 
| 183 | 
         
            +
                    self._var_inits = None
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                @property
         
     | 
| 186 | 
         
            +
                def name(self):
         
     | 
| 187 | 
         
            +
                    """User-specified name string."""
         
     | 
| 188 | 
         
            +
                    return self._name
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                @property
         
     | 
| 191 | 
         
            +
                def scope(self):
         
     | 
| 192 | 
         
            +
                    """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name."""
         
     | 
| 193 | 
         
            +
                    return self._scope
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                @property
         
     | 
| 196 | 
         
            +
                def device(self):
         
     | 
| 197 | 
         
            +
                    """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time."""
         
     | 
| 198 | 
         
            +
                    return self._device
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                @property
         
     | 
| 201 | 
         
            +
                def static_kwargs(self):
         
     | 
| 202 | 
         
            +
                    """EasyDict of arguments passed to the user-supplied build func."""
         
     | 
| 203 | 
         
            +
                    return copy.deepcopy(self._static_kwargs)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                @property
         
     | 
| 206 | 
         
            +
                def components(self):
         
     | 
| 207 | 
         
            +
                    """EasyDict of sub-networks created by the build func."""
         
     | 
| 208 | 
         
            +
                    return copy.copy(self._get_components())
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                def _get_components(self):
         
     | 
| 211 | 
         
            +
                    if self._components is None:
         
     | 
| 212 | 
         
            +
                        self._init_graph()
         
     | 
| 213 | 
         
            +
                        assert self._components is not None
         
     | 
| 214 | 
         
            +
                    return self._components
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                @property
         
     | 
| 217 | 
         
            +
                def input_shapes(self):
         
     | 
| 218 | 
         
            +
                    """List of input tensor shapes, including minibatch dimension."""
         
     | 
| 219 | 
         
            +
                    if self._input_shapes is None:
         
     | 
| 220 | 
         
            +
                        self._input_shapes = [t.shape.as_list() for t in self.input_templates]
         
     | 
| 221 | 
         
            +
                    return copy.deepcopy(self._input_shapes)
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                @property
         
     | 
| 224 | 
         
            +
                def output_shapes(self):
         
     | 
| 225 | 
         
            +
                    """List of output tensor shapes, including minibatch dimension."""
         
     | 
| 226 | 
         
            +
                    if self._output_shapes is None:
         
     | 
| 227 | 
         
            +
                        self._output_shapes = [t.shape.as_list() for t in self.output_templates]
         
     | 
| 228 | 
         
            +
                    return copy.deepcopy(self._output_shapes)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                @property
         
     | 
| 231 | 
         
            +
                def input_shape(self):
         
     | 
| 232 | 
         
            +
                    """Short-hand for input_shapes[0]."""
         
     | 
| 233 | 
         
            +
                    return self.input_shapes[0]
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                @property
         
     | 
| 236 | 
         
            +
                def output_shape(self):
         
     | 
| 237 | 
         
            +
                    """Short-hand for output_shapes[0]."""
         
     | 
| 238 | 
         
            +
                    return self.output_shapes[0]
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                @property
         
     | 
| 241 | 
         
            +
                def num_inputs(self):
         
     | 
| 242 | 
         
            +
                    """Number of input tensors."""
         
     | 
| 243 | 
         
            +
                    return len(self.input_shapes)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                @property
         
     | 
| 246 | 
         
            +
                def num_outputs(self):
         
     | 
| 247 | 
         
            +
                    """Number of output tensors."""
         
     | 
| 248 | 
         
            +
                    return len(self.output_shapes)
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                @property
         
     | 
| 251 | 
         
            +
                def input_names(self):
         
     | 
| 252 | 
         
            +
                    """Name string for each input."""
         
     | 
| 253 | 
         
            +
                    if self._input_names is None:
         
     | 
| 254 | 
         
            +
                        self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates]
         
     | 
| 255 | 
         
            +
                    return copy.copy(self._input_names)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                @property
         
     | 
| 258 | 
         
            +
                def output_names(self):
         
     | 
| 259 | 
         
            +
                    """Name string for each output."""
         
     | 
| 260 | 
         
            +
                    if self._output_names is None:
         
     | 
| 261 | 
         
            +
                        self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
         
     | 
| 262 | 
         
            +
                    return copy.copy(self._output_names)
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                @property
         
     | 
| 265 | 
         
            +
                def input_templates(self):
         
     | 
| 266 | 
         
            +
                    """Input placeholders in the template graph."""
         
     | 
| 267 | 
         
            +
                    if self._input_templates is None:
         
     | 
| 268 | 
         
            +
                        self._init_graph()
         
     | 
| 269 | 
         
            +
                        assert self._input_templates is not None
         
     | 
| 270 | 
         
            +
                    return copy.copy(self._input_templates)
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                @property
         
     | 
| 273 | 
         
            +
                def output_templates(self):
         
     | 
| 274 | 
         
            +
                    """Output tensors in the template graph."""
         
     | 
| 275 | 
         
            +
                    if self._output_templates is None:
         
     | 
| 276 | 
         
            +
                        self._init_graph()
         
     | 
| 277 | 
         
            +
                        assert self._output_templates is not None
         
     | 
| 278 | 
         
            +
                    return copy.copy(self._output_templates)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                @property
         
     | 
| 281 | 
         
            +
                def own_vars(self):
         
     | 
| 282 | 
         
            +
                    """Variables defined by this network (local_name => var), excluding sub-networks."""
         
     | 
| 283 | 
         
            +
                    return copy.copy(self._get_own_vars())
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                def _get_own_vars(self):
         
     | 
| 286 | 
         
            +
                    if self._own_vars is None:
         
     | 
| 287 | 
         
            +
                        self._init_graph()
         
     | 
| 288 | 
         
            +
                        assert self._own_vars is not None
         
     | 
| 289 | 
         
            +
                    return self._own_vars
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                @property
         
     | 
| 292 | 
         
            +
                def vars(self):
         
     | 
| 293 | 
         
            +
                    """All variables (local_name => var)."""
         
     | 
| 294 | 
         
            +
                    return copy.copy(self._get_vars())
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                def _get_vars(self):
         
     | 
| 297 | 
         
            +
                    if self._vars is None:
         
     | 
| 298 | 
         
            +
                        self._vars = OrderedDict(self._get_own_vars())
         
     | 
| 299 | 
         
            +
                        for comp in self._get_components().values():
         
     | 
| 300 | 
         
            +
                            self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items())
         
     | 
| 301 | 
         
            +
                    return self._vars
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                @property
         
     | 
| 304 | 
         
            +
                def trainables(self):
         
     | 
| 305 | 
         
            +
                    """All trainable variables (local_name => var)."""
         
     | 
| 306 | 
         
            +
                    return copy.copy(self._get_trainables())
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                def _get_trainables(self):
         
     | 
| 309 | 
         
            +
                    if self._trainables is None:
         
     | 
| 310 | 
         
            +
                        self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
         
     | 
| 311 | 
         
            +
                    return self._trainables
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                @property
         
     | 
| 314 | 
         
            +
                def var_global_to_local(self):
         
     | 
| 315 | 
         
            +
                    """Mapping from variable global names to local names."""
         
     | 
| 316 | 
         
            +
                    return copy.copy(self._get_var_global_to_local())
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                def _get_var_global_to_local(self):
         
     | 
| 319 | 
         
            +
                    if self._var_global_to_local is None:
         
     | 
| 320 | 
         
            +
                        self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
         
     | 
| 321 | 
         
            +
                    return self._var_global_to_local
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                def reset_own_vars(self) -> None:
         
     | 
| 324 | 
         
            +
                    """Re-initialize all variables of this network, excluding sub-networks."""
         
     | 
| 325 | 
         
            +
                    if self._var_inits is None or self._components is None:
         
     | 
| 326 | 
         
            +
                        tfutil.run([var.initializer for var in self._get_own_vars().values()])
         
     | 
| 327 | 
         
            +
                    else:
         
     | 
| 328 | 
         
            +
                        self._var_inits.clear()
         
     | 
| 329 | 
         
            +
                        self._all_inits_known = False
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                def reset_vars(self) -> None:
         
     | 
| 332 | 
         
            +
                    """Re-initialize all variables of this network, including sub-networks."""
         
     | 
| 333 | 
         
            +
                    if self._var_inits is None:
         
     | 
| 334 | 
         
            +
                        tfutil.run([var.initializer for var in self._get_vars().values()])
         
     | 
| 335 | 
         
            +
                    else:
         
     | 
| 336 | 
         
            +
                        self._var_inits.clear()
         
     | 
| 337 | 
         
            +
                        self._all_inits_known = False
         
     | 
| 338 | 
         
            +
                        if self._components is not None:
         
     | 
| 339 | 
         
            +
                            for comp in self._components.values():
         
     | 
| 340 | 
         
            +
                                comp.reset_vars()
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                def reset_trainables(self) -> None:
         
     | 
| 343 | 
         
            +
                    """Re-initialize all trainable variables of this network, including sub-networks."""
         
     | 
| 344 | 
         
            +
                    tfutil.run([var.initializer for var in self._get_trainables().values()])
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
         
     | 
| 347 | 
         
            +
                    """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).
         
     | 
| 348 | 
         
            +
                    The graph is placed on the current TensorFlow device."""
         
     | 
| 349 | 
         
            +
                    assert len(in_expr) == self.num_inputs
         
     | 
| 350 | 
         
            +
                    assert not all(expr is None for expr in in_expr)
         
     | 
| 351 | 
         
            +
                    self._get_vars()  # ensure that all variables have been created
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    # Choose build func kwargs.
         
     | 
| 354 | 
         
            +
                    build_kwargs = dict(self.static_kwargs)
         
     | 
| 355 | 
         
            +
                    build_kwargs.update(dynamic_kwargs)
         
     | 
| 356 | 
         
            +
                    build_kwargs["is_template_graph"] = False
         
     | 
| 357 | 
         
            +
                    build_kwargs["components"] = self._components
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    # Build TensorFlow graph to evaluate the network.
         
     | 
| 360 | 
         
            +
                    with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
         
     | 
| 361 | 
         
            +
                        assert tf.get_variable_scope().name == self.scope
         
     | 
| 362 | 
         
            +
                        valid_inputs = [expr for expr in in_expr if expr is not None]
         
     | 
| 363 | 
         
            +
                        final_inputs = []
         
     | 
| 364 | 
         
            +
                        for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
         
     | 
| 365 | 
         
            +
                            if expr is not None:
         
     | 
| 366 | 
         
            +
                                expr = tf.identity(expr, name=name)
         
     | 
| 367 | 
         
            +
                            else:
         
     | 
| 368 | 
         
            +
                                expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
         
     | 
| 369 | 
         
            +
                            final_inputs.append(expr)
         
     | 
| 370 | 
         
            +
                        out_expr = self._build_func(*final_inputs, **build_kwargs)
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    # Propagate input shapes back to the user-specified expressions.
         
     | 
| 373 | 
         
            +
                    for expr, final in zip(in_expr, final_inputs):
         
     | 
| 374 | 
         
            +
                        if isinstance(expr, tf.Tensor):
         
     | 
| 375 | 
         
            +
                            expr.set_shape(final.shape)
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                    # Express outputs in the desired format.
         
     | 
| 378 | 
         
            +
                    assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
         
     | 
| 379 | 
         
            +
                    if return_as_list:
         
     | 
| 380 | 
         
            +
                        out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
         
     | 
| 381 | 
         
            +
                    return out_expr
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
         
     | 
| 384 | 
         
            +
                    """Get the local name of a given variable, without any surrounding name scopes."""
         
     | 
| 385 | 
         
            +
                    assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
         
     | 
| 386 | 
         
            +
                    global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
         
     | 
| 387 | 
         
            +
                    return self._get_var_global_to_local()[global_name]
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
         
     | 
| 390 | 
         
            +
                    """Find variable by local or global name."""
         
     | 
| 391 | 
         
            +
                    assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
         
     | 
| 392 | 
         
            +
                    return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
         
     | 
| 395 | 
         
            +
                    """Get the value of a given variable as NumPy array.
         
     | 
| 396 | 
         
            +
                    Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
         
     | 
| 397 | 
         
            +
                    return self.find_var(var_or_local_name).eval()
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
         
     | 
| 400 | 
         
            +
                    """Set the value of a given variable based on the given NumPy array.
         
     | 
| 401 | 
         
            +
                    Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
         
     | 
| 402 | 
         
            +
                    tfutil.set_vars({self.find_var(var_or_local_name): new_value})
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                def __getstate__(self) -> dict:
         
     | 
| 405 | 
         
            +
                    """Pickle export."""
         
     | 
| 406 | 
         
            +
                    state = dict()
         
     | 
| 407 | 
         
            +
                    state["version"]            = 5
         
     | 
| 408 | 
         
            +
                    state["name"]               = self.name
         
     | 
| 409 | 
         
            +
                    state["static_kwargs"]      = dict(self.static_kwargs)
         
     | 
| 410 | 
         
            +
                    state["components"]         = dict(self.components)
         
     | 
| 411 | 
         
            +
                    state["build_module_src"]   = self._build_module_src
         
     | 
| 412 | 
         
            +
                    state["build_func_name"]    = self._build_func_name
         
     | 
| 413 | 
         
            +
                    state["variables"]          = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values()))))
         
     | 
| 414 | 
         
            +
                    state["input_shapes"]       = self.input_shapes
         
     | 
| 415 | 
         
            +
                    state["output_shapes"]      = self.output_shapes
         
     | 
| 416 | 
         
            +
                    state["input_names"]        = self.input_names
         
     | 
| 417 | 
         
            +
                    state["output_names"]       = self.output_names
         
     | 
| 418 | 
         
            +
                    return state
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                def __setstate__(self, state: dict) -> None:
         
     | 
| 421 | 
         
            +
                    """Pickle import."""
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                    # Execute custom import handlers.
         
     | 
| 424 | 
         
            +
                    for handler in _import_handlers:
         
     | 
| 425 | 
         
            +
                        state = handler(state)
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                    # Get basic fields.
         
     | 
| 428 | 
         
            +
                    assert state["version"] in [2, 3, 4, 5]
         
     | 
| 429 | 
         
            +
                    name = state["name"]
         
     | 
| 430 | 
         
            +
                    static_kwargs = state["static_kwargs"]
         
     | 
| 431 | 
         
            +
                    build_module_src = state["build_module_src"]
         
     | 
| 432 | 
         
            +
                    build_func_name = state["build_func_name"]
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    # Create temporary module from the imported source code.
         
     | 
| 435 | 
         
            +
                    module_name = "_tflib_network_import_" + uuid.uuid4().hex
         
     | 
| 436 | 
         
            +
                    module = types.ModuleType(module_name)
         
     | 
| 437 | 
         
            +
                    sys.modules[module_name] = module
         
     | 
| 438 | 
         
            +
                    _import_module_src[module] = build_module_src
         
     | 
| 439 | 
         
            +
                    exec(build_module_src, module.__dict__) # pylint: disable=exec-used
         
     | 
| 440 | 
         
            +
                    build_func = util.get_obj_from_module(module, build_func_name)
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
                    # Initialize fields.
         
     | 
| 443 | 
         
            +
                    self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src)
         
     | 
| 444 | 
         
            +
                    self._var_inits.update(copy.deepcopy(state["variables"]))
         
     | 
| 445 | 
         
            +
                    self._all_inits_known   = True
         
     | 
| 446 | 
         
            +
                    self._components        = util.EasyDict(state.get("components", {}))
         
     | 
| 447 | 
         
            +
                    self._input_shapes      = copy.deepcopy(state.get("input_shapes", None))
         
     | 
| 448 | 
         
            +
                    self._output_shapes     = copy.deepcopy(state.get("output_shapes", None))
         
     | 
| 449 | 
         
            +
                    self._input_names       = copy.deepcopy(state.get("input_names", None))
         
     | 
| 450 | 
         
            +
                    self._output_names      = copy.deepcopy(state.get("output_names", None))
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                def clone(self, name: str = None, **new_static_kwargs) -> "Network":
         
     | 
| 453 | 
         
            +
                    """Create a clone of this network with its own copy of the variables."""
         
     | 
| 454 | 
         
            +
                    static_kwargs = dict(self.static_kwargs)
         
     | 
| 455 | 
         
            +
                    static_kwargs.update(new_static_kwargs)
         
     | 
| 456 | 
         
            +
                    net = object.__new__(Network)
         
     | 
| 457 | 
         
            +
                    net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src)
         
     | 
| 458 | 
         
            +
                    net.copy_vars_from(self)
         
     | 
| 459 | 
         
            +
                    return net
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                def copy_own_vars_from(self, src_net: "Network") -> None:
         
     | 
| 462 | 
         
            +
                    """Copy the values of all variables from the given network, excluding sub-networks."""
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                    # Source has unknown variables or unknown components => init now.
         
     | 
| 465 | 
         
            +
                    if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
         
     | 
| 466 | 
         
            +
                        src_net._get_vars()
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                   # Both networks are inited => copy directly.
         
     | 
| 469 | 
         
            +
                    if src_net._var_inits is None and self._var_inits is None:
         
     | 
| 470 | 
         
            +
                        names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()]
         
     | 
| 471 | 
         
            +
                        tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
         
     | 
| 472 | 
         
            +
                        return
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                    # Read from source.
         
     | 
| 475 | 
         
            +
                    if src_net._var_inits is None:
         
     | 
| 476 | 
         
            +
                        value_dict = tfutil.run(src_net._get_own_vars())
         
     | 
| 477 | 
         
            +
                    else:
         
     | 
| 478 | 
         
            +
                        value_dict = src_net._var_inits
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                    # Write to destination.
         
     | 
| 481 | 
         
            +
                    if self._var_inits is None:
         
     | 
| 482 | 
         
            +
                        tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()})
         
     | 
| 483 | 
         
            +
                    else:
         
     | 
| 484 | 
         
            +
                        self._var_inits.update(value_dict)
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                def copy_vars_from(self, src_net: "Network") -> None:
         
     | 
| 487 | 
         
            +
                    """Copy the values of all variables from the given network, including sub-networks."""
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    # Source has unknown variables or unknown components => init now.
         
     | 
| 490 | 
         
            +
                    if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
         
     | 
| 491 | 
         
            +
                        src_net._get_vars()
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    # Source is inited, but destination components have not been created yet => set as initial values.
         
     | 
| 494 | 
         
            +
                    if src_net._var_inits is None and self._components is None:
         
     | 
| 495 | 
         
            +
                        self._var_inits.update(tfutil.run(src_net._get_vars()))
         
     | 
| 496 | 
         
            +
                        return
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    # Destination has unknown components => init now.
         
     | 
| 499 | 
         
            +
                    if self._components is None:
         
     | 
| 500 | 
         
            +
                        self._get_vars()
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                    # Both networks are inited => copy directly.
         
     | 
| 503 | 
         
            +
                    if src_net._var_inits is None and self._var_inits is None:
         
     | 
| 504 | 
         
            +
                        names = [name for name in self._get_vars().keys() if name in src_net._get_vars()]
         
     | 
| 505 | 
         
            +
                        tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
         
     | 
| 506 | 
         
            +
                        return
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                    # Copy recursively, component by component.
         
     | 
| 509 | 
         
            +
                    self.copy_own_vars_from(src_net)
         
     | 
| 510 | 
         
            +
                    for name, src_comp in src_net._components.items():
         
     | 
| 511 | 
         
            +
                        if name in self._components:
         
     | 
| 512 | 
         
            +
                            self._components[name].copy_vars_from(src_comp)
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                def copy_trainables_from(self, src_net: "Network") -> None:
         
     | 
| 515 | 
         
            +
                    """Copy the values of all trainable variables from the given network, including sub-networks."""
         
     | 
| 516 | 
         
            +
                    names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()]
         
     | 
| 517 | 
         
            +
                    tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                def copy_compatible_trainables_from(self, src_net: "Network") -> None:
         
     | 
| 520 | 
         
            +
                    """Copy the compatible values of all trainable variables from the given network, including sub-networks"""
         
     | 
| 521 | 
         
            +
                    names = []
         
     | 
| 522 | 
         
            +
                    for name in self.trainables.keys():
         
     | 
| 523 | 
         
            +
                        if name not in src_net.trainables:
         
     | 
| 524 | 
         
            +
                            print("Not restoring (not present):     {}".format(name))
         
     | 
| 525 | 
         
            +
                        elif self.trainables[name].shape != src_net.trainables[name].shape:
         
     | 
| 526 | 
         
            +
                            print("Not restoring (different shape): {}".format(name))
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                        if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
         
     | 
| 529 | 
         
            +
                            names.append(name)
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                    tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                def apply_swa(self, src_net, epoch):
         
     | 
| 534 | 
         
            +
                    """Perform stochastic weight averaging on the compatible values of all trainable variables from the given network, including sub-networks"""
         
     | 
| 535 | 
         
            +
                    names = []
         
     | 
| 536 | 
         
            +
                    for name in self.trainables.keys():
         
     | 
| 537 | 
         
            +
                        if name not in src_net.trainables:
         
     | 
| 538 | 
         
            +
                            print("Not restoring (not present):     {}".format(name))
         
     | 
| 539 | 
         
            +
                        elif self.trainables[name].shape != src_net.trainables[name].shape:
         
     | 
| 540 | 
         
            +
                            print("Not restoring (different shape): {}".format(name))
         
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
                        if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
         
     | 
| 543 | 
         
            +
                            names.append(name)
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
                    scale_new_data = 1.0 - 1.0 / (epoch + 1)
         
     | 
| 546 | 
         
            +
                    scale_moving_average = (1.0 - scale_new_data)
         
     | 
| 547 | 
         
            +
                    tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names}))
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
                def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
         
     | 
| 550 | 
         
            +
                    """Create new network with the given parameters, and copy all variables from this network."""
         
     | 
| 551 | 
         
            +
                    if new_name is None:
         
     | 
| 552 | 
         
            +
                        new_name = self.name
         
     | 
| 553 | 
         
            +
                    static_kwargs = dict(self.static_kwargs)
         
     | 
| 554 | 
         
            +
                    static_kwargs.update(new_static_kwargs)
         
     | 
| 555 | 
         
            +
                    net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
         
     | 
| 556 | 
         
            +
                    net.copy_vars_from(self)
         
     | 
| 557 | 
         
            +
                    return net
         
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
                def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
         
     | 
| 560 | 
         
            +
                    """Construct a TensorFlow op that updates the variables of this network
         
     | 
| 561 | 
         
            +
                    to be slightly closer to those of the given network."""
         
     | 
| 562 | 
         
            +
                    with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
         
     | 
| 563 | 
         
            +
                        ops = []
         
     | 
| 564 | 
         
            +
                        for name, var in self._get_vars().items():
         
     | 
| 565 | 
         
            +
                            if name in src_net._get_vars():
         
     | 
| 566 | 
         
            +
                                cur_beta = beta if var.trainable else beta_nontrainable
         
     | 
| 567 | 
         
            +
                                new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta)
         
     | 
| 568 | 
         
            +
                                ops.append(var.assign(new_value))
         
     | 
| 569 | 
         
            +
                        return tf.group(*ops)
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                def update_epochs(self, epochs: TfExpressionEx = 0) -> tf.Operation:
         
     | 
| 572 | 
         
            +
                    """Construct a TensorFlow op that updates the epoch counter of this network."""
         
     | 
| 573 | 
         
            +
                    with tfutil.absolute_name_scope(self.scope + "/_Epochs"):
         
     | 
| 574 | 
         
            +
                        op = self.epochs.assign(epochs)
         
     | 
| 575 | 
         
            +
                        return op
         
     | 
| 576 | 
         
            +
             
     | 
| 577 | 
         
            +
                def run(self,
         
     | 
| 578 | 
         
            +
                        *in_arrays: Tuple[Union[np.ndarray, None], ...],
         
     | 
| 579 | 
         
            +
                        input_transform: dict = None,
         
     | 
| 580 | 
         
            +
                        output_transform: dict = None,
         
     | 
| 581 | 
         
            +
                        return_as_list: bool = False,
         
     | 
| 582 | 
         
            +
                        print_progress: bool = False,
         
     | 
| 583 | 
         
            +
                        minibatch_size: int = None,
         
     | 
| 584 | 
         
            +
                        num_gpus: int = 1,
         
     | 
| 585 | 
         
            +
                        assume_frozen: bool = False,
         
     | 
| 586 | 
         
            +
                        custom_inputs: Any = None,
         
     | 
| 587 | 
         
            +
                        **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
         
     | 
| 588 | 
         
            +
                    """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
         
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
                    Args:
         
     | 
| 591 | 
         
            +
                        input_transform:    A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
         
     | 
| 592 | 
         
            +
                                            The dict must contain a 'func' field that points to a top-level function. The function is called with the input
         
     | 
| 593 | 
         
            +
                                            TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
         
     | 
| 594 | 
         
            +
                        output_transform:   A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
         
     | 
| 595 | 
         
            +
                                            The dict must contain a 'func' field that points to a top-level function. The function is called with the output
         
     | 
| 596 | 
         
            +
                                            TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
         
     | 
| 597 | 
         
            +
                        return_as_list:     True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
         
     | 
| 598 | 
         
            +
                        print_progress:     Print progress to the console? Useful for very large input arrays.
         
     | 
| 599 | 
         
            +
                        minibatch_size:     Maximum minibatch size to use, None = disable batching.
         
     | 
| 600 | 
         
            +
                        num_gpus:           Number of GPUs to use.
         
     | 
| 601 | 
         
            +
                        assume_frozen:      Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
         
     | 
| 602 | 
         
            +
                        custom_inputs:      Allow to use another tensor as input instead of default placeholders.
         
     | 
| 603 | 
         
            +
                        dynamic_kwargs:     Additional keyword arguments to be passed into the network build function.
         
     | 
| 604 | 
         
            +
                    """
         
     | 
| 605 | 
         
            +
                    assert len(in_arrays) == self.num_inputs
         
     | 
| 606 | 
         
            +
                    assert not all(arr is None for arr in in_arrays)
         
     | 
| 607 | 
         
            +
                    assert input_transform is None or util.is_top_level_function(input_transform["func"])
         
     | 
| 608 | 
         
            +
                    assert output_transform is None or util.is_top_level_function(output_transform["func"])
         
     | 
| 609 | 
         
            +
                    output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
         
     | 
| 610 | 
         
            +
                    num_items = in_arrays[0].shape[0]
         
     | 
| 611 | 
         
            +
                    if minibatch_size is None:
         
     | 
| 612 | 
         
            +
                        minibatch_size = num_items
         
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
                    # Construct unique hash key from all arguments that affect the TensorFlow graph.
         
     | 
| 615 | 
         
            +
                    key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
         
     | 
| 616 | 
         
            +
                    def unwind_key(obj):
         
     | 
| 617 | 
         
            +
                        if isinstance(obj, dict):
         
     | 
| 618 | 
         
            +
                            return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
         
     | 
| 619 | 
         
            +
                        if callable(obj):
         
     | 
| 620 | 
         
            +
                            return util.get_top_level_function_name(obj)
         
     | 
| 621 | 
         
            +
                        return obj
         
     | 
| 622 | 
         
            +
                    key = repr(unwind_key(key))
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
                    # Build graph.
         
     | 
| 625 | 
         
            +
                    if key not in self._run_cache:
         
     | 
| 626 | 
         
            +
                        with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
         
     | 
| 627 | 
         
            +
                            if custom_inputs is not None:
         
     | 
| 628 | 
         
            +
                                with tf.device("/gpu:0"):
         
     | 
| 629 | 
         
            +
                                    in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
         
     | 
| 630 | 
         
            +
                                    in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
         
     | 
| 631 | 
         
            +
                            else:
         
     | 
| 632 | 
         
            +
                                with tf.device("/cpu:0"):
         
     | 
| 633 | 
         
            +
                                    in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
         
     | 
| 634 | 
         
            +
                                    in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                            out_split = []
         
     | 
| 637 | 
         
            +
                            for gpu in range(num_gpus):
         
     | 
| 638 | 
         
            +
                                with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu):
         
     | 
| 639 | 
         
            +
                                    net_gpu = self.clone() if assume_frozen else self
         
     | 
| 640 | 
         
            +
                                    in_gpu = in_split[gpu]
         
     | 
| 641 | 
         
            +
             
     | 
| 642 | 
         
            +
                                    if input_transform is not None:
         
     | 
| 643 | 
         
            +
                                        in_kwargs = dict(input_transform)
         
     | 
| 644 | 
         
            +
                                        in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
         
     | 
| 645 | 
         
            +
                                        in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
         
     | 
| 646 | 
         
            +
             
     | 
| 647 | 
         
            +
                                    assert len(in_gpu) == self.num_inputs
         
     | 
| 648 | 
         
            +
                                    out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
         
     | 
| 649 | 
         
            +
             
     | 
| 650 | 
         
            +
                                    if output_transform is not None:
         
     | 
| 651 | 
         
            +
                                        out_kwargs = dict(output_transform)
         
     | 
| 652 | 
         
            +
                                        out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
         
     | 
| 653 | 
         
            +
                                        out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                                    assert len(out_gpu) == self.num_outputs
         
     | 
| 656 | 
         
            +
                                    out_split.append(out_gpu)
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                            with tf.device("/cpu:0"):
         
     | 
| 659 | 
         
            +
                                out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
         
     | 
| 660 | 
         
            +
                                self._run_cache[key] = in_expr, out_expr
         
     | 
| 661 | 
         
            +
             
     | 
| 662 | 
         
            +
                    # Run minibatches.
         
     | 
| 663 | 
         
            +
                    in_expr, out_expr = self._run_cache[key]
         
     | 
| 664 | 
         
            +
                    out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
         
     | 
| 665 | 
         
            +
             
     | 
| 666 | 
         
            +
                    for mb_begin in range(0, num_items, minibatch_size):
         
     | 
| 667 | 
         
            +
                        if print_progress:
         
     | 
| 668 | 
         
            +
                            print("\r%d / %d" % (mb_begin, num_items), end="")
         
     | 
| 669 | 
         
            +
             
     | 
| 670 | 
         
            +
                        mb_end = min(mb_begin + minibatch_size, num_items)
         
     | 
| 671 | 
         
            +
                        mb_num = mb_end - mb_begin
         
     | 
| 672 | 
         
            +
                        mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
         
     | 
| 673 | 
         
            +
                        mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                        for dst, src in zip(out_arrays, mb_out):
         
     | 
| 676 | 
         
            +
                            dst[mb_begin: mb_end] = src
         
     | 
| 677 | 
         
            +
             
     | 
| 678 | 
         
            +
                    # Done.
         
     | 
| 679 | 
         
            +
                    if print_progress:
         
     | 
| 680 | 
         
            +
                        print("\r%d / %d" % (num_items, num_items))
         
     | 
| 681 | 
         
            +
             
     | 
| 682 | 
         
            +
                    if not return_as_list:
         
     | 
| 683 | 
         
            +
                        out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
         
     | 
| 684 | 
         
            +
                    return out_arrays
         
     | 
| 685 | 
         
            +
             
     | 
| 686 | 
         
            +
                def list_ops(self) -> List[TfExpression]:
         
     | 
| 687 | 
         
            +
                    _ = self.output_templates  # ensure that the template graph has been created
         
     | 
| 688 | 
         
            +
                    include_prefix = self.scope + "/"
         
     | 
| 689 | 
         
            +
                    exclude_prefix = include_prefix + "_"
         
     | 
| 690 | 
         
            +
                    ops = tf.get_default_graph().get_operations()
         
     | 
| 691 | 
         
            +
                    ops = [op for op in ops if op.name.startswith(include_prefix)]
         
     | 
| 692 | 
         
            +
                    ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
         
     | 
| 693 | 
         
            +
                    return ops
         
     | 
| 694 | 
         
            +
             
     | 
| 695 | 
         
            +
                def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
         
     | 
| 696 | 
         
            +
                    """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
         
     | 
| 697 | 
         
            +
                    individual layers of the network. Mainly intended to be used for reporting."""
         
     | 
| 698 | 
         
            +
                    layers = []
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
                    def recurse(scope, parent_ops, parent_vars, level):
         
     | 
| 701 | 
         
            +
                        if len(parent_ops) == 0 and len(parent_vars) == 0:
         
     | 
| 702 | 
         
            +
                            return
         
     | 
| 703 | 
         
            +
             
     | 
| 704 | 
         
            +
                        # Ignore specific patterns.
         
     | 
| 705 | 
         
            +
                        if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
         
     | 
| 706 | 
         
            +
                            return
         
     | 
| 707 | 
         
            +
             
     | 
| 708 | 
         
            +
                        # Filter ops and vars by scope.
         
     | 
| 709 | 
         
            +
                        global_prefix = scope + "/"
         
     | 
| 710 | 
         
            +
                        local_prefix = global_prefix[len(self.scope) + 1:]
         
     | 
| 711 | 
         
            +
                        cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
         
     | 
| 712 | 
         
            +
                        cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
         
     | 
| 713 | 
         
            +
                        if not cur_ops and not cur_vars:
         
     | 
| 714 | 
         
            +
                            return
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                        # Filter out all ops related to variables.
         
     | 
| 717 | 
         
            +
                        for var in [op for op in cur_ops if op.type.startswith("Variable")]:
         
     | 
| 718 | 
         
            +
                            var_prefix = var.name + "/"
         
     | 
| 719 | 
         
            +
                            cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
         
     | 
| 720 | 
         
            +
             
     | 
| 721 | 
         
            +
                        # Scope does not contain ops as immediate children => recurse deeper.
         
     | 
| 722 | 
         
            +
                        contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
         
     | 
| 723 | 
         
            +
                        if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0):
         
     | 
| 724 | 
         
            +
                            visited = set()
         
     | 
| 725 | 
         
            +
                            for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
         
     | 
| 726 | 
         
            +
                                token = rel_name.split("/")[0]
         
     | 
| 727 | 
         
            +
                                if token not in visited:
         
     | 
| 728 | 
         
            +
                                    recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
         
     | 
| 729 | 
         
            +
                                    visited.add(token)
         
     | 
| 730 | 
         
            +
                            return
         
     | 
| 731 | 
         
            +
             
     | 
| 732 | 
         
            +
                        # Report layer.
         
     | 
| 733 | 
         
            +
                        layer_name = scope[len(self.scope) + 1:]
         
     | 
| 734 | 
         
            +
                        layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
         
     | 
| 735 | 
         
            +
                        layer_trainables = [var for _name, var in cur_vars if var.trainable]
         
     | 
| 736 | 
         
            +
                        layers.append((layer_name, layer_output, layer_trainables))
         
     | 
| 737 | 
         
            +
             
     | 
| 738 | 
         
            +
                    recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0)
         
     | 
| 739 | 
         
            +
                    return layers
         
     | 
| 740 | 
         
            +
                
         
     | 
| 741 | 
         
            +
                def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
         
     | 
| 742 | 
         
            +
                    """Print a summary table of the network structure."""
         
     | 
| 743 | 
         
            +
                    rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
         
     | 
| 744 | 
         
            +
                    rows += [["---"] * 4]
         
     | 
| 745 | 
         
            +
                    total_params = 0
         
     | 
| 746 | 
         
            +
             
     | 
| 747 | 
         
            +
                    for layer_name, layer_output, layer_trainables in self.list_layers():
         
     | 
| 748 | 
         
            +
                        num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
         
     | 
| 749 | 
         
            +
                        weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
         
     | 
| 750 | 
         
            +
                        weights.sort(key=lambda x: len(x.name))
         
     | 
| 751 | 
         
            +
                        if len(weights) == 0 and len(layer_trainables) == 1:
         
     | 
| 752 | 
         
            +
                            weights = layer_trainables
         
     | 
| 753 | 
         
            +
                        total_params += num_params
         
     | 
| 754 | 
         
            +
             
     | 
| 755 | 
         
            +
                        if not hide_layers_with_no_params or num_params != 0:
         
     | 
| 756 | 
         
            +
                            num_params_str = str(num_params) if num_params > 0 else "-"
         
     | 
| 757 | 
         
            +
                            output_shape_str = str(layer_output.shape)
         
     | 
| 758 | 
         
            +
                            weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
         
     | 
| 759 | 
         
            +
                            rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
         
     | 
| 760 | 
         
            +
             
     | 
| 761 | 
         
            +
                    rows += [["---"] * 4]
         
     | 
| 762 | 
         
            +
                    rows += [["Total", str(total_params), "", ""]]
         
     | 
| 763 | 
         
            +
             
     | 
| 764 | 
         
            +
                    widths = [max(len(cell) for cell in column) for column in zip(*rows)]
         
     | 
| 765 | 
         
            +
                    print()
         
     | 
| 766 | 
         
            +
                    for row in rows:
         
     | 
| 767 | 
         
            +
                        print("  ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
         
     | 
| 768 | 
         
            +
                    print()
         
     | 
| 769 | 
         
            +
             
     | 
| 770 | 
         
            +
                def setup_weight_histograms(self, title: str = None) -> None:
         
     | 
| 771 | 
         
            +
                    """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
         
     | 
| 772 | 
         
            +
                    if title is None:
         
     | 
| 773 | 
         
            +
                        title = self.name
         
     | 
| 774 | 
         
            +
             
     | 
| 775 | 
         
            +
                    with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
         
     | 
| 776 | 
         
            +
                        for local_name, var in self._get_trainables().items():
         
     | 
| 777 | 
         
            +
                            if "/" in local_name:
         
     | 
| 778 | 
         
            +
                                p = local_name.split("/")
         
     | 
| 779 | 
         
            +
                                name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
         
     | 
| 780 | 
         
            +
                            else:
         
     | 
| 781 | 
         
            +
                                name = title + "_toplevel/" + local_name
         
     | 
| 782 | 
         
            +
             
     | 
| 783 | 
         
            +
                            tf.summary.histogram(name, var)
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 786 | 
         
            +
            # Backwards-compatible emulation of legacy output transformation in Network.run().
         
     | 
| 787 | 
         
            +
             
     | 
| 788 | 
         
            +
            _print_legacy_warning = True
         
     | 
| 789 | 
         
            +
             
     | 
| 790 | 
         
            +
            def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
         
     | 
| 791 | 
         
            +
                global _print_legacy_warning
         
     | 
| 792 | 
         
            +
                legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
         
     | 
| 793 | 
         
            +
                if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
         
     | 
| 794 | 
         
            +
                    return output_transform, dynamic_kwargs
         
     | 
| 795 | 
         
            +
             
     | 
| 796 | 
         
            +
                if _print_legacy_warning:
         
     | 
| 797 | 
         
            +
                    _print_legacy_warning = False
         
     | 
| 798 | 
         
            +
                    print()
         
     | 
| 799 | 
         
            +
                    print("WARNING: Old-style output transformations in Network.run() are deprecated.")
         
     | 
| 800 | 
         
            +
                    print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
         
     | 
| 801 | 
         
            +
                    print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
         
     | 
| 802 | 
         
            +
                    print()
         
     | 
| 803 | 
         
            +
                assert output_transform is None
         
     | 
| 804 | 
         
            +
             
     | 
| 805 | 
         
            +
                new_kwargs = dict(dynamic_kwargs)
         
     | 
| 806 | 
         
            +
                new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
         
     | 
| 807 | 
         
            +
                new_transform["func"] = _legacy_output_transform_func
         
     | 
| 808 | 
         
            +
                return new_transform, new_kwargs
         
     | 
| 809 | 
         
            +
             
     | 
| 810 | 
         
            +
            def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
         
     | 
| 811 | 
         
            +
                if out_mul != 1.0:
         
     | 
| 812 | 
         
            +
                    expr = [x * out_mul for x in expr]
         
     | 
| 813 | 
         
            +
             
     | 
| 814 | 
         
            +
                if out_add != 0.0:
         
     | 
| 815 | 
         
            +
                    expr = [x + out_add for x in expr]
         
     | 
| 816 | 
         
            +
             
     | 
| 817 | 
         
            +
                if out_shrink > 1:
         
     | 
| 818 | 
         
            +
                    ksize = [1, 1, out_shrink, out_shrink]
         
     | 
| 819 | 
         
            +
                    expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
         
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
                if out_dtype is not None:
         
     | 
| 822 | 
         
            +
                    if tf.as_dtype(out_dtype).is_integer:
         
     | 
| 823 | 
         
            +
                        expr = [tf.round(x) for x in expr]
         
     | 
| 824 | 
         
            +
                    expr = [tf.saturate_cast(x, out_dtype) for x in expr]
         
     | 
| 825 | 
         
            +
                return expr
         
     | 
    	
        dnnlib/tflib/ops/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # empty
         
     | 
    	
        dnnlib/tflib/ops/fused_bias_act.cu
    ADDED
    
    | 
         @@ -0,0 +1,220 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            // Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            //
         
     | 
| 3 | 
         
            +
            // NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            // and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            // and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            // distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            // license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            #define EIGEN_USE_GPU
         
     | 
| 10 | 
         
            +
            #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
         
     | 
| 11 | 
         
            +
            #include "tensorflow/core/framework/op.h"
         
     | 
| 12 | 
         
            +
            #include "tensorflow/core/framework/op_kernel.h"
         
     | 
| 13 | 
         
            +
            #include "tensorflow/core/framework/shape_inference.h"
         
     | 
| 14 | 
         
            +
            #include <stdio.h>
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            using namespace tensorflow;
         
     | 
| 17 | 
         
            +
            using namespace tensorflow::shape_inference;
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
| 22 | 
         
            +
            // CUDA kernel.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            template <class T>
         
     | 
| 25 | 
         
            +
            struct FusedBiasActKernelParams
         
     | 
| 26 | 
         
            +
            {
         
     | 
| 27 | 
         
            +
                const T*    x;      // [sizeX]
         
     | 
| 28 | 
         
            +
                const T*    b;      // [sizeB] or NULL
         
     | 
| 29 | 
         
            +
                const T*    xref;   // [sizeX] or NULL
         
     | 
| 30 | 
         
            +
                const T*    yref;   // [sizeX] or NULL
         
     | 
| 31 | 
         
            +
                T*          y;      // [sizeX]
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                int         grad;
         
     | 
| 34 | 
         
            +
                int         axis;
         
     | 
| 35 | 
         
            +
                int         act;
         
     | 
| 36 | 
         
            +
                float       alpha;
         
     | 
| 37 | 
         
            +
                float       gain;
         
     | 
| 38 | 
         
            +
                float       clamp;
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                int         sizeX;
         
     | 
| 41 | 
         
            +
                int         sizeB;
         
     | 
| 42 | 
         
            +
                int         stepB;
         
     | 
| 43 | 
         
            +
                int         loopX;
         
     | 
| 44 | 
         
            +
            };
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            template <class T>
         
     | 
| 47 | 
         
            +
            static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams<T> p)
         
     | 
| 48 | 
         
            +
            {
         
     | 
| 49 | 
         
            +
                const float expRange        = 80.0f;
         
     | 
| 50 | 
         
            +
                const float halfExpRange    = 40.0f;
         
     | 
| 51 | 
         
            +
                const float seluScale       = 1.0507009873554804934193349852946f;
         
     | 
| 52 | 
         
            +
                const float seluAlpha       = 1.6732632423543772848170429916717f;
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                // Loop over elements.
         
     | 
| 55 | 
         
            +
                int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
         
     | 
| 56 | 
         
            +
                for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
         
     | 
| 57 | 
         
            +
                {
         
     | 
| 58 | 
         
            +
                    // Load and apply bias.
         
     | 
| 59 | 
         
            +
                    float x = (float)p.x[xi];
         
     | 
| 60 | 
         
            +
                    if (p.b)
         
     | 
| 61 | 
         
            +
                        x += (float)p.b[(xi / p.stepB) % p.sizeB];
         
     | 
| 62 | 
         
            +
                    float xref = (p.xref) ? (float)p.xref[xi] : 0.0f;
         
     | 
| 63 | 
         
            +
                    float yref = (p.yref) ? (float)p.yref[xi] : 0.0f;
         
     | 
| 64 | 
         
            +
                    float yy = (p.gain != 0.0f) ? yref / p.gain : 0.0f;
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    // Evaluate activation func.
         
     | 
| 67 | 
         
            +
                    float y;
         
     | 
| 68 | 
         
            +
                    switch (p.act * 10 + p.grad)
         
     | 
| 69 | 
         
            +
                    {
         
     | 
| 70 | 
         
            +
                        // linear
         
     | 
| 71 | 
         
            +
                        default:
         
     | 
| 72 | 
         
            +
                        case 10: y = x; break;
         
     | 
| 73 | 
         
            +
                        case 11: y = x; break;
         
     | 
| 74 | 
         
            +
                        case 12: y = 0.0f; break;
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                        // relu
         
     | 
| 77 | 
         
            +
                        case 20: y = (x > 0.0f) ? x : 0.0f; break;
         
     | 
| 78 | 
         
            +
                        case 21: y = (yy > 0.0f) ? x : 0.0f; break;
         
     | 
| 79 | 
         
            +
                        case 22: y = 0.0f; break;
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                        // lrelu
         
     | 
| 82 | 
         
            +
                        case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
         
     | 
| 83 | 
         
            +
                        case 31: y = (yy > 0.0f) ? x : x * p.alpha; break;
         
     | 
| 84 | 
         
            +
                        case 32: y = 0.0f; break;
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                        // tanh
         
     | 
| 87 | 
         
            +
                        case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
         
     | 
| 88 | 
         
            +
                        case 41: y = x * (1.0f - yy * yy); break;
         
     | 
| 89 | 
         
            +
                        case 42: y = x * (1.0f - yy * yy) * (-2.0f * yy); break;
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                        // sigmoid
         
     | 
| 92 | 
         
            +
                        case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
         
     | 
| 93 | 
         
            +
                        case 51: y = x * yy * (1.0f - yy); break;
         
     | 
| 94 | 
         
            +
                        case 52: y = x * yy * (1.0f - yy) * (1.0f - 2.0f * yy); break;
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                        // elu
         
     | 
| 97 | 
         
            +
                        case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
         
     | 
| 98 | 
         
            +
                        case 61: y = (yy >= 0.0f) ? x : x * (yy + 1.0f); break;
         
     | 
| 99 | 
         
            +
                        case 62: y = (yy >= 0.0f) ? 0.0f : x * (yy + 1.0f); break;
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                        // selu
         
     | 
| 102 | 
         
            +
                        case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
         
     | 
| 103 | 
         
            +
                        case 71: y = (yy >= 0.0f) ? x * seluScale : x * (yy + seluScale * seluAlpha); break;
         
     | 
| 104 | 
         
            +
                        case 72: y = (yy >= 0.0f) ? 0.0f : x * (yy + seluScale * seluAlpha); break;
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                        // softplus
         
     | 
| 107 | 
         
            +
                        case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
         
     | 
| 108 | 
         
            +
                        case 81: y = x * (1.0f - expf(-yy)); break;
         
     | 
| 109 | 
         
            +
                        case 82: { float c = expf(-yy); y = x * c * (1.0f - c); } break;
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                        // swish
         
     | 
| 112 | 
         
            +
                        case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
         
     | 
| 113 | 
         
            +
                        case 91:
         
     | 
| 114 | 
         
            +
                        case 92:
         
     | 
| 115 | 
         
            +
                            {
         
     | 
| 116 | 
         
            +
                                float c = expf(xref);
         
     | 
| 117 | 
         
            +
                                float d = c + 1.0f;
         
     | 
| 118 | 
         
            +
                                if (p.grad == 1)
         
     | 
| 119 | 
         
            +
                                    y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
         
     | 
| 120 | 
         
            +
                                else
         
     | 
| 121 | 
         
            +
                                    y = (xref > halfExpRange) ? 0.0f : x * c * (xref * (2.0f - d) + 2.0f * d) / (d * d * d);
         
     | 
| 122 | 
         
            +
                                yref = (xref < -expRange) ? 0.0f : xref / (expf(-xref) + 1.0f) * p.gain;
         
     | 
| 123 | 
         
            +
                            }
         
     | 
| 124 | 
         
            +
                            break;
         
     | 
| 125 | 
         
            +
                    }
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    // Apply gain.
         
     | 
| 128 | 
         
            +
                    y *= p.gain;
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    // Clamp.
         
     | 
| 131 | 
         
            +
                    if (p.clamp >= 0.0f)
         
     | 
| 132 | 
         
            +
                    {
         
     | 
| 133 | 
         
            +
                        if (p.grad == 0)
         
     | 
| 134 | 
         
            +
                            y = (fabsf(y) < p.clamp) ? y : (y >= 0.0f) ? p.clamp : -p.clamp;
         
     | 
| 135 | 
         
            +
                        else
         
     | 
| 136 | 
         
            +
                            y = (fabsf(yref) < p.clamp) ? y : 0.0f;
         
     | 
| 137 | 
         
            +
                    }
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    // Store.
         
     | 
| 140 | 
         
            +
                    p.y[xi] = (T)y;
         
     | 
| 141 | 
         
            +
                }
         
     | 
| 142 | 
         
            +
            }
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
| 145 | 
         
            +
            // TensorFlow op.
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            template <class T>
         
     | 
| 148 | 
         
            +
            struct FusedBiasActOp : public OpKernel
         
     | 
| 149 | 
         
            +
            {
         
     | 
| 150 | 
         
            +
                FusedBiasActKernelParams<T> m_attribs;
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
         
     | 
| 153 | 
         
            +
                {
         
     | 
| 154 | 
         
            +
                    memset(&m_attribs, 0, sizeof(m_attribs));
         
     | 
| 155 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("grad",    &m_attribs.grad));
         
     | 
| 156 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("axis",    &m_attribs.axis));
         
     | 
| 157 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("act",     &m_attribs.act));
         
     | 
| 158 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha",   &m_attribs.alpha));
         
     | 
| 159 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("gain",    &m_attribs.gain));
         
     | 
| 160 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("clamp",   &m_attribs.clamp));
         
     | 
| 161 | 
         
            +
                    OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
         
     | 
| 162 | 
         
            +
                    OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
         
     | 
| 163 | 
         
            +
                    OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
         
     | 
| 164 | 
         
            +
                }
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                void Compute(OpKernelContext* ctx)
         
     | 
| 167 | 
         
            +
                {
         
     | 
| 168 | 
         
            +
                    FusedBiasActKernelParams<T> p = m_attribs;
         
     | 
| 169 | 
         
            +
                    cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    const Tensor& x     = ctx->input(0); // [...]
         
     | 
| 172 | 
         
            +
                    const Tensor& b     = ctx->input(1); // [sizeB] or [0]
         
     | 
| 173 | 
         
            +
                    const Tensor& xref  = ctx->input(2); // x.shape or [0]
         
     | 
| 174 | 
         
            +
                    const Tensor& yref  = ctx->input(3); // x.shape or [0]
         
     | 
| 175 | 
         
            +
                    p.x = x.flat<T>().data();
         
     | 
| 176 | 
         
            +
                    p.b = (b.NumElements()) ? b.flat<T>().data() : NULL;
         
     | 
| 177 | 
         
            +
                    p.xref = (xref.NumElements()) ? xref.flat<T>().data() : NULL;
         
     | 
| 178 | 
         
            +
                    p.yref = (yref.NumElements()) ? yref.flat<T>().data() : NULL;
         
     | 
| 179 | 
         
            +
                    OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
         
     | 
| 180 | 
         
            +
                    OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
         
     | 
| 181 | 
         
            +
                    OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
         
     | 
| 182 | 
         
            +
                    OP_REQUIRES(ctx, xref.NumElements() == 0 || xref.NumElements() == x.NumElements(), errors::InvalidArgument("xref has wrong number of elements"));
         
     | 
| 183 | 
         
            +
                    OP_REQUIRES(ctx, yref.NumElements() == 0 || yref.NumElements() == x.NumElements(), errors::InvalidArgument("yref has wrong number of elements"));
         
     | 
| 184 | 
         
            +
                    OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    p.sizeX = (int)x.NumElements();
         
     | 
| 187 | 
         
            +
                    p.sizeB = (int)b.NumElements();
         
     | 
| 188 | 
         
            +
                    p.stepB = 1;
         
     | 
| 189 | 
         
            +
                    for (int i = m_attribs.axis + 1; i < x.dims(); i++)
         
     | 
| 190 | 
         
            +
                        p.stepB *= (int)x.dim_size(i);
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    Tensor* y = NULL; // x.shape
         
     | 
| 193 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
         
     | 
| 194 | 
         
            +
                    p.y = y->flat<T>().data();
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    p.loopX = 4;
         
     | 
| 197 | 
         
            +
                    int blockSize = 4 * 32;
         
     | 
| 198 | 
         
            +
                    int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
         
     | 
| 199 | 
         
            +
                    void* args[] = {&p};
         
     | 
| 200 | 
         
            +
                    OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel<T>, gridSize, blockSize, args, 0, stream));
         
     | 
| 201 | 
         
            +
                }
         
     | 
| 202 | 
         
            +
            };
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
            REGISTER_OP("FusedBiasAct")
         
     | 
| 205 | 
         
            +
                .Input      ("x: T")
         
     | 
| 206 | 
         
            +
                .Input      ("b: T")
         
     | 
| 207 | 
         
            +
                .Input      ("xref: T")
         
     | 
| 208 | 
         
            +
                .Input      ("yref: T")
         
     | 
| 209 | 
         
            +
                .Output     ("y: T")
         
     | 
| 210 | 
         
            +
                .Attr       ("T: {float, half}")
         
     | 
| 211 | 
         
            +
                .Attr       ("grad: int = 0")
         
     | 
| 212 | 
         
            +
                .Attr       ("axis: int = 1")
         
     | 
| 213 | 
         
            +
                .Attr       ("act: int = 0")
         
     | 
| 214 | 
         
            +
                .Attr       ("alpha: float = 0.0")
         
     | 
| 215 | 
         
            +
                .Attr       ("gain: float = 1.0")
         
     | 
| 216 | 
         
            +
                .Attr       ("clamp: float = -1.0");
         
     | 
| 217 | 
         
            +
            REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<float>("T"), FusedBiasActOp<float>);
         
     | 
| 218 | 
         
            +
            REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), FusedBiasActOp<Eigen::half>);
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
    	
        dnnlib/tflib/ops/fused_bias_act.py
    ADDED
    
    | 
         @@ -0,0 +1,211 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """Custom TensorFlow ops for efficient bias and activation."""
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import tensorflow as tf
         
     | 
| 14 | 
         
            +
            from .. import custom_ops
         
     | 
| 15 | 
         
            +
            from ...util import EasyDict
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def _get_plugin():
         
     | 
| 18 | 
         
            +
                return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            activation_funcs = {
         
     | 
| 23 | 
         
            +
                'linear':   EasyDict(func=lambda x, **_:        x,                          def_alpha=None, def_gain=1.0,           cuda_idx=1, ref='y', zero_2nd_grad=True),
         
     | 
| 24 | 
         
            +
                'relu':     EasyDict(func=lambda x, **_:        tf.nn.relu(x),              def_alpha=None, def_gain=np.sqrt(2),    cuda_idx=2, ref='y', zero_2nd_grad=True),
         
     | 
| 25 | 
         
            +
                'lrelu':    EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2,  def_gain=np.sqrt(2),    cuda_idx=3, ref='y', zero_2nd_grad=True),
         
     | 
| 26 | 
         
            +
                'tanh':     EasyDict(func=lambda x, **_:        tf.nn.tanh(x),              def_alpha=None, def_gain=1.0,           cuda_idx=4, ref='y', zero_2nd_grad=False),
         
     | 
| 27 | 
         
            +
                'sigmoid':  EasyDict(func=lambda x, **_:        tf.nn.sigmoid(x),           def_alpha=None, def_gain=1.0,           cuda_idx=5, ref='y', zero_2nd_grad=False),
         
     | 
| 28 | 
         
            +
                'elu':      EasyDict(func=lambda x, **_:        tf.nn.elu(x),               def_alpha=None, def_gain=1.0,           cuda_idx=6, ref='y', zero_2nd_grad=False),
         
     | 
| 29 | 
         
            +
                'selu':     EasyDict(func=lambda x, **_:        tf.nn.selu(x),              def_alpha=None, def_gain=1.0,           cuda_idx=7, ref='y', zero_2nd_grad=False),
         
     | 
| 30 | 
         
            +
                'softplus': EasyDict(func=lambda x, **_:        tf.nn.softplus(x),          def_alpha=None, def_gain=1.0,           cuda_idx=8, ref='y', zero_2nd_grad=False),
         
     | 
| 31 | 
         
            +
                'swish':    EasyDict(func=lambda x, **_:        tf.nn.sigmoid(x) * x,       def_alpha=None, def_gain=np.sqrt(2),    cuda_idx=9, ref='x', zero_2nd_grad=False),
         
     | 
| 32 | 
         
            +
            }
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
         
     | 
| 37 | 
         
            +
                r"""Fused bias and activation function.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
         
     | 
| 40 | 
         
            +
                and scales the result by `gain`. Each of the steps is optional. In most cases,
         
     | 
| 41 | 
         
            +
                the fused op is considerably more efficient than performing the same calculation
         
     | 
| 42 | 
         
            +
                using standard TensorFlow ops. It supports first and second order gradients,
         
     | 
| 43 | 
         
            +
                but not third order gradients.
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                Args:
         
     | 
| 46 | 
         
            +
                    x:      Input activation tensor. Can have any shape, but if `b` is defined, the
         
     | 
| 47 | 
         
            +
                            dimension corresponding to `axis`, as well as the rank, must be known.
         
     | 
| 48 | 
         
            +
                    b:      Bias vector, or `None` to disable. Must be a 1D tensor of the same type
         
     | 
| 49 | 
         
            +
                            as `x`. The shape must be known, and it must match the dimension of `x`
         
     | 
| 50 | 
         
            +
                            corresponding to `axis`.
         
     | 
| 51 | 
         
            +
                    axis:   The dimension in `x` corresponding to the elements of `b`.
         
     | 
| 52 | 
         
            +
                            The value of `axis` is ignored if `b` is not specified.
         
     | 
| 53 | 
         
            +
                    act:    Name of the activation function to evaluate, or `"linear"` to disable.
         
     | 
| 54 | 
         
            +
                            Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
         
     | 
| 55 | 
         
            +
                            See `activation_funcs` for a full list. `None` is not allowed.
         
     | 
| 56 | 
         
            +
                    alpha:  Shape parameter for the activation function, or `None` to use the default.
         
     | 
| 57 | 
         
            +
                    gain:   Scaling factor for the output tensor, or `None` to use default.
         
     | 
| 58 | 
         
            +
                            See `activation_funcs` for the default scaling of each activation function.
         
     | 
| 59 | 
         
            +
                            If unsure, consider specifying `1.0`.
         
     | 
| 60 | 
         
            +
                    clamp:  Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
         
     | 
| 61 | 
         
            +
                            the clamping (default).
         
     | 
| 62 | 
         
            +
                    impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                Returns:
         
     | 
| 65 | 
         
            +
                    Tensor of the same shape and datatype as `x`.
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                impl_dict = {
         
     | 
| 69 | 
         
            +
                    'ref':  _fused_bias_act_ref,
         
     | 
| 70 | 
         
            +
                    'cuda': _fused_bias_act_cuda,
         
     | 
| 71 | 
         
            +
                }
         
     | 
| 72 | 
         
            +
                return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            def _fused_bias_act_ref(x, b, axis, act, alpha, gain, clamp):
         
     | 
| 77 | 
         
            +
                """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                # Validate arguments.
         
     | 
| 80 | 
         
            +
                x = tf.convert_to_tensor(x)
         
     | 
| 81 | 
         
            +
                b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
         
     | 
| 82 | 
         
            +
                act_spec = activation_funcs[act]
         
     | 
| 83 | 
         
            +
                assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
         
     | 
| 84 | 
         
            +
                assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
         
     | 
| 85 | 
         
            +
                if alpha is None:
         
     | 
| 86 | 
         
            +
                    alpha = act_spec.def_alpha
         
     | 
| 87 | 
         
            +
                if gain is None:
         
     | 
| 88 | 
         
            +
                    gain = act_spec.def_gain
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                # Add bias.
         
     | 
| 91 | 
         
            +
                if b.shape[0] != 0:
         
     | 
| 92 | 
         
            +
                    x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                # Evaluate activation function.
         
     | 
| 95 | 
         
            +
                x = act_spec.func(x, alpha=alpha)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                # Scale by gain.
         
     | 
| 98 | 
         
            +
                if gain != 1:
         
     | 
| 99 | 
         
            +
                    x *= gain
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                # Clamp.
         
     | 
| 102 | 
         
            +
                if clamp is not None:
         
     | 
| 103 | 
         
            +
                    clamp = np.asarray(clamp, dtype=x.dtype.name)
         
     | 
| 104 | 
         
            +
                    assert clamp.shape == () and clamp >= 0
         
     | 
| 105 | 
         
            +
                    x = tf.clip_by_value(x, -clamp, clamp)
         
     | 
| 106 | 
         
            +
                return x
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            def _fused_bias_act_cuda(x, b, axis, act, alpha, gain, clamp):
         
     | 
| 111 | 
         
            +
                """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                # Validate arguments.
         
     | 
| 114 | 
         
            +
                x = tf.convert_to_tensor(x)
         
     | 
| 115 | 
         
            +
                empty_tensor = tf.constant([], dtype=x.dtype)
         
     | 
| 116 | 
         
            +
                b = tf.convert_to_tensor(b) if b is not None else empty_tensor
         
     | 
| 117 | 
         
            +
                act_spec = activation_funcs[act]
         
     | 
| 118 | 
         
            +
                assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
         
     | 
| 119 | 
         
            +
                assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
         
     | 
| 120 | 
         
            +
                if alpha is None:
         
     | 
| 121 | 
         
            +
                    alpha = act_spec.def_alpha
         
     | 
| 122 | 
         
            +
                if gain is None:
         
     | 
| 123 | 
         
            +
                    gain = act_spec.def_gain
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                # Special cases.
         
     | 
| 126 | 
         
            +
                if act == 'linear' and b is None and gain == 1.0:
         
     | 
| 127 | 
         
            +
                    return x
         
     | 
| 128 | 
         
            +
                if act_spec.cuda_idx is None:
         
     | 
| 129 | 
         
            +
                    return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                # CUDA op.
         
     | 
| 132 | 
         
            +
                cuda_op = _get_plugin().fused_bias_act
         
     | 
| 133 | 
         
            +
                cuda_kwargs = dict(axis=int(axis), act=int(act_spec.cuda_idx), gain=float(gain))
         
     | 
| 134 | 
         
            +
                if alpha is not None:
         
     | 
| 135 | 
         
            +
                    cuda_kwargs['alpha'] = float(alpha)
         
     | 
| 136 | 
         
            +
                if clamp is not None:
         
     | 
| 137 | 
         
            +
                    clamp = np.asarray(clamp, dtype=x.dtype.name)
         
     | 
| 138 | 
         
            +
                    assert clamp.shape == () and clamp >= 0
         
     | 
| 139 | 
         
            +
                    cuda_kwargs['clamp'] = float(clamp.astype(np.float32))
         
     | 
| 140 | 
         
            +
                def ref(tensor, name):
         
     | 
| 141 | 
         
            +
                    return tensor if act_spec.ref == name else empty_tensor
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                # Forward pass: y = func(x, b).
         
     | 
| 144 | 
         
            +
                def func_y(x, b):
         
     | 
| 145 | 
         
            +
                    y = cuda_op(x=x, b=b, xref=empty_tensor, yref=empty_tensor, grad=0, **cuda_kwargs)
         
     | 
| 146 | 
         
            +
                    y.set_shape(x.shape)
         
     | 
| 147 | 
         
            +
                    return y
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                # Backward pass: dx, db = grad(dy, x, y)
         
     | 
| 150 | 
         
            +
                def grad_dx(dy, x, y):
         
     | 
| 151 | 
         
            +
                    dx = cuda_op(x=dy, b=empty_tensor, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs)
         
     | 
| 152 | 
         
            +
                    dx.set_shape(x.shape)
         
     | 
| 153 | 
         
            +
                    return dx
         
     | 
| 154 | 
         
            +
                def grad_db(dx):
         
     | 
| 155 | 
         
            +
                    if b.shape[0] == 0:
         
     | 
| 156 | 
         
            +
                        return empty_tensor
         
     | 
| 157 | 
         
            +
                    db = dx
         
     | 
| 158 | 
         
            +
                    if axis < x.shape.rank - 1:
         
     | 
| 159 | 
         
            +
                        db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
         
     | 
| 160 | 
         
            +
                    if axis > 0:
         
     | 
| 161 | 
         
            +
                        db = tf.reduce_sum(db, list(range(axis)))
         
     | 
| 162 | 
         
            +
                    db.set_shape(b.shape)
         
     | 
| 163 | 
         
            +
                    return db
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
         
     | 
| 166 | 
         
            +
                def grad2_d_dy(d_dx, d_db, x, y):
         
     | 
| 167 | 
         
            +
                    d_dy = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs)
         
     | 
| 168 | 
         
            +
                    d_dy.set_shape(x.shape)
         
     | 
| 169 | 
         
            +
                    return d_dy
         
     | 
| 170 | 
         
            +
                def grad2_d_x(d_dx, d_db, x, y):
         
     | 
| 171 | 
         
            +
                    d_x = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=2, **cuda_kwargs)
         
     | 
| 172 | 
         
            +
                    d_x.set_shape(x.shape)
         
     | 
| 173 | 
         
            +
                    return d_x
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                # Fast version for piecewise-linear activation funcs.
         
     | 
| 176 | 
         
            +
                @tf.custom_gradient
         
     | 
| 177 | 
         
            +
                def func_zero_2nd_grad(x, b):
         
     | 
| 178 | 
         
            +
                    y = func_y(x, b)
         
     | 
| 179 | 
         
            +
                    @tf.custom_gradient
         
     | 
| 180 | 
         
            +
                    def grad(dy):
         
     | 
| 181 | 
         
            +
                        dx = grad_dx(dy, x, y)
         
     | 
| 182 | 
         
            +
                        db = grad_db(dx)
         
     | 
| 183 | 
         
            +
                        def grad2(d_dx, d_db):
         
     | 
| 184 | 
         
            +
                            d_dy = grad2_d_dy(d_dx, d_db, x, y)
         
     | 
| 185 | 
         
            +
                            return d_dy
         
     | 
| 186 | 
         
            +
                        return (dx, db), grad2
         
     | 
| 187 | 
         
            +
                    return y, grad
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                # Slow version for general activation funcs.
         
     | 
| 190 | 
         
            +
                @tf.custom_gradient
         
     | 
| 191 | 
         
            +
                def func_nonzero_2nd_grad(x, b):
         
     | 
| 192 | 
         
            +
                    y = func_y(x, b)
         
     | 
| 193 | 
         
            +
                    def grad_wrap(dy):
         
     | 
| 194 | 
         
            +
                        @tf.custom_gradient
         
     | 
| 195 | 
         
            +
                        def grad_impl(dy, x):
         
     | 
| 196 | 
         
            +
                            dx = grad_dx(dy, x, y)
         
     | 
| 197 | 
         
            +
                            db = grad_db(dx)
         
     | 
| 198 | 
         
            +
                            def grad2(d_dx, d_db):
         
     | 
| 199 | 
         
            +
                                d_dy = grad2_d_dy(d_dx, d_db, x, y)
         
     | 
| 200 | 
         
            +
                                d_x = grad2_d_x(d_dx, d_db, x, y)
         
     | 
| 201 | 
         
            +
                                return d_dy, d_x
         
     | 
| 202 | 
         
            +
                            return (dx, db), grad2
         
     | 
| 203 | 
         
            +
                        return grad_impl(dy, x)
         
     | 
| 204 | 
         
            +
                    return y, grad_wrap
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                # Which version to use?
         
     | 
| 207 | 
         
            +
                if act_spec.zero_2nd_grad:
         
     | 
| 208 | 
         
            +
                    return func_zero_2nd_grad(x, b)
         
     | 
| 209 | 
         
            +
                return func_nonzero_2nd_grad(x, b)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
    	
        dnnlib/tflib/ops/upfirdn_2d.cu
    ADDED
    
    | 
         @@ -0,0 +1,359 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            // Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            //
         
     | 
| 3 | 
         
            +
            // NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            // and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            // and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            // distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            // license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            #define EIGEN_USE_GPU
         
     | 
| 10 | 
         
            +
            #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
         
     | 
| 11 | 
         
            +
            #include "tensorflow/core/framework/op.h"
         
     | 
| 12 | 
         
            +
            #include "tensorflow/core/framework/op_kernel.h"
         
     | 
| 13 | 
         
            +
            #include "tensorflow/core/framework/shape_inference.h"
         
     | 
| 14 | 
         
            +
            #include <stdio.h>
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            using namespace tensorflow;
         
     | 
| 17 | 
         
            +
            using namespace tensorflow::shape_inference;
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
| 20 | 
         
            +
            // Helpers.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
         
     | 
| 25 | 
         
            +
            {
         
     | 
| 26 | 
         
            +
                int t = 1 - a / b;
         
     | 
| 27 | 
         
            +
                return (a + t * b) / b - t;
         
     | 
| 28 | 
         
            +
            }
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
| 31 | 
         
            +
            // CUDA kernel params.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            template <class T>
         
     | 
| 34 | 
         
            +
            struct UpFirDn2DKernelParams
         
     | 
| 35 | 
         
            +
            {
         
     | 
| 36 | 
         
            +
                const T*    x;          // [majorDim, inH, inW, minorDim]
         
     | 
| 37 | 
         
            +
                const T*    k;          // [kernelH, kernelW]
         
     | 
| 38 | 
         
            +
                T*          y;          // [majorDim, outH, outW, minorDim]
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                int         upx;
         
     | 
| 41 | 
         
            +
                int         upy;
         
     | 
| 42 | 
         
            +
                int         downx;
         
     | 
| 43 | 
         
            +
                int         downy;
         
     | 
| 44 | 
         
            +
                int         padx0;
         
     | 
| 45 | 
         
            +
                int         padx1;
         
     | 
| 46 | 
         
            +
                int         pady0;
         
     | 
| 47 | 
         
            +
                int         pady1;
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                int         majorDim;
         
     | 
| 50 | 
         
            +
                int         inH;
         
     | 
| 51 | 
         
            +
                int         inW;
         
     | 
| 52 | 
         
            +
                int         minorDim;
         
     | 
| 53 | 
         
            +
                int         kernelH;
         
     | 
| 54 | 
         
            +
                int         kernelW;
         
     | 
| 55 | 
         
            +
                int         outH;
         
     | 
| 56 | 
         
            +
                int         outW;
         
     | 
| 57 | 
         
            +
                int         loopMajor;
         
     | 
| 58 | 
         
            +
                int         loopX;
         
     | 
| 59 | 
         
            +
            };
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
| 62 | 
         
            +
            // General CUDA implementation for large filter kernels.
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            template <class T>
         
     | 
| 65 | 
         
            +
            static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p)
         
     | 
| 66 | 
         
            +
            {
         
     | 
| 67 | 
         
            +
                // Calculate thread index.
         
     | 
| 68 | 
         
            +
                int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
         
     | 
| 69 | 
         
            +
                int outY = minorIdx / p.minorDim;
         
     | 
| 70 | 
         
            +
                minorIdx -= outY * p.minorDim;
         
     | 
| 71 | 
         
            +
                int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
         
     | 
| 72 | 
         
            +
                int majorIdxBase = blockIdx.z * p.loopMajor;
         
     | 
| 73 | 
         
            +
                if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
         
     | 
| 74 | 
         
            +
                    return;
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                // Setup Y receptive field.
         
     | 
| 77 | 
         
            +
                int midY = outY * p.downy + p.upy - 1 - p.pady0;
         
     | 
| 78 | 
         
            +
                int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
         
     | 
| 79 | 
         
            +
                int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
         
     | 
| 80 | 
         
            +
                int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                // Loop over majorDim and outX.
         
     | 
| 83 | 
         
            +
                for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
         
     | 
| 84 | 
         
            +
                for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
         
     | 
| 85 | 
         
            +
                {
         
     | 
| 86 | 
         
            +
                    // Setup X receptive field.
         
     | 
| 87 | 
         
            +
                    int midX = outX * p.downx + p.upx - 1 - p.padx0;
         
     | 
| 88 | 
         
            +
                    int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
         
     | 
| 89 | 
         
            +
                    int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
         
     | 
| 90 | 
         
            +
                    int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    // Initialize pointers.
         
     | 
| 93 | 
         
            +
                    const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
         
     | 
| 94 | 
         
            +
                    const T* kp = &p.k[kernelY * p.kernelW + kernelX];
         
     | 
| 95 | 
         
            +
                    int xpx = p.minorDim;
         
     | 
| 96 | 
         
            +
                    int kpx = -p.upx;
         
     | 
| 97 | 
         
            +
                    int xpy = p.inW * p.minorDim;
         
     | 
| 98 | 
         
            +
                    int kpy = -p.upy * p.kernelW;
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    // Inner loop.
         
     | 
| 101 | 
         
            +
                    float v = 0.0f;
         
     | 
| 102 | 
         
            +
                    for (int y = 0; y < h; y++)
         
     | 
| 103 | 
         
            +
                    {
         
     | 
| 104 | 
         
            +
                        for (int x = 0; x < w; x++)
         
     | 
| 105 | 
         
            +
                        {
         
     | 
| 106 | 
         
            +
                            v += (float)(*xp) * (float)(*kp);
         
     | 
| 107 | 
         
            +
                            xp += xpx;
         
     | 
| 108 | 
         
            +
                            kp += kpx;
         
     | 
| 109 | 
         
            +
                        }
         
     | 
| 110 | 
         
            +
                        xp += xpy - w * xpx;
         
     | 
| 111 | 
         
            +
                        kp += kpy - w * kpx;
         
     | 
| 112 | 
         
            +
                    }
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    // Store result.
         
     | 
| 115 | 
         
            +
                    p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
         
     | 
| 116 | 
         
            +
                }
         
     | 
| 117 | 
         
            +
            }
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
| 120 | 
         
            +
            // Specialized CUDA implementation for small filter kernels.
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH>
         
     | 
| 123 | 
         
            +
            static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p)
         
     | 
| 124 | 
         
            +
            {
         
     | 
| 125 | 
         
            +
                //assert(kernelW % upx == 0);
         
     | 
| 126 | 
         
            +
                //assert(kernelH % upy == 0);
         
     | 
| 127 | 
         
            +
                const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
         
     | 
| 128 | 
         
            +
                const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
         
     | 
| 129 | 
         
            +
                __shared__ volatile float sk[kernelH][kernelW];
         
     | 
| 130 | 
         
            +
                __shared__ volatile float sx[tileInH][tileInW];
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                // Calculate tile index.
         
     | 
| 133 | 
         
            +
                int minorIdx = blockIdx.x;
         
     | 
| 134 | 
         
            +
                int tileOutY = minorIdx / p.minorDim;
         
     | 
| 135 | 
         
            +
                minorIdx -= tileOutY * p.minorDim;
         
     | 
| 136 | 
         
            +
                tileOutY *= tileOutH;
         
     | 
| 137 | 
         
            +
                int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
         
     | 
| 138 | 
         
            +
                int majorIdxBase = blockIdx.z * p.loopMajor;
         
     | 
| 139 | 
         
            +
                if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
         
     | 
| 140 | 
         
            +
                    return;
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                // Load filter kernel (flipped).
         
     | 
| 143 | 
         
            +
                for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
         
     | 
| 144 | 
         
            +
                {
         
     | 
| 145 | 
         
            +
                    int ky = tapIdx / kernelW;
         
     | 
| 146 | 
         
            +
                    int kx = tapIdx - ky * kernelW;
         
     | 
| 147 | 
         
            +
                    float v = 0.0f;
         
     | 
| 148 | 
         
            +
                    if (kx < p.kernelW & ky < p.kernelH)
         
     | 
| 149 | 
         
            +
                        v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
         
     | 
| 150 | 
         
            +
                    sk[ky][kx] = v;
         
     | 
| 151 | 
         
            +
                }
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                // Loop over majorDim and outX.
         
     | 
| 154 | 
         
            +
                for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
         
     | 
| 155 | 
         
            +
                for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
         
     | 
| 156 | 
         
            +
                {
         
     | 
| 157 | 
         
            +
                    // Load input pixels.
         
     | 
| 158 | 
         
            +
                    int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
         
     | 
| 159 | 
         
            +
                    int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
         
     | 
| 160 | 
         
            +
                    int tileInX = floorDiv(tileMidX, upx);
         
     | 
| 161 | 
         
            +
                    int tileInY = floorDiv(tileMidY, upy);
         
     | 
| 162 | 
         
            +
                    __syncthreads();
         
     | 
| 163 | 
         
            +
                    for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
         
     | 
| 164 | 
         
            +
                    {
         
     | 
| 165 | 
         
            +
                        int relInY = inIdx / tileInW;
         
     | 
| 166 | 
         
            +
                        int relInX = inIdx - relInY * tileInW;
         
     | 
| 167 | 
         
            +
                        int inX = relInX + tileInX;
         
     | 
| 168 | 
         
            +
                        int inY = relInY + tileInY;
         
     | 
| 169 | 
         
            +
                        float v = 0.0f;
         
     | 
| 170 | 
         
            +
                        if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
         
     | 
| 171 | 
         
            +
                            v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
         
     | 
| 172 | 
         
            +
                        sx[relInY][relInX] = v;
         
     | 
| 173 | 
         
            +
                    }
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    // Loop over output pixels.
         
     | 
| 176 | 
         
            +
                    __syncthreads();
         
     | 
| 177 | 
         
            +
                    for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
         
     | 
| 178 | 
         
            +
                    {
         
     | 
| 179 | 
         
            +
                        int relOutY = outIdx / tileOutW;
         
     | 
| 180 | 
         
            +
                        int relOutX = outIdx - relOutY * tileOutW;
         
     | 
| 181 | 
         
            +
                        int outX = relOutX + tileOutX;
         
     | 
| 182 | 
         
            +
                        int outY = relOutY + tileOutY;
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                        // Setup receptive field.
         
     | 
| 185 | 
         
            +
                        int midX = tileMidX + relOutX * downx;
         
     | 
| 186 | 
         
            +
                        int midY = tileMidY + relOutY * downy;
         
     | 
| 187 | 
         
            +
                        int inX = floorDiv(midX, upx);
         
     | 
| 188 | 
         
            +
                        int inY = floorDiv(midY, upy);
         
     | 
| 189 | 
         
            +
                        int relInX = inX - tileInX;
         
     | 
| 190 | 
         
            +
                        int relInY = inY - tileInY;
         
     | 
| 191 | 
         
            +
                        int kernelX = (inX + 1) * upx - midX - 1; // flipped
         
     | 
| 192 | 
         
            +
                        int kernelY = (inY + 1) * upy - midY - 1; // flipped
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                        // Inner loop.
         
     | 
| 195 | 
         
            +
                        float v = 0.0f;
         
     | 
| 196 | 
         
            +
                        #pragma unroll
         
     | 
| 197 | 
         
            +
                        for (int y = 0; y < kernelH / upy; y++)
         
     | 
| 198 | 
         
            +
                            #pragma unroll
         
     | 
| 199 | 
         
            +
                            for (int x = 0; x < kernelW / upx; x++)
         
     | 
| 200 | 
         
            +
                                v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                        // Store result.
         
     | 
| 203 | 
         
            +
                        if (outX < p.outW & outY < p.outH)
         
     | 
| 204 | 
         
            +
                            p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
         
     | 
| 205 | 
         
            +
                    }
         
     | 
| 206 | 
         
            +
                }
         
     | 
| 207 | 
         
            +
            }
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
| 210 | 
         
            +
            // TensorFlow op.
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
            template <class T>
         
     | 
| 213 | 
         
            +
            struct UpFirDn2DOp : public OpKernel
         
     | 
| 214 | 
         
            +
            {
         
     | 
| 215 | 
         
            +
                UpFirDn2DKernelParams<T> m_attribs;
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
         
     | 
| 218 | 
         
            +
                {
         
     | 
| 219 | 
         
            +
                    memset(&m_attribs, 0, sizeof(m_attribs));
         
     | 
| 220 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
         
     | 
| 221 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
         
     | 
| 222 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
         
     | 
| 223 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
         
     | 
| 224 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
         
     | 
| 225 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
         
     | 
| 226 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
         
     | 
| 227 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
         
     | 
| 228 | 
         
            +
                    OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
         
     | 
| 229 | 
         
            +
                    OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
         
     | 
| 230 | 
         
            +
                }
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                void Compute(OpKernelContext* ctx)
         
     | 
| 233 | 
         
            +
                {
         
     | 
| 234 | 
         
            +
                    UpFirDn2DKernelParams<T> p = m_attribs;
         
     | 
| 235 | 
         
            +
                    cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
         
     | 
| 238 | 
         
            +
                    const Tensor& k = ctx->input(1); // [kernelH, kernelW]
         
     | 
| 239 | 
         
            +
                    p.x = x.flat<T>().data();
         
     | 
| 240 | 
         
            +
                    p.k = k.flat<T>().data();
         
     | 
| 241 | 
         
            +
                    OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
         
     | 
| 242 | 
         
            +
                    OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
         
     | 
| 243 | 
         
            +
                    OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
         
     | 
| 244 | 
         
            +
                    OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                    p.majorDim  = (int)x.dim_size(0);
         
     | 
| 247 | 
         
            +
                    p.inH       = (int)x.dim_size(1);
         
     | 
| 248 | 
         
            +
                    p.inW       = (int)x.dim_size(2);
         
     | 
| 249 | 
         
            +
                    p.minorDim  = (int)x.dim_size(3);
         
     | 
| 250 | 
         
            +
                    p.kernelH   = (int)k.dim_size(0);
         
     | 
| 251 | 
         
            +
                    p.kernelW   = (int)k.dim_size(1);
         
     | 
| 252 | 
         
            +
                    OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
         
     | 
| 255 | 
         
            +
                    p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
         
     | 
| 256 | 
         
            +
                    OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
         
     | 
| 259 | 
         
            +
                    TensorShape ys;
         
     | 
| 260 | 
         
            +
                    ys.AddDim(p.majorDim);
         
     | 
| 261 | 
         
            +
                    ys.AddDim(p.outH);
         
     | 
| 262 | 
         
            +
                    ys.AddDim(p.outW);
         
     | 
| 263 | 
         
            +
                    ys.AddDim(p.minorDim);
         
     | 
| 264 | 
         
            +
                    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
         
     | 
| 265 | 
         
            +
                    p.y = y->flat<T>().data();
         
     | 
| 266 | 
         
            +
                    OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    // Choose CUDA kernel to use.
         
     | 
| 269 | 
         
            +
                    void* cudaKernel = (void*)UpFirDn2DKernel_large<T>;
         
     | 
| 270 | 
         
            +
                    int tileOutW = -1;
         
     | 
| 271 | 
         
            +
                    int tileOutH = -1;
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7  && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 274 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6  && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 275 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5  && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 276 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4  && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 277 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3  && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 278 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 279 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 280 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 281 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 282 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8  && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 8,1,  128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 283 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,24, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 284 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,20, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 285 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,16, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 286 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,12, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 287 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 1,8,  32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8  && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 290 | 
         
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6  && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 291 | 
         
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4  && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 292 | 
         
            +
                    if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2  && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2,  64,16>; tileOutW = 64;  tileOutH = 16; }
         
     | 
| 293 | 
         
            +
                    if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 24,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 294 | 
         
            +
                    if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 20,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 295 | 
         
            +
                    if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 16,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 296 | 
         
            +
                    if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 12,1, 128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 297 | 
         
            +
                    if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8  && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,1, 1,1, 8,1,  128,8>; tileOutW = 128; tileOutH = 8;  }
         
     | 
| 298 | 
         
            +
                    if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,24, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 299 | 
         
            +
                    if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,20, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 300 | 
         
            +
                    if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,16, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 301 | 
         
            +
                    if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,12, 32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 302 | 
         
            +
                    if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1  && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,2, 1,1, 1,8,  32,32>; tileOutW = 32;  tileOutH = 32; }
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8  && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8,  32,8 >; tileOutW = 32;  tileOutH = 8;  }
         
     | 
| 305 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6  && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6,  32,8 >; tileOutW = 32;  tileOutH = 8;  }
         
     | 
| 306 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4  && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4,  32,8 >; tileOutW = 32;  tileOutH = 8;  }
         
     | 
| 307 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2  && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2,  32,8 >; tileOutW = 32;  tileOutH = 8;  }
         
     | 
| 308 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 24,1, 64,8 >; tileOutW = 64;  tileOutH = 8;  }
         
     | 
| 309 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 20,1, 64,8 >; tileOutW = 64;  tileOutH = 8;  }
         
     | 
| 310 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 16,1, 64,8 >; tileOutW = 64;  tileOutH = 8;  }
         
     | 
| 311 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 12,1, 64,8 >; tileOutW = 64;  tileOutH = 8;  }
         
     | 
| 312 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8  && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,1, 8,1,  64,8 >; tileOutW = 64;  tileOutH = 8;  }
         
     | 
| 313 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1  && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,24, 32,16>; tileOutW = 32;  tileOutH = 16; }
         
     | 
| 314 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1  && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,20, 32,16>; tileOutW = 32;  tileOutH = 16; }
         
     | 
| 315 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1  && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,16, 32,16>; tileOutW = 32;  tileOutH = 16; }
         
     | 
| 316 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1  && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,12, 32,16>; tileOutW = 32;  tileOutH = 16; }
         
     | 
| 317 | 
         
            +
                    if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1  && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,2, 1,8,  32,16>; tileOutW = 32;  tileOutH = 16; }
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    // Choose launch params.
         
     | 
| 320 | 
         
            +
                    dim3 blockSize;
         
     | 
| 321 | 
         
            +
                    dim3 gridSize;
         
     | 
| 322 | 
         
            +
                    if (tileOutW > 0 && tileOutH > 0) // small
         
     | 
| 323 | 
         
            +
                    {
         
     | 
| 324 | 
         
            +
                        p.loopMajor = (p.majorDim - 1) / 16384 + 1;
         
     | 
| 325 | 
         
            +
                        p.loopX = 1;
         
     | 
| 326 | 
         
            +
                        blockSize = dim3(32 * 8, 1, 1);
         
     | 
| 327 | 
         
            +
                        gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
         
     | 
| 328 | 
         
            +
                    }
         
     | 
| 329 | 
         
            +
                    else // large
         
     | 
| 330 | 
         
            +
                    {
         
     | 
| 331 | 
         
            +
                        p.loopMajor = (p.majorDim - 1) / 16384 + 1;
         
     | 
| 332 | 
         
            +
                        p.loopX = 4;
         
     | 
| 333 | 
         
            +
                        blockSize = dim3(4, 32, 1);
         
     | 
| 334 | 
         
            +
                        gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
         
     | 
| 335 | 
         
            +
                    }
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                    // Launch CUDA kernel.
         
     | 
| 338 | 
         
            +
                    void* args[] = {&p};
         
     | 
| 339 | 
         
            +
                    OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
         
     | 
| 340 | 
         
            +
                }
         
     | 
| 341 | 
         
            +
            };
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
            REGISTER_OP("UpFirDn2D")
         
     | 
| 344 | 
         
            +
                .Input      ("x: T")
         
     | 
| 345 | 
         
            +
                .Input      ("k: T")
         
     | 
| 346 | 
         
            +
                .Output     ("y: T")
         
     | 
| 347 | 
         
            +
                .Attr       ("T: {float, half}")
         
     | 
| 348 | 
         
            +
                .Attr       ("upx: int = 1")
         
     | 
| 349 | 
         
            +
                .Attr       ("upy: int = 1")
         
     | 
| 350 | 
         
            +
                .Attr       ("downx: int = 1")
         
     | 
| 351 | 
         
            +
                .Attr       ("downy: int = 1")
         
     | 
| 352 | 
         
            +
                .Attr       ("padx0: int = 0")
         
     | 
| 353 | 
         
            +
                .Attr       ("padx1: int = 0")
         
     | 
| 354 | 
         
            +
                .Attr       ("pady0: int = 0")
         
     | 
| 355 | 
         
            +
                .Attr       ("pady1: int = 0");
         
     | 
| 356 | 
         
            +
            REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>);
         
     | 
| 357 | 
         
            +
            REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>);
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
            //------------------------------------------------------------------------
         
     | 
    	
        dnnlib/tflib/ops/upfirdn_2d.py
    ADDED
    
    | 
         @@ -0,0 +1,418 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """Custom TensorFlow ops for efficient resampling of 2D images."""
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import tensorflow as tf
         
     | 
| 14 | 
         
            +
            from .. import custom_ops
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def _get_plugin():
         
     | 
| 17 | 
         
            +
                return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
         
     | 
| 22 | 
         
            +
                r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
         
     | 
| 25 | 
         
            +
                and performs the following operations for each image, batched across
         
     | 
| 26 | 
         
            +
                `majorDim` and `minorDim`:
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                1. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                2. Pad the image with zeros by the specified number of pixels on each side
         
     | 
| 31 | 
         
            +
                   (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
         
     | 
| 32 | 
         
            +
                   corresponds to cropping the image.
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
         
     | 
| 35 | 
         
            +
                   image so that the footprint of all output pixels lies within the input image.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                4. Downsample the image by throwing away pixels (`downx`, `downy`).
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                This sequence of operations bears close resemblance to scipy.signal.upfirdn().
         
     | 
| 40 | 
         
            +
                The fused op is considerably more efficient than performing the same calculation
         
     | 
| 41 | 
         
            +
                using standard TensorFlow ops. It supports gradients of arbitrary order.
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                Args:
         
     | 
| 44 | 
         
            +
                    x:      Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
         
     | 
| 45 | 
         
            +
                    k:      2D FIR filter of the shape `[firH, firW]`.
         
     | 
| 46 | 
         
            +
                    upx:    Integer upsampling factor along the X-axis (default: 1).
         
     | 
| 47 | 
         
            +
                    upy:    Integer upsampling factor along the Y-axis (default: 1).
         
     | 
| 48 | 
         
            +
                    downx:  Integer downsampling factor along the X-axis (default: 1).
         
     | 
| 49 | 
         
            +
                    downy:  Integer downsampling factor along the Y-axis (default: 1).
         
     | 
| 50 | 
         
            +
                    padx0:  Number of pixels to pad on the left side (default: 0).
         
     | 
| 51 | 
         
            +
                    padx1:  Number of pixels to pad on the right side (default: 0).
         
     | 
| 52 | 
         
            +
                    pady0:  Number of pixels to pad on the top side (default: 0).
         
     | 
| 53 | 
         
            +
                    pady1:  Number of pixels to pad on the bottom side (default: 0).
         
     | 
| 54 | 
         
            +
                    impl:   Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                Returns:
         
     | 
| 57 | 
         
            +
                    Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
         
     | 
| 58 | 
         
            +
                """
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                impl_dict = {
         
     | 
| 61 | 
         
            +
                    'ref':  _upfirdn_2d_ref,
         
     | 
| 62 | 
         
            +
                    'cuda': _upfirdn_2d_cuda,
         
     | 
| 63 | 
         
            +
                }
         
     | 
| 64 | 
         
            +
                return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
         
     | 
| 69 | 
         
            +
                """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                x = tf.convert_to_tensor(x)
         
     | 
| 72 | 
         
            +
                k = np.asarray(k, dtype=np.float32)
         
     | 
| 73 | 
         
            +
                assert x.shape.rank == 4
         
     | 
| 74 | 
         
            +
                inH = x.shape[1].value
         
     | 
| 75 | 
         
            +
                inW = x.shape[2].value
         
     | 
| 76 | 
         
            +
                minorDim = _shape(x, 3)
         
     | 
| 77 | 
         
            +
                kernelH, kernelW = k.shape
         
     | 
| 78 | 
         
            +
                assert inW >= 1 and inH >= 1
         
     | 
| 79 | 
         
            +
                assert kernelW >= 1 and kernelH >= 1
         
     | 
| 80 | 
         
            +
                assert isinstance(upx, int) and isinstance(upy, int)
         
     | 
| 81 | 
         
            +
                assert isinstance(downx, int) and isinstance(downy, int)
         
     | 
| 82 | 
         
            +
                assert isinstance(padx0, int) and isinstance(padx1, int)
         
     | 
| 83 | 
         
            +
                assert isinstance(pady0, int) and isinstance(pady1, int)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                # Upsample (insert zeros).
         
     | 
| 86 | 
         
            +
                x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
         
     | 
| 87 | 
         
            +
                x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
         
     | 
| 88 | 
         
            +
                x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                # Pad (crop if negative).
         
     | 
| 91 | 
         
            +
                x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
         
     | 
| 92 | 
         
            +
                x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                # Convolve with filter.
         
     | 
| 95 | 
         
            +
                x = tf.transpose(x, [0, 3, 1, 2])
         
     | 
| 96 | 
         
            +
                x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
         
     | 
| 97 | 
         
            +
                w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
         
     | 
| 98 | 
         
            +
                x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
         
     | 
| 99 | 
         
            +
                x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
         
     | 
| 100 | 
         
            +
                x = tf.transpose(x, [0, 2, 3, 1])
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                # Downsample (throw away pixels).
         
     | 
| 103 | 
         
            +
                return x[:, ::downy, ::downx, :]
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
         
     | 
| 108 | 
         
            +
                """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                x = tf.convert_to_tensor(x)
         
     | 
| 111 | 
         
            +
                k = np.asarray(k, dtype=np.float32)
         
     | 
| 112 | 
         
            +
                majorDim, inH, inW, minorDim = x.shape.as_list()
         
     | 
| 113 | 
         
            +
                kernelH, kernelW = k.shape
         
     | 
| 114 | 
         
            +
                assert inW >= 1 and inH >= 1
         
     | 
| 115 | 
         
            +
                assert kernelW >= 1 and kernelH >= 1
         
     | 
| 116 | 
         
            +
                assert isinstance(upx, int) and isinstance(upy, int)
         
     | 
| 117 | 
         
            +
                assert isinstance(downx, int) and isinstance(downy, int)
         
     | 
| 118 | 
         
            +
                assert isinstance(padx0, int) and isinstance(padx1, int)
         
     | 
| 119 | 
         
            +
                assert isinstance(pady0, int) and isinstance(pady1, int)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
         
     | 
| 122 | 
         
            +
                outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
         
     | 
| 123 | 
         
            +
                assert outW >= 1 and outH >= 1
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                cuda_op = _get_plugin().up_fir_dn2d
         
     | 
| 126 | 
         
            +
                kc = tf.constant(k, dtype=x.dtype)
         
     | 
| 127 | 
         
            +
                gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
         
     | 
| 128 | 
         
            +
                gpadx0 = kernelW - padx0 - 1
         
     | 
| 129 | 
         
            +
                gpady0 = kernelH - pady0 - 1
         
     | 
| 130 | 
         
            +
                gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
         
     | 
| 131 | 
         
            +
                gpady1 = inH * upy - outH * downy + pady0 - upy + 1
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                @tf.custom_gradient
         
     | 
| 134 | 
         
            +
                def func(x):
         
     | 
| 135 | 
         
            +
                    y = cuda_op(x=x, k=kc, upx=int(upx), upy=int(upy), downx=int(downx), downy=int(downy), padx0=int(padx0), padx1=int(padx1), pady0=int(pady0), pady1=int(pady1))
         
     | 
| 136 | 
         
            +
                    y.set_shape([majorDim, outH, outW, minorDim])
         
     | 
| 137 | 
         
            +
                    @tf.custom_gradient
         
     | 
| 138 | 
         
            +
                    def grad(dy):
         
     | 
| 139 | 
         
            +
                        dx = cuda_op(x=dy, k=gkc, upx=int(downx), upy=int(downy), downx=int(upx), downy=int(upy), padx0=int(gpadx0), padx1=int(gpadx1), pady0=int(gpady0), pady1=int(gpady1))
         
     | 
| 140 | 
         
            +
                        dx.set_shape([majorDim, inH, inW, minorDim])
         
     | 
| 141 | 
         
            +
                        return dx, func
         
     | 
| 142 | 
         
            +
                    return y, grad
         
     | 
| 143 | 
         
            +
                return func(x)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            def filter_2d(x, k, gain=1, padding=0, data_format='NCHW', impl='cuda'):
         
     | 
| 148 | 
         
            +
                r"""Filter a batch of 2D images with the given FIR filter.
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         
     | 
| 151 | 
         
            +
                and filters each image with the given filter. The filter is normalized so that
         
     | 
| 152 | 
         
            +
                if the input pixels are constant, they will be scaled by the specified `gain`.
         
     | 
| 153 | 
         
            +
                Pixels outside the image are assumed to be zero.
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                Args:
         
     | 
| 156 | 
         
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 157 | 
         
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         
     | 
| 158 | 
         
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 159 | 
         
            +
                    padding:      Number of pixels to pad or crop the output on each side (default: 0).
         
     | 
| 160 | 
         
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         
     | 
| 161 | 
         
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                Returns:
         
     | 
| 164 | 
         
            +
                    Tensor of the same shape and datatype as `x`.
         
     | 
| 165 | 
         
            +
                """
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                assert isinstance(padding, int)
         
     | 
| 168 | 
         
            +
                k = _FilterKernel(k=k, gain=gain)
         
     | 
| 169 | 
         
            +
                assert k.w == k.h
         
     | 
| 170 | 
         
            +
                pad0 = k.w // 2 + padding
         
     | 
| 171 | 
         
            +
                pad1 = (k.w - 1) // 2 + padding
         
     | 
| 172 | 
         
            +
                return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def upsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
         
     | 
| 177 | 
         
            +
                r"""Upsample a batch of 2D images with the given filter.
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         
     | 
| 180 | 
         
            +
                and upsamples each image with the given filter. The filter is normalized so that
         
     | 
| 181 | 
         
            +
                if the input pixels are constant, they will be scaled by the specified `gain`.
         
     | 
| 182 | 
         
            +
                Pixels outside the image are assumed to be zero, and the filter is padded with
         
     | 
| 183 | 
         
            +
                zeros so that its shape is a multiple of the upsampling factor.
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                Args:
         
     | 
| 186 | 
         
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 187 | 
         
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         
     | 
| 188 | 
         
            +
                                  The default is `[1] * factor`, which corresponds to nearest-neighbor
         
     | 
| 189 | 
         
            +
                                  upsampling.
         
     | 
| 190 | 
         
            +
                    factor:       Integer upsampling factor (default: 2).
         
     | 
| 191 | 
         
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 192 | 
         
            +
                    padding:      Number of pixels to pad or crop the output on each side (default: 0).
         
     | 
| 193 | 
         
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         
     | 
| 194 | 
         
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                Returns:
         
     | 
| 197 | 
         
            +
                    Tensor of the shape `[N, C, H * factor, W * factor]` or
         
     | 
| 198 | 
         
            +
                    `[N, H * factor, W * factor, C]`, and same datatype as `x`.
         
     | 
| 199 | 
         
            +
                """
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                assert isinstance(factor, int) and factor >= 1
         
     | 
| 202 | 
         
            +
                assert isinstance(padding, int)
         
     | 
| 203 | 
         
            +
                k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2))
         
     | 
| 204 | 
         
            +
                assert k.w == k.h
         
     | 
| 205 | 
         
            +
                pad0 = (k.w + factor - 1) // 2 + padding
         
     | 
| 206 | 
         
            +
                pad1 = (k.w - factor) // 2 + padding
         
     | 
| 207 | 
         
            +
                return _simple_upfirdn_2d(x, k, up=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            def downsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
         
     | 
| 212 | 
         
            +
                r"""Downsample a batch of 2D images with the given filter.
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
         
     | 
| 215 | 
         
            +
                and downsamples each image with the given filter. The filter is normalized so that
         
     | 
| 216 | 
         
            +
                if the input pixels are constant, they will be scaled by the specified `gain`.
         
     | 
| 217 | 
         
            +
                Pixels outside the image are assumed to be zero, and the filter is padded with
         
     | 
| 218 | 
         
            +
                zeros so that its shape is a multiple of the downsampling factor.
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                Args:
         
     | 
| 221 | 
         
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 222 | 
         
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         
     | 
| 223 | 
         
            +
                                  The default is `[1] * factor`, which corresponds to average pooling.
         
     | 
| 224 | 
         
            +
                    factor:       Integer downsampling factor (default: 2).
         
     | 
| 225 | 
         
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 226 | 
         
            +
                    padding:      Number of pixels to pad or crop the output on each side (default: 0).
         
     | 
| 227 | 
         
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         
     | 
| 228 | 
         
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                Returns:
         
     | 
| 231 | 
         
            +
                    Tensor of the shape `[N, C, H // factor, W // factor]` or
         
     | 
| 232 | 
         
            +
                    `[N, H // factor, W // factor, C]`, and same datatype as `x`.
         
     | 
| 233 | 
         
            +
                """
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                assert isinstance(factor, int) and factor >= 1
         
     | 
| 236 | 
         
            +
                assert isinstance(padding, int)
         
     | 
| 237 | 
         
            +
                k = _FilterKernel(k if k is not None else [1] * factor, gain)
         
     | 
| 238 | 
         
            +
                assert k.w == k.h
         
     | 
| 239 | 
         
            +
                pad0 = (k.w - factor + 1) // 2 + padding * factor
         
     | 
| 240 | 
         
            +
                pad1 = (k.w - factor) // 2 + padding * factor
         
     | 
| 241 | 
         
            +
                return _simple_upfirdn_2d(x, k, down=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
            def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
         
     | 
| 246 | 
         
            +
                r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                Padding is performed only once at the beginning, not between the operations.
         
     | 
| 249 | 
         
            +
                The fused op is considerably more efficient than performing the same calculation
         
     | 
| 250 | 
         
            +
                using standard TensorFlow ops. It supports gradients of arbitrary order.
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                Args:
         
     | 
| 253 | 
         
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 254 | 
         
            +
                    w:            Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
         
     | 
| 255 | 
         
            +
                                  Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
         
     | 
| 256 | 
         
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         
     | 
| 257 | 
         
            +
                                  The default is `[1] * factor`, which corresponds to nearest-neighbor
         
     | 
| 258 | 
         
            +
                                  upsampling.
         
     | 
| 259 | 
         
            +
                    factor:       Integer upsampling factor (default: 2).
         
     | 
| 260 | 
         
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 261 | 
         
            +
                    padding:      Number of pixels to pad or crop the output on each side (default: 0).
         
     | 
| 262 | 
         
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         
     | 
| 263 | 
         
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                Returns:
         
     | 
| 266 | 
         
            +
                    Tensor of the shape `[N, C, H * factor, W * factor]` or
         
     | 
| 267 | 
         
            +
                    `[N, H * factor, W * factor, C]`, and same datatype as `x`.
         
     | 
| 268 | 
         
            +
                """
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                assert isinstance(factor, int) and factor >= 1
         
     | 
| 271 | 
         
            +
                assert isinstance(padding, int)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                # Check weight shape.
         
     | 
| 274 | 
         
            +
                w = tf.convert_to_tensor(w)
         
     | 
| 275 | 
         
            +
                ch, cw, _inC, _outC = w.shape.as_list()
         
     | 
| 276 | 
         
            +
                inC = _shape(w, 2)
         
     | 
| 277 | 
         
            +
                outC = _shape(w, 3)
         
     | 
| 278 | 
         
            +
                assert cw == ch
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                # Fast path for 1x1 convolution.
         
     | 
| 281 | 
         
            +
                if cw == 1 and ch == 1:
         
     | 
| 282 | 
         
            +
                    x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID')
         
     | 
| 283 | 
         
            +
                    x = upsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl)
         
     | 
| 284 | 
         
            +
                    return x
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                # Setup filter kernel.
         
     | 
| 287 | 
         
            +
                k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2))
         
     | 
| 288 | 
         
            +
                assert k.w == k.h
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                # Determine data dimensions.
         
     | 
| 291 | 
         
            +
                if data_format == 'NCHW':
         
     | 
| 292 | 
         
            +
                    stride = [1, 1, factor, factor]
         
     | 
| 293 | 
         
            +
                    output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + ch, (_shape(x, 3) - 1) * factor + cw]
         
     | 
| 294 | 
         
            +
                    num_groups = _shape(x, 1) // inC
         
     | 
| 295 | 
         
            +
                else:
         
     | 
| 296 | 
         
            +
                    stride = [1, factor, factor, 1]
         
     | 
| 297 | 
         
            +
                    output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + ch, (_shape(x, 2) - 1) * factor + cw, outC]
         
     | 
| 298 | 
         
            +
                    num_groups = _shape(x, 3) // inC
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                # Transpose weights.
         
     | 
| 301 | 
         
            +
                w = tf.reshape(w, [ch, cw, inC, num_groups, -1])
         
     | 
| 302 | 
         
            +
                w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
         
     | 
| 303 | 
         
            +
                w = tf.reshape(w, [ch, cw, -1, num_groups * inC])
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                # Execute.
         
     | 
| 306 | 
         
            +
                x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
         
     | 
| 307 | 
         
            +
                pad0 = (k.w + factor - cw) // 2 + padding
         
     | 
| 308 | 
         
            +
                pad1 = (k.w - factor - cw + 3) // 2 + padding
         
     | 
| 309 | 
         
            +
                return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
            def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'):
         
     | 
| 314 | 
         
            +
                r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                Padding is performed only once at the beginning, not between the operations.
         
     | 
| 317 | 
         
            +
                The fused op is considerably more efficient than performing the same calculation
         
     | 
| 318 | 
         
            +
                using standard TensorFlow ops. It supports gradients of arbitrary order.
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                Args:
         
     | 
| 321 | 
         
            +
                    x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
         
     | 
| 322 | 
         
            +
                    w:            Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
         
     | 
| 323 | 
         
            +
                                  Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
         
     | 
| 324 | 
         
            +
                    k:            FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
         
     | 
| 325 | 
         
            +
                                  The default is `[1] * factor`, which corresponds to average pooling.
         
     | 
| 326 | 
         
            +
                    factor:       Integer downsampling factor (default: 2).
         
     | 
| 327 | 
         
            +
                    gain:         Scaling factor for signal magnitude (default: 1.0).
         
     | 
| 328 | 
         
            +
                    padding:      Number of pixels to pad or crop the output on each side (default: 0).
         
     | 
| 329 | 
         
            +
                    data_format:  `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
         
     | 
| 330 | 
         
            +
                    impl:         Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                Returns:
         
     | 
| 333 | 
         
            +
                    Tensor of the shape `[N, C, H // factor, W // factor]` or
         
     | 
| 334 | 
         
            +
                    `[N, H // factor, W // factor, C]`, and same datatype as `x`.
         
     | 
| 335 | 
         
            +
                """
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                assert isinstance(factor, int) and factor >= 1
         
     | 
| 338 | 
         
            +
                assert isinstance(padding, int)
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                # Check weight shape.
         
     | 
| 341 | 
         
            +
                w = tf.convert_to_tensor(w)
         
     | 
| 342 | 
         
            +
                ch, cw, _inC, _outC = w.shape.as_list()
         
     | 
| 343 | 
         
            +
                assert cw == ch
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                # Fast path for 1x1 convolution.
         
     | 
| 346 | 
         
            +
                if cw == 1 and ch == 1:
         
     | 
| 347 | 
         
            +
                    x = downsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl)
         
     | 
| 348 | 
         
            +
                    x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID')
         
     | 
| 349 | 
         
            +
                    return x
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                # Setup filter kernel.
         
     | 
| 352 | 
         
            +
                k = _FilterKernel(k if k is not None else [1] * factor, gain)
         
     | 
| 353 | 
         
            +
                assert k.w == k.h
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                # Determine stride.
         
     | 
| 356 | 
         
            +
                if data_format == 'NCHW':
         
     | 
| 357 | 
         
            +
                    s = [1, 1, factor, factor]
         
     | 
| 358 | 
         
            +
                else:
         
     | 
| 359 | 
         
            +
                    s = [1, factor, factor, 1]
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                # Execute.
         
     | 
| 362 | 
         
            +
                pad0 = (k.w - factor + cw) // 2 + padding * factor
         
     | 
| 363 | 
         
            +
                pad1 = (k.w - factor + cw - 1) // 2 + padding * factor
         
     | 
| 364 | 
         
            +
                x = _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl)
         
     | 
| 365 | 
         
            +
                return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 368 | 
         
            +
            # Internal helpers.
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
            class _FilterKernel:
         
     | 
| 371 | 
         
            +
                def __init__(self, k, gain=1):
         
     | 
| 372 | 
         
            +
                    k = np.asarray(k, dtype=np.float32)
         
     | 
| 373 | 
         
            +
                    k /= np.sum(k)
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    # Separable.
         
     | 
| 376 | 
         
            +
                    if k.ndim == 1 and k.size >= 8:
         
     | 
| 377 | 
         
            +
                        self.w = k.size
         
     | 
| 378 | 
         
            +
                        self.h = k.size
         
     | 
| 379 | 
         
            +
                        self.kx = k[np.newaxis, :]
         
     | 
| 380 | 
         
            +
                        self.ky = k[:, np.newaxis] * gain
         
     | 
| 381 | 
         
            +
                        self.kxy = None
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    # Non-separable.
         
     | 
| 384 | 
         
            +
                    else:
         
     | 
| 385 | 
         
            +
                        if k.ndim == 1:
         
     | 
| 386 | 
         
            +
                            k = np.outer(k, k)
         
     | 
| 387 | 
         
            +
                        assert k.ndim == 2
         
     | 
| 388 | 
         
            +
                        self.w = k.shape[1]
         
     | 
| 389 | 
         
            +
                        self.h = k.shape[0]
         
     | 
| 390 | 
         
            +
                        self.kx = None
         
     | 
| 391 | 
         
            +
                        self.ky = None
         
     | 
| 392 | 
         
            +
                        self.kxy = k * gain
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
            def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
         
     | 
| 395 | 
         
            +
                assert isinstance(k, _FilterKernel)
         
     | 
| 396 | 
         
            +
                assert data_format in ['NCHW', 'NHWC']
         
     | 
| 397 | 
         
            +
                assert x.shape.rank == 4
         
     | 
| 398 | 
         
            +
                y = x
         
     | 
| 399 | 
         
            +
                if data_format == 'NCHW':
         
     | 
| 400 | 
         
            +
                    y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
         
     | 
| 401 | 
         
            +
                if k.kx is not None:
         
     | 
| 402 | 
         
            +
                    y = upfirdn_2d(y, k.kx, upx=up, downx=down, padx0=pad0, padx1=pad1, impl=impl)
         
     | 
| 403 | 
         
            +
                if k.ky is not None:
         
     | 
| 404 | 
         
            +
                    y = upfirdn_2d(y, k.ky, upy=up, downy=down, pady0=pad0, pady1=pad1, impl=impl)
         
     | 
| 405 | 
         
            +
                if k.kxy is not None:
         
     | 
| 406 | 
         
            +
                    y = upfirdn_2d(y, k.kxy, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
         
     | 
| 407 | 
         
            +
                if data_format == 'NCHW':
         
     | 
| 408 | 
         
            +
                    y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
         
     | 
| 409 | 
         
            +
                return y
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
            def _shape(tf_expr, dim_idx):
         
     | 
| 412 | 
         
            +
                if tf_expr.shape.rank is not None:
         
     | 
| 413 | 
         
            +
                    dim = tf_expr.shape[dim_idx].value
         
     | 
| 414 | 
         
            +
                    if dim is not None:
         
     | 
| 415 | 
         
            +
                        return dim
         
     | 
| 416 | 
         
            +
                return tf.shape(tf_expr)[dim_idx]
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
    	
        dnnlib/tflib/optimizer.py
    ADDED
    
    | 
         @@ -0,0 +1,372 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """Helper wrapper for a Tensorflow optimizer."""
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import platform
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import tensorflow as tf
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from collections import OrderedDict
         
     | 
| 16 | 
         
            +
            from typing import List, Union
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from . import autosummary
         
     | 
| 19 | 
         
            +
            from . import tfutil
         
     | 
| 20 | 
         
            +
            from .. import util
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from .tfutil import TfExpression, TfExpressionEx
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            _collective_ops_warning_printed = False
         
     | 
| 25 | 
         
            +
            _collective_ops_group_key       = 831766147
         
     | 
| 26 | 
         
            +
            _collective_ops_instance_key    = 436340067
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class Optimizer:
         
     | 
| 29 | 
         
            +
                """A Wrapper for tf.train.Optimizer.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Automatically takes care of:
         
     | 
| 32 | 
         
            +
                - Gradient averaging for multi-GPU training.
         
     | 
| 33 | 
         
            +
                - Gradient accumulation for arbitrarily large minibatches.
         
     | 
| 34 | 
         
            +
                - Dynamic loss scaling and typecasts for FP16 training.
         
     | 
| 35 | 
         
            +
                - Ignoring corrupted gradients that contain NaNs/Infs.
         
     | 
| 36 | 
         
            +
                - Reporting statistics.
         
     | 
| 37 | 
         
            +
                - Well-chosen default settings.
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def __init__(self,
         
     | 
| 41 | 
         
            +
                    name:                   str             = "Train",                  # Name string that will appear in TensorFlow graph.
         
     | 
| 42 | 
         
            +
                    tf_optimizer:           str             = "tf.train.AdamOptimizer", # Underlying optimizer class.
         
     | 
| 43 | 
         
            +
                    learning_rate:          TfExpressionEx  = 0.001,                    # Learning rate. Can vary over time.
         
     | 
| 44 | 
         
            +
                    minibatch_multiplier:   TfExpressionEx  = None,                     # Treat N consecutive minibatches as one by accumulating gradients.
         
     | 
| 45 | 
         
            +
                    share:                  "Optimizer"     = None,                     # Share internal state with a previously created optimizer?
         
     | 
| 46 | 
         
            +
                    use_loss_scaling:       bool            = False,                    # Enable dynamic loss scaling for robust mixed-precision training?
         
     | 
| 47 | 
         
            +
                    loss_scaling_init:      float           = 64.0,                     # Log2 of initial loss scaling factor.
         
     | 
| 48 | 
         
            +
                    loss_scaling_inc:       float           = 0.0005,                   # Log2 of per-minibatch loss scaling increment when there is no overflow.
         
     | 
| 49 | 
         
            +
                    loss_scaling_dec:       float           = 1.0,                      # Log2 of per-minibatch loss scaling decrement when there is an overflow.
         
     | 
| 50 | 
         
            +
                    report_mem_usage:       bool            = False,                    # Report fine-grained memory usage statistics in TensorBoard?
         
     | 
| 51 | 
         
            +
                    **kwargs):
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    # Public fields.
         
     | 
| 54 | 
         
            +
                    self.name                   = name
         
     | 
| 55 | 
         
            +
                    self.learning_rate          = learning_rate
         
     | 
| 56 | 
         
            +
                    self.minibatch_multiplier   = minibatch_multiplier
         
     | 
| 57 | 
         
            +
                    self.id                     = self.name.replace("/", ".")
         
     | 
| 58 | 
         
            +
                    self.scope                  = tf.get_default_graph().unique_name(self.id)
         
     | 
| 59 | 
         
            +
                    self.optimizer_class        = util.get_obj_by_name(tf_optimizer)
         
     | 
| 60 | 
         
            +
                    self.optimizer_kwargs       = dict(kwargs)
         
     | 
| 61 | 
         
            +
                    self.use_loss_scaling       = use_loss_scaling
         
     | 
| 62 | 
         
            +
                    self.loss_scaling_init      = loss_scaling_init
         
     | 
| 63 | 
         
            +
                    self.loss_scaling_inc       = loss_scaling_inc
         
     | 
| 64 | 
         
            +
                    self.loss_scaling_dec       = loss_scaling_dec
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    # Private fields.
         
     | 
| 67 | 
         
            +
                    self._updates_applied       = False
         
     | 
| 68 | 
         
            +
                    self._devices               = OrderedDict() # device_name => EasyDict()
         
     | 
| 69 | 
         
            +
                    self._shared_optimizers     = OrderedDict() # device_name => optimizer_class
         
     | 
| 70 | 
         
            +
                    self._gradient_shapes       = None          # [shape, ...]
         
     | 
| 71 | 
         
            +
                    self._report_mem_usage      = report_mem_usage
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    # Validate arguments.
         
     | 
| 74 | 
         
            +
                    assert callable(self.optimizer_class)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # Share internal state if requested.
         
     | 
| 77 | 
         
            +
                    if share is not None:
         
     | 
| 78 | 
         
            +
                        assert isinstance(share, Optimizer)
         
     | 
| 79 | 
         
            +
                        assert self.optimizer_class is share.optimizer_class
         
     | 
| 80 | 
         
            +
                        assert self.learning_rate is share.learning_rate
         
     | 
| 81 | 
         
            +
                        assert self.optimizer_kwargs == share.optimizer_kwargs
         
     | 
| 82 | 
         
            +
                        self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                def _get_device(self, device_name: str):
         
     | 
| 85 | 
         
            +
                    """Get internal state for the given TensorFlow device."""
         
     | 
| 86 | 
         
            +
                    tfutil.assert_tf_initialized()
         
     | 
| 87 | 
         
            +
                    if device_name in self._devices:
         
     | 
| 88 | 
         
            +
                        return self._devices[device_name]
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    # Initialize fields.
         
     | 
| 91 | 
         
            +
                    device = util.EasyDict()
         
     | 
| 92 | 
         
            +
                    device.name             = device_name
         
     | 
| 93 | 
         
            +
                    device.optimizer        = None          # Underlying optimizer:     optimizer_class
         
     | 
| 94 | 
         
            +
                    device.loss_scaling_var = None          # Log2 of loss scaling:     tf.Variable
         
     | 
| 95 | 
         
            +
                    device.grad_raw         = OrderedDict() # Raw gradients:            var => [grad, ...]
         
     | 
| 96 | 
         
            +
                    device.grad_clean       = OrderedDict() # Clean gradients:          var => grad
         
     | 
| 97 | 
         
            +
                    device.grad_acc_vars    = OrderedDict() # Accumulation sums:        var => tf.Variable
         
     | 
| 98 | 
         
            +
                    device.grad_acc_count   = None          # Accumulation counter:     tf.Variable
         
     | 
| 99 | 
         
            +
                    device.grad_acc         = OrderedDict() # Accumulated gradients:    var => grad
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    # Setup TensorFlow objects.
         
     | 
| 102 | 
         
            +
                    with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
         
     | 
| 103 | 
         
            +
                        if device_name not in self._shared_optimizers:
         
     | 
| 104 | 
         
            +
                            optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
         
     | 
| 105 | 
         
            +
                            self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
         
     | 
| 106 | 
         
            +
                        device.optimizer = self._shared_optimizers[device_name]
         
     | 
| 107 | 
         
            +
                        if self.use_loss_scaling:
         
     | 
| 108 | 
         
            +
                            device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    # Register device.
         
     | 
| 111 | 
         
            +
                    self._devices[device_name] = device
         
     | 
| 112 | 
         
            +
                    return device
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
         
     | 
| 115 | 
         
            +
                    """Register the gradients of the given loss function with respect to the given variables.
         
     | 
| 116 | 
         
            +
                    Intended to be called once per GPU."""
         
     | 
| 117 | 
         
            +
                    tfutil.assert_tf_initialized()
         
     | 
| 118 | 
         
            +
                    assert not self._updates_applied
         
     | 
| 119 | 
         
            +
                    device = self._get_device(loss.device)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    # Validate trainables.
         
     | 
| 122 | 
         
            +
                    if isinstance(trainable_vars, dict):
         
     | 
| 123 | 
         
            +
                        trainable_vars = list(trainable_vars.values())  # allow passing in Network.trainables as vars
         
     | 
| 124 | 
         
            +
                    assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
         
     | 
| 125 | 
         
            +
                    assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
         
     | 
| 126 | 
         
            +
                    assert all(var.device == device.name for var in trainable_vars)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    # Validate shapes.
         
     | 
| 129 | 
         
            +
                    if self._gradient_shapes is None:
         
     | 
| 130 | 
         
            +
                        self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
         
     | 
| 131 | 
         
            +
                    assert len(trainable_vars) == len(self._gradient_shapes)
         
     | 
| 132 | 
         
            +
                    assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    # Report memory usage if requested.
         
     | 
| 135 | 
         
            +
                    deps = [loss]
         
     | 
| 136 | 
         
            +
                    if self._report_mem_usage:
         
     | 
| 137 | 
         
            +
                        self._report_mem_usage = False
         
     | 
| 138 | 
         
            +
                        try:
         
     | 
| 139 | 
         
            +
                            with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
         
     | 
| 140 | 
         
            +
                                deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
         
     | 
| 141 | 
         
            +
                        except tf.errors.NotFoundError:
         
     | 
| 142 | 
         
            +
                            pass
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    # Compute gradients.
         
     | 
| 145 | 
         
            +
                    with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
         
     | 
| 146 | 
         
            +
                        loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
         
     | 
| 147 | 
         
            +
                        gate = tf.train.Optimizer.GATE_NONE  # disable gating to reduce memory usage
         
     | 
| 148 | 
         
            +
                        grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # Register gradients.
         
     | 
| 151 | 
         
            +
                    for grad, var in grad_list:
         
     | 
| 152 | 
         
            +
                        if var not in device.grad_raw:
         
     | 
| 153 | 
         
            +
                            device.grad_raw[var] = []
         
     | 
| 154 | 
         
            +
                        device.grad_raw[var].append(grad)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
         
     | 
| 157 | 
         
            +
                    """Construct training op to update the registered variables based on their gradients."""
         
     | 
| 158 | 
         
            +
                    tfutil.assert_tf_initialized()
         
     | 
| 159 | 
         
            +
                    assert not self._updates_applied
         
     | 
| 160 | 
         
            +
                    self._updates_applied = True
         
     | 
| 161 | 
         
            +
                    all_ops = []
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    # Check for no-op.
         
     | 
| 164 | 
         
            +
                    if allow_no_op and len(self._devices) == 0:
         
     | 
| 165 | 
         
            +
                        with tfutil.absolute_name_scope(self.scope):
         
     | 
| 166 | 
         
            +
                            return tf.no_op(name='TrainingOp')
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    # Clean up gradients.
         
     | 
| 169 | 
         
            +
                    for device_idx, device in enumerate(self._devices.values()):
         
     | 
| 170 | 
         
            +
                        with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
         
     | 
| 171 | 
         
            +
                            for var, grad in device.grad_raw.items():
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                                # Filter out disconnected gradients and convert to float32.
         
     | 
| 174 | 
         
            +
                                grad = [g for g in grad if g is not None]
         
     | 
| 175 | 
         
            +
                                grad = [tf.cast(g, tf.float32) for g in grad]
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                                # Sum within the device.
         
     | 
| 178 | 
         
            +
                                if len(grad) == 0:
         
     | 
| 179 | 
         
            +
                                    grad = tf.zeros(var.shape)  # No gradients => zero.
         
     | 
| 180 | 
         
            +
                                elif len(grad) == 1:
         
     | 
| 181 | 
         
            +
                                    grad = grad[0]              # Single gradient => use as is.
         
     | 
| 182 | 
         
            +
                                else:
         
     | 
| 183 | 
         
            +
                                    grad = tf.add_n(grad)       # Multiple gradients => sum.
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                                # Scale as needed.
         
     | 
| 186 | 
         
            +
                                scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
         
     | 
| 187 | 
         
            +
                                scale = tf.constant(scale, dtype=tf.float32, name="scale")
         
     | 
| 188 | 
         
            +
                                if self.minibatch_multiplier is not None:
         
     | 
| 189 | 
         
            +
                                    scale /= tf.cast(self.minibatch_multiplier, tf.float32)
         
     | 
| 190 | 
         
            +
                                scale = self.undo_loss_scaling(scale)
         
     | 
| 191 | 
         
            +
                                device.grad_clean[var] = grad * scale
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    # Sum gradients across devices.
         
     | 
| 194 | 
         
            +
                    if len(self._devices) > 1:
         
     | 
| 195 | 
         
            +
                        with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
         
     | 
| 196 | 
         
            +
                            if platform.system() == "Windows":    # Windows => NCCL ops are not available.
         
     | 
| 197 | 
         
            +
                                self._broadcast_fallback()
         
     | 
| 198 | 
         
            +
                            elif tf.VERSION.startswith("1.15."):  # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
         
     | 
| 199 | 
         
            +
                                self._broadcast_fallback()
         
     | 
| 200 | 
         
            +
                            else:                                 # Otherwise => NCCL ops are safe to use.
         
     | 
| 201 | 
         
            +
                                self._broadcast_nccl()
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    # Apply updates separately on each device.
         
     | 
| 204 | 
         
            +
                    for device_idx, device in enumerate(self._devices.values()):
         
     | 
| 205 | 
         
            +
                        with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
         
     | 
| 206 | 
         
            +
                            # pylint: disable=cell-var-from-loop
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                            # Accumulate gradients over time.
         
     | 
| 209 | 
         
            +
                            if self.minibatch_multiplier is None:
         
     | 
| 210 | 
         
            +
                                acc_ok = tf.constant(True, name='acc_ok')
         
     | 
| 211 | 
         
            +
                                device.grad_acc = OrderedDict(device.grad_clean)
         
     | 
| 212 | 
         
            +
                            else:
         
     | 
| 213 | 
         
            +
                                # Create variables.
         
     | 
| 214 | 
         
            +
                                with tf.control_dependencies(None):
         
     | 
| 215 | 
         
            +
                                    for var in device.grad_clean.keys():
         
     | 
| 216 | 
         
            +
                                        device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
         
     | 
| 217 | 
         
            +
                                    device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                                # Track counter.
         
     | 
| 220 | 
         
            +
                                count_cur = device.grad_acc_count + 1.0
         
     | 
| 221 | 
         
            +
                                count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
         
     | 
| 222 | 
         
            +
                                count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
         
     | 
| 223 | 
         
            +
                                acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
         
     | 
| 224 | 
         
            +
                                all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                                # Track gradients.
         
     | 
| 227 | 
         
            +
                                for var, grad in device.grad_clean.items():
         
     | 
| 228 | 
         
            +
                                    acc_var = device.grad_acc_vars[var]
         
     | 
| 229 | 
         
            +
                                    acc_cur = acc_var + grad
         
     | 
| 230 | 
         
            +
                                    device.grad_acc[var] = acc_cur
         
     | 
| 231 | 
         
            +
                                    with tf.control_dependencies([acc_cur]):
         
     | 
| 232 | 
         
            +
                                        acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
         
     | 
| 233 | 
         
            +
                                        acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
         
     | 
| 234 | 
         
            +
                                        all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                            # No overflow => apply gradients.
         
     | 
| 237 | 
         
            +
                            all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
         
     | 
| 238 | 
         
            +
                            apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
         
     | 
| 239 | 
         
            +
                            all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                            # Adjust loss scaling.
         
     | 
| 242 | 
         
            +
                            if self.use_loss_scaling:
         
     | 
| 243 | 
         
            +
                                ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
         
     | 
| 244 | 
         
            +
                                ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
         
     | 
| 245 | 
         
            +
                                ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
         
     | 
| 246 | 
         
            +
                                all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                            # Last device => report statistics.
         
     | 
| 249 | 
         
            +
                            if device_idx == len(self._devices) - 1:
         
     | 
| 250 | 
         
            +
                                all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
         
     | 
| 251 | 
         
            +
                                all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
         
     | 
| 252 | 
         
            +
                                if self.use_loss_scaling:
         
     | 
| 253 | 
         
            +
                                    all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    # Initialize variables.
         
     | 
| 256 | 
         
            +
                    self.reset_optimizer_state()
         
     | 
| 257 | 
         
            +
                    if self.use_loss_scaling:
         
     | 
| 258 | 
         
            +
                        tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
         
     | 
| 259 | 
         
            +
                    if self.minibatch_multiplier is not None:
         
     | 
| 260 | 
         
            +
                        tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    # Group everything into a single op.
         
     | 
| 263 | 
         
            +
                    with tfutil.absolute_name_scope(self.scope):
         
     | 
| 264 | 
         
            +
                        return tf.group(*all_ops, name="TrainingOp")
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                def reset_optimizer_state(self) -> None:
         
     | 
| 267 | 
         
            +
                    """Reset internal state of the underlying optimizer."""
         
     | 
| 268 | 
         
            +
                    tfutil.assert_tf_initialized()
         
     | 
| 269 | 
         
            +
                    tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
         
     | 
| 272 | 
         
            +
                    """Get or create variable representing log2 of the current dynamic loss scaling factor."""
         
     | 
| 273 | 
         
            +
                    return self._get_device(device).loss_scaling_var
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
         
     | 
| 276 | 
         
            +
                    """Apply dynamic loss scaling for the given expression."""
         
     | 
| 277 | 
         
            +
                    assert tfutil.is_tf_expression(value)
         
     | 
| 278 | 
         
            +
                    if not self.use_loss_scaling:
         
     | 
| 279 | 
         
            +
                        return value
         
     | 
| 280 | 
         
            +
                    return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
         
     | 
| 283 | 
         
            +
                    """Undo the effect of dynamic loss scaling for the given expression."""
         
     | 
| 284 | 
         
            +
                    assert tfutil.is_tf_expression(value)
         
     | 
| 285 | 
         
            +
                    if not self.use_loss_scaling:
         
     | 
| 286 | 
         
            +
                        return value
         
     | 
| 287 | 
         
            +
                    return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                def _broadcast_nccl(self):
         
     | 
| 290 | 
         
            +
                    """Sum gradients across devices using NCCL ops (fast path)."""
         
     | 
| 291 | 
         
            +
                    from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module
         
     | 
| 292 | 
         
            +
                    for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
         
     | 
| 293 | 
         
            +
                        if any(x.shape.num_elements() > 0 for x in all_vars):
         
     | 
| 294 | 
         
            +
                            all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
         
     | 
| 295 | 
         
            +
                            all_grads = nccl_ops.all_sum(all_grads)
         
     | 
| 296 | 
         
            +
                            for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
         
     | 
| 297 | 
         
            +
                                device.grad_clean[var] = grad
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                def _broadcast_fallback(self):
         
     | 
| 300 | 
         
            +
                    """Sum gradients across devices using TensorFlow collective ops (slow fallback path)."""
         
     | 
| 301 | 
         
            +
                    from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module
         
     | 
| 302 | 
         
            +
                    global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key
         
     | 
| 303 | 
         
            +
                    if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()):
         
     | 
| 304 | 
         
            +
                        return
         
     | 
| 305 | 
         
            +
                    if not _collective_ops_warning_printed:
         
     | 
| 306 | 
         
            +
                        print("------------------------------------------------------------------------")
         
     | 
| 307 | 
         
            +
                        print("WARNING: Using slow fallback implementation for inter-GPU communication.")
         
     | 
| 308 | 
         
            +
                        print("Please use TensorFlow 1.14 on Linux for optimal training performance.")
         
     | 
| 309 | 
         
            +
                        print("------------------------------------------------------------------------")
         
     | 
| 310 | 
         
            +
                        _collective_ops_warning_printed = True
         
     | 
| 311 | 
         
            +
                    for device in self._devices.values():
         
     | 
| 312 | 
         
            +
                        with tf.device(device.name):
         
     | 
| 313 | 
         
            +
                            combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()]
         
     | 
| 314 | 
         
            +
                            combo = tf.concat(combo, axis=0)
         
     | 
| 315 | 
         
            +
                            combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id',
         
     | 
| 316 | 
         
            +
                                group_size=len(self._devices), group_key=_collective_ops_group_key,
         
     | 
| 317 | 
         
            +
                                instance_key=_collective_ops_instance_key)
         
     | 
| 318 | 
         
            +
                            cur_ofs = 0
         
     | 
| 319 | 
         
            +
                            for var, grad_old in device.grad_clean.items():
         
     | 
| 320 | 
         
            +
                                grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape)
         
     | 
| 321 | 
         
            +
                                cur_ofs += grad_old.shape.num_elements()
         
     | 
| 322 | 
         
            +
                                device.grad_clean[var] = grad_new
         
     | 
| 323 | 
         
            +
                    _collective_ops_instance_key += 1
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
            class SimpleAdam:
         
     | 
| 327 | 
         
            +
                """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
         
     | 
| 330 | 
         
            +
                    self.name = name
         
     | 
| 331 | 
         
            +
                    self.learning_rate = learning_rate
         
     | 
| 332 | 
         
            +
                    self.beta1 = beta1
         
     | 
| 333 | 
         
            +
                    self.beta2 = beta2
         
     | 
| 334 | 
         
            +
                    self.epsilon = epsilon
         
     | 
| 335 | 
         
            +
                    self.all_state_vars = []
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                def variables(self):
         
     | 
| 338 | 
         
            +
                    return self.all_state_vars
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
         
     | 
| 341 | 
         
            +
                    assert gate_gradients == tf.train.Optimizer.GATE_NONE
         
     | 
| 342 | 
         
            +
                    return list(zip(tf.gradients(loss, var_list), var_list))
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                def apply_gradients(self, grads_and_vars):
         
     | 
| 345 | 
         
            +
                    with tf.name_scope(self.name):
         
     | 
| 346 | 
         
            +
                        state_vars = []
         
     | 
| 347 | 
         
            +
                        update_ops = []
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                        # Adjust learning rate to deal with startup bias.
         
     | 
| 350 | 
         
            +
                        with tf.control_dependencies(None):
         
     | 
| 351 | 
         
            +
                            b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
         
     | 
| 352 | 
         
            +
                            b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
         
     | 
| 353 | 
         
            +
                            state_vars += [b1pow_var, b2pow_var]
         
     | 
| 354 | 
         
            +
                        b1pow_new = b1pow_var * self.beta1
         
     | 
| 355 | 
         
            +
                        b2pow_new = b2pow_var * self.beta2
         
     | 
| 356 | 
         
            +
                        update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
         
     | 
| 357 | 
         
            +
                        lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                        # Construct ops to update each variable.
         
     | 
| 360 | 
         
            +
                        for grad, var in grads_and_vars:
         
     | 
| 361 | 
         
            +
                            with tf.control_dependencies(None):
         
     | 
| 362 | 
         
            +
                                m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
         
     | 
| 363 | 
         
            +
                                v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
         
     | 
| 364 | 
         
            +
                                state_vars += [m_var, v_var]
         
     | 
| 365 | 
         
            +
                            m_new = self.beta1 * m_var + (1 - self.beta1) * grad
         
     | 
| 366 | 
         
            +
                            v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
         
     | 
| 367 | 
         
            +
                            var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
         
     | 
| 368 | 
         
            +
                            update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                        # Group everything together.
         
     | 
| 371 | 
         
            +
                        self.all_state_vars += state_vars
         
     | 
| 372 | 
         
            +
                        return tf.group(*update_ops)
         
     | 
    	
        dnnlib/tflib/tfutil.py
    ADDED
    
    | 
         @@ -0,0 +1,264 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """Miscellaneous helper utils for Tensorflow."""
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import tensorflow as tf
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # Silence deprecation warnings from TensorFlow 1.13 onwards
         
     | 
| 16 | 
         
            +
            import logging
         
     | 
| 17 | 
         
            +
            logging.getLogger('tensorflow').setLevel(logging.ERROR)
         
     | 
| 18 | 
         
            +
            import tensorflow.contrib   # requires TensorFlow 1.x!
         
     | 
| 19 | 
         
            +
            tf.contrib = tensorflow.contrib
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from typing import Any, Iterable, List, Union
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
         
     | 
| 24 | 
         
            +
            """A type that represents a valid Tensorflow expression."""
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
         
     | 
| 27 | 
         
            +
            """A type that can be converted to a valid Tensorflow expression."""
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def run(*args, **kwargs) -> Any:
         
     | 
| 31 | 
         
            +
                """Run the specified ops in the default session."""
         
     | 
| 32 | 
         
            +
                assert_tf_initialized()
         
     | 
| 33 | 
         
            +
                return tf.get_default_session().run(*args, **kwargs)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def is_tf_expression(x: Any) -> bool:
         
     | 
| 37 | 
         
            +
                """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
         
     | 
| 38 | 
         
            +
                return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
         
     | 
| 42 | 
         
            +
                """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
         
     | 
| 43 | 
         
            +
                return [dim.value for dim in shape]
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def flatten(x: TfExpressionEx) -> TfExpression:
         
     | 
| 47 | 
         
            +
                """Shortcut function for flattening a tensor."""
         
     | 
| 48 | 
         
            +
                with tf.name_scope("Flatten"):
         
     | 
| 49 | 
         
            +
                    return tf.reshape(x, [-1])
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def log2(x: TfExpressionEx) -> TfExpression:
         
     | 
| 53 | 
         
            +
                """Logarithm in base 2."""
         
     | 
| 54 | 
         
            +
                with tf.name_scope("Log2"):
         
     | 
| 55 | 
         
            +
                    return tf.log(x) * np.float32(1.0 / np.log(2.0))
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def exp2(x: TfExpressionEx) -> TfExpression:
         
     | 
| 59 | 
         
            +
                """Exponent in base 2."""
         
     | 
| 60 | 
         
            +
                with tf.name_scope("Exp2"):
         
     | 
| 61 | 
         
            +
                    return tf.exp(x * np.float32(np.log(2.0)))
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def erfinv(y: TfExpressionEx) -> TfExpression:
         
     | 
| 65 | 
         
            +
                """Inverse of the error function."""
         
     | 
| 66 | 
         
            +
                # pylint: disable=no-name-in-module
         
     | 
| 67 | 
         
            +
                from tensorflow.python.ops.distributions import special_math
         
     | 
| 68 | 
         
            +
                return special_math.erfinv(y)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
         
     | 
| 72 | 
         
            +
                """Linear interpolation."""
         
     | 
| 73 | 
         
            +
                with tf.name_scope("Lerp"):
         
     | 
| 74 | 
         
            +
                    return a + (b - a) * t
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
         
     | 
| 78 | 
         
            +
                """Linear interpolation with clip."""
         
     | 
| 79 | 
         
            +
                with tf.name_scope("LerpClip"):
         
     | 
| 80 | 
         
            +
                    return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def absolute_name_scope(scope: str) -> tf.name_scope:
         
     | 
| 84 | 
         
            +
                """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
         
     | 
| 85 | 
         
            +
                return tf.name_scope(scope + "/")
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
         
     | 
| 89 | 
         
            +
                """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
         
     | 
| 90 | 
         
            +
                return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            def _sanitize_tf_config(config_dict: dict = None) -> dict:
         
     | 
| 94 | 
         
            +
                # Defaults.
         
     | 
| 95 | 
         
            +
                cfg = dict()
         
     | 
| 96 | 
         
            +
                cfg["rnd.np_random_seed"]               = None      # Random seed for NumPy. None = keep as is.
         
     | 
| 97 | 
         
            +
                cfg["rnd.tf_random_seed"]               = "auto"    # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
         
     | 
| 98 | 
         
            +
                cfg["env.TF_CPP_MIN_LOG_LEVEL"]         = "1"       # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
         
     | 
| 99 | 
         
            +
                cfg["env.HDF5_USE_FILE_LOCKING"]        = "FALSE"   # Disable HDF5 file locking to avoid concurrency issues with network shares.
         
     | 
| 100 | 
         
            +
                cfg["graph_options.place_pruned_graph"] = True      # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
         
     | 
| 101 | 
         
            +
                cfg["gpu_options.allow_growth"]         = True      # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                # Remove defaults for environment variables that are already set.
         
     | 
| 104 | 
         
            +
                for key in list(cfg):
         
     | 
| 105 | 
         
            +
                    fields = key.split(".")
         
     | 
| 106 | 
         
            +
                    if fields[0] == "env":
         
     | 
| 107 | 
         
            +
                        assert len(fields) == 2
         
     | 
| 108 | 
         
            +
                        if fields[1] in os.environ:
         
     | 
| 109 | 
         
            +
                            del cfg[key]
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                # User overrides.
         
     | 
| 112 | 
         
            +
                if config_dict is not None:
         
     | 
| 113 | 
         
            +
                    cfg.update(config_dict)
         
     | 
| 114 | 
         
            +
                return cfg
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def init_tf(config_dict: dict = None) -> None:
         
     | 
| 118 | 
         
            +
                """Initialize TensorFlow session using good default settings."""
         
     | 
| 119 | 
         
            +
                # Skip if already initialized.
         
     | 
| 120 | 
         
            +
                if tf.get_default_session() is not None:
         
     | 
| 121 | 
         
            +
                    return
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                # Setup config dict and random seeds.
         
     | 
| 124 | 
         
            +
                cfg = _sanitize_tf_config(config_dict)
         
     | 
| 125 | 
         
            +
                np_random_seed = cfg["rnd.np_random_seed"]
         
     | 
| 126 | 
         
            +
                if np_random_seed is not None:
         
     | 
| 127 | 
         
            +
                    np.random.seed(np_random_seed)
         
     | 
| 128 | 
         
            +
                tf_random_seed = cfg["rnd.tf_random_seed"]
         
     | 
| 129 | 
         
            +
                if tf_random_seed == "auto":
         
     | 
| 130 | 
         
            +
                    tf_random_seed = np.random.randint(1 << 31)
         
     | 
| 131 | 
         
            +
                if tf_random_seed is not None:
         
     | 
| 132 | 
         
            +
                    tf.set_random_seed(tf_random_seed)
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                # Setup environment variables.
         
     | 
| 135 | 
         
            +
                for key, value in cfg.items():
         
     | 
| 136 | 
         
            +
                    fields = key.split(".")
         
     | 
| 137 | 
         
            +
                    if fields[0] == "env":
         
     | 
| 138 | 
         
            +
                        assert len(fields) == 2
         
     | 
| 139 | 
         
            +
                        os.environ[fields[1]] = str(value)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                # Create default TensorFlow session.
         
     | 
| 142 | 
         
            +
                create_session(cfg, force_as_default=True)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            def assert_tf_initialized():
         
     | 
| 146 | 
         
            +
                """Check that TensorFlow session has been initialized."""
         
     | 
| 147 | 
         
            +
                if tf.get_default_session() is None:
         
     | 
| 148 | 
         
            +
                    raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
            def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
         
     | 
| 152 | 
         
            +
                """Create tf.Session based on config dict."""
         
     | 
| 153 | 
         
            +
                # Setup TensorFlow config proto.
         
     | 
| 154 | 
         
            +
                cfg = _sanitize_tf_config(config_dict)
         
     | 
| 155 | 
         
            +
                config_proto = tf.ConfigProto()
         
     | 
| 156 | 
         
            +
                for key, value in cfg.items():
         
     | 
| 157 | 
         
            +
                    fields = key.split(".")
         
     | 
| 158 | 
         
            +
                    if fields[0] not in ["rnd", "env"]:
         
     | 
| 159 | 
         
            +
                        obj = config_proto
         
     | 
| 160 | 
         
            +
                        for field in fields[:-1]:
         
     | 
| 161 | 
         
            +
                            obj = getattr(obj, field)
         
     | 
| 162 | 
         
            +
                        setattr(obj, fields[-1], value)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                # Create session.
         
     | 
| 165 | 
         
            +
                session = tf.Session(config=config_proto)
         
     | 
| 166 | 
         
            +
                if force_as_default:
         
     | 
| 167 | 
         
            +
                    # pylint: disable=protected-access
         
     | 
| 168 | 
         
            +
                    session._default_session = session.as_default()
         
     | 
| 169 | 
         
            +
                    session._default_session.enforce_nesting = False
         
     | 
| 170 | 
         
            +
                    session._default_session.__enter__()
         
     | 
| 171 | 
         
            +
                return session
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
            def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
         
     | 
| 175 | 
         
            +
                """Initialize all tf.Variables that have not already been initialized.
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                Equivalent to the following, but more efficient and does not bloat the tf graph:
         
     | 
| 178 | 
         
            +
                tf.variables_initializer(tf.report_uninitialized_variables()).run()
         
     | 
| 179 | 
         
            +
                """
         
     | 
| 180 | 
         
            +
                assert_tf_initialized()
         
     | 
| 181 | 
         
            +
                if target_vars is None:
         
     | 
| 182 | 
         
            +
                    target_vars = tf.global_variables()
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                test_vars = []
         
     | 
| 185 | 
         
            +
                test_ops = []
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                with tf.control_dependencies(None):  # ignore surrounding control_dependencies
         
     | 
| 188 | 
         
            +
                    for var in target_vars:
         
     | 
| 189 | 
         
            +
                        assert is_tf_expression(var)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                        try:
         
     | 
| 192 | 
         
            +
                            tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
         
     | 
| 193 | 
         
            +
                        except KeyError:
         
     | 
| 194 | 
         
            +
                            # Op does not exist => variable may be uninitialized.
         
     | 
| 195 | 
         
            +
                            test_vars.append(var)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                            with absolute_name_scope(var.name.split(":")[0]):
         
     | 
| 198 | 
         
            +
                                test_ops.append(tf.is_variable_initialized(var))
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
         
     | 
| 201 | 
         
            +
                run([var.initializer for var in init_vars])
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
            def set_vars(var_to_value_dict: dict) -> None:
         
     | 
| 205 | 
         
            +
                """Set the values of given tf.Variables.
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                Equivalent to the following, but more efficient and does not bloat the tf graph:
         
     | 
| 208 | 
         
            +
                tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
         
     | 
| 209 | 
         
            +
                """
         
     | 
| 210 | 
         
            +
                assert_tf_initialized()
         
     | 
| 211 | 
         
            +
                ops = []
         
     | 
| 212 | 
         
            +
                feed_dict = {}
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                for var, value in var_to_value_dict.items():
         
     | 
| 215 | 
         
            +
                    assert is_tf_expression(var)
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    try:
         
     | 
| 218 | 
         
            +
                        setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0"))  # look for existing op
         
     | 
| 219 | 
         
            +
                    except KeyError:
         
     | 
| 220 | 
         
            +
                        with absolute_name_scope(var.name.split(":")[0]):
         
     | 
| 221 | 
         
            +
                            with tf.control_dependencies(None):  # ignore surrounding control_dependencies
         
     | 
| 222 | 
         
            +
                                setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter")  # create new setter
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    ops.append(setter)
         
     | 
| 225 | 
         
            +
                    feed_dict[setter.op.inputs[1]] = value
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                run(ops, feed_dict)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
            def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
         
     | 
| 231 | 
         
            +
                """Create tf.Variable with large initial value without bloating the tf graph."""
         
     | 
| 232 | 
         
            +
                assert_tf_initialized()
         
     | 
| 233 | 
         
            +
                assert isinstance(initial_value, np.ndarray)
         
     | 
| 234 | 
         
            +
                zeros = tf.zeros(initial_value.shape, initial_value.dtype)
         
     | 
| 235 | 
         
            +
                var = tf.Variable(zeros, *args, **kwargs)
         
     | 
| 236 | 
         
            +
                set_vars({var: initial_value})
         
     | 
| 237 | 
         
            +
                return var
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
            def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
         
     | 
| 241 | 
         
            +
                """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
         
     | 
| 242 | 
         
            +
                Can be used as an input transformation for Network.run().
         
     | 
| 243 | 
         
            +
                """
         
     | 
| 244 | 
         
            +
                images = tf.cast(images, tf.float32)
         
     | 
| 245 | 
         
            +
                if nhwc_to_nchw:
         
     | 
| 246 | 
         
            +
                    images = tf.transpose(images, [0, 3, 1, 2])
         
     | 
| 247 | 
         
            +
                return images * ((drange[1] - drange[0]) / 255) + drange[0]
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
            def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True):
         
     | 
| 251 | 
         
            +
                """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
         
     | 
| 252 | 
         
            +
                Can be used as an output transformation for Network.run().
         
     | 
| 253 | 
         
            +
                """
         
     | 
| 254 | 
         
            +
                images = tf.cast(images, tf.float32)
         
     | 
| 255 | 
         
            +
                if shrink > 1:
         
     | 
| 256 | 
         
            +
                    ksize = [1, 1, shrink, shrink]
         
     | 
| 257 | 
         
            +
                    images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
         
     | 
| 258 | 
         
            +
                if nchw_to_nhwc:
         
     | 
| 259 | 
         
            +
                    images = tf.transpose(images, [0, 2, 3, 1])
         
     | 
| 260 | 
         
            +
                scale = 255 / (drange[1] - drange[0])
         
     | 
| 261 | 
         
            +
                images = images * scale + (0.5 - drange[0] * scale)
         
     | 
| 262 | 
         
            +
                if uint8_cast:
         
     | 
| 263 | 
         
            +
                    images = tf.saturate_cast(images, tf.uint8)
         
     | 
| 264 | 
         
            +
                return images
         
     | 
    	
        dnnlib/util.py
    ADDED
    
    | 
         @@ -0,0 +1,472 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            """Miscellaneous utility classes and functions."""
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import ctypes
         
     | 
| 12 | 
         
            +
            import fnmatch
         
     | 
| 13 | 
         
            +
            import importlib
         
     | 
| 14 | 
         
            +
            import inspect
         
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
            import os
         
     | 
| 17 | 
         
            +
            import shutil
         
     | 
| 18 | 
         
            +
            import sys
         
     | 
| 19 | 
         
            +
            import types
         
     | 
| 20 | 
         
            +
            import io
         
     | 
| 21 | 
         
            +
            import pickle
         
     | 
| 22 | 
         
            +
            import re
         
     | 
| 23 | 
         
            +
            import requests
         
     | 
| 24 | 
         
            +
            import html
         
     | 
| 25 | 
         
            +
            import hashlib
         
     | 
| 26 | 
         
            +
            import glob
         
     | 
| 27 | 
         
            +
            import tempfile
         
     | 
| 28 | 
         
            +
            import urllib
         
     | 
| 29 | 
         
            +
            import urllib.request
         
     | 
| 30 | 
         
            +
            import uuid
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            from distutils.util import strtobool
         
     | 
| 33 | 
         
            +
            from typing import Any, List, Tuple, Union
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            # Util classes
         
     | 
| 37 | 
         
            +
            # ------------------------------------------------------------------------------------------
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class EasyDict(dict):
         
     | 
| 41 | 
         
            +
                """Convenience class that behaves like a dict but allows access with the attribute syntax."""
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def __getattr__(self, name: str) -> Any:
         
     | 
| 44 | 
         
            +
                    try:
         
     | 
| 45 | 
         
            +
                        return self[name]
         
     | 
| 46 | 
         
            +
                    except KeyError:
         
     | 
| 47 | 
         
            +
                        raise AttributeError(name)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def __setattr__(self, name: str, value: Any) -> None:
         
     | 
| 50 | 
         
            +
                    self[name] = value
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def __delattr__(self, name: str) -> None:
         
     | 
| 53 | 
         
            +
                    del self[name]
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            class Logger(object):
         
     | 
| 57 | 
         
            +
                """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
         
     | 
| 60 | 
         
            +
                    self.file = None
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    if file_name is not None:
         
     | 
| 63 | 
         
            +
                        self.file = open(file_name, file_mode)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    self.should_flush = should_flush
         
     | 
| 66 | 
         
            +
                    self.stdout = sys.stdout
         
     | 
| 67 | 
         
            +
                    self.stderr = sys.stderr
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    sys.stdout = self
         
     | 
| 70 | 
         
            +
                    sys.stderr = self
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def __enter__(self) -> "Logger":
         
     | 
| 73 | 
         
            +
                    return self
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
         
     | 
| 76 | 
         
            +
                    self.close()
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def write(self, text: str) -> None:
         
     | 
| 79 | 
         
            +
                    """Write text to stdout (and a file) and optionally flush."""
         
     | 
| 80 | 
         
            +
                    if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
         
     | 
| 81 | 
         
            +
                        return
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    if self.file is not None:
         
     | 
| 84 | 
         
            +
                        self.file.write(text)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    self.stdout.write(text)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    if self.should_flush:
         
     | 
| 89 | 
         
            +
                        self.flush()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def flush(self) -> None:
         
     | 
| 92 | 
         
            +
                    """Flush written text to both stdout and a file, if open."""
         
     | 
| 93 | 
         
            +
                    if self.file is not None:
         
     | 
| 94 | 
         
            +
                        self.file.flush()
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    self.stdout.flush()
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def close(self) -> None:
         
     | 
| 99 | 
         
            +
                    """Flush, close possible files, and remove stdout/stderr mirroring."""
         
     | 
| 100 | 
         
            +
                    self.flush()
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    # if using multiple loggers, prevent closing in wrong order
         
     | 
| 103 | 
         
            +
                    if sys.stdout is self:
         
     | 
| 104 | 
         
            +
                        sys.stdout = self.stdout
         
     | 
| 105 | 
         
            +
                    if sys.stderr is self:
         
     | 
| 106 | 
         
            +
                        sys.stderr = self.stderr
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    if self.file is not None:
         
     | 
| 109 | 
         
            +
                        self.file.close()
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            # Cache directories
         
     | 
| 113 | 
         
            +
            # ------------------------------------------------------------------------------------------
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            _dnnlib_cache_dir = None
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def set_cache_dir(path: str) -> None:
         
     | 
| 118 | 
         
            +
                global _dnnlib_cache_dir
         
     | 
| 119 | 
         
            +
                _dnnlib_cache_dir = path
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            def make_cache_dir_path(*paths: str) -> str:
         
     | 
| 122 | 
         
            +
                if _dnnlib_cache_dir is not None:
         
     | 
| 123 | 
         
            +
                    return os.path.join(_dnnlib_cache_dir, *paths)
         
     | 
| 124 | 
         
            +
                if 'DNNLIB_CACHE_DIR' in os.environ:
         
     | 
| 125 | 
         
            +
                    return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
         
     | 
| 126 | 
         
            +
                if 'HOME' in os.environ:
         
     | 
| 127 | 
         
            +
                    return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
         
     | 
| 128 | 
         
            +
                if 'USERPROFILE' in os.environ:
         
     | 
| 129 | 
         
            +
                    return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
         
     | 
| 130 | 
         
            +
                return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            # Small util functions
         
     | 
| 133 | 
         
            +
            # ------------------------------------------------------------------------------------------
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            def format_time(seconds: Union[int, float]) -> str:
         
     | 
| 137 | 
         
            +
                """Convert the seconds to human readable string with days, hours, minutes and seconds."""
         
     | 
| 138 | 
         
            +
                s = int(np.rint(seconds))
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                if s < 60:
         
     | 
| 141 | 
         
            +
                    return "{0}s".format(s)
         
     | 
| 142 | 
         
            +
                elif s < 60 * 60:
         
     | 
| 143 | 
         
            +
                    return "{0}m {1:02}s".format(s // 60, s % 60)
         
     | 
| 144 | 
         
            +
                elif s < 24 * 60 * 60:
         
     | 
| 145 | 
         
            +
                    return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
         
     | 
| 146 | 
         
            +
                else:
         
     | 
| 147 | 
         
            +
                    return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            def ask_yes_no(question: str) -> bool:
         
     | 
| 151 | 
         
            +
                """Ask the user the question until the user inputs a valid answer."""
         
     | 
| 152 | 
         
            +
                while True:
         
     | 
| 153 | 
         
            +
                    try:
         
     | 
| 154 | 
         
            +
                        print("{0} [y/n]".format(question))
         
     | 
| 155 | 
         
            +
                        return strtobool(input().lower())
         
     | 
| 156 | 
         
            +
                    except ValueError:
         
     | 
| 157 | 
         
            +
                        pass
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            def tuple_product(t: Tuple) -> Any:
         
     | 
| 161 | 
         
            +
                """Calculate the product of the tuple elements."""
         
     | 
| 162 | 
         
            +
                result = 1
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                for v in t:
         
     | 
| 165 | 
         
            +
                    result *= v
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                return result
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            _str_to_ctype = {
         
     | 
| 171 | 
         
            +
                "uint8": ctypes.c_ubyte,
         
     | 
| 172 | 
         
            +
                "uint16": ctypes.c_uint16,
         
     | 
| 173 | 
         
            +
                "uint32": ctypes.c_uint32,
         
     | 
| 174 | 
         
            +
                "uint64": ctypes.c_uint64,
         
     | 
| 175 | 
         
            +
                "int8": ctypes.c_byte,
         
     | 
| 176 | 
         
            +
                "int16": ctypes.c_int16,
         
     | 
| 177 | 
         
            +
                "int32": ctypes.c_int32,
         
     | 
| 178 | 
         
            +
                "int64": ctypes.c_int64,
         
     | 
| 179 | 
         
            +
                "float32": ctypes.c_float,
         
     | 
| 180 | 
         
            +
                "float64": ctypes.c_double
         
     | 
| 181 | 
         
            +
            }
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
         
     | 
| 185 | 
         
            +
                """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
         
     | 
| 186 | 
         
            +
                type_str = None
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                if isinstance(type_obj, str):
         
     | 
| 189 | 
         
            +
                    type_str = type_obj
         
     | 
| 190 | 
         
            +
                elif hasattr(type_obj, "__name__"):
         
     | 
| 191 | 
         
            +
                    type_str = type_obj.__name__
         
     | 
| 192 | 
         
            +
                elif hasattr(type_obj, "name"):
         
     | 
| 193 | 
         
            +
                    type_str = type_obj.name
         
     | 
| 194 | 
         
            +
                else:
         
     | 
| 195 | 
         
            +
                    raise RuntimeError("Cannot infer type name from input")
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                assert type_str in _str_to_ctype.keys()
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                my_dtype = np.dtype(type_str)
         
     | 
| 200 | 
         
            +
                my_ctype = _str_to_ctype[type_str]
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                return my_dtype, my_ctype
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
            def is_pickleable(obj: Any) -> bool:
         
     | 
| 208 | 
         
            +
                try:
         
     | 
| 209 | 
         
            +
                    with io.BytesIO() as stream:
         
     | 
| 210 | 
         
            +
                        pickle.dump(obj, stream)
         
     | 
| 211 | 
         
            +
                    return True
         
     | 
| 212 | 
         
            +
                except:
         
     | 
| 213 | 
         
            +
                    return False
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            # Functionality to import modules/objects by name, and call functions by name
         
     | 
| 217 | 
         
            +
            # ------------------------------------------------------------------------------------------
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
            def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
         
     | 
| 220 | 
         
            +
                """Searches for the underlying module behind the name to some python object.
         
     | 
| 221 | 
         
            +
                Returns the module and the object name (original name with module part removed)."""
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                # allow convenience shorthands, substitute them by full names
         
     | 
| 224 | 
         
            +
                obj_name = re.sub("^np.", "numpy.", obj_name)
         
     | 
| 225 | 
         
            +
                obj_name = re.sub("^tf.", "tensorflow.", obj_name)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                # list alternatives for (module_name, local_obj_name)
         
     | 
| 228 | 
         
            +
                parts = obj_name.split(".")
         
     | 
| 229 | 
         
            +
                name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                # try each alternative in turn
         
     | 
| 232 | 
         
            +
                for module_name, local_obj_name in name_pairs:
         
     | 
| 233 | 
         
            +
                    try:
         
     | 
| 234 | 
         
            +
                        module = importlib.import_module(module_name) # may raise ImportError
         
     | 
| 235 | 
         
            +
                        get_obj_from_module(module, local_obj_name) # may raise AttributeError
         
     | 
| 236 | 
         
            +
                        return module, local_obj_name
         
     | 
| 237 | 
         
            +
                    except:
         
     | 
| 238 | 
         
            +
                        pass
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                # maybe some of the modules themselves contain errors?
         
     | 
| 241 | 
         
            +
                for module_name, _local_obj_name in name_pairs:
         
     | 
| 242 | 
         
            +
                    try:
         
     | 
| 243 | 
         
            +
                        importlib.import_module(module_name) # may raise ImportError
         
     | 
| 244 | 
         
            +
                    except ImportError:
         
     | 
| 245 | 
         
            +
                        if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
         
     | 
| 246 | 
         
            +
                            raise
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                # maybe the requested attribute is missing?
         
     | 
| 249 | 
         
            +
                for module_name, local_obj_name in name_pairs:
         
     | 
| 250 | 
         
            +
                    try:
         
     | 
| 251 | 
         
            +
                        module = importlib.import_module(module_name) # may raise ImportError
         
     | 
| 252 | 
         
            +
                        get_obj_from_module(module, local_obj_name) # may raise AttributeError
         
     | 
| 253 | 
         
            +
                    except ImportError:
         
     | 
| 254 | 
         
            +
                        pass
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                # we are out of luck, but we have no idea why
         
     | 
| 257 | 
         
            +
                raise ImportError(obj_name)
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
            def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
         
     | 
| 261 | 
         
            +
                """Traverses the object name and returns the last (rightmost) python object."""
         
     | 
| 262 | 
         
            +
                if obj_name == '':
         
     | 
| 263 | 
         
            +
                    return module
         
     | 
| 264 | 
         
            +
                obj = module
         
     | 
| 265 | 
         
            +
                for part in obj_name.split("."):
         
     | 
| 266 | 
         
            +
                    obj = getattr(obj, part)
         
     | 
| 267 | 
         
            +
                return obj
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            def get_obj_by_name(name: str) -> Any:
         
     | 
| 271 | 
         
            +
                """Finds the python object with the given name."""
         
     | 
| 272 | 
         
            +
                module, obj_name = get_module_from_obj_name(name)
         
     | 
| 273 | 
         
            +
                return get_obj_from_module(module, obj_name)
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
            def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
         
     | 
| 277 | 
         
            +
                """Finds the python object with the given name and calls it as a function."""
         
     | 
| 278 | 
         
            +
                assert func_name is not None
         
     | 
| 279 | 
         
            +
                func_obj = get_obj_by_name(func_name)
         
     | 
| 280 | 
         
            +
                assert callable(func_obj)
         
     | 
| 281 | 
         
            +
                return func_obj(*args, **kwargs)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
            def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
         
     | 
| 285 | 
         
            +
                """Finds the python class with the given name and constructs it with the given arguments."""
         
     | 
| 286 | 
         
            +
                return call_func_by_name(*args, func_name=class_name, **kwargs)
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
            def get_module_dir_by_obj_name(obj_name: str) -> str:
         
     | 
| 290 | 
         
            +
                """Get the directory path of the module containing the given object name."""
         
     | 
| 291 | 
         
            +
                module, _ = get_module_from_obj_name(obj_name)
         
     | 
| 292 | 
         
            +
                return os.path.dirname(inspect.getfile(module))
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
            def is_top_level_function(obj: Any) -> bool:
         
     | 
| 296 | 
         
            +
                """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
         
     | 
| 297 | 
         
            +
                return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
            def get_top_level_function_name(obj: Any) -> str:
         
     | 
| 301 | 
         
            +
                """Return the fully-qualified name of a top-level function."""
         
     | 
| 302 | 
         
            +
                assert is_top_level_function(obj)
         
     | 
| 303 | 
         
            +
                module = obj.__module__
         
     | 
| 304 | 
         
            +
                if module == '__main__':
         
     | 
| 305 | 
         
            +
                    module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
         
     | 
| 306 | 
         
            +
                return module + "." + obj.__name__
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
            # File system helpers
         
     | 
| 310 | 
         
            +
            # ------------------------------------------------------------------------------------------
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
            def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
         
     | 
| 313 | 
         
            +
                """List all files recursively in a given directory while ignoring given file and directory names.
         
     | 
| 314 | 
         
            +
                Returns list of tuples containing both absolute and relative paths."""
         
     | 
| 315 | 
         
            +
                assert os.path.isdir(dir_path)
         
     | 
| 316 | 
         
            +
                base_name = os.path.basename(os.path.normpath(dir_path))
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                if ignores is None:
         
     | 
| 319 | 
         
            +
                    ignores = []
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                result = []
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                for root, dirs, files in os.walk(dir_path, topdown=True):
         
     | 
| 324 | 
         
            +
                    for ignore_ in ignores:
         
     | 
| 325 | 
         
            +
                        dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                        # dirs need to be edited in-place
         
     | 
| 328 | 
         
            +
                        for d in dirs_to_remove:
         
     | 
| 329 | 
         
            +
                            dirs.remove(d)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                        files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    absolute_paths = [os.path.join(root, f) for f in files]
         
     | 
| 334 | 
         
            +
                    relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    if add_base_to_relative:
         
     | 
| 337 | 
         
            +
                        relative_paths = [os.path.join(base_name, p) for p in relative_paths]
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    assert len(absolute_paths) == len(relative_paths)
         
     | 
| 340 | 
         
            +
                    result += zip(absolute_paths, relative_paths)
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                return result
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
            def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
         
     | 
| 346 | 
         
            +
                """Takes in a list of tuples of (src, dst) paths and copies files.
         
     | 
| 347 | 
         
            +
                Will create all necessary directories."""
         
     | 
| 348 | 
         
            +
                for file in files:
         
     | 
| 349 | 
         
            +
                    target_dir_name = os.path.dirname(file[1])
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                    # will create all intermediate-level directories
         
     | 
| 352 | 
         
            +
                    if not os.path.exists(target_dir_name):
         
     | 
| 353 | 
         
            +
                        os.makedirs(target_dir_name)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    shutil.copyfile(file[0], file[1])
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
            # URL helpers
         
     | 
| 359 | 
         
            +
            # ------------------------------------------------------------------------------------------
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
            def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
         
     | 
| 362 | 
         
            +
                """Determine whether the given object is a valid URL string."""
         
     | 
| 363 | 
         
            +
                if not isinstance(obj, str) or not "://" in obj:
         
     | 
| 364 | 
         
            +
                    return False
         
     | 
| 365 | 
         
            +
                if allow_file_urls and obj.startswith('file://'):
         
     | 
| 366 | 
         
            +
                    return True
         
     | 
| 367 | 
         
            +
                try:
         
     | 
| 368 | 
         
            +
                    res = requests.compat.urlparse(obj)
         
     | 
| 369 | 
         
            +
                    if not res.scheme or not res.netloc or not "." in res.netloc:
         
     | 
| 370 | 
         
            +
                        return False
         
     | 
| 371 | 
         
            +
                    res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
         
     | 
| 372 | 
         
            +
                    if not res.scheme or not res.netloc or not "." in res.netloc:
         
     | 
| 373 | 
         
            +
                        return False
         
     | 
| 374 | 
         
            +
                except:
         
     | 
| 375 | 
         
            +
                    return False
         
     | 
| 376 | 
         
            +
                return True
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
            def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
         
     | 
| 380 | 
         
            +
                """Download the given URL and return a binary-mode file object to access the data."""
         
     | 
| 381 | 
         
            +
                assert num_attempts >= 1
         
     | 
| 382 | 
         
            +
                assert not (return_filename and (not cache))
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                # Doesn't look like an URL scheme so interpret it as a local filename.
         
     | 
| 385 | 
         
            +
                if not re.match('^[a-z]+://', url):
         
     | 
| 386 | 
         
            +
                    return url if return_filename else open(url, "rb")
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                # Handle file URLs.  This code handles unusual file:// patterns that
         
     | 
| 389 | 
         
            +
                # arise on Windows:
         
     | 
| 390 | 
         
            +
                #
         
     | 
| 391 | 
         
            +
                # file:///c:/foo.txt
         
     | 
| 392 | 
         
            +
                #
         
     | 
| 393 | 
         
            +
                # which would translate to a local '/c:/foo.txt' filename that's
         
     | 
| 394 | 
         
            +
                # invalid.  Drop the forward slash for such pathnames.
         
     | 
| 395 | 
         
            +
                #
         
     | 
| 396 | 
         
            +
                # If you touch this code path, you should test it on both Linux and
         
     | 
| 397 | 
         
            +
                # Windows.
         
     | 
| 398 | 
         
            +
                #
         
     | 
| 399 | 
         
            +
                # Some internet resources suggest using urllib.request.url2pathname() but
         
     | 
| 400 | 
         
            +
                # but that converts forward slashes to backslashes and this causes
         
     | 
| 401 | 
         
            +
                # its own set of problems.
         
     | 
| 402 | 
         
            +
                if url.startswith('file://'):
         
     | 
| 403 | 
         
            +
                    filename = urllib.parse.urlparse(url).path
         
     | 
| 404 | 
         
            +
                    if re.match(r'^/[a-zA-Z]:', filename):
         
     | 
| 405 | 
         
            +
                        filename = filename[1:]
         
     | 
| 406 | 
         
            +
                    return filename if return_filename else open(filename, "rb")
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                assert is_url(url)
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                # Lookup from cache.
         
     | 
| 411 | 
         
            +
                if cache_dir is None:
         
     | 
| 412 | 
         
            +
                    cache_dir = make_cache_dir_path('downloads')
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
         
     | 
| 415 | 
         
            +
                if cache:
         
     | 
| 416 | 
         
            +
                    cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
         
     | 
| 417 | 
         
            +
                    if len(cache_files) == 1:
         
     | 
| 418 | 
         
            +
                        filename = cache_files[0]
         
     | 
| 419 | 
         
            +
                        return filename if return_filename else open(filename, "rb")
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                # Download.
         
     | 
| 422 | 
         
            +
                url_name = None
         
     | 
| 423 | 
         
            +
                url_data = None
         
     | 
| 424 | 
         
            +
                with requests.Session() as session:
         
     | 
| 425 | 
         
            +
                    if verbose:
         
     | 
| 426 | 
         
            +
                        print("Downloading %s ..." % url, end="", flush=True)
         
     | 
| 427 | 
         
            +
                    for attempts_left in reversed(range(num_attempts)):
         
     | 
| 428 | 
         
            +
                        try:
         
     | 
| 429 | 
         
            +
                            with session.get(url) as res:
         
     | 
| 430 | 
         
            +
                                res.raise_for_status()
         
     | 
| 431 | 
         
            +
                                if len(res.content) == 0:
         
     | 
| 432 | 
         
            +
                                    raise IOError("No data received")
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                                if len(res.content) < 8192:
         
     | 
| 435 | 
         
            +
                                    content_str = res.content.decode("utf-8")
         
     | 
| 436 | 
         
            +
                                    if "download_warning" in res.headers.get("Set-Cookie", ""):
         
     | 
| 437 | 
         
            +
                                        links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
         
     | 
| 438 | 
         
            +
                                        if len(links) == 1:
         
     | 
| 439 | 
         
            +
                                            url = requests.compat.urljoin(url, links[0])
         
     | 
| 440 | 
         
            +
                                            raise IOError("Google Drive virus checker nag")
         
     | 
| 441 | 
         
            +
                                    if "Google Drive - Quota exceeded" in content_str:
         
     | 
| 442 | 
         
            +
                                        raise IOError("Google Drive download quota exceeded -- please try again later")
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                                match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
         
     | 
| 445 | 
         
            +
                                url_name = match[1] if match else url
         
     | 
| 446 | 
         
            +
                                url_data = res.content
         
     | 
| 447 | 
         
            +
                                if verbose:
         
     | 
| 448 | 
         
            +
                                    print(" done")
         
     | 
| 449 | 
         
            +
                                break
         
     | 
| 450 | 
         
            +
                        except:
         
     | 
| 451 | 
         
            +
                            if not attempts_left:
         
     | 
| 452 | 
         
            +
                                if verbose:
         
     | 
| 453 | 
         
            +
                                    print(" failed")
         
     | 
| 454 | 
         
            +
                                raise
         
     | 
| 455 | 
         
            +
                            if verbose:
         
     | 
| 456 | 
         
            +
                                print(".", end="", flush=True)
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                # Save to cache.
         
     | 
| 459 | 
         
            +
                if cache:
         
     | 
| 460 | 
         
            +
                    safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
         
     | 
| 461 | 
         
            +
                    cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
         
     | 
| 462 | 
         
            +
                    temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
         
     | 
| 463 | 
         
            +
                    os.makedirs(cache_dir, exist_ok=True)
         
     | 
| 464 | 
         
            +
                    with open(temp_file, "wb") as f:
         
     | 
| 465 | 
         
            +
                        f.write(url_data)
         
     | 
| 466 | 
         
            +
                    os.replace(temp_file, cache_file) # atomic
         
     | 
| 467 | 
         
            +
                    if return_filename:
         
     | 
| 468 | 
         
            +
                        return cache_file
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                # Return data as file object.
         
     | 
| 471 | 
         
            +
                assert not return_filename
         
     | 
| 472 | 
         
            +
                return io.BytesIO(url_data)
         
     | 
    	
        gallery/gallery.md
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ## Gallery
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            ### Mosaics 
         
     | 
| 4 | 
         
            +
            
         
     | 
| 5 | 
         
            +
            
         
     | 
| 6 | 
         
            +
            
         
     | 
| 7 | 
         
            +
            
         
     | 
| 8 | 
         
            +
            
         
     | 
| 9 | 
         
            +
            
         
     | 
| 10 | 
         
            +
            
         
     | 
| 11 | 
         
            +
            
         
     | 
| 12 | 
         
            +
            
         
     | 
| 13 | 
         
            +
            
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
    	
        gallery/gl-mosaics1.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics10.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics2.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics3.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics4.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics5.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics6.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics7.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics8.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        gallery/gl-mosaics9.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        generate.py
    ADDED
    
    | 
         @@ -0,0 +1,700 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # NVIDIA CORPORATION and its licensors retain all intellectual property
         
     | 
| 4 | 
         
            +
            # and proprietary rights in and to this software, related documentation
         
     | 
| 5 | 
         
            +
            # and any modifications thereto.  Any use, reproduction, disclosure or
         
     | 
| 6 | 
         
            +
            # distribution of this software and related documentation without an express
         
     | 
| 7 | 
         
            +
            # license agreement from NVIDIA CORPORATION is strictly prohibited.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            """Generate images using pretrained network pickle."""
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import argparse
         
     | 
| 13 | 
         
            +
            import sys
         
     | 
| 14 | 
         
            +
            import os
         
     | 
| 15 | 
         
            +
            import subprocess
         
     | 
| 16 | 
         
            +
            import pickle
         
     | 
| 17 | 
         
            +
            import re
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            import scipy
         
     | 
| 20 | 
         
            +
            import numpy as np
         
     | 
| 21 | 
         
            +
            import PIL.Image
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            import dnnlib
         
     | 
| 24 | 
         
            +
            import dnnlib.tflib as tflib
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
         
     | 
| 27 | 
         
            +
            import moviepy.editor
         
     | 
| 28 | 
         
            +
            from opensimplex import OpenSimplex
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            import warnings # mostly numpy warnings for me
         
     | 
| 31 | 
         
            +
            warnings.filterwarnings('ignore', category=FutureWarning)
         
     | 
| 32 | 
         
            +
            warnings.filterwarnings('ignore', category=DeprecationWarning)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def create_image_grid(images, grid_size=None):
         
     | 
| 37 | 
         
            +
                '''
         
     | 
| 38 | 
         
            +
                Args:
         
     | 
| 39 | 
         
            +
                    images (np.array): images to place on the grid
         
     | 
| 40 | 
         
            +
                    grid_size (tuple(int, int)): size of grid (grid_w, grid_h)
         
     | 
| 41 | 
         
            +
                Returns:
         
     | 
| 42 | 
         
            +
                    grid (np.array): image grid of size grid_size
         
     | 
| 43 | 
         
            +
                '''
         
     | 
| 44 | 
         
            +
                # Some sanity check:
         
     | 
| 45 | 
         
            +
                assert images.ndim == 3 or images.ndim == 4
         
     | 
| 46 | 
         
            +
                num, img_h, img_w = images.shape[0], images.shape[1], images.shape[2]
         
     | 
| 47 | 
         
            +
                if grid_size is not None:
         
     | 
| 48 | 
         
            +
                    grid_w, grid_h = tuple(grid_size)
         
     | 
| 49 | 
         
            +
                else:
         
     | 
| 50 | 
         
            +
                    grid_w = max(int(np.ceil(np.sqrt(num))), 1)
         
     | 
| 51 | 
         
            +
                    grid_h = max((num - 1) // grid_w + 1, 1)
         
     | 
| 52 | 
         
            +
                # Get the grid
         
     | 
| 53 | 
         
            +
                grid = np.zeros(
         
     | 
| 54 | 
         
            +
                    [grid_h * img_h, grid_w * img_w] + list(images.shape[-1:]), dtype=images.dtype
         
     | 
| 55 | 
         
            +
                )
         
     | 
| 56 | 
         
            +
                for idx in range(num):
         
     | 
| 57 | 
         
            +
                    x = (idx % grid_w) * img_w
         
     | 
| 58 | 
         
            +
                    y = (idx // grid_w) * img_h
         
     | 
| 59 | 
         
            +
                    grid[y : y + img_h, x : x + img_w, ...] = images[idx]
         
     | 
| 60 | 
         
            +
                return grid
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def generate_images(network_pkl, seeds, truncation_psi, outdir, class_idx=None, dlatents_npz=None, grid=False):
         
     | 
| 65 | 
         
            +
                tflib.init_tf()
         
     | 
| 66 | 
         
            +
                print('Loading networks from "%s"...' % network_pkl)
         
     | 
| 67 | 
         
            +
                with dnnlib.util.open_url(network_pkl) as fp:
         
     | 
| 68 | 
         
            +
                    _G, _D, Gs = pickle.load(fp)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                os.makedirs(outdir, exist_ok=True)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                # Render images for a given dlatent vector.
         
     | 
| 73 | 
         
            +
                if dlatents_npz is not None:
         
     | 
| 74 | 
         
            +
                    print(f'Generating images from dlatents file "{dlatents_npz}"')
         
     | 
| 75 | 
         
            +
                    dlatents = np.load(dlatents_npz)['dlatents']
         
     | 
| 76 | 
         
            +
                    max_l = 2 * int(np.log2(Gs.output_shape[-1]) - 1)  # max_l=18 for 1024x1024 models
         
     | 
| 77 | 
         
            +
                    if dlatents.shape[1:] != (max_l, 512):  # [N, max_l, 512]
         
     | 
| 78 | 
         
            +
                        actual_size = int(2**(dlatents.shape[1]//2+1))
         
     | 
| 79 | 
         
            +
                        print(f'''Mismatch of loaded dlatents and network! dlatents was created with network of size: {actual_size}\n
         
     | 
| 80 | 
         
            +
                               {network_pkl} is of size {Gs.output_shape[-1]}''')
         
     | 
| 81 | 
         
            +
                        sys.exit(1)
         
     | 
| 82 | 
         
            +
                    imgs = Gs.components.synthesis.run(dlatents, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))
         
     | 
| 83 | 
         
            +
                    for i, img in enumerate(imgs):
         
     | 
| 84 | 
         
            +
                        fname = f'{outdir}/dlatent{i:02d}.png'
         
     | 
| 85 | 
         
            +
                        print (f'Saved {fname}')
         
     | 
| 86 | 
         
            +
                        PIL.Image.fromarray(img, 'RGB').save(fname)
         
     | 
| 87 | 
         
            +
                    return
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                # Render images for dlatents initialized from random seeds.
         
     | 
| 90 | 
         
            +
                Gs_kwargs = {
         
     | 
| 91 | 
         
            +
                    'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
         
     | 
| 92 | 
         
            +
                    'randomize_noise': False
         
     | 
| 93 | 
         
            +
                }
         
     | 
| 94 | 
         
            +
                if truncation_psi is not None:
         
     | 
| 95 | 
         
            +
                    Gs_kwargs['truncation_psi'] = truncation_psi
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
         
     | 
| 98 | 
         
            +
                label = np.zeros([1] + Gs.input_shapes[1][1:])
         
     | 
| 99 | 
         
            +
                if class_idx is not None:
         
     | 
| 100 | 
         
            +
                    label[:, class_idx] = 1
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                images = []
         
     | 
| 103 | 
         
            +
                for seed_idx, seed in enumerate(seeds):
         
     | 
| 104 | 
         
            +
                    print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
         
     | 
| 105 | 
         
            +
                    rnd = np.random.RandomState(seed)
         
     | 
| 106 | 
         
            +
                    z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
         
     | 
| 107 | 
         
            +
                    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
         
     | 
| 108 | 
         
            +
                    image = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 109 | 
         
            +
                    images.append(image[0])
         
     | 
| 110 | 
         
            +
                    PIL.Image.fromarray(image[0], 'RGB').save(f'{outdir}/seed{seed:04d}.png')
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                # If user wants to save a grid of the generated images
         
     | 
| 113 | 
         
            +
                if grid:
         
     | 
| 114 | 
         
            +
                    print('Generating image grid...')
         
     | 
| 115 | 
         
            +
                    PIL.Image.fromarray(create_image_grid(np.array(images)), 'RGB').save(f'{outdir}/grid.png')
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            def truncation_traversal(network_pkl,npys,outdir,class_idx=None, seed=[0],start=-1.0,stop=1.0,increment=0.1,framerate=24):
         
     | 
| 120 | 
         
            +
                tflib.init_tf()
         
     | 
| 121 | 
         
            +
                print('Loading networks from "%s"...' % network_pkl)
         
     | 
| 122 | 
         
            +
                with dnnlib.util.open_url(network_pkl) as fp:
         
     | 
| 123 | 
         
            +
                    _G, _D, Gs = pickle.load(fp)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                os.makedirs(outdir, exist_ok=True)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                Gs_kwargs = {
         
     | 
| 128 | 
         
            +
                    'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
         
     | 
| 129 | 
         
            +
                    'randomize_noise': False
         
     | 
| 130 | 
         
            +
                }
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
         
     | 
| 133 | 
         
            +
                label = np.zeros([1] + Gs.input_shapes[1][1:])
         
     | 
| 134 | 
         
            +
                if class_idx is not None:
         
     | 
| 135 | 
         
            +
                    label[:, class_idx] = 1
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                count = 1
         
     | 
| 138 | 
         
            +
                trunc = start
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                images = []
         
     | 
| 141 | 
         
            +
                while trunc <= stop:
         
     | 
| 142 | 
         
            +
                    Gs_kwargs['truncation_psi'] = trunc
         
     | 
| 143 | 
         
            +
                    print('Generating truncation %0.2f' % trunc)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    rnd = np.random.RandomState(seed)
         
     | 
| 146 | 
         
            +
                    z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
         
     | 
| 147 | 
         
            +
                    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
         
     | 
| 148 | 
         
            +
                    image = Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 149 | 
         
            +
                    images.append(image[0])
         
     | 
| 150 | 
         
            +
                    PIL.Image.fromarray(image[0], 'RGB').save(f'{outdir}/frame{count:05d}.png')
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    trunc+=increment
         
     | 
| 153 | 
         
            +
                    count+=1
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                cmd="ffmpeg -y -r {} -i {}/frame%05d.png -vcodec libx264 -pix_fmt yuv420p {}/truncation-traversal-seed{}-start{}-stop{}.mp4".format(framerate,outdir,outdir,seed[0],start,stop)
         
     | 
| 156 | 
         
            +
                subprocess.call(cmd, shell=True)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            def valmap(value, istart, istop, ostart, ostop):
         
     | 
| 161 | 
         
            +
              return ostart + (ostop - ostart) * ((value - istart) / (istop - istart))
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            class OSN():
         
     | 
| 164 | 
         
            +
              min=-1
         
     | 
| 165 | 
         
            +
              max= 1
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
              def __init__(self,seed,diameter):
         
     | 
| 168 | 
         
            +
                self.tmp = OpenSimplex(seed)
         
     | 
| 169 | 
         
            +
                self.d = diameter
         
     | 
| 170 | 
         
            +
                self.x = 0
         
     | 
| 171 | 
         
            +
                self.y = 0
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
              def get_val(self,angle):
         
     | 
| 174 | 
         
            +
                self.xoff = valmap(np.cos(angle), -1, 1, self.x, self.x + self.d);
         
     | 
| 175 | 
         
            +
                self.yoff = valmap(np.sin(angle), -1, 1, self.y, self.y + self.d);
         
     | 
| 176 | 
         
            +
                return self.tmp.noise2d(self.xoff,self.yoff)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            def get_noiseloop(endpoints, nf, d, start_seed):
         
     | 
| 179 | 
         
            +
                features = []
         
     | 
| 180 | 
         
            +
                zs = []
         
     | 
| 181 | 
         
            +
                for i in range(512):
         
     | 
| 182 | 
         
            +
                  features.append(OSN(i+start_seed,d))
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                inc = (np.pi*2)/nf
         
     | 
| 185 | 
         
            +
                for f in range(nf):
         
     | 
| 186 | 
         
            +
                  z = np.random.randn(1, 512)
         
     | 
| 187 | 
         
            +
                  for i in range(512):
         
     | 
| 188 | 
         
            +
                    z[0,i] = features[i].get_val(inc*f)
         
     | 
| 189 | 
         
            +
                  zs.append(z)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                return zs
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            def line_interpolate(zs, steps):
         
     | 
| 194 | 
         
            +
               out = []
         
     | 
| 195 | 
         
            +
               for i in range(len(zs)-1):
         
     | 
| 196 | 
         
            +
                for index in range(steps):
         
     | 
| 197 | 
         
            +
                 fraction = index/float(steps)
         
     | 
| 198 | 
         
            +
                 out.append(zs[i+1]*fraction + zs[i]*(1-fraction))
         
     | 
| 199 | 
         
            +
               return out
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            def generate_zs_from_seeds(seeds,Gs):
         
     | 
| 202 | 
         
            +
                zs = []
         
     | 
| 203 | 
         
            +
                for seed_idx, seed in enumerate(seeds):
         
     | 
| 204 | 
         
            +
                    rnd = np.random.RandomState(seed)
         
     | 
| 205 | 
         
            +
                    z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
         
     | 
| 206 | 
         
            +
                    zs.append(z)
         
     | 
| 207 | 
         
            +
                return zs
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
            def convertZtoW(latent, truncation_psi=0.7, truncation_cutoff=9):
         
     | 
| 210 | 
         
            +
                dlatent = Gs.components.mapping.run(latent, None) # [seed, layer, component]
         
     | 
| 211 | 
         
            +
                dlatent_avg = Gs.get_var('dlatent_avg') # [component]
         
     | 
| 212 | 
         
            +
                dlatent = dlatent_avg + (dlatent - dlatent_avg) * truncation_psi
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                return dlatent
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            def generate_latent_images(zs, truncation_psi, outdir, save_npy,prefix,vidname,framerate):
         
     | 
| 217 | 
         
            +
                Gs_kwargs = {
         
     | 
| 218 | 
         
            +
                    'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
         
     | 
| 219 | 
         
            +
                    'randomize_noise': False
         
     | 
| 220 | 
         
            +
                }
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                if not isinstance(truncation_psi, list):
         
     | 
| 223 | 
         
            +
                    truncation_psi = [truncation_psi] * len(zs)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                for z_idx, z in enumerate(zs):
         
     | 
| 226 | 
         
            +
                    if isinstance(z,list):
         
     | 
| 227 | 
         
            +
                      z = np.array(z).reshape(1,512)
         
     | 
| 228 | 
         
            +
                    elif isinstance(z,np.ndarray):
         
     | 
| 229 | 
         
            +
                      z.reshape(1,512)
         
     | 
| 230 | 
         
            +
                    print('Generating image for step %d/%d ...' % (z_idx, len(zs)))
         
     | 
| 231 | 
         
            +
                    Gs_kwargs['truncation_psi'] = truncation_psi[z_idx]
         
     | 
| 232 | 
         
            +
                    noise_rnd = np.random.RandomState(1) # fix noise
         
     | 
| 233 | 
         
            +
                    tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
         
     | 
| 234 | 
         
            +
                    images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 235 | 
         
            +
                    PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/frames/{prefix}{z_idx:05d}.png')
         
     | 
| 236 | 
         
            +
                    if save_npy:
         
     | 
| 237 | 
         
            +
                      np.save(dnnlib.make_run_dir_path('%s%05d.npy' % (prefix,z_idx)), z)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                cmd="ffmpeg -y -r {} -i {}/frames/{}%05d.png -vcodec libx264 -pix_fmt yuv420p {}/walk-{}-{}fps.mp4".format(framerate,outdir,prefix,outdir,vidname,framerate)
         
     | 
| 240 | 
         
            +
                subprocess.call(cmd, shell=True)
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
            def generate_images_in_w_space(ws, truncation_psi,outdir,save_npy,prefix,vidname,framerate):
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                Gs_kwargs = {
         
     | 
| 245 | 
         
            +
                    'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
         
     | 
| 246 | 
         
            +
                    'randomize_noise': False,
         
     | 
| 247 | 
         
            +
                    'truncation_psi': truncation_psi
         
     | 
| 248 | 
         
            +
                }
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                for w_idx, w in enumerate(ws):
         
     | 
| 251 | 
         
            +
                    print('Generating image for step %d/%d ...' % (w_idx, len(ws)))
         
     | 
| 252 | 
         
            +
                    noise_rnd = np.random.RandomState(1) # fix noise
         
     | 
| 253 | 
         
            +
                    tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
         
     | 
| 254 | 
         
            +
                    images = Gs.components.synthesis.run(w, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 255 | 
         
            +
                    PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/frames/{prefix}{w_idx:05d}.png')
         
     | 
| 256 | 
         
            +
                    if save_npy:
         
     | 
| 257 | 
         
            +
                      np.save(dnnlib.make_run_dir_path('%s%05d.npy' % (prefix,w_idx)), w)
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                cmd="ffmpeg -y -r {} -i {}/frames/{}%05d.png -vcodec libx264 -pix_fmt yuv420p {}/walk-{}-{}fps.mp4".format(framerate,outdir,prefix,outdir,vidname,framerate)
         
     | 
| 260 | 
         
            +
                subprocess.call(cmd, shell=True)
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
            def generate_latent_walk(network_pkl, truncation_psi, outdir, walk_type, frames, seeds, npys, save_vector, diameter=2.0, start_seed=0, framerate=24 ):
         
     | 
| 263 | 
         
            +
                global _G, _D, Gs, noise_vars
         
     | 
| 264 | 
         
            +
                tflib.init_tf()
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                print('Loading networks from "%s"...' % network_pkl)
         
     | 
| 267 | 
         
            +
                with dnnlib.util.open_url(network_pkl) as fp:
         
     | 
| 268 | 
         
            +
                    _G, _D, Gs = pickle.load(fp)
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                os.makedirs(outdir, exist_ok=True)
         
     | 
| 271 | 
         
            +
                os.makedirs(outdir+"/frames", exist_ok=True)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                # Render images for dlatents initialized from random seeds.
         
     | 
| 274 | 
         
            +
                Gs_kwargs = {
         
     | 
| 275 | 
         
            +
                    'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
         
     | 
| 276 | 
         
            +
                    'randomize_noise': False,
         
     | 
| 277 | 
         
            +
                    'truncation_psi': truncation_psi
         
     | 
| 278 | 
         
            +
                }
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
         
     | 
| 281 | 
         
            +
                zs = []
         
     | 
| 282 | 
         
            +
                ws =[]
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                # npys specified, let's work with these instead of seeds
         
     | 
| 286 | 
         
            +
                # npys must be saved as W's (arrays of 18x512)
         
     | 
| 287 | 
         
            +
                if npys and (len(npys) > 0):
         
     | 
| 288 | 
         
            +
                    ws = npys
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                wt = walk_type.split('-')
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                if wt[0] == 'line':
         
     | 
| 294 | 
         
            +
                    if seeds and (len(seeds) > 0):
         
     | 
| 295 | 
         
            +
                        zs = generate_zs_from_seeds(seeds,Gs)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    if ws == []:
         
     | 
| 298 | 
         
            +
                        number_of_steps = int(frames/(len(zs)-1))+1
         
     | 
| 299 | 
         
            +
                    else:
         
     | 
| 300 | 
         
            +
                        number_of_steps = int(frames/(len(ws)-1))+1
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                    if (len(wt)>1 and wt[1] == 'w'):
         
     | 
| 303 | 
         
            +
                      if ws == []:
         
     | 
| 304 | 
         
            +
                        for i in range(len(zs)):
         
     | 
| 305 | 
         
            +
                          ws.append(convertZtoW(zs[i]))
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                      points = line_interpolate(ws,number_of_steps)
         
     | 
| 308 | 
         
            +
                      zpoints = line_interpolate(zs,number_of_steps)
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    else:
         
     | 
| 311 | 
         
            +
                      points = line_interpolate(zs,number_of_steps)
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                # from Gene Kogan
         
     | 
| 315 | 
         
            +
                elif wt[0] == 'bspline':
         
     | 
| 316 | 
         
            +
                    # bspline in w doesnt work yet
         
     | 
| 317 | 
         
            +
                    # if (len(walk_type)>1 and walk_type[1] == 'w'):
         
     | 
| 318 | 
         
            +
                    #   ws = []
         
     | 
| 319 | 
         
            +
                    #   for i in range(len(zs)):
         
     | 
| 320 | 
         
            +
                    #     ws.append(convertZtoW(zs[i]))
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    #   print(ws[0].shape)
         
     | 
| 323 | 
         
            +
                    #   w = []
         
     | 
| 324 | 
         
            +
                    #   for i in range(len(ws)):
         
     | 
| 325 | 
         
            +
                    #     w.append(np.asarray(ws[i]).reshape(512,18))
         
     | 
| 326 | 
         
            +
                    #   points = get_latent_interpolation_bspline(ws,frames,3, 20, shuffle=False)
         
     | 
| 327 | 
         
            +
                    # else:
         
     | 
| 328 | 
         
            +
                      z = []
         
     | 
| 329 | 
         
            +
                      for i in range(len(zs)):
         
     | 
| 330 | 
         
            +
                        z.append(np.asarray(zs[i]).reshape(512))
         
     | 
| 331 | 
         
            +
                      points = get_latent_interpolation_bspline(z,frames,3, 20, shuffle=False)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                # from Dan Shiffman: https://editor.p5js.org/dvs/sketches/Gb0xavYAR
         
     | 
| 334 | 
         
            +
                elif wt[0] == 'noiseloop':
         
     | 
| 335 | 
         
            +
                    points = get_noiseloop(None,frames,diameter,start_seed)
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                if (wt[0] == 'line' and len(wt)>1 and wt[1] == 'w'):
         
     | 
| 338 | 
         
            +
                  # print(points[0][:,:,1])
         
     | 
| 339 | 
         
            +
                  # print(zpoints[0][:,1])
         
     | 
| 340 | 
         
            +
                  # ws = []
         
     | 
| 341 | 
         
            +
                  # for i in enumerate(len(points)):
         
     | 
| 342 | 
         
            +
                  #   ws.append(convertZtoW(points[i]))
         
     | 
| 343 | 
         
            +
                    #added for npys
         
     | 
| 344 | 
         
            +
                    if seeds:
         
     | 
| 345 | 
         
            +
                        seed_out = 'w-' + wt[0] + ('-'.join([str(seed) for seed in seeds]))
         
     | 
| 346 | 
         
            +
                    else:
         
     | 
| 347 | 
         
            +
                        seed_out = 'w-' + wt[0] + '-dlatents'
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    generate_images_in_w_space(points, truncation_psi,outdir,save_vector,'frame', seed_out, framerate)
         
     | 
| 350 | 
         
            +
                elif (len(wt)>1 and wt[1] == 'w'):
         
     | 
| 351 | 
         
            +
                  print('%s is not currently supported in w space, please change your interpolation type' % (wt[0]))
         
     | 
| 352 | 
         
            +
                else:
         
     | 
| 353 | 
         
            +
                    if(len(wt)>1):
         
     | 
| 354 | 
         
            +
                        seed_out = 'z-' + wt[0] + ('-'.join([str(seed) for seed in seeds]))
         
     | 
| 355 | 
         
            +
                    else:
         
     | 
| 356 | 
         
            +
                        seed_out = 'z-' + walk_type + '-seed' +str(start_seed)
         
     | 
| 357 | 
         
            +
                    generate_latent_images(points, truncation_psi, outdir, save_vector,'frame', seed_out, framerate)
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
            def generate_neighbors(network_pkl, seeds, npys, diameter, truncation_psi, num_samples, save_vector, outdir):
         
     | 
| 362 | 
         
            +
                global _G, _D, Gs, noise_vars
         
     | 
| 363 | 
         
            +
                tflib.init_tf()
         
     | 
| 364 | 
         
            +
                print('Loading networks from "%s"...' % network_pkl)
         
     | 
| 365 | 
         
            +
                with dnnlib.util.open_url(network_pkl) as fp:
         
     | 
| 366 | 
         
            +
                    _G, _D, Gs = pickle.load(fp)
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                os.makedirs(outdir, exist_ok=True)
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                # Render images for dlatents initialized from random seeds.
         
     | 
| 371 | 
         
            +
                Gs_kwargs = {
         
     | 
| 372 | 
         
            +
                    'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
         
     | 
| 373 | 
         
            +
                    'randomize_noise': False,
         
     | 
| 374 | 
         
            +
                    'truncation_psi': truncation_psi
         
     | 
| 375 | 
         
            +
                }
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                for seed_idx, seed in enumerate(seeds):
         
     | 
| 380 | 
         
            +
                    print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx+1, len(seeds)))
         
     | 
| 381 | 
         
            +
                    rnd = np.random.RandomState(seed)
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    og_z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
         
     | 
| 384 | 
         
            +
                    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
         
     | 
| 385 | 
         
            +
                    images = Gs.run(og_z, None, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 386 | 
         
            +
                    # PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed))
         
     | 
| 387 | 
         
            +
                    PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/seed{seed:05d}.png')
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    zs = []
         
     | 
| 390 | 
         
            +
                    z_prefix = 'seed%04d_neighbor' % seed
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                    for s in range(num_samples):
         
     | 
| 393 | 
         
            +
                        random = np.random.uniform(-diameter,diameter,[1,512])
         
     | 
| 394 | 
         
            +
            #             zs.append(np.clip((og_z+random),-1,1))
         
     | 
| 395 | 
         
            +
                        new_z = np.clip(np.add(og_z,random),-1,1)
         
     | 
| 396 | 
         
            +
                        images = Gs.run(new_z, None, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 397 | 
         
            +
                        # PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('%s%04d.png' % (z_prefix,s)))
         
     | 
| 398 | 
         
            +
                        PIL.Image.fromarray(images[0], 'RGB').save(f'{outdir}/{z_prefix}{s:05d}.png')
         
     | 
| 399 | 
         
            +
                        # generate_latent_images(zs, truncation_psi, save_vector, z_prefix)
         
     | 
| 400 | 
         
            +
                        if save_vector:
         
     | 
| 401 | 
         
            +
                            np.save(dnnlib.make_run_dir_path('%s%05d.npy' % (z_prefix,s)), new_z)
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
            def lerp_video(network_pkl,                # Path to pretrained model pkl file
         
     | 
| 408 | 
         
            +
                           seeds,                      # Random seeds
         
     | 
| 409 | 
         
            +
                           grid_w=None,                # Number of columns
         
     | 
| 410 | 
         
            +
                           grid_h=None,                # Number of rows
         
     | 
| 411 | 
         
            +
                           truncation_psi=1.0,         # Truncation trick
         
     | 
| 412 | 
         
            +
                           outdir='out',               # Output dir
         
     | 
| 413 | 
         
            +
                           slowdown=1,                 # Slowdown of the video (power of 2)
         
     | 
| 414 | 
         
            +
                           duration_sec=30.0,          # Duration of video in seconds
         
     | 
| 415 | 
         
            +
                           smoothing_sec=3.0,
         
     | 
| 416 | 
         
            +
                           mp4_fps=30,
         
     | 
| 417 | 
         
            +
                           mp4_codec="libx264",
         
     | 
| 418 | 
         
            +
                           mp4_bitrate="16M"):
         
     | 
| 419 | 
         
            +
                # Sanity check regarding slowdown
         
     | 
| 420 | 
         
            +
                message = 'slowdown must be a power of 2 (1, 2, 4, 8, ...) and greater than 0!'
         
     | 
| 421 | 
         
            +
                assert slowdown & (slowdown - 1) == 0 and slowdown > 0, message
         
     | 
| 422 | 
         
            +
                # Initialize TensorFlow and create outdir
         
     | 
| 423 | 
         
            +
                tflib.init_tf()
         
     | 
| 424 | 
         
            +
                os.makedirs(outdir, exist_ok=True)
         
     | 
| 425 | 
         
            +
                # Total duration of video and number of frames to generate
         
     | 
| 426 | 
         
            +
                num_frames = int(np.rint(duration_sec * mp4_fps))
         
     | 
| 427 | 
         
            +
                total_duration = duration_sec * slowdown
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                print(f'Loading network from {network_pkl}...')
         
     | 
| 430 | 
         
            +
                with dnnlib.util.open_url(network_pkl) as fp:
         
     | 
| 431 | 
         
            +
                    _G, _D, Gs = pickle.load(fp)
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                print("Generating latent vectors...")
         
     | 
| 434 | 
         
            +
                # If there's more than one seed provided and the shape isn't specified
         
     | 
| 435 | 
         
            +
                if grid_w == grid_h == None and len(seeds) >= 1:
         
     | 
| 436 | 
         
            +
                    # number of images according to the seeds provided
         
     | 
| 437 | 
         
            +
                    num = len(seeds)
         
     | 
| 438 | 
         
            +
                    # Get the grid width and height according to num:
         
     | 
| 439 | 
         
            +
                    grid_w = max(int(np.ceil(np.sqrt(num))), 1)
         
     | 
| 440 | 
         
            +
                    grid_h = max((num - 1) // grid_w + 1, 1)
         
     | 
| 441 | 
         
            +
                    grid_size = [grid_w, grid_h]
         
     | 
| 442 | 
         
            +
                    # [frame, image, channel, component]:
         
     | 
| 443 | 
         
            +
                    shape = [num_frames] + Gs.input_shape[1:]
         
     | 
| 444 | 
         
            +
                    # Get the latents:
         
     | 
| 445 | 
         
            +
                    all_latents = np.stack([np.random.RandomState(seed).randn(*shape).astype(np.float32) for seed in seeds], axis=1)
         
     | 
| 446 | 
         
            +
                # If only one seed is provided and the shape is specified
         
     | 
| 447 | 
         
            +
                elif None not in (grid_w, grid_h) and len(seeds) == 1:
         
     | 
| 448 | 
         
            +
                    # Otherwise, the user gives one seed and the grid width and height:
         
     | 
| 449 | 
         
            +
                    grid_size = [grid_w, grid_h]
         
     | 
| 450 | 
         
            +
                    # [frame, image, channel, component]:
         
     | 
| 451 | 
         
            +
                    shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:]
         
     | 
| 452 | 
         
            +
                    # Get the latents with the random state:
         
     | 
| 453 | 
         
            +
                    random_state = np.random.RandomState(seeds)
         
     | 
| 454 | 
         
            +
                    all_latents = random_state.randn(*shape).astype(np.float32)
         
     | 
| 455 | 
         
            +
                else:
         
     | 
| 456 | 
         
            +
                    print("Error: wrong combination of arguments! Please provide \
         
     | 
| 457 | 
         
            +
                            either one seed and the grid width and height, or a \
         
     | 
| 458 | 
         
            +
                            list of seeds to use.")
         
     | 
| 459 | 
         
            +
                    sys.exit(1)
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                all_latents = scipy.ndimage.gaussian_filter(
         
     | 
| 462 | 
         
            +
                    all_latents,
         
     | 
| 463 | 
         
            +
                    [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape),
         
     | 
| 464 | 
         
            +
                    mode="wrap"
         
     | 
| 465 | 
         
            +
                )
         
     | 
| 466 | 
         
            +
                all_latents /= np.sqrt(np.mean(np.square(all_latents)))
         
     | 
| 467 | 
         
            +
                # Name of the final mp4 video
         
     | 
| 468 | 
         
            +
                mp4 = f"{grid_w}x{grid_h}-lerp-{slowdown}xslowdown.mp4"
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                # Aux function to slowdown the video by 2x
         
     | 
| 471 | 
         
            +
                def double_slowdown(latents, duration_sec, num_frames):
         
     | 
| 472 | 
         
            +
                    # Make an empty latent vector with double the amount of frames
         
     | 
| 473 | 
         
            +
                    z = np.empty(np.multiply(latents.shape, [2, 1, 1]), dtype=np.float32)
         
     | 
| 474 | 
         
            +
                    # Populate it
         
     | 
| 475 | 
         
            +
                    for i in range(len(latents)):
         
     | 
| 476 | 
         
            +
                        z[2*i] = latents[i]
         
     | 
| 477 | 
         
            +
                    # Interpolate in the odd frames
         
     | 
| 478 | 
         
            +
                    for i in range(1, len(z), 2):
         
     | 
| 479 | 
         
            +
                        # For the last frame, we loop to the first one
         
     | 
| 480 | 
         
            +
                        if i == len(z) - 1:
         
     | 
| 481 | 
         
            +
                            z[i] = (z[0] + z[i-1]) / 2
         
     | 
| 482 | 
         
            +
                        else:
         
     | 
| 483 | 
         
            +
                            z[i] = (z[i-1] + z[i+1]) / 2
         
     | 
| 484 | 
         
            +
                    # We also need to double the duration_sec and num_frames
         
     | 
| 485 | 
         
            +
                    duration_sec *= 2
         
     | 
| 486 | 
         
            +
                    num_frames *= 2
         
     | 
| 487 | 
         
            +
                    # Return the new latents, and the two previous quantities
         
     | 
| 488 | 
         
            +
                    return z, duration_sec, num_frames
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                while slowdown > 1:
         
     | 
| 491 | 
         
            +
                    all_latents, duration_sec, num_frames = double_slowdown(all_latents, duration_sec, num_frames)
         
     | 
| 492 | 
         
            +
                    slowdown //= 2
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
                # Define the kwargs for the Generator:
         
     | 
| 495 | 
         
            +
                Gs_kwargs = dnnlib.EasyDict()
         
     | 
| 496 | 
         
            +
                Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8,
         
     | 
| 497 | 
         
            +
                                                  nchw_to_nhwc=True)
         
     | 
| 498 | 
         
            +
                Gs_kwargs.randomize_noise = False
         
     | 
| 499 | 
         
            +
                if truncation_psi is not None:
         
     | 
| 500 | 
         
            +
                    Gs_kwargs.truncation_psi = truncation_psi
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                # Aux function: Frame generation func for moviepy.
         
     | 
| 503 | 
         
            +
                def make_frame(t):
         
     | 
| 504 | 
         
            +
                    frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
         
     | 
| 505 | 
         
            +
                    latents = all_latents[frame_idx]
         
     | 
| 506 | 
         
            +
                    # Get the images (with labels = None)
         
     | 
| 507 | 
         
            +
                    images = Gs.run(latents, None, **Gs_kwargs)
         
     | 
| 508 | 
         
            +
                    # Generate the grid for this timestamp:
         
     | 
| 509 | 
         
            +
                    grid = create_image_grid(images, grid_size)
         
     | 
| 510 | 
         
            +
                    # grayscale => RGB
         
     | 
| 511 | 
         
            +
                    if grid.shape[2] == 1:
         
     | 
| 512 | 
         
            +
                        grid = grid.repeat(3, 2)
         
     | 
| 513 | 
         
            +
                    return grid
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                # Generate video using make_frame:
         
     | 
| 516 | 
         
            +
                print(f'Generating interpolation video of length: {total_duration} seconds...')
         
     | 
| 517 | 
         
            +
                videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
         
     | 
| 518 | 
         
            +
                videoclip.write_videofile(os.path.join(outdir, mp4),
         
     | 
| 519 | 
         
            +
                                          fps=mp4_fps,
         
     | 
| 520 | 
         
            +
                                          codec=mp4_codec,
         
     | 
| 521 | 
         
            +
                                          bitrate=mp4_bitrate)
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
            def _parse_num_range(s):
         
     | 
| 526 | 
         
            +
                '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                range_re = re.compile(r'^(\d+)-(\d+)$')
         
     | 
| 529 | 
         
            +
                m = range_re.match(s)
         
     | 
| 530 | 
         
            +
                if m:
         
     | 
| 531 | 
         
            +
                    return range(int(m.group(1)), int(m.group(2))+1)
         
     | 
| 532 | 
         
            +
                vals = s.split(',')
         
     | 
| 533 | 
         
            +
                return [int(x) for x in vals]
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
            # My extended version of this helper function:
         
     | 
| 536 | 
         
            +
            def _parse_num_range_ext(s):
         
     | 
| 537 | 
         
            +
                '''
         
     | 
| 538 | 
         
            +
                Input:
         
     | 
| 539 | 
         
            +
                    s (str): Comma separated string of numbers 'a,b,c', a range 'a-c', or
         
     | 
| 540 | 
         
            +
                             even a combination of both 'a,b-c', 'a-b,c', 'a,b-c,d,e-f,...'
         
     | 
| 541 | 
         
            +
                Output:
         
     | 
| 542 | 
         
            +
                    nums (list): Ordered list of ascending ints in s, with repeating values
         
     | 
| 543 | 
         
            +
                                 deleted (can be modified to not do either of this)
         
     | 
| 544 | 
         
            +
                '''
         
     | 
| 545 | 
         
            +
                # Sanity check 0:
         
     | 
| 546 | 
         
            +
                # In case there's a space between the numbers (impossible due to argparse,
         
     | 
| 547 | 
         
            +
                # but hey, I am that paranoid):
         
     | 
| 548 | 
         
            +
                s = s.replace(' ', '')
         
     | 
| 549 | 
         
            +
                # Split w.r.t comma
         
     | 
| 550 | 
         
            +
                str_list = s.split(',')
         
     | 
| 551 | 
         
            +
                nums = []
         
     | 
| 552 | 
         
            +
                for el in str_list:
         
     | 
| 553 | 
         
            +
                    if '-' in el:
         
     | 
| 554 | 
         
            +
                        # The range will be 'a-b', so we wish to find both a and b using re:
         
     | 
| 555 | 
         
            +
                        range_re = re.compile(r'^(\d+)-(\d+)$')
         
     | 
| 556 | 
         
            +
                        match = range_re.match(el)
         
     | 
| 557 | 
         
            +
                        # We get the two numbers:
         
     | 
| 558 | 
         
            +
                        a = int(match.group(1))
         
     | 
| 559 | 
         
            +
                        b = int(match.group(2))
         
     | 
| 560 | 
         
            +
                        # Sanity check 1: accept 'a-b' or 'b-a', with a<=b:
         
     | 
| 561 | 
         
            +
                        if a <= b: r = [n for n in range(a, b + 1)]
         
     | 
| 562 | 
         
            +
                        else: r = [n for n in range(b, a + 1)]
         
     | 
| 563 | 
         
            +
                        # Use extend since r will also be an array:
         
     | 
| 564 | 
         
            +
                        nums.extend(r)
         
     | 
| 565 | 
         
            +
                    else:
         
     | 
| 566 | 
         
            +
                        # It's a single number, so just append it:
         
     | 
| 567 | 
         
            +
                        nums.append(int(el))
         
     | 
| 568 | 
         
            +
                # Sanity check 2: delete repeating numbers:
         
     | 
| 569 | 
         
            +
                nums = list(set(nums))
         
     | 
| 570 | 
         
            +
                # Return the numbers in ascending order:
         
     | 
| 571 | 
         
            +
                return sorted(nums)
         
     | 
| 572 | 
         
            +
             
     | 
| 573 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
            def _parse_npy_files(files):
         
     | 
| 576 | 
         
            +
                '''Accept a comma separated list of npy files and return a list of z vectors.'''
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
                zs =[]
         
     | 
| 579 | 
         
            +
             
     | 
| 580 | 
         
            +
                file_list = files.split(",")
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                for f in file_list:
         
     | 
| 584 | 
         
            +
                    # load numpy array
         
     | 
| 585 | 
         
            +
                    arr = np.load(f)
         
     | 
| 586 | 
         
            +
                    # check if it's actually npz:
         
     | 
| 587 | 
         
            +
                    if 'dlatents' in arr:
         
     | 
| 588 | 
         
            +
                        arr = arr['dlatents']
         
     | 
| 589 | 
         
            +
                    zs.append(arr)
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
                return zs
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 596 | 
         
            +
             
     | 
| 597 | 
         
            +
            _examples = '''examples:
         
     | 
| 598 | 
         
            +
             
     | 
| 599 | 
         
            +
              # Generate curated MetFaces images without truncation (Fig.10 left)
         
     | 
| 600 | 
         
            +
              python %(prog)s --outdir=out --trunc=1 --seeds=85,265,297,849 \\
         
     | 
| 601 | 
         
            +
                  --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
              # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
         
     | 
| 604 | 
         
            +
              python %(prog)s --outdir=out --trunc=0.7 --seeds=600-605 \\
         
     | 
| 605 | 
         
            +
                  --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
              # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
         
     | 
| 608 | 
         
            +
              python %(prog)s --outdir=out --trunc=1 --seeds=0-35 --class=1 \\
         
     | 
| 609 | 
         
            +
                  --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
              # Render image from projected latent vector
         
     | 
| 612 | 
         
            +
              python %(prog)s --outdir=out --dlatents=out/dlatents.npz \\
         
     | 
| 613 | 
         
            +
                  --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl
         
     | 
| 614 | 
         
            +
            '''
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
            def main():
         
     | 
| 619 | 
         
            +
                parser = argparse.ArgumentParser(
         
     | 
| 620 | 
         
            +
                    description='Generate images using pretrained network pickle.',
         
     | 
| 621 | 
         
            +
                    epilog=_examples,
         
     | 
| 622 | 
         
            +
                    formatter_class=argparse.RawDescriptionHelpFormatter
         
     | 
| 623 | 
         
            +
                )
         
     | 
| 624 | 
         
            +
             
     | 
| 625 | 
         
            +
                subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
                parser_generate_images = subparsers.add_parser('generate-images', help='Generate images')
         
     | 
| 628 | 
         
            +
                parser_generate_images.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
         
     | 
| 629 | 
         
            +
                parser_generate_images.add_argument('--seeds', type=_parse_num_range, help='List of random seeds', dest='seeds', required=True)
         
     | 
| 630 | 
         
            +
                parser_generate_images.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', dest='truncation_psi', default=0.5)
         
     | 
| 631 | 
         
            +
                parser_generate_images.add_argument('--class', dest='class_idx', type=int, help='Class label (default: unconditional)')
         
     | 
| 632 | 
         
            +
                parser_generate_images.add_argument('--create-grid', action='store_true', help='Add flag to save the generated images in a grid', dest='grid')
         
     | 
| 633 | 
         
            +
                parser_generate_images.add_argument('--outdir', help='Root directory for run results (default: %(default)s)', default='out', metavar='DIR')
         
     | 
| 634 | 
         
            +
                parser_generate_images.set_defaults(func=generate_images)
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                parser_truncation_traversal = subparsers.add_parser('truncation-traversal', help='Generate truncation walk')
         
     | 
| 637 | 
         
            +
                parser_truncation_traversal.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
         
     | 
| 638 | 
         
            +
                parser_truncation_traversal.add_argument('--seed', type=_parse_num_range, help='Singular seed value')
         
     | 
| 639 | 
         
            +
                parser_truncation_traversal.add_argument('--npys', type=_parse_npy_files, help='List of .npy files')
         
     | 
| 640 | 
         
            +
                parser_truncation_traversal.add_argument('--fps', type=int, help='Starting value',default=24,dest='framerate')
         
     | 
| 641 | 
         
            +
                parser_truncation_traversal.add_argument('--start', type=float, help='Starting value')
         
     | 
| 642 | 
         
            +
                parser_truncation_traversal.add_argument('--stop', type=float, help='Stopping value')
         
     | 
| 643 | 
         
            +
                parser_truncation_traversal.add_argument('--increment', type=float, help='Incrementing value')
         
     | 
| 644 | 
         
            +
                parser_truncation_traversal.add_argument('--outdir', help='Root directory for run results (default: %(default)s)', default='out', metavar='DIR')
         
     | 
| 645 | 
         
            +
                parser_truncation_traversal.set_defaults(func=truncation_traversal)
         
     | 
| 646 | 
         
            +
             
     | 
| 647 | 
         
            +
                parser_generate_latent_walk = subparsers.add_parser('generate-latent-walk', help='Generate latent walk')
         
     | 
| 648 | 
         
            +
                parser_generate_latent_walk.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
         
     | 
| 649 | 
         
            +
                parser_generate_latent_walk.add_argument('--trunc', type=float, help='Truncation psi (default: %(default)s)', dest='truncation_psi', default=0.5)
         
     | 
| 650 | 
         
            +
                parser_generate_latent_walk.add_argument('--walk-type', help='Type of walk (default: %(default)s)', default='line')
         
     | 
| 651 | 
         
            +
                parser_generate_latent_walk.add_argument('--frames', type=int, help='Frame count (default: %(default)s', default=240)
         
     | 
| 652 | 
         
            +
                parser_generate_latent_walk.add_argument('--fps', type=int, help='Starting value',default=24,dest='framerate')
         
     | 
| 653 | 
         
            +
                parser_generate_latent_walk.add_argument('--seeds', type=_parse_num_range, help='List of random seeds')
         
     | 
| 654 | 
         
            +
                parser_generate_latent_walk.add_argument('--npys', type=_parse_npy_files, help='List of .npy files')
         
     | 
| 655 | 
         
            +
                parser_generate_latent_walk.add_argument('--save_vector', dest='save_vector', action='store_true', help='also save vector in .npy format')
         
     | 
| 656 | 
         
            +
                parser_generate_latent_walk.add_argument('--diameter', type=float, help='diameter of noise loop', default=2.0)
         
     | 
| 657 | 
         
            +
                parser_generate_latent_walk.add_argument('--start_seed', type=int, help='random seed to start noise loop from', default=0)
         
     | 
| 658 | 
         
            +
                parser_generate_latent_walk.add_argument('--outdir', help='Root directory for run results (default: %(default)s)', default='out', metavar='DIR')
         
     | 
| 659 | 
         
            +
                parser_generate_latent_walk.set_defaults(func=generate_latent_walk)
         
     | 
| 660 | 
         
            +
             
     | 
| 661 | 
         
            +
                parser_generate_neighbors = subparsers.add_parser('generate-neighbors', help='Generate random neighbors of a seed')
         
     | 
| 662 | 
         
            +
                parser_generate_neighbors.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
         
     | 
| 663 | 
         
            +
                parser_generate_neighbors.add_argument('--seeds', type=_parse_num_range, help='List of random seeds')
         
     | 
| 664 | 
         
            +
                parser_generate_neighbors.add_argument('--npys', type=_parse_npy_files, help='List of .npy files')
         
     | 
| 665 | 
         
            +
                parser_generate_neighbors.add_argument('--diameter', type=float, help='distance around seed to sample from', default=0.1)
         
     | 
| 666 | 
         
            +
                parser_generate_neighbors.add_argument('--save_vector', dest='save_vector', action='store_true', help='also save vector in .npy format')
         
     | 
| 667 | 
         
            +
                parser_generate_neighbors.add_argument('--num_samples', type=int, help='How many neighbors to generate (default: %(default)s', default=25)
         
     | 
| 668 | 
         
            +
                parser_generate_neighbors.add_argument('--trunc', type=float, help='Truncation psi (default: %(default)s)', dest='truncation_psi', default=0.5)
         
     | 
| 669 | 
         
            +
                parser_generate_neighbors.add_argument('--outdir', help='Root directory for run results (default: %(default)s)', default='out', metavar='DIR')
         
     | 
| 670 | 
         
            +
                parser_generate_neighbors.set_defaults(func=generate_neighbors)
         
     | 
| 671 | 
         
            +
             
     | 
| 672 | 
         
            +
                parser_lerp_video = subparsers.add_parser('lerp-video', help='Generate interpolation video (lerp) between random vectors')
         
     | 
| 673 | 
         
            +
                parser_lerp_video.add_argument('--network', help='Path to network pickle filename', dest='network_pkl', required=True)
         
     | 
| 674 | 
         
            +
                parser_lerp_video.add_argument('--seeds', type=_parse_num_range_ext, help='List of random seeds', dest='seeds', required=True)
         
     | 
| 675 | 
         
            +
                parser_lerp_video.add_argument('--grid-w', type=int, help='Video grid width/columns (default: %(default)s)', default=None, dest='grid_w')
         
     | 
| 676 | 
         
            +
                parser_lerp_video.add_argument('--grid-h', type=int, help='Video grid height/rows (default: %(default)s)', default=None, dest='grid_h')
         
     | 
| 677 | 
         
            +
                parser_lerp_video.add_argument('--trunc', type=float, help='Truncation psi (default: %(default)s)', default=1.0, dest='truncation_psi')
         
     | 
| 678 | 
         
            +
                parser_lerp_video.add_argument('--slowdown', type=int, help='Slowdown the video by this amount; must be a power of 2 (default: %(default)s)', default=1, dest='slowdown')
         
     | 
| 679 | 
         
            +
                parser_lerp_video.add_argument('--duration-sec', type=float, help='Duration of video (default: %(default)s)', default=30.0, dest='duration_sec')
         
     | 
| 680 | 
         
            +
                parser_lerp_video.add_argument('--fps', type=int, help='FPS of generated video (default: %(default)s)', default=30, dest='mp4_fps')
         
     | 
| 681 | 
         
            +
                parser_lerp_video.add_argument('--outdir', help='Root directory for run results (default: %(default)s)', default='out', metavar='DIR')
         
     | 
| 682 | 
         
            +
                parser_lerp_video.set_defaults(func=lerp_video)
         
     | 
| 683 | 
         
            +
             
     | 
| 684 | 
         
            +
                args = parser.parse_args()
         
     | 
| 685 | 
         
            +
                kwargs = vars(args)
         
     | 
| 686 | 
         
            +
                subcmd = kwargs.pop('command')
         
     | 
| 687 | 
         
            +
             
     | 
| 688 | 
         
            +
                if subcmd is None:
         
     | 
| 689 | 
         
            +
                    print('Error: missing subcommand.  Re-run with --help for usage.')
         
     | 
| 690 | 
         
            +
                    sys.exit(1)
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
                func = kwargs.pop('func')
         
     | 
| 693 | 
         
            +
                func(**kwargs)
         
     | 
| 694 | 
         
            +
             
     | 
| 695 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 698 | 
         
            +
                main()
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
            #----------------------------------------------------------------------------
         
     | 
    	
        imgs/calligraphyv2.PNG
    ADDED
    
    | 
											 | 
									
								
											Git LFS Details
  | 
									
    	
        imgs/calligraphyv3.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        imgs/calligraphyv4.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        imgs/calligraphyv5.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        imgs/mosaic.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        imgs/mosaicsv2.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        imgs/mosaicsv3.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        imgs/mosaicsv4.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        models.py
    ADDED
    
    | 
         @@ -0,0 +1,142 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import copy
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import warnings
         
     | 
| 5 | 
         
            +
            import tensorflow as tf
         
     | 
| 6 | 
         
            +
            tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
         
     | 
| 7 | 
         
            +
            import warnings
         
     | 
| 8 | 
         
            +
            warnings.filterwarnings('ignore', category=FutureWarning)
         
     | 
| 9 | 
         
            +
            warnings.filterwarnings('ignore', category=DeprecationWarning)
         
     | 
| 10 | 
         
            +
            import sys, getopt, os
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            import dnnlib
         
     | 
| 14 | 
         
            +
            from dnnlib import EasyDict
         
     | 
| 15 | 
         
            +
            import dnnlib.tflib as tflib
         
     | 
| 16 | 
         
            +
            from dnnlib.tflib import tfutil
         
     | 
| 17 | 
         
            +
            from dnnlib.tflib.autosummary import autosummary
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from training import misc
         
     | 
| 20 | 
         
            +
            import pickle
         
     | 
| 21 | 
         
            +
            import argparse
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def create_model(config_id = 'config-f', gamma = None, height = 512, width = 512, cond = None, label_size = 0):
         
     | 
| 24 | 
         
            +
                train     = EasyDict(run_func_name='training.diagnostic.create_initial_pkl') # Options for training loop.
         
     | 
| 25 | 
         
            +
                G         = EasyDict(func_name='training.networks_stylegan2.G_main')       # Options for generator network.
         
     | 
| 26 | 
         
            +
                D         = EasyDict(func_name='training.networks_stylegan2.D_stylegan2')  # Options for discriminator network.
         
     | 
| 27 | 
         
            +
                D_loss    = EasyDict(func_name='training.loss.D_logistic_r1')              # Options for discriminator loss.
         
     | 
| 28 | 
         
            +
                sched     = EasyDict()                                                     # Options for TrainingSchedule.
         
     | 
| 29 | 
         
            +
                sc        = dnnlib.SubmitConfig()                                          # Options for dnnlib.submit_run().
         
     | 
| 30 | 
         
            +
                tf_config = {'rnd.np_random_seed': 1000}                                   # Options for tflib.init_tf().
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                sched.minibatch_size_base = 192
         
     | 
| 33 | 
         
            +
                sched.minibatch_gpu_base = 3
         
     | 
| 34 | 
         
            +
                D_loss.gamma = 10
         
     | 
| 35 | 
         
            +
                desc = 'stylegan2'
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                dataset_args = EasyDict() # (tfrecord_dir=dataset)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                if cond:
         
     | 
| 40 | 
         
            +
                    desc += '-cond'; dataset_args.max_label_size = 'full' # conditioned on full label
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                desc += '-' + config_id
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # Configs A-E: Shrink networks to match original StyleGAN.
         
     | 
| 45 | 
         
            +
                if config_id != 'config-f':
         
     | 
| 46 | 
         
            +
                    G.fmap_base = D.fmap_base = 8 << 10
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                # Config E: Set gamma to 100 and override G & D architecture.
         
     | 
| 49 | 
         
            +
                if config_id.startswith('config-e'):
         
     | 
| 50 | 
         
            +
                    D_loss.gamma = 100
         
     | 
| 51 | 
         
            +
                    if 'Gorig'   in config_id: G.architecture = 'orig'
         
     | 
| 52 | 
         
            +
                    if 'Gskip'   in config_id: G.architecture = 'skip' # (default)
         
     | 
| 53 | 
         
            +
                    if 'Gresnet' in config_id: G.architecture = 'resnet'
         
     | 
| 54 | 
         
            +
                    if 'Dorig'   in config_id: D.architecture = 'orig'
         
     | 
| 55 | 
         
            +
                    if 'Dskip'   in config_id: D.architecture = 'skip'
         
     | 
| 56 | 
         
            +
                    if 'Dresnet' in config_id: D.architecture = 'resnet' # (default)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                # Configs A-D: Enable progressive growing and switch to networks that support it.
         
     | 
| 59 | 
         
            +
                if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
         
     | 
| 60 | 
         
            +
                    sched.lod_initial_resolution = 8
         
     | 
| 61 | 
         
            +
                    sched.G_lrate_base = sched.D_lrate_base = 0.001
         
     | 
| 62 | 
         
            +
                    sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
         
     | 
| 63 | 
         
            +
                    sched.minibatch_size_base = 32 # (default)
         
     | 
| 64 | 
         
            +
                    sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
         
     | 
| 65 | 
         
            +
                    sched.minibatch_gpu_base = 4 # (default)
         
     | 
| 66 | 
         
            +
                    sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
         
     | 
| 67 | 
         
            +
                    G.synthesis_func = 'G_synthesis_stylegan_revised'
         
     | 
| 68 | 
         
            +
                    D.func_name = 'training.networks_stylegan2.D_stylegan'
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                # Configs A-C: Disable path length regularization.
         
     | 
| 71 | 
         
            +
                if config_id in ['config-a', 'config-b', 'config-c']:
         
     | 
| 72 | 
         
            +
                    G_loss = EasyDict(func_name='training.loss.G_logistic_ns')
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                # Configs A-B: Disable lazy regularization.
         
     | 
| 75 | 
         
            +
                if config_id in ['config-a', 'config-b']:
         
     | 
| 76 | 
         
            +
                    train.lazy_regularization = False
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                # Config A: Switch to original StyleGAN networks.
         
     | 
| 79 | 
         
            +
                if config_id == 'config-a':
         
     | 
| 80 | 
         
            +
                    G = EasyDict(func_name='training.networks_stylegan.G_style')
         
     | 
| 81 | 
         
            +
                    D = EasyDict(func_name='training.networks_stylegan.D_basic')
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                if gamma is not None:
         
     | 
| 84 | 
         
            +
                    D_loss.gamma = gamma
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                G.update(resolution_h=height)
         
     | 
| 87 | 
         
            +
                G.update(resolution_w=width)
         
     | 
| 88 | 
         
            +
                D.update(resolution_h=height)
         
     | 
| 89 | 
         
            +
                D.update(resolution_w=width)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                sc.submit_target = dnnlib.SubmitTarget.DIAGNOSTIC
         
     | 
| 92 | 
         
            +
                sc.local.do_not_copy_source_files = True
         
     | 
| 93 | 
         
            +
                kwargs = EasyDict(train)
         
     | 
| 94 | 
         
            +
                # [EDITED]
         
     | 
| 95 | 
         
            +
                kwargs.update(G_args=G, D_args=D, tf_config=tf_config, config_id=config_id,
         
     | 
| 96 | 
         
            +
                    resolution_h=height, resolution_w=width, label_size = label_size)
         
     | 
| 97 | 
         
            +
                kwargs.submit_config = copy.deepcopy(sc)
         
     | 
| 98 | 
         
            +
                kwargs.submit_config.run_desc = desc
         
     | 
| 99 | 
         
            +
                dnnlib.submit_diagnostic(**kwargs)
         
     | 
| 100 | 
         
            +
                return f'network-initial-config-f-{height}x{width}-{label_size}.pkl'
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
            def _str_to_bool(v):
         
     | 
| 103 | 
         
            +
                if isinstance(v, bool):
         
     | 
| 104 | 
         
            +
                    return v
         
     | 
| 105 | 
         
            +
                if v.lower() in ('yes', 'true', 't', 'y', '1'):
         
     | 
| 106 | 
         
            +
                    return True
         
     | 
| 107 | 
         
            +
                elif v.lower() in ('no', 'false', 'f', 'n', '0'):
         
     | 
| 108 | 
         
            +
                    return False
         
     | 
| 109 | 
         
            +
                else:
         
     | 
| 110 | 
         
            +
                    raise argparse.ArgumentTypeError('Boolean value expected.')
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            def _parse_comma_sep(s):
         
     | 
| 113 | 
         
            +
                if s is None or s.lower() == 'none' or s == '':
         
     | 
| 114 | 
         
            +
                    return []
         
     | 
| 115 | 
         
            +
                return s.split(',')
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def copy_weights(source_pkl, target_pkl, output_pkl):
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                tflib.init_tf()
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                with tf.Session():
         
     | 
| 122 | 
         
            +
                    with tf.device('/gpu:0'):
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                        sourceG, sourceD, sourceGs = pickle.load(open(source_pkl, 'rb'))
         
     | 
| 125 | 
         
            +
                        targetG, targetD, targetGs = pickle.load(open(target_pkl, 'rb'))
         
     | 
| 126 | 
         
            +
                        
         
     | 
| 127 | 
         
            +
                        # print('Source:')
         
     | 
| 128 | 
         
            +
                        # sourceG.print_layers()
         
     | 
| 129 | 
         
            +
                        # sourceD.print_layers() 
         
     | 
| 130 | 
         
            +
                        # sourceGs.print_layers()
         
     | 
| 131 | 
         
            +
                        
         
     | 
| 132 | 
         
            +
                        # print('Target:')
         
     | 
| 133 | 
         
            +
                        # targetG.print_layers()
         
     | 
| 134 | 
         
            +
                        # targetD.print_layers() 
         
     | 
| 135 | 
         
            +
                        # targetGs.print_layers()
         
     | 
| 136 | 
         
            +
                        
         
     | 
| 137 | 
         
            +
                        targetG.copy_compatible_trainables_from(sourceG)
         
     | 
| 138 | 
         
            +
                        targetD.copy_compatible_trainables_from(sourceD)
         
     | 
| 139 | 
         
            +
                        targetGs.copy_compatible_trainables_from(sourceGs)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                        with open(os.path.join('./', output_pkl), 'wb') as file:
         
     | 
| 142 | 
         
            +
                            pickle.dump((targetG, targetD, targetGs), file, protocol=pickle.HIGHEST_PROTOCOL)
         
     | 
    	
        rasm.py
    ADDED
    
    | 
         @@ -0,0 +1,146 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
| 2 | 
         
            +
            from utils import download_url
         
     | 
| 3 | 
         
            +
            import argparse
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import PIL.Image
         
     | 
| 6 | 
         
            +
            import dnnlib
         
     | 
| 7 | 
         
            +
            import dnnlib.tflib as tflib
         
     | 
| 8 | 
         
            +
            import re
         
     | 
| 9 | 
         
            +
            import sys
         
     | 
| 10 | 
         
            +
            from io import BytesIO
         
     | 
| 11 | 
         
            +
            import IPython.display
         
     | 
| 12 | 
         
            +
            from math import ceil
         
     | 
| 13 | 
         
            +
            from PIL import Image, ImageDraw
         
     | 
| 14 | 
         
            +
            import os
         
     | 
| 15 | 
         
            +
            import pickle
         
     | 
| 16 | 
         
            +
            from utils import log_progress, imshow, create_image_grid, show_animation
         
     | 
| 17 | 
         
            +
            import imageio
         
     | 
| 18 | 
         
            +
            import glob
         
     | 
| 19 | 
         
            +
            import gdown 
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            class Rasm:
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __init__(self, mode = 'calligraphy'):
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    if mode == 'calligraphy':
         
     | 
| 26 | 
         
            +
                        url = 'https://drive.google.com/uc?id=138fdURGxdkOwZq7IWvnrGLcfo5VI8O1R'
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    else:
         
     | 
| 29 | 
         
            +
                        url = 'https://drive.google.com/uc?id=13h-alXGI0hbNOJy1qbmeoroXZSPBHEG2'
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    output = 'model.pkl'
         
     | 
| 32 | 
         
            +
                    print('Downloading networks from "%s"...' %url)
         
     | 
| 33 | 
         
            +
                    gdown.download(url, output, quiet=False)
         
     | 
| 34 | 
         
            +
                    dnnlib.tflib.init_tf()
         
     | 
| 35 | 
         
            +
                    with dnnlib.util.open_url(output) as fp:
         
     | 
| 36 | 
         
            +
                        self._G, self._D, self.Gs = pickle.load(fp)
         
     | 
| 37 | 
         
            +
                    self.noise_vars = [var for name, var in self.Gs.components.synthesis.vars.items() if name.startswith('noise')]
         
     | 
| 38 | 
         
            +
                
         
     | 
| 39 | 
         
            +
                # Generates a list of images, based on a list of latent vectors (Z), and a list (or a single constant) of truncation_psi's.
         
     | 
| 40 | 
         
            +
                def generate_images_in_w_space(self, dlatents, truncation_psi):
         
     | 
| 41 | 
         
            +
                    Gs_kwargs = dnnlib.EasyDict()
         
     | 
| 42 | 
         
            +
                    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
         
     | 
| 43 | 
         
            +
                    Gs_kwargs.randomize_noise = False
         
     | 
| 44 | 
         
            +
                    Gs_kwargs.truncation_psi = truncation_psi
         
     | 
| 45 | 
         
            +
                    # dlatent_avg = self.Gs.get_var('dlatent_avg') # [component]
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    imgs = []
         
     | 
| 48 | 
         
            +
                    for _, dlatent in log_progress(enumerate(dlatents), name = "Generating images"):
         
     | 
| 49 | 
         
            +
                        #row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(truncation_psi, [-1, 1, 1]) + dlatent_avg
         
     | 
| 50 | 
         
            +
                        # dl = (dlatent-dlatent_avg)*truncation_psi   + dlatent_avg
         
     | 
| 51 | 
         
            +
                        row_images = self.Gs.components.synthesis.run(dlatent,  **Gs_kwargs)
         
     | 
| 52 | 
         
            +
                        imgs.append(PIL.Image.fromarray(row_images[0], 'RGB'))
         
     | 
| 53 | 
         
            +
                    return imgs       
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def generate_images(self, zs, truncation_psi, class_idx = None):
         
     | 
| 56 | 
         
            +
                    Gs_kwargs = dnnlib.EasyDict()
         
     | 
| 57 | 
         
            +
                    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
         
     | 
| 58 | 
         
            +
                    Gs_kwargs.randomize_noise = False
         
     | 
| 59 | 
         
            +
                    if not isinstance(truncation_psi, list):
         
     | 
| 60 | 
         
            +
                        truncation_psi = [truncation_psi] * len(zs)
         
     | 
| 61 | 
         
            +
                        
         
     | 
| 62 | 
         
            +
                    imgs = []
         
     | 
| 63 | 
         
            +
                    label = np.zeros([1] + self.Gs.input_shapes[1][1:])
         
     | 
| 64 | 
         
            +
                    if class_idx is not None:
         
     | 
| 65 | 
         
            +
                        label[:, class_idx] = 1
         
     | 
| 66 | 
         
            +
                    else:
         
     | 
| 67 | 
         
            +
                        label = None
         
     | 
| 68 | 
         
            +
                    for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"):
         
     | 
| 69 | 
         
            +
                        Gs_kwargs.truncation_psi = truncation_psi[z_idx]
         
     | 
| 70 | 
         
            +
                        noise_rnd = np.random.RandomState(1) # fix noise
         
     | 
| 71 | 
         
            +
                        tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) # [height, width]
         
     | 
| 72 | 
         
            +
                        images = self.Gs.run(z, label, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 73 | 
         
            +
                        imgs.append(PIL.Image.fromarray(images[0], 'RGB'))
         
     | 
| 74 | 
         
            +
                    return imgs
         
     | 
| 75 | 
         
            +
                
         
     | 
| 76 | 
         
            +
                def generate_from_zs(self, zs, truncation_psi = 0.5):
         
     | 
| 77 | 
         
            +
                    Gs_kwargs = dnnlib.EasyDict()
         
     | 
| 78 | 
         
            +
                    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
         
     | 
| 79 | 
         
            +
                    Gs_kwargs.randomize_noise = False
         
     | 
| 80 | 
         
            +
                    if not isinstance(truncation_psi, list):
         
     | 
| 81 | 
         
            +
                        truncation_psi = [truncation_psi] * len(zs)
         
     | 
| 82 | 
         
            +
                        
         
     | 
| 83 | 
         
            +
                    for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"):
         
     | 
| 84 | 
         
            +
                        Gs_kwargs.truncation_psi = truncation_psi[z_idx]
         
     | 
| 85 | 
         
            +
                        noise_rnd = np.random.RandomState(1) # fix noise
         
     | 
| 86 | 
         
            +
                        tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) # [height, width]
         
     | 
| 87 | 
         
            +
                        images = self.Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
         
     | 
| 88 | 
         
            +
                        img = PIL.Image.fromarray(images[0], 'RGB')
         
     | 
| 89 | 
         
            +
                        imshow(img)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def generate_random_zs(self, size):
         
     | 
| 92 | 
         
            +
                    seeds = np.random.randint(2**32, size=size)
         
     | 
| 93 | 
         
            +
                    zs = []
         
     | 
| 94 | 
         
            +
                    for _, seed in enumerate(seeds):
         
     | 
| 95 | 
         
            +
                        rnd = np.random.RandomState(seed)
         
     | 
| 96 | 
         
            +
                        z = rnd.randn(1, *self.Gs.input_shape[1:]) # [minibatch, component]
         
     | 
| 97 | 
         
            +
                        zs.append(z)
         
     | 
| 98 | 
         
            +
                    return zs
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                def generate_zs_from_seeds(self, seeds):
         
     | 
| 102 | 
         
            +
                    zs = []
         
     | 
| 103 | 
         
            +
                    for _, seed in enumerate(seeds):
         
     | 
| 104 | 
         
            +
                        rnd = np.random.RandomState(seed)
         
     | 
| 105 | 
         
            +
                        z = rnd.randn(1, *self.Gs.input_shape[1:]) # [minibatch, component]
         
     | 
| 106 | 
         
            +
                        zs.append(z)
         
     | 
| 107 | 
         
            +
                    return zs
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                # Generates a list of images, based on a list of seed for latent vectors (Z), and a list (or a single constant) of truncation_psi's.
         
     | 
| 110 | 
         
            +
                def generate_images_from_seeds(self, seeds, truncation_psi):
         
     | 
| 111 | 
         
            +
                    return imshow(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi)[0])
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def generate_randomly(self, truncation_psi = 0.5):
         
     | 
| 114 | 
         
            +
                    return self.generate_images_from_seeds(np.random.randint(4294967295, size=1), truncation_psi=truncation_psi)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def generate_grid(self, truncation_psi = 0.7): 
         
     | 
| 117 | 
         
            +
                  seeds = np.random.randint((2**32 - 1), size=9)
         
     | 
| 118 | 
         
            +
                  return create_image_grid(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi), 0.7 , 3)
         
     | 
| 119 | 
         
            +
                
         
     | 
| 120 | 
         
            +
                def generate_animation(self, size = 9, steps = 10, trunc_psi = 0.5):
         
     | 
| 121 | 
         
            +
                  seeds = list(np.random.randint((2**32) - 1, size=size))
         
     | 
| 122 | 
         
            +
                  seeds = seeds + [seeds[0]]
         
     | 
| 123 | 
         
            +
                  zs = self.generate_zs_from_seeds(seeds)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                  imgs = self.generate_images(self.interpolate(zs, steps = steps), trunc_psi)
         
     | 
| 126 | 
         
            +
                  movie_name = 'animation.mp4'
         
     | 
| 127 | 
         
            +
                  with imageio.get_writer(movie_name, mode='I') as writer:
         
     | 
| 128 | 
         
            +
                    for image in log_progress(list(imgs), name = "Creating animation"):
         
     | 
| 129 | 
         
            +
                        writer.append_data(np.array(image))
         
     | 
| 130 | 
         
            +
                  return show_animation(movie_name)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                def convertZtoW(self, latent, truncation_psi=0.7, truncation_cutoff=9):
         
     | 
| 133 | 
         
            +
                    dlatent = self.Gs.components.mapping.run(latent, None) # [seed, layer, component]
         
     | 
| 134 | 
         
            +
                    dlatent_avg = self.Gs.get_var('dlatent_avg') # [component]
         
     | 
| 135 | 
         
            +
                    for i in range(truncation_cutoff):
         
     | 
| 136 | 
         
            +
                        dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg
         
     | 
| 137 | 
         
            +
                        
         
     | 
| 138 | 
         
            +
                    return dlatent
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                def interpolate(self, zs, steps = 10):
         
     | 
| 141 | 
         
            +
                    out = []
         
     | 
| 142 | 
         
            +
                    for i in range(len(zs)-1):
         
     | 
| 143 | 
         
            +
                        for index in range(steps):
         
     | 
| 144 | 
         
            +
                            fraction = index/float(steps) 
         
     | 
| 145 | 
         
            +
                            out.append(zs[i+1]*fraction + zs[i]*(1-fraction))
         
     | 
| 146 | 
         
            +
                    return out
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,32 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            absl-py==0.7.0
         
     | 
| 2 | 
         
            +
            astor==0.7.1
         
     | 
| 3 | 
         
            +
            certifi==2018.11.29
         
     | 
| 4 | 
         
            +
            chardet==3.0.4
         
     | 
| 5 | 
         
            +
            Click==7.0
         
     | 
| 6 | 
         
            +
            Flask==1.0.2
         
     | 
| 7 | 
         
            +
            Flask-Cors==3.0.7
         
     | 
| 8 | 
         
            +
            gast==0.2.2
         
     | 
| 9 | 
         
            +
            gevent==1.4.0
         
     | 
| 10 | 
         
            +
            greenlet==0.4.15
         
     | 
| 11 | 
         
            +
            grpcio==1.19.0
         
     | 
| 12 | 
         
            +
            h5py==2.9.0
         
     | 
| 13 | 
         
            +
            idna==2.8
         
     | 
| 14 | 
         
            +
            itsdangerous==1.1.0
         
     | 
| 15 | 
         
            +
            Jinja2==2.10
         
     | 
| 16 | 
         
            +
            Keras-Applications==1.0.7
         
     | 
| 17 | 
         
            +
            Keras-Preprocessing==1.0.9
         
     | 
| 18 | 
         
            +
            Markdown==3.0.1
         
     | 
| 19 | 
         
            +
            MarkupSafe==1.1.1
         
     | 
| 20 | 
         
            +
            mock==2.0.0
         
     | 
| 21 | 
         
            +
            numpy==1.16.2
         
     | 
| 22 | 
         
            +
            pbr==5.1.2
         
     | 
| 23 | 
         
            +
            Pillow==5.4.1
         
     | 
| 24 | 
         
            +
            protobuf==3.6.1
         
     | 
| 25 | 
         
            +
            requests==2.21.0
         
     | 
| 26 | 
         
            +
            six==1.12.0
         
     | 
| 27 | 
         
            +
            tensorflow-gpu==1.15.0
         
     | 
| 28 | 
         
            +
            termcolor==1.1.0
         
     | 
| 29 | 
         
            +
            urllib3==1.24.1
         
     | 
| 30 | 
         
            +
            Werkzeug==0.14.1
         
     | 
| 31 | 
         
            +
            wget==3.2
         
     | 
| 32 | 
         
            +
            runway-python
         
     | 
    	
        utils.py
    ADDED
    
    | 
         @@ -0,0 +1,165 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import glob
         
     | 
| 2 | 
         
            +
            import os 
         
     | 
| 3 | 
         
            +
            from PIL import Image
         
     | 
| 4 | 
         
            +
            import urllib.request
         
     | 
| 5 | 
         
            +
            from tqdm import tqdm
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import PIL.Image
         
     | 
| 9 | 
         
            +
            import sys
         
     | 
| 10 | 
         
            +
            from io import BytesIO
         
     | 
| 11 | 
         
            +
            import IPython.display
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
            from math import ceil
         
     | 
| 14 | 
         
            +
            from PIL import Image, ImageDraw
         
     | 
| 15 | 
         
            +
            import os
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from IPython.display import HTML
         
     | 
| 18 | 
         
            +
            from base64 import b64encode
         
     | 
| 19 | 
         
            +
            import imageio
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def show_animation(movie_name):
         
     | 
| 22 | 
         
            +
              mp4 = open(movie_name,'rb').read()
         
     | 
| 23 | 
         
            +
              data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
         
     | 
| 24 | 
         
            +
              return HTML("""
         
     | 
| 25 | 
         
            +
              <video width=400 controls>
         
     | 
| 26 | 
         
            +
                    <source src="%s" type="video/mp4">
         
     | 
| 27 | 
         
            +
              </video>
         
     | 
| 28 | 
         
            +
              """ % data_url)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def imshow(a, format='png', jpeg_fallback=True):
         
     | 
| 31 | 
         
            +
                    a = np.asarray(a, dtype=np.uint8)
         
     | 
| 32 | 
         
            +
                    str_file = BytesIO()
         
     | 
| 33 | 
         
            +
                    PIL.Image.fromarray(a).save(str_file, format)
         
     | 
| 34 | 
         
            +
                    im_data = str_file.getvalue()
         
     | 
| 35 | 
         
            +
                    try:
         
     | 
| 36 | 
         
            +
                        disp = IPython.display.display(IPython.display.Image(im_data))
         
     | 
| 37 | 
         
            +
                    except IOError:
         
     | 
| 38 | 
         
            +
                        if jpeg_fallback and format != 'jpeg':
         
     | 
| 39 | 
         
            +
                            print ('Warning: image was too large to display in format "{}"; '
         
     | 
| 40 | 
         
            +
                                    'trying jpeg instead.').format(format)
         
     | 
| 41 | 
         
            +
                            return imshow(a, format='jpeg')
         
     | 
| 42 | 
         
            +
                        else:
         
     | 
| 43 | 
         
            +
                            raise
         
     | 
| 44 | 
         
            +
                    return disp
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    
         
     | 
| 47 | 
         
            +
            def clamp(x, minimum, maximum):
         
     | 
| 48 | 
         
            +
                return max(minimum, min(x, maximum))
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def create_image_grid(images, scale=0.25, rows=1):
         
     | 
| 51 | 
         
            +
                w,h = images[0].size
         
     | 
| 52 | 
         
            +
                w = int(w*scale)
         
     | 
| 53 | 
         
            +
                h = int(h*scale)
         
     | 
| 54 | 
         
            +
                height = rows*h
         
     | 
| 55 | 
         
            +
                cols = ceil(len(images) / rows)
         
     | 
| 56 | 
         
            +
                width = cols*w
         
     | 
| 57 | 
         
            +
                canvas = PIL.Image.new('RGBA', (width,height), 'white')
         
     | 
| 58 | 
         
            +
                for i,img in enumerate(images):
         
     | 
| 59 | 
         
            +
                    img = img.resize((w,h), PIL.Image.ANTIALIAS)
         
     | 
| 60 | 
         
            +
                    canvas.paste(img, (w*(i % cols), h*(i // cols))) 
         
     | 
| 61 | 
         
            +
                return canvas
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def find_latest_pkl(path):
         
     | 
| 64 | 
         
            +
              curr_best = 0
         
     | 
| 65 | 
         
            +
              latest_pkl = ''
         
     | 
| 66 | 
         
            +
              for pkl in glob.glob(f'{path}/*.pkl'):
         
     | 
| 67 | 
         
            +
                ckpt_number = int(pkl.split('-')[-1][:-4])
         
     | 
| 68 | 
         
            +
                if curr_best < ckpt_number:
         
     | 
| 69 | 
         
            +
                  curr_best = ckpt_number
         
     | 
| 70 | 
         
            +
                  latest_pkl = pkl
         
     | 
| 71 | 
         
            +
              return latest_pkl
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def resize(path, dim = (512, 512)):
         
     | 
| 74 | 
         
            +
              dirs = os.listdir(path)
         
     | 
| 75 | 
         
            +
              out_path = f'{path}/{dim[0]}x{dim[1]}'
         
     | 
| 76 | 
         
            +
              os.makedirs(out_path, exist_ok=True)
         
     | 
| 77 | 
         
            +
              for item in log_progress(dirs):
         
     | 
| 78 | 
         
            +
                img_path = f'{path}/{item}'
         
     | 
| 79 | 
         
            +
                if os.path.isfile(img_path):
         
     | 
| 80 | 
         
            +
                    im = Image.open(img_path)
         
     | 
| 81 | 
         
            +
                    imResize = im.resize(dim, Image.ANTIALIAS).convert('RGB')
         
     | 
| 82 | 
         
            +
                    imResize.save(f'{out_path}/{item}', 'JPEG', quality=90)
         
     | 
| 83 | 
         
            +
              return out_path
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            def resize_dirs(path, out_dir, dim = (512, 512)):
         
     | 
| 86 | 
         
            +
              sub_dirs = os.listdir(path)
         
     | 
| 87 | 
         
            +
              for sub_dir in sub_dirs:
         
     | 
| 88 | 
         
            +
                out_path = f'{out_dir}/{sub_dir}'
         
     | 
| 89 | 
         
            +
                os.makedirs(out_path, exist_ok=True)
         
     | 
| 90 | 
         
            +
                for item in log_progress(os.listdir(f'{path}/{sub_dir}/')[:10]):
         
     | 
| 91 | 
         
            +
                    img_path = f'{path}/{sub_dir}/{item}'
         
     | 
| 92 | 
         
            +
                    if os.path.isfile(img_path):
         
     | 
| 93 | 
         
            +
                        im = Image.open(img_path)
         
     | 
| 94 | 
         
            +
                        imResize = im.resize(dim, Image.ANTIALIAS).convert('RGB')
         
     | 
| 95 | 
         
            +
                        imResize.save(f'{out_path}/{item}', 'JPEG', quality=90)
         
     | 
| 96 | 
         
            +
              return out_dir
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            class DownloadProgressBar(tqdm):
         
     | 
| 99 | 
         
            +
                def update_to(self, b=1, bsize=1, tsize=None):
         
     | 
| 100 | 
         
            +
                    if tsize is not None:
         
     | 
| 101 | 
         
            +
                        self.total = tsize
         
     | 
| 102 | 
         
            +
                    self.update(b * bsize - self.n)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            # https://stackoverflow.com/a/53877507
         
     | 
| 105 | 
         
            +
            def download_url(url, output_path):
         
     | 
| 106 | 
         
            +
                with DownloadProgressBar(unit='B', unit_scale=True,
         
     | 
| 107 | 
         
            +
                                         miniters=1, desc=url.split('/')[-1]) as t:
         
     | 
| 108 | 
         
            +
                    urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            # Taken from https://github.com/alexanderkuk/log-progress
         
     | 
| 111 | 
         
            +
            def log_progress(sequence, every=1, size=None, name='Items'):
         
     | 
| 112 | 
         
            +
                from ipywidgets import IntProgress, HTML, VBox
         
     | 
| 113 | 
         
            +
                from IPython.display import display
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                is_iterator = False
         
     | 
| 116 | 
         
            +
                if size is None:
         
     | 
| 117 | 
         
            +
                    try:
         
     | 
| 118 | 
         
            +
                        size = len(sequence)
         
     | 
| 119 | 
         
            +
                    except TypeError:
         
     | 
| 120 | 
         
            +
                        is_iterator = True
         
     | 
| 121 | 
         
            +
                if size is not None:
         
     | 
| 122 | 
         
            +
                    if every is None:
         
     | 
| 123 | 
         
            +
                        if size <= 200:
         
     | 
| 124 | 
         
            +
                            every = 1
         
     | 
| 125 | 
         
            +
                        else:
         
     | 
| 126 | 
         
            +
                            every = int(size / 200)     # every 0.5%
         
     | 
| 127 | 
         
            +
                else:
         
     | 
| 128 | 
         
            +
                    assert every is not None, 'sequence is iterator, set every'
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                if is_iterator:
         
     | 
| 131 | 
         
            +
                    progress = IntProgress(min=0, max=1, value=1)
         
     | 
| 132 | 
         
            +
                    progress.bar_style = 'info'
         
     | 
| 133 | 
         
            +
                else:
         
     | 
| 134 | 
         
            +
                    progress = IntProgress(min=0, max=size, value=0)
         
     | 
| 135 | 
         
            +
                label = HTML()
         
     | 
| 136 | 
         
            +
                box = VBox(children=[label, progress])
         
     | 
| 137 | 
         
            +
                display(box)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                index = 0
         
     | 
| 140 | 
         
            +
                try:
         
     | 
| 141 | 
         
            +
                    for index, record in enumerate(sequence, 1):
         
     | 
| 142 | 
         
            +
                        if index == 1 or index % every == 0:
         
     | 
| 143 | 
         
            +
                            if is_iterator:
         
     | 
| 144 | 
         
            +
                                label.value = '{name}: {index} / ?'.format(
         
     | 
| 145 | 
         
            +
                                    name=name,
         
     | 
| 146 | 
         
            +
                                    index=index
         
     | 
| 147 | 
         
            +
                                )
         
     | 
| 148 | 
         
            +
                            else:
         
     | 
| 149 | 
         
            +
                                progress.value = index
         
     | 
| 150 | 
         
            +
                                label.value = u'{name}: {index} / {size}'.format(
         
     | 
| 151 | 
         
            +
                                    name=name,
         
     | 
| 152 | 
         
            +
                                    index=index,
         
     | 
| 153 | 
         
            +
                                    size=size
         
     | 
| 154 | 
         
            +
                                )
         
     | 
| 155 | 
         
            +
                        yield record
         
     | 
| 156 | 
         
            +
                except:
         
     | 
| 157 | 
         
            +
                    progress.bar_style = 'danger'
         
     | 
| 158 | 
         
            +
                    raise
         
     | 
| 159 | 
         
            +
                else:
         
     | 
| 160 | 
         
            +
                    progress.bar_style = 'success'
         
     | 
| 161 | 
         
            +
                    progress.value = index
         
     | 
| 162 | 
         
            +
                    label.value = "{name}: {index}".format(
         
     | 
| 163 | 
         
            +
                        name=name,
         
     | 
| 164 | 
         
            +
                        index=str(index or '?')
         
     | 
| 165 | 
         
            +
                    )
         
     | 
    	
        video.gif
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  |