Spaces:
Runtime error
Runtime error
fix collator
Browse files
utils.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import time
|
| 3 |
import shutil
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
from typing import Union, Dict, List
|
| 6 |
|
| 7 |
import torch
|
|
@@ -321,6 +322,7 @@ def batch_embed(
|
|
| 321 |
|
| 322 |
start_time = time.time()
|
| 323 |
|
|
|
|
| 324 |
for batch in DataLoader(
|
| 325 |
ds,
|
| 326 |
batch_size=inference_bs,
|
|
@@ -328,8 +330,8 @@ def batch_embed(
|
|
| 328 |
num_workers=1,
|
| 329 |
pin_memory=True,
|
| 330 |
drop_last=False,
|
|
|
|
| 331 |
):
|
| 332 |
-
batch = collate_fn(batch, device=device)
|
| 333 |
ids = batch["input_ids"]
|
| 334 |
mask = batch["attention_mask"]
|
| 335 |
|
|
|
|
| 2 |
import time
|
| 3 |
import shutil
|
| 4 |
from pathlib import Path
|
| 5 |
+
from functools import partial
|
| 6 |
from typing import Union, Dict, List
|
| 7 |
|
| 8 |
import torch
|
|
|
|
| 322 |
|
| 323 |
start_time = time.time()
|
| 324 |
|
| 325 |
+
|
| 326 |
for batch in DataLoader(
|
| 327 |
ds,
|
| 328 |
batch_size=inference_bs,
|
|
|
|
| 330 |
num_workers=1,
|
| 331 |
pin_memory=True,
|
| 332 |
drop_last=False,
|
| 333 |
+
collate_fn=partial(collate_fn, device=device)
|
| 334 |
):
|
|
|
|
| 335 |
ids = batch["input_ids"]
|
| 336 |
mask = batch["attention_mask"]
|
| 337 |
|