I Tested Apple’s New MLX Framework Against Torch on M2 Air

Tim Cvetko

ML Engineer
AI Model Developer
AI Developer
PyTorch

On Tuesday, Apple’s AI team released “

MLX

” — the new machine learning framework designed to work specifically for the Apple Silicon Chips. The design of MLX was inspired by frameworks like 

NumPy

PyTorch

Jax

, and 

ArrayFire

.

Is MLX really faster than Torch on Mac?

As I own a Macbook M2 Air and regularly train ML models locally, I decided to put this hypothesis to the test by training the standard BERT transformers model on both MLX and PyTorch. The results are staggering!

Who should read this?

Who is this blog post useful for? Mac(M1, M2, M3) owners who are looking for a faster training & inference ML framework.

How advanced is this post? Anybody previously acquainted with ML terms should be able to follow along.

Quick Guide into MLX

MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research.

MLX is very Torch-like in its syntax. MLX has higher-level packages like mlx.nn and mlx.optimizers with APIs that closely follow PyTorch to simplify building more complex models.

MLX has a Python API that closely follows NumPy. MLX also has a fully featured C++ API, which closely mirrors the Python API.pip install mlx

The 

MLX examples repo

 has a variety of examples, including:

Large-scale text generation with 

LLaMA

 and finetuning with 

LoRA

.

Generating images with 

Stable Diffusion

.

Speech recognition with 

OpenAI’s Whisper

.

Here’s THE Thing

The goal of this experiment was to test the MLX against the standard ~ 400MB 

BERT

 from HuggingFace in comparison to PyTorch. Here’s what I wanted to test:

Training time [sec]

Inference time [sec]

CPU Utilization during training [%]

Model Setup

The MLX model was built separately to adjust for syntax change but was initialized with the same model weights as the Torch model.

Training Comparison

For MLX, the training process involved loading a pre-trained BERT model, updating its weights with weights converted from the PyTorch model, and evaluating the performance using synthetic data.

The training time was measured across varying data sizes, specifically different batch sizes, to showcase how MLX handles the training workload.

Inference

For MLX, the inference process involved loading a pre-trained BERT model, specifically the MLXBertModel, and providing it with synthetic input data. The input data consisted of randomly generated sequences, including input_ids, token_type_ids, and attention_mask, mimicking the structure of typical BERT inputs. These synthetic inputs were then passed through the MLXBertModel to measure the time taken for inference, capturing the efficiency of the MLX framework in processing BERT-based tasks.

On the PyTorch side, the inference setup mirrored that of MLX. A pre-trained BERT model, sourced from the Hugging Face Transformers library, was loaded into a PyTorch environment. Similarly, synthetic input data, comprising input_ids, token_type_ids, and attention_mask tensors, was generated randomly. The PyTorch BERT model processed this synthetic input data, and the inference time was recorded.

CPU Utilisation

During both MLX and PyTorch BERT model training, I captured CPU utilization via the Psutil library.

Conclusion

As per my final notes, the data on the comparison experiment are final.

The mlx library outperforms torch in computing on an Apple M2 Chip by an 8% CPU Utils Margin.

These research results from Apple’s research team definitely seem promising.

Hey, thanks for reading!

Thanks for getting to the end of this article. My name is Tim, I love to elaborate ML research papers or ML applications with emphasis on business use cases.

Partner With Tim
View Services

More Projects by Tim