Tesseract
3.02
|
00001 // Copyright 2011 Google Inc. All Rights Reserved. 00002 // Author: rays@google.com (Ray Smith) 00003 // 00004 // Licensed under the Apache License, Version 2.0 (the "License"); 00005 // you may not use this file except in compliance with the License. 00006 // You may obtain a copy of the License at 00007 // http://www.apache.org/licenses/LICENSE-2.0 00008 // Unless required by applicable law or agreed to in writing, software 00009 // distributed under the License is distributed on an "AS IS" BASIS, 00010 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00011 // See the License for the specific language governing permissions and 00012 // limitations under the License. 00013 // 00015 #include <ctime> 00016 00017 #include "errorcounter.h" 00018 00019 #include "fontinfo.h" 00020 #include "ndminx.h" 00021 #include "sampleiterator.h" 00022 #include "shapeclassifier.h" 00023 #include "shapetable.h" 00024 #include "trainingsample.h" 00025 #include "trainingsampleset.h" 00026 #include "unicity_table.h" 00027 00028 namespace tesseract { 00029 00030 // Tests a classifier, computing its error rate. 00031 // See errorcounter.h for description of arguments. 00032 // Iterates over the samples, calling the classifier in normal/silent mode. 00033 // If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate 00034 // report_level is set (4 or greater), it will then call the classifier again 00035 // with a debug flag and a keep_this argument to find out what is going on. 00036 double ErrorCounter::ComputeErrorRate(ShapeClassifier* classifier, 00037 int report_level, CountTypes boosting_mode, 00038 const UnicityTable<FontInfo>& fontinfo_table, 00039 const GenericVector<Pix*>& page_images, SampleIterator* it, 00040 double* unichar_error, double* scaled_error, STRING* fonts_report) { 00041 int charsetsize = it->shape_table()->unicharset().size(); 00042 int shapesize = it->CompactCharsetSize(); 00043 int fontsize = it->sample_set()->NumFonts(); 00044 ErrorCounter counter(charsetsize, shapesize, fontsize); 00045 GenericVector<ShapeRating> results; 00046 00047 clock_t start = clock(); 00048 int total_samples = 0; 00049 double unscaled_error = 0.0; 00050 // Set a number of samples on which to run the classify debug mode. 00051 int error_samples = report_level > 3 ? report_level * report_level : 0; 00052 // Iterate over all the samples, accumulating errors. 00053 for (it->Begin(); !it->AtEnd(); it->Next()) { 00054 TrainingSample* mutable_sample = it->MutableSample(); 00055 int page_index = mutable_sample->page_num(); 00056 Pix* page_pix = 0 <= page_index && page_index < page_images.size() 00057 ? page_images[page_index] : NULL; 00058 // No debug, no keep this. 00059 classifier->ClassifySample(*mutable_sample, page_pix, 0, INVALID_UNICHAR_ID, 00060 &results); 00061 if (mutable_sample->class_id() == 0) { 00062 // This is junk so use the special counter. 00063 counter.AccumulateJunk(*it->shape_table(), results, mutable_sample); 00064 } else if (counter.AccumulateErrors(report_level > 3, boosting_mode, 00065 fontinfo_table, *it->shape_table(), 00066 results, mutable_sample) && 00067 error_samples > 0) { 00068 // Running debug, keep the correct answer, and debug the classifier. 00069 tprintf("Error on sample %d: Classifier debug output:\n", 00070 it->GlobalSampleIndex()); 00071 int keep_this = it->GetSparseClassID(); 00072 classifier->ClassifySample(*mutable_sample, page_pix, 1, keep_this, 00073 &results); 00074 --error_samples; 00075 } 00076 ++total_samples; 00077 } 00078 double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC; 00079 // Create the appropriate error report. 00080 unscaled_error = counter.ReportErrors(report_level, boosting_mode, 00081 fontinfo_table, 00082 *it, unichar_error, fonts_report); 00083 if (scaled_error != NULL) *scaled_error = counter.scaled_error_; 00084 if (report_level > 1) { 00085 // It is useful to know the time in microseconds/char. 00086 tprintf("Errors computed in %.2fs at %.1f μs/char\n", 00087 total_time, 1000000.0 * total_time / total_samples); 00088 } 00089 return unscaled_error; 00090 } 00091 00092 // Constructor is private. Only anticipated use of ErrorCounter is via 00093 // the static ComputeErrorRate. 00094 ErrorCounter::ErrorCounter(int charsetsize, int shapesize, int fontsize) 00095 : scaled_error_(0.0), unichar_counts_(charsetsize, shapesize, 0) { 00096 Counts empty_counts; 00097 font_counts_.init_to_size(fontsize, empty_counts); 00098 } 00099 ErrorCounter::~ErrorCounter() { 00100 } 00101 00102 // Accumulates the errors from the classifier results on a single sample. 00103 // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred. 00104 // boosting_mode selects the type of error to be used for boosting and the 00105 // is_error_ member of sample is set according to whether the required type 00106 // of error occurred. The font_table provides access to font properties 00107 // for error counting and shape_table is used to understand the relationship 00108 // between unichar_ids and shape_ids in the results 00109 bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode, 00110 const UnicityTable<FontInfo>& font_table, 00111 const ShapeTable& shape_table, 00112 const GenericVector<ShapeRating>& results, 00113 TrainingSample* sample) { 00114 int num_results = results.size(); 00115 int res_index = 0; 00116 bool debug_it = false; 00117 int font_id = sample->font_id(); 00118 int unichar_id = sample->class_id(); 00119 sample->set_is_error(false); 00120 if (num_results == 0) { 00121 // Reject. We count rejects as a separate category, but still mark the 00122 // sample as an error in case any training module wants to use that to 00123 // improve the classifier. 00124 sample->set_is_error(true); 00125 ++font_counts_[font_id].n[CT_REJECT]; 00126 } else if (shape_table.GetShape(results[0].shape_id). 00127 ContainsUnicharAndFont(unichar_id, font_id)) { 00128 ++font_counts_[font_id].n[CT_SHAPE_TOP_CORRECT]; 00129 // Unichar and font OK, but count if multiple unichars. 00130 if (shape_table.GetShape(results[0].shape_id).size() > 1) 00131 ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR]; 00132 } else { 00133 // This is a top shape error. 00134 ++font_counts_[font_id].n[CT_SHAPE_TOP_ERR]; 00135 // Check to see if any font in the top choice has attributes that match. 00136 bool attributes_match = false; 00137 uinT32 font_props = font_table.get(font_id).properties; 00138 const Shape& shape = shape_table.GetShape(results[0].shape_id); 00139 for (int c = 0; c < shape.size() && !attributes_match; ++c) { 00140 for (int f = 0; f < shape[c].font_ids.size(); ++f) { 00141 if (font_table.get(shape[c].font_ids[f]).properties == font_props) { 00142 attributes_match = true; 00143 break; 00144 } 00145 } 00146 } 00147 // TODO(rays) It is easy to add counters for individual font attributes 00148 // here if we want them. 00149 if (!attributes_match) 00150 ++font_counts_[font_id].n[CT_FONT_ATTR_ERR]; 00151 if (boosting_mode == CT_SHAPE_TOP_ERR) sample->set_is_error(true); 00152 // Find rank of correct unichar answer. (Ignoring the font.) 00153 while (res_index < num_results && 00154 !shape_table.GetShape(results[res_index].shape_id). 00155 ContainsUnichar(unichar_id)) { 00156 ++res_index; 00157 } 00158 if (res_index == 0) { 00159 // Unichar OK, but count if multiple unichars. 00160 if (shape_table.GetShape(results[res_index].shape_id).size() > 1) { 00161 ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR]; 00162 } 00163 } else { 00164 // Count maps from unichar id to shape id. 00165 if (num_results > 0) 00166 ++unichar_counts_(unichar_id, results[0].shape_id); 00167 // This is a unichar error. 00168 ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR]; 00169 if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true); 00170 if (res_index >= MIN(2, num_results)) { 00171 // It is also a 2nd choice unichar error. 00172 ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR]; 00173 if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true); 00174 } 00175 if (res_index >= num_results) { 00176 // It is also a top-n choice unichar error. 00177 ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR]; 00178 if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true); 00179 debug_it = debug; 00180 } 00181 } 00182 } 00183 // Compute mean number of return values and mean rank of correct answer. 00184 font_counts_[font_id].n[CT_NUM_RESULTS] += num_results; 00185 font_counts_[font_id].n[CT_RANK] += res_index; 00186 // If it was an error for boosting then sum the weight. 00187 if (sample->is_error()) { 00188 scaled_error_ += sample->weight(); 00189 } 00190 if (debug_it) { 00191 tprintf("%d results for char %s font %d :", 00192 num_results, shape_table.unicharset().id_to_unichar(unichar_id), 00193 font_id); 00194 for (int i = 0; i < num_results; ++i) { 00195 tprintf(" %.3f/%.3f:%s", 00196 results[i].rating, results[i].font, 00197 shape_table.DebugStr(results[i].shape_id).string()); 00198 } 00199 tprintf("\n"); 00200 return true; 00201 } 00202 return false; 00203 } 00204 00205 // Accumulates counts for junk. Counts only whether the junk was correctly 00206 // rejected or not. 00207 void ErrorCounter::AccumulateJunk(const ShapeTable& shape_table, 00208 const GenericVector<ShapeRating>& results, 00209 TrainingSample* sample) { 00210 // For junk we accept no answer, or an explicit shape answer matching the 00211 // class id of the sample. 00212 int num_results = results.size(); 00213 int font_id = sample->font_id(); 00214 int unichar_id = sample->class_id(); 00215 if (num_results > 0 && 00216 !shape_table.GetShape(results[0].shape_id).ContainsUnichar(unichar_id)) { 00217 // This is a junk error. 00218 ++font_counts_[font_id].n[CT_ACCEPTED_JUNK]; 00219 sample->set_is_error(true); 00220 // It counts as an error for boosting too so sum the weight. 00221 scaled_error_ += sample->weight(); 00222 } else { 00223 // Correctly rejected. 00224 ++font_counts_[font_id].n[CT_REJECTED_JUNK]; 00225 sample->set_is_error(false); 00226 } 00227 } 00228 00229 // Creates a report of the error rate. The report_level controls the detail 00230 // that is reported to stderr via tprintf: 00231 // 0 -> no output. 00232 // >=1 -> bottom-line error rate. 00233 // >=3 -> font-level error rate. 00234 // boosting_mode determines the return value. It selects which (un-weighted) 00235 // error rate to return. 00236 // The fontinfo_table from MasterTrainer provides the names of fonts. 00237 // The it determines the current subset of the training samples. 00238 // If not NULL, the top-choice unichar error rate is saved in unichar_error. 00239 // If not NULL, the report string is saved in fonts_report. 00240 // (Ignoring report_level). 00241 double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode, 00242 const UnicityTable<FontInfo>& fontinfo_table, 00243 const SampleIterator& it, 00244 double* unichar_error, 00245 STRING* fonts_report) { 00246 // Compute totals over all the fonts and report individual font results 00247 // when required. 00248 Counts totals; 00249 int fontsize = font_counts_.size(); 00250 for (int f = 0; f < fontsize; ++f) { 00251 // Accumulate counts over fonts. 00252 totals += font_counts_[f]; 00253 STRING font_report; 00254 if (ReportString(font_counts_[f], &font_report)) { 00255 if (fonts_report != NULL) { 00256 *fonts_report += fontinfo_table.get(f).name; 00257 *fonts_report += ": "; 00258 *fonts_report += font_report; 00259 *fonts_report += "\n"; 00260 } 00261 if (report_level > 2) { 00262 // Report individual font error rates. 00263 tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string()); 00264 } 00265 } 00266 } 00267 if (report_level > 0) { 00268 // Report the totals. 00269 STRING total_report; 00270 if (ReportString(totals, &total_report)) { 00271 tprintf("TOTAL Scaled Err=%.4g%%, %s\n", 00272 scaled_error_ * 100.0, total_report.string()); 00273 } 00274 // Report the worst substitution error only for now. 00275 if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) { 00276 const UNICHARSET& unicharset = it.shape_table()->unicharset(); 00277 int charsetsize = unicharset.size(); 00278 int shapesize = it.CompactCharsetSize(); 00279 int worst_uni_id = 0; 00280 int worst_shape_id = 0; 00281 int worst_err = 0; 00282 for (int u = 0; u < charsetsize; ++u) { 00283 for (int s = 0; s < shapesize; ++s) { 00284 if (unichar_counts_(u, s) > worst_err) { 00285 worst_err = unichar_counts_(u, s); 00286 worst_uni_id = u; 00287 worst_shape_id = s; 00288 } 00289 } 00290 } 00291 if (worst_err > 0) { 00292 tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n", 00293 worst_uni_id, unicharset.id_to_unichar(worst_uni_id), 00294 it.shape_table()->DebugStr(worst_shape_id).string(), 00295 worst_err, totals.n[CT_UNICHAR_TOP1_ERR], 00296 100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]); 00297 } 00298 } 00299 } 00300 double rates[CT_SIZE]; 00301 if (!ComputeRates(totals, rates)) 00302 return 0.0; 00303 // Set output values if asked for. 00304 if (unichar_error != NULL) 00305 *unichar_error = rates[CT_UNICHAR_TOP1_ERR]; 00306 return rates[boosting_mode]; 00307 } 00308 00309 // Sets the report string to a combined human and machine-readable report 00310 // string of the error rates. 00311 // Returns false if there is no data, leaving report unchanged. 00312 bool ErrorCounter::ReportString(const Counts& counts, STRING* report) { 00313 // Compute the error rates. 00314 double rates[CT_SIZE]; 00315 if (!ComputeRates(counts, rates)) 00316 return false; 00317 // Using %.4g%%, the length of the output string should exactly match the 00318 // length of the format string, but in case of overflow, allow for +eddd 00319 // on each number. 00320 const int kMaxExtraLength = 5; // Length of +eddd. 00321 // Keep this format string and the snprintf in sync with the CountTypes enum. 00322 const char* format_str = "ShapeErr=%.4g%%, FontAttr=%.4g%%, " 00323 "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], " 00324 "Multi=%.4g%%, Rej=%.4g%%, " 00325 "Answers=%.3g, Rank=%.3g, " 00326 "OKjunk=%.4g%%, Badjunk=%.4g%%"; 00327 int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1; 00328 char* formatted_str = new char[max_str_len]; 00329 snprintf(formatted_str, max_str_len, format_str, 00330 rates[CT_SHAPE_TOP_ERR] * 100.0, 00331 rates[CT_FONT_ATTR_ERR] * 100.0, 00332 rates[CT_UNICHAR_TOP1_ERR] * 100.0, 00333 rates[CT_UNICHAR_TOP2_ERR] * 100.0, 00334 rates[CT_UNICHAR_TOPN_ERR] * 100.0, 00335 rates[CT_OK_MULTI_UNICHAR] * 100.0, 00336 rates[CT_REJECT] * 100.0, 00337 rates[CT_NUM_RESULTS], 00338 rates[CT_RANK], 00339 100.0 * rates[CT_REJECTED_JUNK], 00340 100.0 * rates[CT_ACCEPTED_JUNK]); 00341 *report = formatted_str; 00342 delete [] formatted_str; 00343 // Now append each field of counts with a tab in front so the result can 00344 // be loaded into a spreadsheet. 00345 for (int ct = 0; ct < CT_SIZE; ++ct) 00346 report->add_str_int("\t", counts.n[ct]); 00347 return true; 00348 } 00349 00350 // Computes the error rates and returns in rates which is an array of size 00351 // CT_SIZE. Returns false if there is no data, leaving rates unchanged. 00352 bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) { 00353 int ok_samples = counts.n[CT_SHAPE_TOP_CORRECT] + counts.n[CT_SHAPE_TOP_ERR] + 00354 counts.n[CT_REJECT]; 00355 int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK]; 00356 if (ok_samples == 0 && junk_samples == 0) { 00357 // There is no data. 00358 return false; 00359 } 00360 // Compute rates for normal chars. 00361 double denominator = static_cast<double>(MAX(ok_samples, 1)); 00362 for (int ct = 0; ct <= CT_RANK; ++ct) 00363 rates[ct] = counts.n[ct] / denominator; 00364 // Compute rates for junk. 00365 denominator = static_cast<double>(MAX(junk_samples, 1)); 00366 for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct) 00367 rates[ct] = counts.n[ct] / denominator; 00368 return true; 00369 } 00370 00371 ErrorCounter::Counts::Counts() { 00372 memset(n, 0, sizeof(n[0]) * CT_SIZE); 00373 } 00374 // Adds other into this for computing totals. 00375 void ErrorCounter::Counts::operator+=(const Counts& other) { 00376 for (int ct = 0; ct < CT_SIZE; ++ct) 00377 n[ct] += other.n[ct]; 00378 } 00379 00380 00381 } // namespace tesseract. 00382 00383 00384 00385 00386