Linear Regression

Let’s look at the hello world of machine learning algorithms.

In this post, we will look at a very simple machine learning algorithm which is actually the “hello world” equivalent of programming languages.

Regression analysis - Wikipedia

So what is linear regression? If you are from a statistics background chances are you already know what it means. In statistics, linear regression is a tool that is used for finding a relation between a dependent variable and an independent variable. It’s that simple. Consider the equation below.

This is a linear equation that captures the relationship between the variables y and x. The variable y and the variable x are called dependent and independent variables respectively. The subscript h means it is our hypothesis function.

If you are given a set of data points and you think that the points are related to each other linearly, that is, there exists some equation y=mx +c which is able to represent the dependent and independent variables, if not exactly then approximately, you would like to apply the linear regression algorithm and try seeing if it is a good fit.

In the above picture, there is a set of data points. We have fit a line through the data points. Notice that the line does not fit all the points, in fact, it just fits a very small number of points, however, it captures the trend of the data. Which is somewhat linear.

Now how do we find this line? More mathematically, how do we find the parameters m and c for the linear equation we discussed above?

It’s simple. We choose a random line, we calculate the sum of the squared distances of points from this line. We then try to minimize this sum, i.e we find parameters for which this sum is lowest. Why does the sum have to be the minimum for the best parameters? For this, you must first ask what are the best parameters.

The best parameters are the ones for which 1. We get a line that covers most of the points and 2. for those points, it does not cover, we want it to be as close as possible to the line.

If you think about the above two points, you will realize that both requirements can be satisfied if the sum of the squared distances is minimum.

In particular, we will try to minimize the average of the sum of squared distances. We take the average so that our model doesn’t depend on the number of data points.

There are some geometrical and linear algebra reasons why we use squared distances instead of absolute distances, won’t cover that much in detail here, if you are interested, you can refer to them here.

We write the above-mentioned metric in the form of a function and call it the loss function. See this:

Now the goal of linear regression is to reduce the output of the cost function as much as possible i.e we want to minimize this function.

If everything makes sense here, let’s move to the process of minimizing this function.

Now if you have some experience in maths, then you might know that there exists something called analytical solutions which aims at transforming the given problem into a well-known form and then calculating its solution.

But we won’t be doing that here, because that leads to some complex matrix inversion operations which are computationally expensive in the case of multivariate linear regression, i.e regression in which we have more than one x.

Next article we will see how to solve this problem.

Scroll to Top