Zaixi's picture
1
dcacefd
#!/usr/bin/python
# -*- coding:utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from torch_scatter import scatter_softmax, scatter_mean, scatter_sum, scatter_std
from tools import _unit_edges_from_block_edges
from radial_basis import RadialBasis
def stable_norm(input, *args, **kwargs):
return torch.norm(input, *args, **kwargs)
input = input.clone()
with torch.no_grad():
sign = torch.sign(input)
input = torch.abs(input)
input.clamp_(min=1e-10)
input = sign * input
return torch.norm(input, *args, **kwargs)
class GET(nn.Module):
'''Equivariant Adaptive Block Transformer'''
def __init__(self, d_hidden, d_radial, n_channel, n_rbf, cutoff=7.0, d_edge=0, n_layers=4, n_head=4,
act_fn=nn.SiLU(), residual=True, dropout=0.1, z_requires_grad=True, pre_norm=False,
sparse_k=3):
super().__init__()
'''
:param d_hidden: Number of hidden features
:param d_radial: Number of features for calculating geometric relations
:param n_channel: Number of channels of coordinates of each unit
:param n_rbf: Dimension of RBF feature, 1 for not using rbf
:param cutoff: cutoff for RBF
:param d_edge: Number of features for the edge features
:param n_layers: Number of layer
:param act_fn: Non-linearity
:param residual: Use residual connections, we recommend not changing this one
:param dropout: probability of dropout
'''
self.n_layers = n_layers
self.pre_norm = pre_norm
self.sparse_k = sparse_k
if self.pre_norm:
self.pre_layernorm = EquivariantLayerNorm(d_hidden, n_channel, n_rbf, cutoff, act_fn)
for i in range(0, n_layers):
self.add_module(f'layer_{i}', GETLayer(d_hidden, d_radial, n_channel, n_rbf, cutoff, d_edge, n_head, act_fn, residual))
self.add_module(f'layernorm0_{i}', EquivariantLayerNorm(d_hidden, n_channel, n_rbf, cutoff, act_fn))
self.add_module(f'ffn_{i}', EquivariantFFN(
d_hidden, 4 * d_hidden, d_hidden, n_channel,
n_rbf, act_fn, residual, dropout,
z_requires_grad=z_requires_grad if i == n_layers - 1 else True
))
self.add_module(f'layernorm1_{i}', EquivariantLayerNorm(d_hidden, n_channel, n_rbf, cutoff, act_fn))
if not z_requires_grad:
self._modules[f'layernorm1_{n_layers - 1}'].sigma.requires_grad = False
# @torch.no_grad()
# def self_loop_edges(self, block_id, n_blocks):
# return None
# node_ids = torch.arange(n_blocks, device=block_id.device) # [Nb]
# self_loop = torch.stack([node_ids, node_ids], dim=1) # [Nb, 2]
# (unit_src, unit_dst), _ = _unit_edges_from_block_edges(block_id, self_loop)
# return torch.stack([unit_src, unit_dst], dim=0) # [2, \sum n_i^2]
def recover_scale(self, Z, block_id, batch_id, record_scale):
with torch.no_grad():
unit_batch_id = batch_id[block_id]
Z_c = scatter_mean(Z, unit_batch_id, dim=0) # [bs, n_channel, 3]
Z_c = Z_c[unit_batch_id] # [N, n_channel, 3]
Z_centered = Z - Z_c
Z = Z_c + Z_centered / record_scale[unit_batch_id]
return Z
def forward(self, H, Z, block_id, batch_id, edges, edge_attr=None, cached_unit_edge_info=None):
if cached_unit_edge_info is None:
with torch.no_grad():
cached_unit_edge_info = _unit_edges_from_block_edges(block_id, edges.T, Z, k=self.sparse_k) # [Eu], Eu = \sum_{i, j \in E} n_i * n_j
# # FFN self-loop
# self_loop = self.self_loop_edges(block_id, batch_id.shape[0])
batch_size, n_channel = batch_id.max() + 1, Z.shape[1]
record_scale = torch.ones((batch_size, n_channel, 1), dtype=torch.float, device=Z.device)
if self.pre_norm:
H, Z, rescale = self.pre_layernorm(H, Z, block_id, batch_id)
record_scale *= rescale
for i in range(self.n_layers):
# for attention visualization
# self._modules[f'layer_{i}'].prefix = self.prefix + f'_layer{i}'
H, Z = self._modules[f'layer_{i}'](H, Z, block_id, edges, edge_attr, cached_unit_edge_info)
H, Z, rescale = self._modules[f'layernorm0_{i}'](H, Z, block_id, batch_id)
record_scale *= rescale
H, Z = self._modules[f'ffn_{i}'](H, Z, block_id)
H, Z, rescale = self._modules[f'layernorm1_{i}'](H, Z, block_id, batch_id)
record_scale *= rescale
Z = self.recover_scale(Z, block_id, batch_id, record_scale)
return H, Z
'''
Below are the implementation of the equivariant adaptive block message passing mechanism
'''
class GETLayer(nn.Module):
'''
Equivariant Adaptive Block Transformer layer
'''
def __init__(self, d_hidden, d_radial, n_channel, n_rbf, cutoff=7.0,
d_edge=0, n_head=4, act_fn=nn.SiLU(), residual=True):
super(GETLayer, self).__init__()
self.residual = residual
self.reci_sqrt_d = 1 / math.sqrt(d_radial)
self.epsilon = 1e-8
self.n_rbf = n_rbf
self.cutoff = cutoff
self.n_head = n_head
assert d_radial % self.n_head == 0, f'd_radial not compatible with n_head ({d_radial} and {self.n_head})'
assert n_rbf % self.n_head == 0, f'n_rbf not compatible with n_head ({n_rbf} and {self.n_head})'
d_hidden_head, d_radial_head = d_hidden // self.n_head, d_radial // self.n_head
n_rbf_head = n_rbf // self.n_head
self.linear_qk = nn.Linear(d_hidden_head, d_radial_head * 2, bias=False)
self.linear_v = nn.Linear(d_hidden_head, d_radial_head)
if n_rbf > 1:
self.rbf = RadialBasis(num_radial=n_rbf, cutoff=cutoff)
# self.dist_mlp = nn.Sequential(
# nn.Linear(n_channel * n_rbf, 1, bias=False),
# act_fn
# )
self.att_mlp = nn.Sequential(
nn.Linear(d_radial_head * 3 + n_channel * n_rbf_head, d_radial_head),
# radial*3 means H_q, H_k and edge_attr
act_fn,
nn.Linear(d_radial_head, d_radial_head),
act_fn
)
self.unit_att_linear = nn.Linear(d_radial_head, 1)
self.block_att_linear = nn.Linear(d_radial_head, 1)
if d_edge != 0:
self.edge_linear = nn.Linear(d_edge, d_radial)
# self.edge_mlp = nn.Sequential(
# nn.Linear(d_edge, d_hidden_head),
# act_fn,
# nn.Linear(d_hidden_head, 1),
# act_fn
# )
self.node_mlp = nn.Sequential(
nn.Linear(d_radial, d_hidden),
act_fn,
nn.Linear(d_hidden, d_hidden),
act_fn
)
self.node_out_linear = nn.Linear(d_hidden, d_hidden)
self.coord_mlp = nn.Sequential(
nn.Linear(d_radial, d_hidden),
act_fn,
nn.Linear(d_hidden, n_head * n_channel),
act_fn
)
self.unit_msg_mlp = nn.Sequential(
nn.Linear(d_radial_head + n_channel * n_rbf_head, d_radial_head),
act_fn,
nn.Linear(d_radial_head, d_radial_head),
act_fn
)
self.unit_msg_coord_mlp = nn.Sequential(
nn.Linear(d_radial_head + n_channel * n_rbf_head, d_radial_head),
act_fn,
nn.Linear(d_radial_head, d_radial_head),
act_fn
)
self.unit_msg_coord_linear = nn.Linear(d_radial_head, n_channel)
# self.coord_mlp = nn.Sequential(
# nn.Linear(1, n_channel),
# act_fn
# )
def attention(self, H, Z, edges, edge_attr, cached_unit_edge_info):
row, col = edges
(unit_row, unit_col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) = cached_unit_edge_info
# multi-head
H = H.view(H.shape[0], self.n_head, -1) # [N, n_head, hidden_size / n_head]
# calculate attention
H_qk = self.linear_qk(H)
H_q, H_k = H_qk[..., 0::2][unit_row], H_qk[..., 1::2][unit_col] # [Eu, n_head, d_radial / n_head]
dZ = Z[unit_row] - Z[unit_col] # [E_u, n_channel, 3]
# D = dZ.bmm(dZ.transpose(1, 2)).view(D.shape[0], -1) # [Eu, n_channel^2]
# D_norm = torch.norm(D + 1e-16, dim=-1, keepdim=True)
# D = D / (1 + D_norm)
# D = torch.norm(dZ + 1e-16, dim=-1) # [Eu, n_channel]
D = stable_norm(dZ, dim=-1) # [Eu, n_channel]
if self.n_rbf > 1:
n_channel = D.shape[-1]
D = self.rbf(D.view(-1)).view(D.shape[0], n_channel, self.n_head,
-1) # [Eu, n_channel, n_head, n_rbf / n_head]
D = D.transpose(1, 2).reshape(D.shape[0], self.n_head, -1) # [Eu, n_head, n_channel * n_rbf / n_head]
else:
D = D.unsqueeze(1).repeat(1, self.n_head, 1) # [Eu, n_head, n_channel]
# R = self.reci_sqrt_d * (H_q * H_k).sum(-1) + self.dist_mlp(D).squeeze() # [Eu]
if edge_attr is None:
R_repr = torch.concat([H_q, H_k, D], dim=-1) # [Eu, n_head, (d_radial * 2 + n_channel * n_rbf) / n_head]
else:
edge_attr = self.edge_linear(edge_attr).view(edge_attr.shape[0], self.n_head, -1)
R_repr = torch.concat([H_q, H_k, D, edge_attr[block_edge_id]], dim=-1)
R_repr = self.att_mlp(R_repr) # [Eu, n_head, d_radial / n_head]
R = self.unit_att_linear(R_repr).squeeze(-1) # [Eu, n_head]
alpha = scatter_softmax(R, unit_edge_src_id, dim=0).unsqueeze(
-1) # [Eu, n_head, 1], unit-level attention within block-level edges
# alpha = F.silu(R).unsqueeze(-1)
# beta = scatter_mean(R, block_edge_id) # [Eb]
# if edge_attr is not None:
# beta = beta + self.edge_mlp(edge_attr).squeeze()
# directly use mean of R is not reasonble as the value before softmax has different scales in different pairs
# using max(R) - min(R) or max(R) - mean(R) are also not reasonable as the lowerbound will be 0 instead of -inf
# so we use pooling on the representation of unit attention
beta = self.block_att_linear(scatter_mean(R_repr, block_edge_id, dim=0)).squeeze(-1) # [Eb, n_head]
beta = scatter_softmax(beta, row, dim=0) # [Eb, n_head], block-level edge attention
# beta = F.silu(beta)
# for attention visualize
# pickle.dump((alpha, beta, edges, (unit_row, unit_col)), open(f'./attention/{self.prefix}.pkl', 'wb'))
beta = beta[block_edge_id[unit_edge_src_start]].unsqueeze(-1) # [Em, n_head, 1], Em = \sum_{i, j \in E} n_i
return alpha, beta, (D, R, dZ)
def invariant_update(self, H_v, H, alpha, beta, D, cached_unit_edge_info):
(unit_row, unit_col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) = cached_unit_edge_info
unit_agg_row = unit_row[unit_edge_src_start]
# update invariant feature
H_v = self.unit_msg_mlp(torch.cat([H_v[unit_col], D], dim=-1)) # [Eu, n_head, d_radial / n_head]
H_agg = scatter_sum(alpha * H_v, unit_edge_src_id, dim=0) # [Em, n_head, hidden_size / n_head]
H_agg = H_agg.view(H_agg.shape[0], -1) # [Em, hidden_size]
H_agg = self.node_mlp(H_agg) # [Em, hidden_size]
H_agg = H_agg.view(H_agg.shape[0], self.n_head, -1) # [Em, n_head, hidden_size / n_head]
H_agg = scatter_sum(beta * H_agg, unit_agg_row, dim=0, dim_size=H.shape[0]) # [N, n_head, hidden_size / n_head]
H_agg = H_agg.view(H_agg.shape[0], -1) # [N, hidden_size]
H_agg = self.node_out_linear(H_agg)
H = H + H_agg if self.residual else H_agg
return H
def equivariant_update(self, H_v, Z, alpha, beta, D, dZ, cached_unit_edge_info):
(unit_row, unit_col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) = cached_unit_edge_info
unit_agg_row = unit_row[unit_edge_src_start]
# update equivariant feature
# H_v = self.unit_msg_coord_mlp(torch.cat([H_v[unit_col], D], dim=-1)) # [Eu, n_head, n_channel]
H_v = self.unit_msg_coord_mlp(torch.cat([H_v[unit_col], D], dim=-1)) # [Eu, n_head, d_radial / n_head]
Z_agg = scatter_sum(
(alpha * self.unit_msg_coord_linear(H_v)).unsqueeze(-1) * dZ.unsqueeze(1),
unit_edge_src_id, dim=0) # [Em, n_head, n_channel, 3]
Z_H_agg = scatter_sum(alpha * H_v, unit_edge_src_id, dim=0) # [Em, n_head, d_radial / n_head]
Z_H_agg = self.coord_mlp(Z_H_agg.view(Z_H_agg.shape[0], -1)) # [Em, d_radial]
Z_H_agg = Z_H_agg.view(Z_H_agg.shape[0], self.n_head, -1) # [Em, n_head, n_channel]
Z_agg = scatter_sum(
(beta * Z_H_agg).unsqueeze(-1) * Z_agg, unit_agg_row,
dim=0, dim_size=Z.shape[0]) # [N, n_head, n_channel, 3]
Z_agg = Z_agg.sum(dim=1) # [N, n_channel, 3]
Z = Z + Z_agg
return Z
def forward(self, H, Z, block_id, edges, edge_attr=None, cached_unit_edge_info=None):
'''
H: [N, hidden_size],
Z: [N, n_channel, 3],
block_id: [N],
edges: [2, E], list of [n_row] and [n_col] where n_row == n_col == E, nodes from col are used to update nodes from row
edge_attr: [E]
cached_unit_edge_info: unit level (row, col), (block_edge_id, unit_edge_src_start, unit_edge_src_id) calculated from block edges
'''
with torch.no_grad():
if cached_unit_edge_info is None:
cached_unit_edge_info = _unit_edges_from_block_edges(block_id,
edges.T) # [Eu], Eu = \sum_{i, j \in E} n_i * n_j
alpha, beta, (D, R, dZ) = self.attention(H, Z, edges, edge_attr, cached_unit_edge_info)
H_v = self.linear_v(H.view(H.shape[0], self.n_head, -1)) # [N, n_head, d_radial / n_head]
H = self.invariant_update(H_v, H, alpha, beta, D, cached_unit_edge_info)
Z = self.equivariant_update(H_v, Z, alpha, beta, D, dZ, cached_unit_edge_info)
return H, Z
class EquivariantFFN(nn.Module):
def __init__(self, d_in, d_hidden, d_out, n_channel, n_rbf=16, act_fn=nn.SiLU(),
residual=True, dropout=0.1, constant=1, z_requires_grad=True) -> None:
super().__init__()
self.constant = constant
self.residual = residual
self.n_rbf = n_rbf
# self.mlp_msg = nn.Sequential(
# nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden),
# act_fn,
# nn.Dropout(dropout),
# nn.Linear(d_hidden, d_hidden),
# act_fn,
# nn.Dropout(dropout),
# )
self.mlp_h = nn.Sequential(
nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, d_out),
nn.Dropout(dropout)
)
self.mlp_z = nn.Sequential(
nn.Linear(d_in * 2 + n_channel * n_rbf, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, d_hidden),
act_fn,
nn.Dropout(dropout),
nn.Linear(d_hidden, n_channel),
nn.Dropout(dropout)
)
# self.mlp_z = nn.Linear(d_hidden, n_channel)
if not z_requires_grad:
for param in self.mlp_z.parameters():
param.requires_grad = False
self.rbf = RadialBasis(n_rbf, 7.0)
# self.linear_radial = nn.Linear(n_channel * n_rbf, d_in)
# self.linear_radial = nn.Linear(n_channel * n_channel, d_in)
def forward(self, H, Z, block_id):
'''
:param H: [N, d_in]
:param Z: [N, n_channel, 3]
:param block_id: [Nu]
'''
# row, col = self_loop
# Z_diff = Z[row] - Z[col] # [E, n_channel, 3]
# radial = stable_norm(Z_diff, dim=-1) # [E, n_channel]
# radial = self.rbf(radial.view(-1)).view(radial.shape[0], -1) # [E, n_channel * n_rbf]
# msg = self.mlp_msg(torch.cat([H[row], H[col], radial], dim=-1)) # [E, d_hidden]
# agg = scatter_sum(msg, row, dim=0) # [Nu, d_hidden]
# H_update = self.mlp_h(torch.cat([H, agg], dim=-1)) # [Nu, d_out]
# H = H + H_update if self.residual else H_update
# Z = Z + scatter_sum(self.mlp_z(msg).unsqueeze(-1) * Z_diff, row, dim=0)
# return H, Z
radial, (Z_c, Z_o) = self._radial(Z, block_id) # [N, n_hidden_channel], ([N, 1, 3], [N, n_channel, 3]
H_c = scatter_mean(H, block_id, dim=0)[block_id] # [N, d_in]
inputs = torch.cat([H, H_c, radial], dim=-1) # [N, d_in + d_in + d_in]
H_update = self.mlp_h(inputs)
H = H + H_update if self.residual else H_update
Z = Z_c + self.mlp_z(inputs).unsqueeze(-1) * Z_o
return H, Z
def _radial(self, Z, block_id):
Z_c = scatter_mean(Z, block_id, dim=0) # [Nb, n_channel, 3]
Z_c = Z_c[block_id]
Z_o = Z - Z_c # [N, n_channel, 3], no translation
D = stable_norm(Z_o, dim=-1) # [N, n_channel]
radial = self.rbf(D.view(-1)).view(D.shape[0], -1) # [N, n_channel * n_rbf]
# radial = Z_o.bmm(Z_o.transpose(1, 2)) # [N, n_channel, n_channel], no orthogonal transformation
# radial = radial.reshape(Z.shape[0], -1) # [N, n_channel^2]
# # radial_norm = torch.norm(radial + 1e-16, dim=-1, keepdim=True) # [N, 1]
# radial_norm = stable_norm(radial, dim=-1, keepdim=True) # [N, 1]
# radial = radial / (self.constant + radial_norm) # normalize for numerical stability
# radial = self.linear_radial(radial) # [N, d_in]
return radial, (Z_c, Z_o)
class EquivariantLayerNorm(nn.Module):
def __init__(self, d_hidden, n_channel, n_rbf=16, cutoff=7.0, act_fn=nn.SiLU()) -> None:
super().__init__()
# invariant
self.fuse_scale_ffn = nn.Sequential(
nn.Linear(n_channel * n_rbf, d_hidden),
act_fn,
nn.Linear(d_hidden, d_hidden),
act_fn
)
self.layernorm = nn.LayerNorm(d_hidden)
# geometric
sigma = torch.ones((1, n_channel, 1))
self.sigma = nn.Parameter(sigma, requires_grad=True)
self.rbf = RadialBasis(num_radial=n_rbf, cutoff=cutoff)
def forward(self, H, Z, block_id, batch_id):
with torch.no_grad():
_, n_channel, n_axis = Z.shape
unit_batch_id = batch_id[block_id]
unit_axis_batch_id = unit_batch_id.unsqueeze(-1).repeat(1, n_axis).flatten() # [N * 3]
# H = self.layernorm(H)
Z_c = scatter_mean(Z, unit_batch_id, dim=0) # [bs, n_channel, 3]
Z_c = Z_c[unit_batch_id] # [N, n_channel, 3]
Z_centered = Z - Z_c
var = scatter_std(
Z_centered.transpose(1, 2).reshape(-1, n_channel).contiguous(),
unit_axis_batch_id, dim=0) # [bs, n_channel]
# var = var[unit_batch_id].unsqueeze(-1) # [N, n_channel, 1]
# Z = Z_c + Z_centered / var * self.sigma
rescale = (1 / var).unsqueeze(-1) * self.sigma # [bs, n_channel, 1]
Z = Z_c + Z_centered * rescale[unit_batch_id]
rescale_rbf = self.rbf(rescale.view(-1)).view(rescale.shape[0], -1) # [bs, n_channel * n_rbf]
H = H + self.fuse_scale_ffn(rescale_rbf)[unit_batch_id]
H = self.layernorm(H)
return H, Z, rescale
class GETEncoder(nn.Module):
def __init__(self, hidden_size, radial_size, n_channel,
n_rbf=1, cutoff=7.0, edge_size=16, n_layers=3,
n_head=1, dropout=0.1,
z_requires_grad=True, stable=False) -> None:
super().__init__()
self.encoder = GET(
hidden_size, radial_size, n_channel,
n_rbf, cutoff, edge_size, n_layers,
n_head, dropout=dropout,
z_requires_grad=z_requires_grad
)
def forward(self, H, Z, block_id, batch_id, edges, edge_attr=None):
H, pred_Z = self.encoder(H, Z, block_id, batch_id, edges, edge_attr)
# block_repr = scatter_mean(H, block_id, dim=0) # [Nb, hidden]
block_repr = scatter_sum(H, block_id, dim=0) # [Nb, hidden]
block_repr = F.normalize(block_repr, dim=-1)
# graph_repr = scatter_mean(block_repr, batch_id, dim=0) # [bs, hidden]
graph_repr = scatter_sum(block_repr, batch_id, dim=0) # [bs, hidden]
graph_repr = F.normalize(graph_repr, dim=-1)
return H, block_repr, graph_repr, pred_Z
if __name__ == '__main__':
d_hidden = 64
d_radial = 16
n_channel = 2
d_edge = 16
n_rbf = 16
n_head = 4
device = torch.device('cuda:0')
model = GET(d_hidden, d_radial, n_channel, n_rbf, d_edge=d_edge, n_head=n_head)
model.to(device)
model.eval()
block_id = torch.tensor([0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4, 5, 6, 6, 6, 6, 7, 7], dtype=torch.long).to(device)
batch_id = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1], dtype=torch.long).to(device)
src_dst = torch.tensor([[0, 1], [2, 3], [1, 3], [2, 4], [3, 0], [3, 3], [5, 7], [7, 6], [5, 6], [6, 7]],
dtype=torch.long).to(device)
src_dst = src_dst.T
edge_attr = torch.randn(len(src_dst[0]), d_edge).to(device)
n_unit = block_id.shape[0]
H = torch.randn(n_unit, d_hidden, device=device)
Z = torch.randn(n_unit, n_channel, 3, device=device)
print(_unit_edges_from_block_edges(block_id, src_dst.T, Z, k=3))
H1, Z1 = model(H, Z, block_id, batch_id, src_dst, edge_attr)
# random rotaion matrix
U, _, V = torch.linalg.svd(torch.randn(3, 3, device=device, dtype=torch.float))
if torch.linalg.det(U) * torch.linalg.det(V) < 0:
U[:, -1] = -U[:, -1]
Q1, t1 = U.mm(V), torch.randn(3, device=device)
U, _, V = torch.linalg.svd(torch.randn(3, 3, device=device, dtype=torch.float))
if torch.linalg.det(U) * torch.linalg.det(V) < 0:
U[:, -1] = -U[:, -1]
Q2, t2 = U.mm(V), torch.randn(3, device=device)
unit_batch_id = batch_id[block_id]
Z[unit_batch_id == 0] = torch.matmul(Z[unit_batch_id == 0], Q1) + t1
Z[unit_batch_id == 1] = torch.matmul(Z[unit_batch_id == 1], Q2) + t2
# Z = torch.matmul(Z, Q) + t
H2, Z2 = model(H, Z, block_id, batch_id, src_dst, edge_attr)
print(f'invariant feature: {torch.abs(H1 - H2).sum()}')
Z1[unit_batch_id == 0] = torch.matmul(Z1[unit_batch_id == 0], Q1) + t1
Z1[unit_batch_id == 1] = torch.matmul(Z1[unit_batch_id == 1], Q2) + t2
print(f'equivariant feature: {torch.abs(Z1 - Z2).sum()}')