Tesseract  3.02
tesseract-ocr/cube/beam_search.cpp
Go to the documentation of this file.
00001 /**********************************************************************
00002  * File:        beam_search.cpp
00003  * Description: Class to implement Beam Word Search Algorithm
00004  * Author:    Ahmad Abdulkader
00005  * Created:   2007
00006  *
00007  * (C) Copyright 2008, Google Inc.
00008  ** Licensed under the Apache License, Version 2.0 (the "License");
00009  ** you may not use this file except in compliance with the License.
00010  ** You may obtain a copy of the License at
00011  ** http://www.apache.org/licenses/LICENSE-2.0
00012  ** Unless required by applicable law or agreed to in writing, software
00013  ** distributed under the License is distributed on an "AS IS" BASIS,
00014  ** WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00015  ** See the License for the specific language governing permissions and
00016  ** limitations under the License.
00017  *
00018  **********************************************************************/
00019 
00020 #include <algorithm>
00021 
00022 #include "beam_search.h"
00023 #include "tesseractclass.h"
00024 
00025 namespace tesseract {
00026 
00027 BeamSearch::BeamSearch(CubeRecoContext *cntxt, bool word_mode) {
00028   cntxt_ = cntxt;
00029   seg_pt_cnt_ = 0;
00030   col_cnt_ = 1;
00031   col_ = NULL;
00032   word_mode_ = word_mode;
00033 }
00034 
00035 // Cleanup the lattice corresponding to the last search
00036 void BeamSearch::Cleanup() {
00037   if (col_ != NULL) {
00038     for (int col = 0; col < col_cnt_; col++) {
00039       if (col_[col])
00040         delete col_[col];
00041     }
00042     delete []col_;
00043   }
00044   col_ = NULL;
00045 }
00046 
00047 BeamSearch::~BeamSearch() {
00048   Cleanup();
00049 }
00050 
00051 // Creates a set of children nodes emerging from a parent node based on
00052 // the character alternate list and the language model.
00053 void BeamSearch::CreateChildren(SearchColumn *out_col, LangModel *lang_mod,
00054                                 SearchNode *parent_node,
00055                                 LangModEdge *lm_parent_edge,
00056                                 CharAltList *char_alt_list, int extra_cost) {
00057   // get all the edges from this parent
00058   int edge_cnt;
00059   LangModEdge **lm_edges = lang_mod->GetEdges(char_alt_list,
00060                                               lm_parent_edge, &edge_cnt);
00061   if (lm_edges) {
00062     // add them to the ending column with the appropriate parent
00063     for (int edge = 0; edge < edge_cnt; edge++) {
00064       // add a node to the column if the current column is not the
00065       // last one, or if the lang model edge indicates it is valid EOW
00066       if (!cntxt_->NoisyInput() && out_col->ColIdx() >= seg_pt_cnt_ &&
00067           !lm_edges[edge]->IsEOW()) {
00068         // free edge since no object is going to own it
00069         delete lm_edges[edge];
00070         continue;
00071       }
00072 
00073       // compute the recognition cost of this node
00074       int recognition_cost =  MIN_PROB_COST;
00075       if (char_alt_list && char_alt_list->AltCount() > 0) {
00076         recognition_cost = MAX(0, char_alt_list->ClassCost(
00077             lm_edges[edge]->ClassID()));
00078         // Add the no space cost. This should zero in word mode
00079         recognition_cost += extra_cost;
00080       }
00081 
00082       // Note that the edge will be freed inside the column if
00083       // AddNode is called
00084       if (recognition_cost >= 0) {
00085         out_col->AddNode(lm_edges[edge], recognition_cost, parent_node,
00086                          cntxt_);
00087       } else {
00088         delete lm_edges[edge];
00089       }
00090     }  // edge
00091     // free edge array
00092     delete []lm_edges;
00093   }  // lm_edges
00094 }
00095 
00096 // Performs a beam seach in the specified search using the specified
00097 // language model; returns an alternate list of possible words as a result.
00098 WordAltList * BeamSearch::Search(SearchObject *srch_obj, LangModel *lang_mod) {
00099   // verifications
00100   if (!lang_mod)
00101     lang_mod = cntxt_->LangMod();
00102   if (!lang_mod) {
00103     fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct "
00104             "LangModel\n");
00105     return NULL;
00106   }
00107 
00108   // free existing state
00109   Cleanup();
00110 
00111   // get seg pt count
00112   seg_pt_cnt_ = srch_obj->SegPtCnt();
00113   if (seg_pt_cnt_ < 0) {
00114     return NULL;
00115   }
00116   col_cnt_ = seg_pt_cnt_ + 1;
00117 
00118   // disregard suspicious cases
00119   if (seg_pt_cnt_ > 128) {
00120     fprintf(stderr, "Cube ERROR (BeamSearch::Search): segment point count is "
00121             "suspiciously high; bailing out\n");
00122     return NULL;
00123   }
00124 
00125   // alloc memory for columns
00126   col_ = new SearchColumn *[col_cnt_];
00127   if (!col_) {
00128     fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct "
00129             "SearchColumn array\n");
00130     return NULL;
00131   }
00132   memset(col_, 0, col_cnt_ * sizeof(*col_));
00133 
00134   // for all possible segments
00135   for (int end_seg = 1; end_seg <= (seg_pt_cnt_ + 1); end_seg++) {
00136     // create a search column
00137     col_[end_seg - 1] = new SearchColumn(end_seg - 1,
00138                                          cntxt_->Params()->BeamWidth());
00139     if (!col_[end_seg - 1]) {
00140       fprintf(stderr, "Cube ERROR (BeamSearch::Search): could not construct "
00141               "SearchColumn for column %d\n", end_seg - 1);
00142       return NULL;
00143     }
00144 
00145     // for all possible start segments
00146     int init_seg = MAX(0, end_seg - cntxt_->Params()->MaxSegPerChar());
00147     for (int strt_seg = init_seg; strt_seg < end_seg; strt_seg++) {
00148       int parent_nodes_cnt;
00149       SearchNode **parent_nodes;
00150 
00151       // for the root segment, we do not have a parent
00152       if (strt_seg == 0) {
00153         parent_nodes_cnt = 1;
00154         parent_nodes = NULL;
00155       } else {
00156         // for all the existing nodes in the starting column
00157         parent_nodes_cnt = col_[strt_seg - 1]->NodeCount();
00158         parent_nodes = col_[strt_seg - 1]->Nodes();
00159       }
00160 
00161       // run the shape recognizer
00162       CharAltList *char_alt_list = srch_obj->RecognizeSegment(strt_seg - 1,
00163                                                               end_seg - 1);
00164       // for all the possible parents
00165       for (int parent_idx = 0; parent_idx < parent_nodes_cnt; parent_idx++) {
00166         // point to the parent node
00167         SearchNode *parent_node = !parent_nodes ? NULL
00168             : parent_nodes[parent_idx];
00169         LangModEdge *lm_parent_edge = !parent_node ? lang_mod->Root()
00170             : parent_node->LangModelEdge();
00171 
00172         // compute the cost of not having spaces within the segment range
00173         int contig_cost = srch_obj->NoSpaceCost(strt_seg - 1, end_seg - 1);
00174 
00175         // In phrase mode, compute the cost of not having a space before
00176         // this character
00177         int no_space_cost = 0;
00178         if (!word_mode_ && strt_seg > 0) {
00179           no_space_cost = srch_obj->NoSpaceCost(strt_seg - 1);
00180         }
00181 
00182         // if the no space cost is low enough
00183         if ((contig_cost + no_space_cost) < MIN_PROB_COST) {
00184           // Add the children nodes
00185           CreateChildren(col_[end_seg - 1], lang_mod, parent_node,
00186                          lm_parent_edge, char_alt_list,
00187                          contig_cost + no_space_cost);
00188         }
00189 
00190         // In phrase mode and if not starting at the root
00191         if (!word_mode_ && strt_seg > 0) {  // parent_node must be non-NULL
00192           // consider starting a new word for nodes that are valid EOW
00193           if (parent_node->LangModelEdge()->IsEOW()) {
00194             // get the space cost
00195             int space_cost = srch_obj->SpaceCost(strt_seg - 1);
00196             // if the space cost is low enough
00197             if ((contig_cost + space_cost) < MIN_PROB_COST) {
00198               // Restart the language model and add nodes as children to the
00199               // space node.
00200               CreateChildren(col_[end_seg - 1], lang_mod, parent_node, NULL,
00201                              char_alt_list, contig_cost + space_cost);
00202             }
00203           }
00204         }
00205       }  // parent
00206     }  // strt_seg
00207 
00208     // prune the column nodes
00209     col_[end_seg - 1]->Prune();
00210 
00211     // Free the column hash table. No longer needed
00212     col_[end_seg - 1]->FreeHashTable();
00213   }  // end_seg
00214 
00215   WordAltList *alt_list = CreateWordAltList(srch_obj);
00216   return alt_list;
00217 }
00218 
00219 // Creates a Word alternate list from the results in the lattice.
00220 WordAltList *BeamSearch::CreateWordAltList(SearchObject *srch_obj) {
00221   // create an alternate list of all the nodes in the last column
00222   int node_cnt = col_[col_cnt_ - 1]->NodeCount();
00223   SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes();
00224   CharBigrams *bigrams = cntxt_->Bigrams();
00225   WordUnigrams *word_unigrams = cntxt_->WordUnigramsObj();
00226 
00227   // Save the index of the best-cost node before the alt list is
00228   // sorted, so that we can retrieve it from the node list when backtracking.
00229   best_presorted_node_idx_ = 0;
00230   int best_cost = -1;
00231 
00232   if (node_cnt <= 0)
00233     return NULL;
00234 
00235   // start creating the word alternate list
00236   WordAltList *alt_list = new WordAltList(node_cnt + 1);
00237   for (int node_idx = 0; node_idx < node_cnt; node_idx++) {
00238     // recognition cost
00239     int recognition_cost = srch_nodes[node_idx]->BestCost();
00240     // compute the size cost of the alternate
00241     char_32 *ch_buff = NULL;
00242     int size_cost = SizeCost(srch_obj, srch_nodes[node_idx], &ch_buff);
00243     // accumulate other costs
00244     if (ch_buff) {
00245       int cost = 0;
00246       // char bigram cost
00247       int bigram_cost = !bigrams ? 0 :
00248           bigrams->Cost(ch_buff, cntxt_->CharacterSet());
00249       // word unigram cost
00250       int unigram_cost = !word_unigrams ? 0 :
00251           word_unigrams->Cost(ch_buff, cntxt_->LangMod(),
00252                               cntxt_->CharacterSet());
00253       // overall cost
00254       cost = static_cast<int>(
00255           (size_cost * cntxt_->Params()->SizeWgt()) +
00256           (bigram_cost * cntxt_->Params()->CharBigramWgt()) +
00257           (unigram_cost * cntxt_->Params()->WordUnigramWgt()) +
00258           (recognition_cost * cntxt_->Params()->RecoWgt()));
00259 
00260       // insert into word alt list
00261       alt_list->Insert(ch_buff, cost,
00262                        static_cast<void *>(srch_nodes[node_idx]));
00263       // Note that strict < is necessary because WordAltList::Sort()
00264       // uses it in a bubble sort to swap entries.
00265       if (best_cost < 0 || cost < best_cost) {
00266         best_presorted_node_idx_ = node_idx;
00267         best_cost = cost;
00268       }
00269       delete []ch_buff;
00270     }
00271   }
00272 
00273   // sort the alternates based on cost
00274   alt_list->Sort();
00275   return alt_list;
00276 }
00277 
00278 // Returns the lattice column corresponding to the specified column index.
00279 SearchColumn *BeamSearch::Column(int col) const {
00280   if (col < 0 || col >= col_cnt_ || !col_)
00281     return NULL;
00282   return col_[col];
00283 }
00284 
00285 // Returns the best node in the last column of last performed search.
00286 SearchNode *BeamSearch::BestNode() const {
00287   if (col_cnt_ < 1 || !col_ || !col_[col_cnt_ - 1])
00288     return NULL;
00289 
00290   int node_cnt = col_[col_cnt_ - 1]->NodeCount();
00291   SearchNode **srch_nodes = col_[col_cnt_ - 1]->Nodes();
00292   if (node_cnt < 1 || !srch_nodes || !srch_nodes[0])
00293     return NULL;
00294   return srch_nodes[0];
00295 }
00296 
00297 // Returns the string corresponding to the specified alt.
00298 char_32 *BeamSearch::Alt(int alt) const {
00299   // get the last column of the lattice
00300   if (col_cnt_ <= 0)
00301     return NULL;
00302 
00303   SearchColumn *srch_col = col_[col_cnt_ - 1];
00304   if (!srch_col)
00305     return NULL;
00306 
00307   // point to the last node in the selected path
00308   if (alt >= srch_col->NodeCount() || srch_col->Nodes() == NULL) {
00309     return NULL;
00310   }
00311 
00312   SearchNode *srch_node = srch_col->Nodes()[alt];
00313   if (!srch_node)
00314     return  NULL;
00315 
00316   // get string
00317   char_32 *str32 = srch_node->PathString();
00318   if (!str32)
00319     return NULL;
00320 
00321   return str32;
00322 }
00323 
00324 // Backtracks from the specified node index and returns the corresponding
00325 // character mapped segments and character count. Optional return
00326 // arguments are the char_32 result string and character bounding
00327 // boxes, if non-NULL values are passed in.
00328 CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, int node_index,
00329                                  int *char_cnt, char_32 **str32,
00330                                  Boxa **char_boxes) const {
00331   // get the last column of the lattice
00332   if (col_cnt_ <= 0)
00333     return NULL;
00334   SearchColumn *srch_col = col_[col_cnt_ - 1];
00335   if (!srch_col)
00336     return NULL;
00337 
00338   // point to the last node in the selected path
00339   if (node_index >= srch_col->NodeCount() || !srch_col->Nodes())
00340     return NULL;
00341 
00342   SearchNode *srch_node = srch_col->Nodes()[node_index];
00343   if (!srch_node)
00344     return NULL;
00345   return BackTrack(srch_obj, srch_node, char_cnt, str32, char_boxes);
00346 }
00347 
00348 // Backtracks from the specified node index and returns the corresponding
00349 // character mapped segments and character count. Optional return
00350 // arguments are the char_32 result string and character bounding
00351 // boxes, if non-NULL values are passed in.
00352 CharSamp **BeamSearch::BackTrack(SearchObject *srch_obj, SearchNode *srch_node,
00353                                  int *char_cnt, char_32 **str32,
00354                                  Boxa **char_boxes) const {
00355   if (!srch_node)
00356     return NULL;
00357 
00358   if (str32) {
00359     if (*str32)
00360       delete [](*str32);  // clear existing value
00361     *str32 = srch_node->PathString();
00362     if (!*str32)
00363       return NULL;
00364   }
00365 
00366   if (char_boxes && *char_boxes) {
00367     boxaDestroy(char_boxes);  // clear existing value
00368   }
00369 
00370   CharSamp **chars;
00371   chars = SplitByNode(srch_obj, srch_node, char_cnt, char_boxes);
00372   if (!chars && str32)
00373     delete []*str32;
00374   return chars;
00375 }
00376 
00377 // Backtracks from the given lattice node and return the corresponding
00378 // char mapped segments and character count. The character bounding
00379 // boxes are optional return arguments, if non-NULL values are passed in.
00380 CharSamp **BeamSearch::SplitByNode(SearchObject *srch_obj,
00381                                    SearchNode *srch_node,
00382                                    int *char_cnt,
00383                                    Boxa **char_boxes) const {
00384   // Count the characters (could be less than the path length when in
00385   // phrase mode)
00386   *char_cnt = 0;
00387   SearchNode *node = srch_node;
00388   while (node) {
00389     node = node->ParentNode();
00390     (*char_cnt)++;
00391   }
00392 
00393   if (*char_cnt == 0)
00394     return NULL;
00395 
00396   // Allocate box array
00397   if (char_boxes) {
00398     if (*char_boxes)
00399       boxaDestroy(char_boxes);  // clear existing value
00400     *char_boxes = boxaCreate(*char_cnt);
00401     if (*char_boxes == NULL)
00402       return NULL;
00403   }
00404 
00405   // Allocate memory for CharSamp array.
00406   CharSamp **chars = new CharSamp *[*char_cnt];
00407   if (!chars) {
00408     if (char_boxes)
00409       boxaDestroy(char_boxes);
00410     return NULL;
00411   }
00412 
00413   int ch_idx = *char_cnt - 1;
00414   int seg_pt_cnt = srch_obj->SegPtCnt();
00415   bool success=true;
00416   while (srch_node && ch_idx >= 0) {
00417     // Parent node (could be null)
00418     SearchNode *parent_node = srch_node->ParentNode();
00419 
00420     // Get the seg pts corresponding to the search node
00421     int st_col = !parent_node ? 0 : parent_node->ColIdx() + 1;
00422     int st_seg_pt = st_col <= 0 ? -1 : st_col - 1;
00423     int end_col = srch_node->ColIdx();
00424     int end_seg_pt = end_col >= seg_pt_cnt ? seg_pt_cnt : end_col;
00425 
00426     // Get a char sample corresponding to the segmentation points
00427     CharSamp *samp = srch_obj->CharSample(st_seg_pt, end_seg_pt);
00428     if (!samp) {
00429       success = false;
00430       break;
00431     }
00432     samp->SetLabel(srch_node->NodeString());
00433     chars[ch_idx] = samp;
00434     if (char_boxes) {
00435       // Create the corresponding character bounding box
00436       Box *char_box = boxCreate(samp->Left(), samp->Top(),
00437                                 samp->Width(), samp->Height());
00438       if (!char_box) {
00439         success = false;
00440         break;
00441       }
00442       boxaAddBox(*char_boxes, char_box, L_INSERT);
00443     }
00444     srch_node = parent_node;
00445     ch_idx--;
00446   }
00447   if (!success) {
00448     delete []chars;
00449     if (char_boxes)
00450       boxaDestroy(char_boxes);
00451     return NULL;
00452   }
00453 
00454   // Reverse the order of boxes.
00455   if (char_boxes) {
00456     int char_boxa_size = boxaGetCount(*char_boxes);
00457     int limit = char_boxa_size / 2;
00458     for (int i = 0; i < limit; ++i) {
00459       int box1_idx = i;
00460       int box2_idx = char_boxa_size - 1 - i;
00461       Box *box1 = boxaGetBox(*char_boxes, box1_idx, L_CLONE);
00462       Box *box2 = boxaGetBox(*char_boxes, box2_idx, L_CLONE);
00463       boxaReplaceBox(*char_boxes, box2_idx, box1);
00464       boxaReplaceBox(*char_boxes, box1_idx, box2);
00465     }
00466   }
00467   return chars;
00468 }
00469 
00470 // Returns the size cost of a string for a lattice path that
00471 // ends at the specified lattice node.
00472 int BeamSearch::SizeCost(SearchObject *srch_obj, SearchNode *node,
00473                          char_32 **str32) const {
00474   CharSamp **chars = NULL;
00475   int char_cnt = 0;
00476   if (!node)
00477     return 0;
00478   // Backtrack to get string and character segmentation
00479   chars = BackTrack(srch_obj, node, &char_cnt, str32, NULL);
00480   if (!chars)
00481     return WORST_COST;
00482   int size_cost = (cntxt_->SizeModel() == NULL) ? 0 :
00483       cntxt_->SizeModel()->Cost(chars, char_cnt);
00484   delete []chars;
00485   return size_cost;
00486 }
00487 }  // namespace tesesract