from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq from trl import SFTTrainer, SFTConfig from datasets import load_dataset from peft import LoraConfig, get_peft_model import torch import os # ------------------------- # Load dataset # ------------------------- dataset = load_dataset("mavilov/convos", split="train") # ------------------------- # Load model and tokenizer # ------------------------- model_id = "swiss-ai/Apertus-8B-2509" model_kwargs = {} if torch.backends.mps.is_available(): print("⚡ Using Apple MPS backend (Metal)") model_kwargs = { "dtype": torch.float16, "device_map": {"": "mps"}, # force load directly on MPS "offload_folder": "./offload", "low_cpu_mem_usage": True, # avoid meta tensors } elif torch.cuda.is_available(): print("⚡ Using CUDA with bitsandbytes quantization") from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0 ) model_kwargs["quantization_config"] = bnb_config model_kwargs["device_map"] = "auto" else: print("⚠️ No GPU/MPS detected, running on CPU (very slow)") model_kwargs = { "dtype": torch.float32, "device_map": {"": "cpu"}, "low_cpu_mem_usage": True, } # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token # Load model safely model = AutoModelForCausalLM.from_pretrained( model_id, **model_kwargs ) model.config.use_cache = False model.config.pretraining_tp = 1 # ------------------------- # Attach LoRA adapters # ------------------------- lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) # ------------------------- # Preprocess / tokenize dataset # ------------------------- def tokenize_fn(example): tokenized = tokenizer( example["text"], truncation=True, max_length=2048 ) tokenized["labels"] = tokenized["input_ids"].copy() return tokenized dataset = dataset.map(tokenize_fn, batched=True) # ------------------------- # Data collator with dynamic padding # ------------------------- data_collator = DataCollatorForSeq2Seq(tokenizer, padding="longest") # ------------------------- # Training configuration # ------------------------- training_args = SFTConfig( output_dir="./results", per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-4, num_train_epochs=3, logging_steps=10, report_to="tensorboard", bf16=False, ) # ------------------------- # Initialize trainer # ------------------------- trainer = SFTTrainer( model=model, train_dataset=dataset, args=training_args, data_collator=data_collator ) # ------------------------- # Start training # ------------------------- trainer.train()