老饼讲解-神经网络 机器学习 神经网络 深度学习
LVQ神经网络

【代码】使用matlab训练一个LVQ神经网络

作者 : 老饼 发表日期 : 2023-03-25 00:18:40 更新日期 : 2023-05-12 03:37:11
本站原创文章,转载请说明来自《老饼讲解-BP神经网络》www.bbbdata.com



LVQ神经网络是用于样本分类的一个常用算法

本文先简单回顾LVQ神经网络是什么,

然后展示如何用matlab工具箱来训练一个LVQ神经网络




  01. LVQ神经网络简要回顾  



本节回顾LVQ神经网络的思想和关键知识



     LVQ神经网络简介    


 LVQ神经网络是什么
LVQ用于解决分类问题,
它先对每个类别都初始化一些类别判别中心点
然后通过训练来调整这些类别判别中心的位置
使它们能较好地识别训练样本
 
这样,来了新样本,只要判断新样本离哪个聚类中心点近
就判断样本属于该聚类中心点所代表的类别

  LVQ神经网络的拓扑表示
 
LVQ一般用一个三层神经网络来表示,
它的拓扑结构如下:

 
 
其中,每个隐层节点代表着一个类别判别中心,
它与输入层的权重就是它的位置,
它的输出层的连接代表着它是哪个类别的判别中心
 例如,某个隐节点的输入权重为[0.3 0.5],输出权重为[0 1]
则代表它的位置为[0.3,0.5], 是类别1的判别中心







  02. 如何使用matlab训练一个LVQ神经网络   



本节讲解如何用matlab工具箱来训练一个LVQ神经网络



    matlab工具箱实现LVQ的代码   


下面以一个例子,讲述如何用matlab工具箱实现LVQ神经网络
代码如下:
%代码说明:matlab工具箱训练一个LVQ神经网络
%来自《老饼讲解神经网络》www.bbbdata.com ,matlab版本:2018a
%数据准备
clear all ;close all 
rand('seed',70)
P = [-3 -2 -2  0  0.5  -0.5  0 +2 +2 +3; ...
    0 +1 -1 +2 +1 -1 -2 +1 -1  0];                    % 输入数据
Tc = [1 1 1 2 2 2 2 1 1 1];                           % 输出类别
T = ind2vec(Tc);                                      % 将输出转为one-hot编码(代表类别的01向量)

%网络训练
net = newlvq(P,4,[0.5 ,0.5],0.01,'learnlv1');         % 建立一个LVQ神经网络,用lvq1规则训练
net = train(net,P,T);                                 % 训练神经网络
%预测
Y = sim(net,P);                                       % 预测(one-hot形式)
Yc = vec2ind(Y);                                      % 将one-hot编码形式转回类别编号形式
% 提取出各个类别的判别中心                            
c       = net.iw{1,1};                                % 中心
c_class = net.lw{2,1};                                % 中心所属类别
c       = [vec2ind(c_class)',c]                       % 添加中心的类别标签


% -------绘制结果-----------------
figure
% 绘制原始数据
subplot(2,1,1)
plot(P(1,Tc==1),P(2,Tc==1),'o','MarkerEdgeColor','k','MarkerFaceColor','b','MarkerSize',10)
hold on 
plot(P(1,Tc==2),P(2,Tc==2),'o','MarkerEdgeColor','k','MarkerFaceColor','g','MarkerSize',10)
legend('类别1','类别2')
title('原始数据类别')
% 绘制预测结果
subplot(2,1,2)
plot(P(1,Yc==1),P(2,Yc==1),'o','MarkerEdgeColor','k','MarkerFaceColor','b','MarkerSize',10)
hold on 
plot(P(1,Yc==2),P(2,Yc==2),'o','MarkerEdgeColor','k','MarkerFaceColor','g','MarkerSize',10)
hold on 
% 绘制网络的隐节点(类别判别中心)
plot(c(:,2),c(:,3),'o','MarkerEdgeColor','k','MarkerFaceColor','y','MarkerSize',10)
for i = 1: size(c,1)
text(c(i,2)-0.050,c(i,3)+0.02,num2str(c(i,1)))
end
title('LVQ预测类别')



    关键代码解说    


其中,核心代码为
net = newlvq(P,4,[0.5 ,0.5],0.01,'learnlv1');
它用于构建一个LVQ神经网络,
其中,
👉 1. P是训练数据的输入                                                   
 👉 2. 4代表我们使用4个隐节点                                           
  也就使用4个类别判别中心                    
 👉 3. [0.5,0.5]代表上述4个隐节点的类别分配比例               
  也就是类别1、类别2的判别中心各2个          
 
 👉 4. 0.01是学习率                                                               
  👉 5. 'learnlv1'则指定了训练方法                                         





      训练结果    


运行上述代码,得到结果如下
  
可见,训练后的LVQ神经网络的预测类别与真实样本一致,
它已经可以准确的对训练样本进行分类

 关于类别判别中心的位置
上述代码的还打印了类别判别中心的信息,如下
 
其中,每一行代表一个判别中心
第一列表示是哪一个类别的判别中心,     
第2、3列表示判别中心的坐标                 









 End 




联系老饼