Spaces:
Sleeping
Sleeping
| from typing import List, Tuple, TypedDict | |
| from re import sub | |
| from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging | |
| from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader | |
| from transformers import QuestionAnsweringPipeline | |
| from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast | |
| import torch | |
| cuda = torch.cuda.is_available() | |
| max_answer_len = 8 | |
| logging.set_verbosity_error() | |
| def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration, | |
| input_texts: List[str]): | |
| inputs = tokenizer(input_texts, padding=True, | |
| return_tensors='pt', truncation=True) | |
| if cuda: | |
| inputs = inputs.to(0) | |
| with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): | |
| summary_ids = model.generate(inputs["input_ids"]) | |
| else: | |
| summary_ids = model.generate(inputs["input_ids"]) | |
| summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, batch_size=len(input_texts)) | |
| return summaries | |
| def get_summarizer(model_id="seonglae/resrer-pegasus-x") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]: | |
| tokenizer = PegasusTokenizerFast.from_pretrained(model_id) | |
| model = PegasusXForConditionalGeneration.from_pretrained(model_id) | |
| if cuda: | |
| model = model.to(0) | |
| model = torch.compile(model) | |
| return tokenizer, model | |
| class AnswerInfo(TypedDict): | |
| score: float | |
| start: int | |
| end: int | |
| answer: str | |
| def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering, | |
| questions: List[str], ctxs: List[str]) -> List[AnswerInfo]: | |
| if cuda: | |
| with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): | |
| pipeline = QuestionAnsweringPipeline( | |
| model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len) | |
| answer_infos: List[AnswerInfo] = pipeline( | |
| question=questions, context=ctxs) | |
| else: | |
| pipeline = QuestionAnsweringPipeline( | |
| model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len) | |
| answer_infos = pipeline( | |
| question=questions, context=ctxs) | |
| if not isinstance(answer_infos, list): | |
| answer_infos = [answer_infos] | |
| for answer_info in answer_infos: | |
| answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer']) | |
| return answer_infos | |
| def get_reader(model_id="facebook/dpr-reader-single-nq-base"): | |
| tokenizer = DPRReaderTokenizer.from_pretrained(model_id) | |
| model = DPRReader.from_pretrained(model_id) | |
| if cuda: | |
| model = model.to(0) | |
| return tokenizer, model | |
| def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor: | |
| """Encode a question using DPR question encoder. | |
| https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder | |
| Args: | |
| question (str): question string to encode | |
| model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base | |
| """ | |
| batch_dict = tokenizer(questions, return_tensors="pt", | |
| padding=True, truncation=True) | |
| if cuda: | |
| batch_dict = batch_dict.to(0) | |
| with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): | |
| embeddings: torch.FloatTensor = model(**batch_dict).pooler_output | |
| else: | |
| embeddings = model(**batch_dict).pooler_output | |
| return embeddings | |
| def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]: | |
| """Encode a question using DPR question encoder. | |
| https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder | |
| Args: | |
| question (str): question string to encode | |
| model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base | |
| """ | |
| tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id) | |
| model = DPRQuestionEncoder.from_pretrained(model_id) | |
| if cuda: | |
| model = model.to(0) | |
| return tokenizer, model | |