| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- using System;
- namespace NeuralNetworkPCL
- {
- public class NeuralSystem
- {
- DataCollection data;
- NeuralNetwork network;
- IHM ihm;
- // Configuration
- double learningRate = 0.3;
- double maxError = 0.005;
- int maxIterations = 10001;
- public NeuralSystem(int _nbInputs, int _nbHidden, int _nbOutputs, String[] _data, double _trainingRatio, IHM _ihm)
- {
- data = new DataCollection(_data, _nbOutputs, _trainingRatio);
- network = new NeuralNetwork(_nbInputs, _nbHidden, _nbOutputs);
- ihm = _ihm;
- }
- public void LearningRate(double _rate)
- {
- learningRate = _rate;
- }
- public void MaximumError(double _error)
- {
- maxError = _error;
- }
- public void MaximumIterations(int _iterations)
- {
- maxIterations = _iterations;
- }
- public void Run()
- {
- int i = 0;
- double totalError = Double.PositiveInfinity;
- double oldError = Double.PositiveInfinity;
- double totalGeneralisationError = Double.PositiveInfinity;
- double oldGeneralisationError = Double.PositiveInfinity;
- Boolean betterGeneralisation = true;
- while (i < maxIterations && totalError > maxError && betterGeneralisation)
- {
- oldError = totalError;
- totalError = 0;
- oldGeneralisationError = totalGeneralisationError;
- totalGeneralisationError = 0;
-
- // Evaluation
- foreach (DataPoint point in data.Points())
- {
- double[] outputs = network.Evaluate(point);
- for (int outNb = 0; outNb < outputs.Length; outNb++)
- {
- double error = point.Outputs[outNb] - outputs[outNb];
- totalError += (error * error);
- }
- // Calcul des nouveaux poids par rétropropagation
- network.AdjustWeights(point, learningRate);
- }
- // Généralisation
- foreach (DataPoint point in data.GeneralisationPoints())
- {
- double[] outputs = network.Evaluate(point);
- for (int outNb = 0; outNb < outputs.Length; outNb++)
- {
- double error = point.Outputs[outNb] - outputs[outNb];
- totalGeneralisationError += (error * error);
- }
- }
- if (totalGeneralisationError > oldGeneralisationError)
- {
- betterGeneralisation = false;
- }
- // Changer le taux
- if (totalError >= oldError)
- {
- learningRate = learningRate / 2.0;
- }
- // Information et incrément
- ihm.PrintMsg("Iteration n°" + i + " - Total error : " + totalError + " - Gener Error : " + totalGeneralisationError + " - Rate : " + learningRate + " - Mean : " + String.Format("{0:0.00}", Math.Sqrt(totalError/data.Points().Length),"%2"));
- i++;
- }
- }
- }
- }
|