Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModel, | |
| PreTrainedModel, | |
| T5ForConditionalGeneration, | |
| ) | |
| class ReactionT5Yield(nn.Module): | |
| def __init__(self, cfg, config_path=None, pretrained=False): | |
| super().__init__() | |
| self.cfg = cfg | |
| if config_path is None: | |
| self.config = AutoConfig.from_pretrained( | |
| self.cfg.pretrained_model_name_or_path, output_hidden_states=True | |
| ) | |
| else: | |
| self.config = torch.load(config_path, weights_only=False) | |
| if pretrained: | |
| self.model = AutoModel.from_pretrained( | |
| self.cfg.pretrained_model_name_or_path | |
| ) | |
| else: | |
| self.model = AutoModel.from_config(self.config) | |
| self.model.resize_token_embeddings(len(self.cfg.tokenizer)) | |
| self.fc_dropout1 = nn.Dropout(self.cfg.fc_dropout) | |
| self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) | |
| self.fc_dropout2 = nn.Dropout(self.cfg.fc_dropout) | |
| self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) | |
| self.fc3 = nn.Linear(self.config.hidden_size // 2 * 2, self.config.hidden_size) | |
| self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
| self.fc5 = nn.Linear(self.config.hidden_size, 1) | |
| self._init_weights(self.fc1) | |
| self._init_weights(self.fc2) | |
| self._init_weights(self.fc3) | |
| self._init_weights(self.fc4) | |
| self._init_weights(self.fc5) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.01) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=0.01) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def forward(self, inputs): | |
| encoder_outputs = self.model.encoder(**inputs) | |
| encoder_hidden_states = encoder_outputs[0] | |
| outputs = self.model.decoder( | |
| input_ids=torch.full( | |
| (inputs["input_ids"].size(0), 1), | |
| self.config.decoder_start_token_id, | |
| dtype=torch.long, | |
| device=inputs["input_ids"].device, | |
| ), | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| last_hidden_states = outputs[0] | |
| output1 = self.fc1( | |
| self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size) | |
| ) | |
| output2 = self.fc2( | |
| encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) | |
| ) | |
| output = self.fc3(self.fc_dropout2(torch.hstack((output1, output2)))) | |
| output = self.fc4(output) | |
| output = self.fc5(output) | |
| return output | |
| def generate_embedding(self, inputs): | |
| encoder_outputs = self.model.encoder(**inputs) | |
| encoder_hidden_states = encoder_outputs[0] | |
| outputs = self.model.decoder( | |
| input_ids=torch.full( | |
| (inputs["input_ids"].size(0), 1), | |
| self.config.decoder_start_token_id, | |
| dtype=torch.long, | |
| device=inputs["input_ids"].device, | |
| ), | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| last_hidden_states = outputs[0] | |
| output1 = self.fc1( | |
| self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size) | |
| ) | |
| output2 = self.fc2( | |
| encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) | |
| ) | |
| return torch.hstack((output1, output2)) | |
| class ReactionT5Yield2(PreTrainedModel): | |
| config_class = AutoConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.model = T5ForConditionalGeneration.from_pretrained( | |
| self.config._name_or_path | |
| ) | |
| self.model.resize_token_embeddings(self.config.vocab_size) | |
| self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) | |
| self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size // 2) | |
| self.fc3 = nn.Linear(self.config.hidden_size // 2 * 2, self.config.hidden_size) | |
| self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
| self.fc5 = nn.Linear(self.config.hidden_size, 1) | |
| self._init_weights(self.fc1) | |
| self._init_weights(self.fc2) | |
| self._init_weights(self.fc3) | |
| self._init_weights(self.fc4) | |
| self._init_weights(self.fc5) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.01) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=0.01) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def forward(self, inputs): | |
| encoder_outputs = self.model.encoder(**inputs) | |
| encoder_hidden_states = encoder_outputs[0] | |
| outputs = self.model.decoder( | |
| input_ids=torch.full( | |
| (inputs["input_ids"].size(0), 1), | |
| self.config.decoder_start_token_id, | |
| dtype=torch.long, | |
| device=inputs["input_ids"].device, | |
| ), | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| last_hidden_states = outputs[0] | |
| output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size)) | |
| output2 = self.fc2( | |
| encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) | |
| ) | |
| output = self.fc3(torch.hstack((output1, output2))) | |
| output = self.fc4(output) | |
| output = self.fc5(output) | |
| return output * 100 | |
| def generate_embedding(self, inputs): | |
| encoder_outputs = self.model.encoder(**inputs) | |
| encoder_hidden_states = encoder_outputs[0] | |
| outputs = self.model.decoder( | |
| input_ids=torch.full( | |
| (inputs["input_ids"].size(0), 1), | |
| self.config.decoder_start_token_id, | |
| dtype=torch.long, | |
| device=inputs["input_ids"].device, | |
| ), | |
| encoder_hidden_states=encoder_hidden_states, | |
| ) | |
| last_hidden_states = outputs[0] | |
| output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size)) | |
| output2 = self.fc2( | |
| encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size) | |
| ) | |
| return torch.hstack((output1, output2)) | |