Spaces:
Runtime error
Runtime error
| # Copyright 2020 Erik Härkönen. All rights reserved. | |
| # This file is licensed to you under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. You may obtain a copy | |
| # of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software distributed under | |
| # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS | |
| # OF ANY KIND, either express or implied. See the License for the specific language | |
| # governing permissions and limitations under the License. | |
| # An interactive glumpy (OpenGL) + tkinter viewer for interacting with principal components. | |
| # Requires OpenGL and CUDA support for rendering. | |
| import torch | |
| import numpy as np | |
| import tkinter as tk | |
| from tkinter import ttk | |
| from types import SimpleNamespace | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| from os import makedirs | |
| from models import get_instrumented_model | |
| from config import Config | |
| from decomposition import get_or_compute | |
| from torch.nn.functional import interpolate | |
| from TkTorchWindow import TorchImageView | |
| from functools import partial | |
| from platform import system | |
| from PIL import Image | |
| from utils import pad_frames, prettify_name | |
| import pickle | |
| # For platform specific UI tweaks | |
| is_windows = 'Windows' in system() | |
| is_linux = 'Linux' in system() | |
| is_mac = 'Darwin' in system() | |
| # Read input parameters | |
| args = Config().from_args() | |
| # Don't bother without GPU | |
| assert torch.cuda.is_available(), 'Interactive mode requires CUDA' | |
| # Use syntax from paper | |
| def get_edit_name(idx, s, e, name=None): | |
| return 'E({comp}, {edit_range}){edit_name}'.format( | |
| comp = idx, | |
| edit_range = f'{s}-{e}' if e > s else s, | |
| edit_name = f': {name}' if name else '' | |
| ) | |
| # Load or compute PCA basis vectors | |
| def load_components(class_name, inst): | |
| global components, state, use_named_latents | |
| config = args.from_dict({ 'output_class': class_name }) | |
| dump_name = get_or_compute(config, inst) | |
| data = np.load(dump_name, allow_pickle=False) | |
| X_comp = data['act_comp'] | |
| X_mean = data['act_mean'] | |
| X_stdev = data['act_stdev'] | |
| Z_comp = data['lat_comp'] | |
| Z_mean = data['lat_mean'] | |
| Z_stdev = data['lat_stdev'] | |
| random_stdev_act = np.mean(data['random_stdevs']) | |
| n_comp = X_comp.shape[0] | |
| data.close() | |
| # Transfer to GPU | |
| components = SimpleNamespace( | |
| X_comp = torch.from_numpy(X_comp).cuda().float(), | |
| X_mean = torch.from_numpy(X_mean).cuda().float(), | |
| X_stdev = torch.from_numpy(X_stdev).cuda().float(), | |
| Z_comp = torch.from_numpy(Z_comp).cuda().float(), | |
| Z_stdev = torch.from_numpy(Z_stdev).cuda().float(), | |
| Z_mean = torch.from_numpy(Z_mean).cuda().float(), | |
| names = [f'Component {i}' for i in range(n_comp)], | |
| latent_types = [model.latent_space_name()]*n_comp, | |
| ranges = [(0, model.get_max_latents())]*n_comp, | |
| ) | |
| state.component_class = class_name # invalidates cache | |
| use_named_latents = False | |
| print('Loaded components for', class_name, 'from', dump_name) | |
| # Load previously exported named components from | |
| # directory specified with '--inputs=path/to/comp' | |
| def load_named_components(path, class_name): | |
| global components, state, use_named_latents | |
| import glob | |
| matches = glob.glob(f'{path}/*.pkl') | |
| selected = [] | |
| for dump_path in matches: | |
| with open(dump_path, 'rb') as f: | |
| data = pickle.load(f) | |
| if data['model_name'] != model_name or data['output_class'] != class_name: | |
| continue | |
| if data['latent_space'] != model.latent_space_name(): | |
| print('Skipping', dump_path, '(wrong latent space)') | |
| continue | |
| selected.append(data) | |
| print('Using', dump_path) | |
| if len(selected) == 0: | |
| raise RuntimeError('No valid components in given path.') | |
| comp_dict = { k : [] for k in ['X_comp', 'Z_comp', 'X_stdev', 'Z_stdev', 'names', 'types', 'layer_names', 'ranges', 'latent_types'] } | |
| components = SimpleNamespace(**comp_dict) | |
| for d in selected: | |
| s = d['edit_start'] | |
| e = d['edit_end'] | |
| title = get_edit_name(d['component_index'], s, e - 1, d['name']) # show inclusive | |
| components.X_comp.append(torch.from_numpy(d['act_comp']).cuda()) | |
| components.Z_comp.append(torch.from_numpy(d['lat_comp']).cuda()) | |
| components.X_stdev.append(d['act_stdev']) | |
| components.Z_stdev.append(d['lat_stdev']) | |
| components.names.append(title) | |
| components.types.append(d['edit_type']) | |
| components.layer_names.append(d['decomposition']['layer']) # only for act | |
| components.ranges.append((s, e)) | |
| components.latent_types.append(d['latent_space']) # W or Z | |
| use_named_latents = True | |
| print('Loaded named components') | |
| def setup_model(): | |
| global model, inst, layer_name, model_name, feat_shape, args, class_name | |
| model_name = args.model | |
| layer_name = args.layer | |
| class_name = args.output_class | |
| # Speed up pytorch | |
| torch.autograd.set_grad_enabled(False) | |
| torch.backends.cudnn.benchmark = True | |
| # Load model | |
| inst = get_instrumented_model(model_name, class_name, layer_name, torch.device('cuda'), use_w=args.use_w) | |
| model = inst.model | |
| feat_shape = inst.feature_shape[layer_name] | |
| sample_dims = np.prod(feat_shape) | |
| # Initialize | |
| if args.inputs: | |
| load_named_components(args.inputs, class_name) | |
| else: | |
| load_components(class_name, inst) | |
| # Project tensor 'X' onto orthonormal basis 'comp', return coordinates | |
| def project_ortho(X, comp): | |
| N = comp.shape[0] | |
| coords = (comp.reshape(N, -1) * X.reshape(-1)).sum(dim=1) | |
| return coords.reshape([N]+[1]*X.ndim) | |
| def zero_sliders(): | |
| for v in ui_state.sliders: | |
| v.set(0.0) | |
| def reset_sliders(zero_on_failure=True): | |
| global ui_state | |
| mode = ui_state.mode.get() | |
| # Not orthogonal: need to solve least-norm problem | |
| # Not batch size 1: one set of sliders not enough | |
| # Not principal components: unsupported format | |
| is_ortho = not (mode == 'latent' and model.latent_space_name() == 'Z') | |
| is_single = state.z.shape[0] == 1 | |
| is_pcs = not use_named_latents | |
| state.lat_slider_offset = 0 | |
| state.act_slider_offset = 0 | |
| enabled = False | |
| if not (enabled and is_ortho and is_single and is_pcs): | |
| if zero_on_failure: | |
| zero_sliders() | |
| return | |
| if mode == 'activation': | |
| val = state.base_act | |
| mean = components.X_mean | |
| comp = components.X_comp | |
| stdev = components.X_stdev | |
| else: | |
| val = state.z | |
| mean = components.Z_mean | |
| comp = components.Z_comp | |
| stdev = components.Z_stdev | |
| n_sliders = len(ui_state.sliders) | |
| coords = project_ortho(val - mean, comp) | |
| offset = torch.sum(coords[:n_sliders] * comp[:n_sliders], dim=0) | |
| scaled_coords = (coords.view(-1) / stdev).detach().cpu().numpy() | |
| # Part representable by sliders | |
| if mode == 'activation': | |
| state.act_slider_offset = offset | |
| else: | |
| state.lat_slider_offset = offset | |
| for i in range(n_sliders): | |
| ui_state.sliders[i].set(round(scaled_coords[i], ndigits=1)) | |
| def setup_ui(): | |
| global root, toolbar, ui_state, app, canvas | |
| root = tk.Tk() | |
| scale = 1.0 | |
| app = TorchImageView(root, width=int(scale*1024), height=int(scale*1024), show_fps=False) | |
| app.pack(fill=tk.BOTH, expand=tk.YES) | |
| root.protocol("WM_DELETE_WINDOW", shutdown) | |
| root.title('GANspace') | |
| toolbar = tk.Toplevel(root) | |
| toolbar.protocol("WM_DELETE_WINDOW", shutdown) | |
| toolbar.geometry("215x800+0+0") | |
| toolbar.title('') | |
| N_COMPONENTS = min(70, len(components.names)) | |
| ui_state = SimpleNamespace( | |
| sliders = [tk.DoubleVar(value=0.0) for _ in range(N_COMPONENTS)], | |
| scales = [], | |
| truncation = tk.DoubleVar(value=0.9), | |
| outclass = tk.StringVar(value=class_name), | |
| random_seed = tk.StringVar(value='0'), | |
| mode = tk.StringVar(value='latent'), | |
| batch_size = tk.IntVar(value=1), # how many images to show in window | |
| edit_layer_start = tk.IntVar(value=0), | |
| edit_layer_end = tk.IntVar(value=model.get_max_latents() - 1), | |
| slider_max_val = 10.0 | |
| ) | |
| # Z vs activation mode button | |
| #tk.Radiobutton(toolbar, text=f"Latent ({model.latent_space_name()})", variable=ui_state.mode, command=reset_sliders, value='latent').pack(fill="x") | |
| #tk.Radiobutton(toolbar, text="Activation", variable=ui_state.mode, command=reset_sliders, value='activation').pack(fill="x") | |
| # Choose range where latents are modified | |
| def set_min(val): | |
| ui_state.edit_layer_start.set(min(int(val), ui_state.edit_layer_end.get())) | |
| def set_max(val): | |
| ui_state.edit_layer_end.set(max(int(val), ui_state.edit_layer_start.get())) | |
| max_latent_idx = model.get_max_latents() - 1 | |
| if not use_named_latents: | |
| slider_min = tk.Scale(toolbar, command=set_min, variable=ui_state.edit_layer_start, | |
| label='Layer start', from_=0, to=max_latent_idx, orient=tk.HORIZONTAL).pack(fill="x") | |
| slider_max = tk.Scale(toolbar, command=set_max, variable=ui_state.edit_layer_end, | |
| label='Layer end', from_=0, to=max_latent_idx, orient=tk.HORIZONTAL).pack(fill="x") | |
| # Scrollable list of components | |
| outer_frame = tk.Frame(toolbar, borderwidth=2, relief=tk.SUNKEN) | |
| canvas = tk.Canvas(outer_frame, highlightthickness=0, borderwidth=0) | |
| frame = tk.Frame(canvas) | |
| vsb = tk.Scrollbar(outer_frame, orient="vertical", command=canvas.yview) | |
| canvas.configure(yscrollcommand=vsb.set) | |
| vsb.pack(side="right", fill="y") | |
| canvas.pack(side="left", fill="both", expand=True) | |
| canvas.create_window((4,4), window=frame, anchor="nw") | |
| def onCanvasConfigure(event): | |
| canvas.itemconfigure("all", width=event.width) | |
| canvas.configure(scrollregion=canvas.bbox("all")) | |
| canvas.bind("<Configure>", onCanvasConfigure) | |
| def on_scroll(event): | |
| delta = 1 if (event.num == 5 or event.delta < 0) else -1 | |
| canvas.yview_scroll(delta, "units") | |
| canvas.bind_all("<Button-4>", on_scroll) | |
| canvas.bind_all("<Button-5>", on_scroll) | |
| canvas.bind_all("<MouseWheel>", on_scroll) | |
| canvas.bind_all("<Key>", lambda event : handle_keypress(event.keysym_num)) | |
| # Sliders and buttons | |
| for i in range(N_COMPONENTS): | |
| inner = tk.Frame(frame, borderwidth=1, background="#aaaaaa") | |
| scale = tk.Scale(inner, variable=ui_state.sliders[i], from_=-ui_state.slider_max_val, | |
| to=ui_state.slider_max_val, resolution=0.1, orient=tk.HORIZONTAL, label=components.names[i]) | |
| scale.pack(fill=tk.X, side=tk.LEFT, expand=True) | |
| ui_state.scales.append(scale) # for changing label later | |
| if not use_named_latents: | |
| tk.Button(inner, text=f"Save", command=partial(export_direction, i, inner)).pack(fill=tk.Y, side=tk.RIGHT) | |
| inner.pack(fill=tk.X) | |
| outer_frame.pack(fill="both", expand=True, pady=0) | |
| tk.Button(toolbar, text="Reset", command=reset_sliders).pack(anchor=tk.CENTER, fill=tk.X, padx=4, pady=4) | |
| tk.Scale(toolbar, variable=ui_state.truncation, from_=0.01, to=1.0, | |
| resolution=0.01, orient=tk.HORIZONTAL, label='Truncation').pack(fill="x") | |
| tk.Scale(toolbar, variable=ui_state.batch_size, from_=1, to=9, | |
| resolution=1, orient=tk.HORIZONTAL, label='Batch size').pack(fill="x") | |
| # Output class | |
| frame = tk.Frame(toolbar) | |
| tk.Label(frame, text="Class name").pack(fill="x", side="left") | |
| tk.Entry(frame, textvariable=ui_state.outclass).pack(fill="x", side="right", expand=True, padx=5) | |
| frame.pack(fill=tk.X, pady=3) | |
| # Random seed | |
| def update_seed(): | |
| seed_str = ui_state.random_seed.get() | |
| if seed_str.isdigit(): | |
| resample_latent(int(seed_str)) | |
| frame = tk.Frame(toolbar) | |
| tk.Label(frame, text="Seed").pack(fill="x", side="left") | |
| tk.Entry(frame, textvariable=ui_state.random_seed, width=12).pack(fill="x", side="left", expand=True, padx=2) | |
| tk.Button(frame, text="Update", command=update_seed).pack(fill="y", side="right", padx=3) | |
| frame.pack(fill=tk.X, pady=3) | |
| # Get new latent or new components | |
| tk.Button(toolbar, text="Resample latent", command=partial(resample_latent, None, False)).pack(anchor=tk.CENTER, fill=tk.X, padx=4, pady=4) | |
| #tk.Button(toolbar, text="Recompute", command=recompute_components).pack(anchor=tk.CENTER, fill=tk.X) | |
| # App state | |
| state = SimpleNamespace( | |
| z=None, # current latent(s) | |
| lat_slider_offset = 0, # part of lat that is explained by sliders | |
| act_slider_offset = 0, # part of act that is explained by sliders | |
| component_class=None, # name of current PCs' image class | |
| seed=0, # Latent z_i generated by seed+i | |
| base_act = None, # activation of considered layer given z | |
| ) | |
| def resample_latent(seed=None, only_style=False): | |
| class_name = ui_state.outclass.get() | |
| if class_name.isnumeric(): | |
| class_name = int(class_name) | |
| if hasattr(model, 'is_valid_class'): | |
| if not model.is_valid_class(class_name): | |
| return | |
| model.set_output_class(class_name) | |
| B = ui_state.batch_size.get() | |
| state.seed = np.random.randint(np.iinfo(np.int32).max - B) if seed is None else seed | |
| ui_state.random_seed.set(str(state.seed)) | |
| # Use consecutive seeds along batch dimension (for easier reproducibility) | |
| trunc = ui_state.truncation.get() | |
| latents = [model.sample_latent(1, seed=state.seed + i, truncation=trunc) for i in range(B)] | |
| state.z = torch.cat(latents).clone().detach() # make leaf node | |
| assert state.z.is_leaf, 'Latent is not leaf node!' | |
| if hasattr(model, 'truncation'): | |
| model.truncation = ui_state.truncation.get() | |
| print(f'Seeds: {state.seed} -> {state.seed + B - 1}' if B > 1 else f'Seed: {state.seed}') | |
| torch.manual_seed(state.seed) | |
| model.partial_forward(state.z, layer_name) | |
| state.base_act = inst.retained_features()[layer_name] | |
| reset_sliders(zero_on_failure=False) | |
| # Remove focus from text entry | |
| canvas.focus_set() | |
| # Used to recompute after changing class of conditional model | |
| def recompute_components(): | |
| class_name = ui_state.outclass.get() | |
| if class_name.isnumeric(): | |
| class_name = int(class_name) | |
| if hasattr(model, 'is_valid_class'): | |
| if not model.is_valid_class(class_name): | |
| return | |
| if hasattr(model, 'set_output_class'): | |
| model.set_output_class(class_name) | |
| load_components(class_name, inst) | |
| # Used to detect parameter changes for lazy recomputation | |
| class ParamCache(): | |
| def update(self, **kwargs): | |
| dirty = False | |
| for argname, val in kwargs.items(): | |
| # Check pointer, then value | |
| current = getattr(self, argname, 0) | |
| if current is not val and pickle.dumps(current) != pickle.dumps(val): | |
| setattr(self, argname, val) | |
| dirty = True | |
| return dirty | |
| cache = ParamCache() | |
| def l2norm(t): | |
| return torch.norm(t.view(t.shape[0], -1), p=2, dim=1, keepdim=True) | |
| def apply_edit(z0, delta): | |
| return z0 + delta | |
| def reposition_toolbar(): | |
| size, X, Y = root.winfo_geometry().split('+') | |
| W, H = size.split('x') | |
| toolbar_W = toolbar.winfo_geometry().split('x')[0] | |
| offset_y = -30 if is_linux else 0 # window title bar | |
| toolbar.geometry(f'{toolbar_W}x{H}+{int(X)-int(toolbar_W)}+{int(Y)+offset_y}') | |
| toolbar.update() | |
| def on_draw(): | |
| global img | |
| n_comp = len(ui_state.sliders) | |
| slider_vals = np.array([s.get() for s in ui_state.sliders], dtype=np.float32) | |
| # Run model sparingly | |
| mode = ui_state.mode.get() | |
| latent_start = ui_state.edit_layer_start.get() | |
| latent_end = ui_state.edit_layer_end.get() + 1 # save as exclusive, show as inclusive | |
| if cache.update(coords=slider_vals, comp=state.component_class, mode=mode, z=state.z, s=latent_start, e=latent_end): | |
| with torch.no_grad(): | |
| z_base = state.z - state.lat_slider_offset | |
| z_deltas = [0.0]*model.get_max_latents() | |
| z_delta_global = 0.0 | |
| n_comp = slider_vals.size | |
| act_deltas = {} | |
| if torch.is_tensor(state.act_slider_offset): | |
| act_deltas[layer_name] = -state.act_slider_offset | |
| for space in components.latent_types: | |
| assert space == model.latent_space_name(), \ | |
| 'Cannot mix latent spaces (for now)' | |
| for c in range(n_comp): | |
| coord = slider_vals[c] | |
| if coord == 0: | |
| continue | |
| edit_mode = components.types[c] if use_named_latents else mode | |
| # Activation offset | |
| if edit_mode in ['activation', 'both']: | |
| delta = components.X_comp[c] * components.X_stdev[c] * coord | |
| name = components.layer_names[c] if use_named_latents else layer_name | |
| act_deltas[name] = act_deltas.get(name, 0.0) + delta | |
| # Latent offset | |
| if edit_mode in ['latent', 'both']: | |
| delta = components.Z_comp[c] * components.Z_stdev[c] * coord | |
| edit_range = components.ranges[c] if use_named_latents else (latent_start, latent_end) | |
| full_range = (edit_range == (0, model.get_max_latents())) | |
| # Single or multiple offsets? | |
| if full_range: | |
| z_delta_global = z_delta_global + delta | |
| else: | |
| for l in range(*edit_range): | |
| z_deltas[l] = z_deltas[l] + delta | |
| # Apply activation deltas | |
| inst.remove_edits() | |
| for layer, delta in act_deltas.items(): | |
| inst.edit_layer(layer, offset=delta) | |
| # Evaluate | |
| has_offsets = any(torch.is_tensor(t) for t in z_deltas) | |
| z_final = apply_edit(z_base, z_delta_global) | |
| if has_offsets: | |
| z_final = [apply_edit(z_final, d) for d in z_deltas] | |
| img = model.forward(z_final).clamp(0.0, 1.0) | |
| app.draw(img) | |
| # Save necessary data to disk for later loading | |
| def export_direction(idx, button_frame): | |
| name = tk.StringVar(value='') | |
| num_strips = tk.IntVar(value=0) | |
| strip_width = tk.IntVar(value=5) | |
| slider_values = np.array([s.get() for s in ui_state.sliders]) | |
| slider_value = slider_values[idx] | |
| if (slider_values != 0).sum() > 1: | |
| print('Please modify only one slider') | |
| return | |
| elif slider_value == 0: | |
| print('Modify selected slider to set usable range (currently 0)') | |
| return | |
| popup = tk.Toplevel(root) | |
| popup.geometry("200x200+0+0") | |
| tk.Label(popup, text="Edit name").pack() | |
| tk.Entry(popup, textvariable=name).pack(pady=5) | |
| # tk.Scale(popup, from_=0, to=30, variable=num_strips, | |
| # resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strips to export').pack() | |
| # tk.Scale(popup, from_=3, to=15, variable=strip_width, | |
| # resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strip width').pack() | |
| tk.Button(popup, text='OK', command=popup.quit).pack() | |
| canceled = False | |
| def on_close(): | |
| nonlocal canceled | |
| canceled = True | |
| popup.quit() | |
| popup.protocol("WM_DELETE_WINDOW", on_close) | |
| x = button_frame.winfo_rootx() | |
| y = button_frame.winfo_rooty() | |
| w = int(button_frame.winfo_geometry().split('x')[0]) | |
| popup.geometry('%dx%d+%d+%d' % (180, 90, x + w, y)) | |
| popup.mainloop() | |
| popup.destroy() | |
| # Update slider name | |
| label = get_edit_name(idx, ui_state.edit_layer_start.get(), | |
| ui_state.edit_layer_end.get(), name.get()) | |
| ui_state.scales[idx].config(label=label) | |
| if canceled: | |
| return | |
| params = { | |
| 'name': name.get(), | |
| 'sigma_range': slider_value, | |
| 'component_index': idx, | |
| 'act_comp': components.X_comp[idx].detach().cpu().numpy(), | |
| 'lat_comp': components.Z_comp[idx].detach().cpu().numpy(), # either Z or W | |
| 'latent_space': model.latent_space_name(), | |
| 'act_stdev': components.X_stdev[idx].item(), | |
| 'lat_stdev': components.Z_stdev[idx].item(), | |
| 'model_name': model_name, | |
| 'output_class': ui_state.outclass.get(), # applied onto | |
| 'decomposition': { | |
| 'name': args.estimator, | |
| 'components': args.components, | |
| 'samples': args.n, | |
| 'layer': args.layer, | |
| 'class_name': state.component_class # computed from | |
| }, | |
| 'edit_type': ui_state.mode.get(), | |
| 'truncation': ui_state.truncation.get(), | |
| 'edit_start': ui_state.edit_layer_start.get(), | |
| 'edit_end': ui_state.edit_layer_end.get() + 1, # show as inclusive, save as exclusive | |
| 'example_seed': state.seed, | |
| } | |
| edit_mode_str = params['edit_type'] | |
| if edit_mode_str == 'latent': | |
| edit_mode_str = model.latent_space_name().lower() | |
| comp_class = state.component_class | |
| appl_class = params['output_class'] | |
| if comp_class != appl_class: | |
| comp_class = f'{comp_class}_onto_{appl_class}' | |
| file_ident = "{model}-{name}-{cls}-{est}-{mode}-{layer}-comp{idx}-range{start}-{end}".format( | |
| model=model_name, | |
| name=prettify_name(params['name']), | |
| cls=comp_class, | |
| est=args.estimator, | |
| mode=edit_mode_str, | |
| layer=args.layer, | |
| idx=idx, | |
| start=params['edit_start'], | |
| end=params['edit_end'], | |
| ) | |
| out_dir = Path(__file__).parent / 'out' / 'directions' | |
| makedirs(out_dir / file_ident, exist_ok=True) | |
| with open(out_dir / f"{file_ident}.pkl", 'wb') as outfile: | |
| pickle.dump(params, outfile) | |
| print(f'Direction "{name.get()}" saved as "{file_ident}.pkl"') | |
| batch_size = ui_state.batch_size.get() | |
| len_padded = ((num_strips.get() - 1) // batch_size + 1) * batch_size | |
| orig_seed = state.seed | |
| reset_sliders() | |
| # Limit max resolution | |
| max_H = 512 | |
| ratio = min(1.0, max_H / inst.output_shape[2]) | |
| strips = [[] for _ in range(len_padded)] | |
| for b in range(0, len_padded, batch_size): | |
| # Resample | |
| resample_latent((orig_seed + b) % np.iinfo(np.int32).max) | |
| sigmas = np.linspace(slider_value, -slider_value, strip_width.get(), dtype=np.float32) | |
| for sid, sigma in enumerate(sigmas): | |
| ui_state.sliders[idx].set(sigma) | |
| # Advance and show results on screen | |
| on_draw() | |
| root.update() | |
| app.update() | |
| batch_res = (255*img).byte().permute(0, 2, 3, 1).detach().cpu().numpy() | |
| for i, data in enumerate(batch_res): | |
| # Save individual | |
| name_nodots = file_ident.replace('.', '_') | |
| outname = out_dir / file_ident / f"{name_nodots}_ex{b+i}_{sid}.png" | |
| im = Image.fromarray(data) | |
| im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS) | |
| im.save(outname) | |
| strips[b+i].append(data) | |
| for i, strip in enumerate(strips[:num_strips.get()]): | |
| print(f'Saving strip {i + 1}/{num_strips.get()}', end='\r', flush=True) | |
| data = np.hstack(pad_frames(strip)) | |
| im = Image.fromarray(data) | |
| im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS) | |
| im.save(out_dir / file_ident / f"{file_ident}_ex{i}.png") | |
| # Reset to original state | |
| resample_latent(orig_seed) | |
| ui_state.sliders[idx].set(slider_value) | |
| # Shared by glumpy and tkinter | |
| def handle_keypress(code): | |
| if code == 65307: # ESC | |
| shutdown() | |
| elif code == 65360: # HOME | |
| reset_sliders() | |
| elif code == 114: # R | |
| pass #reset_sliders() | |
| def shutdown(): | |
| global pending_close | |
| pending_close = True | |
| def on_key_release(symbol, modifiers): | |
| handle_keypress(symbol) | |
| if __name__=='__main__': | |
| setup_model() | |
| setup_ui() | |
| resample_latent() | |
| pending_close = False | |
| while not pending_close: | |
| root.update() | |
| app.update() | |
| on_draw() | |
| reposition_toolbar() | |
| root.destroy() |