include/backpropagationalgo.h

Go to the documentation of this file.
00001 /********************************************************************************
00002  *  Neural Network Framework.                                                   *
00003  *  Copyright (C) 2005-2008 Gianluca Massera <emmegian@yahoo.it>                *
00004  *                                                                              *
00005  *  This program is free software; you can redistribute it and/or modify        *
00006  *  it under the terms of the GNU General Public License as published by        *
00007  *  the Free Software Foundation; either version 2 of the License, or           *
00008  *  (at your option) any later version.                                         *
00009  *                                                                              *
00010  *  This program is distributed in the hope that it will be useful,             *
00011  *  but WITHOUT ANY WARRANTY; without even the implied warranty of              *
00012  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               *
00013  *  GNU General Public License for more details.                                *
00014  *                                                                              *
00015  *  You should have received a copy of the GNU General Public License           *
00016  *  along with this program; if not, write to the Free Software                 *
00017  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA  *
00018  ********************************************************************************/
00019 
00020 #ifndef BACKPROPAGATIONALGO_H
00021 #define BACKPROPAGATIONALGO_H
00022 
00026 #include "types.h"
00027 #include "learningalgorithm.h"
00028 #include <map>
00029 
00030 namespace nnfw {
00031 
00032 class AbstractModifier;
00033 
00037 class NNFW_API BackPropagationAlgo : public LearningAlgorithm {
00038 public:
00041 
00048     BackPropagationAlgo( BaseNeuralNet *n_n, UpdatableVec update_order, Real l_r = 0.1f );
00049 
00051     ~BackPropagationAlgo( );
00052 
00054 
00056 
00060     void setTeachingInput( Cluster* output, const RealVec& ti );
00061 
00062     virtual void learn();
00063 
00065     virtual void learn( const Pattern& );
00066 
00068     virtual Real calculateMSE( const Pattern& );
00069 
00071     void setRate( Real newrate ) {
00072         learn_rate = newrate;
00073     };
00074 
00076     Real rate() {
00077         return learn_rate;
00078     };
00079 
00081     void setMomentum( Real newmom ) {
00082         momentumv = newmom;
00083     };
00084 
00086     Real momentum() {
00087         return momentumv;
00088     };
00089 
00091     void enableMomentum();
00092 
00094     void disableMomentum() {
00095         useMomentum = false;
00096     };
00097 
00119     const RealVec& getError( Cluster* );
00121 
00122 private:
00124     Real learn_rate;
00126     Real momentumv;
00128     Real useMomentum;
00130     UpdatableVec update_order;
00131 
00133     class NNFW_API cluster_deltas {
00134     public:
00135         Cluster* cluster;
00136         AbstractModifier* modcluster;
00137         bool isOutput;
00138         RealVec deltas_outputs;
00139         RealVec deltas_inputs;
00140         RealVec last_deltas_inputs;
00141         LinkerVec incoming_linkers_vec;
00142         VectorData<AbstractModifier*> incoming_modlinkers;
00143         VectorData<RealVec> incoming_last_outputs;
00144     };
00146     std::map<Cluster*, int> mapIndex;
00148     VectorData<cluster_deltas> cluster_deltas_vec;
00149     // --- propagate delta through the net
00150     void propagDeltas();
00151     // --- add a Cluster into the structures above
00152     void addCluster( Cluster*, bool );
00153     // --- add a Linker into the structures above
00154     void addLinker( Linker* );
00155 
00156 };
00157 
00158 }
00159 
00160 #endif
00161 
BerliOS Developer Logo Valid XHTML 1.0 Transitional Valid CSS!