Update README.md
Browse files
README.md
CHANGED
|
@@ -2999,14 +2999,14 @@ document_tokens = tokenizer(documents, padding=True, truncation=True, return_te
|
|
| 2999 |
# Compute token embeddings
|
| 3000 |
with torch.no_grad():
|
| 3001 |
query_embeddings = model(**query_tokens)[0][:, 0]
|
| 3002 |
-
|
| 3003 |
|
| 3004 |
|
| 3005 |
# normalize embeddings
|
| 3006 |
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
|
| 3007 |
-
|
| 3008 |
|
| 3009 |
-
scores = torch.mm(query_embeddings,
|
| 3010 |
for query, query_scores in zip(queries, scores):
|
| 3011 |
doc_score_pairs = list(zip(documents, query_scores))
|
| 3012 |
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
|
|
|
| 2999 |
# Compute token embeddings
|
| 3000 |
with torch.no_grad():
|
| 3001 |
query_embeddings = model(**query_tokens)[0][:, 0]
|
| 3002 |
+
document_embeddings = model(**document_tokens)[0][:, 0]
|
| 3003 |
|
| 3004 |
|
| 3005 |
# normalize embeddings
|
| 3006 |
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
|
| 3007 |
+
document_embeddings = torch.nn.functional.normalize(document_embeddings, p=2, dim=1)
|
| 3008 |
|
| 3009 |
+
scores = torch.mm(query_embeddings, document_embeddings.transpose(0, 1))
|
| 3010 |
for query, query_scores in zip(queries, scores):
|
| 3011 |
doc_score_pairs = list(zip(documents, query_scores))
|
| 3012 |
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|