20 #ifndef TESSERACT_LSTM_TFNETWORK_H_ 21 #define TESSERACT_LSTM_TFNETWORK_H_ 23 #ifdef INCLUDE_TENSORFLOW 30 #include "tfnetwork.proto.h" 31 #include "third_party/tensorflow/core/framework/graph.pb.h" 32 #include "third_party/tensorflow/core/public/session.h" 36 class TFNetwork :
public Network {
38 explicit TFNetwork(
const STRING& name);
39 virtual ~TFNetwork() =
default;
42 StaticShape InputShape()
const override {
return input_shape_; }
45 StaticShape OutputShape(
const StaticShape& input_shape)
const override {
49 STRING spec()
const override {
return spec_.
c_str(); }
53 int InitFromProtoStr(
const string& proto_str);
56 int num_classes()
const {
return output_shape_.depth(); }
67 void Forward(
bool debug,
const NetworkIO& input,
68 const TransposedArray* input_transpose,
69 NetworkScratch* scratch, NetworkIO* output)
override;
77 StaticShape input_shape_;
79 StaticShape output_shape_;
81 std::unique_ptr<tensorflow::Session> session_;
83 TFNetworkModel model_proto_;
88 #endif // ifdef INCLUDE_TENSORFLOW 90 #endif // TESSERACT_TENSORFLOW_TFNETWORK_H_ bool Serialize(FILE *fp, const char *data, size_t n)
const char * c_str() const
bool DeSerialize(FILE *fp, char *data, size_t n)