rawag commited on
Commit
344b95e
·
verified ·
1 Parent(s): 18f2c0e

Upload train.ipynb

Browse files
Files changed (1) hide show
  1. train.ipynb +179 -0
train.ipynb ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "vscode": {
8
+ "languageId": "shellscript"
9
+ }
10
+ },
11
+ "outputs": [],
12
+ "source": [
13
+ "pip install transformers"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "from transformers import WhisperForAudioClassification\n",
23
+ "# Load pre-trained Whisper model\n",
24
+ "model = WhisperForAudioClassification.from_pretrained(\"openai/whisper-medium\")"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "import pandas as pd\n",
34
+ "\n",
35
+ "# Load the CSV file\n",
36
+ "df = pd.read_csv('dataset.csv')"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "from transformers import WhisperProcessor\n",
46
+ "\n",
47
+ "# Initialize the Whisper processor\n",
48
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-medium\")"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "import librosa\n",
58
+ "import torch\n",
59
+ "\n",
60
+ "# Create a custom dataset class\n",
61
+ "class LispDataset(torch.utils.data.Dataset):\n",
62
+ " def __init__(self, df):\n",
63
+ " self.df = df\n",
64
+ "\n",
65
+ " def __len__(self):\n",
66
+ " return len(self.df)\n",
67
+ " \n",
68
+ " def __getitem__(self, idx):\n",
69
+ " row = self.df.iloc[idx]\n",
70
+ " audio_path = row['file_path']\n",
71
+ " label = row['label']\n",
72
+ "\n",
73
+ " audio, original_sr = librosa.load(audio_path, sr=44100)\n",
74
+ "\n",
75
+ " # Resample to target sample rate (if needed)\n",
76
+ " target_sr = 16000\n",
77
+ " if original_sr != target_sr:\n",
78
+ " audio = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr)\n",
79
+ "\n",
80
+ " # Extract mel features\n",
81
+ " mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=target_sr, n_mels=80, hop_length=512)\n",
82
+ " mel_spectrogram_db = librosa.power_to_db(mel_spectrogram) # Convert to decibels\n",
83
+ "\n",
84
+ " # Pad mel spectrogram to fixed length (assuming max_len is pre-defined)\n",
85
+ " max_len = 3000 # Replace with your desired maximum length\n",
86
+ " pad_width = (0, max_len - mel_spectrogram_db.shape[1]) # Calculate padding width\n",
87
+ " mel_spectrogram_db_padded = torch.nn.functional.pad(torch.from_numpy(mel_spectrogram_db).float(), \n",
88
+ " pad_width, mode='constant', value=0)\n",
89
+ "\n",
90
+ " # Convert to tensor\n",
91
+ " input_features = mel_spectrogram_db_padded\n",
92
+ "\n",
93
+ " # # Convert to tensor\n",
94
+ " # input_features = torch.from_numpy(mel_spectrogram_db_padded).float()\n",
95
+ "\n",
96
+ " # Create dictionary with expected key\n",
97
+ " return {'input_features': input_features, 'labels': label}\n",
98
+ " \n",
99
+ "# Create a DataLoader\n",
100
+ "train_dataset = LispDataset(df)"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "from transformers import TrainingArguments\n",
110
+ "\n",
111
+ "# Training arguments (adjust learning rate as needed)\n",
112
+ "training_args = TrainingArguments(\n",
113
+ " output_dir=\"./results\",\n",
114
+ " num_train_epochs=10,\n",
115
+ " per_device_train_batch_size=2,\n",
116
+ " learning_rate=5e-5,\n",
117
+ " fp16=True,\n",
118
+ " use_cpu=True,\n",
119
+ " warmup_ratio=0.1,\n",
120
+ " metric_for_best_model=\"accuracy\",\n",
121
+ " gradient_accumulation_steps=1 # No gradient accumulation (equivalent to no_auto_optimize=True)\n",
122
+ ")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "from torch.optim import AdamW # Import AdamW from PyTorch\n",
132
+ "\n",
133
+ "# Create the optimizer (adjust other hyperparameters as needed)\n",
134
+ "optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "from torch.optim.lr_scheduler import LambdaLR\n",
144
+ "\n",
145
+ "lambda1 = lambda epoch: epoch // 30\n",
146
+ "scheduler = LambdaLR(optimizer, lr_lambda=[lambda1,])\n",
147
+ "\n",
148
+ "optimizertuple = (optimizer,scheduler)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "from transformers import Trainer\n",
158
+ "\n",
159
+ "# Trainer instance\n",
160
+ "trainer = Trainer(\n",
161
+ " model=model,\n",
162
+ " args=training_args,\n",
163
+ " train_dataset=train_dataset,\n",
164
+ " optimizers=optimizertuple, # Wrap optimizer in a tuple\n",
165
+ ")\n",
166
+ "\n",
167
+ "# Start training\n",
168
+ "trainer.train()"
169
+ ]
170
+ }
171
+ ],
172
+ "metadata": {
173
+ "language_info": {
174
+ "name": "python"
175
+ }
176
+ },
177
+ "nbformat": 4,
178
+ "nbformat_minor": 2
179
+ }