Spaces:
Sleeping
Sleeping
| from .cftfm import BilevelEncoder, HierEncoder, GET | |
| from .gnn import GNN_graphpred, MLP, WeightGNN | |
| def get_encoder(config, device): | |
| if config.name == 'tf': | |
| return HierEncoder( | |
| hidden_channels = config.hidden_channels, | |
| edge_channels = config.edge_channels, | |
| key_channels = config.key_channels, | |
| num_heads = config.num_heads, | |
| num_interactions = config.num_interactions, | |
| k = config.knn, | |
| cutoff = config.cutoff, | |
| ) | |
| elif config.name == 'hierGT': | |
| return BilevelEncoder( | |
| hidden_channels = config.hidden_channels, | |
| edge_channels = config.edge_channels, | |
| key_channels = config.key_channels, | |
| num_heads = config.num_heads, | |
| num_interactions = config.num_interactions, | |
| k = config.knn, | |
| cutoff = config.cutoff, | |
| device = device | |
| ) | |
| elif config.name == 'GET': | |
| return GET( | |
| hidden_channels = config.hidden_channels, | |
| edge_channels = config.edge_channels, | |
| key_channels = config.key_channels, | |
| num_heads = config.num_heads, | |
| num_interactions = config.num_interactions, | |
| k = config.knn, | |
| cutoff = config.cutoff, | |
| device = device | |
| ) | |
| else: | |
| raise NotImplementedError('Unknown encoder: %s' % config.name) |