|  | import argparse | 
					
						
						|  | import time | 
					
						
						|  | from typing import List | 
					
						
						|  |  | 
					
						
						|  | import model | 
					
						
						|  | import numpy as np | 
					
						
						|  | import mlx.core as mx | 
					
						
						|  | from transformers import AutoModel, AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run_torch(bert_model: str, batch: List[str]): | 
					
						
						|  | print(f"\n[PyTorch] Loading model and tokenizer: {bert_model}") | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained(bert_model) | 
					
						
						|  | torch_model = AutoModel.from_pretrained(bert_model) | 
					
						
						|  | load_time = time.time() - start_time | 
					
						
						|  | print(f"[PyTorch] Model loaded in {load_time:.2f} seconds") | 
					
						
						|  |  | 
					
						
						|  | print(f"[PyTorch] Tokenizing batch of {len(batch)} sentences") | 
					
						
						|  | torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) | 
					
						
						|  |  | 
					
						
						|  | print(f"[PyTorch] Running model inference") | 
					
						
						|  | inference_start = time.time() | 
					
						
						|  | torch_forward = torch_model(**torch_tokens) | 
					
						
						|  | inference_time = time.time() - inference_start | 
					
						
						|  | print(f"[PyTorch] Inference completed in {inference_time:.4f} seconds") | 
					
						
						|  |  | 
					
						
						|  | torch_output = torch_forward.last_hidden_state.detach().numpy() | 
					
						
						|  | torch_pooled = torch_forward.pooler_output.detach().numpy() | 
					
						
						|  |  | 
					
						
						|  | print(f"[PyTorch] Output shape: {torch_output.shape}") | 
					
						
						|  | print(f"[PyTorch] Pooled output shape: {torch_pooled.shape}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(f"[PyTorch] Sample of output (first token, first 5 values): {torch_output[0, 0, :5]}") | 
					
						
						|  | print(f"[PyTorch] Sample of pooled output (first 5 values): {torch_pooled[0, :5]}") | 
					
						
						|  |  | 
					
						
						|  | return torch_output, torch_pooled | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def run_mlx(bert_model: str, mlx_model: str, batch: List[str]): | 
					
						
						|  | print(f"\n[MLX] Loading model and tokenizer with weights from: {mlx_model}") | 
					
						
						|  | start_time = time.time() | 
					
						
						|  | mlx_output, mlx_pooled = model.run(bert_model, mlx_model, batch) | 
					
						
						|  | load_and_run_time = time.time() - start_time | 
					
						
						|  | print(f"[MLX] Model loaded and run in {load_and_run_time:.2f} seconds") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mlx_output_np = np.array(mlx_output) | 
					
						
						|  | mlx_pooled_np = np.array(mlx_pooled) | 
					
						
						|  |  | 
					
						
						|  | print(f"[MLX] Output shape: {mlx_output_np.shape}") | 
					
						
						|  | print(f"[MLX] Pooled output shape: {mlx_pooled_np.shape}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(f"[MLX] Sample of output (first token, first 5 values): {mlx_output_np[0, 0, :5]}") | 
					
						
						|  | print(f"[MLX] Sample of pooled output (first 5 values): {mlx_pooled_np[0, :5]}") | 
					
						
						|  |  | 
					
						
						|  | return mlx_output_np, mlx_pooled_np | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled): | 
					
						
						|  | print("\n[Comparison] Comparing PyTorch and MLX outputs") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(f"[Comparison] Shape match - Output: {torch_output.shape == mlx_output.shape}") | 
					
						
						|  | print(f"[Comparison] Shape match - Pooled: {torch_pooled.shape == mlx_pooled.shape}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | output_max_diff = np.max(np.abs(torch_output - mlx_output)) | 
					
						
						|  | output_mean_diff = np.mean(np.abs(torch_output - mlx_output)) | 
					
						
						|  | pooled_max_diff = np.max(np.abs(torch_pooled - mlx_pooled)) | 
					
						
						|  | pooled_mean_diff = np.mean(np.abs(torch_pooled - mlx_pooled)) | 
					
						
						|  |  | 
					
						
						|  | print(f"[Comparison] Output - Max absolute difference: {output_max_diff:.6f}") | 
					
						
						|  | print(f"[Comparison] Output - Mean absolute difference: {output_mean_diff:.6f}") | 
					
						
						|  | print(f"[Comparison] Pooled - Max absolute difference: {pooled_max_diff:.6f}") | 
					
						
						|  | print(f"[Comparison] Pooled - Mean absolute difference: {pooled_mean_diff:.6f}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("\n[Comparison] Detailed comparison of first 5 values from first output token:") | 
					
						
						|  | for i in range(5): | 
					
						
						|  | torch_val = torch_output[0, 0, i] | 
					
						
						|  | mlx_val = mlx_output[0, 0, i] | 
					
						
						|  | diff = abs(torch_val - mlx_val) | 
					
						
						|  | print(f"Index {i}: PyTorch={torch_val:.6f}, MLX={mlx_val:.6f}, Diff={diff:.6f}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | outputs_close = np.allclose(torch_output, mlx_output, rtol=1e-4, atol=1e-4) | 
					
						
						|  | pooled_close = np.allclose(torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-4) | 
					
						
						|  |  | 
					
						
						|  | print(f"\n[Comparison] Outputs match within tolerance: {outputs_close}") | 
					
						
						|  | print(f"[Comparison] Pooled outputs match within tolerance: {pooled_close}") | 
					
						
						|  |  | 
					
						
						|  | return outputs_close and pooled_close | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | parser = argparse.ArgumentParser( | 
					
						
						|  | description="Run a BERT-like model for a batch of text and compare PyTorch and MLX outputs." | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--bert-model", | 
					
						
						|  | type=str, | 
					
						
						|  | default="bert-base-uncased", | 
					
						
						|  | help="The model identifier for a BERT-like model from Hugging Face Transformers.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--mlx-model", | 
					
						
						|  | type=str, | 
					
						
						|  | default="weights/bert-base-uncased.npz", | 
					
						
						|  | help="The path of the stored MLX BERT weights (npz file).", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--text", | 
					
						
						|  | nargs="+", | 
					
						
						|  | default=["This is an example of BERT working in MLX."], | 
					
						
						|  | help="A batch of texts to process. Multiple texts should be separated by spaces.", | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--verbose", | 
					
						
						|  | action="store_true", | 
					
						
						|  | help="Print detailed information about the model execution.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | print(f"Testing BERT model: {args.bert_model}") | 
					
						
						|  | print(f"MLX weights: {args.mlx_model}") | 
					
						
						|  | print(f"Input text: {args.text}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch_output, torch_pooled = run_torch(args.bert_model, args.text) | 
					
						
						|  | mlx_output, mlx_pooled = run_mlx(args.bert_model, args.mlx_model, args.text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_match = compare_outputs(torch_output, torch_pooled, mlx_output, mlx_pooled) | 
					
						
						|  |  | 
					
						
						|  | if all_match: | 
					
						
						|  | print("\n✅ TEST PASSED: PyTorch and MLX implementations produce equivalent results!") | 
					
						
						|  | else: | 
					
						
						|  | print("\n❌ TEST FAILED: PyTorch and MLX implementations produce different results.") | 
					
						
						|  |  |