Spaces:
Runtime error
Runtime error
| from typing import List | |
| import random | |
| import argparse | |
| from datasets import load_dataset | |
| from datasets import Dataset | |
| from sonicverse.constants import ROLE_ASSISTANT, ROLE_USER | |
| from sonicverse.modalities.document_gte import ( | |
| split_text_into_documents, | |
| ) | |
| TEMP_TOKEN = "<<<TEMP-TOKEN>>>" | |
| PRETRAIN_PHRASES = [ | |
| f"Repeat the content of the document {TEMP_TOKEN}", | |
| f"Transcribe {TEMP_TOKEN}", | |
| f"Provide a verbatim transcription of {TEMP_TOKEN}", | |
| f"Write down exactly what is in {TEMP_TOKEN}", | |
| f"Copy the text from {TEMP_TOKEN}", | |
| f"Duplicate the content of {TEMP_TOKEN}", | |
| f"Reproduce the text in {TEMP_TOKEN}", | |
| f"Render the exact text from {TEMP_TOKEN}", | |
| f"Echo the content of {TEMP_TOKEN}", | |
| f"Mirror the text in {TEMP_TOKEN}", | |
| f"Reflect the content of {TEMP_TOKEN}", | |
| f"Transcribe the exact words from {TEMP_TOKEN}", | |
| f"Write out the exact content of {TEMP_TOKEN}", | |
| f"Provide a direct transcription of {TEMP_TOKEN}", | |
| f"Give a word-for-word account of {TEMP_TOKEN}", | |
| f"Reiterate the exact text of {TEMP_TOKEN}", | |
| f"Replicate the content of {TEMP_TOKEN}", | |
| f"Reprint the text from {TEMP_TOKEN}", | |
| f"Rewrite the exact words from {TEMP_TOKEN}", | |
| ] | |
| def _write_convo(row, max_document_chunks) -> List: | |
| docs = split_text_into_documents(row["text"]) | |
| if len(docs) > max_document_chunks: | |
| raise ValueError("Document too long") | |
| example = { | |
| "id": str(row["title"]), | |
| "documents": docs, | |
| } | |
| phrase = random.choice(PRETRAIN_PHRASES) | |
| example["messages"] = [ | |
| { | |
| "role": ROLE_USER, | |
| "content": phrase.replace(TEMP_TOKEN, "<document>" * len(docs)), | |
| }, | |
| { | |
| "role": ROLE_ASSISTANT, | |
| "content": row["text"], | |
| }, | |
| ] | |
| return example | |
| def main(args): | |
| wiki_data = load_dataset("graelo/wikipedia", "20230601.en")["train"] | |
| idxs = list(range(len(wiki_data))) | |
| random.shuffle(idxs) | |
| def gen(): | |
| i = 0 | |
| for idx in idxs: | |
| row = wiki_data[idx] | |
| try: | |
| yield _write_convo(row, args.max_document_chunks) | |
| except ValueError: | |
| pass | |
| else: | |
| i += 1 | |
| if i >= args.max_examples: | |
| break | |
| ds = Dataset.from_generator(gen) | |
| ds.save_to_disk(args.output_folder) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-o", "--output_folder", type=str) | |
| parser.add_argument("-n", "--max_examples", type=int, default=1_000_000) | |
| parser.add_argument("-c", "--max_document_chunks", type=int, default=4) | |
| args = parser.parse_args() | |
| main(args) | |