Spaces:
Build error
Build error
| import flax | |
| import dill as pickle | |
| import os | |
| import builtins | |
| from jax._src.lib import xla_client | |
| import tensorflow as tf | |
| # Hack: this is the module reported by this object. | |
| # https://github.com/google/jax/issues/8505 | |
| builtins.bfloat16 = xla_client.bfloat16 | |
| def pickle_dump(obj, filename): | |
| """ Wrapper to dump an object to a file.""" | |
| with tf.io.gfile.GFile(filename, "wb") as f: | |
| f.write(pickle.dumps(obj)) | |
| def pickle_load(filename): | |
| """ Wrapper to load an object from a file.""" | |
| with tf.io.gfile.GFile(filename, 'rb') as f: | |
| pickled = pickle.loads(f.read()) | |
| return pickled | |
| def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2): | |
| """ | |
| Saves checkpoint. | |
| Args: | |
| ckpt_dir (str): Path to the directory, where checkpoints are saved. | |
| state_G (train_state.TrainState): Generator state. | |
| state_D (train_state.TrainState): Discriminator state. | |
| params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator. | |
| pl_mean (array): Moving average of the path length (generator regularization). | |
| config (argparse.Namespace): Configuration. | |
| step (int): Current step. | |
| epoch (int): Current epoch. | |
| fid_score (float): FID score corresponding to the checkpoint. | |
| keep (int): Number of checkpoints to keep. | |
| """ | |
| state_dict = {'state_G': flax.jax_utils.unreplicate(state_G), | |
| 'state_D': flax.jax_utils.unreplicate(state_D), | |
| 'params_ema_G': params_ema_G, | |
| 'pl_mean': flax.jax_utils.unreplicate(pl_mean), | |
| 'config': config, | |
| 'fid_score': fid_score, | |
| 'step': step, | |
| 'epoch': epoch} | |
| pickle_dump(state_dict, os.path.join(ckpt_dir, f'ckpt_{step}.pickle')) | |
| ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle')) | |
| if len(ckpts) > keep: | |
| modified_times = {} | |
| for ckpt in ckpts: | |
| stats = tf.io.gfile.stat(ckpt) | |
| modified_times[ckpt] = stats.mtime_nsec | |
| oldest_ckpt = sorted(modified_times, key=modified_times.get)[0] | |
| tf.io.gfile.remove(oldest_ckpt) | |
| def load_checkpoint(filename): | |
| """ | |
| Loads checkpoints. | |
| Args: | |
| filename (str): Path to the checkpoint file. | |
| Returns: | |
| (dict): Checkpoint. | |
| """ | |
| state_dict = pickle_load(filename) | |
| return state_dict | |
| def get_latest_checkpoint(ckpt_dir): | |
| """ | |
| Returns the path of the latest checkpoint. | |
| Args: | |
| ckpt_dir (str): Path to the directory, where checkpoints are saved. | |
| Returns: | |
| (str): Path to latest checkpoint (if it exists). | |
| """ | |
| ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle')) | |
| if len(ckpts) == 0: | |
| return None | |
| modified_times = {} | |
| for ckpt in ckpts: | |
| stats = tf.io.gfile.stat(ckpt) | |
| modified_times[ckpt] = stats.mtime_nsec | |
| latest_ckpt = sorted(modified_times, key=modified_times.get)[-1] | |
| return latest_ckpt | |