An Engineer’s Guide to GEMM

gemm_corrected

I’ve spent most of the last couple of years worrying about the GEMM function because it’s the heart of deep learning calculations. The trouble is, I’m not very good at matrix math! I struggled through the courses I took in high school and college, barely getting a passing grade, confident that I’d never need anything so esoteric ever again. Right out of college I started working on 3D graphics engines where matrices were everywhere, and they’ve been an essential tool in my work ever since.

I managed to develop decent intuitions for 3D transformations and their 4×4 matrix representations, but not having a solid grounding in the theory left me very prone to mistakes when I moved on to more general calculations. I screwed up the first version of all my diagrams in a previous blog post, and most recently had to make a breaking API change to the open-source gemmlowp library, all because I’d messed up the ordering of the matrix multiplies.

The best way I know to fix something in my own mind is to try to explain it to somebody else, so here are my notes on the areas I found most confusing about matrix math as an engineer. I hope they’ll be helpful, and I look forward to getting corrections on anything I screwed up!

Row versus Column Major

The root of a lot of my difficulties are the two competing ways you can store matrix values in RAM. Coming from an image-processing world, when I see a 2D array of values my in-grained assumption is that it’s stored like letters on a page, starting in the top-left corner, and moving from left to right and jumping down at the end of the row. For example, if you have a matrix that you’d draw like this:
| 0 | 1 | 2 |
| 3 | 4 | 5 |
| 6 | 7 | 8 |

You would store it in memory in this order:
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |

This is known as ‘row major’, because each row is stored in adjacent locations in memory.
The alternative storage scheme is ‘column major’, and the same matrix would be laid out like this in memory:
| 0 | 3 | 6 | 1 | 4 | 7 | 2 | 5 | 8 |

What this means is that we’re walking down the columns to get adjacent memory locations instead.

It’s important to know that both of these are just ways of storing the same matrix in memory, they’re an implementation detail that should not affect the underlying math, or the way you diagram your equations. There’s also no widespread agreement on what the ‘correct’ or canonical storage order to use is, so you’ll need to pay attention to what convention the code you’re interacting with expects.

Transpose

One thing you may notice about row versus column ordering is that if you screw it up and pass a square row-major matrix into a library that expects column major, or vice-versa, it will be read as the transpose of the actual matrix. Visually, you transpose a matrix by flipping all the columns into rows, like this:

| 0 | 1 | 2 | 3 |
| 4 | 5 | 6 | 7 |
| 8 | 9 |10 |11 |

Original

| 0 | 4 | 8 |
| 1 | 5 | 9 |
| 2 | 6 |10 |
| 3 | 7 |11 |

Transpose

Another way of visualizing this is drawing a line through the top-left to bottom-right corners, and flipping everything around on that diagonal. You’ll often see the transpose of a matrix indicated by adding an apostrophe to the name, so that the transpose of a matrix A is A’.

Argument Ordering

If you multiply two numbers, A * B is always the same as B * A. This is not true for matrix multiplications! Indeed, you can only multiply two matrices together at all if the number of columns on the left-hand-side is equal to the number of rows in the right-hand argument. Even if they’re both the same square size, and so can potentially be swapped, the result will still depend on the order.

There is one mathematical identity that crops up a lot in practice with transposes. If you have the standard GEMM equation of C = A * B, then C’ = B’ * A’. In words, if you swap the order of the two input matrices and transpose both of them, then multiplying them will give the transpose of the result you’d get in the original untransposed order.

#Errors % 2 == 0

What really led me into danger is that all three of storage order, transposing, and input order have effects that can mimic and cancel each other out. It’s like Jim Blinn’s old quote about all correct graphics programs having an even number of sign errors, except that there are three different ways to screw things up instead of one!

For example, what I realized last week was that I was working in a new code base that assumes row-major order, but gemmlowp assumes column major. Because I had been in a hurry and couldn’t figure out why my unit tests weren’t working, so I ended up swapping the input argument order. Since C’ = B’ * A’, the storage order error was canceled out by the argument order error! It made for very confusing code, so thankfully a co-worker slapped me round the back of the head (very politely) when he ran across it and I revisited it and figured out my errors.

Because I know I’m so prone to these kind of errors, I’ve forced myself to slow down when I’m tackling this kind of problem, and start off by working through a couple of examples on pen and paper. I find working visually with the diagram at the top of this post in mind has helped me immensely. Once I’ve got those examples straight, I’ll turn them into unit tests, with the hand calculations in the comments. You can see an example in gemmlowp/test/test.cc:698:


// Runs a small set of hand-calculated data through the implementation.
void TestWithSmallData() {
const int m = 4;
const int n = 2;
const int k = 3;
// Matrix A (LHS) is:
// | 7 | 10 | 13 | 16 |
// | 8 | 11 | 14 | 17 |
// | 9 | 12 | 15 | 18 |
const uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
// Matrix B (RHS) is:
// | 1 | 3 | 5 |
// | 2 | 4 | 6 |
const uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
// Here are the results we expect, from hand calculations:
// (1 * 7) + (3 * 8) + (5 * 9) = 76
// (2 * 7) + (4 * 8) + (6 * 9) = 100
// (1 * 10) + (3 * 11) + (5 * 12) = 103
// (2 * 10) + (4 * 11) + (6 * 12) = 136
// (1 * 13) + (3 * 14) + (5 * 15) = 130
// (2 * 13) + (4 * 14) + (6 * 15) = 172
// (1 * 16) + (3 * 17) + (5 * 18) = 157
// (2 * 16) + (4 * 17) + (6 * 18) = 208
// That means matrix C should be:
// | 76 | 103 | 130 | 157 |
// | 100 | 136 | 172 | 208 |
const uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
const int c_count = m * n;
std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]);
const bool is_a_transposed = true;
const bool is_b_transposed = true;
const bool is_c_transposed = true;
const int lda = k;
const int ldb = n;
const int ldc = n;
const int a_offset = 0;
const int b_offset = 0;
const int c_offset = 0;
const int c_mult = 1;
const int c_shift = 0;
gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset,
c_mult, c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
ResultStats stats;
GetResultStats(output_data.get(), expected_data, c_count, &stats);
ResultStatsBounds bounds;
const bool good = CheckResultStatsBounds(stats, bounds);
printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
ReportResultStats(stats, bounds);
Check(good);
}

view raw

test.cc

hosted with ❤ by GitHub

My other key tool is keeping around a simple reference GEMM function that’s unoptimized, but that I can easily step through and add logging statements to. Since a lot of my work involves cutting corners with precision to increase speed, it’s important that I have an understandable implementation that I can play with, and compare against more complex versions. You can see my version for gemmlowp at test.cc:36:


void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
bool transpose_c, int m, int n, int k,
const uint8_t* a, int32_t a_offset, int lda,
const uint8_t* b, int32_t b_offset, int ldb,
uint8_t* c, int32_t c_offset, int32_t c_mult_int,
int32_t c_shift, int ldc) {
assert((c_shift >= 0) && (c_shift <= 32));
assert(a != nullptr);
assert(b != nullptr);
assert(c != nullptr);
int a_i_stride;
int a_l_stride;
if (transpose_a) {
a_i_stride = lda;
a_l_stride = 1;
} else {
a_i_stride = 1;
a_l_stride = lda;
}
int b_j_stride;
int b_l_stride;
if (transpose_b) {
b_j_stride = 1;
b_l_stride = ldb;
} else {
b_j_stride = ldb;
b_l_stride = 1;
}
int c_i_stride;
int c_j_stride;
if (transpose_c) {
c_i_stride = ldc;
c_j_stride = 1;
} else {
c_i_stride = 1;
c_j_stride = ldc;
}
int i, j, l;
const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift – 1));
for (j = 0; j < n; j++) {
for (i = 0; i < m; i++) {
int32_t total = 0;
for (l = 0; l < k; l++) {
const int a_index = i * a_i_stride + l * a_l_stride;
const uint8_t a_as_byte = a[a_index];
const int32_t a_as_int = static_cast<int32_t>(a_as_byte) + a_offset;
const int b_index = j * b_j_stride + l * b_l_stride;
const uint8_t b_as_byte = b[b_index];
const int32_t b_as_int = static_cast<int32_t>(b_as_byte) + b_offset;
const int32_t mult_as_int = a_as_int * b_as_int;
total += mult_as_int;
}
int32_t output =
(((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
if (output > 255) {
output = 255;
}
if (output < 0) {
output = 0;
}
const int c_index = i * c_i_stride + j * c_j_stride;
c[c_index] = static_cast<uint8_t>(output);
}
}
}

view raw

test.cc

hosted with ❤ by GitHub

This includes some eight-bit specific code, but the structure is common across all my reference versions, with three nested loops across the m, n, and k dimensions of the matrices. It also doesn’t include some of the standard arguments like alpha or beta. This particular code assumes column major by default, but if any of the transpose flags are set to true, then the matrix is treated as row major.

Leading Dimensions

The final pieces of the puzzle for me were the lda, ldb, and ldc arguments. These left me confused initially because I struggled to find a clear representation of that they meant. I finally realized that they were the number of values you moved forward in memory when you reached the end of a row (in row-major order) or column (in column major). They’re strides that give a lot of flexibility when you only want to work with smaller tiles inside a larger matrix, since they let you skip over values you want to ignore. If you’re not dealing with sub-tiles, then they’ll be the number of columns for a row-major matrix, and the number of rows for a column-major one.

Anyway, I hope these notes help any other lost souls who are struggling with matrices. I’m now feeling a lot more confident, and I wish I’d taken the time to study them more carefully before. They’re very powerful tools, and actually a lot of fun once I moved past some of my confusion!

12 responses

  1. Hi,

    Can you please throw some light on why multiplication operation on the 8 bit inputs is done by blasting the inputs to int32 first? I am assuming variables like “a_offset” are used to align quantized 8 bit values to actual zero value in floating point domain, is conversion to int32 before multiplication done to avoid overflow caused by adding this offset value?

    Thanks

    • The reference implementation expands to 32-bit to keep the code simple, but the various assembler versions in gemmlowp itself operate on 8×8 multiplies, producing 16-bit results, that are summed into a 32-bit accumulator. The offset parts are handled in the O(n^2) outer loops, leaving the O(n^3) inner loop as pure eight-bit multiplies. Different SIMD instruction sets between x86 and various ARM NEON flavors make the exact bit depth choices at each stage flexible (e.g. sometimes it makes sense to promote the 8-bit values to 16-bit to do a faster 16×16 multiply and accumulate directly into 32-bit) but this is the general idea.

      • Thanks for the clarification. I have one more doubt which may be a bit naive!
        I am going through the 8 bit quantization discussed by you in various posts and the code on github, and I understand that we add offset values to make the quantized values align to actual 0 value to hold special properties of zero true. However I am not able to find out where and how are we handling overflows when we add the offset value(location of zero on the quantized scale) as the variable already holds values in range [0,255] and might overflow..

  2. Pingback: Why Deep Learning Needs Assembler Hackers « Pete Warden's blog

  3. A very nice way to learn GEMM is to follow the tutorial by Ulrich Drepper in “What every programmer should know about memory” (https://people.freebsd.org/~lstewart/articles/cpumemory.pdf). The article shows how to start with some “naive code” (page 49) and optimisations are added until you get to the final code in Appendix A.1. OpenBLAS has some dgemm and sgemm assembly kernels here: https://github.com/xianyi/OpenBLAS/tree/develop/kernel/x86_64 (different versions optimised for Sandy Bridge, Haswell, Bulldozer, Piledriver, etc). On the other hand, not all optimisations for floating point matrix multiplication will necessarily translate to integer matrix multiplications (but apparently, the matrix transpose does, and possibly the optimisations for cache as well).

  4. I’m confused with what gemm does.

    TestWithSmallData is actually computing B*A instead of A*B.

    A shape is 3 * 4.
    B shape is 2 * 3.
    B * A shape is 2 * 4.

    // Matrix A (LHS) is:
    // | 7 | 10 | 13 | 16 |
    // | 8 | 11 | 14 | 17 |
    // | 9 | 12 | 15 | 18 |

    // Matrix B (RHS) is:
    // | 1 | 3 | 5 |
    // | 2 | 4 | 6 |

  5. Pingback: How do the leading dimension parameter for submatrices in cblas_dgemm() works? - Tutorial Guruji

  6. Pingback: Caches Considered Harmful for Machine Learning « Pete Warden's blog

Leave a comment