Tesseract  3.02
tesseract-ocr/cube/hybrid_neural_net_classifier.h
Go to the documentation of this file.
00001 /**********************************************************************
00002  * File:        conv_net_classifier.h
00003  * Description: Declaration of Convolutional-NeuralNet Character Classifier
00004  * Author:    Ahmad Abdulkader
00005  * Created:   2007
00006  *
00007  * (C) Copyright 2008, Google Inc.
00008  ** Licensed under the Apache License, Version 2.0 (the "License");
00009  ** you may not use this file except in compliance with the License.
00010  ** You may obtain a copy of the License at
00011  ** http://www.apache.org/licenses/LICENSE-2.0
00012  ** Unless required by applicable law or agreed to in writing, software
00013  ** distributed under the License is distributed on an "AS IS" BASIS,
00014  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  ** See the License for the specific language governing permissions and
00016  ** limitations under the License.
00017  *
00018  **********************************************************************/
00019 
00020 #ifndef HYBRID_NEURAL_NET_CLASSIFIER_H
00021 #define HYBRID_NEURAL_NET_CLASSIFIER_H
00022 
00023 #include <string>
00024 #include <vector>
00025 
00026 #include "char_samp.h"
00027 #include "char_altlist.h"
00028 #include "char_set.h"
00029 #include "classifier_base.h"
00030 #include "feature_base.h"
00031 #include "lang_model.h"
00032 #include "neural_net.h"
00033 #include "tuning_params.h"
00034 
00035 namespace tesseract {
00036 
00037 // Folding Ratio is the ratio of the max-activation of members of a folding
00038 // set that is used to compute the min-activation of the rest of the set
00039 // static const float kFoldingRatio = 0.75;  // see conv_net_classifier.h
00040 
00041 class HybridNeuralNetCharClassifier : public CharClassifier {
00042  public:
00043   HybridNeuralNetCharClassifier(CharSet *char_set, TuningParams *params,
00044       FeatureBase *feat_extract);
00045   virtual ~HybridNeuralNetCharClassifier();
00046   // The main training function. Given a sample and a class ID the classifier
00047   // updates its parameters according to its learning algorithm. This function
00048   // is currently not implemented. TODO(ahmadab): implement end-2-end training
00049   virtual bool Train(CharSamp *char_samp, int ClassID);
00050   // A secondary function needed for training. Allows the trainer to set the
00051   // value of any train-time paramter. This function is currently not
00052   // implemented. TODO(ahmadab): implement end-2-end training
00053   virtual bool SetLearnParam(char *var_name, float val);
00054   // Externally sets the Neural Net used by the classifier. Used for training
00055   void SetNet(tesseract::NeuralNet *net);
00056 
00057   // Classifies an input charsamp and return a CharAltList object containing
00058   // the possible candidates and corresponding scores
00059   virtual CharAltList *Classify(CharSamp *char_samp);
00060   // Computes the cost of a specific charsamp being a character (versus a
00061   // non-character: part-of-a-character OR more-than-one-character)
00062   virtual int CharCost(CharSamp *char_samp);
00063 
00064  private:
00065   // Neural Net object used for classification
00066   vector<tesseract::NeuralNet *> nets_;
00067   vector<float> net_wgts_;
00068 
00069   // data buffers used to hold Neural Net inputs and outputs
00070   float *net_input_;
00071   float *net_output_;
00072 
00073   // Init the classifier provided a data-path and a language string
00074   virtual bool Init(const string &data_file_path, const string &lang,
00075                     LangModel *lang_mod);
00076   // Loads the NeuralNets needed for the classifier
00077   bool LoadNets(const string &data_file_path, const string &lang);
00078   // Load folding sets
00079   // This function returns true on success or if the file can't be read,
00080   // returns false if an error is encountered.
00081   virtual bool LoadFoldingSets(const string &data_file_path,
00082                                const string &lang,
00083                                LangModel *lang_mod);
00084   // Folds the output of the NeuralNet using the loaded folding sets
00085   virtual void Fold();
00086   // Scales the input char_samp and feeds it to the NeuralNet as input
00087   bool RunNets(CharSamp *char_samp);
00088 };
00089 }
00090 #endif  // HYBRID_NEURAL_NET_CLASSIFIER_H