Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import numpy as np | |
| def decode(id_to_something, tokenizer=None, data_args=None): | |
| decode_fn = None | |
| switch_case = None | |
| elem = next(iter(id_to_something.values())) | |
| if isinstance(elem, str): | |
| switch_case = -1 | |
| decode_fn = lambda text: text.strip() | |
| elif isinstance(elem, list) and not isinstance(elem[0], int): | |
| if isinstance(elem[0], str): | |
| switch_case = 0 | |
| decode_fn = lambda texts: [text.strip() for text in texts] | |
| else: | |
| switch_case = 1 | |
| decode_fn = lambda token_ids_list: [ | |
| text.strip() | |
| for text in partial( | |
| tokenizer.batch_decode, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| )(token_ids_list) | |
| ] | |
| else: | |
| switch_case = 2 | |
| decode_fn = lambda token_ids: partial( | |
| tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| )(token_ids).strip() | |
| id_to_text = {} | |
| for id_, something in id_to_something.items(): | |
| if switch_case == -1 or switch_case == 0: | |
| obj_to_decode = something | |
| else: | |
| if data_args is None: | |
| data_args = {} | |
| if not isinstance(data_args, dict): | |
| data_args = vars(data_args) | |
| if data_args.get("ignore_pad_token_for_loss", True): | |
| # Replace -100 in the token_ids as we can't decode them. | |
| if switch_case == 1: | |
| token_ids_list = something | |
| for i in range(len(token_ids_list)): | |
| token_ids_list[i] = _replace_padding(token_ids_list[i], tokenizer.pad_token_id) | |
| obj_to_decode = token_ids_list | |
| elif switch_case == 2: | |
| token_ids = something | |
| token_ids = _replace_padding(token_ids, tokenizer.pad_token_id) | |
| obj_to_decode = token_ids | |
| else: | |
| obj_to_decode = something | |
| id_to_text[id_] = decode_fn(obj_to_decode) | |
| return id_to_text | |
| def _replace_padding(token_ids: np.array, pad_token_id): | |
| return np.where(token_ids != -100, token_ids, pad_token_id) | |