Upload 3 files
Browse files- convert.py +66 -0
 - model.py +203 -0
 - test.py +143 -0
 
    	
        convert.py
    ADDED
    
    | 
         @@ -0,0 +1,66 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy
         
     | 
| 6 | 
         
            +
            from transformers import AutoModel, AutoConfig
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def replace_key(key: str) -> str:
         
     | 
| 10 | 
         
            +
                key = key.replace(".layer.", ".layers.")
         
     | 
| 11 | 
         
            +
                key = key.replace(".self.key.", ".key_proj.")
         
     | 
| 12 | 
         
            +
                key = key.replace(".self.query.", ".query_proj.")
         
     | 
| 13 | 
         
            +
                key = key.replace(".self.value.", ".value_proj.")
         
     | 
| 14 | 
         
            +
                key = key.replace(".attention.output.dense.", ".attention.out_proj.")
         
     | 
| 15 | 
         
            +
                key = key.replace(".attention.output.LayerNorm.", ".ln1.")
         
     | 
| 16 | 
         
            +
                key = key.replace(".output.LayerNorm.", ".ln2.")
         
     | 
| 17 | 
         
            +
                key = key.replace(".intermediate.dense.", ".linear1.")
         
     | 
| 18 | 
         
            +
                key = key.replace(".output.dense.", ".linear2.")
         
     | 
| 19 | 
         
            +
                key = key.replace(".LayerNorm.", ".norm.")
         
     | 
| 20 | 
         
            +
                key = key.replace("pooler.dense.", "pooler.")
         
     | 
| 21 | 
         
            +
                return key
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def convert(bert_model: str, mlx_model: str) -> None:
         
     | 
| 25 | 
         
            +
                # Load model and its configuration
         
     | 
| 26 | 
         
            +
                model = AutoModel.from_pretrained(bert_model)
         
     | 
| 27 | 
         
            +
                config = AutoConfig.from_pretrained(bert_model)
         
     | 
| 28 | 
         
            +
                
         
     | 
| 29 | 
         
            +
                # Create output directory if it doesn't exist
         
     | 
| 30 | 
         
            +
                output_dir = os.path.dirname(mlx_model)
         
     | 
| 31 | 
         
            +
                if output_dir and not os.path.exists(output_dir):
         
     | 
| 32 | 
         
            +
                    os.makedirs(output_dir)
         
     | 
| 33 | 
         
            +
                
         
     | 
| 34 | 
         
            +
                # Save config as well
         
     | 
| 35 | 
         
            +
                config_path = os.path.join(output_dir, "config.json")
         
     | 
| 36 | 
         
            +
                with open(config_path, "w") as f:
         
     | 
| 37 | 
         
            +
                    f.write(config.to_json_string())
         
     | 
| 38 | 
         
            +
                    
         
     | 
| 39 | 
         
            +
                print(f"Saved model config to {config_path}")
         
     | 
| 40 | 
         
            +
                
         
     | 
| 41 | 
         
            +
                # Save the tensors
         
     | 
| 42 | 
         
            +
                tensors = {
         
     | 
| 43 | 
         
            +
                    replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
         
     | 
| 44 | 
         
            +
                }
         
     | 
| 45 | 
         
            +
                numpy.savez(mlx_model, **tensors)
         
     | 
| 46 | 
         
            +
                print(f"Saved model weights to {mlx_model}")
         
     | 
| 47 | 
         
            +
                print(f"Model vocab size: {config.vocab_size}, hidden size: {config.hidden_size}")
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 51 | 
         
            +
                parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
         
     | 
| 52 | 
         
            +
                parser.add_argument(
         
     | 
| 53 | 
         
            +
                    "--bert-model",
         
     | 
| 54 | 
         
            +
                    type=str,
         
     | 
| 55 | 
         
            +
                    default="bert-base-uncased",
         
     | 
| 56 | 
         
            +
                    help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
         
     | 
| 57 | 
         
            +
                )
         
     | 
| 58 | 
         
            +
                parser.add_argument(
         
     | 
| 59 | 
         
            +
                    "--mlx-model",
         
     | 
| 60 | 
         
            +
                    type=str,
         
     | 
| 61 | 
         
            +
                    default="weights/bert-base-uncased.npz",
         
     | 
| 62 | 
         
            +
                    help="The output path for the MLX BERT weights.",
         
     | 
| 63 | 
         
            +
                )
         
     | 
| 64 | 
         
            +
                args = parser.parse_args()
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                convert(args.bert_model, args.mlx_model)
         
     | 
    	
        model.py
    ADDED
    
    | 
         @@ -0,0 +1,203 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 4 | 
         
            +
            from pathlib import Path
         
     | 
| 5 | 
         
            +
            from typing import List, Optional, Tuple
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import mlx.core as mx
         
     | 
| 8 | 
         
            +
            import mlx.nn as nn
         
     | 
| 9 | 
         
            +
            import numpy
         
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            from mlx.utils import tree_unflatten
         
     | 
| 12 | 
         
            +
            from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class TransformerEncoderLayer(nn.Module):
         
     | 
| 16 | 
         
            +
                """
         
     | 
| 17 | 
         
            +
                A transformer encoder layer with (the original BERT) post-normalization.
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def __init__(
         
     | 
| 21 | 
         
            +
                    self,
         
     | 
| 22 | 
         
            +
                    dims: int,
         
     | 
| 23 | 
         
            +
                    num_heads: int,
         
     | 
| 24 | 
         
            +
                    mlp_dims: Optional[int] = None,
         
     | 
| 25 | 
         
            +
                    layer_norm_eps: float = 1e-12,
         
     | 
| 26 | 
         
            +
                ):
         
     | 
| 27 | 
         
            +
                    super().__init__()
         
     | 
| 28 | 
         
            +
                    mlp_dims = mlp_dims or dims * 4
         
     | 
| 29 | 
         
            +
                    self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
         
     | 
| 30 | 
         
            +
                    self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
         
     | 
| 31 | 
         
            +
                    self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
         
     | 
| 32 | 
         
            +
                    self.linear1 = nn.Linear(dims, mlp_dims)
         
     | 
| 33 | 
         
            +
                    self.linear2 = nn.Linear(mlp_dims, dims)
         
     | 
| 34 | 
         
            +
                    self.gelu = nn.GELU()
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def __call__(self, x, mask):
         
     | 
| 37 | 
         
            +
                    attention_out = self.attention(x, x, x, mask)
         
     | 
| 38 | 
         
            +
                    add_and_norm = self.ln1(x + attention_out)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    ff = self.linear1(add_and_norm)
         
     | 
| 41 | 
         
            +
                    ff_gelu = self.gelu(ff)
         
     | 
| 42 | 
         
            +
                    ff_out = self.linear2(ff_gelu)
         
     | 
| 43 | 
         
            +
                    x = self.ln2(ff_out + add_and_norm)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    return x
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            class TransformerEncoder(nn.Module):
         
     | 
| 49 | 
         
            +
                def __init__(
         
     | 
| 50 | 
         
            +
                    self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
         
     | 
| 51 | 
         
            +
                ):
         
     | 
| 52 | 
         
            +
                    super().__init__()
         
     | 
| 53 | 
         
            +
                    self.layers = [
         
     | 
| 54 | 
         
            +
                        TransformerEncoderLayer(dims, num_heads, mlp_dims)
         
     | 
| 55 | 
         
            +
                        for i in range(num_layers)
         
     | 
| 56 | 
         
            +
                    ]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def __call__(self, x, mask):
         
     | 
| 59 | 
         
            +
                    for layer in self.layers:
         
     | 
| 60 | 
         
            +
                        x = layer(x, mask)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    return x
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            class BertEmbeddings(nn.Module):
         
     | 
| 66 | 
         
            +
                def __init__(self, config):
         
     | 
| 67 | 
         
            +
                    super().__init__()
         
     | 
| 68 | 
         
            +
                    self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
         
     | 
| 69 | 
         
            +
                    self.token_type_embeddings = nn.Embedding(
         
     | 
| 70 | 
         
            +
                        config.type_vocab_size, config.hidden_size
         
     | 
| 71 | 
         
            +
                    )
         
     | 
| 72 | 
         
            +
                    self.position_embeddings = nn.Embedding(
         
     | 
| 73 | 
         
            +
                        config.max_position_embeddings, config.hidden_size
         
     | 
| 74 | 
         
            +
                    )
         
     | 
| 75 | 
         
            +
                    self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def __call__(
         
     | 
| 78 | 
         
            +
                    self, input_ids: mx.array, token_type_ids: mx.array = None
         
     | 
| 79 | 
         
            +
                ) -> mx.array:
         
     | 
| 80 | 
         
            +
                    words = self.word_embeddings(input_ids)
         
     | 
| 81 | 
         
            +
                    position = self.position_embeddings(
         
     | 
| 82 | 
         
            +
                        mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
         
     | 
| 83 | 
         
            +
                    )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    if token_type_ids is None:
         
     | 
| 86 | 
         
            +
                        # If token_type_ids is not provided, default to zeros
         
     | 
| 87 | 
         
            +
                        token_type_ids = mx.zeros_like(input_ids)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    token_types = self.token_type_embeddings(token_type_ids)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    embeddings = position + words + token_types
         
     | 
| 92 | 
         
            +
                    return self.norm(embeddings)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            class Bert(nn.Module):
         
     | 
| 96 | 
         
            +
                def __init__(self, config):
         
     | 
| 97 | 
         
            +
                    super().__init__()
         
     | 
| 98 | 
         
            +
                    self.embeddings = BertEmbeddings(config)
         
     | 
| 99 | 
         
            +
                    self.encoder = TransformerEncoder(
         
     | 
| 100 | 
         
            +
                        num_layers=config.num_hidden_layers,
         
     | 
| 101 | 
         
            +
                        dims=config.hidden_size,
         
     | 
| 102 | 
         
            +
                        num_heads=config.num_attention_heads,
         
     | 
| 103 | 
         
            +
                        mlp_dims=config.intermediate_size,
         
     | 
| 104 | 
         
            +
                    )
         
     | 
| 105 | 
         
            +
                    self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def __call__(
         
     | 
| 108 | 
         
            +
                    self,
         
     | 
| 109 | 
         
            +
                    input_ids: mx.array,
         
     | 
| 110 | 
         
            +
                    token_type_ids: mx.array = None,
         
     | 
| 111 | 
         
            +
                    attention_mask: mx.array = None,
         
     | 
| 112 | 
         
            +
                ) -> Tuple[mx.array, mx.array]:
         
     | 
| 113 | 
         
            +
                    x = self.embeddings(input_ids, token_type_ids)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    if attention_mask is not None:
         
     | 
| 116 | 
         
            +
                        # convert 0's to -infs, 1's to 0's, and make it broadcastable
         
     | 
| 117 | 
         
            +
                        attention_mask = mx.log(attention_mask)
         
     | 
| 118 | 
         
            +
                        attention_mask = mx.expand_dims(attention_mask, (1, 2))
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    y = self.encoder(x, attention_mask)
         
     | 
| 121 | 
         
            +
                    return y, mx.tanh(self.pooler(y[:, 0]))
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            def load_model(
         
     | 
| 125 | 
         
            +
                bert_model: str, weights_path: str
         
     | 
| 126 | 
         
            +
            ) -> Tuple[Bert, PreTrainedTokenizerBase]:
         
     | 
| 127 | 
         
            +
                if not Path(weights_path).exists():
         
     | 
| 128 | 
         
            +
                    raise ValueError(f"No model weights found in {weights_path}")
         
     | 
| 129 | 
         
            +
                
         
     | 
| 130 | 
         
            +
                # First check if there's a local config
         
     | 
| 131 | 
         
            +
                config_path = Path(weights_path).parent / "config.json"
         
     | 
| 132 | 
         
            +
                if config_path.exists():
         
     | 
| 133 | 
         
            +
                    with open(config_path, "r") as f:
         
     | 
| 134 | 
         
            +
                        config_dict = json.load(f)
         
     | 
| 135 | 
         
            +
                    config = AutoConfig.for_model(**config_dict)
         
     | 
| 136 | 
         
            +
                    print(f"Loaded local config from {config_path}")
         
     | 
| 137 | 
         
            +
                else:
         
     | 
| 138 | 
         
            +
                    # If no local config, use the HuggingFace one
         
     | 
| 139 | 
         
            +
                    config = AutoConfig.from_pretrained(bert_model)
         
     | 
| 140 | 
         
            +
                    print(f"Loaded config from HuggingFace for {bert_model}")
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                # Create and update the model
         
     | 
| 143 | 
         
            +
                print(f"Creating model with vocab_size={config.vocab_size}, hidden_size={config.hidden_size}")
         
     | 
| 144 | 
         
            +
                model = Bert(config)
         
     | 
| 145 | 
         
            +
                model.load_weights(weights_path)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(bert_model)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                return model, tokenizer
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            def run(bert_model: str, mlx_model: str, batch: List[str]):
         
     | 
| 153 | 
         
            +
                import time
         
     | 
| 154 | 
         
            +
                
         
     | 
| 155 | 
         
            +
                # Time model loading
         
     | 
| 156 | 
         
            +
                load_start = time.time()
         
     | 
| 157 | 
         
            +
                model, tokenizer = load_model(bert_model, mlx_model)
         
     | 
| 158 | 
         
            +
                load_time = time.time() - load_start
         
     | 
| 159 | 
         
            +
                print(f"[MLX] Model loaded in {load_time:.2f} seconds")
         
     | 
| 160 | 
         
            +
                
         
     | 
| 161 | 
         
            +
                # Time tokenization
         
     | 
| 162 | 
         
            +
                print(f"[MLX] Tokenizing batch of {len(batch)} sentences")
         
     | 
| 163 | 
         
            +
                token_start = time.time()
         
     | 
| 164 | 
         
            +
                tokens = tokenizer(batch, return_tensors="np", padding=True)
         
     | 
| 165 | 
         
            +
                token_time = time.time() - token_start
         
     | 
| 166 | 
         
            +
                print(f"[MLX] Tokenization completed in {token_time:.4f} seconds")
         
     | 
| 167 | 
         
            +
                
         
     | 
| 168 | 
         
            +
                print(f"[MLX] Tokens shape: input_ids={tokens['input_ids'].shape}")
         
     | 
| 169 | 
         
            +
                tokens = {key: mx.array(v) for key, v in tokens.items()}
         
     | 
| 170 | 
         
            +
                
         
     | 
| 171 | 
         
            +
                # Time inference
         
     | 
| 172 | 
         
            +
                print(f"[MLX] Running model inference")
         
     | 
| 173 | 
         
            +
                infer_start = time.time()
         
     | 
| 174 | 
         
            +
                output, pooled = model(**tokens)
         
     | 
| 175 | 
         
            +
                mx.eval(output, pooled)  # Force evaluation of lazy arrays
         
     | 
| 176 | 
         
            +
                infer_time = time.time() - infer_start
         
     | 
| 177 | 
         
            +
                print(f"[MLX] Inference completed in {infer_time:.4f} seconds")
         
     | 
| 178 | 
         
            +
                
         
     | 
| 179 | 
         
            +
                return output, pooled
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 183 | 
         
            +
                parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
         
     | 
| 184 | 
         
            +
                parser.add_argument(
         
     | 
| 185 | 
         
            +
                    "--bert-model",
         
     | 
| 186 | 
         
            +
                    type=str,
         
     | 
| 187 | 
         
            +
                    default="bert-base-uncased",
         
     | 
| 188 | 
         
            +
                    help="The huggingface name of the BERT model to save.",
         
     | 
| 189 | 
         
            +
                )
         
     | 
| 190 | 
         
            +
                parser.add_argument(
         
     | 
| 191 | 
         
            +
                    "--mlx-model",
         
     | 
| 192 | 
         
            +
                    type=str,
         
     | 
| 193 | 
         
            +
                    default="weights/bert-base-uncased.npz",
         
     | 
| 194 | 
         
            +
                    help="The path of the stored MLX BERT weights (npz file).",
         
     | 
| 195 | 
         
            +
                )
         
     | 
| 196 | 
         
            +
                parser.add_argument(
         
     | 
| 197 | 
         
            +
                    "--text",
         
     | 
| 198 | 
         
            +
                    type=str,
         
     | 
| 199 | 
         
            +
                    default="This is an example of BERT working in MLX",
         
     | 
| 200 | 
         
            +
                    help="The text to generate embeddings for.",
         
     | 
| 201 | 
         
            +
                )
         
     | 
| 202 | 
         
            +
                args = parser.parse_args()
         
     | 
| 203 | 
         
            +
                run(args.bert_model, args.mlx_model, args.text)
         
     | 
    	
        test.py
    ADDED
    
    | 
         @@ -0,0 +1,143 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import argparse
         
     | 
| 2 | 
         
            +
            import time
         
     | 
| 3 | 
         
            +
            from typing import List
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import model
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import mlx.core as mx
         
     | 
| 8 | 
         
            +
            from transformers import AutoModel, AutoTokenizer
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def run_torch(bert_model: str, batch: List[str]):
         
     | 
| 12 | 
         
            +
                print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}")
         
     | 
| 13 | 
         
            +
                start_time = time.time()
         
     | 
| 14 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(bert_model)
         
     | 
| 15 | 
         
            +
                torch_model = AutoModel.from_pretrained(bert_model)
         
     | 
| 16 | 
         
            +
                load_time = time.time() - start_time
         
     | 
| 17 | 
         
            +
                print(f"[PyTorch] Model loaded in {load_time:.2f} seconds")
         
     | 
| 18 | 
         
            +
                
         
     | 
| 19 | 
         
            +
                print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences")
         
     | 
| 20 | 
         
            +
                torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
         
     | 
| 21 | 
         
            +
                
         
     | 
| 22 | 
         
            +
                print(f"[PyTorch] Running model inference")
         
     | 
| 23 | 
         
            +
                inference_start = time.time()
         
     | 
| 24 | 
         
            +
                torch_forward = torch_model(**torch_tokens)
         
     | 
| 25 | 
         
            +
                inference_time = time.time() - inference_start
         
     | 
| 26 | 
         
            +
                print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds")
         
     | 
| 27 | 
         
            +
                
         
     | 
| 28 | 
         
            +
                torch_output = torch_forward.last_hidden_state.detach().numpy()
         
     | 
| 29 | 
         
            +
                torch_pooled = torch_forward.pooler_output.detach().numpy()
         
     | 
| 30 | 
         
            +
                
         
     | 
| 31 | 
         
            +
                print(f"[PyTorch] Output shape: {torch_output.shape}")
         
     | 
| 32 | 
         
            +
                print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}")
         
     | 
| 33 | 
         
            +
                
         
     | 
| 34 | 
         
            +
                # Print a small sample of the output to verify sensible values
         
     | 
| 35 | 
         
            +
                print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}")
         
     | 
| 36 | 
         
            +
                print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}")
         
     | 
| 37 | 
         
            +
                
         
     | 
| 38 | 
         
            +
                return torch_output, torch_pooled
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def run_mlx(bert_model: str, mlx_model: str, batch: List[str]):
         
     | 
| 42 | 
         
            +
                print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}")
         
     | 
| 43 | 
         
            +
                start_time = time.time()
         
     | 
| 44 | 
         
            +
                mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch)
         
     | 
| 45 | 
         
            +
                load_and_run_time = time.time() - start_time
         
     | 
| 46 | 
         
            +
                print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds")
         
     | 
| 47 | 
         
            +
                
         
     | 
| 48 | 
         
            +
                # Convert from MLX arrays to numpy for comparison
         
     | 
| 49 | 
         
            +
                # The correct way to convert MLX arrays to numpy
         
     | 
| 50 | 
         
            +
                mlx_output_np = np.array(mlx_output)
         
     | 
| 51 | 
         
            +
                mlx_pooled_np = np.array(mlx_pooled)
         
     | 
| 52 | 
         
            +
                
         
     | 
| 53 | 
         
            +
                print(f"[MLX] Output shape: {mlx_output_np.shape}")
         
     | 
| 54 | 
         
            +
                print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}")
         
     | 
| 55 | 
         
            +
                
         
     | 
| 56 | 
         
            +
                # Print a small sample of the output to verify sensible values
         
     | 
| 57 | 
         
            +
                print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}")
         
     | 
| 58 | 
         
            +
                print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}")
         
     | 
| 59 | 
         
            +
                
         
     | 
| 60 | 
         
            +
                return mlx_output_np, mlx_pooled_np
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled):
         
     | 
| 64 | 
         
            +
                print("\n[Comparison] Comparing PyTorch and MLX outputs")
         
     | 
| 65 | 
         
            +
                
         
     | 
| 66 | 
         
            +
                # Check shapes
         
     | 
| 67 | 
         
            +
                print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}")
         
     | 
| 68 | 
         
            +
                print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}")
         
     | 
| 69 | 
         
            +
                
         
     | 
| 70 | 
         
            +
                # Calculate differences
         
     | 
| 71 | 
         
            +
                output_max_diff = np.max(np.abs(torch_output - mlx_output))
         
     | 
| 72 | 
         
            +
                output_mean_diff = np.mean(np.abs(torch_output - mlx_output))
         
     | 
| 73 | 
         
            +
                pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled))
         
     | 
| 74 | 
         
            +
                pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled))
         
     | 
| 75 | 
         
            +
                
         
     | 
| 76 | 
         
            +
                print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}")
         
     | 
| 77 | 
         
            +
                print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}")
         
     | 
| 78 | 
         
            +
                print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}")
         
     | 
| 79 | 
         
            +
                print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}")
         
     | 
| 80 | 
         
            +
                
         
     | 
| 81 | 
         
            +
                # Detailed comparison of first few values from first sentence
         
     | 
| 82 | 
         
            +
                print("\n[Comparison] Detailed comparison of first 5 values from first output token:")
         
     | 
| 83 | 
         
            +
                for i in range(5):
         
     | 
| 84 | 
         
            +
                    torch_val = torch_output[0, 0, i]
         
     | 
| 85 | 
         
            +
                    mlx_val = mlx_output[0, 0, i]
         
     | 
| 86 | 
         
            +
                    diff = abs(torch_val - mlx_val)
         
     | 
| 87 | 
         
            +
                    print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}")
         
     | 
| 88 | 
         
            +
                
         
     | 
| 89 | 
         
            +
                # Check if outputs are close
         
     | 
| 90 | 
         
            +
                outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4)
         
     | 
| 91 | 
         
            +
                pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4)
         
     | 
| 92 | 
         
            +
                
         
     | 
| 93 | 
         
            +
                print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}")
         
     | 
| 94 | 
         
            +
                print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}")
         
     | 
| 95 | 
         
            +
                
         
     | 
| 96 | 
         
            +
                return outputs_close and pooled_close
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 100 | 
         
            +
                parser = argparse.ArgumentParser(
         
     | 
| 101 | 
         
            +
                    description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs."
         
     | 
| 102 | 
         
            +
                )
         
     | 
| 103 | 
         
            +
                parser.add_argument(
         
     | 
| 104 | 
         
            +
                    "--bert-model",
         
     | 
| 105 | 
         
            +
                    type=str,
         
     | 
| 106 | 
         
            +
                    default="bert-base-uncased",
         
     | 
| 107 | 
         
            +
                    help="The model identifier for a BERT-like model from Hugging Face Transformers.",
         
     | 
| 108 | 
         
            +
                )
         
     | 
| 109 | 
         
            +
                parser.add_argument(
         
     | 
| 110 | 
         
            +
                    "--mlx-model",
         
     | 
| 111 | 
         
            +
                    type=str,
         
     | 
| 112 | 
         
            +
                    default="weights/bert-base-uncased.npz",
         
     | 
| 113 | 
         
            +
                    help="The path of the stored MLX BERT weights (npz file).",
         
     | 
| 114 | 
         
            +
                )
         
     | 
| 115 | 
         
            +
                parser.add_argument(
         
     | 
| 116 | 
         
            +
                    "--text",
         
     | 
| 117 | 
         
            +
                    nargs="+",
         
     | 
| 118 | 
         
            +
                    default=["This is an example of BERT working in MLX."],
         
     | 
| 119 | 
         
            +
                    help="A batch of texts to process. Multiple texts should be separated by spaces.",
         
     | 
| 120 | 
         
            +
                )
         
     | 
| 121 | 
         
            +
                parser.add_argument(
         
     | 
| 122 | 
         
            +
                    "--verbose",
         
     | 
| 123 | 
         
            +
                    action="store_true",
         
     | 
| 124 | 
         
            +
                    help="Print detailed information about the model execution.",
         
     | 
| 125 | 
         
            +
                )
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                args = parser.parse_args()
         
     | 
| 128 | 
         
            +
                
         
     | 
| 129 | 
         
            +
                print(f"Testing BERT model: {args.bert_model}")
         
     | 
| 130 | 
         
            +
                print(f"MLX weights: {args.mlx_model}")
         
     | 
| 131 | 
         
            +
                print(f"Input text: {args.text}")
         
     | 
| 132 | 
         
            +
                
         
     | 
| 133 | 
         
            +
                # Run both implementations
         
     | 
| 134 | 
         
            +
                torch_output, torch_pooled = run_torch(args.bert_model, args.text)
         
     | 
| 135 | 
         
            +
                mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text)
         
     | 
| 136 | 
         
            +
                
         
     | 
| 137 | 
         
            +
                # Compare outputs
         
     | 
| 138 | 
         
            +
                all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled)
         
     | 
| 139 | 
         
            +
                
         
     | 
| 140 | 
         
            +
                if all_match:
         
     | 
| 141 | 
         
            +
                    print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!")
         
     | 
| 142 | 
         
            +
                else:
         
     | 
| 143 | 
         
            +
                    print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.")
         
     |