19 #ifdef INCLUDE_TENSORFLOW 23 #include "allheaders.h" 27 using tensorflow::Status;
28 using tensorflow::Tensor;
29 using tensorflow::TensorShape;
35 int TFNetwork::InitFromProtoStr(
const string& proto_str) {
36 if (!model_proto_.ParseFromString(proto_str))
return 0;
37 return InitFromProto();
45 model_proto_.SerializeToString(&proto_str);
48 memcpy(&data[0], proto_str.data(), proto_str.size());
58 if (!model_proto_.ParseFromArray(&data[0], data.
size())) {
61 return InitFromProto();
66 void TFNetwork::Forward(
bool debug,
const NetworkIO& input,
67 const TransposedArray* input_transpose,
68 NetworkScratch* scratch, NetworkIO* output) {
69 std::vector<std::pair<string, Tensor>> tf_inputs;
70 int depth = input_shape_.depth();
73 const StrideMap& stride_map = input.stride_map();
77 Tensor input_tensor(tensorflow::DT_FLOAT, shape);
79 auto eigen_tensor = input_tensor.flat<
float>();
80 memcpy(eigen_tensor.data(), input.f(0),
81 input.Width() * depth *
sizeof(input.f(0)[0]));
83 tf_inputs.emplace_back(model_proto_.image_input(), input_tensor);
90 if (!model_proto_.image_widths().empty()) {
91 TensorShape size_shape{1};
92 Tensor width_tensor(tensorflow::DT_INT64, size_shape);
93 auto eigen_wtensor = width_tensor.flat<int64>();
94 *eigen_wtensor.data() = stride_map.Size(
FD_WIDTH);
95 tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
97 if (!model_proto_.image_heights().empty()) {
98 TensorShape size_shape{1};
99 Tensor height_tensor(tensorflow::DT_INT64, size_shape);
100 auto eigen_htensor = height_tensor.flat<int64>();
101 *eigen_htensor.data() = stride_map.Size(
FD_HEIGHT);
102 tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
104 std::vector<string> target_layers = {model_proto_.output_layer()};
105 std::vector<Tensor> outputs;
106 Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
107 if (!s.ok())
tprintf(
"session->Run failed:%s\n", s.error_message().c_str());
110 const Tensor& output_tensor = outputs[0];
113 int output_batch = output_tensor.shape().dim_size(0);
114 int output_steps = output_tensor.shape().dim_size(1);
115 int output_depth = output_tensor.shape().dim_size(2);
117 ASSERT_HOST(output_depth == output_shape_.depth());
118 output->Resize2d(
false, output_steps, output_depth);
119 auto eigen_output = output_tensor.flat<
float>();
120 memcpy(output->f(0), eigen_output.data(),
121 output_steps * output_depth *
sizeof(output->f(0)[0]));
124 int TFNetwork::InitFromProto() {
125 spec_ = model_proto_.spec();
126 input_shape_.SetShape(
127 model_proto_.batch_size(), std::max(0, model_proto_.y_size()),
128 std::max(0, model_proto_.x_size()), model_proto_.depth());
129 output_shape_.SetShape(model_proto_.batch_size(), 1, 0,
130 model_proto_.num_classes());
131 output_shape_.set_loss_type(model_proto_.using_ctc() ?
LT_CTC :
LT_SOFTMAX);
132 ni_ = input_shape_.height();
133 no_ = output_shape_.depth();
136 tensorflow::SessionOptions options;
137 session_.reset(NewSession(options));
138 Status s = session_->Create(model_proto_.graph());
139 if (s.ok())
return model_proto_.global_step();
140 tprintf(
"Session_->Create returned '%s'\n", s.error_message().c_str());
146 #endif // ifdef INCLUDE_TENSORFLOW
void resize_no_init(int size)
bool DeSerialize(bool swap, FILE *fp)
bool Serialize(FILE *fp, const char *data, size_t n)
bool Serialize(FILE *fp) const
virtual bool Serialize(TFile *fp) const
DLLSYM void tprintf(const char *format,...)
bool DeSerialize(FILE *fp, char *data, size_t n)