lispdetector / train.py
rawag's picture
Upload 9 files
18f2c0e verified
raw
history blame
5.13 kB
from transformers import WhisperForAudioClassification
# Load pre-trained Whisper model
model = WhisperForAudioClassification.from_pretrained("openai/whisper-medium")
import pandas as pd
# Load the CSV file
df = pd.read_csv('dataset.csv')
from transformers import WhisperProcessor
# Initialize the Whisper processor
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
import librosa
import torch
# Create a custom dataset class
class LispDataset(torch.utils.data.Dataset):
def __init__(self, df):
self.df = df
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
audio_path = row['file_path']
label = row['label']
audio, original_sr = librosa.load(audio_path, sr=44100)
# Resample to target sample rate (if needed)
target_sr = 16000
if original_sr != target_sr:
audio = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr)
# Extract mel features
mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=target_sr, n_mels=80, hop_length=512)
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram) # Convert to decibels
# Pad mel spectrogram to fixed length (assuming max_len is pre-defined)
max_len = 3000 # Replace with your desired maximum length
pad_width = (0, max_len - mel_spectrogram_db.shape[1]) # Calculate padding width
mel_spectrogram_db_padded = torch.nn.functional.pad(torch.from_numpy(mel_spectrogram_db).float(),
pad_width, mode='constant', value=0)
# Convert to tensor
input_features = mel_spectrogram_db_padded
# # Convert to tensor
# input_features = torch.from_numpy(mel_spectrogram_db_padded).float()
# Create dictionary with expected key
return {'input_features': input_features, 'labels': label}
# Create a DataLoader
train_dataset = LispDataset(df)
from transformers import TrainingArguments
# Training arguments (adjust learning rate as needed)
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=10,
per_device_train_batch_size=2,
learning_rate=5e-5,
fp16=True,
use_cpu=True,
warmup_ratio=0.1,
metric_for_best_model="accuracy",
gradient_accumulation_steps=1 # No gradient accumulation (equivalent to no_auto_optimize=True)
)
from torch.optim import AdamW # Import AdamW from PyTorch
# Create the optimizer (adjust other hyperparameters as needed)
optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)
from torch.optim.lr_scheduler import LambdaLR
lambda1 = lambda epoch: epoch // 30
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1,])
optimizertuple = (optimizer,scheduler)
from transformers import Trainer
# Trainer instance
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
optimizers=optimizertuple, # Wrap optimizer in a tuple
)
# Start training
trainer.train()
# import soundfile as sf
""" # Define a custom collate function to handle variable-length audio samples
def collate_fn(batch):
# Pad audio samples to the same length
input_lengths = [len(sample[0]) for sample in batch]
max_length = max(input_lengths)
padded_inputs = torch.nn.utils.rnn.pad_sequence([torch.tensor(sample[0]) for sample in batch], batch_first=True, padding_value=0)
attention_mask = torch.tensor([[1] * length + [0] * (max_length - length) for length in input_lengths])
return {
"inputs": padded_inputs,
"attention_mask": attention_mask,
"labels": torch.tensor([sample[1] for sample in batch])
}
"""
"""
def collate_fn(batch):
# Pad audio samples to the same length
input_lengths = [len(sample[0]) for sample in batch]
max_length = max(input_lengths)
padded_inputs = torch.nn.utils.rnn.pad_sequence([torch.tensor(sample[0]) for sample in batch], batch_first=True, padding_value=0)
attention_mask = torch.tensor([[1] * length + [0] * (max_length - length) for length in input_lengths])
# Convert each element in batch to a dictionary
batch = [{'inputs': padded_inputs, 'attention_mask': attention_mask, 'labels': label} for inp, mask, label in zip(padded_inputs, attention_mask, batch)]
print (batch)
return batch """
"""
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
# lambda2 = lambda epoch: 0.95 ** epoch
# Load the audio file
audio, original_sr = librosa.load("dataset/lisp/sample_01.wav", sr=44100)
# Target sample rate
target_sr = 16000
# Resample the audio
audio_resampled = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr) """
""" inputs = processor(
audio_resampled, sampling_rate=target_sr, return_tensors="pt"
)
# Forward pass
with torch.no_grad():
logits = model(**inputs).logits
# Predict the class (0 for normal, 1 for lisp)
predicted_class = torch.argmax(logits, dim=1).item() """