打算開發類神經網路訓練來解讀股價走勢
以下程式碼改自別人的範例
是我的蹲馬步練基本功的初步成果
最基本的神經網路--感知器
import random, math
import string
from sympy import *
def hardlims(inputM,weightM,bias):
  n=(weightM*inputM)[0].evalf()-bias
  if n<0:
  a = -1
  if n>=0:
  a = 1
  return a
def sum_array(a):
  acc = 0
  for i in a:
  acc += i
  return acc
class NN:
def __init__(self):
  self.w1 = [random.random() for x in range(10)]
  self.w2 = [random.random() for x in range(10)]
  self.w3 = [random.random() for x in range(10)]
  self.bias = 0
def learn(self,input,target):
  errorAbsSum = 1
  epsilon = 1e-6
  iterations = 0
  while abs(errorAbsSum) > epsilon:
  iterations +=1
  errorAbsSum = 0
  counter = 0
  tmp1 = (sum_array(self.w1) / len(self.w1))
  tmp2 = (sum_array(self.w2) / len(self.w2))
  tmp3 = (sum_array(self.w3) / len(self.w3))
  for i,x in enumerate(input):
  activation = hardlims(Matrix(3,1,x),Matrix(1,3,[tmp1,tmp2,tmp3]),self.bias)
  error = target[i] - activation
  learningRate = random.random() / 10
  tmp1 += (error * learningRate * x[0])
  tmp2 += (error * learningRate * x[1])
  tmp3 += (error * learningRate * x[2])
  errorAbsSum += abs(error)
  print 'Learning:w1=%f w2=%f w3=%f learningRate=%f error=%f activation=%f' % ( tmp1, tmp2,tmp3,learningRate,error,activation)
  del self.w1[0], self.w2[0],self.w3[0]
  self.w1.append(tmp1)
  self.w2.append(tmp2)
  self.w3.append(tmp3)
  print 'Iterations: %d \n w1=%f w2=%f w3=%f' % (iterations, tmp1, tmp2,tmp3)
def test(self,x):
  tmp1 = (sum_array(self.w1) / len(self.w1))
  tmp2 = (sum_array(self.w2) / len(self.w2))
  tmp3 = (sum_array(self.w3) / len(self.w3))
  activation = hardlims(Matrix(3,1,x),Matrix(1,3,[tmp1,tmp2,tmp3]),self.bias)
  print 'test(%s):result=%d w1=%f w2=%f w3=%f' % (x, activation, tmp1, tmp2,tmp3)
if __name__ == '__main__':
  input =[[-1,-1,-1],[1,1,1],[1,-1,-1]]
  target = [-1,1,1]
  nn=NN();
  nn.learn(input,target)
  nn.test([1,1,1])
  nn.test([1,1,-1])
  nn.test([1,-1,1])
  nn.test([1,-1,-1])
  nn.test([-1,1,1])
  nn.test([-1,1,-1])
  nn.test([-1,-1,1])
  nn.test([-1,-1,-1])
驗收訓練成果
test([1, 1, 1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([1, 1, -1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([1, -1, 1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([1, -1, -1]):result=1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, 1, 1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, 1, -1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, -1, 1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
test([-1, -1, -1]):result=-1 w1=0.663924 w2=0.276574 w3=0.333518
縮排有點亂
有誰能指點一下?