Tesseract  3.02
tesseract-ocr/classify/sampleiterator.cpp
Go to the documentation of this file.
00001 // Copyright 2011 Google Inc. All Rights Reserved.
00002 // Author: rays@google.com (Ray Smith)
00003 //
00004 // Licensed under the Apache License, Version 2.0 (the "License");
00005 // you may not use this file except in compliance with the License.
00006 // You may obtain a copy of the License at
00007 // http://www.apache.org/licenses/LICENSE-2.0
00008 // Unless required by applicable law or agreed to in writing, software
00009 // distributed under the License is distributed on an "AS IS" BASIS,
00010 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00011 // See the License for the specific language governing permissions and
00012 // limitations under the License.
00013 //
00015 
00016 #include "sampleiterator.h"
00017 
00018 #include "indexmapbidi.h"
00019 #include "shapetable.h"
00020 #include "trainingsample.h"
00021 #include "trainingsampleset.h"
00022 
00023 namespace tesseract {
00024 
00025 // ================== SampleIterator Implementation =================
00026 
00027 SampleIterator::SampleIterator()
00028   : charset_map_(NULL),
00029     shape_table_(NULL),
00030     sample_set_(NULL),
00031     randomize_(false),
00032     owned_shape_table_(NULL) {
00033   num_shapes_ = 0;
00034   Begin();
00035 }
00036 
00037 SampleIterator::~SampleIterator() {
00038   Clear();
00039 }
00040 
00041 void SampleIterator::Clear() {
00042   delete owned_shape_table_;
00043   owned_shape_table_ = NULL;
00044 }
00045 
00046 // See class comment for arguments.
00047 void SampleIterator::Init(const IndexMapBiDi* charset_map,
00048                           const ShapeTable* shape_table,
00049                           bool randomize,
00050                           TrainingSampleSet* sample_set) {
00051   Clear();
00052   charset_map_ = charset_map;
00053   shape_table_ = shape_table;
00054   sample_set_ = sample_set;
00055   randomize_ = randomize;
00056   if (shape_table_ == NULL && charset_map_ != NULL) {
00057     // The caller wishes to iterate by class. The easiest way to do this
00058     // is to create a dummy shape_table_ that we will own.
00059     int num_fonts = sample_set_->NumFonts();
00060     owned_shape_table_ = new ShapeTable(sample_set_->unicharset());
00061     int charsetsize = sample_set_->unicharset().size();
00062     for (int c = 0; c < charsetsize; ++c) {
00063       // We always add a shape for each character to keep the index in sync
00064       // with the unichar_id.
00065       int shape_id = owned_shape_table_->AddShape(c, 0);
00066       for (int f = 1; f < num_fonts; ++f) {
00067         if (sample_set_->NumClassSamples(f, c, true) > 0) {
00068           owned_shape_table_->AddToShape(shape_id, c, f);
00069         }
00070       }
00071     }
00072     shape_table_ = owned_shape_table_;
00073   }
00074   if (shape_table_ != NULL) {
00075     num_shapes_ = shape_table_->NumShapes();
00076   } else {
00077     num_shapes_ = randomize ? sample_set_->num_samples()
00078                             : sample_set_->num_raw_samples();
00079   }
00080   Begin();
00081 }
00082 
00083 // Iterator functions designed for use with a simple for loop:
00084 // for (it.Begin(); !it.AtEnd(); it.Next()) {
00085 //   const TrainingSample& sample = it.GetSample();
00086 // }
00087 void SampleIterator::Begin() {
00088   shape_index_ = -1;
00089   shape_char_index_ = 0;
00090   num_shape_chars_ = 0;
00091   shape_font_index_ = 0;
00092   num_shape_fonts_ = 0;
00093   sample_index_ = 0;
00094   num_samples_ = 0;
00095   // Find the first indexable sample.
00096   Next();
00097 }
00098 
00099 bool SampleIterator::AtEnd() const {
00100   return shape_index_ >= num_shapes_;
00101 }
00102 
00103 const TrainingSample& SampleIterator::GetSample() const {
00104   if (shape_table_ != NULL) {
00105     const UnicharAndFonts* shape_entry = GetShapeEntry();
00106     int char_id = shape_entry->unichar_id;
00107     int font_id = shape_entry->font_ids[shape_font_index_];
00108     return *sample_set_->GetSample(font_id, char_id, sample_index_);
00109   } else {
00110     return *sample_set_->GetSample(shape_index_);
00111   }
00112 }
00113 
00114 TrainingSample* SampleIterator::MutableSample() const {
00115   if (shape_table_ != NULL) {
00116     const UnicharAndFonts* shape_entry = GetShapeEntry();
00117     int char_id = shape_entry->unichar_id;
00118     int font_id = shape_entry->font_ids[shape_font_index_];
00119     return sample_set_->MutableSample(font_id, char_id, sample_index_);
00120   } else {
00121     return sample_set_->mutable_sample(shape_index_);
00122   }
00123 }
00124 
00125 // Returns the total index (from the original set of samples) of the current
00126 // sample.
00127 int SampleIterator::GlobalSampleIndex() const {
00128   if (shape_table_ != NULL) {
00129     const UnicharAndFonts* shape_entry = GetShapeEntry();
00130     int char_id = shape_entry->unichar_id;
00131     int font_id = shape_entry->font_ids[shape_font_index_];
00132     return sample_set_->GlobalSampleIndex(font_id, char_id, sample_index_);
00133   } else {
00134     return shape_index_;
00135   }
00136 }
00137 
00138 // Returns the index of the current sample in compact charset space, so
00139 // in a 2-class problem between x and y, the returned indices will all be
00140 // 0 or 1, and have nothing to do with the unichar_ids.
00141 // If the charset_map_ is NULL, then this is equal to GetSparseClassID().
00142 int SampleIterator::GetCompactClassID() const {
00143   return charset_map_ != NULL ? charset_map_->SparseToCompact(shape_index_)
00144                               : GetSparseClassID();
00145 }
00146 // Returns the index of the current sample in sparse charset space, so
00147 // in a 2-class problem between x and y, the returned indices will all be
00148 // x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids
00149 // with a shape_table_.
00150 int SampleIterator::GetSparseClassID() const {
00151   return shape_table_ != NULL ? shape_index_ : GetSample().class_id();
00152 }
00153 
00154 // Moves on to the next indexable sample. If the end is reached, leaves
00155 // the state such that AtEnd() is true.
00156 void SampleIterator::Next() {
00157   if (shape_table_ != NULL) {
00158     // Next sample in this class/font combination.
00159     ++sample_index_;
00160     if (sample_index_ < num_samples_)
00161       return;
00162     // Next font in this class in this shape.
00163     sample_index_ = 0;
00164     do {
00165       ++shape_font_index_;
00166       if (shape_font_index_ >= num_shape_fonts_) {
00167         // Next unichar in this shape.
00168         shape_font_index_ = 0;
00169         ++shape_char_index_;
00170         if (shape_char_index_ >= num_shape_chars_) {
00171           // Find the next shape that is mapped in the charset_map_.
00172           shape_char_index_ = 0;
00173           do {
00174             ++shape_index_;
00175           } while (shape_index_ < num_shapes_ &&
00176                    charset_map_ != NULL &&
00177                    charset_map_->SparseToCompact(shape_index_) < 0);
00178           if (shape_index_ >= num_shapes_)
00179             return;  // The end.
00180           num_shape_chars_ = shape_table_->GetShape(shape_index_).size();
00181         }
00182       }
00183       const UnicharAndFonts* shape_entry = GetShapeEntry();
00184       num_shape_fonts_ = shape_entry->font_ids.size();
00185       int char_id = shape_entry->unichar_id;
00186       int font_id = shape_entry->font_ids[shape_font_index_];
00187       num_samples_ = sample_set_->NumClassSamples(font_id, char_id, randomize_);
00188     } while (num_samples_ == 0);
00189   } else {
00190     // We are just iterating over the samples.
00191     ++shape_index_;
00192   }
00193 }
00194 
00195 // Returns the size of the compact charset space.
00196 int SampleIterator::CompactCharsetSize() const {
00197   return charset_map_ != NULL ? charset_map_->CompactSize()
00198                               : SparseCharsetSize();
00199 }
00200 
00201 // Returns the size of the sparse charset space.
00202 int SampleIterator::SparseCharsetSize() const {
00203   return charset_map_ != NULL
00204       ? charset_map_->SparseSize()
00205       : (shape_table_ != NULL ? shape_table_->NumShapes()
00206                               : sample_set_->charsetsize());
00207 }
00208 
00209 // Apply the supplied feature_space/feature_map transform to all samples
00210 // accessed by this iterator.
00211 void SampleIterator::MapSampleFeatures(const IntFeatureMap& feature_map) {
00212   for (Begin(); !AtEnd(); Next()) {
00213     TrainingSample* sample = MutableSample();
00214     sample->MapFeatures(feature_map);
00215   }
00216 }
00217 
00218 // Adjust the weights of all the samples to be uniform in the given charset.
00219 // Returns the number of samples in the iterator.
00220 int SampleIterator::UniformSamples() {
00221   int num_good_samples = 0;
00222   for (Begin(); !AtEnd(); Next()) {
00223     TrainingSample* sample = MutableSample();
00224     sample->set_weight(1.0);
00225     ++num_good_samples;
00226   }
00227   NormalizeSamples();
00228   return num_good_samples;
00229 }
00230 
00231 // Normalize the weights of all the samples in the charset_map so they sum
00232 // to 1. Returns the minimum assigned sample weight.
00233 double SampleIterator::NormalizeSamples() {
00234   double total_weight = 0.0;
00235   int sample_count = 0;
00236   for (Begin(); !AtEnd(); Next()) {
00237     const TrainingSample& sample = GetSample();
00238     total_weight += sample.weight();
00239     ++sample_count;
00240   }
00241   // Normalize samples.
00242   double min_assigned_sample_weight = 1.0;
00243   if (total_weight > 0.0) {
00244     for (Begin(); !AtEnd(); Next()) {
00245       TrainingSample* sample = MutableSample();
00246       double weight = sample->weight() / total_weight;
00247       if (weight < min_assigned_sample_weight)
00248         min_assigned_sample_weight = weight;
00249       sample->set_weight(weight);
00250     }
00251   }
00252   return min_assigned_sample_weight;
00253 }
00254 
00255 // Helper returns the current UnicharAndFont shape_entry.
00256 const UnicharAndFonts* SampleIterator::GetShapeEntry() const {
00257   const Shape& shape = shape_table_->GetShape(shape_index_);
00258   return &shape[shape_char_index_];
00259 }
00260 
00261 }  // namespace tesseract.
00262