00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef LEARNINGALGORITHM_H
00021 #define LEARNINGALGORITHM_H
00022
00026 #include "types.h"
00027 #include "neuralnet.h"
00028 #include <map>
00029 #include <cmath>
00030
00031 namespace nnfw {
00032
00033 class BaseNeuralNet;
00034
00064 class NNFW_API Pattern {
00065 public:
00068 class PatternInfo {
00069 public:
00070 RealVec inputs;
00071 RealVec outputs;
00072 };
00074
00077 Pattern() : pinfo(), empty() { };
00079 ~Pattern() { };
00080
00082
00085 void setInputsOf( Cluster*, const RealVec& );
00087 void setOutputsOf( Cluster*, const RealVec& );
00089 void setInputsOutputsOf( Cluster*, const RealVec& inputs, const RealVec& outputs );
00090
00092 const RealVec& inputsOf( Cluster* ) const;
00094 const RealVec& outputsOf( Cluster* ) const;
00095
00098 PatternInfo& operator[]( Cluster* );
00099
00101 private:
00102 mutable std::map<Cluster*, PatternInfo> pinfo;
00103 RealVec empty;
00104 };
00105
00117 class NNFW_API PatternSet : public VectorData<Pattern> {
00118 public:
00121
00123 PatternSet() : VectorData<Pattern>() { };
00124
00126 PatternSet( u_int size ) : VectorData<Pattern>( size ) { };
00127
00129 PatternSet( u_int size, Pattern& pat ) : VectorData<Pattern>( size, pat ) { };
00130
00132 PatternSet( PatternSet& src, u_int idS, u_int idE ) : VectorData<Pattern>( src, idS, idE ) { };
00133
00139 PatternSet( const PatternSet& src ) : VectorData<Pattern>( src ) { };
00141 };
00142
00147 class NNFW_API LearningAlgorithm {
00148 public:
00151
00153 LearningAlgorithm( BaseNeuralNet* net );
00155 virtual ~LearningAlgorithm();
00156
00158
00160
00162 BaseNeuralNet* net() {
00163 return netp;
00164 };
00165
00167 virtual void learn() = 0;
00168
00170 virtual void learn( const Pattern& ) = 0;
00171
00173 virtual void learnOnSet( const PatternSet& set ) {
00174 for( int i=0; i<(int)set.size(); i++ ) {
00175 learn( set[i] );
00176 }
00177 };
00178
00180 virtual Real calculateMSE( const Pattern& ) = 0;
00181
00183 virtual Real calculateMSEOnSet( const PatternSet& set ) {
00184 Real mseacc = 0.0;
00185 int dim = (int)set.size();
00186 for( int i=0; i<dim; i++ ) {
00187 mseacc += calculateMSE( set[i] );
00188 }
00189 return mseacc/dim;
00190 };
00191
00193 Real calculateRMSD( const Pattern& p ) {
00194 return sqrt( calculateMSE( p ) );
00195 };
00196
00198 Real calculateRMSDOnSet( const PatternSet& p ) {
00199 return sqrt( calculateMSEOnSet( p ) );
00200 };
00201
00203
00204 private:
00205 BaseNeuralNet* netp;
00206 };
00207
00208 }
00209
00210 #endif
00211