Spaces:
Build error
Build error
| # Copyright 2024 the LlamaFactory team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Literal, Optional, Sequence | |
| from transformers.utils import cached_file | |
| from ..extras.constants import DATA_CONFIG | |
| from ..extras.misc import use_modelscope | |
| class DatasetAttr: | |
| r""" | |
| Dataset attributes. | |
| """ | |
| # basic configs | |
| load_from: Literal["hf_hub", "ms_hub", "script", "file"] | |
| dataset_name: str | |
| formatting: Literal["alpaca", "sharegpt"] = "alpaca" | |
| ranking: bool = False | |
| # extra configs | |
| subset: Optional[str] = None | |
| split: str = "train" | |
| folder: Optional[str] = None | |
| num_samples: Optional[int] = None | |
| # common columns | |
| system: Optional[str] = None | |
| tools: Optional[str] = None | |
| images: Optional[str] = None | |
| # rlhf columns | |
| chosen: Optional[str] = None | |
| rejected: Optional[str] = None | |
| kto_tag: Optional[str] = None | |
| # alpaca columns | |
| prompt: Optional[str] = "instruction" | |
| query: Optional[str] = "input" | |
| response: Optional[str] = "output" | |
| history: Optional[str] = None | |
| # sharegpt columns | |
| messages: Optional[str] = "conversations" | |
| # sharegpt tags | |
| role_tag: Optional[str] = "from" | |
| content_tag: Optional[str] = "value" | |
| user_tag: Optional[str] = "human" | |
| assistant_tag: Optional[str] = "gpt" | |
| observation_tag: Optional[str] = "observation" | |
| function_tag: Optional[str] = "function_call" | |
| system_tag: Optional[str] = "system" | |
| def __repr__(self) -> str: | |
| return self.dataset_name | |
| def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: | |
| setattr(self, key, obj.get(key, default)) | |
| def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: | |
| r""" | |
| Gets the attributes of the datasets. | |
| """ | |
| if dataset_names is None: | |
| dataset_names = [] | |
| if dataset_dir == "ONLINE": | |
| dataset_info = None | |
| else: | |
| if dataset_dir.startswith("REMOTE:"): | |
| config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") | |
| else: | |
| config_path = os.path.join(dataset_dir, DATA_CONFIG) | |
| try: | |
| with open(config_path, "r") as f: | |
| dataset_info = json.load(f) | |
| except Exception as err: | |
| if len(dataset_names) != 0: | |
| raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) | |
| dataset_info = None | |
| dataset_list: List["DatasetAttr"] = [] | |
| for name in dataset_names: | |
| if dataset_info is None: # dataset_dir is ONLINE | |
| load_from = "ms_hub" if use_modelscope() else "hf_hub" | |
| dataset_attr = DatasetAttr(load_from, dataset_name=name) | |
| dataset_list.append(dataset_attr) | |
| continue | |
| if name not in dataset_info: | |
| raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) | |
| has_hf_url = "hf_hub_url" in dataset_info[name] | |
| has_ms_url = "ms_hub_url" in dataset_info[name] | |
| if has_hf_url or has_ms_url: | |
| if (use_modelscope() and has_ms_url) or (not has_hf_url): | |
| dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) | |
| else: | |
| dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) | |
| elif "script_url" in dataset_info[name]: | |
| dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) | |
| else: | |
| dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) | |
| dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") | |
| dataset_attr.set_attr("ranking", dataset_info[name], default=False) | |
| dataset_attr.set_attr("subset", dataset_info[name]) | |
| dataset_attr.set_attr("split", dataset_info[name], default="train") | |
| dataset_attr.set_attr("folder", dataset_info[name]) | |
| dataset_attr.set_attr("num_samples", dataset_info[name]) | |
| if "columns" in dataset_info[name]: | |
| column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] | |
| if dataset_attr.formatting == "alpaca": | |
| column_names.extend(["prompt", "query", "response", "history"]) | |
| else: | |
| column_names.extend(["messages"]) | |
| for column_name in column_names: | |
| dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) | |
| if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: | |
| tag_names = ( | |
| "role_tag", | |
| "content_tag", | |
| "user_tag", | |
| "assistant_tag", | |
| "observation_tag", | |
| "function_tag", | |
| "system_tag", | |
| ) | |
| for tag in tag_names: | |
| dataset_attr.set_attr(tag, dataset_info[name]["tags"]) | |
| dataset_list.append(dataset_attr) | |
| return dataset_list | |