Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from transformers import glue_convert_examples_to_features as convert_examples_to_features | |
| from transformers import InputExample | |
| class MyClassifier(): | |
| def __init__(self,model,tokenizer,label_list,output_mode,exit_type,exit_value,model_type='albert',max_length=128): | |
| self.model = model | |
| self.model.eval() | |
| self.model_type = model_type | |
| self.tokenizer = tokenizer | |
| self.label_list = label_list | |
| self.output_mode = output_mode | |
| self.max_length = max_length | |
| self.exit_type = exit_type | |
| self.exit_value = exit_value | |
| self.count = 0 | |
| self.reset_status(mode='all',stats=True) | |
| if exit_type == 'patience': | |
| self.set_patience(patience=exit_value) | |
| elif exit_type == 'confi': | |
| self.set_threshold(confidence_threshold=exit_value) | |
| def tokenize(self,input_,idx): | |
| examples = [] | |
| guid = f"dev_{idx}" | |
| if input_[1] == "<none>": | |
| text_a = input_[0] | |
| text_b = None | |
| else: | |
| text_a = input_[0] | |
| text_b = input_[1] | |
| # print(f'len: {len(input_)}\t text_a: {text_a}\t text_b:{text_b}') | |
| label = None | |
| examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) | |
| # print(examples) | |
| features = convert_examples_to_features( | |
| examples, | |
| self.tokenizer, | |
| label_list=self.label_list, | |
| max_length=self.max_length, | |
| output_mode=self.output_mode, | |
| ) | |
| # print(features) | |
| all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) | |
| all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) | |
| all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) | |
| return all_input_ids,all_attention_mask,all_token_type_ids | |
| def set_threshold(self,confidence_threshold): | |
| if self.model_type == 'albert': | |
| self.model.albert.set_confi_threshold(confidence_threshold) | |
| elif self.model_type == 'bert': | |
| self.model.bert.set_confi_threshold(confidence_threshold) | |
| def set_patience(self,patience): | |
| if self.model_type == 'albert': | |
| self.model.albert.set_patience(patience) | |
| elif self.model_type == 'bert': | |
| self.model.bert.set_patience(patience) | |
| def set_exit_position(self,exit_pos): | |
| if self.model_type == 'albert': | |
| self.model.albert.set_exit_pos(exit_pos) | |
| def reset_status(self,mode,stats=False): | |
| if self.model_type == 'albert': | |
| self.model.albert.set_mode(mode) | |
| if stats: | |
| self.model.albert.reset_stats() | |
| elif self.model_type == 'bert': | |
| self.model.bert.set_mode(mode) | |
| if stats: | |
| self.model.bert.reset_stats() | |
| def get_exit_number(self): | |
| if self.model_type == 'albert': | |
| return self.model.albert.config.num_hidden_layers | |
| elif self.model_type == 'bert': | |
| return self.model.bert.config.num_hidden_layers | |
| def get_current_exit(self): | |
| if self.model_type == 'albert': | |
| return self.model.albert.current_exit_layer | |
| elif self.model_type == 'bert': | |
| return self.model.bert.current_exit_layer | |
| # TODO: 改一下预测算法得到预测结果 | |
| def get_pred(self,input_): | |
| # print(self.get_prob(input_).argmax(axis=2).shape) | |
| return self.get_prob(input_).argmax(axis=2) | |
| def get_prob(self,input_): | |
| self.reset_status(mode=self.exit_type,stats=False) # set patience | |
| ret = [] | |
| for sent in input_: | |
| self.count+=1 | |
| batch = self.tokenize(sent,idx=self.count) | |
| inputs = {"input_ids": batch[0], "attention_mask": batch[1],"token_type_ids":batch[2]} | |
| outputs = self.model(**inputs)[0] # get all logits | |
| output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy() | |
| ret.append(output_) | |
| return np.array(ret) | |
| def get_prob_time(self,input_,exit_position): | |
| self.reset_status(mode='exact',stats=False) # set patience | |
| self.set_exit_position(exit_position) | |
| ret = [] | |
| for sent in input_: | |
| self.count+=1 | |
| batch = self.tokenize(sent,idx=self.count) | |
| inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2]} | |
| outputs = self.model(**inputs)[0] # get all logits | |
| output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy() | |
| ret.append(output_) | |
| return np.array(ret) |