JuanJoseMV's picture
add model logic implementation
8f96165
raw
history blame
753 Bytes
import os
import argparse
import torch
from .make_model import make_model
hparams_dict = {
'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53',
'DATASET': 'recanvo',
'MAX_DURATION': 4,
'SAMPLING_RATE': 16_000,
'OUTPUT_HIDDEN_STATES': True,
'CLASSIFIER_NAME': 'multilevel',
'CLASSIFIER_PROJ_SIZE': 256,
'NUM_LABELS': 3,
'LABEL_WEIGHTS': [1.0],
'LOSS': 'cross-entropy',
'GPU_ID': 0,
'RETURN_RAW_ARRAY': False,
}
hparams = argparse.Namespace(**hparams_dict)
def get_behaviour_model(behaviour_model_path, device):
state_dict = torch.load(os.path.join(behaviour_model_path, 'pytorch_model.bin'), map_location=device)
model = make_model(hparams)
model.load_state_dict(state_dict)
return model