MLX

MLX 20/80. With 20% effort, hope to get 80% covered on MLX's offerings. And, I don't stand by this claim.
(pyvenv-activate "~/.venv/machine-learning")
import mlx.core as mx

a = mx.array([0, 1, 2, 3, 4, 5])
print(a.shape)
print(a.dtype)

b = mx.array([10, 11.0, 12.0, 13.0, 14.0, 15.0])
print(b.shape)
print(b.dtype)
(6,)
mlx.core.int32
(6,)
mlx.core.float32
c = a + b
print(c)
mx.eval(c)
print(c)
array([10, 12, 14, 16, 18, 20], dtype=float32)
array([10, 12, 14, 16, 18, 20], dtype=float32)
print(mx.arange(10))
print(mx.random.normal((1, 10)))
array([0, 1, 2, ..., 7, 8, 9], dtype=int32)
array([[0.492185, -0.329351, -1.04443, ..., 0.454991, -0.884072, -0.39394]], dtype=float32)

Utility Functions

import time
def timeit(f):
    start = time.perf_counter()
    ret = f()
    end = time.perf_counter()
    return ret, end - start

Testing this…

ret, elapsed = timeit(lambda : 1 + 1)
print(ret)
print(elapsed)
2
4.000000000004e-06

Unified Memory

The CPU and GPU use the same RAM. We don’t need to move data when planning operations between the CPU and the GPU.

It’s essential to understand what the CPU and GPU are good and not great for, before we choose how to distribute our computing. Here’s an example from the MLX :: Unified Memory documentation.

Here’s a function that operates on its arguments, and takes as parameters the devices on which to run its constituent operations.

def fun(a, b, d1, d2): # d1, d2 -> devices
    x = mx.matmul(a, b, stream=d1)
    for _ in range(500):
        b = mx.exp(b, stream=d2)
    return x, b

We create some sample inputs a and b

a = mx.random.uniform(shape=(4096, 512))
b = mx.random.uniform(shape=(512, 4))
ret, elapsed = timeit(lambda: fun(a, b, mx.gpu, mx.gpu))
print(ret)
print(elapsed)

ret, elapsed = timeit(lambda: fun(a, b, mx.gpu, mx.cpu))
print(ret)
print(elapsed)
(array([[134.262, 129.237, 130.571, 130.649],
       [130.853, 129.764, 131.11, 131.915],
       [125.345, 126.757, 126.702, 126.284],
       ...,
       [129.413, 126.892, 128.287, 128.288],
       [132.635, 127.151, 128.288, 132.74],
       [131.024, 122.929, 128.609, 130.531]], dtype=float32), array([[inf, inf, inf, inf],
       [inf, inf, inf, inf],
       [inf, inf, inf, inf],
       ...,
       [inf, inf, inf, inf],
       [inf, inf, inf, inf],
       [inf, inf, inf, inf]], dtype=float32))
0.00019970801076851785
(array([[134.262, 129.237, 130.571, 130.649],
       [130.853, 129.764, 131.11, 131.915],
       [125.345, 126.757, 126.702, 126.284],
       ...,
       [129.413, 126.892, 128.287, 128.288],
       [132.635, 127.151, 128.288, 132.74],
       [131.024, 122.929, 128.609, 130.531]], dtype=float32), array([[inf, inf, inf, inf],
       [inf, inf, inf, inf],
       [inf, inf, inf, inf],
       ...,
       [inf, inf, inf, inf],
       [inf, inf, inf, inf],
       [inf, inf, inf, inf]], dtype=float32))
0.00015633300063200295

Saving and Loading Arrays

MLX can save and load arrays in multiple formats, making it very easy to interoperate with other libraries/tools/applications. Supported formats and corresponding functions in the `mx.core` package.

Format Save function Load function
NumPy `save` `load`
NumPy Archive `savez` and `savez_compressed` `load`
Safetensors `save_safetensors` `load`
GGUF `save_gguf` `load`

As you can see, the `load` function works for all formats - it infers the format either from the suffix of the file used for reading, or can take an (optional) argument that indicates the format.

GPT

This is from Andrej Karpathy’s Let’s build GPT: from scratch, in code, spelled out video.

We’ll be using the dataset indicated in the video - called the Tiny Shakespeare text. Download it.

if [ ! -f "input.txt" ]; then
    wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
fi

We read the text into memory, and then evaluate some quick details and create the encoder and decoder. To and from integers.

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print("length of dataset in characters: ", len(text))
length of dataset in characters:  1115394
chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch : i for i, ch in enumerate(chars) }
itos = { i : ch for i, ch in enumerate(chars) }

encode = lambda s: [stoi[c] for c in s] # string -> list[integer]
decode = lambda l: ''.join([itos[i] for i in l]) # list[integer] -> string

Let’s quickly check our code

print(''.join(chars))
print(vocab_size)

print(encode("Hello, World!"))
print(decode(encode("How are you doing today?")))

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65
[20, 43, 50, 50, 53, 6, 1, 35, 53, 56, 50, 42, 2]
How are you doing today?

We have a custom encoder/decoder pair here. For real-world use-cases, libraries like sentencepiece and tiktoken are appropriate.

We’ll next encode the entire text into an `mlx.core.array` object

import mlx.core as mx
data = mx.array(encode(text))

Quick check

print(data.shape, data.dtype)
print(data[:1000])
(1115394,) mlx.core.int32
array([18, 47, 56, ..., 8, 0, 0], dtype=int32)

Splitting into train and test

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
block_size = 8 # The maximum context length for predictions

Quick check, and a show of what our input training data and corresponding target values are

print(train_data[:block_size+1])

x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context}, the target is {target}")
array([18, 47, 56, ..., 15, 47, 58], dtype=int32)
When input is array([18], dtype=int32), the target is array(47, dtype=int32)
When input is array([18, 47], dtype=int32), the target is array(56, dtype=int32)
When input is array([18, 47, 56], dtype=int32), the target is array(57, dtype=int32)
When input is array([18, 47, 56, 57], dtype=int32), the target is array(58, dtype=int32)
When input is array([18, 47, 56, 57, 58], dtype=int32), the target is array(1, dtype=int32)
When input is array([18, 47, 56, 57, 58, 1], dtype=int32), the target is array(15, dtype=int32)
When input is array([18, 47, 56, ..., 58, 1, 15], dtype=int32), the target is array(47, dtype=int32)
When input is array([18, 47, 56, ..., 1, 15, 47], dtype=int32), the target is array(58, dtype=int32)

Coming to now, the batch size. We pass in data in batches to be efficient with the use of the CPU/GPU resources which can handle multiple computations of the same kind in parallel.

batch_size = 4 # Number of independent sequences we will process in parallel
mx.random.seed(1337)

def get_batch(data):
    # Generate a small batch of data of inputs x and targets y
    ix = mx.random.randint(0, len(data) - block_size, [batch_size])
    x = mx.stack([data[i:i+block_size] for i in ix.tolist()])
    y = mx.stack([data[i+1:i+block_size+1] for i in ix.tolist()])
    return x, y

Quick check

x, y = get_batch(train_data)
print("inputs")
print(x.shape)
print(x)
print("targets")
print(y.shape)
print(y)

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = x[b, :t+1]
        target = y[b, t]
        print(f"when input is {context.tolist()}, target is {target}")
inputs
(4, 8)
array([[53, 1, 51, ..., 43, 50, 44],
       [32, 53, 1, ..., 39, 58, 1],
       [53, 59, 1, ..., 50, 50, 1],
       [23, 17, 10, ..., 39, 52, 1]], dtype=int32)
targets
(4, 8)
array([[1, 51, 63, ..., 50, 44, 0],
       [53, 1, 61, ..., 58, 1, 61],
       [59, 1, 58, ..., 50, 1, 51],
       [17, 10, 0, ..., 52, 1, 52]], dtype=int32)
when input is [53], target is array(1, dtype=int32)
when input is [53, 1], target is array(51, dtype=int32)
when input is [53, 1, 51], target is array(63, dtype=int32)
when input is [53, 1, 51, 63], target is array(57, dtype=int32)
when input is [53, 1, 51, 63, 57], target is array(43, dtype=int32)
when input is [53, 1, 51, 63, 57, 43], target is array(50, dtype=int32)
when input is [53, 1, 51, 63, 57, 43, 50], target is array(44, dtype=int32)
when input is [53, 1, 51, 63, 57, 43, 50, 44], target is array(0, dtype=int32)
when input is [32], target is array(53, dtype=int32)
when input is [32, 53], target is array(1, dtype=int32)
when input is [32, 53, 1], target is array(61, dtype=int32)
when input is [32, 53, 1, 61], target is array(46, dtype=int32)
when input is [32, 53, 1, 61, 46], target is array(39, dtype=int32)
when input is [32, 53, 1, 61, 46, 39], target is array(58, dtype=int32)
when input is [32, 53, 1, 61, 46, 39, 58], target is array(1, dtype=int32)
when input is [32, 53, 1, 61, 46, 39, 58, 1], target is array(61, dtype=int32)
when input is [53], target is array(59, dtype=int32)
when input is [53, 59], target is array(1, dtype=int32)
when input is [53, 59, 1], target is array(58, dtype=int32)
when input is [53, 59, 1, 58], target is array(43, dtype=int32)
when input is [53, 59, 1, 58, 43], target is array(50, dtype=int32)
when input is [53, 59, 1, 58, 43, 50], target is array(50, dtype=int32)
when input is [53, 59, 1, 58, 43, 50, 50], target is array(1, dtype=int32)
when input is [53, 59, 1, 58, 43, 50, 50, 1], target is array(51, dtype=int32)
when input is [23], target is array(17, dtype=int32)
when input is [23, 17], target is array(10, dtype=int32)
when input is [23, 17, 10], target is array(0, dtype=int32)
when input is [23, 17, 10, 0], target is array(15, dtype=int32)
when input is [23, 17, 10, 0, 15], target is array(39, dtype=int32)
when input is [23, 17, 10, 0, 15, 39], target is array(52, dtype=int32)
when input is [23, 17, 10, 0, 15, 39, 52], target is array(1, dtype=int32)
when input is [23, 17, 10, 0, 15, 39, 52, 1], target is array(52, dtype=int32)

Let’s next implement the `BigramLanguageModel`

import mlx.nn as nn

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def __call__(self, idx, targets):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B, T, C)
        loss = nn.losses.cross_entropy(logits, targets)
        return logits
m = BigramLanguageModel(vocab_size)
out = m(x, y)
print(out.shape)
array([[4.24595, 4.1331, 4.40312, ..., 4.31248, 4.05547, 4.20361],
       [4.10041, 4.24595, 4.21379, ..., 4.32175, 4.28335, 4.21379],
       [4.11496, 4.26998, 4.01405, ..., 4.19023, 4.3109, 4.1331],
       [4.0343, 4.34374, 4.10972, ..., 4.19232, 4.04217, 4.41227]], dtype=float32)
(4, 8, 65)