I’ve spent most of the last couple of years worrying about the GEMM function because it’s the heart of deep learning calculations. The trouble is, I’m not very good at matrix math! I struggled through the courses I took in high school and college, barely getting a passing grade, confident that I’d never need anything so esoteric ever again. Right out of college I started working on 3D graphics engines where matrices were everywhere, and they’ve been an essential tool in my work ever since.
I managed to develop decent intuitions for 3D transformations and their 4×4 matrix representations, but not having a solid grounding in the theory left me very prone to mistakes when I moved on to more general calculations. I screwed up the first version of all my diagrams in a previous blog post, and most recently had to make a breaking API change to the open-source gemmlowp library, all because I’d messed up the ordering of the matrix multiplies.
The best way I know to fix something in my own mind is to try to explain it to somebody else, so here are my notes on the areas I found most confusing about matrix math as an engineer. I hope they’ll be helpful, and I look forward to getting corrections on anything I screwed up!
Row versus Column Major
The root of a lot of my difficulties are the two competing ways you can store matrix values in RAM. Coming from an image-processing world, when I see a 2D array of values my in-grained assumption is that it’s stored like letters on a page, starting in the top-left corner, and moving from left to right and jumping down at the end of the row. For example, if you have a matrix that you’d draw like this:
| 0 | 1 | 2 |
| 3 | 4 | 5 |
| 6 | 7 | 8 |
You would store it in memory in this order:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
This is known as ‘row major’, because each row is stored in adjacent locations in memory.
The alternative storage scheme is ‘column major’, and the same matrix would be laid out like this in memory:
| 0 | 3 | 6 | 1 | 4 | 7 | 2 | 5 | 8 |
What this means is that we’re walking down the columns to get adjacent memory locations instead.
It’s important to know that both of these are just ways of storing the same matrix in memory, they’re an implementation detail that should not affect the underlying math, or the way you diagram your equations. There’s also no widespread agreement on what the ‘correct’ or canonical storage order to use is, so you’ll need to pay attention to what convention the code you’re interacting with expects.
One thing you may notice about row versus column ordering is that if you screw it up and pass a square row-major matrix into a library that expects column major, or vice-versa, it will be read as the transpose of the actual matrix. Visually, you transpose a matrix by flipping all the columns into rows, like this:
| 0 | 1 | 2 | 3 |
| 4 | 5 | 6 | 7 |
| 8 | 9 |10 |11 |
| 0 | 4 | 8 |
| 1 | 5 | 9 |
| 2 | 6 |10 |
| 3 | 7 |11 |
Another way of visualizing this is drawing a line through the top-left to bottom-right corners, and flipping everything around on that diagonal. You’ll often see the transpose of a matrix indicated by adding an apostrophe to the name, so that the transpose of a matrix A is A’.
If you multiply two numbers, A * B is always the same as B * A. This is not true for matrix multiplications! Indeed, you can only multiply two matrices together at all if the number of columns on the left-hand-side is equal to the number of rows in the right-hand argument. Even if they’re both the same square size, and so can potentially be swapped, the result will still depend on the order.
There is one mathematical identity that crops up a lot in practice with transposes. If you have the standard GEMM equation of C = A * B, then C’ = B’ * A’. In words, if you swap the order of the two input matrices and transpose both of them, then multiplying them will give the transpose of the result you’d get in the original untransposed order.
#Errors % 2 == 0
What really led me into danger is that all three of storage order, transposing, and input order have effects that can mimic and cancel each other out. It’s like Jim Blinn’s old quote about all correct graphics programs having an even number of sign errors, except that there are three different ways to screw things up instead of one!
For example, what I realized last week was that I was working in a new code base that assumes row-major order, but gemmlowp assumes column major. Because I had been in a hurry and couldn’t figure out why my unit tests weren’t working, so I ended up swapping the input argument order. Since C’ = B’ * A’, the storage order error was canceled out by the argument order error! It made for very confusing code, so thankfully a co-worker slapped me round the back of the head (very politely) when he ran across it and I revisited it and figured out my errors.
Because I know I’m so prone to these kind of errors, I’ve forced myself to slow down when I’m tackling this kind of problem, and start off by working through a couple of examples on pen and paper. I find working visually with the diagram at the top of this post in mind has helped me immensely. Once I’ve got those examples straight, I’ll turn them into unit tests, with the hand calculations in the comments. You can see an example in gemmlowp/test/test.cc:698:
My other key tool is keeping around a simple reference GEMM function that’s unoptimized, but that I can easily step through and add logging statements to. Since a lot of my work involves cutting corners with precision to increase speed, it’s important that I have an understandable implementation that I can play with, and compare against more complex versions. You can see my version for gemmlowp at test.cc:36:
This includes some eight-bit specific code, but the structure is common across all my reference versions, with three nested loops across the m, n, and k dimensions of the matrices. It also doesn’t include some of the standard arguments like alpha or beta. This particular code assumes column major by default, but if any of the transpose flags are set to true, then the matrix is treated as row major.
The final pieces of the puzzle for me were the lda, ldb, and ldc arguments. These left me confused initially because I struggled to find a clear representation of that they meant. I finally realized that they were the number of values you moved forward in memory when you reached the end of a row (in row-major order) or column (in column major). They’re strides that give a lot of flexibility when you only want to work with smaller tiles inside a larger matrix, since they let you skip over values you want to ignore. If you’re not dealing with sub-tiles, then they’ll be the number of columns for a row-major matrix, and the number of rows for a column-major one.
Anyway, I hope these notes help any other lost souls who are struggling with matrices. I’m now feeling a lot more confident, and I wish I’d taken the time to study them more carefully before. They’re very powerful tools, and actually a lot of fun once I moved past some of my confusion!