tesseract  4.00.00dev
tfnetwork.cpp
Go to the documentation of this file.
1 // File: tfnetwork.h
3 // Description: Encapsulation of an entire tensorflow graph as a
4 // Tesseract Network.
5 // Author: Ray Smith
6 // Created: Fri Feb 26 09:35:29 PST 2016
7 //
8 // (C) Copyright 2016, Google Inc.
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 // http://www.apache.org/licenses/LICENSE-2.0
13 // Unless required by applicable law or agreed to in writing, software
14 // distributed under the License is distributed on an "AS IS" BASIS,
15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 // See the License for the specific language governing permissions and
17 // limitations under the License.
19 #ifdef INCLUDE_TENSORFLOW
20 
21 #include "tfnetwork.h"
22 
23 #include "allheaders.h"
24 #include "input.h"
25 #include "networkscratch.h"
26 
27 using tensorflow::Status;
28 using tensorflow::Tensor;
29 using tensorflow::TensorShape;
30 
31 namespace tesseract {
32 
33 TFNetwork::TFNetwork(const STRING& name) : Network(NT_TENSORFLOW, name, 0, 0) {}
34 
35 TFNetwork::~TFNetwork() {}
36 
37 int TFNetwork::InitFromProtoStr(const string& proto_str) {
38  if (!model_proto_.ParseFromString(proto_str)) return 0;
39  return InitFromProto();
40 }
41 
42 // Writes to the given file. Returns false in case of error.
43 // Should be overridden by subclasses, but called by their Serialize.
44 bool TFNetwork::Serialize(TFile* fp) const {
45  if (!Network::Serialize(fp)) return false;
46  string proto_str;
47  model_proto_.SerializeToString(&proto_str);
49  data.resize_no_init(proto_str.size());
50  memcpy(&data[0], proto_str.data(), proto_str.size());
51  if (!data.Serialize(fp)) return false;
52  return true;
53 }
54 
55 // Reads from the given file. Returns false in case of error.
56 // Should be overridden by subclasses, but NOT called by their DeSerialize.
57 bool TFNetwork::DeSerialize(TFile* fp) {
59  if (!data.DeSerialize(fp)) return false;
60  if (!model_proto_.ParseFromArray(&data[0], data.size())) {
61  return false;
62  }
63  return InitFromProto();
64 }
65 
66 // Runs forward propagation of activations on the input line.
67 // See Network for a detailed discussion of the arguments.
68 void TFNetwork::Forward(bool debug, const NetworkIO& input,
69  const TransposedArray* input_transpose,
70  NetworkScratch* scratch, NetworkIO* output) {
71  std::vector<std::pair<string, Tensor>> tf_inputs;
72  int depth = input_shape_.depth();
73  ASSERT_HOST(depth == input.NumFeatures());
74  // TODO(rays) Allow batching. For now batch_size = 1.
75  const StrideMap& stride_map = input.stride_map();
76  // TF requires a tensor of shape float[batch, height, width, depth].
77  TensorShape shape{1, stride_map.Size(FD_HEIGHT), stride_map.Size(FD_WIDTH),
78  depth};
79  Tensor input_tensor(tensorflow::DT_FLOAT, shape);
80  // The flat() member gives a 1d array, with a data() member to get the data.
81  auto eigen_tensor = input_tensor.flat<float>();
82  memcpy(eigen_tensor.data(), input.f(0),
83  input.Width() * depth * sizeof(input.f(0)[0]));
84  // Add the tensor to the vector of inputs.
85  tf_inputs.emplace_back(model_proto_.image_input(), input_tensor);
86 
87  // Provide tensors giving the width and/or height of the image if they are
88  // required. Some tf ops require a separate tensor with knowledge of the
89  // size of the input as they cannot obtain it from the input tensor. This is
90  // usually true in the case of ops that process a batch of variable-sized
91  // objects.
92  if (!model_proto_.image_widths().empty()) {
93  TensorShape size_shape{1};
94  Tensor width_tensor(tensorflow::DT_INT64, size_shape);
95  auto eigen_wtensor = width_tensor.flat<int64>();
96  *eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
97  tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
98  }
99  if (!model_proto_.image_heights().empty()) {
100  TensorShape size_shape{1};
101  Tensor height_tensor(tensorflow::DT_INT64, size_shape);
102  auto eigen_htensor = height_tensor.flat<int64>();
103  *eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
104  tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
105  }
106  std::vector<string> target_layers = {model_proto_.output_layer()};
107  std::vector<Tensor> outputs;
108  Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
109  if (!s.ok()) tprintf("session->Run failed:%s\n", s.error_message().c_str());
110  ASSERT_HOST(s.ok());
111  ASSERT_HOST(outputs.size() == 1);
112  const Tensor& output_tensor = outputs[0];
113  // Check the dimensions of the output.
114  ASSERT_HOST(output_tensor.shape().dims() == 3);
115  int output_batch = output_tensor.shape().dim_size(0);
116  int output_steps = output_tensor.shape().dim_size(1);
117  int output_depth = output_tensor.shape().dim_size(2);
118  ASSERT_HOST(output_batch == 1);
119  ASSERT_HOST(output_depth == output_shape_.depth());
120  output->Resize2d(false, output_steps, output_depth);
121  auto eigen_output = output_tensor.flat<float>();
122  memcpy(output->f(0), eigen_output.data(),
123  output_steps * output_depth * sizeof(output->f(0)[0]));
124 }
125 
126 int TFNetwork::InitFromProto() {
127  spec_ = model_proto_.spec();
128  input_shape_.SetShape(
129  model_proto_.batch_size(), std::max(0, model_proto_.y_size()),
130  std::max(0, model_proto_.x_size()), model_proto_.depth());
131  output_shape_.SetShape(model_proto_.batch_size(), 1, 0,
132  model_proto_.num_classes());
133  output_shape_.set_loss_type(model_proto_.using_ctc() ? LT_CTC : LT_SOFTMAX);
134  ni_ = input_shape_.height();
135  no_ = output_shape_.depth();
136  // Initialize the session_ with the graph. Since we can't get the graph
137  // back from the session_, we have to keep the proto as well
138  tensorflow::SessionOptions options;
139  session_.reset(NewSession(options));
140  Status s = session_->Create(model_proto_.graph());
141  if (s.ok()) return model_proto_.global_step();
142  tprintf("Session_->Create returned '%s'\n", s.error_message().c_str());
143  return 0;
144 }
145 
146 } // namespace tesseract
147 
148 #endif // ifdef INCLUDE_TENSORFLOW
bool DeSerialize(bool swap, FILE *fp)
#define tprintf(...)
Definition: tprintf.h:31
void resize_no_init(int size)
Definition: genericvector.h:66
int size() const
Definition: genericvector.h:72
#define ASSERT_HOST(x)
Definition: errcode.h:84
Definition: strngs.h:45
bool Serialize(FILE *fp) const
const int max
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153