Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, random_split, TensorDataset | |
| from src.dataset import TokenizerDataset | |
| from src.bert import BERT | |
| from src.pretrainer import BERTFineTuneTrainer1 | |
| from src.vocab import Vocab | |
| import pandas as pd | |
| def preprocess_labels(label_csv_path): | |
| try: | |
| labels_df = pd.read_csv(label_csv_path) | |
| labels = labels_df['last_hint_class'].values.astype(int) | |
| return torch.tensor(labels, dtype=torch.long) | |
| except Exception as e: | |
| print(f"Error reading dataset file: {e}") | |
| return None | |
| def preprocess_data(data_path, vocab, max_length=128): | |
| try: | |
| with open(data_path, 'r') as f: | |
| sequences = f.readlines() | |
| except Exception as e: | |
| print(f"Error reading data file: {e}") | |
| return None, None | |
| tokenized_sequences = [] | |
| for sequence in sequences: | |
| sequence = sequence.strip() | |
| if sequence: | |
| encoded = vocab.to_seq(sequence, seq_len=max_length) | |
| encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded)) | |
| segment_label = [0] * max_length | |
| tokenized_sequences.append({ | |
| 'input_ids': torch.tensor(encoded), | |
| 'segment_label': torch.tensor(segment_label) | |
| }) | |
| input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0) | |
| segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0) | |
| print(f"Input IDs shape: {input_ids.shape}") | |
| print(f"Segment labels shape: {segment_labels.shape}") | |
| return input_ids, segment_labels | |
| def custom_collate_fn(batch): | |
| inputs = [item['input_ids'].unsqueeze(0) for item in batch] | |
| labels = [item['label'].unsqueeze(0) for item in batch] | |
| segment_labels = [item['segment_label'].unsqueeze(0) for item in batch] | |
| inputs = torch.cat(inputs, dim=0) | |
| labels = torch.cat(labels, dim=0) | |
| segment_labels = torch.cat(segment_labels, dim=0) | |
| return { | |
| 'input': inputs, | |
| 'label': labels, | |
| 'segment_label': segment_labels | |
| } | |
| def main(opt): | |
| # Set device to GPU if available, otherwise use CPU | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load vocabulary | |
| vocab = Vocab(opt.vocab_file) | |
| vocab.load_vocab() | |
| # Preprocess data and labels | |
| input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=50) # Using sequence length 50 | |
| labels = preprocess_labels(opt.dataset) | |
| if input_ids is None or segment_labels is None or labels is None: | |
| print("Error in preprocessing data. Exiting.") | |
| return | |
| # Create TensorDataset and split into train and validation sets | |
| dataset = TensorDataset(input_ids, segment_labels, labels) | |
| val_size = len(dataset) - int(0.8 * len(dataset)) | |
| val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size]) | |
| # Create DataLoaders for training and validation | |
| train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn) | |
| val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn) | |
| # Initialize custom BERT model and move it to the device | |
| custom_model = CustomBERTModel( | |
| vocab_size=len(vocab.vocab), | |
| output_dim=2, | |
| pre_trained_model_path=opt.pre_trained_model_path | |
| ).to(device) | |
| # Initialize the fine-tuning trainer | |
| trainer = BERTFineTuneTrainer1( | |
| bert=custom_model, | |
| vocab_size=len(vocab.vocab), | |
| train_dataloader=train_dataloader, | |
| test_dataloader=val_dataloader, | |
| lr=1e-5, # Using learning rate 10^-5 as specified | |
| num_labels=2, | |
| with_cuda=torch.cuda.is_available(), | |
| log_freq=10, | |
| workspace_name=opt.output_dir, | |
| log_folder_path=opt.log_folder_path | |
| ) | |
| # Train the model | |
| trainer.train(epoch=20) | |
| # Save the model | |
| os.makedirs(opt.output_dir, exist_ok=True) | |
| output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_3.pth') | |
| torch.save(custom_model, output_model_file) | |
| print(f'Model saved to {output_model_file}') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Fine-tune BERT model.') | |
| parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.') | |
| parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.') | |
| parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.') | |
| parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.') | |
| parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.') | |
| parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct', help='Path to the folder for saving logs.') | |
| opt = parser.parse_args() | |
| main(opt) | |