Tesseract
3.02
|
00001 /********************************************************************** 00002 * File: charclassifier.cpp 00003 * Description: Implementation of Convolutional-NeuralNet Character Classifier 00004 * Author: Ahmad Abdulkader 00005 * Created: 2007 00006 * 00007 * (C) Copyright 2008, Google Inc. 00008 ** Licensed under the Apache License, Version 2.0 (the "License"); 00009 ** you may not use this file except in compliance with the License. 00010 ** You may obtain a copy of the License at 00011 ** http://www.apache.org/licenses/LICENSE-2.0 00012 ** Unless required by applicable law or agreed to in writing, software 00013 ** distributed under the License is distributed on an "AS IS" BASIS, 00014 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00015 ** See the License for the specific language governing permissions and 00016 ** limitations under the License. 00017 * 00018 **********************************************************************/ 00019 00020 #include <algorithm> 00021 #include <stdio.h> 00022 #include <stdlib.h> 00023 #include <string> 00024 #include <vector> 00025 #include <wctype.h> 00026 00027 #include "classifier_base.h" 00028 #include "char_set.h" 00029 #include "const.h" 00030 #include "conv_net_classifier.h" 00031 #include "cube_utils.h" 00032 #include "feature_base.h" 00033 #include "feature_bmp.h" 00034 #include "hybrid_neural_net_classifier.h" 00035 #include "tess_lang_model.h" 00036 00037 namespace tesseract { 00038 00039 HybridNeuralNetCharClassifier::HybridNeuralNetCharClassifier( 00040 CharSet *char_set, 00041 TuningParams *params, 00042 FeatureBase *feat_extract) 00043 : CharClassifier(char_set, params, feat_extract) { 00044 net_input_ = NULL; 00045 net_output_ = NULL; 00046 } 00047 00048 HybridNeuralNetCharClassifier::~HybridNeuralNetCharClassifier() { 00049 for (int net_idx = 0; net_idx < nets_.size(); net_idx++) { 00050 if (nets_[net_idx] != NULL) { 00051 delete nets_[net_idx]; 00052 } 00053 } 00054 nets_.clear(); 00055 00056 if (net_input_ != NULL) { 00057 delete []net_input_; 00058 net_input_ = NULL; 00059 } 00060 00061 if (net_output_ != NULL) { 00062 delete []net_output_; 00063 net_output_ = NULL; 00064 } 00065 } 00066 00067 // The main training function. Given a sample and a class ID the classifier 00068 // updates its parameters according to its learning algorithm. This function 00069 // is currently not implemented. TODO(ahmadab): implement end-2-end training 00070 bool HybridNeuralNetCharClassifier::Train(CharSamp *char_samp, int ClassID) { 00071 return false; 00072 } 00073 00074 // A secondary function needed for training. Allows the trainer to set the 00075 // value of any train-time paramter. This function is currently not 00076 // implemented. TODO(ahmadab): implement end-2-end training 00077 bool HybridNeuralNetCharClassifier::SetLearnParam(char *var_name, float val) { 00078 // TODO(ahmadab): implementation of parameter initializing. 00079 return false; 00080 } 00081 00082 // Folds the output of the NeuralNet using the loaded folding sets 00083 void HybridNeuralNetCharClassifier::Fold() { 00084 // in case insensitive mode 00085 if (case_sensitive_ == false) { 00086 int class_cnt = char_set_->ClassCount(); 00087 // fold case 00088 for (int class_id = 0; class_id < class_cnt; class_id++) { 00089 // get class string 00090 const char_32 *str32 = char_set_->ClassString(class_id); 00091 // get the upper case form of the string 00092 string_32 upper_form32 = str32; 00093 for (int ch = 0; ch < upper_form32.length(); ch++) { 00094 if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) { 00095 upper_form32[ch] = towupper(upper_form32[ch]); 00096 } 00097 } 00098 00099 // find out the upperform class-id if any 00100 int upper_class_id = 00101 char_set_->ClassID(reinterpret_cast<const char_32 *>( 00102 upper_form32.c_str())); 00103 if (upper_class_id != -1 && class_id != upper_class_id) { 00104 float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]); 00105 net_output_[class_id] = max_out; 00106 net_output_[upper_class_id] = max_out; 00107 } 00108 } 00109 } 00110 00111 // The folding sets specify how groups of classes should be folded 00112 // Folding involved assigning a min-activation to all the members 00113 // of the folding set. The min-activation is a fraction of the max-activation 00114 // of the members of the folding set 00115 for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) { 00116 float max_prob = net_output_[fold_sets_[fold_set][0]]; 00117 00118 for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) { 00119 if (net_output_[fold_sets_[fold_set][ch]] > max_prob) { 00120 max_prob = net_output_[fold_sets_[fold_set][ch]]; 00121 } 00122 } 00123 for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) { 00124 net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio, 00125 net_output_[fold_sets_[fold_set][ch]]); 00126 } 00127 } 00128 } 00129 00130 // compute the features of specified charsamp and 00131 // feedforward the specified nets 00132 bool HybridNeuralNetCharClassifier::RunNets(CharSamp *char_samp) { 00133 int feat_cnt = feat_extract_->FeatureCnt(); 00134 int class_cnt = char_set_->ClassCount(); 00135 00136 // allocate i/p and o/p buffers if needed 00137 if (net_input_ == NULL) { 00138 net_input_ = new float[feat_cnt]; 00139 if (net_input_ == NULL) { 00140 return false; 00141 } 00142 00143 net_output_ = new float[class_cnt]; 00144 if (net_output_ == NULL) { 00145 return false; 00146 } 00147 } 00148 00149 // compute input features 00150 if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) { 00151 return false; 00152 } 00153 00154 // go thru all the nets 00155 memset(net_output_, 0, class_cnt * sizeof(*net_output_)); 00156 float *inputs = net_input_; 00157 for (int net_idx = 0; net_idx < nets_.size(); net_idx++) { 00158 // run each net 00159 vector<float> net_out(class_cnt, 0.0); 00160 if (!nets_[net_idx]->FeedForward(inputs, &net_out[0])) { 00161 return false; 00162 } 00163 // add the output values 00164 for (int class_idx = 0; class_idx < class_cnt; class_idx++) { 00165 net_output_[class_idx] += (net_out[class_idx] * net_wgts_[net_idx]); 00166 } 00167 // increment inputs pointer 00168 inputs += nets_[net_idx]->in_cnt(); 00169 } 00170 00171 Fold(); 00172 00173 return true; 00174 } 00175 00176 // return the cost of being a char 00177 int HybridNeuralNetCharClassifier::CharCost(CharSamp *char_samp) { 00178 // it is by design that a character cost is equal to zero 00179 // when no nets are present. This is the case during training. 00180 if (RunNets(char_samp) == false) { 00181 return 0; 00182 } 00183 00184 return CubeUtils::Prob2Cost(1.0f - net_output_[0]); 00185 } 00186 00187 // classifies a charsamp and returns an alternate list 00188 // of chars sorted by char costs 00189 CharAltList *HybridNeuralNetCharClassifier::Classify(CharSamp *char_samp) { 00190 // run the needed nets 00191 if (RunNets(char_samp) == false) { 00192 return NULL; 00193 } 00194 00195 int class_cnt = char_set_->ClassCount(); 00196 00197 // create an altlist 00198 CharAltList *alt_list = new CharAltList(char_set_, class_cnt); 00199 if (alt_list == NULL) { 00200 return NULL; 00201 } 00202 00203 for (int out = 1; out < class_cnt; out++) { 00204 int cost = CubeUtils::Prob2Cost(net_output_[out]); 00205 alt_list->Insert(out, cost); 00206 } 00207 00208 return alt_list; 00209 } 00210 00211 // set an external net (for training purposes) 00212 void HybridNeuralNetCharClassifier::SetNet(tesseract::NeuralNet *char_net) { 00213 } 00214 00215 // Load folding sets 00216 // This function returns true on success or if the file can't be read, 00217 // returns false if an error is encountered. 00218 bool HybridNeuralNetCharClassifier::LoadFoldingSets( 00219 const string &data_file_path, const string &lang, LangModel *lang_mod) { 00220 fold_set_cnt_ = 0; 00221 string fold_file_name; 00222 fold_file_name = data_file_path + lang; 00223 fold_file_name += ".cube.fold"; 00224 00225 // folding sets are optional 00226 FILE *fp = fopen(fold_file_name.c_str(), "rb"); 00227 if (fp == NULL) { 00228 return true; 00229 } 00230 fclose(fp); 00231 00232 string fold_sets_str; 00233 if (!CubeUtils::ReadFileToString(fold_file_name.c_str(), 00234 &fold_sets_str)) { 00235 return false; 00236 } 00237 00238 // split into lines 00239 vector<string> str_vec; 00240 CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec); 00241 fold_set_cnt_ = str_vec.size(); 00242 fold_sets_ = new int *[fold_set_cnt_]; 00243 if (fold_sets_ == NULL) { 00244 return false; 00245 } 00246 fold_set_len_ = new int[fold_set_cnt_]; 00247 if (fold_set_len_ == NULL) { 00248 fold_set_cnt_ = 0; 00249 return false; 00250 } 00251 00252 for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) { 00253 reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters( 00254 &str_vec[fold_set]); 00255 00256 // if all or all but one character are invalid, invalidate this set 00257 if (str_vec[fold_set].length() <= 1) { 00258 fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): " 00259 "invalidating folding set %d\n", fold_set); 00260 fold_set_len_[fold_set] = 0; 00261 fold_sets_[fold_set] = NULL; 00262 continue; 00263 } 00264 00265 string_32 str32; 00266 CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32); 00267 fold_set_len_[fold_set] = str32.length(); 00268 fold_sets_[fold_set] = new int[fold_set_len_[fold_set]]; 00269 if (fold_sets_[fold_set] == NULL) { 00270 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadFoldingSets): " 00271 "could not allocate folding set\n"); 00272 fold_set_cnt_ = fold_set; 00273 return false; 00274 } 00275 for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) { 00276 fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]); 00277 } 00278 } 00279 return true; 00280 } 00281 00282 // Init the classifier provided a data-path and a language string 00283 bool HybridNeuralNetCharClassifier::Init(const string &data_file_path, 00284 const string &lang, 00285 LangModel *lang_mod) { 00286 if (init_ == true) { 00287 return true; 00288 } 00289 00290 // load the nets if any. This function will return true if the net file 00291 // does not exist. But will fail if the net did not pass the sanity checks 00292 if (!LoadNets(data_file_path, lang)) { 00293 return false; 00294 } 00295 00296 // load the folding sets if any. This function will return true if the 00297 // file does not exist. But will fail if the it did not pass the sanity checks 00298 if (!LoadFoldingSets(data_file_path, lang, lang_mod)) { 00299 return false; 00300 } 00301 00302 init_ = true; 00303 return true; 00304 } 00305 00306 // Load the classifier's Neural Nets 00307 // This function will return true if the net file does not exist. 00308 // But will fail if the net did not pass the sanity checks 00309 bool HybridNeuralNetCharClassifier::LoadNets(const string &data_file_path, 00310 const string &lang) { 00311 string hybrid_net_file; 00312 string junk_net_file; 00313 00314 // add the lang identifier 00315 hybrid_net_file = data_file_path + lang; 00316 hybrid_net_file += ".cube.hybrid"; 00317 00318 // neural network is optional 00319 FILE *fp = fopen(hybrid_net_file.c_str(), "rb"); 00320 if (fp == NULL) { 00321 return true; 00322 } 00323 fclose(fp); 00324 00325 string str; 00326 if (!CubeUtils::ReadFileToString(hybrid_net_file.c_str(), &str)) { 00327 return false; 00328 } 00329 00330 // split into lines 00331 vector<string> str_vec; 00332 CubeUtils::SplitStringUsing(str, "\r\n", &str_vec); 00333 if (str_vec.size() <= 0) { 00334 return false; 00335 } 00336 00337 // create and add the nets 00338 nets_.resize(str_vec.size(), NULL); 00339 net_wgts_.resize(str_vec.size(), 0); 00340 int total_input_size = 0; 00341 for (int net_idx = 0; net_idx < str_vec.size(); net_idx++) { 00342 // parse the string 00343 vector<string> tokens_vec; 00344 CubeUtils::SplitStringUsing(str_vec[net_idx], " \t", &tokens_vec); 00345 // has to be 2 tokens, net name and input size 00346 if (tokens_vec.size() != 2) { 00347 return false; 00348 } 00349 // load the net 00350 string net_file_name = data_file_path + tokens_vec[0]; 00351 nets_[net_idx] = tesseract::NeuralNet::FromFile(net_file_name.c_str()); 00352 if (nets_[net_idx] == NULL) { 00353 return false; 00354 } 00355 // parse the input size and validate it 00356 net_wgts_[net_idx] = atof(tokens_vec[1].c_str()); 00357 if (net_wgts_[net_idx] < 0.0) { 00358 return false; 00359 } 00360 total_input_size += nets_[net_idx]->in_cnt(); 00361 } 00362 // validate total input count 00363 if (total_input_size != feat_extract_->FeatureCnt()) { 00364 return false; 00365 } 00366 // success 00367 return true; 00368 } 00369 } // tesseract