Tesseract
3.02
|
00001 /********************************************************************** 00002 * File: tess_lang_model.cpp 00003 * Description: Implementation of the Tesseract Language Model Class 00004 * Author: Ahmad Abdulkader 00005 * Created: 2008 00006 * 00007 * (C) Copyright 2008, Google Inc. 00008 ** Licensed under the Apache License, Version 2.0 (the "License"); 00009 ** you may not use this file except in compliance with the License. 00010 ** You may obtain a copy of the License at 00011 ** http://www.apache.org/licenses/LICENSE-2.0 00012 ** Unless required by applicable law or agreed to in writing, software 00013 ** distributed under the License is distributed on an "AS IS" BASIS, 00014 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00015 ** See the License for the specific language governing permissions and 00016 ** limitations under the License. 00017 * 00018 **********************************************************************/ 00019 00020 // The TessLangModel class abstracts the Tesseract language model. It inherits 00021 // from the LangModel class. The Tesseract language model encompasses several 00022 // Dawgs (words from training data, punctuation, numbers, document words). 00023 // On top of this Cube adds an OOD state machine 00024 // The class provides methods to traverse the language model in a generative 00025 // fashion. Given any node in the DAWG, the language model can generate a list 00026 // of children (or fan-out) edges 00027 00028 #include <string> 00029 #include <vector> 00030 00031 #include "char_samp.h" 00032 #include "cube_utils.h" 00033 #include "dict.h" 00034 #include "tesseractclass.h" 00035 #include "tess_lang_model.h" 00036 #include "tessdatamanager.h" 00037 #include "unicharset.h" 00038 00039 namespace tesseract { 00040 // max fan-out (used for preallocation). Initialized here, but modified by 00041 // constructor 00042 int TessLangModel::max_edge_ = 4096; 00043 00044 // Language model extra State machines 00045 const Dawg *TessLangModel::ood_dawg_ = reinterpret_cast<Dawg *>(DAWG_OOD); 00046 const Dawg *TessLangModel::number_dawg_ = reinterpret_cast<Dawg *>(DAWG_NUMBER); 00047 00048 // number state machine 00049 const int TessLangModel::num_state_machine_[kStateCnt][kNumLiteralCnt] = { 00050 {0, 1, 1, NUM_TRM, NUM_TRM}, 00051 {NUM_TRM, 1, 1, 3, 2}, 00052 {NUM_TRM, NUM_TRM, 1, NUM_TRM, 2}, 00053 {NUM_TRM, NUM_TRM, 3, NUM_TRM, 2}, 00054 }; 00055 const int TessLangModel::num_max_repeat_[kStateCnt] = {3, 32, 8, 3}; 00056 00057 // thresholds and penalties 00058 int TessLangModel::max_ood_shape_cost_ = CubeUtils::Prob2Cost(1e-4); 00059 00060 TessLangModel::TessLangModel(const string &lm_params, 00061 const string &data_file_path, 00062 bool load_system_dawg, 00063 TessdataManager *tessdata_manager, 00064 CubeRecoContext *cntxt) { 00065 cntxt_ = cntxt; 00066 has_case_ = cntxt_->HasCase(); 00067 // Load the rest of the language model elements from file 00068 LoadLangModelElements(lm_params); 00069 // Load word_dawgs_ if needed. 00070 if (tessdata_manager->SeekToStart(TESSDATA_CUBE_UNICHARSET)) { 00071 word_dawgs_ = new DawgVector(); 00072 if (load_system_dawg && 00073 tessdata_manager->SeekToStart(TESSDATA_CUBE_SYSTEM_DAWG)) { 00074 // The last parameter to the Dawg constructor (the debug level) is set to 00075 // false, until Cube has a way to express its preferred debug level. 00076 *word_dawgs_ += new SquishedDawg(tessdata_manager->GetDataFilePtr(), 00077 DAWG_TYPE_WORD, 00078 cntxt_->Lang().c_str(), 00079 SYSTEM_DAWG_PERM, false); 00080 } 00081 } else { 00082 word_dawgs_ = NULL; 00083 } 00084 } 00085 00086 // Cleanup an edge array 00087 void TessLangModel::FreeEdges(int edge_cnt, LangModEdge **edge_array) { 00088 if (edge_array != NULL) { 00089 for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) { 00090 if (edge_array[edge_idx] != NULL) { 00091 delete edge_array[edge_idx]; 00092 } 00093 } 00094 delete []edge_array; 00095 } 00096 } 00097 00098 // Determines if a sequence of 32-bit chars is valid in this language model 00099 // starting from the specified edge. If the eow_flag is ON, also checks for 00100 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last 00101 // edge 00102 bool TessLangModel::IsValidSequence(LangModEdge *edge, 00103 const char_32 *sequence, 00104 bool eow_flag, 00105 LangModEdge **final_edge) { 00106 // get the edges emerging from this edge 00107 int edge_cnt = 0; 00108 LangModEdge **edge_array = GetEdges(NULL, edge, &edge_cnt); 00109 00110 // find the 1st char in the sequence in the children 00111 for (int edge_idx = 0; edge_idx < edge_cnt; edge_idx++) { 00112 // found a match 00113 if (sequence[0] == edge_array[edge_idx]->EdgeString()[0]) { 00114 // if this is the last char 00115 if (sequence[1] == 0) { 00116 // succeed if we are in prefix mode or this is a terminal edge 00117 if (eow_flag == false || edge_array[edge_idx]->IsEOW()) { 00118 if (final_edge != NULL) { 00119 (*final_edge) = edge_array[edge_idx]; 00120 edge_array[edge_idx] = NULL; 00121 } 00122 00123 FreeEdges(edge_cnt, edge_array); 00124 return true; 00125 } 00126 } else { 00127 // not the last char continue checking 00128 if (IsValidSequence(edge_array[edge_idx], sequence + 1, eow_flag, 00129 final_edge) == true) { 00130 FreeEdges(edge_cnt, edge_array); 00131 return true; 00132 } 00133 } 00134 } 00135 } 00136 00137 FreeEdges(edge_cnt, edge_array); 00138 return false; 00139 } 00140 00141 // Determines if a sequence of 32-bit chars is valid in this language model 00142 // starting from the root. If the eow_flag is ON, also checks for 00143 // a valid EndOfWord. If final_edge is not NULL, returns a pointer to the last 00144 // edge 00145 bool TessLangModel::IsValidSequence(const char_32 *sequence, bool eow_flag, 00146 LangModEdge **final_edge) { 00147 if (final_edge != NULL) { 00148 (*final_edge) = NULL; 00149 } 00150 00151 return IsValidSequence(NULL, sequence, eow_flag, final_edge); 00152 } 00153 00154 bool TessLangModel::IsLeadingPunc(const char_32 ch) { 00155 return lead_punc_.find(ch) != string::npos; 00156 } 00157 00158 bool TessLangModel::IsTrailingPunc(const char_32 ch) { 00159 return trail_punc_.find(ch) != string::npos; 00160 } 00161 00162 bool TessLangModel::IsDigit(const char_32 ch) { 00163 return digits_.find(ch) != string::npos; 00164 } 00165 00166 // The general fan-out generation function. Returns the list of edges 00167 // fanning-out of the specified edge and their count. If an AltList is 00168 // specified, only the class-ids with a minimum cost are considered 00169 LangModEdge ** TessLangModel::GetEdges(CharAltList *alt_list, 00170 LangModEdge *lang_mod_edge, 00171 int *edge_cnt) { 00172 TessLangModEdge *tess_lm_edge = 00173 reinterpret_cast<TessLangModEdge *>(lang_mod_edge); 00174 LangModEdge **edge_array = NULL; 00175 (*edge_cnt) = 0; 00176 00177 // if we are starting from the root, we'll instantiate every DAWG 00178 // and get the all the edges that emerge from the root 00179 if (tess_lm_edge == NULL) { 00180 // get DAWG count from Tesseract 00181 int dawg_cnt = NumDawgs(); 00182 // preallocate the edge buffer 00183 (*edge_cnt) = dawg_cnt * max_edge_; 00184 edge_array = new LangModEdge *[(*edge_cnt)]; 00185 if (edge_array == NULL) { 00186 return NULL; 00187 } 00188 00189 for (int dawg_idx = (*edge_cnt) = 0; dawg_idx < dawg_cnt; dawg_idx++) { 00190 const Dawg *curr_dawg = GetDawg(dawg_idx); 00191 // Only look through word Dawgs (since there is a special way of 00192 // handling numbers and punctuation). 00193 if (curr_dawg->type() == DAWG_TYPE_WORD) { 00194 (*edge_cnt) += FanOut(alt_list, curr_dawg, 0, 0, NULL, true, 00195 edge_array + (*edge_cnt)); 00196 } 00197 } // dawg 00198 00199 (*edge_cnt) += FanOut(alt_list, number_dawg_, 0, 0, NULL, true, 00200 edge_array + (*edge_cnt)); 00201 00202 // OOD: it is intentionally not added to the list to make sure it comes 00203 // at the end 00204 (*edge_cnt) += FanOut(alt_list, ood_dawg_, 0, 0, NULL, true, 00205 edge_array + (*edge_cnt)); 00206 00207 // set the root flag for all root edges 00208 for (int edge_idx = 0; edge_idx < (*edge_cnt); edge_idx++) { 00209 edge_array[edge_idx]->SetRoot(true); 00210 } 00211 } else { // not starting at the root 00212 // preallocate the edge buffer 00213 (*edge_cnt) = max_edge_; 00214 // allocate memory for edges 00215 edge_array = new LangModEdge *[(*edge_cnt)]; 00216 if (edge_array == NULL) { 00217 return NULL; 00218 } 00219 00220 // get the FanOut edges from the root of each dawg 00221 (*edge_cnt) = FanOut(alt_list, 00222 tess_lm_edge->GetDawg(), 00223 tess_lm_edge->EndEdge(), tess_lm_edge->EdgeMask(), 00224 tess_lm_edge->EdgeString(), false, edge_array); 00225 } 00226 return edge_array; 00227 } 00228 00229 // generate edges from an NULL terminated string 00230 // (used for punctuation, operators and digits) 00231 int TessLangModel::Edges(const char *strng, const Dawg *dawg, 00232 EDGE_REF edge_ref, EDGE_REF edge_mask, 00233 LangModEdge **edge_array) { 00234 int edge_idx, 00235 edge_cnt = 0; 00236 00237 for (edge_idx = 0; strng[edge_idx] != 0; edge_idx++) { 00238 int class_id = cntxt_->CharacterSet()->ClassID((char_32)strng[edge_idx]); 00239 if (class_id != INVALID_UNICHAR_ID) { 00240 // create an edge object 00241 edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg, edge_ref, 00242 class_id); 00243 if (edge_array[edge_cnt] == NULL) { 00244 return 0; 00245 } 00246 00247 reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])-> 00248 SetEdgeMask(edge_mask); 00249 edge_cnt++; 00250 } 00251 } 00252 00253 return edge_cnt; 00254 } 00255 00256 // generate OOD edges 00257 int TessLangModel::OODEdges(CharAltList *alt_list, EDGE_REF edge_ref, 00258 EDGE_REF edge_ref_mask, LangModEdge **edge_array) { 00259 int class_cnt = cntxt_->CharacterSet()->ClassCount(); 00260 int edge_cnt = 0; 00261 for (int class_id = 0; class_id < class_cnt; class_id++) { 00262 // produce an OOD edge only if the cost of the char is low enough 00263 if ((alt_list == NULL || 00264 alt_list->ClassCost(class_id) <= max_ood_shape_cost_)) { 00265 // create an edge object 00266 edge_array[edge_cnt] = new TessLangModEdge(cntxt_, class_id); 00267 if (edge_array[edge_cnt] == NULL) { 00268 return 0; 00269 } 00270 00271 edge_cnt++; 00272 } 00273 } 00274 00275 return edge_cnt; 00276 } 00277 00278 // computes and returns the edges that fan out of an edge ref 00279 int TessLangModel::FanOut(CharAltList *alt_list, const Dawg *dawg, 00280 EDGE_REF edge_ref, EDGE_REF edge_mask, 00281 const char_32 *str, bool root_flag, 00282 LangModEdge **edge_array) { 00283 int edge_cnt = 0; 00284 NODE_REF next_node = NO_EDGE; 00285 00286 // OOD 00287 if (dawg == reinterpret_cast<Dawg *>(DAWG_OOD)) { 00288 if (ood_enabled_ == true) { 00289 return OODEdges(alt_list, edge_ref, edge_mask, edge_array); 00290 } else { 00291 return 0; 00292 } 00293 } else if (dawg == reinterpret_cast<Dawg *>(DAWG_NUMBER)) { 00294 // Number 00295 if (numeric_enabled_ == true) { 00296 return NumberEdges(edge_ref, edge_array); 00297 } else { 00298 return 0; 00299 } 00300 } else if (IsTrailingPuncEdge(edge_mask)) { 00301 // a TRAILING PUNC MASK, generate more trailing punctuation and return 00302 if (punc_enabled_ == true) { 00303 EDGE_REF trail_cnt = TrailingPuncCount(edge_mask); 00304 return Edges(trail_punc_.c_str(), dawg, edge_ref, 00305 TrailingPuncEdgeMask(trail_cnt + 1), edge_array); 00306 } else { 00307 return 0; 00308 } 00309 } else if (root_flag == true || edge_ref == 0) { 00310 // Root, generate leading punctuation and continue 00311 if (root_flag) { 00312 if (punc_enabled_ == true) { 00313 edge_cnt += Edges(lead_punc_.c_str(), dawg, 0, LEAD_PUNC_EDGE_REF_MASK, 00314 edge_array); 00315 } 00316 } 00317 next_node = 0; 00318 } else { 00319 // a node in the main trie 00320 bool eow_flag = (dawg->end_of_word(edge_ref) != 0); 00321 00322 // for EOW 00323 if (eow_flag == true) { 00324 // generate trailing punctuation 00325 if (punc_enabled_ == true) { 00326 edge_cnt += Edges(trail_punc_.c_str(), dawg, edge_ref, 00327 TrailingPuncEdgeMask((EDGE_REF)1), edge_array); 00328 // generate a hyphen and go back to the root 00329 edge_cnt += Edges("-/", dawg, 0, 0, edge_array + edge_cnt); 00330 } 00331 } 00332 00333 // advance node 00334 next_node = dawg->next_node(edge_ref); 00335 if (next_node == 0 || next_node == NO_EDGE) { 00336 return edge_cnt; 00337 } 00338 } 00339 00340 // now get all the emerging edges if word list is enabled 00341 if (word_list_enabled_ == true && next_node != NO_EDGE) { 00342 // create child edges 00343 int child_edge_cnt = 00344 TessLangModEdge::CreateChildren(cntxt_, dawg, next_node, 00345 edge_array + edge_cnt); 00346 int strt_cnt = edge_cnt; 00347 00348 // set the edge mask 00349 for (int child = 0; child < child_edge_cnt; child++) { 00350 reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt++])-> 00351 SetEdgeMask(edge_mask); 00352 } 00353 00354 // if we are at the root, create upper case forms of these edges if possible 00355 if (root_flag == true) { 00356 for (int child = 0; child < child_edge_cnt; child++) { 00357 TessLangModEdge *child_edge = 00358 reinterpret_cast<TessLangModEdge *>(edge_array[strt_cnt + child]); 00359 00360 if (has_case_ == true) { 00361 const char_32 *edge_str = child_edge->EdgeString(); 00362 if (edge_str != NULL && islower(edge_str[0]) != 0 && 00363 edge_str[1] == 0) { 00364 int class_id = 00365 cntxt_->CharacterSet()->ClassID(toupper(edge_str[0])); 00366 if (class_id != INVALID_UNICHAR_ID) { 00367 // generate an upper case edge for lower case chars 00368 edge_array[edge_cnt] = new TessLangModEdge(cntxt_, dawg, 00369 child_edge->StartEdge(), child_edge->EndEdge(), class_id); 00370 00371 if (edge_array[edge_cnt] != NULL) { 00372 reinterpret_cast<TessLangModEdge *>(edge_array[edge_cnt])-> 00373 SetEdgeMask(edge_mask); 00374 edge_cnt++; 00375 } 00376 } 00377 } 00378 } 00379 } 00380 } 00381 } 00382 return edge_cnt; 00383 } 00384 00385 // Generate the edges fanning-out from an edge in the number state machine 00386 int TessLangModel::NumberEdges(EDGE_REF edge_ref, LangModEdge **edge_array) { 00387 EDGE_REF new_state, 00388 state; 00389 00390 int repeat_cnt, 00391 new_repeat_cnt; 00392 00393 state = ((edge_ref & NUMBER_STATE_MASK) >> NUMBER_STATE_SHIFT); 00394 repeat_cnt = ((edge_ref & NUMBER_REPEAT_MASK) >> NUMBER_REPEAT_SHIFT); 00395 00396 if (state < 0 || state >= kStateCnt) { 00397 return 0; 00398 } 00399 00400 // go thru all valid transitions from the state 00401 int edge_cnt = 0; 00402 00403 EDGE_REF new_edge_ref; 00404 00405 for (int lit = 0; lit < kNumLiteralCnt; lit++) { 00406 // move to the new state 00407 new_state = num_state_machine_[state][lit]; 00408 if (new_state == NUM_TRM) { 00409 continue; 00410 } 00411 00412 if (new_state == state) { 00413 new_repeat_cnt = repeat_cnt + 1; 00414 } else { 00415 new_repeat_cnt = 1; 00416 } 00417 00418 // not allowed to repeat beyond this 00419 if (new_repeat_cnt > num_max_repeat_[state]) { 00420 continue; 00421 } 00422 00423 new_edge_ref = (new_state << NUMBER_STATE_SHIFT) | 00424 (lit << NUMBER_LITERAL_SHIFT) | 00425 (new_repeat_cnt << NUMBER_REPEAT_SHIFT); 00426 00427 edge_cnt += Edges(literal_str_[lit]->c_str(), number_dawg_, 00428 new_edge_ref, 0, edge_array + edge_cnt); 00429 } 00430 00431 return edge_cnt; 00432 } 00433 00434 // Loads Language model elements from contents of the <lang>.cube.lm file 00435 bool TessLangModel::LoadLangModelElements(const string &lm_params) { 00436 bool success = true; 00437 // split into lines, each corresponding to a token type below 00438 vector<string> str_vec; 00439 CubeUtils::SplitStringUsing(lm_params, "\r\n", &str_vec); 00440 for (int entry = 0; entry < str_vec.size(); entry++) { 00441 vector<string> tokens; 00442 // should be only two tokens: type and value 00443 CubeUtils::SplitStringUsing(str_vec[entry], "=", &tokens); 00444 if (tokens.size() != 2) 00445 success = false; 00446 if (tokens[0] == "LeadPunc") { 00447 lead_punc_ = tokens[1]; 00448 } else if (tokens[0] == "TrailPunc") { 00449 trail_punc_ = tokens[1]; 00450 } else if (tokens[0] == "NumLeadPunc") { 00451 num_lead_punc_ = tokens[1]; 00452 } else if (tokens[0] == "NumTrailPunc") { 00453 num_trail_punc_ = tokens[1]; 00454 } else if (tokens[0] == "Operators") { 00455 operators_ = tokens[1]; 00456 } else if (tokens[0] == "Digits") { 00457 digits_ = tokens[1]; 00458 } else if (tokens[0] == "Alphas") { 00459 alphas_ = tokens[1]; 00460 } else { 00461 success = false; 00462 } 00463 } 00464 00465 RemoveInvalidCharacters(&num_lead_punc_); 00466 RemoveInvalidCharacters(&num_trail_punc_); 00467 RemoveInvalidCharacters(&digits_); 00468 RemoveInvalidCharacters(&operators_); 00469 RemoveInvalidCharacters(&alphas_); 00470 00471 // form the array of literal strings needed for number state machine 00472 // It is essential that the literal strings go in the order below 00473 literal_str_[0] = &num_lead_punc_; 00474 literal_str_[1] = &num_trail_punc_; 00475 literal_str_[2] = &digits_; 00476 literal_str_[3] = &operators_; 00477 literal_str_[4] = &alphas_; 00478 00479 return success; 00480 } 00481 00482 void TessLangModel::RemoveInvalidCharacters(string *lm_str) { 00483 CharSet *char_set = cntxt_->CharacterSet(); 00484 tesseract::string_32 lm_str32; 00485 CubeUtils::UTF8ToUTF32(lm_str->c_str(), &lm_str32); 00486 00487 int len = CubeUtils::StrLen(lm_str32.c_str()); 00488 char_32 *clean_str32 = new char_32[len + 1]; 00489 if (!clean_str32) 00490 return; 00491 int clean_len = 0; 00492 for (int i = 0; i < len; ++i) { 00493 int class_id = char_set->ClassID((char_32)lm_str32[i]); 00494 if (class_id != INVALID_UNICHAR_ID) { 00495 clean_str32[clean_len] = lm_str32[i]; 00496 ++clean_len; 00497 } 00498 } 00499 clean_str32[clean_len] = 0; 00500 if (clean_len < len) { 00501 lm_str->clear(); 00502 CubeUtils::UTF32ToUTF8(clean_str32, lm_str); 00503 } 00504 delete [] clean_str32; 00505 } 00506 00507 int TessLangModel::NumDawgs() const { 00508 return (word_dawgs_ != NULL) ? 00509 word_dawgs_->size() : cntxt_->TesseractObject()->getDict().NumDawgs(); 00510 } 00511 00512 // Returns the dawgs with the given index from either the dawgs 00513 // stored by the Tesseract object, or the word_dawgs_. 00514 const Dawg *TessLangModel::GetDawg(int index) const { 00515 if (word_dawgs_ != NULL) { 00516 ASSERT_HOST(index < word_dawgs_->size()); 00517 return (*word_dawgs_)[index]; 00518 } else { 00519 ASSERT_HOST(index < cntxt_->TesseractObject()->getDict().NumDawgs()); 00520 return cntxt_->TesseractObject()->getDict().GetDawg(index); 00521 } 00522 } 00523 }