46 int append_index,
int net_flags,
47 float weight_range,
TRand* randomizer,
50 Series* bottom_series =
nullptr;
52 if (append_index >= 0) {
55 auto* series = static_cast<Series*>(*network);
56 Series* top_series =
nullptr;
57 series->
SplitAt(append_index, &bottom_series, &top_series);
58 if (bottom_series ==
nullptr || top_series ==
nullptr) {
59 tprintf(
"Yikes! Splitting current network failed!!\n");
62 input_shape = bottom_series->
OutputShape(input_shape);
65 char* str_ptr = &network_spec[0];
67 if (*network ==
nullptr)
return false;
69 (*network)->InitWeights(weight_range, randomizer);
70 (*network)->SetupNeedsBackprop(
false);
71 if (bottom_series !=
nullptr) {
73 *network = bottom_series;
80 static void SkipWhitespace(
char** str) {
81 while (**str ==
' ' || **str ==
'\t' || **str ==
'\n') ++*str;
91 return ParseSeries(input_shape,
nullptr, str);
93 if (input_shape.
depth() == 0) {
95 return ParseInput(str);
99 return ParseParallel(input_shape, str);
101 return ParseR(input_shape, str);
103 return ParseS(input_shape, str);
105 return ParseC(input_shape, str);
107 return ParseM(input_shape, str);
109 return ParseLSTM(input_shape, str);
111 return ParseFullyConnected(input_shape, str);
113 return ParseOutput(input_shape, str);
115 tprintf(
"Invalid network spec:%s\n", *str);
123 Network* NetworkBuilder::ParseInput(
char** str) {
126 int batch, height, width, depth;
128 sscanf(*str,
"%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
130 shape.
SetShape(batch, height, width, depth);
132 if (num_converted != 4 && num_converted != 5) {
133 tprintf(
"Must specify an input layer as the first layer, not %s!!\n", *str);
137 Input* input =
new Input(
"Input", shape);
141 if (**str ==
'[')
return ParseSeries(shape, input, str);
146 Network* NetworkBuilder::ParseSeries(
const StaticShape& input_shape,
147 Input* input_layer,
char** str) {
148 StaticShape shape = input_shape;
149 Series* series =
new Series(
"Series");
151 if (input_layer !=
nullptr) {
152 series->AddToStack(input_layer);
153 shape = input_layer->OutputShape(shape);
155 Network* network =
nullptr;
156 while (**str !=
'\0' && **str !=
']' &&
158 shape = network->OutputShape(shape);
159 series->AddToStack(network);
162 tprintf(
"Missing ] at end of [Series]!\n");
171 Network* NetworkBuilder::ParseParallel(
const StaticShape& input_shape,
173 Parallel* parallel =
new Parallel(
"Parallel",
NT_PARALLEL);
175 Network* network =
nullptr;
176 while (**str !=
'\0' && **str !=
')' &&
178 parallel->AddToStack(network);
181 tprintf(
"Missing ) at end of (Parallel)!\n");
190 Network* NetworkBuilder::ParseR(
const StaticShape& input_shape,
char** str) {
191 char dir = (*str)[1];
192 if (dir ==
'x' || dir ==
'y') {
197 if (network ==
nullptr)
return nullptr;
200 rev->SetNetwork(network);
203 int replicas = strtol(*str + 1, str, 10);
205 tprintf(
"Invalid R spec!:%s\n", *str);
208 Parallel* parallel =
new Parallel(
"Replicated",
NT_REPLICATED);
209 char* str_copy = *str;
210 for (
int i = 0; i < replicas; ++i) {
213 if (network ==
nullptr) {
214 tprintf(
"Invalid replicated network!\n");
218 parallel->AddToStack(network);
225 Network* NetworkBuilder::ParseS(
const StaticShape& input_shape,
char** str) {
226 int y = strtol(*str + 1, str, 10);
228 int x = strtol(*str + 1, str, 10);
229 if (y <= 0 || x <= 0) {
230 tprintf(
"Invalid S spec!:%s\n", *str);
233 return new Reconfig(
"Reconfig", input_shape.depth(), x, y);
234 }
else if (**str ==
'(') {
236 tprintf(
"Generic reshape not yet implemented!!\n");
239 tprintf(
"Invalid S spec!:%s\n", *str);
266 Network* NetworkBuilder::ParseC(
const StaticShape& input_shape,
char** str) {
269 tprintf(
"Invalid nonlinearity on C-spec!: %s\n", *str);
272 int y = 0, x = 0, d = 0;
273 if ((y = strtol(*str + 2, str, 10)) <= 0 || **str !=
',' ||
274 (x = strtol(*str + 1, str, 10)) <= 0 || **str !=
',' ||
275 (d = strtol(*str + 1, str, 10)) <= 0) {
276 tprintf(
"Invalid C spec!:%s\n", *str);
279 if (x == 1 && y == 1) {
282 return new FullyConnected(
"Conv1x1", input_shape.depth(), d,
type);
284 Series* series =
new Series(
"ConvSeries");
286 new Convolve(
"Convolve", input_shape.depth(), x / 2, y / 2);
287 series->AddToStack(convolve);
288 StaticShape fc_input = convolve->OutputShape(input_shape);
289 series->AddToStack(
new FullyConnected(
"ConvNL", fc_input.depth(), d,
type));
294 Network* NetworkBuilder::ParseM(
const StaticShape& input_shape,
char** str) {
296 if ((*str)[1] !=
'p' || (y = strtol(*str + 2, str, 10)) <= 0 ||
297 **str !=
',' || (x = strtol(*str + 1, str, 10)) <= 0) {
298 tprintf(
"Invalid Mp spec!:%s\n", *str);
301 return new Maxpool(
"Maxpool", input_shape.depth(), x, y);
305 Network* NetworkBuilder::ParseLSTM(
const StaticShape& input_shape,
char** str) {
308 char* spec_start = *str;
309 int chars_consumed = 1;
311 char key = (*str)[chars_consumed], dir =
'f', dim =
'x';
314 num_outputs = num_softmax_outputs_;
316 }
else if (key ==
'E') {
318 num_outputs = num_softmax_outputs_;
320 }
else if (key ==
'2' && (((*str)[2] ==
'x' && (*str)[3] ==
'y') ||
321 ((*str)[2] ==
'y' && (*str)[3] ==
'x'))) {
325 }
else if (key ==
'f' || key ==
'r' || key ==
'b') {
328 if (dim !=
'x' && dim !=
'y') {
329 tprintf(
"Invalid dimension (x|y) in L Spec!:%s\n", *str);
333 if ((*str)[chars_consumed] ==
's') {
338 tprintf(
"Invalid direction (f|r|b) in L Spec!:%s\n", *str);
341 int num_states = strtol(*str + chars_consumed, str, 10);
342 if (num_states <= 0) {
343 tprintf(
"Invalid number of states in L Spec!:%s\n", *str);
346 Network* lstm =
nullptr;
348 lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
350 if (num_outputs == 0) num_outputs = num_states;
351 STRING name(spec_start, *str - spec_start);
352 lstm =
new LSTM(name, input_shape.depth(), num_states, num_outputs,
false,
356 rev->SetNetwork(lstm);
362 parallel->AddToStack(
new LSTM(name, input_shape.depth(), num_states,
363 num_outputs,
false,
type));
364 parallel->AddToStack(lstm);
370 rev->SetNetwork(lstm);
377 Network* NetworkBuilder::BuildLSTMXYQuad(
int num_inputs,
int num_states) {
379 parallel->AddToStack(
new LSTM(
"L2DLTRDown", num_inputs, num_states,
381 Reversed* rev =
new Reversed(
"L2DLTRXRev",
NT_XREVERSED);
382 rev->SetNetwork(
new LSTM(
"L2DRTLDown", num_inputs, num_states, num_states,
384 parallel->AddToStack(rev);
387 new LSTM(
"L2DRTLUp", num_inputs, num_states, num_states,
true,
NT_LSTM));
388 Reversed* rev2 =
new Reversed(
"L2DXRevU",
NT_XREVERSED);
389 rev2->SetNetwork(rev);
390 parallel->AddToStack(rev2);
392 rev->SetNetwork(
new LSTM(
"L2DLTRDown", num_inputs, num_states, num_states,
394 parallel->AddToStack(rev);
399 static Network* BuildFullyConnected(
const StaticShape& input_shape,
402 if (input_shape.height() == 0 || input_shape.width() == 0) {
403 tprintf(
"Fully connected requires positive height and width, had %d,%d\n",
404 input_shape.height(), input_shape.width());
407 int input_size = input_shape.height() * input_shape.width();
408 int input_depth = input_size * input_shape.depth();
409 Network* fc =
new FullyConnected(name, input_depth, depth,
type);
410 if (input_size > 1) {
411 Series* series =
new Series(
"FCSeries");
412 series->AddToStack(
new Reconfig(
"FCReconfig", input_shape.depth(),
413 input_shape.width(), input_shape.height()));
414 series->AddToStack(fc);
421 Network* NetworkBuilder::ParseFullyConnected(
const StaticShape& input_shape,
423 char* spec_start = *str;
426 tprintf(
"Invalid nonlinearity on F-spec!: %s\n", *str);
429 int depth = strtol(*str + 2, str, 10);
431 tprintf(
"Invalid F spec!:%s\n", *str);
434 STRING name(spec_start, *str - spec_start);
435 return BuildFullyConnected(input_shape,
type, name, depth);
439 Network* NetworkBuilder::ParseOutput(
const StaticShape& input_shape,
441 char dims_ch = (*str)[1];
442 if (dims_ch !=
'0' && dims_ch !=
'1' && dims_ch !=
'2') {
443 tprintf(
"Invalid dims (2|1|0) in output spec!:%s\n", *str);
446 char type_ch = (*str)[2];
447 if (type_ch !=
'l' && type_ch !=
's' && type_ch !=
'c') {
448 tprintf(
"Invalid output type (l|s|c) in output spec!:%s\n", *str);
451 int depth = strtol(*str + 3, str, 10);
452 if (depth != num_softmax_outputs_) {
453 tprintf(
"Warning: given outputs %d not equal to unicharset of %d.\n", depth,
454 num_softmax_outputs_);
455 depth = num_softmax_outputs_;
460 else if (type_ch ==
's')
462 if (dims_ch ==
'0') {
464 return BuildFullyConnected(input_shape,
type,
"Output", depth);
465 }
else if (dims_ch ==
'2') {
467 return new FullyConnected(
"Output2d", input_shape.depth(), depth,
type);
470 if (input_shape.height() == 0) {
471 tprintf(
"Fully connected requires fixed height!\n");
474 int input_size = input_shape.height();
475 int input_depth = input_size * input_shape.depth();
476 Network* fc =
new FullyConnected(
"Output", input_depth, depth,
type);
477 if (input_size > 1) {
478 Series* series =
new Series(
"FCSeries");
479 series->AddToStack(
new Reconfig(
"FCReconfig", input_shape.depth(), 1,
480 input_shape.height()));
481 series->AddToStack(fc);