Spaces:
Build error
Build error
| import torch | |
| import re | |
| def question_generation_sampling( | |
| g1_model, | |
| g1_tokenizer, | |
| g2_model, | |
| g2_tokenizer, | |
| context, | |
| num_questions, | |
| device, | |
| ): | |
| qa_input_ids = prepare_qa_input( | |
| g1_tokenizer, | |
| context=context, | |
| device=device, | |
| ) | |
| max_repeated_sampling = int(num_questions * 1.5) # sometimes generated question+answer is invalid | |
| num_valid_questions = 0 | |
| questions = [] | |
| for q_ in range(max_repeated_sampling): | |
| # Stage G.1: question+answer generation | |
| outputs = g1_model.generate( | |
| qa_input_ids, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| ) | |
| question_answer = g1_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| question_answer = question_answer.replace(g1_tokenizer.pad_token, "").replace(g1_tokenizer.eos_token, "") | |
| question_answer_split = question_answer.split(g1_tokenizer.sep_token) | |
| if len(question_answer_split) == 2: | |
| # valid Question + Annswer output | |
| num_valid_questions += 1 | |
| else: | |
| continue | |
| question = question_answer_split[0].strip() | |
| answer = question_answer_split[1].strip() | |
| # Stage G.2: Distractor Generation | |
| distractor_input_ids = prepare_distractor_input( | |
| g2_tokenizer, | |
| context = context, | |
| question = question, | |
| answer = answer, | |
| device = device, | |
| separator = g2_tokenizer.sep_token, | |
| ) | |
| outputs = g2_model.generate( | |
| distractor_input_ids, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| ) | |
| distractors = g2_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| distractors = distractors.replace(g2_tokenizer.pad_token, "").replace(g2_tokenizer.eos_token, "") | |
| distractors = re.sub("<extra\S+>", g2_tokenizer.sep_token, distractors) | |
| distractors = [y.strip() for y in distractors.split(g2_tokenizer.sep_token)] | |
| options = [answer] + distractors | |
| while len(options) < 4: | |
| options.append(options[-1]) | |
| question_item = { | |
| 'question': question, | |
| 'options': options, | |
| } | |
| questions.append(question_item) | |
| if num_valid_questions == num_questions: | |
| break | |
| return questions | |
| def prepare_qa_input(t5_tokenizer, context, device): | |
| """ | |
| input: context | |
| output: question <sep> answer | |
| """ | |
| encoding = t5_tokenizer( | |
| [context], | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoding.input_ids.to(device) | |
| return input_ids | |
| def prepare_distractor_input(t5_tokenizer, context, question, answer, device, separator='<sep>'): | |
| """ | |
| input: question <sep> answer <sep> article | |
| output: distractor1 <sep> distractor2 <sep> distractor3 | |
| """ | |
| input_text = question + ' ' + separator + ' ' + answer + ' ' + separator + ' ' + context | |
| encoding = t5_tokenizer( | |
| [input_text], | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoding.input_ids.to(device) | |
| return input_ids |