The code accompanying this post is available at pbt-batch-invariance.
Recently, a Thinking Machines blog post discussed why nondeterminism in large language models is a problem. The blog argues that batch-invariance in matrix multiplication, RMSNorm, and attention is crucial for deterministic inference. In their repo, the test_batch_invariance.py
file shows a simple test for batch-invariance of matrix multiplication, with a random draw of PyTorch tensors (basically, using torch.randn
).
This testing seemed interesting enough, but I wanted to do something more rigorous. I wanted to use property-based testing to test for batch-invariance. The Hypothesis library allows for more sophisticated testing: you define an input domain (which can be quite complex), and property or properties that should hold. Hypothesis then generates random inputs from the domain, and checks that the properties hold. Another good thing about Hypothesis is that, if it finds a counterexample, it will attempt to shrink it to a minimal example.
An example of a property is, for example, that a sorted list, sorted via a function my_sort()
, should be non-decreasing. In Hypothesis, we would write this as:
from hypothesis import given, strategies as st
@given(st.lists(st.integers()))
def test_sorted_list(lst):
sorted_lst = my_sort(lst)
for i in range(len(sorted_lst) - 1):
assert sorted_lst[i] <= sorted_lst[i + 1]
This test generates random lists of integers, sorts them via my_sort()
, and checks that the sorted list is non-decreasing. We are going to try to do something similar for batch-invariance.
The test that the Thinking Machine repo tests is essentially the following: given two tensors \( a \) and \( b \), test that \( a[:1] @ b = (a @ b)[:1] \). The left-hand side is \( a[:1] \) matrix-multiplied by \( b \), while the right-hand side is the first row of \( a @ b \). The dimensions are the same, and because of how matrix multiplication works, the two are equivalent. However, as the Thinking Machines blog post argues, in practice, they are not equivalent, because of how kernel operations are implemented. The way the test is currently written draws a random \( a \) and \( b \), of fixed size, with a fixed sequence of numbers defined by a linear space, and then tests this property by computing the difference of the two sides of the equality.
It would be better to define a general property and let Hypothesis generate random tensors and test this property. First of all, a more general input generation strategy is that instead of taking the slice of the first row, we can take any slice, namely, rows \( m \) to \( n \), exclusive of the last row. Second of all, the sizes of the tensors can be random, and the elements of the tensor are random floats within a given range. So, we write the following input strategy:
You can do something very similar for RMSNorm (the root-mean-square normalization). Now the input strategy is:
Lastly, you can do the same thing with attention (the scaled dot-product attention). Now the input strategy is:
The built-in implementations use reduction across the batch dimension.For matrix multiplication, we use the @
operator. For RMSNorm, we use the definition: x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * gamma
. For attention, we use the torch.nn.functional.scaled_dot_product_attention
function.
The rowwise implementations break up the computation across rows (batches) in order to enforce batch-invariance. They are essentially the same as the batched implementation, but instead compute each row separately, then stack them. This is obviously much slower, and is not the same as the batch-invariant kernels implemented in the Thinking Machines post.
test_outputs
folder, including the counterexamples that Hypothesis found.
On a CPU, the results are:
test_batch_invariance.py::test_matmul[matmul_batched] FAILED
test_batch_invariance.py::test_matmul[matmul_rowwise] PASSED
test_batch_invariance.py::test_rmsnorm[rmsnorm_batched] PASSED
test_batch_invariance.py::test_rmsnorm[rmsnorm_rowwise] PASSED
test_batch_invariance.py::test_attn[attn_batched] PASSED
test_batch_invariance.py::test_attn[attn_rowwise] PASSED
Here, on the CPU version, only matmul_batched
failed, not the other *_batched
versions. This is likely due to how CPU implementations work.
The output of the tests on a GPU is:
test_batch_invariance.py::test_matmul[matmul_batched] FAILED
test_batch_invariance.py::test_matmul[matmul_rowwise] PASSED
test_batch_invariance.py::test_rmsnorm[rmsnorm_batched] FAILED
test_batch_invariance.py::test_rmsnorm[rmsnorm_rowwise] PASSED
test_batch_invariance.py::test_attn[attn_batched] FAILED
test_batch_invariance.py::test_attn[attn_rowwise] PASSED
Here, on the GPU version, all *_batched
versions failed, and all the *_rowwise
versions passed. As the original blog post argues, the way that GPU kernels handle batching causes nondeterminism.