using System; namespace NeuralNetworkPCL { public class NeuralSystem { DataCollection data; NeuralNetwork network; IHM ihm; // Configuration double learningRate = 0.3; double maxError = 0.005; int maxIterations = 10001; public NeuralSystem(int _nbInputs, int _nbHidden, int _nbOutputs, String[] _data, double _trainingRatio, IHM _ihm) { data = new DataCollection(_data, _nbOutputs, _trainingRatio); network = new NeuralNetwork(_nbInputs, _nbHidden, _nbOutputs); ihm = _ihm; } public void LearningRate(double _rate) { learningRate = _rate; } public void MaximumError(double _error) { maxError = _error; } public void MaximumIterations(int _iterations) { maxIterations = _iterations; } public void Run() { int i = 0; double totalError = Double.PositiveInfinity; double oldError = Double.PositiveInfinity; double totalGeneralisationError = Double.PositiveInfinity; double oldGeneralisationError = Double.PositiveInfinity; Boolean betterGeneralisation = true; while (i < maxIterations && totalError > maxError && betterGeneralisation) { oldError = totalError; totalError = 0; oldGeneralisationError = totalGeneralisationError; totalGeneralisationError = 0; // Evaluation foreach (DataPoint point in data.Points()) { double[] outputs = network.Evaluate(point); for (int outNb = 0; outNb < outputs.Length; outNb++) { double error = point.Outputs[outNb] - outputs[outNb]; totalError += (error * error); } // Calcul des nouveaux poids par rétropropagation network.AdjustWeights(point, learningRate); } // Généralisation foreach (DataPoint point in data.GeneralisationPoints()) { double[] outputs = network.Evaluate(point); for (int outNb = 0; outNb < outputs.Length; outNb++) { double error = point.Outputs[outNb] - outputs[outNb]; totalGeneralisationError += (error * error); } } if (totalGeneralisationError > oldGeneralisationError) { betterGeneralisation = false; } // Changer le taux if (totalError >= oldError) { learningRate = learningRate / 2.0; } // Information et incrément ihm.PrintMsg("Iteration n°" + i + " - Total error : " + totalError + " - Gener Error : " + totalGeneralisationError + " - Rate : " + learningRate + " - Mean : " + String.Format("{0:0.00}", Math.Sqrt(totalError/data.Points().Length),"%2")); i++; } } } }