import os import json import csv import pandas as pd import random def validate_dataset(file_path, options): """ Validates that a dataset file can be processed with the given options. Args: file_path: Path to the dataset file options: Dictionary of processing options Returns: Tuple of (is_valid, message) """ if not os.path.exists(file_path): return False, f"File not found: {file_path}" file_format = options.get("format", "").lower() try: if file_format == "csv": # Validate CSV format separator = options.get("csv_separator", ",") prompt_col = options.get("csv_prompt_col", "prompt") completion_col = options.get("csv_completion_col", "completion") df = pd.read_csv(file_path, sep=separator) if prompt_col not in df.columns: return False, f"Prompt column '{prompt_col}' not found in CSV file" if completion_col not in df.columns: return False, f"Completion column '{completion_col}' not found in CSV file" # Check for empty values if df[prompt_col].isnull().any(): return False, "CSV file contains empty prompt values" if df[completion_col].isnull().any(): return False, "CSV file contains empty completion values" elif file_format == "jsonl": # Validate JSONL format prompt_key = options.get("jsonl_prompt_key", "prompt") completion_key = options.get("jsonl_completion_key", "completion") with open(file_path, 'r', encoding='utf-8') as f: line_count = 0 for line in f: line = line.strip() if not line: continue data = json.loads(line) line_count += 1 if prompt_key not in data: return False, f"Prompt key '{prompt_key}' not found in JSONL at line {line_count}" if completion_key not in data: return False, f"Completion key '{completion_key}' not found in JSONL at line {line_count}" if not data[prompt_key] or not isinstance(data[prompt_key], str): return False, f"Invalid prompt value at line {line_count}" if not data[completion_key] or not isinstance(data[completion_key], str): return False, f"Invalid completion value at line {line_count}" if line_count == 0: return False, "JSONL file is empty" elif file_format == "plain text": # Validate plain text format separator = options.get("text_separator", "###") with open(file_path, 'r', encoding='utf-8') as f: content = f.read() parts = content.split(separator) if len(parts) < 3: # Need at least one prompt and one completion return False, f"Text file doesn't contain enough sections separated by '{separator}'" # Check if there's an odd number of parts (should be prompt, completion, prompt, completion, ...) if len(parts) % 2 == 0: return False, f"Text file has an invalid number of sections separated by '{separator}'" else: return False, f"Unsupported format: {file_format}" return True, "Dataset is valid" except Exception as e: return False, f"Error validating dataset: {str(e)}" def process_dataset(file_path, options): """ Processes a dataset file according to the given options. Args: file_path: Path to the dataset file options: Dictionary of processing options Returns: Tuple of (processed_data, stats, preview) """ file_format = options.get("format", "").lower() if file_format == "csv": return _process_csv(file_path, options) elif file_format == "jsonl": return _process_jsonl(file_path, options) elif file_format == "plain text": return _process_text(file_path, options) else: raise ValueError(f"Unsupported format: {file_format}") def _process_csv(file_path, options): """Process a CSV dataset file.""" separator = options.get("csv_separator", ",") prompt_col = options.get("csv_prompt_col", "prompt") completion_col = options.get("csv_completion_col", "completion") df = pd.read_csv(file_path, sep=separator) # Extract prompts and completions data = [] for _, row in df.iterrows(): data.append({ "prompt": str(row[prompt_col]), "completion": str(row[completion_col]) }) # Generate statistics stats = { "num_examples": len(data), "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data), "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data), "format": "csv" } # Create a preview DataFrame (showing first 5 rows) preview = df[[prompt_col, completion_col]].head(5) return data, stats, preview def _process_jsonl(file_path, options): """Process a JSONL dataset file.""" prompt_key = options.get("jsonl_prompt_key", "prompt") completion_key = options.get("jsonl_completion_key", "completion") data = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue item = json.loads(line) data.append({ "prompt": item[prompt_key], "completion": item[completion_key] }) # Generate statistics stats = { "num_examples": len(data), "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data), "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data), "format": "jsonl" } # Create a preview DataFrame preview_data = [] for i, item in enumerate(data[:5]): preview_data.append({ "prompt": item["prompt"], "completion": item["completion"] }) preview = pd.DataFrame(preview_data) return data, stats, preview def _process_text(file_path, options): """Process a plain text dataset file.""" separator = options.get("text_separator", "###") with open(file_path, 'r', encoding='utf-8') as f: content = f.read() parts = content.split(separator) data = [] for i in range(0, len(parts) - 1, 2): prompt = parts[i].strip() completion = parts[i + 1].strip() if prompt and completion: data.append({ "prompt": prompt, "completion": completion }) # Generate statistics stats = { "num_examples": len(data), "avg_prompt_length": sum(len(item["prompt"]) for item in data) / len(data), "avg_completion_length": sum(len(item["completion"]) for item in data) / len(data), "format": "text" } # Create a preview DataFrame preview_data = [] for i, item in enumerate(data[:5]): preview_data.append({ "prompt": item["prompt"], "completion": item["completion"] }) preview = pd.DataFrame(preview_data) return data, stats, preview def format_for_training(dataset, tokenizer, max_length=512): """ Formats a processed dataset for training with Gemma. Args: dataset: List of prompt/completion pairs tokenizer: Tokenizer for the model max_length: Maximum sequence length Returns: Dictionary of training data """ input_ids = [] labels = [] attention_mask = [] for item in dataset: prompt = item["prompt"] completion = item["completion"] # Format as the model expects full_text = f"{prompt}{tokenizer.eos_token}{completion}{tokenizer.eos_token}" # Tokenize encoded = tokenizer(full_text, max_length=max_length, padding="max_length", truncation=True) # For input_ids, we use the full sequence input_ids.append(encoded["input_ids"]) attention_mask.append(encoded["attention_mask"]) # For labels, we set the prompt tokens to -100 so they're ignored in loss calculation prompt_encoded = tokenizer(f"{prompt}{tokenizer.eos_token}", add_special_tokens=False) prompt_length = len(prompt_encoded["input_ids"]) # Create label tensor: -100 for prompt tokens (ignored in loss), actual token IDs for completion label = [-100] * prompt_length + encoded["input_ids"][prompt_length:] # Pad to max_length if len(label) < max_length: label = label + [-100] * (max_length - len(label)) else: label = label[:max_length] labels.append(label) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } def create_train_val_split(dataset, val_size=0.1, seed=42): """ Splits a dataset into training and validation sets. Args: dataset: List of examples val_size: Fraction of examples to use for validation seed: Random seed for reproducibility Returns: Tuple of (train_dataset, val_dataset) """ random.seed(seed) random.shuffle(dataset) val_count = max(1, int(len(dataset) * val_size)) val_dataset = dataset[:val_count] train_dataset = dataset[val_count:] return train_dataset, val_dataset