Update modeling_mcqbert.py
Browse files- modeling_mcqbert.py +3 -7
modeling_mcqbert.py
CHANGED
|
@@ -24,23 +24,19 @@ class MCQStudentBert(BertModel):
|
|
| 24 |
def forward(self, input_ids, student_embeddings=None):
|
| 25 |
if self.config.integration_strategy is None:
|
| 26 |
# don't consider embeddings is no integration strategy (MCQBert)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
input_embeddings = self.embeddings(input_ids)
|
| 30 |
-
combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).unsqueeze(1).repeat(1, input_embeddings.size(1), 1)
|
| 31 |
-
output = super().forward(inputs_embeds = combined_embeddings)
|
| 32 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
| 33 |
|
| 34 |
elif self.config.integration_strategy == "cat":
|
| 35 |
# MCQStudentBertCat
|
| 36 |
output = super().forward(input_ids)
|
| 37 |
-
output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings)), dim = 1)
|
| 38 |
return self.classifier(output_with_student_embedding)
|
| 39 |
|
| 40 |
elif self.config.integration_strategy == "sum":
|
| 41 |
# MCQStudentBertSum
|
| 42 |
input_embeddings = self.embeddings(input_ids)
|
| 43 |
-
combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).
|
| 44 |
output = super().forward(inputs_embeds = combined_embeddings)
|
| 45 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
| 46 |
|
|
|
|
| 24 |
def forward(self, input_ids, student_embeddings=None):
|
| 25 |
if self.config.integration_strategy is None:
|
| 26 |
# don't consider embeddings is no integration strategy (MCQBert)
|
| 27 |
+
output = super().forward(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
| 29 |
|
| 30 |
elif self.config.integration_strategy == "cat":
|
| 31 |
# MCQStudentBertCat
|
| 32 |
output = super().forward(input_ids)
|
| 33 |
+
output_with_student_embedding = torch.cat((output.last_hidden_state[:, 0, :], self.student_embedding_layer(student_embeddings).unsqueeze(0)), dim = 1)
|
| 34 |
return self.classifier(output_with_student_embedding)
|
| 35 |
|
| 36 |
elif self.config.integration_strategy == "sum":
|
| 37 |
# MCQStudentBertSum
|
| 38 |
input_embeddings = self.embeddings(input_ids)
|
| 39 |
+
combined_embeddings = input_embeddings + self.student_embedding_layer(student_embeddings).repeat(1, input_embeddings.size(1), 1)
|
| 40 |
output = super().forward(inputs_embeds = combined_embeddings)
|
| 41 |
return self.classifier(output.last_hidden_state[:, 0, :])
|
| 42 |
|