Tesseract
3.02
|
00001 // Copyright 2008 Google Inc. 00002 // All Rights Reserved. 00003 // Author: ahmadab@google.com (Ahmad Abdulkader) 00004 // 00005 // neural_net.h: Declarations of a class for an object that 00006 // represents an arbitrary network of neurons 00007 // 00008 00009 #ifndef NEURAL_NET_H 00010 #define NEURAL_NET_H 00011 00012 #include <string> 00013 #include <vector> 00014 #include "neuron.h" 00015 #include "input_file_buffer.h" 00016 00017 namespace tesseract { 00018 00019 // Minimum input range below which we set the input weight to zero 00020 static const float kMinInputRange = 1e-6f; 00021 00022 class NeuralNet { 00023 public: 00024 NeuralNet(); 00025 virtual ~NeuralNet(); 00026 // create a net object from a file. Uses stdio 00027 static NeuralNet *FromFile(const string file_name); 00028 // create a net object from an input buffer 00029 static NeuralNet *FromInputBuffer(InputFileBuffer *ib); 00030 // Different flavors of feed forward function 00031 template <typename Type> bool FeedForward(const Type *inputs, 00032 Type *outputs); 00033 // Compute the output of a specific output node. 00034 // This function is useful for application that are interested in a single 00035 // output of the net and do not want to waste time on the rest 00036 template <typename Type> bool GetNetOutput(const Type *inputs, 00037 int output_id, 00038 Type *output); 00039 // Accessor functions 00040 int in_cnt() const { return in_cnt_; } 00041 int out_cnt() const { return out_cnt_; } 00042 00043 protected: 00044 struct Node; 00045 // A node-weight pair 00046 struct WeightedNode { 00047 Node *input_node; 00048 float input_weight; 00049 }; 00050 // node struct used for fast feedforward in 00051 // Read only nets 00052 struct Node { 00053 float out; 00054 float bias; 00055 int fan_in_cnt; 00056 WeightedNode *inputs; 00057 }; 00058 // Read-Only flag (no training: On by default) 00059 // will presumeably be set to false by 00060 // the inherting TrainableNeuralNet class 00061 bool read_only_; 00062 // input count 00063 int in_cnt_; 00064 // output count 00065 int out_cnt_; 00066 // Total neuron count (including inputs) 00067 int neuron_cnt_; 00068 // count of unique weights 00069 int wts_cnt_; 00070 // Neuron vector 00071 Neuron *neurons_; 00072 // size of allocated weight chunk (in weights) 00073 // This is basically the size of the biggest network 00074 // that I have trained. However, the class will allow 00075 // a bigger sized net if desired 00076 static const int kWgtChunkSize = 0x10000; 00077 // Magic number expected at the beginning of the NN 00078 // binary file 00079 static const unsigned int kNetSignature = 0xFEFEABD0; 00080 // count of allocated wgts in the last chunk 00081 int alloc_wgt_cnt_; 00082 // vector of weights buffers 00083 vector<vector<float> *>wts_vec_; 00084 // Is the net an auto-encoder type 00085 bool auto_encoder_; 00086 // vector of input max values 00087 vector<float> inputs_max_; 00088 // vector of input min values 00089 vector<float> inputs_min_; 00090 // vector of input mean values 00091 vector<float> inputs_mean_; 00092 // vector of input standard deviation values 00093 vector<float> inputs_std_dev_; 00094 // vector of input offsets used by fast read-only 00095 // feedforward function 00096 vector<Node> fast_nodes_; 00097 // Network Initialization function 00098 void Init(); 00099 // Clears all neurons 00100 void Clear() { 00101 for (int node = 0; node < neuron_cnt_; node++) { 00102 neurons_[node].Clear(); 00103 } 00104 } 00105 // Reads the net from an input buffer 00106 template<class ReadBuffType> bool ReadBinary(ReadBuffType *input_buff) { 00107 // Init vars 00108 Init(); 00109 // is this an autoencoder 00110 unsigned int read_val; 00111 unsigned int auto_encode; 00112 // read and verify signature 00113 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) { 00114 return false; 00115 } 00116 if (read_val != kNetSignature) { 00117 return false; 00118 } 00119 if (input_buff->Read(&auto_encode, sizeof(auto_encode)) != 00120 sizeof(auto_encode)) { 00121 return false; 00122 } 00123 auto_encoder_ = auto_encode; 00124 // read and validate total # of nodes 00125 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) { 00126 return false; 00127 } 00128 neuron_cnt_ = read_val; 00129 if (neuron_cnt_ <= 0) { 00130 return false; 00131 } 00132 // set the size of the neurons vector 00133 neurons_ = new Neuron[neuron_cnt_]; 00134 if (neurons_ == NULL) { 00135 return false; 00136 } 00137 // read & validate inputs 00138 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) { 00139 return false; 00140 } 00141 in_cnt_ = read_val; 00142 if (in_cnt_ <= 0) { 00143 return false; 00144 } 00145 // read outputs 00146 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) { 00147 return false; 00148 } 00149 out_cnt_ = read_val; 00150 if (out_cnt_ <= 0) { 00151 return false; 00152 } 00153 // set neuron ids and types 00154 for (int idx = 0; idx < neuron_cnt_; idx++) { 00155 neurons_[idx].set_id(idx); 00156 // input type 00157 if (idx < in_cnt_) { 00158 neurons_[idx].set_node_type(Neuron::Input); 00159 } else if (idx >= (neuron_cnt_ - out_cnt_)) { 00160 neurons_[idx].set_node_type(Neuron::Output); 00161 } else { 00162 neurons_[idx].set_node_type(Neuron::Hidden); 00163 } 00164 } 00165 // read the connections 00166 for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) { 00167 // read fanout 00168 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) { 00169 return false; 00170 } 00171 // read the neuron's info 00172 int fan_out_cnt = read_val; 00173 for (int fan_out_idx = 0; fan_out_idx < fan_out_cnt; fan_out_idx++) { 00174 // read the neuron id 00175 if (input_buff->Read(&read_val, sizeof(read_val)) != sizeof(read_val)) { 00176 return false; 00177 } 00178 // create the connection 00179 if (!SetConnection(node_idx, read_val)) { 00180 return false; 00181 } 00182 } 00183 } 00184 // read all the neurons' fan-in connections 00185 for (int node_idx = 0; node_idx < neuron_cnt_; node_idx++) { 00186 // read 00187 if (!neurons_[node_idx].ReadBinary(input_buff)) { 00188 return false; 00189 } 00190 } 00191 // size input stats vector to expected input size 00192 inputs_mean_.resize(in_cnt_); 00193 inputs_std_dev_.resize(in_cnt_); 00194 inputs_min_.resize(in_cnt_); 00195 inputs_max_.resize(in_cnt_); 00196 // read stats 00197 if (input_buff->Read(&(inputs_mean_.front()), 00198 sizeof(inputs_mean_[0]) * in_cnt_) != 00199 sizeof(inputs_mean_[0]) * in_cnt_) { 00200 return false; 00201 } 00202 if (input_buff->Read(&(inputs_std_dev_.front()), 00203 sizeof(inputs_std_dev_[0]) * in_cnt_) != 00204 sizeof(inputs_std_dev_[0]) * in_cnt_) { 00205 return false; 00206 } 00207 if (input_buff->Read(&(inputs_min_.front()), 00208 sizeof(inputs_min_[0]) * in_cnt_) != 00209 sizeof(inputs_min_[0]) * in_cnt_) { 00210 return false; 00211 } 00212 if (input_buff->Read(&(inputs_max_.front()), 00213 sizeof(inputs_max_[0]) * in_cnt_) != 00214 sizeof(inputs_max_[0]) * in_cnt_) { 00215 return false; 00216 } 00217 // create a readonly version for fast feedforward 00218 if (read_only_) { 00219 return CreateFastNet(); 00220 } 00221 return true; 00222 } 00223 00224 // creates a connection between two nodes 00225 bool SetConnection(int from, int to); 00226 // Create a read only version of the net that 00227 // has faster feedforward performance 00228 bool CreateFastNet(); 00229 // internal function to allocate a new set of weights 00230 // Centralized weight allocation attempts to increase 00231 // weights locality of reference making it more cache friendly 00232 float *AllocWgt(int wgt_cnt); 00233 // different flavors read-only feedforward function 00234 template <typename Type> bool FastFeedForward(const Type *inputs, 00235 Type *outputs); 00236 // Compute the output of a specific output node. 00237 // This function is useful for application that are interested in a single 00238 // output of the net and do not want to waste time on the rest 00239 // This is the fast-read-only version of this function 00240 template <typename Type> bool FastGetNetOutput(const Type *inputs, 00241 int output_id, 00242 Type *output); 00243 }; 00244 } 00245 00246 #endif // NEURAL_NET_H__