NeuralSystem.cs 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. using System;
  2. namespace NeuralNetworkPCL
  3. {
  4. public class NeuralSystem
  5. {
  6. DataCollection data;
  7. NeuralNetwork network;
  8. IHM ihm;
  9. // Configuration
  10. double learningRate = 0.3;
  11. double maxError = 0.005;
  12. int maxIterations = 10001;
  13. public NeuralSystem(int _nbInputs, int _nbHidden, int _nbOutputs, String[] _data, double _trainingRatio, IHM _ihm)
  14. {
  15. data = new DataCollection(_data, _nbOutputs, _trainingRatio);
  16. network = new NeuralNetwork(_nbInputs, _nbHidden, _nbOutputs);
  17. ihm = _ihm;
  18. }
  19. public void LearningRate(double _rate)
  20. {
  21. learningRate = _rate;
  22. }
  23. public void MaximumError(double _error)
  24. {
  25. maxError = _error;
  26. }
  27. public void MaximumIterations(int _iterations)
  28. {
  29. maxIterations = _iterations;
  30. }
  31. public void Run()
  32. {
  33. int i = 0;
  34. double totalError = Double.PositiveInfinity;
  35. double oldError = Double.PositiveInfinity;
  36. double totalGeneralisationError = Double.PositiveInfinity;
  37. double oldGeneralisationError = Double.PositiveInfinity;
  38. Boolean betterGeneralisation = true;
  39. while (i < maxIterations && totalError > maxError && betterGeneralisation)
  40. {
  41. oldError = totalError;
  42. totalError = 0;
  43. oldGeneralisationError = totalGeneralisationError;
  44. totalGeneralisationError = 0;
  45. // Evaluation
  46. foreach (DataPoint point in data.Points())
  47. {
  48. double[] outputs = network.Evaluate(point);
  49. for (int outNb = 0; outNb < outputs.Length; outNb++)
  50. {
  51. double error = point.Outputs[outNb] - outputs[outNb];
  52. totalError += (error * error);
  53. }
  54. // Calcul des nouveaux poids par rétropropagation
  55. network.AdjustWeights(point, learningRate);
  56. }
  57. // Généralisation
  58. foreach (DataPoint point in data.GeneralisationPoints())
  59. {
  60. double[] outputs = network.Evaluate(point);
  61. for (int outNb = 0; outNb < outputs.Length; outNb++)
  62. {
  63. double error = point.Outputs[outNb] - outputs[outNb];
  64. totalGeneralisationError += (error * error);
  65. }
  66. }
  67. if (totalGeneralisationError > oldGeneralisationError)
  68. {
  69. betterGeneralisation = false;
  70. }
  71. // Changer le taux
  72. if (totalError >= oldError)
  73. {
  74. learningRate = learningRate / 2.0;
  75. }
  76. // Information et incrément
  77. ihm.PrintMsg("Iteration n°" + i + " - Total error : " + totalError + " - Gener Error : " + totalGeneralisationError + " - Rate : " + learningRate + " - Mean : " + String.Format("{0:0.00}", Math.Sqrt(totalError/data.Points().Length),"%2"));
  78. i++;
  79. }
  80. }
  81. }
  82. }