Spaces:
Runtime error
Runtime error
fix download wikipedia
Browse filesspecify number to embed, to skip
utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import time
|
| 3 |
import shutil
|
| 4 |
from pathlib import Path
|
|
@@ -107,33 +108,48 @@ def load_hf_dataset(ds_name: str, ds_config: str = None, ds_split: str = "train"
|
|
| 107 |
if ds_config == "":
|
| 108 |
ds_config = None
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
return ds
|
| 114 |
|
| 115 |
-
def download_wikipedia(ds_name, ds_config):
|
| 116 |
ds = load_dataset(ds_name, ds_config, streaming=True, split="train")
|
| 117 |
|
| 118 |
def gen():
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
ds2 = Dataset.from_generator(gen)
|
| 123 |
|
| 124 |
-
chunk_size =
|
| 125 |
|
| 126 |
filenames = []
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
for chunk_num, start_idx in enumerate(range(0, len(ds2), chunk_size)):
|
| 131 |
end_idx = min(start_idx + chunk_size, len(ds2))
|
| 132 |
|
| 133 |
temp = ds2.select(range(start_idx, end_idx))
|
| 134 |
|
| 135 |
-
temp.to_parquet(f"
|
| 136 |
-
filenames.append(f"
|
| 137 |
|
| 138 |
return load_dataset("parquet", data_files=filenames, split="train")
|
| 139 |
|
|
|
|
| 1 |
import os
|
| 2 |
+
import re
|
| 3 |
import time
|
| 4 |
import shutil
|
| 5 |
from pathlib import Path
|
|
|
|
| 108 |
if ds_config == "":
|
| 109 |
ds_config = None
|
| 110 |
|
| 111 |
+
if ds_name == "wikipedia":
|
| 112 |
+
pattern = re.compile(r"[^a-zA-Z0-9]")
|
| 113 |
+
folder = Path("/data") / pattern.sub("", ds_name+ds_config)
|
| 114 |
+
files = list(map(str, folder.glob("chunk_")))
|
| 115 |
+
|
| 116 |
+
return load_dataset("parquet", data_files=files, split="train")
|
| 117 |
+
|
| 118 |
+
ds = load_dataset(ds_name, ds_config, split=ds_split)
|
| 119 |
|
| 120 |
return ds
|
| 121 |
|
| 122 |
+
def download_wikipedia(ds_name, ds_config, num2skip, num2embed):
|
| 123 |
ds = load_dataset(ds_name, ds_config, streaming=True, split="train")
|
| 124 |
|
| 125 |
def gen():
|
| 126 |
+
if num2embed > 0:
|
| 127 |
+
|
| 128 |
+
for example in ds.skip(num2skip).take(num2embed):
|
| 129 |
+
yield {"text": example["text"]}
|
| 130 |
+
else:
|
| 131 |
+
for example in ds.skip(num2skip):
|
| 132 |
+
yield {"text": example["text"]}
|
| 133 |
|
| 134 |
ds2 = Dataset.from_generator(gen)
|
| 135 |
|
| 136 |
+
chunk_size = 20_000
|
| 137 |
|
| 138 |
filenames = []
|
| 139 |
|
| 140 |
+
pattern = re.compile(r"[^a-zA-Z0-9]")
|
| 141 |
+
|
| 142 |
+
folder = Path("/data") / pattern.sub("", ds_name+ds_config)
|
| 143 |
+
|
| 144 |
+
folder.mkdir(exist_ok=True, parents=True)
|
| 145 |
|
| 146 |
for chunk_num, start_idx in enumerate(range(0, len(ds2), chunk_size)):
|
| 147 |
end_idx = min(start_idx + chunk_size, len(ds2))
|
| 148 |
|
| 149 |
temp = ds2.select(range(start_idx, end_idx))
|
| 150 |
|
| 151 |
+
temp.to_parquet(str(folder / f"chunk_{chunk_num}"))
|
| 152 |
+
filenames.append(str(folder / f"chunk_{chunk_num}"))
|
| 153 |
|
| 154 |
return load_dataset("parquet", data_files=filenames, split="train")
|
| 155 |
|