Tesseract  3.02
tesseract-ocr/classify/mastertrainer.h
Go to the documentation of this file.
00001 // Copyright 2010 Google Inc. All Rights Reserved.
00002 // Author: rays@google.com (Ray Smith)
00004 // File:        mastertrainer.h
00005 // Description: Trainer to build the MasterClassifier.
00006 // Author:      Ray Smith
00007 // Created:     Wed Nov 03 18:07:01 PDT 2010
00008 //
00009 // (C) Copyright 2010, Google Inc.
00010 // Licensed under the Apache License, Version 2.0 (the "License");
00011 // you may not use this file except in compliance with the License.
00012 // You may obtain a copy of the License at
00013 // http://www.apache.org/licenses/LICENSE-2.0
00014 // Unless required by applicable law or agreed to in writing, software
00015 // distributed under the License is distributed on an "AS IS" BASIS,
00016 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00017 // See the License for the specific language governing permissions and
00018 // limitations under the License.
00019 //
00021 
00022 #ifndef TESSERACT_TRAINING_MASTERTRAINER_H__
00023 #define TESSERACT_TRAINING_MASTERTRAINER_H__
00024 
00028 #include "classify.h"
00029 #include "cluster.h"
00030 #include "intfx.h"
00031 #include "elst.h"
00032 #include "featdefs.h"
00033 #include "fontinfo.h"
00034 #include "indexmapbidi.h"
00035 #include "intfeaturespace.h"
00036 #include "intfeaturemap.h"
00037 #include "intmatcher.h"
00038 #include "params.h"
00039 #include "shapetable.h"
00040 #include "trainingsample.h"
00041 #include "trainingsampleset.h"
00042 #include "unicharset.h"
00043 
00044 namespace tesseract {
00045 
00046 class ShapeClassifier;
00047 
00048 // Simple struct to hold the distance between two shapes during clustering.
00049 struct ShapeDist {
00050   ShapeDist() : shape1(0), shape2(0), distance(0.0f) {}
00051   ShapeDist(int s1, int s2, float dist)
00052     : shape1(s1), shape2(s2), distance(dist) {}
00053 
00054   // Sort operator to sort in ascending order of distance.
00055   bool operator<(const ShapeDist& other) const {
00056     return distance < other.distance;
00057   }
00058 
00059   int shape1;
00060   int shape2;
00061   float distance;
00062 };
00063 
00064 // Class to encapsulate training processes that use the TrainingSampleSet.
00065 // Initially supports shape clustering and mftrainining.
00066 // Other important features of the MasterTrainer are conditioning the data
00067 // by outlier elimination, replication with perturbation, and serialization.
00068 class MasterTrainer {
00069  public:
00070   MasterTrainer(NormalizationMode norm_mode, bool shape_analysis,
00071                 bool replicate_samples, int debug_level);
00072   ~MasterTrainer();
00073 
00074   // Writes to the given file. Returns false in case of error.
00075   bool Serialize(FILE* fp) const;
00076   // Reads from the given file. Returns false in case of error.
00077   // If swap is true, assumes a big/little-endian swap is needed.
00078   bool DeSerialize(bool swap, FILE* fp);
00079 
00080   // Loads an initial unicharset, or sets one up if the file cannot be read.
00081   void LoadUnicharset(const char* filename);
00082 
00083   // Sets the feature space definition.
00084   void SetFeatureSpace(const IntFeatureSpace& fs) {
00085     feature_space_ = fs;
00086     feature_map_.Init(fs);
00087   }
00088 
00089   // Reads the samples and their features from the given file,
00090   // adding them to the trainer with the font_id from the content of the file.
00091   // If verification, then these are verification samples, not training.
00092   void ReadTrainingSamples(FILE  *fp,
00093                            const FEATURE_DEFS_STRUCT& feature_defs,
00094                            bool verification);
00095 
00096   // Adds the given single sample to the trainer, setting the classid
00097   // appropriately from the given unichar_str.
00098   void AddSample(bool verification, const char* unichar_str,
00099                  TrainingSample* sample);
00100 
00101   // Loads all pages from the given tif filename and append to page_images_.
00102   // Must be called after ReadTrainingSamples, as the current number of images
00103   // is used as an offset for page numbers in the samples.
00104   void LoadPageImages(const char* filename);
00105 
00106   // Cleans up the samples after initial load from the tr files, and prior to
00107   // saving the MasterTrainer:
00108   // Remaps fragmented chars if running shape anaylsis.
00109   // Sets up the samples appropriately for class/fontwise access.
00110   // Deletes outlier samples.
00111   void PostLoadCleanup();
00112 
00113   // Gets the samples ready for training. Use after both
00114   // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
00115   // Re-indexes the features and computes canonical and cloud features.
00116   void PreTrainingSetup();
00117 
00118   // Sets up the master_shapes_ table, which tells which fonts should stay
00119   // together until they get to a leaf node classifier.
00120   void SetupMasterShapes();
00121 
00122   // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
00123   // fragments and n-grams (all incorrectly segmented characters).
00124   // Various training functions may result in incorrectly segmented characters
00125   // being added to the unicharset of the main samples, perhaps because they
00126   // form a "radical" decomposition of some (Indic) grapheme, or because they
00127   // just look the same as a real character (like rn/m)
00128   // This function moves all the junk samples, to the main samples_ set, but
00129   // desirable junk, being any sample for which the unichar already exists in
00130   // the samples_ unicharset gets the unichar-ids re-indexed to match, but
00131   // anything else gets re-marked as unichar_id 0 (space character) to identify
00132   // it as junk to the error counter.
00133   void IncludeJunk();
00134 
00135   // Replicates the samples and perturbs them if the enable_replication_ flag
00136   // is set. MUST be used after the last call to OrganizeByFontAndClass on
00137   // the training samples, ie after IncludeJunk if it is going to be used, as
00138   // OrganizeByFontAndClass will eat the replicated samples into the regular
00139   // samples.
00140   void ReplicateAndRandomizeSamplesIfRequired();
00141 
00142   // Loads the basic font properties file into fontinfo_table_.
00143   // Returns false on failure.
00144   bool LoadFontInfo(const char* filename);
00145 
00146   // Loads the xheight font properties file into xheights_.
00147   // Returns false on failure.
00148   bool LoadXHeights(const char* filename);
00149 
00150   // Reads spacing stats from filename and adds them to fontinfo_table.
00151   // Returns false on failure.
00152   bool AddSpacingInfo(const char *filename);
00153 
00154   // Returns the font id corresponding to the given font name.
00155   // Returns -1 if the font cannot be found.
00156   int GetFontInfoId(const char* font_name);
00157   // Returns the font_id of the closest matching font name to the given
00158   // filename. It is assumed that a substring of the filename will match
00159   // one of the fonts. If more than one is matched, the longest is returned.
00160   int GetBestMatchingFontInfoId(const char* filename);
00161 
00162   // Sets up a flat shapetable with one shape per class/font combination.
00163   void SetupFlatShapeTable(ShapeTable* shape_table);
00164 
00165   // Sets up a Clusterer for mftraining on a single shape_id.
00166   // Call FreeClusterer on the return value after use.
00167   CLUSTERER* SetupForClustering(const ShapeTable& shape_table,
00168                                 const FEATURE_DEFS_STRUCT& feature_defs,
00169                                 int shape_id, int* num_samples);
00170 
00171   // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
00172   // to the given inttemp_file, and the corresponding pffmtable.
00173   // The unicharset is the original encoding of graphemes, and shape_set should
00174   // match the size of the shape_table, and may possibly be totally fake.
00175   void WriteInttempAndPFFMTable(const UNICHARSET& unicharset,
00176                                 const UNICHARSET& shape_set,
00177                                 const ShapeTable& shape_table,
00178                                 CLASS_STRUCT* float_classes,
00179                                 const char* inttemp_file,
00180                                 const char* pffmtable_file);
00181 
00182   const UNICHARSET& unicharset() const {
00183     return samples_.unicharset();
00184   }
00185   TrainingSampleSet* GetSamples() {
00186     return &samples_;
00187   }
00188   const ShapeTable& master_shapes() const {
00189     return master_shapes_;
00190   }
00191 
00192   // Generates debug output relating to the canonical distance between the
00193   // two given UTF8 grapheme strings.
00194   void DebugCanonical(const char* unichar_str1, const char* unichar_str2);
00195   #ifndef GRAPHICS_DISABLED
00196   // Debugging for cloud/canonical features.
00197   // Displays a Features window containing:
00198   // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
00199   // displays the canonical features of the char/font combination in red.
00200   // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
00201   // displays the cloud feature of the char/font combination in green.
00202   // The canonical features are drawn first to show which ones have no
00203   // matches in the cloud features.
00204   // Until the features window is destroyed, each click in the features window
00205   // will display the samples that have that feature in a separate window.
00206   void DisplaySamples(const char* unichar_str1, int cloud_font,
00207                       const char* unichar_str2, int canonical_font);
00208   #endif  // GRAPHICS_DISABLED
00209 
00210   // Tests the given test_classifier on the internal samples.
00211   // See TestClassifier for details.
00212   void TestClassifierOnSamples(int report_level,
00213                                bool replicate_samples,
00214                                ShapeClassifier* test_classifier,
00215                                STRING* report_string);
00216   // Tests the given test_classifier on the given samples
00217   // report_levels:
00218   // 0 = no output.
00219   // 1 = bottom-line error rate.
00220   // 2 = bottom-line error rate + time.
00221   // 3 = font-level error rate + time.
00222   // 4 = list of all errors + short classifier debug output on 16 errors.
00223   // 5 = list of all errors + short classifier debug output on 25 errors.
00224   // If replicate_samples is true, then the test is run on an extended test
00225   // sample including replicated and systematically perturbed samples.
00226   // If report_string is non-NULL, a summary of the results for each font
00227   // is appended to the report_string.
00228   double TestClassifier(int report_level,
00229                         bool replicate_samples,
00230                         TrainingSampleSet* samples,
00231                         ShapeClassifier* test_classifier,
00232                         STRING* report_string);
00233 
00234   // Returns the average (in some sense) distance between the two given
00235   // shapes, which may contain multiple fonts and/or unichars.
00236   // This function is public to facilitate testing.
00237   float ShapeDistance(const ShapeTable& shapes, int s1, int s2);
00238 
00239  private:
00240   // Replaces samples that are always fragmented with the corresponding
00241   // fragment samples.
00242   void ReplaceFragmentedSamples();
00243 
00244   // Runs a hierarchical agglomerative clustering to merge shapes in the given
00245   // shape_table, while satisfying the given constraints:
00246   // * End with at least min_shapes left in shape_table,
00247   // * No shape shall have more than max_shape_unichars in it,
00248   // * Don't merge shapes where the distance between them exceeds max_dist.
00249   void ClusterShapes(int min_shapes, int max_shape_unichars,
00250                      float max_dist, ShapeTable* shape_table);
00251 
00252  private:
00253   NormalizationMode norm_mode_;
00254   // Character set we are training for.
00255   UNICHARSET unicharset_;
00256   // Original feature space. Subspace mapping is contained in feature_map_.
00257   IntFeatureSpace feature_space_;
00258   TrainingSampleSet samples_;
00259   TrainingSampleSet junk_samples_;
00260   TrainingSampleSet verify_samples_;
00261   // Master shape table defines what fonts stay together until the leaves.
00262   ShapeTable master_shapes_;
00263   // Flat shape table has each unichar/font id pair in a separate shape.
00264   ShapeTable flat_shapes_;
00265   // Font metrics gathered from multiple files.
00266   UnicityTable<FontInfo> fontinfo_table_;
00267   // Array of xheights indexed by font ids in fontinfo_table_;
00268   GenericVector<int> xheights_;
00269 
00270   // Non-serialized data initialized by other means or used temporarily
00271   // during loading of training samples.
00272   // Number of different class labels in unicharset_.
00273   int charsetsize_;
00274   // Flag to indicate that we are running shape analysis and need fragments
00275   // fixing.
00276   bool enable_shape_anaylsis_;
00277   // Flag to indicate that sample replication is required.
00278   bool enable_replication_;
00279   // Flag to indicate that junk should be included in samples_.
00280   bool include_junk_;
00281   // Array of classids of fragments that replace the correctly segmented chars.
00282   int* fragments_;
00283   // Classid of previous correctly segmented sample that was added.
00284   int prev_unichar_id_;
00285   // Debug output control.
00286   int debug_level_;
00287   // Feature map used to construct reduced feature spaces for compact
00288   // classifiers.
00289   IntFeatureMap feature_map_;
00290   // Vector of Pix pointers used for classifiers that need the image.
00291   // Indexed by page_num_ in the samples.
00292   // These images are owned by the trainer and need to be pixDestroyed.
00293   GenericVector<Pix*> page_images_;
00294 };
00295 
00296 }  // namespace tesseract.
00297 
00298 #endif