| from sklearn.mixture import GaussianMixture | |
| def get_noise_sampler(sample_type='gau'): | |
| if sample_type == 'gau': | |
| sampler = lambda latnt_sz: torch.randn_like(latnt_sz) | |
| elif sample_type == 'gau_offset': | |
| sampler = lambda latnt_sz: torch.randn_like(latnt_sz) + (torch.randn_like(latnt_sz)) | |
| ... | |
| elif sample_type == 'gmm': | |
| ... | |
| else: | |
| ... | |
| return | |
| if __name__ == "__main__": | |
| ... |