Tesseract  3.02
tesseract-ocr/classify/errorcounter.cpp
Go to the documentation of this file.
00001 // Copyright 2011 Google Inc. All Rights Reserved.
00002 // Author: rays@google.com (Ray Smith)
00003 //
00004 // Licensed under the Apache License, Version 2.0 (the "License");
00005 // you may not use this file except in compliance with the License.
00006 // You may obtain a copy of the License at
00007 // http://www.apache.org/licenses/LICENSE-2.0
00008 // Unless required by applicable law or agreed to in writing, software
00009 // distributed under the License is distributed on an "AS IS" BASIS,
00010 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00011 // See the License for the specific language governing permissions and
00012 // limitations under the License.
00013 //
00015 #include <ctime>
00016 
00017 #include "errorcounter.h"
00018 
00019 #include "fontinfo.h"
00020 #include "ndminx.h"
00021 #include "sampleiterator.h"
00022 #include "shapeclassifier.h"
00023 #include "shapetable.h"
00024 #include "trainingsample.h"
00025 #include "trainingsampleset.h"
00026 #include "unicity_table.h"
00027 
00028 namespace tesseract {
00029 
00030 // Tests a classifier, computing its error rate.
00031 // See errorcounter.h for description of arguments.
00032 // Iterates over the samples, calling the classifier in normal/silent mode.
00033 // If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate
00034 // report_level is set (4 or greater), it will then call the classifier again
00035 // with a debug flag and a keep_this argument to find out what is going on.
00036 double ErrorCounter::ComputeErrorRate(ShapeClassifier* classifier,
00037     int report_level, CountTypes boosting_mode,
00038     const UnicityTable<FontInfo>& fontinfo_table,
00039     const GenericVector<Pix*>& page_images, SampleIterator* it,
00040     double* unichar_error,  double* scaled_error, STRING* fonts_report) {
00041   int charsetsize = it->shape_table()->unicharset().size();
00042   int shapesize = it->CompactCharsetSize();
00043   int fontsize = it->sample_set()->NumFonts();
00044   ErrorCounter counter(charsetsize, shapesize, fontsize);
00045   GenericVector<ShapeRating> results;
00046 
00047   clock_t start = clock();
00048   int total_samples = 0;
00049   double unscaled_error = 0.0;
00050   // Set a number of samples on which to run the classify debug mode.
00051   int error_samples = report_level > 3 ? report_level * report_level : 0;
00052   // Iterate over all the samples, accumulating errors.
00053   for (it->Begin(); !it->AtEnd(); it->Next()) {
00054     TrainingSample* mutable_sample = it->MutableSample();
00055     int page_index = mutable_sample->page_num();
00056     Pix* page_pix = 0 <= page_index && page_index < page_images.size()
00057                   ? page_images[page_index] : NULL;
00058     // No debug, no keep this.
00059     classifier->ClassifySample(*mutable_sample, page_pix, 0, INVALID_UNICHAR_ID,
00060                                &results);
00061     if (mutable_sample->class_id() == 0) {
00062       // This is junk so use the special counter.
00063       counter.AccumulateJunk(*it->shape_table(), results, mutable_sample);
00064     } else if (counter.AccumulateErrors(report_level > 3, boosting_mode,
00065                                         fontinfo_table, *it->shape_table(),
00066                                         results, mutable_sample) &&
00067                error_samples > 0) {
00068       // Running debug, keep the correct answer, and debug the classifier.
00069       tprintf("Error on sample %d: Classifier debug output:\n",
00070               it->GlobalSampleIndex());
00071       int keep_this = it->GetSparseClassID();
00072       classifier->ClassifySample(*mutable_sample, page_pix, 1, keep_this,
00073                                  &results);
00074       --error_samples;
00075     }
00076     ++total_samples;
00077   }
00078   double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC;
00079   // Create the appropriate error report.
00080   unscaled_error = counter.ReportErrors(report_level, boosting_mode,
00081                                         fontinfo_table,
00082                                         *it, unichar_error, fonts_report);
00083   if (scaled_error != NULL) *scaled_error = counter.scaled_error_;
00084   if (report_level > 1) {
00085     // It is useful to know the time in microseconds/char.
00086     tprintf("Errors computed in %.2fs at %.1f μs/char\n",
00087             total_time, 1000000.0 * total_time / total_samples);
00088   }
00089   return unscaled_error;
00090 }
00091 
00092 // Constructor is private. Only anticipated use of ErrorCounter is via
00093 // the static ComputeErrorRate.
00094 ErrorCounter::ErrorCounter(int charsetsize, int shapesize, int fontsize)
00095   : scaled_error_(0.0), unichar_counts_(charsetsize, shapesize, 0) {
00096   Counts empty_counts;
00097   font_counts_.init_to_size(fontsize, empty_counts);
00098 }
00099 ErrorCounter::~ErrorCounter() {
00100 }
00101 
00102 // Accumulates the errors from the classifier results on a single sample.
00103 // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred.
00104 // boosting_mode selects the type of error to be used for boosting and the
00105 // is_error_ member of sample is set according to whether the required type
00106 // of error occurred. The font_table provides access to font properties
00107 // for error counting and shape_table is used to understand the relationship
00108 // between unichar_ids and shape_ids in the results
00109 bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode,
00110                                     const UnicityTable<FontInfo>& font_table,
00111                                     const ShapeTable& shape_table,
00112                                     const GenericVector<ShapeRating>& results,
00113                                     TrainingSample* sample) {
00114   int num_results = results.size();
00115   int res_index = 0;
00116   bool debug_it = false;
00117   int font_id = sample->font_id();
00118   int unichar_id = sample->class_id();
00119   sample->set_is_error(false);
00120   if (num_results == 0) {
00121     // Reject. We count rejects as a separate category, but still mark the
00122     // sample as an error in case any training module wants to use that to
00123     // improve the classifier.
00124     sample->set_is_error(true);
00125     ++font_counts_[font_id].n[CT_REJECT];
00126   } else if (shape_table.GetShape(results[0].shape_id).
00127           ContainsUnicharAndFont(unichar_id, font_id)) {
00128     ++font_counts_[font_id].n[CT_SHAPE_TOP_CORRECT];
00129     // Unichar and font OK, but count if multiple unichars.
00130     if (shape_table.GetShape(results[0].shape_id).size() > 1)
00131       ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR];
00132   } else {
00133     // This is a top shape error.
00134     ++font_counts_[font_id].n[CT_SHAPE_TOP_ERR];
00135     // Check to see if any font in the top choice has attributes that match.
00136     bool attributes_match = false;
00137     uinT32 font_props = font_table.get(font_id).properties;
00138     const Shape& shape = shape_table.GetShape(results[0].shape_id);
00139     for (int c = 0; c < shape.size() && !attributes_match; ++c) {
00140       for (int f = 0; f < shape[c].font_ids.size(); ++f) {
00141         if (font_table.get(shape[c].font_ids[f]).properties == font_props) {
00142           attributes_match = true;
00143           break;
00144         }
00145       }
00146     }
00147     // TODO(rays) It is easy to add counters for individual font attributes
00148     // here if we want them.
00149     if (!attributes_match)
00150       ++font_counts_[font_id].n[CT_FONT_ATTR_ERR];
00151     if (boosting_mode == CT_SHAPE_TOP_ERR) sample->set_is_error(true);
00152     // Find rank of correct unichar answer. (Ignoring the font.)
00153     while (res_index < num_results &&
00154            !shape_table.GetShape(results[res_index].shape_id).
00155                 ContainsUnichar(unichar_id)) {
00156       ++res_index;
00157     }
00158     if (res_index == 0) {
00159       // Unichar OK, but count if multiple unichars.
00160       if (shape_table.GetShape(results[res_index].shape_id).size() > 1) {
00161         ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR];
00162       }
00163     } else {
00164       // Count maps from unichar id to shape id.
00165       if (num_results > 0)
00166         ++unichar_counts_(unichar_id, results[0].shape_id);
00167       // This is a unichar error.
00168       ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR];
00169       if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true);
00170       if (res_index >= MIN(2, num_results)) {
00171         // It is also a 2nd choice unichar error.
00172         ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR];
00173         if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true);
00174       }
00175       if (res_index >= num_results) {
00176         // It is also a top-n choice unichar error.
00177         ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR];
00178         if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true);
00179         debug_it = debug;
00180       }
00181     }
00182   }
00183   // Compute mean number of return values and mean rank of correct answer.
00184   font_counts_[font_id].n[CT_NUM_RESULTS] += num_results;
00185   font_counts_[font_id].n[CT_RANK] += res_index;
00186   // If it was an error for boosting then sum the weight.
00187   if (sample->is_error()) {
00188     scaled_error_ += sample->weight();
00189   }
00190   if (debug_it) {
00191     tprintf("%d results for char %s font %d :",
00192             num_results, shape_table.unicharset().id_to_unichar(unichar_id),
00193             font_id);
00194     for (int i = 0; i < num_results; ++i) {
00195       tprintf(" %.3f/%.3f:%s",
00196               results[i].rating, results[i].font,
00197               shape_table.DebugStr(results[i].shape_id).string());
00198     }
00199     tprintf("\n");
00200     return true;
00201   }
00202   return false;
00203 }
00204 
00205 // Accumulates counts for junk. Counts only whether the junk was correctly
00206 // rejected or not.
00207 void ErrorCounter::AccumulateJunk(const ShapeTable& shape_table,
00208                                   const GenericVector<ShapeRating>& results,
00209                                   TrainingSample* sample) {
00210   // For junk we accept no answer, or an explicit shape answer matching the
00211   // class id of the sample.
00212   int num_results = results.size();
00213   int font_id = sample->font_id();
00214   int unichar_id = sample->class_id();
00215   if (num_results > 0 &&
00216       !shape_table.GetShape(results[0].shape_id).ContainsUnichar(unichar_id)) {
00217     // This is a junk error.
00218     ++font_counts_[font_id].n[CT_ACCEPTED_JUNK];
00219     sample->set_is_error(true);
00220     // It counts as an error for boosting too so sum the weight.
00221     scaled_error_ += sample->weight();
00222   } else {
00223     // Correctly rejected.
00224     ++font_counts_[font_id].n[CT_REJECTED_JUNK];
00225     sample->set_is_error(false);
00226   }
00227 }
00228 
00229 // Creates a report of the error rate. The report_level controls the detail
00230 // that is reported to stderr via tprintf:
00231 // 0   -> no output.
00232 // >=1 -> bottom-line error rate.
00233 // >=3 -> font-level error rate.
00234 // boosting_mode determines the return value. It selects which (un-weighted)
00235 // error rate to return.
00236 // The fontinfo_table from MasterTrainer provides the names of fonts.
00237 // The it determines the current subset of the training samples.
00238 // If not NULL, the top-choice unichar error rate is saved in unichar_error.
00239 // If not NULL, the report string is saved in fonts_report.
00240 // (Ignoring report_level).
00241 double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode,
00242                                   const UnicityTable<FontInfo>& fontinfo_table,
00243                                   const SampleIterator& it,
00244                                   double* unichar_error,
00245                                   STRING* fonts_report) {
00246   // Compute totals over all the fonts and report individual font results
00247   // when required.
00248   Counts totals;
00249   int fontsize = font_counts_.size();
00250   for (int f = 0; f < fontsize; ++f) {
00251     // Accumulate counts over fonts.
00252     totals += font_counts_[f];
00253     STRING font_report;
00254     if (ReportString(font_counts_[f], &font_report)) {
00255       if (fonts_report != NULL) {
00256         *fonts_report += fontinfo_table.get(f).name;
00257         *fonts_report += ": ";
00258         *fonts_report += font_report;
00259         *fonts_report += "\n";
00260       }
00261       if (report_level > 2) {
00262         // Report individual font error rates.
00263         tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string());
00264       }
00265     }
00266   }
00267   if (report_level > 0) {
00268     // Report the totals.
00269     STRING total_report;
00270     if (ReportString(totals, &total_report)) {
00271       tprintf("TOTAL Scaled Err=%.4g%%, %s\n",
00272               scaled_error_ * 100.0, total_report.string());
00273     }
00274     // Report the worst substitution error only for now.
00275     if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) {
00276       const UNICHARSET& unicharset = it.shape_table()->unicharset();
00277       int charsetsize = unicharset.size();
00278       int shapesize = it.CompactCharsetSize();
00279       int worst_uni_id = 0;
00280       int worst_shape_id = 0;
00281       int worst_err = 0;
00282       for (int u = 0; u < charsetsize; ++u) {
00283         for (int s = 0; s < shapesize; ++s) {
00284           if (unichar_counts_(u, s) > worst_err) {
00285             worst_err = unichar_counts_(u, s);
00286             worst_uni_id = u;
00287             worst_shape_id = s;
00288           }
00289         }
00290       }
00291       if (worst_err > 0) {
00292         tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n",
00293                 worst_uni_id, unicharset.id_to_unichar(worst_uni_id),
00294                 it.shape_table()->DebugStr(worst_shape_id).string(),
00295                 worst_err, totals.n[CT_UNICHAR_TOP1_ERR],
00296                 100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]);
00297       }
00298     }
00299   }
00300   double rates[CT_SIZE];
00301   if (!ComputeRates(totals, rates))
00302     return 0.0;
00303   // Set output values if asked for.
00304   if (unichar_error != NULL)
00305     *unichar_error = rates[CT_UNICHAR_TOP1_ERR];
00306   return rates[boosting_mode];
00307 }
00308 
00309 // Sets the report string to a combined human and machine-readable report
00310 // string of the error rates.
00311 // Returns false if there is no data, leaving report unchanged.
00312 bool ErrorCounter::ReportString(const Counts& counts, STRING* report) {
00313   // Compute the error rates.
00314   double rates[CT_SIZE];
00315   if (!ComputeRates(counts, rates))
00316     return false;
00317   // Using %.4g%%, the length of the output string should exactly match the
00318   // length of the format string, but in case of overflow, allow for +eddd
00319   // on each number.
00320   const int kMaxExtraLength = 5;  // Length of +eddd.
00321   // Keep this format string and the snprintf in sync with the CountTypes enum.
00322   const char* format_str = "ShapeErr=%.4g%%, FontAttr=%.4g%%, "
00323                            "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], "
00324                            "Multi=%.4g%%, Rej=%.4g%%, "
00325                            "Answers=%.3g, Rank=%.3g, "
00326                            "OKjunk=%.4g%%, Badjunk=%.4g%%";
00327   int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1;
00328   char* formatted_str = new char[max_str_len];
00329   snprintf(formatted_str, max_str_len, format_str,
00330            rates[CT_SHAPE_TOP_ERR] * 100.0,
00331            rates[CT_FONT_ATTR_ERR] * 100.0,
00332            rates[CT_UNICHAR_TOP1_ERR] * 100.0,
00333            rates[CT_UNICHAR_TOP2_ERR] * 100.0,
00334            rates[CT_UNICHAR_TOPN_ERR] * 100.0,
00335            rates[CT_OK_MULTI_UNICHAR] * 100.0,
00336            rates[CT_REJECT] * 100.0,
00337            rates[CT_NUM_RESULTS],
00338            rates[CT_RANK],
00339            100.0 * rates[CT_REJECTED_JUNK],
00340            100.0 * rates[CT_ACCEPTED_JUNK]);
00341   *report = formatted_str;
00342   delete [] formatted_str;
00343   // Now append each field of counts with a tab in front so the result can
00344   // be loaded into a spreadsheet.
00345   for (int ct = 0; ct < CT_SIZE; ++ct)
00346     report->add_str_int("\t", counts.n[ct]);
00347   return true;
00348 }
00349 
00350 // Computes the error rates and returns in rates which is an array of size
00351 // CT_SIZE. Returns false if there is no data, leaving rates unchanged.
00352 bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) {
00353   int ok_samples = counts.n[CT_SHAPE_TOP_CORRECT] + counts.n[CT_SHAPE_TOP_ERR] +
00354       counts.n[CT_REJECT];
00355   int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK];
00356   if (ok_samples == 0 && junk_samples == 0) {
00357     // There is no data.
00358     return false;
00359   }
00360   // Compute rates for normal chars.
00361   double denominator = static_cast<double>(MAX(ok_samples, 1));
00362   for (int ct = 0; ct <= CT_RANK; ++ct)
00363     rates[ct] = counts.n[ct] / denominator;
00364   // Compute rates for junk.
00365   denominator = static_cast<double>(MAX(junk_samples, 1));
00366   for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct)
00367     rates[ct] = counts.n[ct] / denominator;
00368   return true;
00369 }
00370 
00371 ErrorCounter::Counts::Counts() {
00372   memset(n, 0, sizeof(n[0]) * CT_SIZE);
00373 }
00374 // Adds other into this for computing totals.
00375 void ErrorCounter::Counts::operator+=(const Counts& other) {
00376   for (int ct = 0; ct < CT_SIZE; ++ct)
00377     n[ct] += other.n[ct];
00378 }
00379 
00380 
00381 }  // namespace tesseract.
00382 
00383 
00384 
00385 
00386