|
|
from transformers import WhisperForAudioClassification
|
|
|
|
|
|
|
|
|
model = WhisperForAudioClassification.from_pretrained("openai/whisper-medium")
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
|
df = pd.read_csv('dataset.csv')
|
|
|
|
|
|
from transformers import WhisperProcessor
|
|
|
|
|
|
|
|
|
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
|
|
|
|
|
import librosa
|
|
|
import torch
|
|
|
|
|
|
|
|
|
class LispDataset(torch.utils.data.Dataset):
|
|
|
def __init__(self, df):
|
|
|
self.df = df
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.df)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
row = self.df.iloc[idx]
|
|
|
audio_path = row['file_path']
|
|
|
label = row['label']
|
|
|
|
|
|
audio, original_sr = librosa.load(audio_path, sr=44100)
|
|
|
|
|
|
|
|
|
target_sr = 16000
|
|
|
if original_sr != target_sr:
|
|
|
audio = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr)
|
|
|
|
|
|
|
|
|
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=target_sr, n_mels=80, hop_length=512)
|
|
|
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram)
|
|
|
|
|
|
|
|
|
max_len = 3000
|
|
|
pad_width = (0, max_len - mel_spectrogram_db.shape[1])
|
|
|
mel_spectrogram_db_padded = torch.nn.functional.pad(torch.from_numpy(mel_spectrogram_db).float(),
|
|
|
pad_width, mode='constant', value=0)
|
|
|
|
|
|
|
|
|
input_features = mel_spectrogram_db_padded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {'input_features': input_features, 'labels': label}
|
|
|
|
|
|
|
|
|
train_dataset = LispDataset(df)
|
|
|
|
|
|
from transformers import TrainingArguments
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
output_dir="./results",
|
|
|
num_train_epochs=10,
|
|
|
per_device_train_batch_size=2,
|
|
|
learning_rate=5e-5,
|
|
|
fp16=True,
|
|
|
use_cpu=True,
|
|
|
warmup_ratio=0.1,
|
|
|
metric_for_best_model="accuracy",
|
|
|
gradient_accumulation_steps=1
|
|
|
)
|
|
|
|
|
|
from torch.optim import AdamW
|
|
|
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)
|
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
|
|
lambda1 = lambda epoch: epoch // 30
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1,])
|
|
|
|
|
|
optimizertuple = (optimizer,scheduler)
|
|
|
|
|
|
from transformers import Trainer
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model=model,
|
|
|
args=training_args,
|
|
|
train_dataset=train_dataset,
|
|
|
optimizers=optimizertuple,
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
""" # Define a custom collate function to handle variable-length audio samples
|
|
|
def collate_fn(batch):
|
|
|
# Pad audio samples to the same length
|
|
|
input_lengths = [len(sample[0]) for sample in batch]
|
|
|
max_length = max(input_lengths)
|
|
|
padded_inputs = torch.nn.utils.rnn.pad_sequence([torch.tensor(sample[0]) for sample in batch], batch_first=True, padding_value=0)
|
|
|
attention_mask = torch.tensor([[1] * length + [0] * (max_length - length) for length in input_lengths])
|
|
|
|
|
|
return {
|
|
|
"inputs": padded_inputs,
|
|
|
"attention_mask": attention_mask,
|
|
|
"labels": torch.tensor([sample[1] for sample in batch])
|
|
|
}
|
|
|
"""
|
|
|
"""
|
|
|
def collate_fn(batch):
|
|
|
# Pad audio samples to the same length
|
|
|
input_lengths = [len(sample[0]) for sample in batch]
|
|
|
max_length = max(input_lengths)
|
|
|
padded_inputs = torch.nn.utils.rnn.pad_sequence([torch.tensor(sample[0]) for sample in batch], batch_first=True, padding_value=0)
|
|
|
attention_mask = torch.tensor([[1] * length + [0] * (max_length - length) for length in input_lengths])
|
|
|
|
|
|
# Convert each element in batch to a dictionary
|
|
|
batch = [{'inputs': padded_inputs, 'attention_mask': attention_mask, 'labels': label} for inp, mask, label in zip(padded_inputs, attention_mask, batch)]
|
|
|
print (batch)
|
|
|
|
|
|
return batch """
|
|
|
|
|
|
"""
|
|
|
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
|
|
|
|
|
|
# lambda2 = lambda epoch: 0.95 ** epoch
|
|
|
|
|
|
# Load the audio file
|
|
|
audio, original_sr = librosa.load("dataset/lisp/sample_01.wav", sr=44100)
|
|
|
|
|
|
# Target sample rate
|
|
|
target_sr = 16000
|
|
|
|
|
|
# Resample the audio
|
|
|
audio_resampled = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr) """
|
|
|
|
|
|
""" inputs = processor(
|
|
|
audio_resampled, sampling_rate=target_sr, return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
# Forward pass
|
|
|
with torch.no_grad():
|
|
|
logits = model(**inputs).logits
|
|
|
|
|
|
# Predict the class (0 for normal, 1 for lisp)
|
|
|
predicted_class = torch.argmax(logits, dim=1).item() """ |