Tesseract  3.02
tesseract-ocr/cube/hybrid_neural_net_classifier.cpp
Go to the documentation of this file.
00001 /**********************************************************************
00002  * File:        charclassifier.cpp
00003  * Description: Implementation of Convolutional-NeuralNet Character Classifier
00004  * Author:    Ahmad Abdulkader
00005  * Created:   2007
00006  *
00007  * (C) Copyright 2008, Google Inc.
00008  ** Licensed under the Apache License, Version 2.0 (the "License");
00009  ** you may not use this file except in compliance with the License.
00010  ** You may obtain a copy of the License at
00011  ** http://www.apache.org/licenses/LICENSE-2.0
00012  ** Unless required by applicable law or agreed to in writing, software
00013  ** distributed under the License is distributed on an "AS IS" BASIS,
00014  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  ** See the License for the specific language governing permissions and
00016  ** limitations under the License.
00017  *
00018  **********************************************************************/
00019 
00020 #include <algorithm>
00021 #include <stdio.h>
00022 #include <stdlib.h>
00023 #include <string>
00024 #include <vector>
00025 #include <wctype.h>
00026 
00027 #include "classifier_base.h"
00028 #include "char_set.h"
00029 #include "const.h"
00030 #include "conv_net_classifier.h"
00031 #include "cube_utils.h"
00032 #include "feature_base.h"
00033 #include "feature_bmp.h"
00034 #include "hybrid_neural_net_classifier.h"
00035 #include "tess_lang_model.h"
00036 
00037 namespace tesseract {
00038 
00039 HybridNeuralNetCharClassifier::HybridNeuralNetCharClassifier(
00040     CharSet *char_set,
00041     TuningParams *params,
00042     FeatureBase *feat_extract)
00043     : CharClassifier(char_set, params, feat_extract) {
00044   net_input_ = NULL;
00045   net_output_ = NULL;
00046 }
00047 
00048 HybridNeuralNetCharClassifier::~HybridNeuralNetCharClassifier() {
00049   for (int net_idx = 0; net_idx < nets_.size(); net_idx++) {
00050     if (nets_[net_idx] != NULL) {
00051       delete nets_[net_idx];
00052     }
00053   }
00054   nets_.clear();
00055 
00056   if (net_input_ != NULL) {
00057     delete []net_input_;
00058     net_input_ = NULL;
00059   }
00060 
00061   if (net_output_ != NULL) {
00062     delete []net_output_;
00063     net_output_ = NULL;
00064   }
00065 }
00066 
00067 // The main training function. Given a sample and a class ID the classifier
00068 // updates its parameters according to its learning algorithm. This function
00069 // is currently not implemented. TODO(ahmadab): implement end-2-end training
00070 bool HybridNeuralNetCharClassifier::Train(CharSamp *char_samp, int ClassID) {
00071   return false;
00072 }
00073 
00074 // A secondary function needed for training. Allows the trainer to set the
00075 // value of any train-time paramter. This function is currently not
00076 // implemented. TODO(ahmadab): implement end-2-end training
00077 bool HybridNeuralNetCharClassifier::SetLearnParam(char *var_name, float val) {
00078   // TODO(ahmadab): implementation of parameter initializing.
00079   return false;
00080 }
00081 
00082 // Folds the output of the NeuralNet using the loaded folding sets
00083 void HybridNeuralNetCharClassifier::Fold() {
00084   // in case insensitive mode
00085   if (case_sensitive_ == false) {
00086     int class_cnt = char_set_->ClassCount();
00087     // fold case
00088     for (int class_id = 0; class_id < class_cnt; class_id++) {
00089       // get class string
00090       const char_32 *str32 = char_set_->ClassString(class_id);
00091       // get the upper case form of the string
00092       string_32 upper_form32 = str32;
00093       for (int ch = 0; ch < upper_form32.length(); ch++) {
00094         if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
00095           upper_form32[ch] = towupper(upper_form32[ch]);
00096         }
00097       }
00098 
00099       // find out the upperform class-id if any
00100       int upper_class_id =
00101           char_set_->ClassID(reinterpret_cast<const char_32 *>(
00102               upper_form32.c_str()));
00103       if (upper_class_id != -1 && class_id != upper_class_id) {
00104         float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]);
00105         net_output_[class_id] = max_out;
00106         net_output_[upper_class_id] = max_out;
00107       }
00108     }
00109   }
00110 
00111   // The folding sets specify how groups of classes should be folded
00112   // Folding involved assigning a min-activation to all the members
00113   // of the folding set. The min-activation is a fraction of the max-activation
00114   // of the members of the folding set
00115   for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
00116     float max_prob = net_output_[fold_sets_[fold_set][0]];
00117 
00118     for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) {
00119       if (net_output_[fold_sets_[fold_set][ch]] > max_prob) {
00120         max_prob = net_output_[fold_sets_[fold_set][ch]];
00121       }
00122     }
00123     for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
00124       net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio,
00125           net_output_[fold_sets_[fold_set][ch]]);
00126     }
00127   }
00128 }
00129 
00130 // compute the features of specified charsamp and
00131 // feedforward the specified nets
00132 bool HybridNeuralNetCharClassifier::RunNets(CharSamp *char_samp) {
00133   int feat_cnt = feat_extract_->FeatureCnt();
00134   int class_cnt = char_set_->ClassCount();
00135 
00136   // allocate i/p and o/p buffers if needed
00137   if (net_input_ == NULL) {
00138     net_input_ = new float[feat_cnt];
00139     if (net_input_ == NULL) {
00140       return false;
00141     }
00142 
00143     net_output_ = new float[class_cnt];
00144     if (net_output_ == NULL) {
00145       return false;
00146     }
00147   }
00148 
00149   // compute input features
00150   if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) {
00151     return false;
00152   }
00153 
00154   // go thru all the nets
00155   memset(net_output_, 0, class_cnt * sizeof(*net_output_));
00156   float *inputs = net_input_;
00157   for (int net_idx = 0; net_idx < nets_.size(); net_idx++) {
00158     // run each net
00159     vector<float> net_out(class_cnt, 0.0);
00160     if (!nets_[net_idx]->FeedForward(inputs, &net_out[0])) {
00161       return false;
00162     }
00163     // add the output values
00164     for (int class_idx = 0; class_idx < class_cnt; class_idx++) {
00165       net_output_[class_idx] += (net_out[class_idx] * net_wgts_[net_idx]);
00166     }
00167     // increment inputs pointer
00168     inputs += nets_[net_idx]->in_cnt();
00169   }
00170 
00171   Fold();
00172 
00173   return true;
00174 }
00175 
00176 // return the cost of being a char
00177 int HybridNeuralNetCharClassifier::CharCost(CharSamp *char_samp) {
00178   // it is by design that a character cost is equal to zero
00179   // when no nets are present. This is the case during training.
00180   if (RunNets(char_samp) == false) {
00181     return 0;
00182   }
00183 
00184   return CubeUtils::Prob2Cost(1.0f - net_output_[0]);
00185 }
00186 
00187 // classifies a charsamp and returns an alternate list
00188 // of chars sorted by char costs
00189 CharAltList *HybridNeuralNetCharClassifier::Classify(CharSamp *char_samp) {
00190   // run the needed nets
00191   if (RunNets(char_samp) == false) {
00192     return NULL;
00193   }
00194 
00195   int class_cnt = char_set_->ClassCount();
00196 
00197   // create an altlist
00198   CharAltList *alt_list = new CharAltList(char_set_, class_cnt);
00199   if (alt_list == NULL) {
00200     return NULL;
00201   }
00202 
00203   for (int out = 1; out < class_cnt; out++) {
00204     int cost = CubeUtils::Prob2Cost(net_output_[out]);
00205     alt_list->Insert(out, cost);
00206   }
00207 
00208   return alt_list;
00209 }
00210 
00211 // set an external net (for training purposes)
00212 void HybridNeuralNetCharClassifier::SetNet(tesseract::NeuralNet *char_net) {
00213 }
00214 
00215 // Load folding sets
00216 // This function returns true on success or if the file can't be read,
00217 // returns false if an error is encountered.
00218 bool HybridNeuralNetCharClassifier::LoadFoldingSets(
00219     const string &data_file_path, const string &lang, LangModel *lang_mod) {
00220   fold_set_cnt_ = 0;
00221   string fold_file_name;
00222   fold_file_name = data_file_path + lang;
00223   fold_file_name += ".cube.fold";
00224 
00225   // folding sets are optional
00226   FILE *fp = fopen(fold_file_name.c_str(), "rb");
00227   if (fp == NULL) {
00228     return true;
00229   }
00230   fclose(fp);
00231 
00232   string fold_sets_str;
00233   if (!CubeUtils::ReadFileToString(fold_file_name.c_str(),
00234                                   &fold_sets_str)) {
00235     return false;
00236   }
00237 
00238   // split into lines
00239   vector<string> str_vec;
00240   CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec);
00241   fold_set_cnt_ = str_vec.size();
00242   fold_sets_ = new int *[fold_set_cnt_];
00243   if (fold_sets_ == NULL) {
00244     return false;
00245   }
00246   fold_set_len_ = new int[fold_set_cnt_];
00247   if (fold_set_len_ == NULL) {
00248     fold_set_cnt_ = 0;
00249     return false;
00250   }
00251 
00252   for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) {
00253     reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters(
00254         &str_vec[fold_set]);
00255 
00256     // if all or all but one character are invalid, invalidate this set
00257     if (str_vec[fold_set].length() <= 1) {
00258       fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): "
00259               "invalidating folding set %d\n", fold_set);
00260       fold_set_len_[fold_set] = 0;
00261       fold_sets_[fold_set] = NULL;
00262       continue;
00263     }
00264 
00265     string_32 str32;
00266     CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32);
00267     fold_set_len_[fold_set] = str32.length();
00268     fold_sets_[fold_set] = new int[fold_set_len_[fold_set]];
00269     if (fold_sets_[fold_set] == NULL) {
00270       fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadFoldingSets): "
00271               "could not allocate folding set\n");
00272       fold_set_cnt_ = fold_set;
00273       return false;
00274     }
00275     for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
00276       fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]);
00277     }
00278   }
00279   return true;
00280 }
00281 
00282 // Init the classifier provided a data-path and a language string
00283 bool HybridNeuralNetCharClassifier::Init(const string &data_file_path,
00284                                          const string &lang,
00285                                          LangModel *lang_mod) {
00286   if (init_ == true) {
00287     return true;
00288   }
00289 
00290   // load the nets if any. This function will return true if the net file
00291   // does not exist. But will fail if the net did not pass the sanity checks
00292   if (!LoadNets(data_file_path, lang)) {
00293     return false;
00294   }
00295 
00296   // load the folding sets if any. This function will return true if the
00297   // file does not exist. But will fail if the it did not pass the sanity checks
00298   if (!LoadFoldingSets(data_file_path, lang, lang_mod)) {
00299     return false;
00300   }
00301 
00302   init_ = true;
00303   return true;
00304 }
00305 
00306 // Load the classifier's Neural Nets
00307 // This function will return true if the net file does not exist.
00308 // But will fail if the net did not pass the sanity checks
00309 bool HybridNeuralNetCharClassifier::LoadNets(const string &data_file_path,
00310                                              const string &lang) {
00311   string hybrid_net_file;
00312   string junk_net_file;
00313 
00314   // add the lang identifier
00315   hybrid_net_file = data_file_path + lang;
00316   hybrid_net_file += ".cube.hybrid";
00317 
00318   // neural network is optional
00319   FILE *fp = fopen(hybrid_net_file.c_str(), "rb");
00320   if (fp == NULL) {
00321     return true;
00322   }
00323   fclose(fp);
00324 
00325   string str;
00326   if (!CubeUtils::ReadFileToString(hybrid_net_file.c_str(), &str)) {
00327     return false;
00328   }
00329 
00330   // split into lines
00331   vector<string> str_vec;
00332   CubeUtils::SplitStringUsing(str, "\r\n", &str_vec);
00333   if (str_vec.size() <= 0) {
00334     return false;
00335   }
00336 
00337   // create and add the nets
00338   nets_.resize(str_vec.size(), NULL);
00339   net_wgts_.resize(str_vec.size(), 0);
00340   int total_input_size = 0;
00341   for (int net_idx = 0; net_idx < str_vec.size(); net_idx++) {
00342     // parse the string
00343     vector<string> tokens_vec;
00344     CubeUtils::SplitStringUsing(str_vec[net_idx], " \t", &tokens_vec);
00345     // has to be 2 tokens, net name and input size
00346     if (tokens_vec.size() != 2) {
00347       return false;
00348     }
00349     // load the net
00350     string net_file_name = data_file_path + tokens_vec[0];
00351     nets_[net_idx] = tesseract::NeuralNet::FromFile(net_file_name.c_str());
00352     if (nets_[net_idx] == NULL) {
00353       return false;
00354     }
00355     // parse the input size and validate it
00356     net_wgts_[net_idx] = atof(tokens_vec[1].c_str());
00357     if (net_wgts_[net_idx] < 0.0) {
00358       return false;
00359     }
00360     total_input_size += nets_[net_idx]->in_cnt();
00361   }
00362   // validate total input count
00363   if (total_input_size != feat_extract_->FeatureCnt()) {
00364     return false;
00365   }
00366   // success
00367   return true;
00368 }
00369 }  // tesseract