| import os | |
| import argparse | |
| import torch | |
| from .make_model import make_model | |
| hparams_dict = { | |
| 'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53', | |
| 'DATASET': 'recanvo', | |
| 'MAX_DURATION': 4, | |
| 'SAMPLING_RATE': 16_000, | |
| 'OUTPUT_HIDDEN_STATES': True, | |
| 'CLASSIFIER_NAME': 'multilevel', | |
| 'CLASSIFIER_PROJ_SIZE': 256, | |
| 'NUM_LABELS': 3, | |
| 'LABEL_WEIGHTS': [1.0], | |
| 'LOSS': 'cross-entropy', | |
| 'GPU_ID': 0, | |
| 'RETURN_RAW_ARRAY': False, | |
| } | |
| hparams = argparse.Namespace(**hparams_dict) | |
| def get_behaviour_model(behaviour_model_path, device): | |
| state_dict = torch.load(os.path.join(behaviour_model_path, 'pytorch_model.bin'), map_location=device) | |
| model = make_model(hparams) | |
| model.load_state_dict(state_dict) | |
| return model |