Solving XOR using a Neural Network

Building your own neural network from scratch can be challenging and confusing. It can often be insightful to look at something simple to help grasp the concepts. In this tutorial we'll be doing exactly that. We will work on creating a neural network to solve the simple XOR problem in Python.

The Challenge:

When given two binary inputs $x_1$ and $x_2$, we want to output a 1 if exactly one of them is a 1. Otherwise we want to output a 0. This is represented by the truth table below:

$x_1$$x_2$Output
000
011
101
110

Network Architecture

If we plot the XOR problem, we can see that the data is not linearly separable. In other words, we can't separate the blue dots from the red dots using a straight line.



This means that our neural network will need to have at least one hidden layer with a non linear activation function to classify all inputs correctly.

The sigmoid function is a commonly used activation function which we will use for the purpose of this tutorial:

sigmoid($x$) = $\sigma(x)$ = $\frac{1}{1 + e^{-x}}$


We design the neural network taking the following points into consideration:

  • From experience, we know that the XOR problem can be written as a combination of the "AND" function and the "OR" function: $A$ XOR $B$ = ($A$ OR $B$) AND NOT ($A$ AND $B$). This allows us to engineer our network to have 2 nodes in the hidden layer.
  • Our network has to output one value between 0 and 1, so we will have one node in the output layer (we could have two nodes representing the probability of 0 and 1 respectively but let's keep it simple for now).
  • In general adding a bias term is a good idea because it lets a neural network learn a function that does not have to go through the origin.

The architecture of our network now looks as follows:



Our goal is to learn the weights in the network that will produce the correct output given our inputs ($x_1$ and $x_2$).

Learning Weights

We start by doing some setup and initializing our weights as small random numbers sampled uniformly between 0 and 1. We'll use a seed value so that the random values on my side are the same as yours. Let's use seed 27 because it's my birth date.

                        
What happens if we do a forward pass right now without any training? Let's try this with input values $x_1$, $x_2$ both as zero.

Weights $w_1$-$w_6$ correspond to the layer 1 weights and weights $w_7$-$w_9$ correspond to the layer 2 weights. With the above code, they should have been initialized as follows:

$\color{red}{w_1}$ $\color{red}{w_2}$ $\color{red}{w_3}$ $\color{blue}{w_4}$ $\color{blue}{w_5}$ $\color{blue}{w_6}$
0.42572141 0.73539729 0.38338077 0.81458374 0.8680032 0.97945663
$\color{green}{w_7}$ $\color{green}{w_8}$ $\color{green}{w_9}$
0.89319435 0.20971517 0.74182765

$\begin{align} h_1 & = \sigma(1 \cdot \color{red}{w_1} + x_1 \cdot \color{red}{w_2} + x_2 \cdot \color{red}{w_3}) \\\\ & = \sigma(1 \cdot 0.42572141 + 0 \cdot 0.73539729 + 0 \cdot 0.38338077) \\\\ & = \sigma(0.42572141) \\\\ & = 0.60485152 \end{align}$

$\begin{align} h_2 & = \sigma(1 \cdot \color{blue}{w_4} + x_1 \cdot \color{blue}{w_5} + x_2 \cdot \color{blue}{w_6}) \\\\ & = \sigma(1 \cdot 0.81458374 + 0 \cdot 0.8680032 + 0 \cdot 0.97945663) \\\\ & = \sigma(0.81458374) \\\\ & = 0.69308541 \end{align}$

$\begin{align} \hat{y} & = \sigma(1 \cdot \color{green}{w_7} + h_1 \cdot \color{green}{w_8} + h_2 \cdot \color{green}{w_9}) \\\\ & = \sigma(1 \cdot 0.89319435 + 0.60485152 \cdot 0.20971517 + 0.69308541 \cdot 0.74182765) \\\\ & = \sigma(1.53419081) \\\\ & = 0.82261865 \end{align}$

So when $x_1$ and $x_2$ are both zero, the output is 0.82261865.

Similarly can do the same for the other input rows:

$x_1$$x_2$ $\hat{y}$ Expected Output
000.822618650
010.842156171
100.842698351
110.853140430
We obtained these values by doing a forward pass on our network. The code for this forward pass is shown below. You can paste this code directly below the previous code segment, then run it to make sure you obtain the same output.

                        
These initial forward pass results are bad but there is hope! We can update the weights in our neural network using the backpropagation algorithm. This algorithm works by taking our estimates and comparing them to our expected output values. We then adjust our weights according to the difference. In practice we can use something called the Mean Squared Error (MSE) to get a measure on the error that our network makes when generating predictions. There are other ways of measuring the error but we will stick to MSE for its simplicity.

MSE = $\frac{1}{n} \sum_{i=1}^n (y_i-\hat{y}_i)^2$


We can see that the error grows larger when the difference between the exepcted output ($y$) and the calculated output ($\hat{y}$) is larger. It makes sense that we are taking the difference between the produced and expected values, but why is there a squared term in there?

The squared term serves 2 purposes:
  1. It keeps everything we add to our sum positive so that our error makes sense. If we didn't keep it positive then predicing a 1 as a 0 would give an error of 1 and predicing a 0 as a 1 would give us a -1. Adding these together would tell us that we had an error of 0 which is not what we want. Ofcourse, we could have just used an absolute value.
  2. More importantly, we want to penalise predictions which are further away from the expected values non linearly. For example if we expect a 0 and get a 0.1 this might be ok but if we predicted a 0.5 for that same value it would be a problem. Without squaring, the error is only 5 times worse. If we square the numbers then we have errors of 0.01 and 0.25, so the second prediction is now 25 times worse.

What we want to do is obtain the weights that minimize this mean squared error. We can do this by using a technique called gradient descent.

Gradient Descent Example

Using gradient descent, we want to minimize the error with respect to the weights. Let's update weight $w_7$ as an example:

$w_7$ = $w_7$ - $\alpha \frac{\partial E}{\partial W_7}$,

where $\alpha$ is our learning rate and $E$ is the MSE. For now, let's look at the error for the case where both $x_1$ and $x_2$ are zero to make things simpler:

$E = (y_1-\hat{y}_1)^2$


If we observe how $w_7$ affects the error, it contributes directly to $\hat{y}$.



Let's make a quick substitution to make things easier:

$\begin{align} \hat{y} & = \sigma(1 \cdot \color{green}{w_7} + h_1 \cdot \color{green}{w_8} + h_2 \cdot \color{green}{w_9}) \\\\ & = \sigma(net_3) \\\\ \end{align}$
where $net_3 = 1 \cdot \color{green}{w_7} + h_1 \cdot \color{green}{w_8} + h_2 \cdot \color{green}{w_9}$

We see that:
  • $E$ is a function of $\hat{y}$.
  • $\hat{y}$ is a function of $net_3$.
  • $net_3$ is a function of $w_7$.
This means that we have to use the chain rule to differentiate $E$ with respect to $w_7$. Before we start, we note that:

$\sigma'(x) = \sigma(x) \cdot (1-\sigma(x))$,

where $\sigma'(x)$ denotes the differentiation of the sigma function.

Using the chain rule:

$E = (y_1-\hat{y}_1)^2$

$\frac{\partial E}{w_7} = \frac{\partial E}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial net_3} \frac{\partial net_3}{\partial w_7}$

$\frac{\partial E}{\partial \hat{y}} = 2(\hat{y} - y) = 2(0.82261865 - 0) = 1.6452373$

$\frac{\partial \hat{y}}{\partial net_3} = \hat{y} \cdot (1 - \hat{y}) = 0.82261865 (1 - 0.82261865) = 0.1459172067$

$\frac{\partial net_3}{\partial w_7} = 1$

$\frac{\partial E}{w_7} = (1.6452373)(0.1459172067)(1) = 0.2400684312 $

Now we know how to update weight $w_7$:

$w_7 = w_7 - \alpha \frac{\partial E}{w_7}$

It is important to note that we do not update this weight immediately, otherwise we could influence our other weight updates. Instead we memorise the amount that our weight needs to be updated by. We do this by making use of a temporary matrix. We must also do the same calculations for weights $w_8$ and $w_9$ with the small difference that these weights are not attached to bias nodes:

$\frac{\partial E}{w_8} = \frac{\partial E}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial net_3} \frac{\partial net_3}{\partial w_8}$

$\frac{\partial net_3}{\partial w_8} = h_1 = 0.60485152$

$\frac{\partial E}{w_8} = (1.6452373)(0.1459172067)(0.60485152) = 0.1452057555 $

$\frac{\partial E}{w_9} = \frac{\partial E}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial net_3} \frac{\partial net_3}{\partial w_9}$

$\frac{\partial net_3}{\partial w_9} = h_2 = 0.69308541$

$\frac{\partial E}{w_9} = (1.6452373)(0.1459172067)(0.69308541) = 0.166387927 $


The code for this is as follows:
                        
It gets a bit more tricky for weights $w_1$ - $w_6$. We apply the same principles but we have to backpropagate further. We will work through an example using weight $w_1$.



Again we will make a quick substitution to make things easier:

$\begin{align} h_1 & = \sigma(1 \cdot \color{red}{w_1} + x_1 \cdot \color{red}{w_2} + x_2 \cdot \color{red}{w_3}) \\\\ & = \sigma(net_1) \\\\ \end{align}$
where $net_1 = 1 \cdot \color{red}{w_1} + x_1 \cdot \color{red}{w_2} + x_2 \cdot \color{red}{w_3}$

We see that:
  • $E$ is a function of $\hat{y}$.
  • $\hat{y}$ is a function of $net_3$.
  • $net_3$ is a function of $h_1$.
  • $h_1$ is a function of $net_1$.
  • $net_1$ is a function of $w_1$.

So now using the chain rule:

$ \frac{\partial E}{ \partial w_1} = \frac{\partial E}{\partial \hat{y}} \frac{\partial \hat{y}}{\partial net_3} \frac{\partial net_3}{\partial h_1} \frac{\partial h_1}{\partial net_1} \frac{\partial net_1}{\partial w_1}$

We've already calculated $\frac{\partial E}{\partial \hat{y}}$ and $\frac{\partial \hat{y}}{\partial net_3}$

$\frac{\partial net_3}{\partial h_1} = w_8 = 0.20971517 $

$\begin{align} \frac{\partial h_1}{\partial net_1} & = \frac{\partial \sigma(net_1)}{\partial net_1} \\\\ & = \sigma(net_1) (1 - \sigma(net_1)) \\\\ & = h_1 (1 - h_1) \\\\ & = 0.60485152 (1 - 0.60485152) \\\\ & = 0.2390061588 \\\\ \end{align}$

$\frac{\partial net_1}{\partial w_1} = 1$

$ \begin{align} \frac{\partial E}{\partial w_1} &= (1.6452373)(0.1459172067)(0.20971517)(0.2390061588)(1) \\\\ &= 0.012033 \end{align}$

We repeat the above proceedure for weights 2 and 3 noting that:

$\frac{\partial net_1}{\partial w_2} = x_1 = 0$

$\frac{\partial net_1}{\partial w_3} = x_2 = 0$

This means that:

$\frac{\partial E}{\partial w_2} = 0$

$\frac{\partial E}{\partial w_2} = 0$

For weights $w_4-w_6$ there is a slight change. Since these weights contribute to $h_2$ and not $h_1$, we need to differentiate $net_3$ with respect to $h_2$ instead. When this is complete, we should obtain the following values:
$\frac{\partial E}{\partial w_1}$0.012033
$\frac{\partial E}{\partial w_2}$0
$\frac{\partial E}{\partial w_3}$0
$\frac{\partial E}{\partial w_4}$0.03788282
$\frac{\partial E}{\partial w_5}$0
$\frac{\partial E}{\partial w_6}$0
These calculations must be repeated for each input value of $x_1$ and $x_2$ and then the weights should be updated accordingly. The code is shown below:
                        
If we run this, we will see that our code successfully manages to learn the xor function!

Analysis of Results

Interesting to note is the weights the network has learnt. Add this code to the end of the previous code:
                        
Running the code results in the following output being added:

The weights have caused the network to learn $h_1$ as an $AND$ function, and $h_2$ as an $OR$ function.
$x_1$$x_2$$h_1$$x_1$ $AND$ $x_2$
000.0010
010.0930
100.0940
110.8791
$x_1$$x_2$$h_2$$x_1$ $OR$ $x_2$
000.0650
010.9691
100.9711
1111

Let's review the function learnt for weights $w_7$-$w_9$ independently from the rest of the network. We do this by having $h_1$ and $h_2$ take on binary values instead of their actual values. Add this to the end of the previous code.
                        
This should add the following output:

The weights $w_7$-$w_9$ have learnt the function: $NOT(h_1)$ $AND$ $h_2$
$h_1$$NOT(h_1)$$h_2$$\hat{y}$$NOT(h_1)$ $AND$ $h_2$
0100.020
0110.991
00000
0010.0110
So putting this all together we have:

$ \begin{align} \hat{y} &= NOT(h_1) \, AND \, h_2 \\\\ &= NOT(x_1 \, AND \, x_1) \, AND \, (x_1 \, OR \, x_2) \end{align}$

This is a represents the $XOR$ function!