tesseract  4.00.00dev
maxpool.cpp
Go to the documentation of this file.
1 // File: maxpool.h
3 // Description: Standard Max-Pooling layer.
4 // Author: Ray Smith
5 // Created: Tue Mar 18 16:28:18 PST 2014
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "maxpool.h"
20 #include "tprintf.h"
21 
22 namespace tesseract {
23 
24 Maxpool::Maxpool(const STRING& name, int ni, int x_scale, int y_scale)
25  : Reconfig(name, ni, x_scale, y_scale) {
26  type_ = NT_MAXPOOL;
27  no_ = ni;
28 }
29 
31 }
32 
33 // Reads from the given file. Returns false in case of error.
35  bool result = Reconfig::DeSerialize(fp);
36  no_ = ni_;
37  return result;
38 }
39 
40 // Runs forward propagation of activations on the input line.
41 // See NetworkCpp for a detailed discussion of the arguments.
42 void Maxpool::Forward(bool debug, const NetworkIO& input,
43  const TransposedArray* input_transpose,
44  NetworkScratch* scratch, NetworkIO* output) {
45  output->ResizeScaled(input, x_scale_, y_scale_, no_);
46  maxes_.ResizeNoInit(output->Width(), ni_);
47  back_map_ = input.stride_map();
48 
49  StrideMap::Index dest_index(output->stride_map());
50  do {
51  int out_t = dest_index.t();
52  StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
53  dest_index.index(FD_HEIGHT) * y_scale_,
54  dest_index.index(FD_WIDTH) * x_scale_);
55  // Find the max input out of x_scale_ groups of y_scale_ inputs.
56  // Do it independently for each input dimension.
57  int* max_line = maxes_[out_t];
58  int in_t = src_index.t();
59  output->CopyTimeStepFrom(out_t, input, in_t);
60  for (int i = 0; i < ni_; ++i) {
61  max_line[i] = in_t;
62  }
63  for (int x = 0; x < x_scale_; ++x) {
64  for (int y = 0; y < y_scale_; ++y) {
65  StrideMap::Index src_xy(src_index);
66  if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
67  output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line);
68  }
69  }
70  }
71  } while (dest_index.Increment());
72 }
73 
74 // Runs backward propagation of errors on the deltas line.
75 // See NetworkCpp for a detailed discussion of the arguments.
76 bool Maxpool::Backward(bool debug, const NetworkIO& fwd_deltas,
77  NetworkScratch* scratch,
78  NetworkIO* back_deltas) {
79  back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
80  back_deltas->MaxpoolBackward(fwd_deltas, maxes_);
81  return true;
82 }
83 
84 
85 } // namespace tesseract.
86 
Maxpool(const STRING &name, int ni, int x_scale, int y_scale)
Definition: maxpool.cpp:24
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:383
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:62
virtual ~Maxpool()
Definition: maxpool.cpp:30
int Width() const
Definition: networkio.h:107
void MaxpoolTimeStep(int dest_t, const NetworkIO &src, int src_t, int *max_line)
Definition: networkio.cpp:668
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: maxpool.cpp:76
StrideMap back_map_
Definition: reconfig.h:76
virtual bool DeSerialize(TFile *fp)
Definition: reconfig.cpp:62
void ResizeScaled(const NetworkIO &src, int x_scale, int y_scale, int num_features)
Definition: networkio.cpp:62
bool int_mode() const
Definition: networkio.h:127
void MaxpoolBackward(const NetworkIO &fwd, const GENERIC_2D_ARRAY< int > &maxes)
Definition: networkio.cpp:695
Definition: strngs.h:45
void ResizeNoInit(int size1, int size2)
Definition: matrix.h:86
NetworkType type_
Definition: network.h:285
const StrideMap & stride_map() const
Definition: networkio.h:133
virtual bool DeSerialize(TFile *fp)
Definition: maxpool.cpp:34
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: maxpool.cpp:42
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:45