19 #ifndef TESSERACT_LSTM_TFNETWORK_H_
20 #define TESSERACT_LSTM_TFNETWORK_H_
22 #ifdef INCLUDE_TENSORFLOW
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/public/session.h"
35 class TFNetwork :
public Network {
37 explicit TFNetwork(
const STRING& name);
38 virtual ~TFNetwork() =
default;
41 StaticShape InputShape()
const override {
return input_shape_; }
44 StaticShape OutputShape(
const StaticShape& input_shape)
const override {
48 STRING spec()
const override {
return spec_.
c_str(); }
55 int num_classes()
const {
return output_shape_.depth(); }
66 void Forward(
bool debug,
const NetworkIO& input,
67 const TransposedArray* input_transpose,
68 NetworkScratch* scratch, NetworkIO* output)
override;
73 bool Backward(
bool debug,
const NetworkIO& fwd_deltas,
74 NetworkScratch* scratch,
75 NetworkIO* back_deltas)
override {
76 tprintf(
"Must override Network::Backward for type %d\n", type_);
80 void DebugWeights()
override {
81 tprintf(
"Must override Network::DebugWeights for type %d\n", type_);
89 StaticShape input_shape_;
91 StaticShape output_shape_;
93 std::unique_ptr<tensorflow::Session> session_;
95 TFNetworkModel model_proto_;
100 #endif // ifdef INCLUDE_TENSORFLOW
102 #endif // TESSERACT_TENSORFLOW_TFNETWORK_H_