20 #ifndef TESSERACT_LSTM_TFNETWORK_H_ 21 #define TESSERACT_LSTM_TFNETWORK_H_ 23 #ifdef INCLUDE_TENSORFLOW 30 #include "tfnetwork.proto.h" 31 #include "third_party/tensorflow/core/framework/graph.pb.h" 32 #include "third_party/tensorflow/core/public/session.h" 36 class TFNetwork :
public Network {
38 explicit TFNetwork(
const STRING& name);
42 virtual StaticShape InputShape()
const {
return input_shape_; }
45 virtual StaticShape OutputShape(
const StaticShape& input_shape)
const {
49 virtual STRING spec()
const {
return spec_.
c_str(); }
53 int InitFromProtoStr(
const string& proto_str);
56 int num_classes()
const {
return output_shape_.depth(); }
60 virtual bool Serialize(TFile* fp)
const;
63 virtual bool DeSerialize(TFile* fp);
67 virtual void Forward(
bool debug,
const NetworkIO& input,
68 const TransposedArray* input_transpose,
69 NetworkScratch* scratch, NetworkIO* output);
77 StaticShape input_shape_;
79 StaticShape output_shape_;
81 std::unique_ptr<tensorflow::Session> session_;
83 TFNetworkModel model_proto_;
88 #endif // ifdef INCLUDE_TENSORFLOW 90 #endif // TESSERACT_TENSORFLOW_TFNETWORK_H_
const char * c_str() const