|
|
from dataclasses import dataclass |
|
|
import shutil |
|
|
from textwrap import dedent, indent |
|
|
from typing import Any |
|
|
import numpy as np |
|
|
from zstandard import ZstdCompressor |
|
|
from pathlib import Path |
|
|
import io |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from torch.nn import EmbeddingBag |
|
|
import torch |
|
|
from model2vec import StaticModel |
|
|
from tokenizers import Encoding, Tokenizer |
|
|
|
|
|
models_path = Path("models") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelCard: |
|
|
owner: str |
|
|
repo: str |
|
|
|
|
|
matroyshka_dims: list[int] |
|
|
description: str |
|
|
license: str |
|
|
|
|
|
def name(self): |
|
|
return f"{self.owner}/{self.repo}" |
|
|
|
|
|
def path(self): |
|
|
return models_path / self.owner / self.repo |
|
|
|
|
|
def get_description(self): |
|
|
return dedent(self.description).strip() |
|
|
|
|
|
|
|
|
def zst_compress_file(input: Path): |
|
|
cctx = ZstdCompressor() |
|
|
output = input.parent / f"{input.name}.zst" |
|
|
print(f"Compressing {output}") |
|
|
with open(input, "rb") as fin, open(output, "wb") as fout: |
|
|
cctx.copy_stream(fin, fout) |
|
|
|
|
|
|
|
|
def save_data(path: Path, tensor: torch.Tensor): |
|
|
"""Writes out the static embeddings to a .npy and .npy.zst file""" |
|
|
buffer = io.BytesIO() |
|
|
|
|
|
if tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): |
|
|
|
|
|
np.save(buffer, tensor.detach().view(torch.uint8).numpy()) |
|
|
else: |
|
|
np.save(buffer, tensor.detach().numpy()) |
|
|
|
|
|
print(f"Saving {path}") |
|
|
with (open(path, "wb") as outfile,): |
|
|
outfile.write(buffer.getvalue()) |
|
|
|
|
|
zst_compress_file(path) |
|
|
|
|
|
|
|
|
def quantization_loss_mse(tensor: torch.Tensor, dtype: torch.dtype): |
|
|
""" |
|
|
Compute reconstruction loss when converting embeddings to a datatype and back using |
|
|
the mean squared error, which punishes big errors more than small ones. |
|
|
""" |
|
|
|
|
|
|
|
|
roundtrip = tensor.detach().to(dtype).to(tensor.dtype) |
|
|
|
|
|
|
|
|
return torch.mean((tensor - roundtrip) ** 2).item() |
|
|
|
|
|
|
|
|
def quantization_loss_mae(tensor: torch.Tensor, dtype: torch.dtype): |
|
|
""" |
|
|
Compute reconstruction loss when converting embeddings to a datatype and back using |
|
|
the mean absolute error, which is less sensitive to outliers than MSE. |
|
|
""" |
|
|
|
|
|
|
|
|
roundtrip = tensor.detach().to(dtype).to(tensor.dtype) |
|
|
|
|
|
|
|
|
return torch.mean(torch.abs(tensor - roundtrip)).item() |
|
|
|
|
|
|
|
|
def quantization_loss_cosine(tensor: torch.Tensor, dtype: torch.dtype): |
|
|
""" |
|
|
Compute reconstruction loss when converting embeddings to a datatype and back using |
|
|
cosine similarity. This measures whether the embedding directions are preserved |
|
|
after quantization, independent of their magnitudes. |
|
|
""" |
|
|
|
|
|
|
|
|
roundtrip = tensor.detach().to(dtype).to(tensor.dtype) |
|
|
|
|
|
|
|
|
if tensor.ndim == 1: |
|
|
orig = tensor.unsqueeze(0) |
|
|
recon = roundtrip.unsqueeze(0) |
|
|
else: |
|
|
orig = tensor.view(tensor.shape[0], -1) |
|
|
recon = roundtrip.view(roundtrip.shape[0], -1) |
|
|
|
|
|
|
|
|
cos = torch.nn.functional.cosine_similarity(orig, recon, dim=1) |
|
|
return cos.mean().item() |
|
|
|
|
|
|
|
|
def export_embeddings(model_card: ModelCard, embeddings: torch.Tensor) -> None: |
|
|
vocab_size, dimensions = embeddings.shape |
|
|
|
|
|
|
|
|
assert ( |
|
|
embeddings.dtype == torch.float32 |
|
|
), f"The embeddings {embeddings.dtype} are assumed to be float32." |
|
|
|
|
|
for dim in model_card.matroyshka_dims: |
|
|
assert ( |
|
|
dim <= dimensions |
|
|
), f"The Matroyshka dimensions {dim} were bigger than the models dimensions of {dimensions}" |
|
|
|
|
|
truncated = embeddings[:, :dim] |
|
|
assert truncated.shape == torch.Size([vocab_size, dim]) |
|
|
|
|
|
save_data(model_card.path() / f"fp32.d{dim}.npy", truncated) |
|
|
save_data( |
|
|
model_card.path() / f"fp16.d{dim}.npy", |
|
|
truncated.to(dtype=torch.float16), |
|
|
) |
|
|
save_data( |
|
|
model_card.path() / f"fp8_e5m2.d{dim}.npy", |
|
|
truncated.to(dtype=torch.float8_e5m2), |
|
|
) |
|
|
save_data( |
|
|
model_card.path() / f"fp8_e4m3.d{dim}.npy", |
|
|
truncated.to(dtype=torch.float8_e4m3fn), |
|
|
) |
|
|
|
|
|
|
|
|
def normalized_mean_pooling(x: torch.Tensor) -> torch.Tensor: |
|
|
pooled = x.mean(dim=0) |
|
|
normalized = torch.nn.functional.normalize(pooled, dim=0) |
|
|
return normalized |
|
|
|
|
|
|
|
|
def export_readme( |
|
|
model_card: ModelCard, |
|
|
embeddings: torch.Tensor, |
|
|
tokenizer: Tokenizer, |
|
|
): |
|
|
vocab_size, dimensions = embeddings.shape |
|
|
norms = torch.norm(embeddings, dim=1) |
|
|
|
|
|
phrases = [ |
|
|
"The committee approved the proposal after hours of heated discussion and several last-minute amendments." |
|
|
"When training large neural networks, careful tuning of hyperparameters can significantly affect performance and stability." |
|
|
"Despite the heavy rain, the concert continued as planned and the crowd stayed enthusiastic until the final encore." |
|
|
"In ancient mythology, heroes often embarked on perilous journeys to discover hidden truths about themselves and their world." |
|
|
"The new smartphone model features an improved camera system, faster processing, and extended battery life compared to its predecessor." |
|
|
"He tried to explain the concept using simple analogies, but the underlying mathematics remained difficult to grasp for most listeners." |
|
|
"After weeks of negotiations, the two countries signed a historic trade agreement aimed at reducing tariffs and boosting cooperation." |
|
|
"She paused for a moment before answering, choosing her words carefully to avoid misunderstanding in such a delicate situation." |
|
|
"The detective pieced together the timeline of events, realizing that the key witness had provided a contradictory statement." |
|
|
"Remote work has changed the way teams collaborate, with online tools replacing traditional office routines and in-person meetings." |
|
|
] |
|
|
|
|
|
cosine_similarity = { |
|
|
torch.float16: [], |
|
|
torch.float8_e4m3fn: [], |
|
|
torch.float8_e5m2: [], |
|
|
} |
|
|
|
|
|
for phrase in phrases: |
|
|
encoding: Encoding = tokenizer.encode(phrase) |
|
|
embedded_phrase = embeddings[torch.tensor(encoding.ids, dtype=torch.long)] |
|
|
|
|
|
for dtype in cosine_similarity.keys(): |
|
|
pooling_unquantized = normalized_mean_pooling(embedded_phrase) |
|
|
pooling_roundtrip = normalized_mean_pooling( |
|
|
embedded_phrase.to(dtype).to(torch.float32) |
|
|
) |
|
|
cosine = torch.dot(pooling_unquantized, pooling_roundtrip).item() |
|
|
cosine_similarity[dtype].append(cosine) |
|
|
|
|
|
avg_cosine_similarity = { |
|
|
dtype: sum(values) / len(values) for dtype, values in cosine_similarity.items() |
|
|
} |
|
|
|
|
|
tokenizer_examples = [] |
|
|
for text in [ |
|
|
"This is an example of encoding", |
|
|
"The quick brown fox jumps over the lazy dog.", |
|
|
"Curaçao, naïve fiancé, jalapeño, déjà vu.", |
|
|
"Привет, как дела?", |
|
|
"Бързата кафява лисица прескача мързеливото куче.", |
|
|
"Γρήγορη καφέ αλεπού πηδάει πάνω από τον τεμπέλη σκύλο.", |
|
|
"اللغة العربية جميلة وغنية بالتاريخ.", |
|
|
"مرحبا بالعالم!", |
|
|
"Simplified: 快速的棕色狐狸跳过懒狗。", |
|
|
"Traditional: 快速的棕色狐狸跳過懶狗。", |
|
|
"素早い茶色の狐が怠け者の犬を飛び越える。", |
|
|
"コンピュータープログラミング", |
|
|
"빠른 갈색 여우가 게으른 개를 뛰어넘습니다.", |
|
|
"तेज़ भूरी लोमड़ी आलसी कुत्ते के ऊपर कूदती है।", |
|
|
"দ্রুত বাদামী শিয়াল অলস কুকুরের উপর দিয়ে লাফ দেয়।", |
|
|
"வேகமான பழுப்பு நரி சோம்பேறி நாயின் மேல் குதிக்கிறது.", |
|
|
"สุนัขจิ้งจอกสีน้ำตาลกระโดดข้ามสุนัขขี้เกียจ.", |
|
|
"ብሩክ ቡናማ ቀበሮ ሰነፍ ውሻን ተዘልሏል።", |
|
|
"Hello 世界 مرحبا 🌍", |
|
|
"123, αβγ, абв, العربية, 中文, हिन्दी.", |
|
|
]: |
|
|
encoding = tokenizer.encode(text) |
|
|
tokens = [f"`{token}`" for token in encoding.tokens] |
|
|
|
|
|
tokenizer_examples.append(f"**Input:** {text}<br/>") |
|
|
tokenizer_examples.append(f"**Tokens**: {' '.join(tokens)}") |
|
|
tokenizer_examples.append("") |
|
|
|
|
|
tokenizer_output = "\n".join(tokenizer_examples) |
|
|
|
|
|
with (model_card.path() / "README.md").open("wt") as file: |
|
|
prefix = " " |
|
|
|
|
|
file.write( |
|
|
dedent( |
|
|
f""" |
|
|
# [{model_card.name()}](https://huggingface.co/{model_card.name()}) |
|
|
|
|
|
License: [{model_card.license}](https://choosealicense.com/licenses/{model_card.license}/) |
|
|
|
|
|
{indent(model_card.get_description(), prefix).strip()} |
|
|
|
|
|
## Model Stats |
|
|
|
|
|
Stats that describe the embeddings tensor shapes and value distribution. |
|
|
|
|
|
| item | metric | value | |
|
|
| --------------| ----------------------- | ----- | |
|
|
| vocab | size | {vocab_size:,.0f} | |
|
|
| embedding | dimensions | {dimensions:,.0f} | |
|
|
| vector length | mean | {norms.mean().item():.2f} | |
|
|
| vector length | median | {norms.median().item():.2f} | |
|
|
| vector length | stddev | {norms.std().item():.2f} | |
|
|
| values | mean | {embeddings.mean().item():.2f} | |
|
|
| values | median | {embeddings.median().item():.2f} | |
|
|
| values | stddev | {embeddings.std().item():.2f} | |
|
|
|
|
|
## Mean Pooled Quantization Loss |
|
|
|
|
|
This test roundtrips the vectors through quantization, but performs the |
|
|
mean pooling arithmetic in float32 space. The quantized and unquantized |
|
|
mean pooled vectors are compared to each other to determine their cosine |
|
|
similarity, to show how much the meaning of the vector has changed due |
|
|
to quantization. |
|
|
|
|
|
| Precision | Cosine Similarity | |
|
|
| ------------- | ----------------- | |
|
|
| fp16 | {avg_cosine_similarity[torch.float16]:.5f} | |
|
|
| fp8 e4m3 | {avg_cosine_similarity[torch.float8_e4m3fn]:.5f} | |
|
|
| fp8 e5m2 | {avg_cosine_similarity[torch.float8_e5m2]:.5f} | |
|
|
|
|
|
## Quantization Loss Per Vector |
|
|
|
|
|
While ultimately the embedding vectors will be mean pooled together, it's |
|
|
still useful to look at the loss per-vector in the embedding table to see |
|
|
which quantization strategies retain the most vector meaning. |
|
|
|
|
|
- **Cosine Similarity** — measures how well the *direction* of embedding vectors |
|
|
is preserved after quantization, independent of scale. This is especially |
|
|
relevant when embeddings are used for similarity search or retrieval. |
|
|
- **MSE (Mean Squared Error)** — emphasizes large errors by squaring the |
|
|
differences. Useful for detecting whether any values are badly distorted. |
|
|
- **MAE (Mean Absolute Error)** — the average absolute difference between |
|
|
original and quantized values. Easier to interpret, less sensitive to outliers. |
|
|
|
|
|
| Precision | Metric | Value | |
|
|
| ------------- | ------ | ----- | |
|
|
| fp16 | cosine similarity | {quantization_loss_cosine(embeddings, torch.float16):.5f} | |
|
|
| fp8 e4m3 | cosine similarity | {quantization_loss_cosine(embeddings, torch.float8_e4m3fn):.5f} | |
|
|
| fp8 e5m2 | cosine similarity | {quantization_loss_cosine(embeddings, torch.float8_e5m2):.5f} | |
|
|
| fp16 | MSE | {quantization_loss_mse(embeddings, torch.float16):.5f} | |
|
|
| fp8 e4m3 | MSE | {quantization_loss_mse(embeddings, torch.float8_e4m3fn):.5f} | |
|
|
| fp8 e5m2 | MSE | {quantization_loss_mse(embeddings, torch.float8_e5m2):.5f} | |
|
|
| fp16 | MAE | {quantization_loss_mae(embeddings, torch.float16):.5f} | |
|
|
| fp8 e4m3 | MAE | {quantization_loss_mae(embeddings, torch.float8_e4m3fn):.5f} | |
|
|
| fp8 e5m2 | MAE | {quantization_loss_mae(embeddings, torch.float8_e5m2):.5f} | |
|
|
|
|
|
## Tokenizer Examples |
|
|
|
|
|
{indent(tokenizer_output, prefix).strip()} |
|
|
""" |
|
|
).strip() |
|
|
) |
|
|
|
|
|
|
|
|
def export_tokenizer(model_card: ModelCard, tokenizer: Tokenizer) -> None: |
|
|
tokenizer_path = model_card.path() / "tokenizer.json" |
|
|
print(f"Exporting tokenizer: {tokenizer_path}") |
|
|
tokenizer.save(str(tokenizer_path)) |
|
|
zst_compress_file(tokenizer_path) |
|
|
|
|
|
|
|
|
def export_sentence_transformers(model_card: ModelCard) -> None: |
|
|
"""Extract the embeddings and tokenizer from SentenceTransformers""" |
|
|
|
|
|
print("Processing", model_card.name()) |
|
|
|
|
|
model = SentenceTransformer(model_card.name(), device="cpu") |
|
|
embedding_bag: EmbeddingBag = model[0].embedding |
|
|
model_card.path().mkdir(exist_ok=True, parents=True) |
|
|
embeddings = torch.Tensor(embedding_bag.weight) |
|
|
|
|
|
export_embeddings(model_card, embeddings) |
|
|
export_tokenizer(model_card, model.tokenizer) |
|
|
export_readme(model_card, embeddings, model.tokenizer) |
|
|
|
|
|
|
|
|
def export_model2vec(model_card: ModelCard) -> None: |
|
|
"""Extract the embeddings and tokenizer from model2vec""" |
|
|
|
|
|
print("Processing", model_card.name()) |
|
|
|
|
|
model = StaticModel.from_pretrained(model_card.name()) |
|
|
model_card.path().mkdir(exist_ok=True, parents=True) |
|
|
embeddings = torch.from_numpy(model.embedding) |
|
|
export_embeddings(model_card, embeddings) |
|
|
export_tokenizer(model_card, model.tokenizer) |
|
|
export_readme(model_card, embeddings, model.tokenizer) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
|
|
|
sentence_transformers_models = [ |
|
|
ModelCard( |
|
|
owner="sentence-transformers", |
|
|
repo="static-similarity-mrl-multilingual-v1", |
|
|
description=""" |
|
|
Multi-lingual similarity embeddings that were trained with Matroyshka loss |
|
|
that allows for more effective truncation of the embedding vectors. It |
|
|
was trained on a variety of domains of multilingual datasets. |
|
|
|
|
|
It's a general purpose model that can be used for semantic textual similarity, |
|
|
paraphrase mining, text classification, clustering, and more |
|
|
""", |
|
|
matroyshka_dims=[32, 64, 128, 256, 512, 1024], |
|
|
license="apache-2.0", |
|
|
), |
|
|
ModelCard( |
|
|
owner="sentence-transformers", |
|
|
repo="static-retrieval-mrl-en-v1", |
|
|
description=""" |
|
|
English-only uncased similarity embeddings that were trained with Matroyshka |
|
|
loss that allows for more effective truncation of the embedding vectors. It |
|
|
was trained on a variety of domains of monolingual datasets. I was designed |
|
|
specifically for similarity retrieval. |
|
|
""", |
|
|
matroyshka_dims=[32, 64, 128, 256, 512, 1024], |
|
|
license="apache-2.0", |
|
|
), |
|
|
] |
|
|
|
|
|
model2vec_models = [ |
|
|
ModelCard( |
|
|
owner="minishlab", |
|
|
repo="potion-multilingual-128M", |
|
|
|
|
|
matroyshka_dims=[32, 64, 128, 256], |
|
|
description=""" |
|
|
A multilingual embedder. The details are a bit scant on how it's trained as |
|
|
there is no source code for it. However, it's likely a close architecture |
|
|
to the potion-retrieval-32M model, but trained on Common Crawl data. |
|
|
|
|
|
The 128M references the number of parameters in the embeddings: |
|
|
|
|
|
256 dimensions * 500,353 vocab. |
|
|
""", |
|
|
license="mit", |
|
|
), |
|
|
ModelCard( |
|
|
owner="minishlab", |
|
|
repo="potion-retrieval-32M", |
|
|
matroyshka_dims=[32, 64, 128, 256, 512], |
|
|
description=""" |
|
|
The token embeddings from a monolingual English 32M parameter model that was |
|
|
distilled from embeddings that were initialized from the the multi-domain |
|
|
[BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) |
|
|
|
|
|
The 32M references the number of parameters in the embeddings: |
|
|
|
|
|
512 dimension * 63,091 vocab. |
|
|
""", |
|
|
license="mit", |
|
|
), |
|
|
] |
|
|
|
|
|
if models_path.exists(): |
|
|
print(f"Removing the old models folder: {models_path}") |
|
|
shutil.rmtree(models_path) |
|
|
models_path.mkdir() |
|
|
|
|
|
for model_card in sentence_transformers_models: |
|
|
export_sentence_transformers(model_card) |
|
|
|
|
|
for model_card in model2vec_models: |
|
|
export_model2vec(model_card) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|