Tesseract
3.02
|
00001 /********************************************************************** 00002 * File: beam_search.cpp 00003 * Description: Class to implement Beam Word Search Algorithm 00004 * Author: Ahmad Abdulkader 00005 * Created: 2007 00006 * 00007 * (C) Copyright 2008, Google Inc. 00008 ** Licensed under the Apache License, Version 2.0 (the "License"); 00009 ** you may not use this file except in compliance with the License. 00010 ** You may obtain a copy of the License at 00011 ** http://www.apache.org/licenses/LICENSE-2.0 00012 ** Unless required by applicable law or agreed to in writing, software 00013 ** distributed under the License is distributed on an "AS IS" BASIS, 00014 ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00015 ** See the License for the specific language governing permissions and 00016 ** limitations under the License. 00017 * 00018 **********************************************************************/ 00019 00020 #include <algorithm> 00021 00022 #include "beam_search.h" 00023 #include "tesseractclass.h" 00024 00025 namespace tesseract { 00026 00027 BeamSearch::BeamSearch(CubeRecoContext *cntxt, bool word_mode) { 00028 cntxt_ = cntxt; 00029 seg_pt_cnt_ = 0; 00030 col_cnt_ = 1; 00031 col_ = NULL; 00032 word_mode_ = word_mode; 00033 } 00034 00035 // Cleanup the lattice corresponding to the last search 00036 void BeamSearch::Cleanup() { 00037 if (col_ != NULL) { 00038 for (int col = 0; col < col_cnt_; col++) { 00039 if (col_[col]) 00040 delete col_[col]; 00041 } 00042 delete []col_; 00043 } 00044 col_ = NULL; 00045 } 00046 00047 BeamSearch::~BeamSearch() { 00048 Cleanup(); 00049 } 00050 00051 // Creates a set of children nodes emerging from a parent node based on 00052 // the character alternate list and the language model. 00053 void BeamSearch::CreateChildren(SearchColumn *out_col, LangModel *lang_mod, 00054 SearchNode *parent_node, 00055 LangModEdge *lm_parent_edge, 00056 CharAltList *char_alt_list, int extra_cost) { 00057 // get all the edges from this parent 00058 int edge_cnt; 00059 LangModEdge **lm_edges = lang_mod->GetEdges(char_alt_list, 00060 lm_parent_edge, &edge_cnt); 00061 if (lm_edges) { 00062 // add them to the ending column with the appropriate parent 00063 for (int edge = 0; edge < edge_cnt; edge++) { 00064 // add a node to the column if the current column is not the 00065 // last one, or if the lang model edge indicates it is valid EOW 00066 if (!cntxt_->NoisyInput() && out_col->ColIdx() >= seg_pt_cnt_ && 00067 !lm_edges[edge]->IsEOW()) { 00068 // free edge since no object is going to own it 00069 delete lm_edges[edge]; 00070 continue; 00071 } 00072 00073 // compute the recognition cost of this node 00074 int recognition_cost = MIN_PROB_COST; 00075 if (char_alt_list && char_alt_list->AltCount() > 0) { 00076 recognition_cost = MAX(0, char_alt_list->ClassCost( 00077 lm_edges[edge]->ClassID())); 00078 // Add the no space cost. This should zero in word mode 00079 recognition_cost += extra_cost; 00080 } 00081 00082 // Note that the edge will be freed inside the column if 00083 // AddNode is called 00084 if (recognition_cost >= 0) { 00085 out_col->AddNode(lm_edges[edge], recognition_cost, parent_node, 00086 cntxt_); 00087 } else { 00088 delete lm_edges[edge]; 00089 } 00090 } // edge 00091 // free edge array 00092 delete []lm_edges; 00093 } // lm_edges 00094 } 00095 00096 // Performs a beam seach in the specified search using the specified 00097 // language model; returns an alternate list of possible words as a result. 00098 WordAltList * BeamSearch::Search(SearchObject *srch_obj, LangModel *lang_mod) { 00099 // verifications 00100 if (!lang_mod) 00101 lang_mod = cntxt_->LangMod(); 00102 if (!lang_mod) { 00103 fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct " 00104 "LangModel\n"); 00105 return NULL; 00106 } 00107 00108 // free existing state 00109 Cleanup(); 00110 00111 // get seg pt count 00112 seg_pt_cnt_ = srch_obj->SegPtCnt(); 00113 if (seg_pt_cnt_ < 0) { 00114 return NULL; 00115 } 00116 col_cnt_ = seg_pt_cnt_ + 1; 00117 00118 // disregard suspicious cases 00119 if (seg_pt_cnt_ > 128) { 00120 fprintf(stderr, "Cube ERROR (BeamSearch::Search): segment point count is " 00121 "suspiciously high; bailing out\n"); 00122 return NULL; 00123 } 00124 00125 // alloc memory for columns 00126 col_ = new SearchColumn *[col_cnt_]; 00127 if (!col_) { 00128 fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct " 00129 "SearchColumn array\n"); 00130 return NULL; 00131 } 00132 memset(col_, 0, col_cnt_ * sizeof(*col_)); 00133 00134 // for all possible segments 00135 for (int end_seg = 1; end_seg <= (seg_pt_cnt_ + 1); end_seg++) { 00136 // create a search column 00137 col_[end_seg - 1] = new SearchColumn(end_seg - 1, 00138 cntxt_->Params()->BeamWidth()); 00139 if (!col_[end_seg - 1]) { 00140 fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct " 00141 "SearchColumn for column %d\n", end_seg - 1); 00142 return NULL; 00143 } 00144 00145 // for all possible start segments 00146 int init_seg = MAX(0, end_seg - cntxt_->Params()->MaxSegPerChar()); 00147 for (int strt_seg = init_seg; strt_seg < end_seg; strt_seg++) { 00148 int parent_nodes_cnt; 00149 SearchNode **parent_nodes; 00150 00151 // for the root segment, we do not have a parent 00152 if (strt_seg == 0) { 00153 parent_nodes_cnt = 1; 00154 parent_nodes = NULL; 00155 } else { 00156 // for all the existing nodes in the starting column 00157 parent_nodes_cnt = col_[strt_seg - 1]->NodeCount(); 00158 parent_nodes = col_[strt_seg - 1]->Nodes(); 00159 } 00160 00161 // run the shape recognizer 00162 CharAltList *char_alt_list = srch_obj->RecognizeSegment(strt_seg - 1, 00163 end_seg - 1); 00164 // for all the possible parents 00165 for (int parent_idx = 0; parent_idx < parent_nodes_cnt; parent_idx++) { 00166 // point to the parent node 00167 SearchNode *parent_node = !parent_nodes ? NULL 00168 : parent_nodes[parent_idx]; 00169 LangModEdge *lm_parent_edge = !parent_node ? lang_mod->Root() 00170 : parent_node->LangModelEdge(); 00171 00172 // compute the cost of not having spaces within the segment range 00173 int contig_cost = srch_obj->NoSpaceCost(strt_seg - 1, end_seg - 1); 00174 00175 // In phrase mode, compute the cost of not having a space before 00176 // this character 00177 int no_space_cost = 0; 00178 if (!word_mode_ && strt_seg > 0) { 00179 no_space_cost = srch_obj->NoSpaceCost(strt_seg - 1); 00180 } 00181 00182 // if the no space cost is low enough 00183 if ((contig_cost + no_space_cost) < MIN_PROB_COST) { 00184 // Add the children nodes 00185 CreateChildren(col_[end_seg - 1], lang_mod, parent_node, 00186 lm_parent_edge, char_alt_list, 00187 contig_cost + no_space_cost); 00188 } 00189 00190 // In phrase mode and if not starting at the root 00191 if (!word_mode_ && strt_seg > 0) { // parent_node must be non-NULL 00192 // consider starting a new word for nodes that are valid EOW 00193 if (parent_node->LangModelEdge()->IsEOW()) { 00194 // get the space cost 00195 int space_cost = srch_obj->SpaceCost(strt_seg - 1); 00196 // if the space cost is low enough 00197 if ((contig_cost + space_cost) < MIN_PROB_COST) { 00198 // Restart the language model and add nodes as children to the 00199 // space node. 00200 CreateChildren(col_[end_seg - 1], lang_mod, parent_node, NULL, 00201 char_alt_list, contig_cost + space_cost); 00202 } 00203 } 00204 } 00205 } // parent 00206 } // strt_seg 00207 00208 // prune the column nodes 00209 col_[end_seg - 1]->Prune(); 00210 00211 // Free the column hash table. No longer needed 00212 col_[end_seg - 1]->FreeHashTable(); 00213 } // end_seg 00214 00215 WordAltList *alt_list = CreateWordAltList(srch_obj); 00216 return alt_list; 00217 } 00218 00219 // Creates a Word alternate list from the results in the lattice. 00220 WordAltList *BeamSearch::CreateWordAltList(SearchObject *srch_obj) { 00221 // create an alternate list of all the nodes in the last column 00222 int node_cnt = col_[col_cnt_ - 1]->NodeCount(); 00223 SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes(); 00224 CharBigrams *bigrams = cntxt_->Bigrams(); 00225 WordUnigrams *word_unigrams = cntxt_->WordUnigramsObj(); 00226 00227 // Save the index of the best-cost node before the alt list is 00228 // sorted, so that we can retrieve it from the node list when backtracking. 00229 best_presorted_node_idx_ = 0; 00230 int best_cost = -1; 00231 00232 if (node_cnt <= 0) 00233 return NULL; 00234 00235 // start creating the word alternate list 00236 WordAltList *alt_list = new WordAltList(node_cnt + 1); 00237 for (int node_idx = 0; node_idx < node_cnt; node_idx++) { 00238 // recognition cost 00239 int recognition_cost = srch_nodes[node_idx]->BestCost(); 00240 // compute the size cost of the alternate 00241 char_32 *ch_buff = NULL; 00242 int size_cost = SizeCost(srch_obj, srch_nodes[node_idx], &ch_buff); 00243 // accumulate other costs 00244 if (ch_buff) { 00245 int cost = 0; 00246 // char bigram cost 00247 int bigram_cost = !bigrams ? 0 : 00248 bigrams->Cost(ch_buff, cntxt_->CharacterSet()); 00249 // word unigram cost 00250 int unigram_cost = !word_unigrams ? 0 : 00251 word_unigrams->Cost(ch_buff, cntxt_->LangMod(), 00252 cntxt_->CharacterSet()); 00253 // overall cost 00254 cost = static_cast<int>( 00255 (size_cost * cntxt_->Params()->SizeWgt()) + 00256 (bigram_cost * cntxt_->Params()->CharBigramWgt()) + 00257 (unigram_cost * cntxt_->Params()->WordUnigramWgt()) + 00258 (recognition_cost * cntxt_->Params()->RecoWgt())); 00259 00260 // insert into word alt list 00261 alt_list->Insert(ch_buff, cost, 00262 static_cast<void *>(srch_nodes[node_idx])); 00263 // Note that strict < is necessary because WordAltList::Sort() 00264 // uses it in a bubble sort to swap entries. 00265 if (best_cost < 0 || cost < best_cost) { 00266 best_presorted_node_idx_ = node_idx; 00267 best_cost = cost; 00268 } 00269 delete []ch_buff; 00270 } 00271 } 00272 00273 // sort the alternates based on cost 00274 alt_list->Sort(); 00275 return alt_list; 00276 } 00277 00278 // Returns the lattice column corresponding to the specified column index. 00279 SearchColumn *BeamSearch::Column(int col) const { 00280 if (col < 0 || col >= col_cnt_ || !col_) 00281 return NULL; 00282 return col_[col]; 00283 } 00284 00285 // Returns the best node in the last column of last performed search. 00286 SearchNode *BeamSearch::BestNode() const { 00287 if (col_cnt_ < 1 || !col_ || !col_[col_cnt_ - 1]) 00288 return NULL; 00289 00290 int node_cnt = col_[col_cnt_ - 1]->NodeCount(); 00291 SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes(); 00292 if (node_cnt < 1 || !srch_nodes || !srch_nodes[0]) 00293 return NULL; 00294 return srch_nodes[0]; 00295 } 00296 00297 // Returns the string corresponding to the specified alt. 00298 char_32 *BeamSearch::Alt(int alt) const { 00299 // get the last column of the lattice 00300 if (col_cnt_ <= 0) 00301 return NULL; 00302 00303 SearchColumn *srch_col = col_[col_cnt_ - 1]; 00304 if (!srch_col) 00305 return NULL; 00306 00307 // point to the last node in the selected path 00308 if (alt >= srch_col->NodeCount() || srch_col->Nodes() == NULL) { 00309 return NULL; 00310 } 00311 00312 SearchNode *srch_node = srch_col->Nodes()[alt]; 00313 if (!srch_node) 00314 return NULL; 00315 00316 // get string 00317 char_32 *str32 = srch_node->PathString(); 00318 if (!str32) 00319 return NULL; 00320 00321 return str32; 00322 } 00323 00324 // Backtracks from the specified node index and returns the corresponding 00325 // character mapped segments and character count. Optional return 00326 // arguments are the char_32 result string and character bounding 00327 // boxes, if non-NULL values are passed in. 00328 CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, int node_index, 00329 int *char_cnt, char_32 **str32, 00330 Boxa **char_boxes) const { 00331 // get the last column of the lattice 00332 if (col_cnt_ <= 0) 00333 return NULL; 00334 SearchColumn *srch_col = col_[col_cnt_ - 1]; 00335 if (!srch_col) 00336 return NULL; 00337 00338 // point to the last node in the selected path 00339 if (node_index >= srch_col->NodeCount() || !srch_col->Nodes()) 00340 return NULL; 00341 00342 SearchNode *srch_node = srch_col->Nodes()[node_index]; 00343 if (!srch_node) 00344 return NULL; 00345 return BackTrack(srch_obj, srch_node, char_cnt, str32, char_boxes); 00346 } 00347 00348 // Backtracks from the specified node index and returns the corresponding 00349 // character mapped segments and character count. Optional return 00350 // arguments are the char_32 result string and character bounding 00351 // boxes, if non-NULL values are passed in. 00352 CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, SearchNode *srch_node, 00353 int *char_cnt, char_32 **str32, 00354 Boxa **char_boxes) const { 00355 if (!srch_node) 00356 return NULL; 00357 00358 if (str32) { 00359 if (*str32) 00360 delete [](*str32); // clear existing value 00361 *str32 = srch_node->PathString(); 00362 if (!*str32) 00363 return NULL; 00364 } 00365 00366 if (char_boxes && *char_boxes) { 00367 boxaDestroy(char_boxes); // clear existing value 00368 } 00369 00370 CharSamp **chars; 00371 chars = SplitByNode(srch_obj, srch_node, char_cnt, char_boxes); 00372 if (!chars && str32) 00373 delete []*str32; 00374 return chars; 00375 } 00376 00377 // Backtracks from the given lattice node and return the corresponding 00378 // char mapped segments and character count. The character bounding 00379 // boxes are optional return arguments, if non-NULL values are passed in. 00380 CharSamp **BeamSearch::SplitByNode(SearchObject *srch_obj, 00381 SearchNode *srch_node, 00382 int *char_cnt, 00383 Boxa **char_boxes) const { 00384 // Count the characters (could be less than the path length when in 00385 // phrase mode) 00386 *char_cnt = 0; 00387 SearchNode *node = srch_node; 00388 while (node) { 00389 node = node->ParentNode(); 00390 (*char_cnt)++; 00391 } 00392 00393 if (*char_cnt == 0) 00394 return NULL; 00395 00396 // Allocate box array 00397 if (char_boxes) { 00398 if (*char_boxes) 00399 boxaDestroy(char_boxes); // clear existing value 00400 *char_boxes = boxaCreate(*char_cnt); 00401 if (*char_boxes == NULL) 00402 return NULL; 00403 } 00404 00405 // Allocate memory for CharSamp array. 00406 CharSamp **chars = new CharSamp *[*char_cnt]; 00407 if (!chars) { 00408 if (char_boxes) 00409 boxaDestroy(char_boxes); 00410 return NULL; 00411 } 00412 00413 int ch_idx = *char_cnt - 1; 00414 int seg_pt_cnt = srch_obj->SegPtCnt(); 00415 bool success=true; 00416 while (srch_node && ch_idx >= 0) { 00417 // Parent node (could be null) 00418 SearchNode *parent_node = srch_node->ParentNode(); 00419 00420 // Get the seg pts corresponding to the search node 00421 int st_col = !parent_node ? 0 : parent_node->ColIdx() + 1; 00422 int st_seg_pt = st_col <= 0 ? -1 : st_col - 1; 00423 int end_col = srch_node->ColIdx(); 00424 int end_seg_pt = end_col >= seg_pt_cnt ? seg_pt_cnt : end_col; 00425 00426 // Get a char sample corresponding to the segmentation points 00427 CharSamp *samp = srch_obj->CharSample(st_seg_pt, end_seg_pt); 00428 if (!samp) { 00429 success = false; 00430 break; 00431 } 00432 samp->SetLabel(srch_node->NodeString()); 00433 chars[ch_idx] = samp; 00434 if (char_boxes) { 00435 // Create the corresponding character bounding box 00436 Box *char_box = boxCreate(samp->Left(), samp->Top(), 00437 samp->Width(), samp->Height()); 00438 if (!char_box) { 00439 success = false; 00440 break; 00441 } 00442 boxaAddBox(*char_boxes, char_box, L_INSERT); 00443 } 00444 srch_node = parent_node; 00445 ch_idx--; 00446 } 00447 if (!success) { 00448 delete []chars; 00449 if (char_boxes) 00450 boxaDestroy(char_boxes); 00451 return NULL; 00452 } 00453 00454 // Reverse the order of boxes. 00455 if (char_boxes) { 00456 int char_boxa_size = boxaGetCount(*char_boxes); 00457 int limit = char_boxa_size / 2; 00458 for (int i = 0; i < limit; ++i) { 00459 int box1_idx = i; 00460 int box2_idx = char_boxa_size - 1 - i; 00461 Box *box1 = boxaGetBox(*char_boxes, box1_idx, L_CLONE); 00462 Box *box2 = boxaGetBox(*char_boxes, box2_idx, L_CLONE); 00463 boxaReplaceBox(*char_boxes, box2_idx, box1); 00464 boxaReplaceBox(*char_boxes, box1_idx, box2); 00465 } 00466 } 00467 return chars; 00468 } 00469 00470 // Returns the size cost of a string for a lattice path that 00471 // ends at the specified lattice node. 00472 int BeamSearch::SizeCost(SearchObject *srch_obj, SearchNode *node, 00473 char_32 **str32) const { 00474 CharSamp **chars = NULL; 00475 int char_cnt = 0; 00476 if (!node) 00477 return 0; 00478 // Backtrack to get string and character segmentation 00479 chars = BackTrack(srch_obj, node, &char_cnt, str32, NULL); 00480 if (!chars) 00481 return WORST_COST; 00482 int size_cost = (cntxt_->SizeModel() == NULL) ? 0 : 00483 cntxt_->SizeModel()->Cost(chars, char_cnt); 00484 delete []chars; 00485 return size_cost; 00486 } 00487 } // namespace tesesract