Tesseract
3.02
|
00001 // Copyright 2011 Google Inc. All Rights Reserved. 00002 // Author: rays@google.com (Ray Smith) 00003 // 00004 // Licensed under the Apache License, Version 2.0 (the "License"); 00005 // you may not use this file except in compliance with the License. 00006 // You may obtain a copy of the License at 00007 // http://www.apache.org/licenses/LICENSE-2.0 00008 // Unless required by applicable law or agreed to in writing, software 00009 // distributed under the License is distributed on an "AS IS" BASIS, 00010 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00011 // See the License for the specific language governing permissions and 00012 // limitations under the License. 00013 // 00015 00016 #include "sampleiterator.h" 00017 00018 #include "indexmapbidi.h" 00019 #include "shapetable.h" 00020 #include "trainingsample.h" 00021 #include "trainingsampleset.h" 00022 00023 namespace tesseract { 00024 00025 // ================== SampleIterator Implementation ================= 00026 00027 SampleIterator::SampleIterator() 00028 : charset_map_(NULL), 00029 shape_table_(NULL), 00030 sample_set_(NULL), 00031 randomize_(false), 00032 owned_shape_table_(NULL) { 00033 num_shapes_ = 0; 00034 Begin(); 00035 } 00036 00037 SampleIterator::~SampleIterator() { 00038 Clear(); 00039 } 00040 00041 void SampleIterator::Clear() { 00042 delete owned_shape_table_; 00043 owned_shape_table_ = NULL; 00044 } 00045 00046 // See class comment for arguments. 00047 void SampleIterator::Init(const IndexMapBiDi* charset_map, 00048 const ShapeTable* shape_table, 00049 bool randomize, 00050 TrainingSampleSet* sample_set) { 00051 Clear(); 00052 charset_map_ = charset_map; 00053 shape_table_ = shape_table; 00054 sample_set_ = sample_set; 00055 randomize_ = randomize; 00056 if (shape_table_ == NULL && charset_map_ != NULL) { 00057 // The caller wishes to iterate by class. The easiest way to do this 00058 // is to create a dummy shape_table_ that we will own. 00059 int num_fonts = sample_set_->NumFonts(); 00060 owned_shape_table_ = new ShapeTable(sample_set_->unicharset()); 00061 int charsetsize = sample_set_->unicharset().size(); 00062 for (int c = 0; c < charsetsize; ++c) { 00063 // We always add a shape for each character to keep the index in sync 00064 // with the unichar_id. 00065 int shape_id = owned_shape_table_->AddShape(c, 0); 00066 for (int f = 1; f < num_fonts; ++f) { 00067 if (sample_set_->NumClassSamples(f, c, true) > 0) { 00068 owned_shape_table_->AddToShape(shape_id, c, f); 00069 } 00070 } 00071 } 00072 shape_table_ = owned_shape_table_; 00073 } 00074 if (shape_table_ != NULL) { 00075 num_shapes_ = shape_table_->NumShapes(); 00076 } else { 00077 num_shapes_ = randomize ? sample_set_->num_samples() 00078 : sample_set_->num_raw_samples(); 00079 } 00080 Begin(); 00081 } 00082 00083 // Iterator functions designed for use with a simple for loop: 00084 // for (it.Begin(); !it.AtEnd(); it.Next()) { 00085 // const TrainingSample& sample = it.GetSample(); 00086 // } 00087 void SampleIterator::Begin() { 00088 shape_index_ = -1; 00089 shape_char_index_ = 0; 00090 num_shape_chars_ = 0; 00091 shape_font_index_ = 0; 00092 num_shape_fonts_ = 0; 00093 sample_index_ = 0; 00094 num_samples_ = 0; 00095 // Find the first indexable sample. 00096 Next(); 00097 } 00098 00099 bool SampleIterator::AtEnd() const { 00100 return shape_index_ >= num_shapes_; 00101 } 00102 00103 const TrainingSample& SampleIterator::GetSample() const { 00104 if (shape_table_ != NULL) { 00105 const UnicharAndFonts* shape_entry = GetShapeEntry(); 00106 int char_id = shape_entry->unichar_id; 00107 int font_id = shape_entry->font_ids[shape_font_index_]; 00108 return *sample_set_->GetSample(font_id, char_id, sample_index_); 00109 } else { 00110 return *sample_set_->GetSample(shape_index_); 00111 } 00112 } 00113 00114 TrainingSample* SampleIterator::MutableSample() const { 00115 if (shape_table_ != NULL) { 00116 const UnicharAndFonts* shape_entry = GetShapeEntry(); 00117 int char_id = shape_entry->unichar_id; 00118 int font_id = shape_entry->font_ids[shape_font_index_]; 00119 return sample_set_->MutableSample(font_id, char_id, sample_index_); 00120 } else { 00121 return sample_set_->mutable_sample(shape_index_); 00122 } 00123 } 00124 00125 // Returns the total index (from the original set of samples) of the current 00126 // sample. 00127 int SampleIterator::GlobalSampleIndex() const { 00128 if (shape_table_ != NULL) { 00129 const UnicharAndFonts* shape_entry = GetShapeEntry(); 00130 int char_id = shape_entry->unichar_id; 00131 int font_id = shape_entry->font_ids[shape_font_index_]; 00132 return sample_set_->GlobalSampleIndex(font_id, char_id, sample_index_); 00133 } else { 00134 return shape_index_; 00135 } 00136 } 00137 00138 // Returns the index of the current sample in compact charset space, so 00139 // in a 2-class problem between x and y, the returned indices will all be 00140 // 0 or 1, and have nothing to do with the unichar_ids. 00141 // If the charset_map_ is NULL, then this is equal to GetSparseClassID(). 00142 int SampleIterator::GetCompactClassID() const { 00143 return charset_map_ != NULL ? charset_map_->SparseToCompact(shape_index_) 00144 : GetSparseClassID(); 00145 } 00146 // Returns the index of the current sample in sparse charset space, so 00147 // in a 2-class problem between x and y, the returned indices will all be 00148 // x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids 00149 // with a shape_table_. 00150 int SampleIterator::GetSparseClassID() const { 00151 return shape_table_ != NULL ? shape_index_ : GetSample().class_id(); 00152 } 00153 00154 // Moves on to the next indexable sample. If the end is reached, leaves 00155 // the state such that AtEnd() is true. 00156 void SampleIterator::Next() { 00157 if (shape_table_ != NULL) { 00158 // Next sample in this class/font combination. 00159 ++sample_index_; 00160 if (sample_index_ < num_samples_) 00161 return; 00162 // Next font in this class in this shape. 00163 sample_index_ = 0; 00164 do { 00165 ++shape_font_index_; 00166 if (shape_font_index_ >= num_shape_fonts_) { 00167 // Next unichar in this shape. 00168 shape_font_index_ = 0; 00169 ++shape_char_index_; 00170 if (shape_char_index_ >= num_shape_chars_) { 00171 // Find the next shape that is mapped in the charset_map_. 00172 shape_char_index_ = 0; 00173 do { 00174 ++shape_index_; 00175 } while (shape_index_ < num_shapes_ && 00176 charset_map_ != NULL && 00177 charset_map_->SparseToCompact(shape_index_) < 0); 00178 if (shape_index_ >= num_shapes_) 00179 return; // The end. 00180 num_shape_chars_ = shape_table_->GetShape(shape_index_).size(); 00181 } 00182 } 00183 const UnicharAndFonts* shape_entry = GetShapeEntry(); 00184 num_shape_fonts_ = shape_entry->font_ids.size(); 00185 int char_id = shape_entry->unichar_id; 00186 int font_id = shape_entry->font_ids[shape_font_index_]; 00187 num_samples_ = sample_set_->NumClassSamples(font_id, char_id, randomize_); 00188 } while (num_samples_ == 0); 00189 } else { 00190 // We are just iterating over the samples. 00191 ++shape_index_; 00192 } 00193 } 00194 00195 // Returns the size of the compact charset space. 00196 int SampleIterator::CompactCharsetSize() const { 00197 return charset_map_ != NULL ? charset_map_->CompactSize() 00198 : SparseCharsetSize(); 00199 } 00200 00201 // Returns the size of the sparse charset space. 00202 int SampleIterator::SparseCharsetSize() const { 00203 return charset_map_ != NULL 00204 ? charset_map_->SparseSize() 00205 : (shape_table_ != NULL ? shape_table_->NumShapes() 00206 : sample_set_->charsetsize()); 00207 } 00208 00209 // Apply the supplied feature_space/feature_map transform to all samples 00210 // accessed by this iterator. 00211 void SampleIterator::MapSampleFeatures(const IntFeatureMap& feature_map) { 00212 for (Begin(); !AtEnd(); Next()) { 00213 TrainingSample* sample = MutableSample(); 00214 sample->MapFeatures(feature_map); 00215 } 00216 } 00217 00218 // Adjust the weights of all the samples to be uniform in the given charset. 00219 // Returns the number of samples in the iterator. 00220 int SampleIterator::UniformSamples() { 00221 int num_good_samples = 0; 00222 for (Begin(); !AtEnd(); Next()) { 00223 TrainingSample* sample = MutableSample(); 00224 sample->set_weight(1.0); 00225 ++num_good_samples; 00226 } 00227 NormalizeSamples(); 00228 return num_good_samples; 00229 } 00230 00231 // Normalize the weights of all the samples in the charset_map so they sum 00232 // to 1. Returns the minimum assigned sample weight. 00233 double SampleIterator::NormalizeSamples() { 00234 double total_weight = 0.0; 00235 int sample_count = 0; 00236 for (Begin(); !AtEnd(); Next()) { 00237 const TrainingSample& sample = GetSample(); 00238 total_weight += sample.weight(); 00239 ++sample_count; 00240 } 00241 // Normalize samples. 00242 double min_assigned_sample_weight = 1.0; 00243 if (total_weight > 0.0) { 00244 for (Begin(); !AtEnd(); Next()) { 00245 TrainingSample* sample = MutableSample(); 00246 double weight = sample->weight() / total_weight; 00247 if (weight < min_assigned_sample_weight) 00248 min_assigned_sample_weight = weight; 00249 sample->set_weight(weight); 00250 } 00251 } 00252 return min_assigned_sample_weight; 00253 } 00254 00255 // Helper returns the current UnicharAndFont shape_entry. 00256 const UnicharAndFonts* SampleIterator::GetShapeEntry() const { 00257 const Shape& shape = shape_table_->GetShape(shape_index_); 00258 return &shape[shape_char_index_]; 00259 } 00260 00261 } // namespace tesseract. 00262