I’ve spent most of the last couple of years worrying about the GEMM function because it’s the heart of deep learning calculations. The trouble is, I’m not very good at matrix math! I struggled through the courses I took in high school and college, barely getting a passing grade, confident that I’d never need anything so esoteric ever again. Right out of college I started working on 3D graphics engines where matrices were everywhere, and they’ve been an essential tool in my work ever since.
I managed to develop decent intuitions for 3D transformations and their 4×4 matrix representations, but not having a solid grounding in the theory left me very prone to mistakes when I moved on to more general calculations. I screwed up the first version of all my diagrams in a previous blog post, and most recently had to make a breaking API change to the opensource gemmlowp library, all because I’d messed up the ordering of the matrix multiplies.
The best way I know to fix something in my own mind is to try to explain it to somebody else, so here are my notes on the areas I found most confusing about matrix math as an engineer. I hope they’ll be helpful, and I look forward to getting corrections on anything I screwed up!
Row versus Column Major
The root of a lot of my difficulties are the two competing ways you can store matrix values in RAM. Coming from an imageprocessing world, when I see a 2D array of values my ingrained assumption is that it’s stored like letters on a page, starting in the topleft corner, and moving from left to right and jumping down at the end of the row. For example, if you have a matrix that you’d draw like this:
 0  1  2 
 3  4  5 
 6  7  8 
You would store it in memory in this order:
 0  1  2  3  4  5  6  7  8 
This is known as ‘row major’, because each row is stored in adjacent locations in memory.
The alternative storage scheme is ‘column major’, and the same matrix would be laid out like this in memory:
 0  3  6  1  4  7  2  5  8 
What this means is that we’re walking down the columns to get adjacent memory locations instead.
It’s important to know that both of these are just ways of storing the same matrix in memory, they’re an implementation detail that should not affect the underlying math, or the way you diagram your equations. There’s also no widespread agreement on what the ‘correct’ or canonical storage order to use is, so you’ll need to pay attention to what convention the code you’re interacting with expects.
Transpose
One thing you may notice about row versus column ordering is that if you screw it up and pass a square rowmajor matrix into a library that expects column major, or viceversa, it will be read as the transpose of the actual matrix. Visually, you transpose a matrix by flipping all the columns into rows, like this:
 0  1  2  3 
 4  5  6  7 
 8  9 10 11 
Original
 0  4  8 
 1  5  9 
 2  6 10 
 3  7 11 
Transpose
Another way of visualizing this is drawing a line through the topleft to bottomright corners, and flipping everything around on that diagonal. You’ll often see the transpose of a matrix indicated by adding an apostrophe to the name, so that the transpose of a matrix A is A’.
Argument Ordering
If you multiply two numbers, A * B is always the same as B * A. This is not true for matrix multiplications! Indeed, you can only multiply two matrices together at all if the number of columns on the lefthandside is equal to the number of rows in the righthand argument. Even if they’re both the same square size, and so can potentially be swapped, the result will still depend on the order.
There is one mathematical identity that crops up a lot in practice with transposes. If you have the standard GEMM equation of C = A * B, then C’ = B’ * A’. In words, if you swap the order of the two input matrices and transpose both of them, then multiplying them will give the transpose of the result you’d get in the original untransposed order.
#Errors % 2 == 0
What really led me into danger is that all three of storage order, transposing, and input order have effects that can mimic and cancel each other out. It’s like Jim Blinn’s old quote about all correct graphics programs having an even number of sign errors, except that there are three different ways to screw things up instead of one!
For example, what I realized last week was that I was working in a new code base that assumes rowmajor order, but gemmlowp assumes column major. Because I had been in a hurry and couldn’t figure out why my unit tests weren’t working, so I ended up swapping the input argument order. Since C’ = B’ * A’, the storage order error was canceled out by the argument order error! It made for very confusing code, so thankfully a coworker slapped me round the back of the head (very politely) when he ran across it and I revisited it and figured out my errors.
Because I know I’m so prone to these kind of errors, I’ve forced myself to slow down when I’m tackling this kind of problem, and start off by working through a couple of examples on pen and paper. I find working visually with the diagram at the top of this post in mind has helped me immensely. Once I’ve got those examples straight, I’ll turn them into unit tests, with the hand calculations in the comments. You can see an example in gemmlowp/test/test.cc:698:

// Runs a small set of handcalculated data through the implementation. 

void TestWithSmallData() { 

const int m = 4; 

const int n = 2; 

const int k = 3; 

// Matrix A (LHS) is: 

//  7  10  13  16  

//  8  11  14  17  

//  9  12  15  18  

const uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; 

// Matrix B (RHS) is: 

//  1  3  5  

//  2  4  6  

const uint8_t b_data[] = {1, 2, 3, 4, 5, 6}; 

// Here are the results we expect, from hand calculations: 

// (1 * 7) + (3 * 8) + (5 * 9) = 76 

// (2 * 7) + (4 * 8) + (6 * 9) = 100 

// (1 * 10) + (3 * 11) + (5 * 12) = 103 

// (2 * 10) + (4 * 11) + (6 * 12) = 136 

// (1 * 13) + (3 * 14) + (5 * 15) = 130 

// (2 * 13) + (4 * 14) + (6 * 15) = 172 

// (1 * 16) + (3 * 17) + (5 * 18) = 157 

// (2 * 16) + (4 * 17) + (6 * 18) = 208 

// That means matrix C should be: 

//  76  103  130  157  

//  100  136  172  208  

const uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208}; 



const int c_count = m * n; 

std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]); 



const bool is_a_transposed = true; 

const bool is_b_transposed = true; 

const bool is_c_transposed = true; 

const int lda = k; 

const int ldb = n; 

const int ldc = n; 



const int a_offset = 0; 

const int b_offset = 0; 

const int c_offset = 0; 

const int c_mult = 1; 

const int c_shift = 0; 



gemmlowp::eight_bit_int_gemm::EightBitIntGemm( 

is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data, 

a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, 

c_mult, c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8); 



ResultStats stats; 

GetResultStats(output_data.get(), expected_data, c_count, &stats); 



ResultStatsBounds bounds; 

const bool good = CheckResultStatsBounds(stats, bounds); 

printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL"); 

ReportResultStats(stats, bounds); 

Check(good); 

} 
My other key tool is keeping around a simple reference GEMM function that’s unoptimized, but that I can easily step through and add logging statements to. Since a lot of my work involves cutting corners with precision to increase speed, it’s important that I have an understandable implementation that I can play with, and compare against more complex versions. You can see my version for gemmlowp at test.cc:36:

void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b, 

bool transpose_c, int m, int n, int k, 

const uint8_t* a, int32_t a_offset, int lda, 

const uint8_t* b, int32_t b_offset, int ldb, 

uint8_t* c, int32_t c_offset, int32_t c_mult_int, 

int32_t c_shift, int ldc) { 

assert((c_shift >= 0) && (c_shift <= 32)); 



assert(a != nullptr); 

assert(b != nullptr); 

assert(c != nullptr); 



int a_i_stride; 

int a_l_stride; 

if (transpose_a) { 

a_i_stride = lda; 

a_l_stride = 1; 

} else { 

a_i_stride = 1; 

a_l_stride = lda; 

} 

int b_j_stride; 

int b_l_stride; 

if (transpose_b) { 

b_j_stride = 1; 

b_l_stride = ldb; 

} else { 

b_j_stride = ldb; 

b_l_stride = 1; 

} 

int c_i_stride; 

int c_j_stride; 

if (transpose_c) { 

c_i_stride = ldc; 

c_j_stride = 1; 

} else { 

c_i_stride = 1; 

c_j_stride = ldc; 

} 

int i, j, l; 



const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift – 1)); 



for (j = 0; j < n; j++) { 

for (i = 0; i < m; i++) { 

int32_t total = 0; 

for (l = 0; l < k; l++) { 

const int a_index = i * a_i_stride + l * a_l_stride; 

const uint8_t a_as_byte = a[a_index]; 

const int32_t a_as_int = static_cast<int32_t>(a_as_byte) + a_offset; 

const int b_index = j * b_j_stride + l * b_l_stride; 

const uint8_t b_as_byte = b[b_index]; 

const int32_t b_as_int = static_cast<int32_t>(b_as_byte) + b_offset; 

const int32_t mult_as_int = a_as_int * b_as_int; 

total += mult_as_int; 

} 

int32_t output = 

(((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift; 

if (output > 255) { 

output = 255; 

} 

if (output < 0) { 

output = 0; 

} 

const int c_index = i * c_i_stride + j * c_j_stride; 

c[c_index] = static_cast<uint8_t>(output); 

} 

} 

} 
This includes some eightbit specific code, but the structure is common across all my reference versions, with three nested loops across the m, n, and k dimensions of the matrices. It also doesn’t include some of the standard arguments like alpha or beta. This particular code assumes column major by default, but if any of the transpose flags are set to true, then the matrix is treated as row major.
Leading Dimensions
The final pieces of the puzzle for me were the lda, ldb, and ldc arguments. These left me confused initially because I struggled to find a clear representation of that they meant. I finally realized that they were the number of values you moved forward in memory when you reached the end of a row (in rowmajor order) or column (in column major). They’re strides that give a lot of flexibility when you only want to work with smaller tiles inside a larger matrix, since they let you skip over values you want to ignore. If you’re not dealing with subtiles, then they’ll be the number of columns for a rowmajor matrix, and the number of rows for a columnmajor one.
Anyway, I hope these notes help any other lost souls who are struggling with matrices. I’m now feeling a lot more confident, and I wish I’d taken the time to study them more carefully before. They’re very powerful tools, and actually a lot of fun once I moved past some of my confusion!