tesseract  4.00.00dev
static_shape.h
Go to the documentation of this file.
1 // File: static_shape.h
3 // Description: Defines the size of the 4-d tensor input/output from a network.
4 // Author: Ray Smith
5 // Created: Fri Oct 14 09:07:31 PST 2016
6 //
7 // (C) Copyright 2016, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 #ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
19 #define TESSERACT_LSTM_STATIC_SHAPE_H_
20 
21 #include "tprintf.h"
22 
23 namespace tesseract {
24 
25 // Enum describing the loss function to apply during training and/or the
26 // decoding method to apply at runtime.
27 enum LossType {
28  LT_NONE, // Undefined.
29  LT_CTC, // Softmax with standard CTC for training/decoding.
30  LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
31  LT_LOGISTIC, // Logistic outputs with independent values.
32 };
33 
34 // Simple class to hold the tensor shape that is known at network build time
35 // and the LossType of the loss function.
36 class StaticShape {
37  public:
39  : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
40  int batch() const { return batch_; }
41  void set_batch(int value) { batch_ = value; }
42  int height() const { return height_; }
43  void set_height(int value) { height_ = value; }
44  int width() const { return width_; }
45  void set_width(int value) { width_ = value; }
46  int depth() const { return depth_; }
47  void set_depth(int value) { depth_ = value; }
48  LossType loss_type() const { return loss_type_; }
49  void set_loss_type(LossType value) { loss_type_ = value; }
50  void SetShape(int batch, int height, int width, int depth) {
51  batch_ = batch;
52  height_ = height;
53  width_ = width;
54  depth_ = depth;
55  }
56 
57  void Print() const {
58  tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_,
59  height_, width_, depth_, loss_type_);
60  }
61 
62  private:
63  // Size of the 4-D tensor input/output to a network. A value of zero is
64  // allowed for all except depth_ and means to be determined at runtime, and
65  // regarded as variable.
66  // Number of elements in a batch, or number of frames in a video stream.
67  int batch_;
68  // Height of the image.
69  int height_;
70  // Width of the image.
71  int width_;
72  // Depth of the image. (Number of "nodes").
73  int depth_;
74  // How to train/interpret the output.
75  LossType loss_type_;
76 };
77 
78 } // namespace tesseract
79 
80 #endif // TESSERACT_LSTM_STATIC_SHAPE_H_
LossType loss_type() const
Definition: static_shape.h:48
void set_batch(int value)
Definition: static_shape.h:41
#define tprintf(...)
Definition: tprintf.h:31
void set_loss_type(LossType value)
Definition: static_shape.h:49
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:50
void set_width(int value)
Definition: static_shape.h:45
void set_height(int value)
Definition: static_shape.h:43
void set_depth(int value)
Definition: static_shape.h:47