include/learningalgorithm.h

Go to the documentation of this file.
00001 /********************************************************************************
00002  *  Neural Network Framework.                                                   *
00003  *  Copyright (C) 2005-2008 Gianluca Massera <emmegian@yahoo.it>                *
00004  *                                                                              *
00005  *  This program is free software; you can redistribute it and/or modify        *
00006  *  it under the terms of the GNU General Public License as published by        *
00007  *  the Free Software Foundation; either version 2 of the License, or           *
00008  *  (at your option) any later version.                                         *
00009  *                                                                              *
00010  *  This program is distributed in the hope that it will be useful,             *
00011  *  but WITHOUT ANY WARRANTY; without even the implied warranty of              *
00012  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               *
00013  *  GNU General Public License for more details.                                *
00014  *                                                                              *
00015  *  You should have received a copy of the GNU General Public License           *
00016  *  along with this program; if not, write to the Free Software                 *
00017  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA  *
00018  ********************************************************************************/
00019 
00020 #ifndef LEARNINGALGORITHM_H
00021 #define LEARNINGALGORITHM_H
00022 
00026 #include "types.h"
00027 #include "neuralnet.h"
00028 #include <map>
00029 #include <cmath>
00030 
00031 namespace nnfw {
00032 
00033 class BaseNeuralNet;
00034 
00064 class NNFW_API Pattern {
00065 public:
00068     class PatternInfo {
00069     public:
00070         RealVec inputs;
00071         RealVec outputs;
00072     };
00074 
00077     Pattern() : pinfo(), empty() { /*nothing to do*/ };
00079     ~Pattern() { /*nothing to do*/ };
00080 
00082 
00085     void setInputsOf( Cluster*, const RealVec& );
00087     void setOutputsOf( Cluster*, const RealVec& );
00089     void setInputsOutputsOf( Cluster*, const RealVec& inputs, const RealVec& outputs );
00090 
00092     const RealVec& inputsOf( Cluster* ) const;
00094     const RealVec& outputsOf( Cluster* ) const;
00095 
00098     PatternInfo& operator[]( Cluster* );
00099 
00101 private:
00102     mutable std::map<Cluster*, PatternInfo> pinfo;
00103     RealVec empty;
00104 };
00105 
00117 class NNFW_API PatternSet : public VectorData<Pattern> {
00118 public:
00121 
00123     PatternSet() : VectorData<Pattern>() { };
00124 
00126     PatternSet( u_int size ) : VectorData<Pattern>( size ) { };
00127 
00129     PatternSet( u_int size, Pattern& pat ) : VectorData<Pattern>( size, pat ) { };
00130     
00132     PatternSet( PatternSet& src, u_int idS, u_int idE ) : VectorData<Pattern>( src, idS, idE ) { };
00133 
00139     PatternSet( const PatternSet& src ) : VectorData<Pattern>( src ) { };
00141 };
00142 
00147 class NNFW_API LearningAlgorithm {
00148 public:
00151 
00153     LearningAlgorithm( BaseNeuralNet* net );
00155     virtual ~LearningAlgorithm();
00156 
00158 
00160 
00162     BaseNeuralNet* net() {
00163         return netp;
00164     };
00165 
00167     virtual void learn() = 0;
00168 
00170     virtual void learn( const Pattern& ) = 0;
00171 
00173     virtual void learnOnSet( const PatternSet& set ) {
00174         for( int i=0; i<(int)set.size(); i++ ) {
00175             learn( set[i] );
00176         }
00177     };
00178 
00180     virtual Real calculateMSE( const Pattern& ) = 0;
00181     
00183     virtual Real calculateMSEOnSet( const PatternSet& set ) {
00184         Real mseacc = 0.0;
00185         int dim = (int)set.size();
00186         for( int i=0; i<dim; i++ ) {
00187             mseacc += calculateMSE( set[i] );
00188         }
00189         return mseacc/dim;
00190     };
00191 
00193     Real calculateRMSD( const Pattern& p ) {
00194         return sqrt( calculateMSE( p ) );
00195     };
00196 
00198     Real calculateRMSDOnSet( const PatternSet& p ) {
00199         return sqrt( calculateMSEOnSet( p ) );
00200     };
00201 
00203 
00204 private:
00205     BaseNeuralNet* netp;
00206 };
00207 
00208 }
00209 
00210 #endif
00211 
BerliOS Developer Logo Valid XHTML 1.0 Transitional Valid CSS!