RISC-V MCU中文社区

【分享】 使用matlab搭建BP从零搭建BP神经网络完成鸢尾花数据集分类

发表于 全国大学生集成电路创新创业大赛 2023-05-26 10:11:23
0
1341
0

小组名:啊啊对对对队
报名编号:CICC2400
iris_training.mat文件如下
链接:https://pan.baidu.com/s/14vb1c0noPB4YKCCdOCsofA?pwd=ozmz
提取码:ozmz
不赘述,正确率96.7%

  1. load("iris_training.mat")
  2. X = [iristraining(1:30,1:4);iristraining(41:70,1:4);iristraining(81:110,1:4)];
  3. D = [iristraining(1:30,5);iristraining(41:70,5);iristraining(81:110,5)];
  4. X_test = [iristraining(31:40,1:4);iristraining(71:80,1:4);iristraining(111:120,1:4)];
  5. D_test = [iristraining(31:40,5);iristraining(71:80,5);iristraining(111:120,5)];
  6. %%Bp neural network
  7. %三层bp神经网络,四维特征,三分类问题
  8. %最后三个节点的输出分别代表三个类别,当输入类别为i是,当且仅当第i个神经元输出大于0.5时,认为判别正确
  9. %%bp神经网络共分为三层,输入层,中间层,输出层,其中输入层4节点,中间层4节点,输出层3节点
  10. w1 = 2*rand(5,4)-1; %连接权重
  11. w2 = 2*rand(5,4)-1;
  12. w3 = 2*rand(5,3)-1;
  13. s = 0.01; %误差
  14. a = 0.05; %学习率
  15. d_correct = 0.9;
  16. d_wrong = 0.1;
  17. err = 1;
  18. gen = 0;
  19. %BP训练
  20. while err>s && gen<1000
  21. gen = gen + 1
  22. err = 0;
  23. for i = 1:90
  24. %前馈
  25. for m = 1:4
  26. x(m) = logsig(w1(1,m)*X(i,1) + w1(2,m)*X(i,2) + w1(3,m)*X(i,3) + w1(4,m)*X(i,4) - w1(5,m)); %输入层
  27. end
  28. for m = 1:4
  29. y(m) = logsig(w2(1,m)*x(1) + w2(2,m)*x(2) + w2(3,m)*x(3) + w2(4,m)*x(4) - w2(5,m)); %中间层
  30. end
  31. for m = 1:3
  32. z(m) = logsig(w3(1,m)*y(1) + w3(2,m)*y(2) + w3(3,m)*y(3) + w3(4,m)*y(4) - w3(5,m)); %输出层
  33. end
  34. for m =1:3
  35. if (D(i)+1)==m
  36. err = err + (z(m)-d_correct)^2;
  37. else
  38. err = err + (z(m)-d_wrong)^2;
  39. end
  40. end
  41. %反馈
  42. for m = 1:3
  43. delta_w3(m) = z(m)*(1-z(m))*(z(m)-(d_wrong+(d_correct-d_wrong)*((D(i)+1)==m)));
  44. end
  45. for m = 1:4
  46. delta_w2(m) = y(m)*(1-y(m))*(delta_w3(1)*w3(m,1) + delta_w3(2)*w3(m,2) + delta_w3(3)*w3(m,3));
  47. end
  48. for m = 1:4
  49. delta_w1(m) = x(m)*(1-x(m))*(delta_w2(1)*w2(m,1) + delta_w2(2)*w2(m,2) + delta_w2(3)*w2(m,3) + delta_w2(4)*w2(m,4));
  50. end
  51. %更新网络权重
  52. w1 = w1 - a*[X(i,:) -1]'*delta_w1;
  53. w2 = w2 - a*[x -1]'*delta_w2;
  54. w3 = w3 - a*[y -1]'*delta_w3;
  55. end
  56. err = err/270;
  57. end
  58. %测试效果
  59. correct = 0;
  60. for i = 1:30
  61. corr = 1;
  62. %前馈
  63. for m = 1:4
  64. x(m) = logsig(w1(1,m)*X_test(i,1) + w1(2,m)*X_test(i,2) + w1(3,m)*X_test(i,3) + w1(4,m)*X_test(i,4) - w1(5,m)); %输入层
  65. end
  66. for m = 1:4
  67. y(m) = logsig(w2(1,m)*x(1) + w2(2,m)*x(2) + w2(3,m)*x(3) + w2(4,m)*x(4) - w2(5,m)); %中间层
  68. end
  69. for m = 1:3
  70. z(m) = logsig(w3(1,m)*y(1) + w3(2,m)*y(2) + w3(3,m)*y(3) + w3(4,m)*y(4) - w3(5,m)); %输出层
  71. end
  72. for m = 1:3
  73. if m==(D_test(i)+1) && z(m)<0.5
  74. corr = 0;
  75. end
  76. if m~=(D_test(i)+1) && z(m)>0.5
  77. corr = 0;
  78. end
  79. end
  80. if corr==1
  81. correct = correct + 1;
  82. end
  83. end
  84. correct = correct/30;
  85. fprintf("correction is %d%%",correct*100)
喜欢0
用户评论
铭………

铭……… 实名认证

凭来去

积分
问答
粉丝
关注
  • RV-STAR 开发板
  • RISC-V处理器设计系列课程
  • 培养RISC-V大学土壤 共建RISC-V教育生态
RV-STAR 开发板