37

I trained a simple CNN on the MNIST database of handwritten digits to 99% accuracy. I'm feeding in a bunch of handwritten digits, and non-digits from a document.

I want the CNN to report errors, so I set a threshold of 90% certainty below which my algorithm assumes that what it's looking at is not a digit.

My problem is that the CNN is 100% certain of many incorrect guesses. In the example below, the CNN reports 100% certainty that it's a 0. How do I make it report failure?

Handwritten digit input - does not look like a digit

My thoughts on this: Maybe the CNN is not really 100% certain that this is a zero. Maybe it just thinks that it can't be anything else, and it's being forced to choose (because of normalisation on the output vector). Is there any way I can get insight into what the CNN "thought" before I forced it to choose?

PS: I'm using Keras on Tensorflow with Python.

Edit

Because someone asked. Here is the context of my problem:

This came from me applying a heuristic algorithm for segmentation of sequences of connected digits. In the image above, the left part is actually a 4, and the right is the curve bit of a 2 without the base. The algorithm is supposed to step through segment cuts, and when it finds a confident match, remove that cut and continue moving along the sequence. It works really well for some cases, but of course it's totally reliant on being able to tell if what it's looking at is not a good match for a digit. Here's an example of where it kind of did okay.

Various input images for handwritten digit recognition

My next best option is to do inference on all permutations and maximise combined score. That's more expensive.

nbro
  • 42,615
  • 12
  • 119
  • 217
Alexander Soare
  • 1,379
  • 3
  • 12
  • 28

6 Answers6

35

The concept you are looking for is called epistemic uncertainty, also known as model uncertainty. You want the model to produce meaningful calibrated probabilities that quantify the real confidence of the model.

This is generally not possible with simple neural networks as they simply do not have this property, for this you need a Bayesian Neural Network (BNN). This kind of network learns a distribution of weights instead of scalar or point-wise weights, which then allow to encode model uncertainty, as then the distribution of the output is calibrated and has the properties you want.

This problem is also called out of distribution (OOD) detection, and again it can be done with BNNs, but unfortunately training a full BNN is untractable, so we use approximations.

As a reference, one of these approximations is Deep Ensembles, which train several instances of a model in the same dataset and then average the softmax probabilities, and has good out of distribution detection properties. Check the paper here, in particular section 3.5 which shows results for OOD based on entropy of the ensemble probabilities.

Dr. Snoopy
  • 1,363
  • 8
  • 12
15

Your classifier is specifically learning the ways in which 0s are different from other digits, not what it really means for a digit to be a zero.

Philosophically, you could say the model appears to have some powerful understanding when restricted to a tightly controlled domain, but that facade is lifted as soon as you throw any sort of wrench in the works.

Mathematically, you could say that the model is simply optimizing a classification metric for data drawn from a specific distribution, and when you give it data from a different distribution, all bets are off.

The go-to answer is to collect or generate data like the data you expect the model to deal with (in practice, the effort required to do so can vary dramatically depending upon the application). In this case, that could involve drawing a bunch of random scribbles and adding them to your training data set. At this point you must ask, now how do I label them? You will want a new "other" or "non-digit" class so that your model can learn to categorize these scribbles separately from digits. After retraining, your model should now better deal with these cases.

However, you may then ask, but what if I gave it color images of digits? Or color images of farm animals? Maybe pigs will be classified as zeros because they are round. This problem is a fundamental property of the way deep learning is orchestrated. Your model is not capable of higher order logic, which means it can seem to go from being very intelligent to very dumb by just throwing the slightest curve ball at it. For now, all deep learning does is recognize patterns in data that allow it to minimize some loss function.

Deep learning is a fantastic tool, but not an all-powerful omnitool. Bear in mind its limitations and use it where appropriate, and it will serve you well.

7

Broken assumptions

Generalization relies on making strong assumptions (no free lunch, etc). If you break your assumptions, then you're not going to have a good time. A key assumption of a standard digit-recognition classifier like MNIST is that you're classifying pictures that actually contain a single digit. If your real data contains pictures that have non-digits, then that means that your real data is not similar to training data but is conceptually very, very different.

If that's a problem (as in this case) then one way to treat that is to explicitly break that assumption and train a model that not only recognizes digits 0-9 but also recognizes whether there's a digit at all, and is able to provide an answer "that's not a digit", so a 11-class classifier instead of a 10-class one. MNIST training data is not sufficient for that, but you can use some kind of 'distractor' data to provide the not-a-digit examples. For example, you could use some dataset of letters (perhaps omitting I, l, O and B) transformed to look similar to MNIST data.

Peteris
  • 883
  • 5
  • 8
3

Apollys,

That's a very well thought out response. Particularly, the philosophical discussion of the essence of "0-ness."

I haven't actually performed this experiment, so caveat emptor... I wonder how well an "other" class would actually work. The ways in which "other" differs from "digit" has infinite variability (or at least its only limitation is the cardinality of the input layer).

The NN decides whether something is more of one class or more of a different class. If there isn't an essence in common among other "non-digits", I don't believe it will do well at identifying "other" as the catch-all for everything that has low confidence level of classification.

This approach still doesn't identify what it is to be "not-digit". It identifies how all the things that are "other" differ from the other labeled inputs -- probably poorly, depending on the variability of the "non-digit" labeled data. (i.e. is it numerically exhaustive, many times over, of all random scribbles?) Thoughts?

Rich Chase
  • 131
  • 2
2

I'm an amateur with neural networks, but I will illustrate my understanding of how this problem comes to be.

First, lets see how trivial neural network classifies 2D input into two classes : enter image description here

But in case of complex neural network, the input space is much bigger and the sample data points are much more clustered with big chunks of empty space between them:

enter image description here

The neural network then doesn't know how to classify the data in the empty space, so something like this is possible :

enter image description here enter image description here

When using the traditional ways of measuring quality of neural networks, both of these will be considered good. As they do classify the classes themselves correctly.

Then, what happens if we try to classify these data points?

enter image description here

Really, neural network has no data it could fall back on, so it just outputs what seems to us as random nonsense.

Euphoric
  • 121
  • 2
0

In your particular case, you could add a eleventh category to your training data: "not a digit".

Then train your model with a bunch of images of incorrectly segmented digits, in addition to the normal digit examples. This way the model will learn to tell apart real digits from incorrectly segmented ones.

However even after doing that, there will be an infinite number of random looking images that will be classified as digits. They're just far away from the examples of "not a digit" you provided.

jpa
  • 121
  • 2