本站原创文章,转载请说明来自《老饼讲解-BP神经网络》bp.bbbdata.com
本文是笔者细扒matlab神经网络工具箱newgrnn的源码后,
去除冗余代码,重现的简版newgrnn代码,代码与newgrnn的结果完全一致。
通过本代码的学习,可以完全细节的了解广义回归神经网络的实现逻辑。
代码包含了三个函数:
testGrnnNet:测试用例主函数,直接运行时就是执行该函数。
1、数据:生成一个2输入1输出的训练数据,
2、用自写的函数构建一个广义回归网络,与用网络进行预测。
3、使用工具箱newgrnn训练一个广义回归网络。比较自写函数与工具箱训练结果是否一致(比较测试数据的输出结果是否一致)
trainGrnnNet:训练主函数。
训练一个广义回归神经网络。
predictGrnnNet:预测主函数。
传入需要预测的X,与网络的权重矩阵,即可得到预测结果。
1、训练好的网络参数
2、与工具箱结果的比较
从运行结果可以看到,自写代码与工具箱的结果一样,说明扒出的逻辑与工具箱的一致。
需要的(最小)输入输出
● 待求的网络参数
(1) W21:输入到隐层的权重
(2) B2:隐层阈值
(3) W32:隐层到输出层的权重
● 需要的输入
(1) X:训练数据的输入
(2) y:训练数据的输出
(3) spread:一个用于控制径向基的肥瘦的参数。
广义回归构建流程
广义神经网络,不需要任何训练,只需要把输入数据存成网络的参数即可:
W21(输入层-->隐层权重):用X作为w21即可
B2(隐层阈值):用spread生成,
W32(隐层-->输出层层权重):用y作为w32即可
广义回归预测流程
计算y的公式格式参考如下Demo:
1. 计算各个隐节点的输出:
即计算各个exp的值。
先计算它到各个W21(也即训练样本的X)的距离,再将距离乘以B2(即a),
再经激活函数转换(即套上径向基函数),这样就得到隐节点的输出。
2. 隐节点归一化
将各个隐节点/所有隐节点的和
3.得到最后输出:
W32(即训练数据的y)乘以归一化后的隐节点(即y乘以权重),求和后就是最终的输出。
matlab2014b亲测已跑通:
function testGrnnNet()
%本代码来自 bp.bbbdata.com
%本代码模仿matlab神经网络工具箱的newgrnn神经网络,用于训练《广义回归神经网络》,
%代码主旨用于教学,供大家学习理解newgrnn神经网络原理
%--------生成训练数据-------------------
x1 = 1:1:10;
X = [ x1; x1];
y = sin(X(1, :)) + X( 2, :);
test_x = [2 3]'; %测试数据
%---------参数预设----------------------
spread = 2; %扩展系数
%------调用自写Grnn函数获得广义回归神经网络-----------------
[w21,b2,w32] = trainGrnnNet( X,y,spread )
py = predictGrnnNet(w21,b2,w32,test_x) %模型预测
%------调用matlab神经网络工具箱训练广义回归神经网络
net=newgrnn(X,y,spread); % 用工具箱设计广义回归网络
pyByBox = sim(net, test_x) % 工具箱对测试数据的预测结果
% -------检查自写代码与工具箱的结果是否一致------------------------------
testResult = isequal( py, pyByBox);
disp(['testIsequal = ',num2str(testResult)]);
web('bp.bbbdata.com')
end
% 广义神经网络的生成函数
function [w21,b2,w32] = trainGrnnNet(X,y,spread)
%生成广义神经网络只要将输入输出存到w21,w32中,
%再用spread生成影响径向基宽度的b2就可以
w21 = X';
w32 = y;
b2 = ones( size(X,2), 1)*sqrt( -log(.5))/spread;
end
% 广义神经网络的预测函数
function y = predictGrnnNet(w21,b2,w32,X)
y = [];
for i = 1:size(X,2)
cur_x = X(:,i);
hv = b2.*sqrt(sum((ones(size(w21,1),1)*cur_x' - w21).^2,2)); % 计算隐节点的值
ha = exp(-(hv.*hv)); % 计算隐节点激活值
cur_y = w32*(ha./sum(ha)); % 计算输出
y = [y,cur_y];
end
end
阅读代码,调试代码,才能真正的理解神经网络算法的原理,希望本文能够帮助大家进一步掌握神经网络原理.
End