打算開發類神經網路訓練來解讀股價走勢
以下程式碼改自別人的範例
是我的蹲馬步練基本功的初步成果
最基本的神經網路--感知器
import random, math
import string
from sympy import *
def hardlims(inputM,weightM,bias):
n=(weightM*inputM)[0].evalf()-bias
if n<0:
a = -1
if n>=0:
a = 1
return a
def sum_array(a):
acc = 0
for i in a:
acc += i
return acc
class NN:
def __init__(self):
self.w1 = [random.random() for x in range(10)]
self.w2 = [random.random() for x in range(10)]
self.w3 = [random.random() for x in range(10)]
self.bias = 0
def learn(self,input,target):
errorAbsSum = 1
epsilon = 1e-6
iterations = 0
while abs(errorAbsSum) > epsilon:
iterations +=1
errorAbsSum = 0
counter = 0
tmp1 = (sum_array(self.w1) / len(self.w1))
tmp2 = (sum_array(self.w2) / len(self.w2))
tmp3 = (sum_array(self.w3) / len(self.w3))
for i,x in enumerate(input):
activation = hardlims(Matrix(3,1,x),Matrix(1,3,[tmp1,tmp2,tmp3]),self.bias)
error = target[i] - activation
learningRate = random.random() / 10
tmp1 += (error * learningRate * x[0])
tmp2 += (error * learningRate * x[1])
tmp3 += (error * learningRate * x[2])
errorAbsSum += abs(error)
print 'Learning:w1=%f w2=%f w3=%f learningRate=%f error=%f activation=%f' % ( tmp1, tmp2,tmp3,learningRate,error,activation)
del self.w1[0], self.w2[0],self.w3[0]
self.w1.append(tmp1)
self.w2.append(tmp2)
self.w3.append(tmp3)
print 'Iterations: %d \n w1=%f w2=%f w3=%f' % (iterations, tmp1, tmp2,tmp3)
def test(self,x):
tmp1 = (sum_array(self.w1) / len(self.w1))
tmp2 = (sum_array(self.w2) / len(self.w2))
tmp3 = (sum_array(self.w3) / len(self.w3))
activation = hardlims(Matrix(3,1,x),Matrix(1,3,[tmp1,tmp2,tmp3]),self.bias)
print 'test(%s):result=%d w1=%f w2=%f w3=%f' % (x, activation, tmp1, tmp2,tmp3)
if __name__ == '__main__':
input =[[-1,-1,-1],[1,1,1],[1,-1,-1]]
target = [-1,1,1]
nn=NN();
nn.learn(input,target)
nn.test([1,1,1])
nn.test([1,1,-1])
nn.test([1,-1,1])
nn.test([1,-1,-1])
nn.test([-1,1,1])
nn.test([-1,1,-1])
nn.test([-1,-1,1])
nn.test([-1,-1,-1])
驗收訓練成果
test([1, 1, 1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([1, 1, -1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([1, -1, 1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([1, -1, -1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, 1, 1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, 1, -1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, -1, 1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, -1, -1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
縮排有點亂
有誰能指點一下?