老饼讲解-神经网络
BP神经网络
公式提取
打印BP表达式Demo
作者 : 老饼 日期 : 2022-06-09 05:22:59 更新 : 2022-06-29 01:31:59
本站原创文章,转载请说明来自《老饼讲解-BP神经网络》bp.bbbdata.com




布署生产,或者写报告等需要,有时我们需要把神经网络的表达式打印出来,

自己对着权重阈值慢慢拼凑无疑是非常繁琐和低效率的,这就需要我们程序化,

本文提供一个打印三层BP神经网络模型表达式的Demo代码。



  01. 代码结构  

代码主要包含了三个函数:


printBpModel:打印表达式的功能函数。


只支持三层BP神经网络,网络可以是多输入,多输出。


testPrintBpModel:示范如何调用printBpModel进行表达式打印。


在本示范中,先训练一个网络,然后把网络的权值阈值传到printBpModel中打印表达式。




 02. 代码效果 


运行代码后打印的表达式格式如下:







 三.代码  

matlab2014b亲测已跑通:

%本代码展示如何打印三层BP神经网络的数学表达式
%转载请说明来自 《老饼讲解神经网络》 bp.bbbdata.com
function testPrintBpModel()
%------------ 原始数据------------------
X= [linspace(-3,3,10);linspace(-2,1.5,10)];                        % 生成输入数据
y = [10*sin(X(1,:))+0.2*X(1,:).^2;10*sin(X(2,:))+0.2*X(1,:).^2];   % 生成y
setdemorandstream(88);                                      
%------------网络训练 ---------------------------
net            = newff(X,y,3,{'tansig','purelin'},'trainlm');
[net,tr,net_y] = train(net,X,y);             % 调用matlab神经网络工具箱自带的train函数训练网络y,net返回
% 提取权重WB
W12 = net.iw{1,1};                           % 第1层(输入层)到第2层(隐层)的权值
B2  = net.b{1};                              % 第2层(隐层)的神经元阈值
W23 = net.lw{2,1};                           % 第2层(输入层)到第3层(输出层)的权值
B3  = net.b{2};                              % 第3层(输出层)的神经元阈值

%------------调用printBpModel打印网络表达式 ---------------------------
fcell = printBpModel(W12,B2,W23,B3);         
end

function fcell =printBpModel(W12,B2,W23,B3)
% 打印三层BP神经网络的数学表达式
[hn,xn] = size(W12);          % 隐节点个数,输入个数
[yn,~]  = size(W23);          % 输出个数
fcell   = cell(hn+2,yn);      % 初始化表达式cell
for t = 1:yn
    fcell{1,t}   = ['y',num2str(t),'='];
    %------逐个隐节点拼装------
    for i=1:hn
        % ----tansig头----
        cur_tansig = [num2str(W23(t,i)),'*tansig('];
        if(i>1 && W23(t,i)>=0)
            cur_tansig =['+',cur_tansig] ;
        end
        
        %--逐个x拼装-----
        for j = 1: xn
            sign_str = '';
            if(j>1 && W12(i,j)>0 )
                sign_str='+';
            end
            cur_tansig = [cur_tansig,sign_str,num2str(W12(i,j)),'*x',num2str(j)];
        end
        
        % 拼装阈值
        sign_str = '';
        if( B2(i)>0 )
            sign_str='+';
        end
        cur_tansig = [cur_tansig,sign_str,num2str(B2(i)),')'];
        fcell{i+1,t}= cur_tansig;
    end
    % ----拼装输出层阈值------
    sign_str = '';
    if( B3(t)>=0 )
        sign_str='+';
    end
    fcell{hn+2,t}=[sign_str,num2str(B3(t))];
end
% ----------打印---------
[rn,cn] = size(fcell);
for i = 1:cn
    for j =1:rn
    disp(fcell{j,i})
    end
    disp(' ')
end
end








 End 




联系小饼