Tesseract
3.02
|
00001 /********************************************************************** 00002 * File: charclassifier.cpp 00003 * Description: Implementation of Convolutional-NeuralNet Character Classifier 00004 * Author: Ahmad Abdulkader 00005 * Created: 2007 00006 * 00007 * (C) Copyright 2008, Google Inc. 00008 ** Licensed under the Apache License, Version 2.0 (the "License"); 00009 ** you may not use this file except in compliance with the License. 00010 ** You may obtain a copy of the License at 00011 ** http://www.apache.org/licenses/LICENSE-2.0 00012 ** Unless required by applicable law or agreed to in writing, software 00013 ** distributed under the License is distributed on an "AS IS" BASIS, 00014 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00015 ** See the License for the specific language governing permissions and 00016 ** limitations under the License. 00017 * 00018 **********************************************************************/ 00019 00020 #include <algorithm> 00021 #include <stdio.h> 00022 #include <stdlib.h> 00023 #include <string> 00024 #include <vector> 00025 #include <wctype.h> 00026 00027 #include "char_set.h" 00028 #include "classifier_base.h" 00029 #include "const.h" 00030 #include "conv_net_classifier.h" 00031 #include "cube_utils.h" 00032 #include "feature_base.h" 00033 #include "feature_bmp.h" 00034 #include "tess_lang_model.h" 00035 00036 namespace tesseract { 00037 00038 ConvNetCharClassifier::ConvNetCharClassifier(CharSet *char_set, 00039 TuningParams *params, 00040 FeatureBase *feat_extract) 00041 : CharClassifier(char_set, params, feat_extract) { 00042 char_net_ = NULL; 00043 net_input_ = NULL; 00044 net_output_ = NULL; 00045 } 00046 00047 ConvNetCharClassifier::~ConvNetCharClassifier() { 00048 if (char_net_ != NULL) { 00049 delete char_net_; 00050 char_net_ = NULL; 00051 } 00052 00053 if (net_input_ != NULL) { 00054 delete []net_input_; 00055 net_input_ = NULL; 00056 } 00057 00058 if (net_output_ != NULL) { 00059 delete []net_output_; 00060 net_output_ = NULL; 00061 } 00062 } 00063 00064 // The main training function. Given a sample and a class ID the classifier 00065 // updates its parameters according to its learning algorithm. This function 00066 // is currently not implemented. TODO(ahmadab): implement end-2-end training 00067 bool ConvNetCharClassifier::Train(CharSamp *char_samp, int ClassID) { 00068 return false; 00069 } 00070 00071 // A secondary function needed for training. Allows the trainer to set the 00072 // value of any train-time paramter. This function is currently not 00073 // implemented. TODO(ahmadab): implement end-2-end training 00074 bool ConvNetCharClassifier::SetLearnParam(char *var_name, float val) { 00075 // TODO(ahmadab): implementation of parameter initializing. 00076 return false; 00077 } 00078 00079 // Folds the output of the NeuralNet using the loaded folding sets 00080 void ConvNetCharClassifier::Fold() { 00081 // in case insensitive mode 00082 if (case_sensitive_ == false) { 00083 int class_cnt = char_set_->ClassCount(); 00084 // fold case 00085 for (int class_id = 0; class_id < class_cnt; class_id++) { 00086 // get class string 00087 const char_32 *str32 = char_set_->ClassString(class_id); 00088 // get the upper case form of the string 00089 string_32 upper_form32 = str32; 00090 for (int ch = 0; ch < upper_form32.length(); ch++) { 00091 if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) { 00092 upper_form32[ch] = towupper(upper_form32[ch]); 00093 } 00094 } 00095 00096 // find out the upperform class-id if any 00097 int upper_class_id = 00098 char_set_->ClassID(reinterpret_cast<const char_32 *>( 00099 upper_form32.c_str())); 00100 if (upper_class_id != -1 && class_id != upper_class_id) { 00101 float max_out = MAX(net_output_[class_id], net_output_[upper_class_id]); 00102 net_output_[class_id] = max_out; 00103 net_output_[upper_class_id] = max_out; 00104 } 00105 } 00106 } 00107 00108 // The folding sets specify how groups of classes should be folded 00109 // Folding involved assigning a min-activation to all the members 00110 // of the folding set. The min-activation is a fraction of the max-activation 00111 // of the members of the folding set 00112 for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) { 00113 if (fold_set_len_[fold_set] == 0) 00114 continue; 00115 float max_prob = net_output_[fold_sets_[fold_set][0]]; 00116 for (int ch = 1; ch < fold_set_len_[fold_set]; ch++) { 00117 if (net_output_[fold_sets_[fold_set][ch]] > max_prob) { 00118 max_prob = net_output_[fold_sets_[fold_set][ch]]; 00119 } 00120 } 00121 for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) { 00122 net_output_[fold_sets_[fold_set][ch]] = MAX(max_prob * kFoldingRatio, 00123 net_output_[fold_sets_[fold_set][ch]]); 00124 } 00125 } 00126 } 00127 00128 // Compute the features of specified charsamp and feedforward the 00129 // specified nets 00130 bool ConvNetCharClassifier::RunNets(CharSamp *char_samp) { 00131 if (char_net_ == NULL) { 00132 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00133 "NeuralNet is NULL\n"); 00134 return false; 00135 } 00136 int feat_cnt = char_net_->in_cnt(); 00137 int class_cnt = char_set_->ClassCount(); 00138 00139 // allocate i/p and o/p buffers if needed 00140 if (net_input_ == NULL) { 00141 net_input_ = new float[feat_cnt]; 00142 if (net_input_ == NULL) { 00143 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00144 "unable to allocate memory for input nodes\n"); 00145 return false; 00146 } 00147 00148 net_output_ = new float[class_cnt]; 00149 if (net_output_ == NULL) { 00150 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00151 "unable to allocate memory for output nodes\n"); 00152 return false; 00153 } 00154 } 00155 00156 // compute input features 00157 if (feat_extract_->ComputeFeatures(char_samp, net_input_) == false) { 00158 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00159 "unable to compute features\n"); 00160 return false; 00161 } 00162 00163 if (char_net_ != NULL) { 00164 if (char_net_->FeedForward(net_input_, net_output_) == false) { 00165 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::RunNets): " 00166 "unable to run feed-forward\n"); 00167 return false; 00168 } 00169 } else { 00170 return false; 00171 } 00172 Fold(); 00173 return true; 00174 } 00175 00176 // return the cost of being a char 00177 int ConvNetCharClassifier::CharCost(CharSamp *char_samp) { 00178 if (RunNets(char_samp) == false) { 00179 return 0; 00180 } 00181 return CubeUtils::Prob2Cost(1.0f - net_output_[0]); 00182 } 00183 00184 // classifies a charsamp and returns an alternate list 00185 // of chars sorted by char costs 00186 CharAltList *ConvNetCharClassifier::Classify(CharSamp *char_samp) { 00187 // run the needed nets 00188 if (RunNets(char_samp) == false) { 00189 return NULL; 00190 } 00191 00192 int class_cnt = char_set_->ClassCount(); 00193 00194 // create an altlist 00195 CharAltList *alt_list = new CharAltList(char_set_, class_cnt); 00196 if (alt_list == NULL) { 00197 fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::Classify): " 00198 "returning emtpy CharAltList\n"); 00199 return NULL; 00200 } 00201 00202 for (int out = 1; out < class_cnt; out++) { 00203 int cost = CubeUtils::Prob2Cost(net_output_[out]); 00204 alt_list->Insert(out, cost); 00205 } 00206 00207 return alt_list; 00208 } 00209 00210 // Set an external net (for training purposes) 00211 void ConvNetCharClassifier::SetNet(tesseract::NeuralNet *char_net) { 00212 if (char_net_ != NULL) { 00213 delete char_net_; 00214 char_net_ = NULL; 00215 } 00216 char_net_ = char_net; 00217 } 00218 00219 // This function will return true if the file does not exist. 00220 // But will fail if the it did not pass the sanity checks 00221 bool ConvNetCharClassifier::LoadFoldingSets(const string &data_file_path, 00222 const string &lang, 00223 LangModel *lang_mod) { 00224 fold_set_cnt_ = 0; 00225 string fold_file_name; 00226 fold_file_name = data_file_path + lang; 00227 fold_file_name += ".cube.fold"; 00228 00229 // folding sets are optional 00230 FILE *fp = fopen(fold_file_name.c_str(), "rb"); 00231 if (fp == NULL) { 00232 return true; 00233 } 00234 fclose(fp); 00235 00236 string fold_sets_str; 00237 if (!CubeUtils::ReadFileToString(fold_file_name.c_str(), 00238 &fold_sets_str)) { 00239 return false; 00240 } 00241 00242 // split into lines 00243 vector<string> str_vec; 00244 CubeUtils::SplitStringUsing(fold_sets_str, "\r\n", &str_vec); 00245 fold_set_cnt_ = str_vec.size(); 00246 00247 fold_sets_ = new int *[fold_set_cnt_]; 00248 if (fold_sets_ == NULL) { 00249 return false; 00250 } 00251 fold_set_len_ = new int[fold_set_cnt_]; 00252 if (fold_set_len_ == NULL) { 00253 fold_set_cnt_ = 0; 00254 return false; 00255 } 00256 00257 for (int fold_set = 0; fold_set < fold_set_cnt_; fold_set++) { 00258 reinterpret_cast<TessLangModel *>(lang_mod)->RemoveInvalidCharacters( 00259 &str_vec[fold_set]); 00260 00261 // if all or all but one character are invalid, invalidate this set 00262 if (str_vec[fold_set].length() <= 1) { 00263 fprintf(stderr, "Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): " 00264 "invalidating folding set %d\n", fold_set); 00265 fold_set_len_[fold_set] = 0; 00266 fold_sets_[fold_set] = NULL; 00267 continue; 00268 } 00269 00270 string_32 str32; 00271 CubeUtils::UTF8ToUTF32(str_vec[fold_set].c_str(), &str32); 00272 fold_set_len_[fold_set] = str32.length(); 00273 fold_sets_[fold_set] = new int[fold_set_len_[fold_set]]; 00274 if (fold_sets_[fold_set] == NULL) { 00275 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadFoldingSets): " 00276 "could not allocate folding set\n"); 00277 fold_set_cnt_ = fold_set; 00278 return false; 00279 } 00280 for (int ch = 0; ch < fold_set_len_[fold_set]; ch++) { 00281 fold_sets_[fold_set][ch] = char_set_->ClassID(str32[ch]); 00282 } 00283 } 00284 return true; 00285 } 00286 00287 // Init the classifier provided a data-path and a language string 00288 bool ConvNetCharClassifier::Init(const string &data_file_path, 00289 const string &lang, 00290 LangModel *lang_mod) { 00291 if (init_) { 00292 return true; 00293 } 00294 00295 // load the nets if any. This function will return true if the net file 00296 // does not exist. But will fail if the net did not pass the sanity checks 00297 if (!LoadNets(data_file_path, lang)) { 00298 return false; 00299 } 00300 00301 // load the folding sets if any. This function will return true if the 00302 // file does not exist. But will fail if the it did not pass the sanity checks 00303 if (!LoadFoldingSets(data_file_path, lang, lang_mod)) { 00304 return false; 00305 } 00306 00307 init_ = true; 00308 return true; 00309 } 00310 00311 // Load the classifier's Neural Nets 00312 // This function will return true if the net file does not exist. 00313 // But will fail if the net did not pass the sanity checks 00314 bool ConvNetCharClassifier::LoadNets(const string &data_file_path, 00315 const string &lang) { 00316 string char_net_file; 00317 00318 // add the lang identifier 00319 char_net_file = data_file_path + lang; 00320 char_net_file += ".cube.nn"; 00321 00322 // neural network is optional 00323 FILE *fp = fopen(char_net_file.c_str(), "rb"); 00324 if (fp == NULL) { 00325 return true; 00326 } 00327 fclose(fp); 00328 00329 // load main net 00330 char_net_ = tesseract::NeuralNet::FromFile(char_net_file.c_str()); 00331 if (char_net_ == NULL) { 00332 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): " 00333 "could not load %s\n", char_net_file.c_str()); 00334 return false; 00335 } 00336 00337 // validate net 00338 if (char_net_->in_cnt()!= feat_extract_->FeatureCnt()) { 00339 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): " 00340 "could not validate net %s\n", char_net_file.c_str()); 00341 return false; 00342 } 00343 00344 // alloc net i/o buffers 00345 int feat_cnt = char_net_->in_cnt(); 00346 int class_cnt = char_set_->ClassCount(); 00347 00348 if (char_net_->out_cnt() != class_cnt) { 00349 fprintf(stderr, "Cube ERROR (ConvNetCharClassifier::LoadNets): " 00350 "output count (%d) and class count (%d) are not equal\n", 00351 char_net_->out_cnt(), class_cnt); 00352 return false; 00353 } 00354 00355 // allocate i/p and o/p buffers if needed 00356 if (net_input_ == NULL) { 00357 net_input_ = new float[feat_cnt]; 00358 if (net_input_ == NULL) { 00359 return false; 00360 } 00361 00362 net_output_ = new float[class_cnt]; 00363 if (net_output_ == NULL) { 00364 return false; 00365 } 00366 } 00367 00368 return true; 00369 } 00370 } // tesseract