How Do Machines Learn To Make Predictions?

Malay Haldar
7 min readJul 16, 2018

--

The craze behind machine learning is fueled by its ability to make predictions that come handy in running a business. Making predictions, a.k.a computing probabilities. This post tries to provide an intuitive understanding of how machines learn to spit out probabilities.

Suppose you run a bubble tea shop, and want to create a machine learning model to predict if customers will like pieces of mock coconut in their bubble tea.

Step #1: One day you make a note of all the customers who ordered mock coconut pieces, and whether they liked it or not (by rummaging through the garbage can later and checking how many of them actually finished all those white little cubes of industrial waste).

This allows you to make a chart like the one below, where 0 represents the customer didn’t like the mock coconut pieces, and 1 they really liked it.

Let’s start with the base case, where you know nothing about the customers. If you had to predict the probability of whether a future customer will like mock coconut pieces or not, what will it be? Based on the data so far, the best guess would be 4/10.

Step#2: Now let’s see how we tackle this in the machine learning world. There, the model tries to predict a value for the coconut preference so that the gap between the predicted preference and the actual preference for recorded examples is minimized. We measure this gap, or the error, between the model’s prediction and the actual preference of the customers as (actual_preference - predicted_preference)².

The reason behind measuring the error in this particular way will become clear in a minute. First let’s see an example of the error in prediction if the model predicts a value of 0.7:

When actual preference is 1 (a like), the error becomes (1- x)², which is (1- 0.7) x (1- 0.7) = 0.09. Similarly the error when actual preference is 0, error becomes x², which evaluates to 0.7 x 0.7 = 0.49.

In general, the total error can be written as: num_likes*(1 - x)² + num_dislikes*x² . In our example, num_likes, which is the number of customers who like mock coconut pieces is 4, and num_dislikes 6. What should be the value of x so that the total error is minimized? The figure below plots the errors for likes, dislikes, and their sum as you sweep through different prediction values along the x-axis:

The beauty of stating the error in prediction as (actual_preference -predicted_preference)² is that you get this bowl shaped total error, known as a smooth convex function, with a very clear minimum value. Note that the total error is minimized at 0.4, or 4/10 which is the probability of a customer liking mock coconut pieces that we computed before. Happy coincidence? Not at all! With a bit of calculus, you can derive that the function num_likes*(1 - x)² + num_dislikes*x² is minimized at x = num_likes / (num_likes + num_dislikes), which is exactly the probability of a customer liking pieces of mock coconut.

This forms the core of the inuition: the point at which the total error in predicting the likes and dislikes is minimized, also happens to be the point at which the prediction matches the probability of likes.

Step#3: Let’s make it a little more interesting now. Suppose in addition to recording whether the customers liked or disliked the coconut pieces, you record additional attributes. Like whether they ordered a small or large drink to begin with. Now your data may look like:

To improve your prediction, you can divide your customers into two groups — those who order a small drink (A, D, E and G), and those who order large (B, C, F, H, I, J). Now instead of making a single prediction, you can make two different predictions, based on the size of the drink. For small sized drinkers, the probability of their liking the mock coconut pieces can be adjusted to 2/4 = 0.5, and for large sized drinkers 2/6 = 0.33.

Once the information about the drink size is spent to create the two groupings of customers, you know nothing more about the created groups and the problem reduces to two instances of the example we discussed before. So the same intuition applies again, this time to the two groups created. Only now, instead of a simple probability of liking mock coconuts, you can speak of a more nuanced and accurate prediction - the probability of liking mock coconut pieces given the size of drinks. With some calculus, you can prove that the error of prediction is minimized at the probability of likes, even when you have multiple attributes for each customer, and even when some of the attributes are continuous values, like height of the customer.

When this basic mechanism is repeated over millions of customers, and you plug in hundreds of customer attributes, the quality of the prediction starts to appear like intelligence.

But intuitive as it is, in practice you would rarely minimize (predicted_preference - actual_preference)². Because by taking a square, the error grows large rather quickly if the predicted_preference is farther away from the actual_preference. The total error in this case gets dominated by the outliers in your data, the few erratic customers that defy the usual pattern. With the focus of the model consumed by a handful of erratic customers, your prediction quality for regular customers may suffer.

Step#4: To avoid getting bogged down by your oddball customers, you need something more levelheaded. Enter logistic loss, the hero of machine learning. Instead of representing your error on predicting likes as (1-x)², you use log(1+e^x). For error on predicting dislikes, instead of x² you use log(1+e^-x). Applying it to our first example where we knew nothing about the customer, we compare how the total error looks using the logistic loss function against the simple squaring of differences.

We plot the error over a slightly extended range of {-2, 2}, since we are interested in looking at the error when the prediction diverges from the actual recorded preferences. Notice how sharply the error rises in the violet curve when using the square. Outside this range the error rises even more dramatically. The error computed using logistic loss shown in green still maintains the bowl shape, but its curvature is much less. So even if the prediction strays away from 0 or 1, the error doesn’t skyrocket.

But here is the glitch, if you look closely, the low point of the green curve is not at 0.4 anymore! That is because the error represented as logistic loss is not minimized at the probability of likes. This is a frequently asked question — how does minimizing logistic loss end up computing probabilities? Strictly speaking, the question is not stated correctly. You don’t compute probabilities by minimizing logistic loss.

This is not the catastrophe it looks as first glance. The total error given by logistic loss, num_likes*log(1+e^-x) + num_dislikes*log(1+e^x), is minimized when x = log(num_likes/num_dislikes). This can be rearranged to say that error is minimized when x = log(probability of likes/(1-probability of likes)).

To extract the probability of likes buried in that mess, all you need to do is apply the sigmoid function on both sides. The sigmoid function is just the inverse of log(y/(1-y)), so probability of likes equals sigmoid(x), where x is the value that minimizes the error in prediction measured by the logistic loss function.

There! Done! That’s the heart of the intuition behind how machines end up solving such a diverse set of problems, from deciding whether a tumor can be cancerous or not, to whether users will like a post or not. Anytime you can go out in the world and collect examples of “something or not”, you can come back to your desk and try building a model to predict that “something”.

It doesn’t stop here. If you want to go beyond the framework of “something or not”, to “either this, or that, or that, or that..” you can graduate from logistic loss to its elder sibling with an equally descriptive name — softmax loss. The basic mechanism still remains the same, the point at which the total prediction error is at minimum is also the point at which the prediction is the inverse sigmoid of probability.

See also:

How Machines Learn To Doubt?

How Do Neural Networks Work?

How much training data do you need?

--

--