my-tide-env / utils /masking.py
SeungHyeok Jang
Upload model files with Git LFS
e1ccef5
raw
history blame
11.4 kB
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
class TriangularCausalMask():
def __init__(self, B, L, device="cpu"):
mask_shape = [B, 1, L, L]
with torch.no_grad():
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
@property
def mask(self):
return self._mask
class ProbMask():
def __init__(self, B, H, L, index, scores, device="cpu"):
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
indicator = _mask_ex[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :].to(device)
self._mask = indicator.view(scores.shape).to(device)
@property
def mask(self):
return self._mask
def generate_continuous_mask(B, T, C=None, n=5, l=0.1):
if C:
res = torch.full((B, T, C), True, dtype=torch.bool)
else:
res = torch.full((B, T), True, dtype=torch.bool)
if isinstance(n, float):
n = int(n * T)
n = max(min(n, T), 1)
if isinstance(l, float):
l = int(l * T)
l = max(l, 1)
for i in range(B):
for _ in range(n):
t = np.random.randint(T - l + 1)
if C:
# For a continuous timestamps, mask a random selection of channels
num_channels_to_mask = np.random.randint(1, C + 1) # Randomly decide how many channels to mask
index = np.random.choice(C, num_channels_to_mask, replace=False) # Select random channels to mask
res[i, t:t + l, index] = False
else:
# For a continuous timestamps, mask all channels
res[i, t:t + l] = False
return res
def expand_tensor(input_tensor, third_dim_size):
# 将输入张量转换为三维张量
expanded_tensor = input_tensor.unsqueeze(2).expand(-1, -1, third_dim_size)
return expanded_tensor.bool()
def geom_noise_mask_single(L, lm, masking_ratio):
"""
Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio`
proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution.
Args:
L: length of mask and sequence to be masked
lm: average length of masking subsequences (streaks of 0s)
masking_ratio: proportion of L to be masked
Returns:
(L,) boolean numpy array intended to mask ('drop') with 0s a sequence of length L
"""
keep_mask = np.ones(L, dtype=bool)
p_m = 1 / lm # probability of each masking sequence stopping. parameter of geometric distribution.
p_u = p_m * masking_ratio / (
1 - masking_ratio) # probability of each unmasked sequence stopping. parameter of geometric distribution.
p = [p_m, p_u]
# Start in state 0 with masking_ratio probability
state = int(np.random.rand() > masking_ratio) # state 0 means masking, 1 means not masking
for i in range(L):
keep_mask[i] = state # here it happens that state and masking value corresponding to state are identical
if np.random.rand() < p[state]:
state = 1 - state
return keep_mask
def generate_geometric_mask(B, T, C=None, p=0.75, l=3):
if C:
mask = np.ones((B, T, C), dtype=bool)
else:
mask = np.ones((B, T), dtype=bool)
for i in range(B):
if C:
for c in range(C):
mask[i, :, c] = geom_noise_mask_single(T, l, p)
else:
mask[i, :] = geom_noise_mask_single(T, l, p)
return torch.from_numpy(mask).to(torch.bool)
def generate_binomial_mask(B, T, C=None, p=0.5):
if C:
return torch.from_numpy(np.random.binomial(1, 1 - p, size=(B, T, C))).to(torch.bool)
else:
return torch.from_numpy(np.random.binomial(1, 1 - p, size=(B, T))).to(torch.bool)
# def mask_function(x, mask_type, p):
# if mask_type == 'binomial':
# mask = generate_binomial_mask(x.size(0), x.size(1), p=p)
# mask = expand_tensor(mask, x.size(2)).to(x.device)
# elif mask_type == 'channel_binomial':
# mask = generate_binomial_mask(x.size(0), x.size(1), x.size(2), p=p).to(x.device)
# elif mask_type == 'continuous':
# mask = generate_continuous_mask(x.size(0), x.size(1), n=p).to(x.device)
# mask = expand_tensor(mask, x.size(2)).to(x.device)
# elif mask_type == 'channel_continuous':
# mask = generate_continuous_mask(x.size(0), x.size(1), x.size(2), n=p).to(x.device)
# elif mask_type == 'all_true':
# mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
# elif mask_type == 'all_false':
# mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool)
# elif mask_type == 'mask_last':
# mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
# idx = int(x.size(1) * (1 - p))
# mask[:, idx:] = False
# mask = expand_tensor(mask, x.size(2)).to(x.device)
# else:
# raise ValueError(f'\'{mask_type}\' is a wrong argument for mask function!')
#
# x = mask * x
#
# return x, mask
# def mask_function(x, mask_type, mask_ratio, lm):
# if mask_type == 'binomial':
# mask = generate_binomial_mask(x.size(0), x.size(1), p=mask_ratio).to(x.device)
# mask = expand_tensor(mask, x.shape[-1])
# elif mask_type == 'channel_binomial':
# mask = generate_binomial_mask(x.size(0), x.size(1), x.size(2), p=mask_ratio).to(x.device)
# elif mask_type == 'continuous':
# mask = generate_geometric_mask(x.size(0), x.size(1), p=mask_ratio, l=lm).to(x.device)
# mask = expand_tensor(mask, x.shape[-1])
# elif mask_type == 'channel_continuous':
# mask = generate_geometric_mask(x.size(0), x.size(1), x.size(2), p=mask_ratio, l=lm).to(x.device)
# elif mask_type == 'all_true':
# mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
# elif mask_type == 'all_false':
# mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool)
# elif mask_type == 'mask_last':
# mask = x.new_full((x.size(0), x.size(1), x.size(2)), True, dtype=torch.bool)
# idx = int(x.shape[1] * mask_ratio)
# mask[:, -idx:, :] = False
# else:
# raise ValueError(f'\'{mask_type}\' is a wrong argument for mask function!')
#
# x = mask * x
#
# return x, mask # x: [b, s, c] mask: True: unmasked, False: masked
def patch_mask(x, mask_ratio, patch_len=12, stride=12):
px = x.clone().permute(0, 2, 1)
padding_patch_layer = nn.ReplicationPad1d((0, stride))
px = padding_patch_layer(px)
px = px.unfold(dimension=-1, size=patch_len, step=stride)
px = torch.reshape(px, (px.shape[0] * px.shape[1], px.shape[2], px.shape[3]))
mask = generate_binomial_mask(px.size(0), px.size(1), p=mask_ratio).to(x.device)
return mask
def mask_function(x, args):
b, s, c = x.shape
if args.masked_rule == 'binomial':
mask = generate_binomial_mask(x.size(0), x.size(1), p=args.mask_rate).to(x.device)
mask = expand_tensor(mask, x.shape[-1])
elif args.masked_rule == 'channel_binomial':
mask = generate_binomial_mask(x.size(0), x.size(1), x.size(2), p=args.mask_rate).to(x.device)
elif args.masked_rule == 'continuous':
mask = generate_geometric_mask(x.size(0), x.size(1), p=args.mask_rate, l=args.lm).to(x.device)
mask = expand_tensor(mask, x.shape[-1])
elif args.masked_rule == 'channel_continuous':
mask = generate_geometric_mask(x.size(0), x.size(1), x.size(2), p=args.mask_rate, l=args.lm).to(x.device)
elif args.masked_rule == 'mask_last':
mask = x.new_full((x.size(0), x.size(1), x.size(2)), True, dtype=torch.bool)
idx = int(x.shape[1] * args.mask_rate)
mask[:, -idx:, :] = False
elif args.masked_rule == 'mask_patch':
mask = patch_mask(x, args.mask_rate, args.patch_len, args.stride)
mask = expand_tensor(mask, args.patch_len)
mask = mask.reshape(b, c, -1)[:, :, :s].permute(0, 2, 1)
else:
raise ValueError(f'\'{args.mask_rate}\' is a wrong argument for mask function!')
x = mask * x
return x, mask # x: [b, s, c] unmasked: True, masked, False
# MAE Masking
def random_masking(xb, mask_ratio=0.75):
bs, L, nvars, D = xb.shape # xb: [bs x num_patch x n_vars x patch_len]
x = xb.clone()
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(bs, L, nvars, device=xb.device) # noise in [0, 1], bs x L x nvars
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1) # ids_restore: [bs x L x nvars]
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep, :] # ids_keep: [bs x len_keep x nvars]
x_kept = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, 1, D)) # x_kept: [bs x len_keep x nvars x patch_len]
# removed x
x_removed = torch.zeros(bs, L - len_keep, nvars, D, device=xb.device) # x_removed: [bs x (L-len_keep) x nvars x patch_len]
x_ = torch.cat([x_kept, x_removed], dim=1) # x_: [bs x L x nvars x patch_len]
# combine the kept part and the removed one
x_masked = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, 1, D)) # x_masked: [bs x num_patch x nvars x patch_len]
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([bs, L, nvars], device=x.device) # mask: [bs x num_patch x nvars]
mask[:, :len_keep, :] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore) # [bs x num_patch x nvars]
mask = mask.permute(0, 2, 1)
mask = mask.reshape(-1, L) # [bs * nvars x num_patch]
return x_masked, x_kept, mask, ids_restore
# MAE Masking
def random_masking_v2(xb, mask_ratio=0.75):
bs, L, D = xb.shape # [bs x n_vars x d_model]
x = xb.clone()
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(bs, L, device=xb.device) # noise in [0, 1], bs x L
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1) # ids_restore: [bs x L x nvars]
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep] # ids_keep: [bs x len_keep]
x_kept = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # x_kept: [bs x len_keep x d_model]
# removed x
x_removed = torch.zeros(bs, L - len_keep, D, device=xb.device) # x_removed: [bs x (L-len_keep) x d_model]
x_ = torch.cat([x_kept, x_removed], dim=1) # x_: [bs x L x d_model]
# combine the kept part and the removed one
x_masked = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) # x_masked: [bs x L x d_model]
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([bs, L], device=x.device) # mask: [bs x L]
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore) # [bs x num_patch]
return x_masked, x_kept, mask, ids_restore