Why is it so difficult to retrain neural networks and get the same results?

Photo by Ian Sane

Last week I had a question from a colleague about reproducibility in TensorFlow, specifically in the 1.14 era. He wanted to be able to run the same training code multiple times and get exactly the same results, which on the surface doesn’t seem like an unreasonable expectation. Machine learning training is fundamentally a series of arithmetic operations applied repeatedly, so what makes getting the same results every time so hard? I had the same question when we first started TensorFlow, and I was lucky enough to learn some of the answers from the numerical programming experts on the team, so I want to share a bit of what I discovered.

There are good guides to achieving reproducibility out there, but they don’t usually include explanations for why all the steps involved are necessary, or why training becomes so slow when you do apply them. One reason reproducibility is so hard is that every single calculation in the process has the potential to change the end results, which means every step is a potential weak link. This means you have to worry about everything from random seeds (which are actually fairly easy to make reproducible, but can be hard to locate in a big code base) to code in external libraries. CuDNN doesn’t guarantee determinism by default for example, nor do some operations like reductions in TensorFlow’s CPU code.

It was the code execution part that confused me the most. I could understand the need to set random seeds correctly, but why would any numerical library with exactly the same inputs sometimes produce different outputs? It seemed like there must be some horrible design flaw to allow that to happen! Thankfully my teammates helped me wrap my head around what was going on.

The key thing I was missing was timing. To get the best performance, numerical code needs to run on multiple cores, whether on the CPU or the GPU. The important part to understand is that how long each core takes to complete is not deterministic. Lots of external factors, from the presence of data in the cache to interruptions from multi-tasking can affect the timing. This means that the order of operations can change. Your high school math class might have taught you that x + y + z will produce the same result as z + y + x, but in the imperfect world of floating point numbers that ain’t necessarily so. To illustrate this, I’ve created a short example program in the Godbolt Compiler Explorer.

#include <stdio.h>

float add(float a, float b) {
    return a + b;
}

int main(void) {
    float x = 0.00000005f;
    float y = 0.00000005f;
    float z = 1.0f;

    float result0 = x + y + z;
    float result1 = z + x + y;

    printf("%.9g\n", result0);
    printf("%.9g\n", result1);

    float result2 = add(add(x, y), z);
    float result3 = add(x, add(y, z));

    printf("%.9g\n", result2);
    printf("%.9g\n", result3);
}

You might not guess exactly what the result will be, but most people’s expectations are that result0, result1, result2, and result3 should all be the same, since the only difference is in the order of the additions. If you run the program in Godbolt though, you’ll see the following output:

ASM generation compiler returned: 0
Execution build compiler returned: 0
Program returned: 0
1.00000012
1
1.00000012
1

So what’s going on? The short answer is that floating point numbers only have so much precision, and if you try to add a very small number to a larger one, there’s a limit below which the small number gets rounded to zero and so the addition has no effect. In this example, I’ve set things up so that 0.00000005 is below that limit for 1.0, so that if you do the 1.0 + 0.00000005 operation first, there’s no change to the result because 0.00000005 is rounded to zero. However, if you do 0.00000005 + 0.00000005 first, this produces an intermediate sum of 0.0000001, which is above the rounding-to-zero limit when added to 1.0, and so it does affect the result.

This might seem like an artificial example, but most of the compute-intensive operations inside neural networks boil down to a long series of multiply-adds. Convolutions and fully-connected calculations will be split across multiple cores either on a GPU or CPU. The intermediate results from each core are often accumulated together, either in memory or registers. As the timing of the results being returned from each core varies, the order of addition will change, just as in the code above. Neural networks may require trillions of operations to be executed in any given training run, so it’s almost certain that these kinds of edge cases will occur and the results will vary.

So far I’ve been assuming we’re running on a single machine with exactly the same hardware for each run, but as you can imagine not only will the timing vary between platforms, but even things like the cache sizes may affect the dimensions of the tiles used to optimize matrix multiplies across cores, which will add a lot more opportunities for differences. You might be tempted to increase the size of the floating point representation from 32 bits to 64 to address the issue, but this only reduces the probability of non-determinism, but doesn’t eliminate it entirely. It will also have a big impact on performance.

Properly addressing these problems requires writers of the base math functions like GEMM to add extra constraints to their code, which can be a lot of work to get right, since testing anything timing-related and probabilistic is so complex. Since all the operations that you use in your network need to be modified and checked, the work usually requires a dedicated team to fix and verify all the barriers to reproducibility. These requirements also conflict with many of the optimizations that reduce latency on multi-core systems, so the functions become slower. In the guide above, the total time for a training run went from 28 to 105 minutes once all the modifications needed to ensure reproducibility were made.

I’m writing this post because I find it fascinating that our systems have become so complex that an algorithm like matrix multiply that can be described in a few lines of pseudo-code can have implementations that produce such surprising and unexpected results. I’m barely scratching the surface of all the complexities of floating point math here, but many other causes of similar issues are emerging as we do more and more calculations on massive distributed networks of computers. Honestly that’s one reason I enjoy working on embedded systems so much these days, I have a lot more confidence in my mental model of how those chips work. It does feel like it’s no longer possible for any one person to have a full and deep understanding of the whole stack on any modern platform, even just for software. I love digging into the causes of weird and surprising properties like non-determinism in training because it helps me understand more about what I don’t know!

3 responses

  1. I was misguided and stupid enough to write automagically distributed and deterministic training and inference in the DSSTNE framework back in 2014-2015. It was not well-received by the AI community and I was told that my focus on bitwise reproducibility and large models was pedantic and silly by the high muckety muck sorts of AI in industry. I find a lot of AI people don’t like engineers, especially the ones that get AI despite that being an extremely lucrative combination of skills the past decade.

    I left AI this year (with zero to negative interest in ever returning to it), but a friend sent me this article. It would have been trivial to make TF deterministic if that had been a 1st class feature from day one. Too bad the guys at DeepMind said my code made their eyes bleed. Their loss. I did the same thing for Molecular Dynamics previously, it’s kind of my (useless) superpower. Very useless it would seem.

    Great article though, maybe they’ll finally understand reproducible AI matters, doubly so in production when something goes wrong? Nah. Race conditions are fun, amIRight?

  2. Pingback: The Problem of Reproducability in Neural Networks – Curated SQL

  3. Hi,

    You migth find our paper Sources of Irreproducibility interesting: https://arxiv.org/abs/2204.07610 Here we try to give a broad overview of what can cause irreproducible results. You will recognize at least some of the points you make. I would be very interested in getting your feedback.

    Odd Erik

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: