Spaces:
Sleeping
Sleeping
| """ | |
| Domain Dataset Module for Cross-Domain Uncertainty Quantification | |
| This module provides functionality for loading and managing datasets from different domains | |
| for evaluating uncertainty quantification methods across domains. | |
| """ | |
| import os | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| from typing import List, Dict, Any, Union, Optional, Tuple | |
| from datasets import load_dataset | |
| class DomainDataset: | |
| """Base class for domain-specific datasets.""" | |
| def __init__(self, name: str, domain: str): | |
| """ | |
| Initialize the domain dataset. | |
| Args: | |
| name: Name of the dataset | |
| domain: Domain category (e.g., 'medical', 'legal', 'general') | |
| """ | |
| self.name = name | |
| self.domain = domain | |
| self.data = None | |
| def load(self) -> None: | |
| """Load the dataset.""" | |
| raise NotImplementedError("Subclasses must implement this method") | |
| def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Get samples from the dataset. | |
| Args: | |
| n: Number of samples to return (None for all) | |
| Returns: | |
| List of samples with prompts and expected outputs | |
| """ | |
| raise NotImplementedError("Subclasses must implement this method") | |
| def get_prompt_template(self) -> str: | |
| """ | |
| Get the prompt template for this domain. | |
| Returns: | |
| Prompt template string | |
| """ | |
| raise NotImplementedError("Subclasses must implement this method") | |
| class MedicalQADataset(DomainDataset): | |
| """Dataset for medical question answering.""" | |
| def __init__(self, data_path: Optional[str] = None): | |
| """ | |
| Initialize the medical QA dataset. | |
| Args: | |
| data_path: Path to the dataset file (None to use default) | |
| """ | |
| super().__init__("medical_qa", "medical") | |
| self.data_path = data_path | |
| def load(self) -> None: | |
| """Load the medical QA dataset.""" | |
| if self.data_path and os.path.exists(self.data_path): | |
| # Load from local file if available | |
| if self.data_path.endswith('.csv'): | |
| self.data = pd.read_csv(self.data_path) | |
| elif self.data_path.endswith('.json'): | |
| with open(self.data_path, 'r') as f: | |
| self.data = json.load(f) | |
| else: | |
| raise ValueError(f"Unsupported file format: {self.data_path}") | |
| else: | |
| # Use a sample of the MedMCQA dataset from Hugging Face | |
| try: | |
| dataset = load_dataset("medmcqa", split="train[:100]") | |
| self.data = dataset.to_pandas() | |
| except Exception as e: | |
| # Fallback to synthetic data if dataset loading fails | |
| print(f"Failed to load MedMCQA dataset: {e}") | |
| self.data = self._create_synthetic_data() | |
| def _create_synthetic_data(self) -> pd.DataFrame: | |
| """Create synthetic medical QA data for testing.""" | |
| questions = [ | |
| "What are the common symptoms of myocardial infarction?", | |
| "How does insulin regulate blood glucose levels?", | |
| "What is the mechanism of action for ACE inhibitors?", | |
| "What are the diagnostic criteria for rheumatoid arthritis?", | |
| "How does the SARS-CoV-2 virus enter human cells?", | |
| "What are the main side effects of chemotherapy?", | |
| "How does the blood-brain barrier function?", | |
| "What is the pathophysiology of type 2 diabetes?", | |
| "How do vaccines create immunity?", | |
| "What are the stages of chronic kidney disease?" | |
| ] | |
| # Create a dataframe with questions only (answers would be generated by LLMs) | |
| return pd.DataFrame({ | |
| 'question': questions, | |
| 'domain': ['medical'] * len(questions) | |
| }) | |
| def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Get samples from the medical QA dataset. | |
| Args: | |
| n: Number of samples to return (None for all) | |
| Returns: | |
| List of samples with prompts | |
| """ | |
| if self.data is None: | |
| self.load() | |
| if 'question' in self.data.columns: | |
| questions = self.data['question'].tolist() | |
| elif 'question_text' in self.data.columns: | |
| questions = self.data['question_text'].tolist() | |
| else: | |
| raise ValueError("Dataset does not contain question column") | |
| if n is not None: | |
| questions = questions[:n] | |
| # Create samples with prompts | |
| samples = [] | |
| for question in questions: | |
| prompt = self.get_prompt_template().format(question=question) | |
| samples.append({ | |
| 'domain': 'medical', | |
| 'question': question, | |
| 'prompt': prompt | |
| }) | |
| return samples | |
| def get_prompt_template(self) -> str: | |
| """ | |
| Get the prompt template for medical domain. | |
| Returns: | |
| Prompt template string | |
| """ | |
| return "You are a medical expert. Please answer the following medical question accurately and concisely:\n\n{question}" | |
| class LegalQADataset(DomainDataset): | |
| """Dataset for legal question answering.""" | |
| def __init__(self, data_path: Optional[str] = None): | |
| """ | |
| Initialize the legal QA dataset. | |
| Args: | |
| data_path: Path to the dataset file (None to use default) | |
| """ | |
| super().__init__("legal_qa", "legal") | |
| self.data_path = data_path | |
| def load(self) -> None: | |
| """Load the legal QA dataset.""" | |
| if self.data_path and os.path.exists(self.data_path): | |
| # Load from local file if available | |
| if self.data_path.endswith('.csv'): | |
| self.data = pd.read_csv(self.data_path) | |
| elif self.data_path.endswith('.json'): | |
| with open(self.data_path, 'r') as f: | |
| self.data = json.load(f) | |
| else: | |
| raise ValueError(f"Unsupported file format: {self.data_path}") | |
| else: | |
| # Use synthetic data for legal domain | |
| self.data = self._create_synthetic_data() | |
| def _create_synthetic_data(self) -> pd.DataFrame: | |
| """Create synthetic legal QA data for testing.""" | |
| questions = [ | |
| "What constitutes a breach of contract?", | |
| "How is intellectual property protected under international law?", | |
| "What are the elements of negligence in tort law?", | |
| "How does the doctrine of stare decisis function in common law systems?", | |
| "What rights are protected under the Fourth Amendment?", | |
| "What is the difference between a patent and a copyright?", | |
| "How does arbitration differ from litigation?", | |
| "What constitutes insider trading under securities law?", | |
| "What are the legal requirements for a valid will?", | |
| "How does diplomatic immunity work under international law?" | |
| ] | |
| # Create a dataframe with questions only | |
| return pd.DataFrame({ | |
| 'question': questions, | |
| 'domain': ['legal'] * len(questions) | |
| }) | |
| def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Get samples from the legal QA dataset. | |
| Args: | |
| n: Number of samples to return (None for all) | |
| Returns: | |
| List of samples with prompts | |
| """ | |
| if self.data is None: | |
| self.load() | |
| questions = self.data['question'].tolist() | |
| if n is not None: | |
| questions = questions[:n] | |
| # Create samples with prompts | |
| samples = [] | |
| for question in questions: | |
| prompt = self.get_prompt_template().format(question=question) | |
| samples.append({ | |
| 'domain': 'legal', | |
| 'question': question, | |
| 'prompt': prompt | |
| }) | |
| return samples | |
| def get_prompt_template(self) -> str: | |
| """ | |
| Get the prompt template for legal domain. | |
| Returns: | |
| Prompt template string | |
| """ | |
| return "You are a legal expert. Please answer the following legal question accurately and concisely:\n\n{question}" | |
| class GeneralKnowledgeDataset(DomainDataset): | |
| """Dataset for general knowledge question answering.""" | |
| def __init__(self, data_path: Optional[str] = None): | |
| """ | |
| Initialize the general knowledge dataset. | |
| Args: | |
| data_path: Path to the dataset file (None to use default) | |
| """ | |
| super().__init__("general_knowledge", "general") | |
| self.data_path = data_path | |
| def load(self) -> None: | |
| """Load the general knowledge dataset.""" | |
| if self.data_path and os.path.exists(self.data_path): | |
| # Load from local file if available | |
| if self.data_path.endswith('.csv'): | |
| self.data = pd.read_csv(self.data_path) | |
| elif self.data_path.endswith('.json'): | |
| with open(self.data_path, 'r') as f: | |
| self.data = json.load(f) | |
| else: | |
| raise ValueError(f"Unsupported file format: {self.data_path}") | |
| else: | |
| # Use a sample of the TriviaQA dataset from Hugging Face | |
| try: | |
| dataset = load_dataset("trivia_qa", "unfiltered", split="train[:100]") | |
| self.data = dataset.to_pandas() | |
| except Exception as e: | |
| # Fallback to synthetic data if dataset loading fails | |
| print(f"Failed to load TriviaQA dataset: {e}") | |
| self.data = self._create_synthetic_data() | |
| def _create_synthetic_data(self) -> pd.DataFrame: | |
| """Create synthetic general knowledge data for testing.""" | |
| questions = [ | |
| "What is the capital of France?", | |
| "Who wrote the novel '1984'?", | |
| "What is the chemical symbol for gold?", | |
| "Which planet is known as the Red Planet?", | |
| "Who painted the Mona Lisa?", | |
| "What is the largest ocean on Earth?", | |
| "What year did World War II end?", | |
| "What is the tallest mountain in the world?", | |
| "Who was the first person to step on the moon?", | |
| "What is the speed of light in a vacuum?" | |
| ] | |
| # Create a dataframe with questions only | |
| return pd.DataFrame({ | |
| 'question': questions, | |
| 'domain': ['general'] * len(questions) | |
| }) | |
| def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Get samples from the general knowledge dataset. | |
| Args: | |
| n: Number of samples to return (None for all) | |
| Returns: | |
| List of samples with prompts | |
| """ | |
| if self.data is None: | |
| self.load() | |
| if 'question' in self.data.columns: | |
| questions = self.data['question'].tolist() | |
| elif 'question_text' in self.data.columns: | |
| questions = self.data['question_text'].tolist() | |
| else: | |
| raise ValueError("Dataset does not contain question column") | |
| if n is not None: | |
| questions = questions[:n] | |
| # Create samples with prompts | |
| samples = [] | |
| for question in questions: | |
| prompt = self.get_prompt_template().format(question=question) | |
| samples.append({ | |
| 'domain': 'general', | |
| 'question': question, | |
| 'prompt': prompt | |
| }) | |
| return samples | |
| def get_prompt_template(self) -> str: | |
| """ | |
| Get the prompt template for general knowledge domain. | |
| Returns: | |
| Prompt template string | |
| """ | |
| return "Please answer the following general knowledge question accurately and concisely:\n\n{question}" | |
| # Factory function to create domain datasets | |
| def create_domain_dataset(domain: str, data_path: Optional[str] = None) -> DomainDataset: | |
| """ | |
| Create a domain dataset based on the specified domain. | |
| Args: | |
| domain: Domain category ('medical', 'legal', 'general') | |
| data_path: Path to the dataset file (None to use default) | |
| Returns: | |
| Domain dataset instance | |
| """ | |
| if domain == "medical": | |
| return MedicalQADataset(data_path) | |
| elif domain == "legal": | |
| return LegalQADataset(data_path) | |
| elif domain == "general": | |
| return GeneralKnowledgeDataset(data_path) | |
| else: | |
| raise ValueError(f"Unsupported domain: {domain}") | |