Spaces:
Runtime error
Runtime error
Add all reference for training model with BioMistrall process to add
Browse files- app.py +33 -32
- requirements.txt +8 -2
- spanish_medica_llm.py +303 -1
app.py
CHANGED
|
@@ -9,7 +9,7 @@ import sys
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
|
| 12 |
-
from spanish_medica_llm import run_training
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
|
|
@@ -42,37 +42,38 @@ def train_model(*inputs):
|
|
| 42 |
if "IS_SHARED_UI" in os.environ:
|
| 43 |
raise gr.Error("This Space only works in duplicated instances")
|
| 44 |
|
| 45 |
-
args_general = argparse.Namespace(
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
)
|
| 66 |
-
run_training(args_general)
|
| 67 |
-
torch.cuda.empty_cache()
|
| 68 |
-
#convert("output_model", "model.ckpt")
|
| 69 |
-
#shutil.rmtree('instance_images')
|
| 70 |
-
#shutil.make_archive("diffusers_model", 'zip', "output_model")
|
| 71 |
-
#with zipfile.ZipFile('diffusers_model.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 72 |
-
# zipdir('output_model/', zipf)
|
| 73 |
-
torch.cuda.empty_cache()
|
| 74 |
-
return [gr.update(visible=True, value=["diffusers_model.zip"]), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
|
| 75 |
-
|
|
|
|
| 76 |
def stop_model(*input):
|
| 77 |
return f"Model with Gradio!"
|
| 78 |
|
|
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
|
| 12 |
+
from spanish_medica_llm import run_training, run_training_process
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
|
|
|
|
| 42 |
if "IS_SHARED_UI" in os.environ:
|
| 43 |
raise gr.Error("This Space only works in duplicated instances")
|
| 44 |
|
| 45 |
+
# args_general = argparse.Namespace(
|
| 46 |
+
# image_captions_filename = True,
|
| 47 |
+
# train_text_encoder = True,
|
| 48 |
+
# #stop_text_encoder_training = stptxt,
|
| 49 |
+
# save_n_steps = 0,
|
| 50 |
+
# #pretrained_model_name_or_path = model_to_load,
|
| 51 |
+
# instance_data_dir="instance_images",
|
| 52 |
+
# #class_data_dir=class_data_dir,
|
| 53 |
+
# output_dir="output_model",
|
| 54 |
+
# instance_prompt="",
|
| 55 |
+
# seed=42,
|
| 56 |
+
# resolution=512,
|
| 57 |
+
# mixed_precision="fp16",
|
| 58 |
+
# train_batch_size=1,
|
| 59 |
+
# gradient_accumulation_steps=1,
|
| 60 |
+
# use_8bit_adam=True,
|
| 61 |
+
# learning_rate=2e-6,
|
| 62 |
+
# lr_scheduler="polynomial",
|
| 63 |
+
# lr_warmup_steps = 0,
|
| 64 |
+
# #max_train_steps=Training_Steps,
|
| 65 |
+
# )
|
| 66 |
+
# run_training(args_general)
|
| 67 |
+
# torch.cuda.empty_cache()
|
| 68 |
+
# #convert("output_model", "model.ckpt")
|
| 69 |
+
# #shutil.rmtree('instance_images')
|
| 70 |
+
# #shutil.make_archive("diffusers_model", 'zip', "output_model")
|
| 71 |
+
# #with zipfile.ZipFile('diffusers_model.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 72 |
+
# # zipdir('output_model/', zipf)
|
| 73 |
+
# torch.cuda.empty_cache()
|
| 74 |
+
# return [gr.update(visible=True, value=["diffusers_model.zip"]), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
|
| 75 |
+
run_training_process()
|
| 76 |
+
return f"Train Model Sucessful!!!"
|
| 77 |
def stop_model(*input):
|
| 78 |
return f"Model with Gradio!"
|
| 79 |
|
requirements.txt
CHANGED
|
@@ -1,2 +1,8 @@
|
|
| 1 |
-
transformers
|
| 2 |
-
torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.38.0
|
| 2 |
+
torch>=2.1.1+cu113
|
| 3 |
+
trl @ git+https://github.com/huggingface/trl
|
| 4 |
+
peft
|
| 5 |
+
wandb
|
| 6 |
+
accelerate
|
| 7 |
+
datasets
|
| 8 |
+
bitsandbytes
|
spanish_medica_llm.py
CHANGED
|
@@ -6,8 +6,60 @@ from pathlib import Path
|
|
| 6 |
from typing import Optional
|
| 7 |
import subprocess
|
| 8 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import torch
|
| 10 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def parse_args():
|
| 13 |
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
@@ -248,3 +300,253 @@ def run_training(args_imported):
|
|
| 248 |
args_default = parse_args()
|
| 249 |
#args = merge_args(args_default, args_imported)
|
| 250 |
return(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from typing import Optional
|
| 7 |
import subprocess
|
| 8 |
import sys
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
import torch
|
| 15 |
+
from datasets import load_dataset, concatenate_datasets
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoModelForCausalLM,
|
| 18 |
+
AutoTokenizer,
|
| 19 |
+
BitsAndBytesConfig,
|
| 20 |
+
TrainingArguments,
|
| 21 |
+
Trainer,
|
| 22 |
+
DataCollatorForLanguageModeling
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from accelerate import FullyShardedDataParallelPlugin, Accelerator
|
| 26 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
|
| 27 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 28 |
+
import wandb
|
| 29 |
+
from trl import SFTTrainer
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
CHAT_ML_TEMPLATE_Mistral_7B_Instruct = """
|
| 33 |
+
{% if messages[0]['role'] == 'system' %}
|
| 34 |
+
{% set loop_messages = messages[1:] %}
|
| 35 |
+
{% set system_message = messages[0]['content'].strip() + '\n\n' %}
|
| 36 |
+
{% else %}
|
| 37 |
+
{% set loop_messages = messages %}
|
| 38 |
+
{% set system_message = '' %}
|
| 39 |
+
{% endif %}
|
| 40 |
+
|
| 41 |
+
{{ bos_token }}
|
| 42 |
+
{% for message in loop_messages %}
|
| 43 |
+
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
| 44 |
+
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
| 45 |
+
{% endif %}
|
| 46 |
+
|
| 47 |
+
{% if loop.index0 == 0 %}
|
| 48 |
+
{% set content = system_message + message['content'] %}
|
| 49 |
+
{% else %}
|
| 50 |
+
{% set content = message['content'] %}
|
| 51 |
+
{% endif %}
|
| 52 |
+
|
| 53 |
+
{% if message['role'] == 'user' %}
|
| 54 |
+
{{ '[INST] ' + content.strip() + ' [/INST]' }}
|
| 55 |
+
{% elif message['role'] == 'assistant' %}
|
| 56 |
+
{{ ' ' + content.strip() + ' ' + eos_token }}
|
| 57 |
+
{% endif %}
|
| 58 |
+
{% endfor %}
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
|
| 64 |
def parse_args():
|
| 65 |
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
|
|
| 300 |
args_default = parse_args()
|
| 301 |
#args = merge_args(args_default, args_imported)
|
| 302 |
return(args)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
TOKEN_NAME = "DeepESP/gpt2-spanish-medium"
|
| 307 |
+
TOKEN_MISTRAL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
| 308 |
+
SPANISH_MEDICA_LLM_DATASET = "somosnlp/spanish_medica_llm"
|
| 309 |
+
|
| 310 |
+
TOPIC_TYPE_DIAGNOSTIC = 'medical_diagnostic'
|
| 311 |
+
TOPIC_TYPE_TRATAMIENT = 'medical_topic'
|
| 312 |
+
FILTER_CRITERIA = [TOPIC_TYPE_DIAGNOSTIC, TOPIC_TYPE_TRATAMIENT]
|
| 313 |
+
CONTEXT_LENGTH = 256 #Max of tokens
|
| 314 |
+
|
| 315 |
+
MISTRAL_BASE_MODEL_ID = "BioMistral/BioMistral-7B"
|
| 316 |
+
|
| 317 |
+
MICRO_BATCH_SIZE = 16 #32 For other GPU BIGGER THAN T4
|
| 318 |
+
BATCH_SIZE = 64 #128 For other GPU BIGGER THAN T4
|
| 319 |
+
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
| 320 |
+
|
| 321 |
+
PROJECT_NAME = "spanish-medica-llm"
|
| 322 |
+
BASE_MODEL_NAME = "biomistral"
|
| 323 |
+
run_name = BASE_MODEL_NAME + "-" + PROJECT_NAME
|
| 324 |
+
output_dir = "./" + run_name
|
| 325 |
+
|
| 326 |
+
HUB_MODEL_ID = 'somosnlp/spanish_medica_llm'
|
| 327 |
+
MAX_TRAINING_STEPS = int(1500/2)
|
| 328 |
+
|
| 329 |
+
def loadSpanishTokenizer():
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
"""
|
| 333 |
+
#Load first the mistral used tokenizer
|
| 334 |
+
tokenizerMistrall = AutoTokenizer.from_pretrained(TOKEN_MISTRAL_NAME)
|
| 335 |
+
|
| 336 |
+
#Load second an spanish specialized tokenizer
|
| 337 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 338 |
+
TOKEN_NAME,
|
| 339 |
+
eos_token = tokenizerMistrall.special_tokens_map['eos_token'],
|
| 340 |
+
bos_token = tokenizerMistrall.special_tokens_map['bos_token'],
|
| 341 |
+
unk_token = tokenizerMistrall.special_tokens_map['unk_token']
|
| 342 |
+
)
|
| 343 |
+
tokenizer.chat_template = CHAT_ML_TEMPLATE_Mistral_7B_Instruct
|
| 344 |
+
|
| 345 |
+
return tokenizer
|
| 346 |
+
|
| 347 |
+
def tokenize(element, tokenizer):
|
| 348 |
+
outputs = tokenizer(
|
| 349 |
+
element["raw_text"],
|
| 350 |
+
truncation = True,
|
| 351 |
+
max_length = CONTEXT_LENGTH,
|
| 352 |
+
return_overflowing_tokens = True,
|
| 353 |
+
return_length = True,
|
| 354 |
+
)
|
| 355 |
+
input_batch = []
|
| 356 |
+
for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
|
| 357 |
+
if length == CONTEXT_LENGTH:
|
| 358 |
+
input_batch.append(input_ids)
|
| 359 |
+
return {"input_ids": input_batch}
|
| 360 |
+
|
| 361 |
+
def splitDatasetInTestValid(dataset):
|
| 362 |
+
"""
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
if dataset == None or dataset['train'] == None:
|
| 366 |
+
return dataset
|
| 367 |
+
elif dataset['test'] == None:
|
| 368 |
+
return None
|
| 369 |
+
else:
|
| 370 |
+
test_eval = dataset['test'].train_test_split(test_size=0.001)
|
| 371 |
+
eval_dataset = test_eval['train']
|
| 372 |
+
test_dataset = test_eval['test']
|
| 373 |
+
|
| 374 |
+
return (dataset['train'], eval_dataset, test_dataset)
|
| 375 |
+
|
| 376 |
+
def loadSpanishDataset():
|
| 377 |
+
spanishMedicaLllmDataset = load_dataset(SPANISH_MEDICA_LLM_DATASET, split="train")
|
| 378 |
+
spanishMedicaLllmDataset = spanishMedicaLllmDataset.filter(lambda example: example["topic_type"] not in FILTER_CRITERIA)
|
| 379 |
+
spanishMedicaLllmDataset = spanishMedicaLllmDataset.train_test_split(0.2, seed=203984)
|
| 380 |
+
return spanishMedicaLllmDataset
|
| 381 |
+
|
| 382 |
+
##See Jupyter Notebook for change CONTEXT_LENGTH size
|
| 383 |
+
|
| 384 |
+
def accelerateConfigModel():
|
| 385 |
+
"""
|
| 386 |
+
Only with GPU support
|
| 387 |
+
RuntimeError: There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'.
|
| 388 |
+
"""
|
| 389 |
+
fsdp_plugin = FullyShardedDataParallelPlugin(
|
| 390 |
+
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
| 391 |
+
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return Accelerator(fsdp_plugin=fsdp_plugin)
|
| 395 |
+
|
| 396 |
+
def getTokenizedDataset(dataset, tokenizer):
|
| 397 |
+
if dataset == None or tokenizer == None:
|
| 398 |
+
return dataset
|
| 399 |
+
|
| 400 |
+
return dataset.map(
|
| 401 |
+
tokenize,
|
| 402 |
+
batched = True,
|
| 403 |
+
remove_columns = dataset["train"].column_names
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
def loadBaseModel(base_model_id):
|
| 407 |
+
|
| 408 |
+
if base_model_id in [ "", None]:
|
| 409 |
+
return None
|
| 410 |
+
else:
|
| 411 |
+
bnb_config = BitsAndBytesConfig(
|
| 412 |
+
load_in_4bit = True,
|
| 413 |
+
bnb_4bit_quant_type = "nf4",
|
| 414 |
+
bnb_4bit_use_double_quant = True,
|
| 415 |
+
bnb_4bit_compute_dtype = torch.bfloat16
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 419 |
+
base_model_id,
|
| 420 |
+
quantization_config = bnb_config
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
model.gradient_checkpointing_enable()
|
| 424 |
+
model = prepare_model_for_kbit_training(model)
|
| 425 |
+
|
| 426 |
+
return model
|
| 427 |
+
|
| 428 |
+
def print_trainable_parameters(model):
|
| 429 |
+
"""
|
| 430 |
+
Prints the number of trainable parameters in the model.
|
| 431 |
+
"""
|
| 432 |
+
trainable_params = 0
|
| 433 |
+
all_param = 0
|
| 434 |
+
for _, param in model.named_parameters():
|
| 435 |
+
all_param += param.numel()
|
| 436 |
+
if param.requires_grad:
|
| 437 |
+
trainable_params += param.numel()
|
| 438 |
+
print(
|
| 439 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
def modelLoraConfigBioMistral(model):
|
| 443 |
+
"""
|
| 444 |
+
r is the rank of the low-rank matrix used in the adapters, which thus controls
|
| 445 |
+
the number of parameters trained. A higher rank will allow for more expressivity, but there is a
|
| 446 |
+
compute tradeoff.
|
| 447 |
+
alpha is the scaling factor for the learned weights. The weight matrix is scaled by
|
| 448 |
+
alpha/r, and thus a higher value for alpha assigns more weight to the LoRA activations.
|
| 449 |
+
The values used in the QLoRA paper werer=64 and lora_alpha=16,
|
| 450 |
+
and these are said to generalize well, but we will user=8 and lora_alpha=16 so that we have more emphasis on the new fine-tuned data while also reducing computational complexity.
|
| 451 |
+
"""
|
| 452 |
+
if model == None:
|
| 453 |
+
return model
|
| 454 |
+
else:
|
| 455 |
+
config = LoraConfig(
|
| 456 |
+
r=8,
|
| 457 |
+
lora_alpha=16,
|
| 458 |
+
target_modules=[
|
| 459 |
+
"q_proj",
|
| 460 |
+
"k_proj",
|
| 461 |
+
"v_proj",
|
| 462 |
+
"o_proj",
|
| 463 |
+
"gate_proj",
|
| 464 |
+
"up_proj",
|
| 465 |
+
"down_proj",
|
| 466 |
+
"lm_head",
|
| 467 |
+
],
|
| 468 |
+
bias="none",
|
| 469 |
+
lora_dropout=0.05, # Conventional
|
| 470 |
+
task_type="CAUSAL_LM",
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
model = get_peft_model(model, config)
|
| 474 |
+
print_trainable_parameters(model)
|
| 475 |
+
|
| 476 |
+
accelerator = accelerateConfigModel()
|
| 477 |
+
# Apply the accelerator. You can comment this out to remove the accelerator.
|
| 478 |
+
model = accelerator.prepare_model(model)
|
| 479 |
+
return (model)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
# A note on training. You can set the max_steps to be high initially, and examine at what step your
|
| 483 |
+
# model's performance starts to degrade. There is where you'll find a sweet spot for how many steps
|
| 484 |
+
# to perform. For example, say you start with 1000 steps, and find that at around 500 steps
|
| 485 |
+
# the model starts overfitting - the validation loss goes up (bad) while the training
|
| 486 |
+
# loss goes down significantly, meaning the model is learning the training set really well,
|
| 487 |
+
# but is unable to generalize to new datapoints. Therefore, 500 steps would be your sweet spot,
|
| 488 |
+
# so you would use the checkpoint-500 model repo in your output dir (biomistral-medqa-finetune)
|
| 489 |
+
# as your final model in step 6 below.
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def configAndRunTraining(basemodel, dataset, eval_dataset, tokenizer):
|
| 494 |
+
if basemodel is None or dataset is None or tokenizer is None:
|
| 495 |
+
return None
|
| 496 |
+
else:
|
| 497 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 498 |
+
data_collator_pretrain = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
| 499 |
+
|
| 500 |
+
training_args = transformers.TrainingArguments(
|
| 501 |
+
output_dir=output_dir,
|
| 502 |
+
push_to_hub = True,
|
| 503 |
+
hub_private_repo = False,
|
| 504 |
+
hub_model_id = HUB_MODEL_ID,
|
| 505 |
+
warmup_steps =5,
|
| 506 |
+
per_device_train_batch_size = MICRO_BATCH_SIZE,
|
| 507 |
+
per_device_eval_batch_size=1,
|
| 508 |
+
#gradient_checkpointing=True,
|
| 509 |
+
gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS,
|
| 510 |
+
max_steps = MAX_TRAINING_STEPS,
|
| 511 |
+
learning_rate = 2.5e-5, # Want about 10x smaller than the Mistral learning rate
|
| 512 |
+
logging_steps = 50,
|
| 513 |
+
optim="paged_adamw_8bit",
|
| 514 |
+
logging_dir="./logs", # Directory for storing logs
|
| 515 |
+
save_strategy = "steps", # Save the model checkpoint every logging step
|
| 516 |
+
save_steps = 50, # Save checkpoints every 50 steps
|
| 517 |
+
evaluation_strategy = "steps", # Evaluate the model every logging step
|
| 518 |
+
eval_steps = 50, # Evaluate and save checkpoints every 50 steps
|
| 519 |
+
do_eval = True, # Perform evaluation at the end of training
|
| 520 |
+
#report_to="wandb", # Comment this out if you don't want to use weights & baises
|
| 521 |
+
run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" , # Name of the W&B run (optional)
|
| 522 |
+
fp16=True, #Set for GPU T4 for more powerful GPU as G-100 or another change to false and bf16 parameter
|
| 523 |
+
bf16=False
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
trainer = transformers.Trainer(
|
| 527 |
+
model= basemodel,
|
| 528 |
+
train_dataset = dataset['train'],
|
| 529 |
+
eval_dataset = eval_dataset,
|
| 530 |
+
args = training_args,
|
| 531 |
+
data_collator = data_collator_pretrain
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
basemodel.config.use_cache = False # silence the warnings. Please re-enable for inference!
|
| 535 |
+
trainer.train()
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
trainer.push_to_hub()
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def run_training_process():
|
| 542 |
+
|
| 543 |
+
tokenizer = loadSpanishTokenizer()
|
| 544 |
+
medicalSpanishDataset = loadSpanishDataset()
|
| 545 |
+
train_dataset, eval_dataset, test_dataset = splitDatasetInTestValid(
|
| 546 |
+
getTokenizedDataset( medicalSpanishDataset, tokenizer)
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
base_model = loadBaseModel(MISTRAL_BASE_MODEL_ID)
|
| 550 |
+
base_model = modelLoraConfigBioMistral(base_model)
|
| 551 |
+
|
| 552 |
+
configAndRunTraining(base_model,train_dataset, eval_dataset, tokenizer)
|