Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class TriangularCausalMask(): | |
| def __init__(self, B, L, device="cpu"): | |
| mask_shape = [B, 1, L, L] | |
| with torch.no_grad(): | |
| self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) | |
| def mask(self): | |
| return self._mask | |
| class ProbMask(): | |
| def __init__(self, B, H, L, index, scores, device="cpu"): | |
| _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) | |
| _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) | |
| indicator = _mask_ex[torch.arange(B)[:, None, None], | |
| torch.arange(H)[None, :, None], | |
| index, :].to(device) | |
| self._mask = indicator.view(scores.shape).to(device) | |
| def mask(self): | |
| return self._mask | |
| def generate_continuous_mask(B, T, C=None, n=5, l=0.1): | |
| if C: | |
| res = torch.full((B, T, C), True, dtype=torch.bool) | |
| else: | |
| res = torch.full((B, T), True, dtype=torch.bool) | |
| if isinstance(n, float): | |
| n = int(n * T) | |
| n = max(min(n, T), 1) | |
| if isinstance(l, float): | |
| l = int(l * T) | |
| l = max(l, 1) | |
| for i in range(B): | |
| for _ in range(n): | |
| t = np.random.randint(T - l + 1) | |
| if C: | |
| # For a continuous timestamps, mask a random selection of channels | |
| num_channels_to_mask = np.random.randint(1, C + 1) # Randomly decide how many channels to mask | |
| index = np.random.choice(C, num_channels_to_mask, replace=False) # Select random channels to mask | |
| res[i, t:t + l, index] = False | |
| else: | |
| # For a continuous timestamps, mask all channels | |
| res[i, t:t + l] = False | |
| return res | |
| def expand_tensor(input_tensor, third_dim_size): | |
| # 将输入张量转换为三维张量 | |
| expanded_tensor = input_tensor.unsqueeze(2).expand(-1, -1, third_dim_size) | |
| return expanded_tensor.bool() | |
| def geom_noise_mask_single(L, lm, masking_ratio): | |
| """ | |
| Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio` | |
| proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution. | |
| Args: | |
| L: length of mask and sequence to be masked | |
| lm: average length of masking subsequences (streaks of 0s) | |
| masking_ratio: proportion of L to be masked | |
| Returns: | |
| (L,) boolean numpy array intended to mask ('drop') with 0s a sequence of length L | |
| """ | |
| keep_mask = np.ones(L, dtype=bool) | |
| p_m = 1 / lm # probability of each masking sequence stopping. parameter of geometric distribution. | |
| p_u = p_m * masking_ratio / ( | |
| 1 - masking_ratio) # probability of each unmasked sequence stopping. parameter of geometric distribution. | |
| p = [p_m, p_u] | |
| # Start in state 0 with masking_ratio probability | |
| state = int(np.random.rand() > masking_ratio) # state 0 means masking, 1 means not masking | |
| for i in range(L): | |
| keep_mask[i] = state # here it happens that state and masking value corresponding to state are identical | |
| if np.random.rand() < p[state]: | |
| state = 1 - state | |
| return keep_mask | |
| def generate_geometric_mask(B, T, C=None, p=0.75, l=3): | |
| if C: | |
| mask = np.ones((B, T, C), dtype=bool) | |
| else: | |
| mask = np.ones((B, T), dtype=bool) | |
| for i in range(B): | |
| if C: | |
| for c in range(C): | |
| mask[i, :, c] = geom_noise_mask_single(T, l, p) | |
| else: | |
| mask[i, :] = geom_noise_mask_single(T, l, p) | |
| return torch.from_numpy(mask).to(torch.bool) | |
| def generate_binomial_mask(B, T, C=None, p=0.5): | |
| if C: | |
| return torch.from_numpy(np.random.binomial(1, 1 - p, size=(B, T, C))).to(torch.bool) | |
| else: | |
| return torch.from_numpy(np.random.binomial(1, 1 - p, size=(B, T))).to(torch.bool) | |
| # def mask_function(x, mask_type, p): | |
| # if mask_type == 'binomial': | |
| # mask = generate_binomial_mask(x.size(0), x.size(1), p=p) | |
| # mask = expand_tensor(mask, x.size(2)).to(x.device) | |
| # elif mask_type == 'channel_binomial': | |
| # mask = generate_binomial_mask(x.size(0), x.size(1), x.size(2), p=p).to(x.device) | |
| # elif mask_type == 'continuous': | |
| # mask = generate_continuous_mask(x.size(0), x.size(1), n=p).to(x.device) | |
| # mask = expand_tensor(mask, x.size(2)).to(x.device) | |
| # elif mask_type == 'channel_continuous': | |
| # mask = generate_continuous_mask(x.size(0), x.size(1), x.size(2), n=p).to(x.device) | |
| # elif mask_type == 'all_true': | |
| # mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) | |
| # elif mask_type == 'all_false': | |
| # mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) | |
| # elif mask_type == 'mask_last': | |
| # mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) | |
| # idx = int(x.size(1) * (1 - p)) | |
| # mask[:, idx:] = False | |
| # mask = expand_tensor(mask, x.size(2)).to(x.device) | |
| # else: | |
| # raise ValueError(f'\'{mask_type}\' is a wrong argument for mask function!') | |
| # | |
| # x = mask * x | |
| # | |
| # return x, mask | |
| # def mask_function(x, mask_type, mask_ratio, lm): | |
| # if mask_type == 'binomial': | |
| # mask = generate_binomial_mask(x.size(0), x.size(1), p=mask_ratio).to(x.device) | |
| # mask = expand_tensor(mask, x.shape[-1]) | |
| # elif mask_type == 'channel_binomial': | |
| # mask = generate_binomial_mask(x.size(0), x.size(1), x.size(2), p=mask_ratio).to(x.device) | |
| # elif mask_type == 'continuous': | |
| # mask = generate_geometric_mask(x.size(0), x.size(1), p=mask_ratio, l=lm).to(x.device) | |
| # mask = expand_tensor(mask, x.shape[-1]) | |
| # elif mask_type == 'channel_continuous': | |
| # mask = generate_geometric_mask(x.size(0), x.size(1), x.size(2), p=mask_ratio, l=lm).to(x.device) | |
| # elif mask_type == 'all_true': | |
| # mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) | |
| # elif mask_type == 'all_false': | |
| # mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) | |
| # elif mask_type == 'mask_last': | |
| # mask = x.new_full((x.size(0), x.size(1), x.size(2)), True, dtype=torch.bool) | |
| # idx = int(x.shape[1] * mask_ratio) | |
| # mask[:, -idx:, :] = False | |
| # else: | |
| # raise ValueError(f'\'{mask_type}\' is a wrong argument for mask function!') | |
| # | |
| # x = mask * x | |
| # | |
| # return x, mask # x: [b, s, c] mask: True: unmasked, False: masked | |
| def patch_mask(x, mask_ratio, patch_len=12, stride=12): | |
| px = x.clone().permute(0, 2, 1) | |
| padding_patch_layer = nn.ReplicationPad1d((0, stride)) | |
| px = padding_patch_layer(px) | |
| px = px.unfold(dimension=-1, size=patch_len, step=stride) | |
| px = torch.reshape(px, (px.shape[0] * px.shape[1], px.shape[2], px.shape[3])) | |
| mask = generate_binomial_mask(px.size(0), px.size(1), p=mask_ratio).to(x.device) | |
| return mask | |
| def mask_function(x, args): | |
| b, s, c = x.shape | |
| if args.masked_rule == 'binomial': | |
| mask = generate_binomial_mask(x.size(0), x.size(1), p=args.mask_rate).to(x.device) | |
| mask = expand_tensor(mask, x.shape[-1]) | |
| elif args.masked_rule == 'channel_binomial': | |
| mask = generate_binomial_mask(x.size(0), x.size(1), x.size(2), p=args.mask_rate).to(x.device) | |
| elif args.masked_rule == 'continuous': | |
| mask = generate_geometric_mask(x.size(0), x.size(1), p=args.mask_rate, l=args.lm).to(x.device) | |
| mask = expand_tensor(mask, x.shape[-1]) | |
| elif args.masked_rule == 'channel_continuous': | |
| mask = generate_geometric_mask(x.size(0), x.size(1), x.size(2), p=args.mask_rate, l=args.lm).to(x.device) | |
| elif args.masked_rule == 'mask_last': | |
| mask = x.new_full((x.size(0), x.size(1), x.size(2)), True, dtype=torch.bool) | |
| idx = int(x.shape[1] * args.mask_rate) | |
| mask[:, -idx:, :] = False | |
| elif args.masked_rule == 'mask_patch': | |
| mask = patch_mask(x, args.mask_rate, args.patch_len, args.stride) | |
| mask = expand_tensor(mask, args.patch_len) | |
| mask = mask.reshape(b, c, -1)[:, :, :s].permute(0, 2, 1) | |
| else: | |
| raise ValueError(f'\'{args.mask_rate}\' is a wrong argument for mask function!') | |
| x = mask * x | |
| return x, mask # x: [b, s, c] unmasked: True, masked, False | |
| # MAE Masking | |
| def random_masking(xb, mask_ratio=0.75): | |
| bs, L, nvars, D = xb.shape # xb: [bs x num_patch x n_vars x patch_len] | |
| x = xb.clone() | |
| len_keep = int(L * (1 - mask_ratio)) | |
| noise = torch.rand(bs, L, nvars, device=xb.device) # noise in [0, 1], bs x L x nvars | |
| # sort noise for each sample | |
| ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
| ids_restore = torch.argsort(ids_shuffle, dim=1) # ids_restore: [bs x L x nvars] | |
| # keep the first subset | |
| ids_keep = ids_shuffle[:, :len_keep, :] # ids_keep: [bs x len_keep x nvars] | |
| x_kept = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, 1, D)) # x_kept: [bs x len_keep x nvars x patch_len] | |
| # removed x | |
| x_removed = torch.zeros(bs, L - len_keep, nvars, D, device=xb.device) # x_removed: [bs x (L-len_keep) x nvars x patch_len] | |
| x_ = torch.cat([x_kept, x_removed], dim=1) # x_: [bs x L x nvars x patch_len] | |
| # combine the kept part and the removed one | |
| x_masked = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, 1, D)) # x_masked: [bs x num_patch x nvars x patch_len] | |
| # generate the binary mask: 0 is keep, 1 is remove | |
| mask = torch.ones([bs, L, nvars], device=x.device) # mask: [bs x num_patch x nvars] | |
| mask[:, :len_keep, :] = 0 | |
| # unshuffle to get the binary mask | |
| mask = torch.gather(mask, dim=1, index=ids_restore) # [bs x num_patch x nvars] | |
| mask = mask.permute(0, 2, 1) | |
| mask = mask.reshape(-1, L) # [bs * nvars x num_patch] | |
| return x_masked, x_kept, mask, ids_restore | |
| # MAE Masking | |
| def random_masking_v2(xb, mask_ratio=0.75): | |
| bs, L, D = xb.shape # [bs x n_vars x d_model] | |
| x = xb.clone() | |
| len_keep = int(L * (1 - mask_ratio)) | |
| noise = torch.rand(bs, L, device=xb.device) # noise in [0, 1], bs x L | |
| # sort noise for each sample | |
| ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
| ids_restore = torch.argsort(ids_shuffle, dim=1) # ids_restore: [bs x L x nvars] | |
| # keep the first subset | |
| ids_keep = ids_shuffle[:, :len_keep] # ids_keep: [bs x len_keep] | |
| x_kept = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # x_kept: [bs x len_keep x d_model] | |
| # removed x | |
| x_removed = torch.zeros(bs, L - len_keep, D, device=xb.device) # x_removed: [bs x (L-len_keep) x d_model] | |
| x_ = torch.cat([x_kept, x_removed], dim=1) # x_: [bs x L x d_model] | |
| # combine the kept part and the removed one | |
| x_masked = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) # x_masked: [bs x L x d_model] | |
| # generate the binary mask: 0 is keep, 1 is remove | |
| mask = torch.ones([bs, L], device=x.device) # mask: [bs x L] | |
| mask[:, :len_keep] = 0 | |
| # unshuffle to get the binary mask | |
| mask = torch.gather(mask, dim=1, index=ids_restore) # [bs x num_patch] | |
| return x_masked, x_kept, mask, ids_restore |