Spaces:
Runtime error
Runtime error
fix batching v2
Browse files
utils.py
CHANGED
|
@@ -211,6 +211,17 @@ def tokenize(
|
|
| 211 |
)
|
| 212 |
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
@torch.inference_mode()
|
| 215 |
def batch_embed(
|
| 216 |
ds: datasets.IterableDataset,
|
|
@@ -308,18 +319,20 @@ def batch_embed(
|
|
| 308 |
ds,
|
| 309 |
batch_size=inference_bs,
|
| 310 |
shuffle=False,
|
| 311 |
-
num_workers=
|
| 312 |
pin_memory=True,
|
| 313 |
drop_last=False,
|
| 314 |
):
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
| 317 |
t_ids = torch.zeros_like(ids)
|
| 318 |
|
| 319 |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
|
| 320 |
|
| 321 |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
|
| 322 |
-
texts.extend(
|
| 323 |
|
| 324 |
current_count += ids.shape[0]
|
| 325 |
|
|
|
|
| 211 |
)
|
| 212 |
|
| 213 |
|
| 214 |
+
def collate_fn(examples, tokenizer=None, padding=None, device=None):
|
| 215 |
+
batch = {k: [] for k in examples[0].keys()}
|
| 216 |
+
|
| 217 |
+
for example in examples:
|
| 218 |
+
for k, v in example.items():
|
| 219 |
+
batch[k].append(v)
|
| 220 |
+
|
| 221 |
+
return {
|
| 222 |
+
k: torch.tensor(v, dtype=torch.long, device=device) if k in {"attention_mask", "input_ids"} else v for k, v in batch.items()
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
@torch.inference_mode()
|
| 226 |
def batch_embed(
|
| 227 |
ds: datasets.IterableDataset,
|
|
|
|
| 319 |
ds,
|
| 320 |
batch_size=inference_bs,
|
| 321 |
shuffle=False,
|
| 322 |
+
num_workers=1,
|
| 323 |
pin_memory=True,
|
| 324 |
drop_last=False,
|
| 325 |
):
|
| 326 |
+
batch = collate_fn(batch, device=device)
|
| 327 |
+
ids = batch["input_ids"]
|
| 328 |
+
mask = batch["attention_mask"]
|
| 329 |
+
|
| 330 |
t_ids = torch.zeros_like(ids)
|
| 331 |
|
| 332 |
outputs = model(input_ids=ids, attention_mask=mask, token_type_ids=t_ids)
|
| 333 |
|
| 334 |
embeds.extend(mean_pooling(outputs[0], mask).cpu().tolist())
|
| 335 |
+
texts.extend([b[column_name] for b in batch])
|
| 336 |
|
| 337 |
current_count += ids.shape[0]
|
| 338 |
|