Machine Learning with k-Nearest Neighbors
Hey guys, this is Matt Heads 101. Today I'm gonna be talking about probably the simplest machine learning algorithm ever, and it's known as K nearest neighbors.
This algorithm is very inspired by the fact that as we go through life and we gain more experiences, we can take on new challenges and deal with new situations by comparing our current situation to previous situations we've been in and doing what's worked well for us in the past.
To give a more concrete example of something you can do with K nearest neighbors, I have taken the MNIST handwriting data set, which contains 60,000 handwritten digits. I've created a small web app here, which allows me to draw a digit, and it tells me which number it thinks I drew. So I can just go through and draw digits, and you can see it gets them right most of the time. It will make mistakes occasionally.
So the question I'm gonna be answering in this video is basically how does something like this work using K nearest neighbors? How do we take this idea of comparing something new to previous experiences, and how do we take that idea and make it concrete? We want to implement it in a program like this one, where it recognizes digits that I'm drawing on my computer screen.
I want to start by thinking about a really simple example. In this example, we are given a grid of data points, and some of the points are red and some of them are blue. These are literally points on a grid; you know you can think of them as XY coordinates—an x-coordinate and a y-coordinate. Each point has a color associated with it, so you can see that, and there's clearly a pattern to how these points are arranged.
If you look at it as a human with your whole brain, you know, this strong intelligence, you can clearly tell that there's a circle of points in the middle that are a different color than the outside. You can derive that pattern, but it's hard for a computer to just learn patterns like this. You know there's no intuition.
So how do we get the computer to learn a pattern like this? The more concrete question is how would we get the computer to color this image in so that if we point to a place on the image where there's no circle right there, you know one of the white places on the image, and we said, "What color is that?" A human could pretty much tell for sure if you point into the middle of the circle what color it should be.
So how do we get the computer to be just as confident and be able to do the same thing? How do we get it to generalize these data points to the rest of the picture? You know, the parts of the picture that aren't covered by data points.
I'm gonna propose perhaps the most straightforward and obvious solution to this problem. We can think of any image as a grid of pixels. You know, each pixel has a distinct color, and when you look at it from far enough away, it looks like a picture.
So here I've just drawn a grid. This is just, you know, right now we haven't decided the colors of any pixels, and we want to start coloring in the image—deciding what color each pixel should be. So we pick out a pixel that hasn't been colored yet, and now in order to decide what color it should be, we're gonna look at the data points that are close to it.
So here in this case, maybe we've been given some blue data points, and they happen to be to the left of the pixel, and some red data points, and they happen to be to the right of the pixel. To decide what color the pixel is going to be, we're just gonna find the data point which is the closest to it. In this case, it's a blue one right by the corner of it, and we're just going to make that pixel the same color as that data point.
As you can see here, I've written: we find the distance between the pixel and all the data points using just the basic distance formula. You probably learned it in middle school if you live in the US. It's just based on the Pythagorean theorem.
So it's really easy to implement in code because we have coordinates for all the pixels and coordinates for all the data points. So we've been able to decide what color each pixel is just by finding the data point that's closest to that given pixel and just making the pixel that color.
If you go through the whole image using this technique, you will get something like this. As you can see, it's not exactly the circle that I described earlier; it's kind of lumpy, but it's actually a pretty good generalization. You see that the whole outside area is blue and the whole inside area is red.
So this technique, where you look at all the data points around a pixel and pick the closest one, is actually an example of K nearest neighbors where K is equal to one. Basically, it's because we picked one data point—we picked the one data point that was closest to the pixel, and that is the only source we used to decide what color the pixel should be.
Now, you can also do K nearest neighbors with multiple data points. So say K was equal to three. The way I would decide what color of a pixel is then is I would look at the three closest data points to that pixel, and maybe I would pick the color that's the majority, or I would average out the colors and pick that.
For instance, if I were to average out the five closest data points to a given point, then I would get something that looks more like this. As you can see, there are sections that are kind of averages between completely red and completely blue because the five data points that are closest to them aren't necessarily all the same color.
But using a higher value of K does a lot more than just making our picture look more interesting; it actually helps us deal with outliers. So consider this data set, which is similar to the set of points we had before, but you can see I've taken one red point and clustered it with a bunch of blue points. So it's in the wrong place.
We could look at it and visually we can tell it's an outlier, and we can tell it should probably be ignored because, you know, whoever was entering this data or recording it might have just messed up and mislabeled it or something like that. But if we used K equal to one, so we only looked at one neighbor for each data point, every pixel that's closer to this red data point than to any of the blue data points—any of those pixels will be mislabeled as red.
This is colored red, but if instead we looked at the five neighboring data points, you know, the five closest data points to each pixel, and we were to say pick the majority color—so if blue happens more than red then we pick blue, otherwise we pick red—if we were to do it that way, then the red pixel, the red data point, would be ignored because it's surrounded by blue.
So no matter where we go, we might be closest to the red, but there are four blues that were also close, so we would get the majority rule. Using a higher value of K helps you deal with noise in the data. Now, if either data points are mislabeled, or if—going back to that handwriting thing—there's one, let's say one digit that looks really strange.
You know, someone really messed up their zero, and it kind of looks like a six. You know, you can kind of ignore outliers like that by using a higher value of K. Of course, depending on your problem, the value of K that works best will be different. You know, for some things, looking at if there's no noise in the data, just using K equal to one will give you the best results.
But if there's noise, a higher value of K might help you out. So now I'm almost ready to talk about how we might do something like handwriting recognition, but I still have to discuss distance a little bit. When all of our data points were two-dimensional, we could just look at how far apart they were and use the distance formula we learned in middle school or high school geometry to figure out how far apart they are.
But how do you figure out the distance between two letters that people have written? You know, how do you figure out distances between things that are just two-dimensional data points? Well, that's something we need to address. So first, let's just review the two-dimensional distance formula, where you have points that are just a pair of x and y.
So I can have this green point here as (x1, y1), and the red point is (x2, y2). We find their distances using this basic distance formula, where you square the difference in x, you square the difference in y, you add those two squares together, and you take the square root of it, and you call that the distance. Geometrically, it actually works out, and you can prove that this is actually the distance between the things by our intuitive definition of distance.
But it turns out it's not too difficult to extend this distance formula to more dimensions. So if you look here, you see that the x's are kind of their own little thing in the equation and the y's are their own little thing in the equation, and you just add them together right before taking the square root.
So if you also add a z-axis and a w-axis, and you know all of your data points had more than just two components—many components—you can actually extend this formula and get something like this. So if we had an x-axis, y-axis, z-axis, w-axis, etc., we could just define the distance as a formula where I've taken the terms of the two-dimensional distance and then I've added on more terms for all of the other coordinates.
This distance formula has a lot of nice properties. The most interesting one being it's actually correct for three dimensions. You know, it actually measures the distance in three dimensions if you actually go in and figure that out. It has nice properties for more dimensions even if that's not physically defined. You know, even if there is no actual, you know, fifty-dimensional space, it still has nice properties.
So how do we apply this to handwriting recognition? Well, really all a handwritten digit is, is it's just a picture. You know, if I draw an eight on the screen, it's just a bunch of pixels. In the case of the MNIST handwriting data set, it's actually 28 by 28 images, so it's pretty small. There's only 784 pixels in a handwritten image.
So really, all an image is, it's a bunch of numbers. You know, it's, in this case, 784 numbers. We can take the distance between two images using this distance formula, and there's just in this case 784 terms in the distance formula. But we can take the distance between two images.
So that handwriting demo I showed at the beginning of the video there are about 500 images of each number. So there's, you know, 500 zeroes, 500 ones, 500 twos, etc., taken from this database. When you draw a new image, it just goes through all of the things in its database and looks at the distances between those things in the database and the new thing you've drawn, and it draws a conclusion from that.
But the important part is you can kind of see how it comes together. We have this idea of, you know, classifying a data point based on its closest neighbors, and then we can measure the distance between basically anything using this extended distance formula.
And that's pretty much it. That's how this works. I think it's pretty straightforward. So I hope you learned something. Thanks for watching, subscribe, and goodbye.