Zaixi's picture
1
dcacefd
from .cftfm import BilevelEncoder, HierEncoder, GET
from .gnn import GNN_graphpred, MLP, WeightGNN
def get_encoder(config, device):
if config.name == 'tf':
return HierEncoder(
hidden_channels = config.hidden_channels,
edge_channels = config.edge_channels,
key_channels = config.key_channels,
num_heads = config.num_heads,
num_interactions = config.num_interactions,
k = config.knn,
cutoff = config.cutoff,
)
elif config.name == 'hierGT':
return BilevelEncoder(
hidden_channels = config.hidden_channels,
edge_channels = config.edge_channels,
key_channels = config.key_channels,
num_heads = config.num_heads,
num_interactions = config.num_interactions,
k = config.knn,
cutoff = config.cutoff,
device = device
)
elif config.name == 'GET':
return GET(
hidden_channels = config.hidden_channels,
edge_channels = config.edge_channels,
key_channels = config.key_channels,
num_heads = config.num_heads,
num_interactions = config.num_interactions,
k = config.knn,
cutoff = config.cutoff,
device = device
)
else:
raise NotImplementedError('Unknown encoder: %s' % config.name)