Tesseract
3.02
|
00001 /********************************************************************** 00002 * File: conv_net_classifier.h 00003 * Description: Declaration of Convolutional-NeuralNet Character Classifier 00004 * Author: Ahmad Abdulkader 00005 * Created: 2007 00006 * 00007 * (C) Copyright 2008, Google Inc. 00008 ** Licensed under the Apache License, Version 2.0 (the "License"); 00009 ** you may not use this file except in compliance with the License. 00010 ** You may obtain a copy of the License at 00011 ** http://www.apache.org/licenses/LICENSE-2.0 00012 ** Unless required by applicable law or agreed to in writing, software 00013 ** distributed under the License is distributed on an "AS IS" BASIS, 00014 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00015 ** See the License for the specific language governing permissions and 00016 ** limitations under the License. 00017 * 00018 **********************************************************************/ 00019 00020 #ifndef HYBRID_NEURAL_NET_CLASSIFIER_H 00021 #define HYBRID_NEURAL_NET_CLASSIFIER_H 00022 00023 #include <string> 00024 #include <vector> 00025 00026 #include "char_samp.h" 00027 #include "char_altlist.h" 00028 #include "char_set.h" 00029 #include "classifier_base.h" 00030 #include "feature_base.h" 00031 #include "lang_model.h" 00032 #include "neural_net.h" 00033 #include "tuning_params.h" 00034 00035 namespace tesseract { 00036 00037 // Folding Ratio is the ratio of the max-activation of members of a folding 00038 // set that is used to compute the min-activation of the rest of the set 00039 // static const float kFoldingRatio = 0.75; // see conv_net_classifier.h 00040 00041 class HybridNeuralNetCharClassifier : public CharClassifier { 00042 public: 00043 HybridNeuralNetCharClassifier(CharSet *char_set, TuningParams *params, 00044 FeatureBase *feat_extract); 00045 virtual ~HybridNeuralNetCharClassifier(); 00046 // The main training function. Given a sample and a class ID the classifier 00047 // updates its parameters according to its learning algorithm. This function 00048 // is currently not implemented. TODO(ahmadab): implement end-2-end training 00049 virtual bool Train(CharSamp *char_samp, int ClassID); 00050 // A secondary function needed for training. Allows the trainer to set the 00051 // value of any train-time paramter. This function is currently not 00052 // implemented. TODO(ahmadab): implement end-2-end training 00053 virtual bool SetLearnParam(char *var_name, float val); 00054 // Externally sets the Neural Net used by the classifier. Used for training 00055 void SetNet(tesseract::NeuralNet *net); 00056 00057 // Classifies an input charsamp and return a CharAltList object containing 00058 // the possible candidates and corresponding scores 00059 virtual CharAltList *Classify(CharSamp *char_samp); 00060 // Computes the cost of a specific charsamp being a character (versus a 00061 // non-character: part-of-a-character OR more-than-one-character) 00062 virtual int CharCost(CharSamp *char_samp); 00063 00064 private: 00065 // Neural Net object used for classification 00066 vector<tesseract::NeuralNet *> nets_; 00067 vector<float> net_wgts_; 00068 00069 // data buffers used to hold Neural Net inputs and outputs 00070 float *net_input_; 00071 float *net_output_; 00072 00073 // Init the classifier provided a data-path and a language string 00074 virtual bool Init(const string &data_file_path, const string &lang, 00075 LangModel *lang_mod); 00076 // Loads the NeuralNets needed for the classifier 00077 bool LoadNets(const string &data_file_path, const string &lang); 00078 // Load folding sets 00079 // This function returns true on success or if the file can't be read, 00080 // returns false if an error is encountered. 00081 virtual bool LoadFoldingSets(const string &data_file_path, 00082 const string &lang, 00083 LangModel *lang_mod); 00084 // Folds the output of the NeuralNet using the loaded folding sets 00085 virtual void Fold(); 00086 // Scales the input char_samp and feeds it to the NeuralNet as input 00087 bool RunNets(CharSamp *char_samp); 00088 }; 00089 } 00090 #endif // HYBRID_NEURAL_NET_CLASSIFIER_H