| from torch.utils.data import Dataset | |
| import json | |
| default_get = lambda key : lambda data: data[key] | |
| class PromptDataset(Dataset): | |
| def __init__(self,data_dir,*keys,**projections) -> None: | |
| self.data = [] | |
| for d in data_dir: | |
| list_contents = {key:default_get(key)(d) for key in keys if key in d.keys()} | |
| dict_contents = {projection:projections[projection](d) for projection in projections.keys()} | |
| self.data.append({**list_contents,**dict_contents}) | |
| def __getitem__(self, index) -> dict: | |
| return self.data[index] | |
| def __len__(self): | |
| return len(self.data) | |
| class FileDataset(PromptDataset): | |
| def __init__(self,data_dir,*keys,**projections) -> None: | |
| with open(data_dir,'r',encoding='utf-8') as file: | |
| data_dir = json.load(file) | |
| if not keys: | |
| keys = data_dir[0].keys() | |
| self.data = [] | |
| for d in data_dir: | |
| list_contents = {key:default_get(key)(d) for key in keys if key in d.keys()} | |
| dict_contents = {projection:projections[projection](d) for projection in projections.keys()} | |
| self.data.append({**list_contents,**dict_contents}) | |
| def __getitem__(self, index) -> dict: | |
| return self.data[index] | |
| def __len__(self): | |
| return len(self.data) |