-
Notifications
You must be signed in to change notification settings - Fork 0
/
sigmoid_softmax_iris_bias_nn.py
101 lines (85 loc) · 3.2 KB
/
sigmoid_softmax_iris_bias_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
from sklearn import datasets
import numpy as np
from simple_utils import sigmoid, sigmoid_grad, softmax, cross_entropy_loss, cross_entropy_grad
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser(description='Simple three layer neural net script with sigmoid activations')
parser.add_argument('--plot', action='store_true',
help='plot the loss and weights')
args = parser.parse_args()
X,y = datasets.load_iris(return_X_y=True)
num_labels = len(np.unique(y))
y_onehot = np.eye(num_labels)[y]
# seed random number to make the calculation
# deterministic (easier to debug, etc)
np.random.seed(1)
input_dims = X.shape[1]
hidden_dims = 8
epochs = 5000
W0 = np.random.randn(input_dims, hidden_dims)
W1 = np.random.randn(hidden_dims, num_labels)
W0s = W0.copy()
W1s = W1.copy()
B0 = np.random.randn(1,hidden_dims)
B1 = np.random.randn(1,num_labels)
iz = []
losses = []
for i in range(epochs):
# forward propagate
a1 = sigmoid(X.dot(W0)+B0)
a2 = softmax(a1.dot(W1)+B1)
loss = cross_entropy_loss(y_onehot, a2)
# backpropagation
# how much we missed times nothing
# the beauty of cross entropy
l2_delta = cross_entropy_grad(y_onehot, a2)
# how much did each l1 value contribute to the l2 loss
# (according to the weights)?
l1_loss = l2_delta.dot(W1.T)
# in what direction is the target a1?
# were we really sure? if so, don't change too much.
l1_delta = l1_loss * sigmoid_grad(a1)
# loss due to weights
nabla_w1 = a1.T.dot(l2_delta)
nabla_w0 = X.T.dot(l1_delta)
nabla_b1 = np.sum(l2_delta,0,keepdims=True)
nabla_b0 = np.sum(l1_delta,0)
if i == 0:
print("BBBBBBBB0000000 l1d({}) nb0({}) b0({})".format(l1_delta.shape, nabla_b0.shape, B0.shape))
print("BBBBBBBB1111111 l2d({}) nb1({}) b1({}) nb1N({})"
.format(l2_delta.shape, nabla_b1.shape, B1.shape, np.sum(l2_delta,0).shape))
# and update the weights
W1 -= nabla_w1
W0 -= nabla_w0
B1 -= nabla_b1
B0 -= nabla_b0
iz.append(i)
losses.append(loss)
if i != 0:
W0s = np.concatenate((W0s, W0.copy()), axis = 1)
W1s = np.concatenate((W1s, W1.copy()), axis = 1)
W0s = W0s.reshape(input_dims, epochs, hidden_dims )
print("Final prediction ({})".format(a2))
print("Final W ({})".format(W0))
if args.plot:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
Ws00, = ax2.plot(iz, W0s[0,:,0], label='W00')
Ws10, = ax2.plot(iz, W0s[1,:,0], label='W10')
Ws20, = ax2.plot(iz, W0s[2,:,0], label='W20')
Ws01, = ax2.plot(iz, W0s[0,:,1], label='W01')
Ws11, = ax2.plot(iz, W0s[1,:,1], label='W11')
Ws21, = ax2.plot(iz, W0s[2,:,1], label='W21')
Ws02, = ax2.plot(iz, W0s[0,:,2], label='W02')
Ws12, = ax2.plot(iz, W0s[1,:,2], label='W12')
Ws22, = ax2.plot(iz, W0s[2,:,2], label='W22')
Ws03, = ax2.plot(iz, W0s[0,:,3], label='W03')
Ws13, = ax2.plot(iz, W0s[1,:,3], label='W13')
Ws23, = ax2.plot(iz, W0s[2,:,3], label='W23')
lss, = ax1.plot(iz, losses, 'r-', label = 'loss', linewidth=3)
plt.legend(handles=[lss, Ws00, Ws10, Ws20, Ws01, Ws11, Ws21])
ax1.set_xlabel('Iterations')
ax1.set_ylabel('Loss', color='r')
ax2.set_ylabel('Weights')
ax2
plt.show()