00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef BACKPROPAGATIONALGO_H
00021 #define BACKPROPAGATIONALGO_H
00022
00026 #include "types.h"
00027 #include "learningalgorithm.h"
00028 #include <map>
00029
00030 namespace nnfw {
00031
00032 class AbstractModifier;
00033
00037 class NNFW_API BackPropagationAlgo : public LearningAlgorithm {
00038 public:
00041
00048 BackPropagationAlgo( BaseNeuralNet *n_n, UpdatableVec update_order, Real l_r = 0.1f );
00049
00051 ~BackPropagationAlgo( );
00052
00054
00056
00060 void setTeachingInput( Cluster* output, const RealVec& ti );
00061
00062 virtual void learn();
00063
00065 virtual void learn( const Pattern& );
00066
00068 virtual Real calculateMSE( const Pattern& );
00069
00071 void setRate( Real newrate ) {
00072 learn_rate = newrate;
00073 };
00074
00076 Real rate() {
00077 return learn_rate;
00078 };
00079
00081 void setMomentum( Real newmom ) {
00082 momentumv = newmom;
00083 };
00084
00086 Real momentum() {
00087 return momentumv;
00088 };
00089
00091 void enableMomentum();
00092
00094 void disableMomentum() {
00095 useMomentum = false;
00096 };
00097
00119 const RealVec& getError( Cluster* );
00121
00122 private:
00124 Real learn_rate;
00126 Real momentumv;
00128 Real useMomentum;
00130 UpdatableVec update_order;
00131
00133 class NNFW_API cluster_deltas {
00134 public:
00135 Cluster* cluster;
00136 AbstractModifier* modcluster;
00137 bool isOutput;
00138 RealVec deltas_outputs;
00139 RealVec deltas_inputs;
00140 RealVec last_deltas_inputs;
00141 LinkerVec incoming_linkers_vec;
00142 VectorData<AbstractModifier*> incoming_modlinkers;
00143 VectorData<RealVec> incoming_last_outputs;
00144 };
00146 std::map<Cluster*, int> mapIndex;
00148 VectorData<cluster_deltas> cluster_deltas_vec;
00149
00150 void propagDeltas();
00151
00152 void addCluster( Cluster*, bool );
00153
00154 void addLinker( Linker* );
00155
00156 };
00157
00158 }
00159
00160 #endif
00161