A few days ago I received a question from Plant Village, a team I’m collaborating with about a problem that’s emerged with a mobile app they’re developing. It detects plant diseases, and is delivering good results when it’s pointed at leaves, but if you point it at a computer keyboard it thinks it’s a damaged crop.
This isn’t a surprising result to computer vision researchers, but it is a shock to most other people, so I want to explain why it’s happening, and what we can do about it.
As people, we’re used to being able to classify anything we see in the world around us, and we naturally expect machines to have the same ability. Most models are only trained to recognize a very limited set of objects though, such as the 1,000 categories of the original ImageNet competition. Crucially, the training process makes the assumption that every example the model sees is one of those objects, and the prediction must be within that set. There’s no option for the model to say “I don’t know”, and there’s no training data to help it learn that response. This is a simplification that makes sense within a research setting, but causes problems when we try to use the resulting models in the real world.
Back when I was at Jetpac, we had a lot of trouble convincing people that the ground-breaking AlexNet model was a big leap forward because every time we handed over a demo phone running the network, they would point it at their faces and it would predict something like “Oxygen mask” or “Seat belt”. This was because the ImageNet competition categories didn’t include any labels for people, but most of the photos with mask and seatbelt labels included faces along with the objects. Another embarrassing mistake came when they would point it at a plate and it would predict “Toilet seat”! This was because there were no plates in the original categories, and the closest white circular object in appearance was a toilet.
I came to think of this as the “open world” versus “closed world” problem. Models were trained and evaluated assuming that there was only ever going to be a limited universe of objects presented to them, but as soon as they make it outside the lab that assumption breaks down and they are judged by users on their performance for any arbitrary object that’s put in front of them, whether or not it was in the training set.
So, What’s The Solution?
Unfortunately I don’t know of a simple fix for this problem, but there are some strategies that I’ve seen help. The most obvious start is to add an “Unknown” class to your training data. The bad news is that just opens up a whole different set of issues
- What examples should go into that class? There’s an almost limitless number of possible natural images, so how do you choose which to include?
- How many of each different type of object do you need in the unknown class?
- What should you do about unknown objects that look very similar to the classes you care about? For example adding a dog breed that’s not in the ImageNet 1,000, but looks nearly identical, will likely force a lot of what would have been correct matches into the unknown bucket.
- What proportion of your training data should be made up of examples of the unknown class?
This last point actually touches on a much larger issue. The prediction values you get from image classification networks are not probabilities. They assume that the odds of seeing any particular class are equal to how often that class shows up in the training data. If you try to use an animal classifier that includes penguins in the Amazon jungle you’ll experience this problem, since (presumably) all of the penguin sightings will be false positives. Even with dog breeds in a US city, the rarer breeds show up a lot more often in the ImageNet training data than they will in a dog park, so they’ll be over-represented as false positives. The usual solution is to figure out what the prior probabilities in the situation you’ll be facing in production are, and then use those to apply calibration values to the network’s output to get something that’s closer to real probabilities.
The main strategy that helps tackle the overall problem in real applications is constraining the model’s usage to situations where the assumptions about what objects will be present matches the training data. A straightforward way of doing this is through product design. You can create a user interface that directs people to focus their device on an object of interest before running the classifier, much like applications that ask you to take photographs of checks or other documents often do.
Getting a little more sophisticated, you can write a separate image classifier that tries to identify conditions that the main image classifier is not designed for. This is different than adding a single “Unknown” class, because it acts more like a cascade, or a filter before the detailed model. In the crop disease case, the operating environment is visually distinct enough that it might be fine to just train a model to distinguish between leaves and a random selection of other photos. There’s enough similarity that the gating model should at least be able to tell if the image is being taken in a type of scene that’s not supported. This gating model would be run before the full image classifier, and if it doesn’t detect something that looks like it could be a plant, it will bail out early with an error message indicating no crops were found.
Applications that ask you to capture images of credit cards or perform other kinds of OCR will often use a combination of on-screen directions and a model to detect blurriness or lack of alignment to guide users to take photos that can be successfully processed, and having a “are there leaves?” model is a simple version of this interface pattern.
This probably isn’t a very satisfying set of answers, but they’re a reflection of the messiness of user expectations once you take machine learning beyond constrained research problems. There’s a lot of common sense and external knowledge that goes into a person’s recognition of an object, and we don’t capture any of that in the classic image classification task. To get results that meet user expectations, we have to design a full system around our models that understands the world that they will be deployed in, and makes smart decisions based on more than just the model outputs.