喜欢0次
小组名:啊啊对对对队
报名编号:CICC2400
iris_training.mat文件如下
链接:https://pan.baidu.com/s/14vb1c0noPB4YKCCdOCsofA?pwd=ozmz
提取码:ozmz
不赘述,正确率96.7%
load("iris_training.mat")
X = [iristraining(1:30,1:4);iristraining(41:70,1:4);iristraining(81:110,1:4)];
D = [iristraining(1:30,5);iristraining(41:70,5);iristraining(81:110,5)];
X_test = [iristraining(31:40,1:4);iristraining(71:80,1:4);iristraining(111:120,1:4)];
D_test = [iristraining(31:40,5);iristraining(71:80,5);iristraining(111:120,5)];
%%Bp neural network
%三层bp神经网络,四维特征,三分类问题
%最后三个节点的输出分别代表三个类别,当输入类别为i是,当且仅当第i个神经元输出大于0.5时,认为判别正确
%%bp神经网络共分为三层,输入层,中间层,输出层,其中输入层4节点,中间层4节点,输出层3节点
w1 = 2*rand(5,4)-1; %连接权重
w2 = 2*rand(5,4)-1;
w3 = 2*rand(5,3)-1;
s = 0.01; %误差
a = 0.05; %学习率
d_correct = 0.9;
d_wrong = 0.1;
err = 1;
gen = 0;
%BP训练
while err>s && gen<1000
gen = gen + 1
err = 0;
for i = 1:90
%前馈
for m = 1:4
x(m) = logsig(w1(1,m)*X(i,1) + w1(2,m)*X(i,2) + w1(3,m)*X(i,3) + w1(4,m)*X(i,4) - w1(5,m)); %输入层
end
for m = 1:4
y(m) = logsig(w2(1,m)*x(1) + w2(2,m)*x(2) + w2(3,m)*x(3) + w2(4,m)*x(4) - w2(5,m)); %中间层
end
for m = 1:3
z(m) = logsig(w3(1,m)*y(1) + w3(2,m)*y(2) + w3(3,m)*y(3) + w3(4,m)*y(4) - w3(5,m)); %输出层
end
for m =1:3
if (D(i)+1)==m
err = err + (z(m)-d_correct)^2;
else
err = err + (z(m)-d_wrong)^2;
end
end
%反馈
for m = 1:3
delta_w3(m) = z(m)*(1-z(m))*(z(m)-(d_wrong+(d_correct-d_wrong)*((D(i)+1)==m)));
end
for m = 1:4
delta_w2(m) = y(m)*(1-y(m))*(delta_w3(1)*w3(m,1) + delta_w3(2)*w3(m,2) + delta_w3(3)*w3(m,3));
end
for m = 1:4
delta_w1(m) = x(m)*(1-x(m))*(delta_w2(1)*w2(m,1) + delta_w2(2)*w2(m,2) + delta_w2(3)*w2(m,3) + delta_w2(4)*w2(m,4));
end
%更新网络权重
w1 = w1 - a*[X(i,:) -1]'*delta_w1;
w2 = w2 - a*[x -1]'*delta_w2;
w3 = w3 - a*[y -1]'*delta_w3;
end
err = err/270;
end
%测试效果
correct = 0;
for i = 1:30
corr = 1;
%前馈
for m = 1:4
x(m) = logsig(w1(1,m)*X_test(i,1) + w1(2,m)*X_test(i,2) + w1(3,m)*X_test(i,3) + w1(4,m)*X_test(i,4) - w1(5,m)); %输入层
end
for m = 1:4
y(m) = logsig(w2(1,m)*x(1) + w2(2,m)*x(2) + w2(3,m)*x(3) + w2(4,m)*x(4) - w2(5,m)); %中间层
end
for m = 1:3
z(m) = logsig(w3(1,m)*y(1) + w3(2,m)*y(2) + w3(3,m)*y(3) + w3(4,m)*y(4) - w3(5,m)); %输出层
end
for m = 1:3
if m==(D_test(i)+1) && z(m)<0.5
corr = 0;
end
if m~=(D_test(i)+1) && z(m)>0.5
corr = 0;
end
end
if corr==1
correct = correct + 1;
end
end
correct = correct/30;
fprintf("correction is %d%%",correct*100)