Tesseract  3.02
tesseract-ocr/cube/tess_lang_model.cpp
Go to the documentation of this file.
00001 /**********************************************************************
00002  * File:        tess_lang_model.cpp
00003  * Description: Implementation of the Tesseract Language Model Class
00004  * Author:    Ahmad Abdulkader
00005  * Created:   2008
00006  *
00007  * (C) Copyright 2008, Google Inc.
00008  ** Licensed under the Apache License, Version 2.0 (the "License");
00009  ** you may not use this file except in compliance with the License.
00010  ** You may obtain a copy of the License at
00011  ** http://www.apache.org/licenses/LICENSE-2.0
00012  ** Unless required by applicable law or agreed to in writing, software
00013  ** distributed under the License is distributed on an "AS IS" BASIS,
00014  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  ** See the License for the specific language governing permissions and
00016  ** limitations under the License.
00017  *
00018  **********************************************************************/
00019 
00020 // The TessLangModel class abstracts the Tesseract language model. It inherits
00021 // from the LangModel class. The Tesseract language model encompasses several
00022 // Dawgs (words from training data, punctuation, numbers, document words).
00023 // On top of this Cube adds an OOD state machine
00024 // The class provides methods to traverse the language model in a generative
00025 // fashion. Given any node in the DAWG, the language model can generate a list
00026 // of children (or fan-out) edges
00027 
00028 #include <string>
00029 #include <vector>
00030 
00031 #include "char_samp.h"
00032 #include "cube_utils.h"
00033 #include "dict.h"
00034 #include "tesseractclass.h"
00035 #include "tess_lang_model.h"
00036 #include "tessdatamanager.h"
00037 #include "unicharset.h"
00038 
00039 namespace tesseract {
00040 // max fan-out (used for preallocation). Initialized here, but modified by
00041 // constructor
00042 int TessLangModel::max_edge_ = 4096;
00043 
00044 // Language model extra State machines
00045 const Dawg *TessLangModel::ood_dawg_ = reinterpret_cast<Dawg *>(DAWG_OOD);
00046 const Dawg *TessLangModel::number_dawg_ = reinterpret_cast<Dawg *>(DAWG_NUMBER);
00047 
00048 // number state machine
00049 const int TessLangModel::num_state_machine_[kStateCnt][kNumLiteralCnt] = {
00050   {0, 1, 1, NUM_TRM, NUM_TRM},
00051   {NUM_TRM, 1, 1, 3, 2},
00052   {NUM_TRM, NUM_TRM, 1, NUM_TRM, 2},
00053   {NUM_TRM, NUM_TRM, 3, NUM_TRM, 2},
00054 };
00055 const int TessLangModel::num_max_repeat_[kStateCnt] = {3, 32, 8, 3};
00056 
00057 // thresholds and penalties
00058 int TessLangModel::max_ood_shape_cost_ = CubeUtils::Prob2Cost(1e-4);
00059 
00060 TessLangModel::TessLangModel(const string &lm_params,
00061                              const string &data_file_path,
00062                              bool load_system_dawg,
00063                              TessdataManager *tessdata_manager,
00064                              CubeRecoContext *cntxt) {
00065   cntxt_ = cntxt;
00066   has_case_ = cntxt_->HasCase();
00067   // Load the rest of the language model elements from file
00068   LoadLangModelElements(lm_params);
00069   // Load word_dawgs_ if needed.
00070   if (tessdata_manager->SeekToStart(TESSDATA_CUBE_UNICHARSET)) {
00071     word_dawgs_ = new DawgVector();
00072     if (load_system_dawg &&
00073         tessdata_manager->SeekToStart(TESSDATA_CUBE_SYSTEM_DAWG)) {
00074       // The last parameter to the Dawg constructor (the debug level) is set to
00075       // false, until Cube has a way to express its preferred debug level.
00076       *word_dawgs_ +=  new SquishedDawg(tessdata_manager->GetDataFilePtr(),
00077                                         DAWG_TYPE_WORD,
00078                                         cntxt_->Lang().c_str(),
00079                                         SYSTEM_DAWG_PERM, false);
00080     }
00081   } else {
00082     word_dawgs_ = NULL;
00083   }
00084 }
00085 
00086 // Cleanup an edge array
00087 void TessLangModel::FreeEdges(int edge_cnt, LangModEdge **edge_array) {
00088   if (edge_array != NULL) {
00089     for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) {
00090       if (edge_array[edge_idx] != NULL) {
00091         delete edge_array[edge_idx];
00092       }
00093     }
00094     delete []edge_array;
00095   }
00096 }
00097 
00098 // Determines if a sequence of 32-bit chars is valid in this language model
00099 // starting from the specified edge. If the eow_flag is ON, also checks for
00100 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last
00101 // edge
00102 bool TessLangModel::IsValidSequence(LangModEdge *edge,
00103                                     const char_32 *sequence,
00104                                     bool eow_flag,
00105                                     LangModEdge **final_edge) {
00106   // get the edges emerging from this edge
00107   int edge_cnt = 0;
00108   LangModEdge **edge_array = GetEdges(NULL, edge, &edge_cnt);
00109 
00110   // find the 1st char in the sequence in the children
00111   for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) {
00112     // found a match
00113     if (sequence[0] == edge_array[edge_idx]->EdgeString()[0]) {
00114       // if this is the last char
00115       if (sequence[1] == 0) {
00116         // succeed if we are in prefix mode or this is a terminal edge
00117         if (eow_flag == false || edge_array[edge_idx]->IsEOW()) {
00118           if (final_edge != NULL) {
00119             (*final_edge) = edge_array[edge_idx];
00120             edge_array[edge_idx] = NULL;
00121           }
00122 
00123           FreeEdges(edge_cnt, edge_array);
00124           return true;
00125         }
00126       } else {
00127         // not the last char continue checking
00128         if (IsValidSequence(edge_array[edge_idx], sequence + 1, eow_flag,
00129                             final_edge) == true) {
00130           FreeEdges(edge_cnt, edge_array);
00131           return true;
00132         }
00133       }
00134     }
00135   }
00136 
00137   FreeEdges(edge_cnt, edge_array);
00138   return false;
00139 }
00140 
00141 // Determines if a sequence of 32-bit chars is valid in this language model
00142 // starting from the root. If the eow_flag is ON, also checks for
00143 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last
00144 // edge
00145 bool TessLangModel::IsValidSequence(const char_32 *sequence, bool eow_flag,
00146                                     LangModEdge **final_edge) {
00147   if (final_edge != NULL) {
00148     (*final_edge) = NULL;
00149   }
00150 
00151   return IsValidSequence(NULL, sequence, eow_flag, final_edge);
00152 }
00153 
00154 bool TessLangModel::IsLeadingPunc(const char_32 ch) {
00155   return lead_punc_.find(ch) != string::npos;
00156 }
00157 
00158 bool TessLangModel::IsTrailingPunc(const char_32 ch) {
00159   return trail_punc_.find(ch) != string::npos;
00160 }
00161 
00162 bool TessLangModel::IsDigit(const char_32 ch) {
00163   return digits_.find(ch) != string::npos;
00164 }
00165 
00166 // The general fan-out generation function. Returns the list of edges
00167 // fanning-out of the specified edge and their count. If an AltList is
00168 // specified, only the class-ids with a minimum cost are considered
00169 LangModEdge ** TessLangModel::GetEdges(CharAltList *alt_list,
00170                                        LangModEdge *lang_mod_edge,
00171                                        int *edge_cnt) {
00172   TessLangModEdge *tess_lm_edge =
00173       reinterpret_cast<TessLangModEdge *>(lang_mod_edge);
00174   LangModEdge **edge_array = NULL;
00175   (*edge_cnt) = 0;
00176 
00177   // if we are starting from the root, we'll instantiate every DAWG
00178   // and get the all the edges that emerge from the root
00179   if (tess_lm_edge == NULL) {
00180     // get DAWG count from Tesseract
00181     int dawg_cnt = NumDawgs();
00182     // preallocate the edge buffer
00183     (*edge_cnt) = dawg_cnt * max_edge_;
00184     edge_array = new LangModEdge *[(*edge_cnt)];
00185     if (edge_array == NULL) {
00186       return NULL;
00187     }
00188 
00189     for (int dawg_idx = (*edge_cnt) = 0; dawg_idx < dawg_cnt; dawg_idx++) {
00190       const Dawg *curr_dawg = GetDawg(dawg_idx);
00191       // Only look through word Dawgs (since there is a special way of
00192       // handling numbers and punctuation).
00193       if (curr_dawg->type() == DAWG_TYPE_WORD) {
00194         (*edge_cnt) += FanOut(alt_list, curr_dawg, 0, 0, NULL, true,
00195                               edge_array + (*edge_cnt));
00196       }
00197     }  // dawg
00198 
00199     (*edge_cnt) += FanOut(alt_list, number_dawg_, 0, 0, NULL, true,
00200                           edge_array + (*edge_cnt));
00201 
00202     // OOD: it is intentionally not added to the list to make sure it comes
00203     // at the end
00204     (*edge_cnt) += FanOut(alt_list, ood_dawg_, 0, 0, NULL, true,
00205                           edge_array + (*edge_cnt));
00206 
00207     // set the root flag for all root edges
00208     for (int edge_idx = 0; edge_idx < (*edge_cnt); edge_idx++) {
00209       edge_array[edge_idx]->SetRoot(true);
00210     }
00211   } else {  // not starting at the root
00212     // preallocate the edge buffer
00213     (*edge_cnt) = max_edge_;
00214     // allocate memory for edges
00215     edge_array = new LangModEdge *[(*edge_cnt)];
00216     if (edge_array == NULL) {
00217       return NULL;
00218     }
00219 
00220     // get the FanOut edges from the root of each dawg
00221     (*edge_cnt) = FanOut(alt_list,
00222                          tess_lm_edge->GetDawg(),
00223                          tess_lm_edge->EndEdge(), tess_lm_edge->EdgeMask(),
00224                          tess_lm_edge->EdgeString(), false, edge_array);
00225   }
00226   return edge_array;
00227 }
00228 
00229 // generate edges from an NULL terminated string
00230 // (used for punctuation, operators and digits)
00231 int TessLangModel::Edges(const char *strng, const Dawg *dawg,
00232                          EDGE_REF edge_ref, EDGE_REF edge_mask,
00233                          LangModEdge **edge_array) {
00234   int edge_idx,
00235     edge_cnt = 0;
00236 
00237   for (edge_idx = 0; strng[edge_idx] != 0; edge_idx++) {
00238     int class_id = cntxt_->CharacterSet()->ClassID((char_32)strng[edge_idx]);
00239     if (class_id != INVALID_UNICHAR_ID) {
00240       // create an edge object
00241       edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg, edge_ref,
00242                                                  class_id);
00243       if (edge_array[edge_cnt] == NULL) {
00244         return 0;
00245       }
00246 
00247       reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])->
00248           SetEdgeMask(edge_mask);
00249       edge_cnt++;
00250     }
00251   }
00252 
00253   return edge_cnt;
00254 }
00255 
00256 // generate OOD edges
00257 int TessLangModel::OODEdges(CharAltList *alt_list, EDGE_REF edge_ref,
00258                             EDGE_REF edge_ref_mask, LangModEdge **edge_array) {
00259   int class_cnt = cntxt_->CharacterSet()->ClassCount();
00260   int edge_cnt = 0;
00261   for (int class_id = 0; class_id < class_cnt; class_id++) {
00262     // produce an OOD edge only if the cost of the char is low enough
00263     if ((alt_list == NULL ||
00264          alt_list->ClassCost(class_id) <= max_ood_shape_cost_)) {
00265       // create an edge object
00266       edge_array[edge_cnt] = new TessLangModEdge(cntxt_, class_id);
00267       if (edge_array[edge_cnt] == NULL) {
00268         return 0;
00269       }
00270 
00271       edge_cnt++;
00272     }
00273   }
00274 
00275   return edge_cnt;
00276 }
00277 
00278 // computes and returns the edges that fan out of an edge ref
00279 int TessLangModel::FanOut(CharAltList *alt_list, const Dawg *dawg,
00280                           EDGE_REF edge_ref, EDGE_REF edge_mask,
00281                           const char_32 *str, bool root_flag,
00282                           LangModEdge **edge_array) {
00283   int edge_cnt = 0;
00284   NODE_REF next_node = NO_EDGE;
00285 
00286   // OOD
00287   if (dawg == reinterpret_cast<Dawg *>(DAWG_OOD)) {
00288     if (ood_enabled_ == true) {
00289       return OODEdges(alt_list, edge_ref, edge_mask, edge_array);
00290     } else {
00291       return 0;
00292     }
00293   } else if (dawg == reinterpret_cast<Dawg *>(DAWG_NUMBER)) {
00294     // Number
00295     if (numeric_enabled_ == true) {
00296       return NumberEdges(edge_ref, edge_array);
00297     } else {
00298       return 0;
00299     }
00300   } else if (IsTrailingPuncEdge(edge_mask)) {
00301     // a TRAILING PUNC MASK, generate more trailing punctuation and return
00302     if (punc_enabled_ == true) {
00303       EDGE_REF trail_cnt = TrailingPuncCount(edge_mask);
00304       return Edges(trail_punc_.c_str(), dawg, edge_ref,
00305                    TrailingPuncEdgeMask(trail_cnt + 1), edge_array);
00306     } else {
00307       return 0;
00308     }
00309   } else if (root_flag == true || edge_ref == 0) {
00310     // Root, generate leading punctuation and continue
00311     if (root_flag) {
00312       if (punc_enabled_ == true) {
00313         edge_cnt += Edges(lead_punc_.c_str(), dawg, 0, LEAD_PUNC_EDGE_REF_MASK,
00314                           edge_array);
00315       }
00316     }
00317     next_node = 0;
00318   } else {
00319     // a node in the main trie
00320     bool eow_flag = (dawg->end_of_word(edge_ref) != 0);
00321 
00322     // for EOW
00323     if (eow_flag == true) {
00324       // generate trailing punctuation
00325       if (punc_enabled_ == true) {
00326         edge_cnt += Edges(trail_punc_.c_str(), dawg, edge_ref,
00327                           TrailingPuncEdgeMask((EDGE_REF)1), edge_array);
00328         // generate a hyphen and go back to the root
00329         edge_cnt += Edges("-/", dawg, 0, 0, edge_array + edge_cnt);
00330       }
00331     }
00332 
00333     // advance node
00334     next_node = dawg->next_node(edge_ref);
00335     if (next_node == 0 || next_node == NO_EDGE) {
00336       return edge_cnt;
00337     }
00338   }
00339 
00340   // now get all the emerging edges if word list is enabled
00341   if (word_list_enabled_ == true && next_node != NO_EDGE) {
00342     // create child edges
00343     int child_edge_cnt =
00344       TessLangModEdge::CreateChildren(cntxt_, dawg, next_node,
00345                                       edge_array + edge_cnt);
00346     int strt_cnt = edge_cnt;
00347 
00348     // set the edge mask
00349     for (int child = 0; child < child_edge_cnt; child++) {
00350       reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt++])->
00351           SetEdgeMask(edge_mask);
00352     }
00353 
00354     // if we are at the root, create upper case forms of these edges if possible
00355     if (root_flag == true) {
00356       for (int child = 0; child < child_edge_cnt; child++) {
00357         TessLangModEdge *child_edge =
00358             reinterpret_cast<TessLangModEdge *>(edge_array[strt_cnt + child]);
00359 
00360         if (has_case_ == true) {
00361           const char_32 *edge_str = child_edge->EdgeString();
00362           if (edge_str != NULL && islower(edge_str[0]) != 0 &&
00363               edge_str[1] == 0) {
00364             int class_id =
00365                 cntxt_->CharacterSet()->ClassID(toupper(edge_str[0]));
00366             if (class_id != INVALID_UNICHAR_ID) {
00367               // generate an upper case edge for lower case chars
00368               edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg,
00369                   child_edge->StartEdge(), child_edge->EndEdge(), class_id);
00370 
00371               if (edge_array[edge_cnt] != NULL) {
00372                 reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])->
00373                     SetEdgeMask(edge_mask);
00374                 edge_cnt++;
00375               }
00376             }
00377           }
00378         }
00379       }
00380     }
00381   }
00382   return edge_cnt;
00383 }
00384 
00385 // Generate the edges fanning-out from an edge in the number state machine
00386 int TessLangModel::NumberEdges(EDGE_REF edge_ref, LangModEdge **edge_array) {
00387   EDGE_REF new_state,
00388     state;
00389 
00390   int repeat_cnt,
00391     new_repeat_cnt;
00392 
00393   state = ((edge_ref & NUMBER_STATE_MASK) >> NUMBER_STATE_SHIFT);
00394   repeat_cnt = ((edge_ref & NUMBER_REPEAT_MASK) >> NUMBER_REPEAT_SHIFT);
00395 
00396   if (state < 0 || state >= kStateCnt) {
00397     return 0;
00398   }
00399 
00400   // go thru all valid transitions from the state
00401   int edge_cnt = 0;
00402 
00403   EDGE_REF new_edge_ref;
00404 
00405   for (int lit = 0; lit < kNumLiteralCnt; lit++) {
00406     // move to the new state
00407     new_state = num_state_machine_[state][lit];
00408     if (new_state == NUM_TRM) {
00409       continue;
00410     }
00411 
00412     if (new_state == state) {
00413       new_repeat_cnt = repeat_cnt + 1;
00414     } else {
00415       new_repeat_cnt = 1;
00416     }
00417 
00418     // not allowed to repeat beyond this
00419     if (new_repeat_cnt > num_max_repeat_[state]) {
00420       continue;
00421     }
00422 
00423     new_edge_ref = (new_state << NUMBER_STATE_SHIFT) |
00424         (lit << NUMBER_LITERAL_SHIFT) |
00425         (new_repeat_cnt << NUMBER_REPEAT_SHIFT);
00426 
00427     edge_cnt += Edges(literal_str_[lit]->c_str(), number_dawg_,
00428                       new_edge_ref, 0, edge_array + edge_cnt);
00429   }
00430 
00431   return edge_cnt;
00432 }
00433 
00434 // Loads Language model elements from contents of the <lang>.cube.lm file
00435 bool TessLangModel::LoadLangModelElements(const string &lm_params) {
00436   bool success = true;
00437   // split into lines, each corresponding to a token type below
00438   vector<string> str_vec;
00439   CubeUtils::SplitStringUsing(lm_params, "\r\n", &str_vec);
00440   for (int entry = 0; entry < str_vec.size(); entry++) {
00441     vector<string> tokens;
00442     // should be only two tokens: type and value
00443     CubeUtils::SplitStringUsing(str_vec[entry], "=", &tokens);
00444     if (tokens.size() != 2)
00445       success = false;
00446     if (tokens[0] == "LeadPunc") {
00447       lead_punc_ = tokens[1];
00448     } else if (tokens[0] == "TrailPunc") {
00449       trail_punc_ = tokens[1];
00450     } else if (tokens[0] == "NumLeadPunc") {
00451       num_lead_punc_ = tokens[1];
00452     } else if (tokens[0] == "NumTrailPunc") {
00453       num_trail_punc_ = tokens[1];
00454     } else if (tokens[0] == "Operators") {
00455       operators_ = tokens[1];
00456     } else if (tokens[0] == "Digits") {
00457       digits_ = tokens[1];
00458     } else if (tokens[0] == "Alphas") {
00459       alphas_ = tokens[1];
00460     } else {
00461       success = false;
00462     }
00463   }
00464 
00465   RemoveInvalidCharacters(&num_lead_punc_);
00466   RemoveInvalidCharacters(&num_trail_punc_);
00467   RemoveInvalidCharacters(&digits_);
00468   RemoveInvalidCharacters(&operators_);
00469   RemoveInvalidCharacters(&alphas_);
00470 
00471   // form the array of literal strings needed for number state machine
00472   // It is essential that the literal strings go in the order below
00473   literal_str_[0] = &num_lead_punc_;
00474   literal_str_[1] = &num_trail_punc_;
00475   literal_str_[2] = &digits_;
00476   literal_str_[3] = &operators_;
00477   literal_str_[4] = &alphas_;
00478 
00479   return success;
00480 }
00481 
00482 void TessLangModel::RemoveInvalidCharacters(string *lm_str) {
00483   CharSet *char_set = cntxt_->CharacterSet();
00484   tesseract::string_32 lm_str32;
00485   CubeUtils::UTF8ToUTF32(lm_str->c_str(), &lm_str32);
00486 
00487   int len = CubeUtils::StrLen(lm_str32.c_str());
00488   char_32 *clean_str32 = new char_32[len + 1];
00489   if (!clean_str32)
00490     return;
00491   int clean_len = 0;
00492   for (int i = 0; i < len; ++i) {
00493     int class_id = char_set->ClassID((char_32)lm_str32[i]);
00494     if (class_id != INVALID_UNICHAR_ID) {
00495       clean_str32[clean_len] = lm_str32[i];
00496       ++clean_len;
00497     }
00498   }
00499   clean_str32[clean_len] = 0;
00500   if (clean_len < len) {
00501     lm_str->clear();
00502     CubeUtils::UTF32ToUTF8(clean_str32, lm_str);
00503   }
00504   delete [] clean_str32;
00505 }
00506 
00507 int TessLangModel::NumDawgs() const {
00508   return (word_dawgs_ != NULL) ?
00509       word_dawgs_->size() : cntxt_->TesseractObject()->getDict().NumDawgs();
00510 }
00511 
00512 // Returns the dawgs with the given index from either the dawgs
00513 // stored by the Tesseract object, or the word_dawgs_.
00514 const Dawg *TessLangModel::GetDawg(int index) const {
00515   if (word_dawgs_ != NULL) {
00516     ASSERT_HOST(index < word_dawgs_->size());
00517     return (*word_dawgs_)[index];
00518   } else {
00519     ASSERT_HOST(index < cntxt_->TesseractObject()->getDict().NumDawgs());
00520     return cntxt_->TesseractObject()->getDict().GetDawg(index);
00521   }
00522 }
00523 }