46 int append_index,
int net_flags,
47 float weight_range,
TRand* randomizer,
50 Series* bottom_series = NULL;
52 if (append_index >= 0) {
57 series->
SplitAt(append_index, &bottom_series, &top_series);
58 if (bottom_series == NULL || top_series == NULL) {
59 tprintf(
"Yikes! Splitting current network failed!!\n");
62 input_shape = bottom_series->
OutputShape(input_shape);
65 char* str_ptr = &network_spec[0];
67 if (*network == NULL)
return false;
69 (*network)->InitWeights(weight_range, randomizer);
70 (*network)->SetupNeedsBackprop(
false);
71 if (bottom_series != NULL) {
73 *network = bottom_series;
80 static void SkipWhitespace(
char** str) {
81 while (**str ==
' ' || **str ==
'\t' || **str ==
'\n') ++*str;
91 return ParseSeries(input_shape,
nullptr, str);
93 if (input_shape.
depth() == 0) {
95 return ParseInput(str);
99 return ParseParallel(input_shape, str);
101 return ParseR(input_shape, str);
103 return ParseS(input_shape, str);
105 return ParseC(input_shape, str);
107 return ParseM(input_shape, str);
109 return ParseLSTM(input_shape, str);
111 return ParseFullyConnected(input_shape, str);
113 return ParseOutput(input_shape, str);
115 tprintf(
"Invalid network spec:%s\n", *str);
123 Network* NetworkBuilder::ParseInput(
char** str) {
126 int batch, height, width, depth;
128 sscanf(*str,
"%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
130 shape.
SetShape(batch, height, width, depth);
132 if (num_converted != 4 && num_converted != 5) {
133 tprintf(
"Must specify an input layer as the first layer, not %s!!\n", *str);
141 if (**str ==
'[')
return ParseSeries(shape, input, str);
147 Input* input_layer,
char** str) {
151 if (input_layer !=
nullptr) {
156 while (**str !=
'\0' && **str !=
']' &&
162 tprintf(
"Missing ] at end of [Series]!\n");
176 while (**str !=
'\0' && **str !=
')' &&
181 tprintf(
"Missing ) at end of (Parallel)!\n");
191 char dir = (*str)[1];
192 if (dir ==
'x' || dir ==
'y') {
197 if (network ==
nullptr)
return nullptr;
203 int replicas = strtol(*str + 1, str, 10);
205 tprintf(
"Invalid R spec!:%s\n", *str);
209 char* str_copy = *str;
210 for (
int i = 0; i < replicas; ++i) {
213 if (network == NULL) {
214 tprintf(
"Invalid replicated network!\n");
226 int y = strtol(*str + 1, str, 10);
228 int x = strtol(*str + 1, str, 10);
229 if (y <= 0 || x <= 0) {
230 tprintf(
"Invalid S spec!:%s\n", *str);
234 }
else if (**str ==
'(') {
236 tprintf(
"Generic reshape not yet implemented!!\n");
239 tprintf(
"Invalid S spec!:%s\n", *str);
269 tprintf(
"Invalid nonlinearity on C-spec!: %s\n", *str);
272 int y = 0, x = 0, d = 0;
273 if ((y = strtol(*str + 2, str, 10)) <= 0 || **str !=
',' ||
274 (x = strtol(*str + 1, str, 10)) <= 0 || **str !=
',' ||
275 (d = strtol(*str + 1, str, 10)) <= 0) {
276 tprintf(
"Invalid C spec!:%s\n", *str);
279 if (x == 1 && y == 1) {
287 series->AddToStack(convolve);
296 if ((*str)[1] !=
'p' || (y = strtol(*str + 2, str, 10)) <= 0 ||
297 **str !=
',' || (x = strtol(*str + 1, str, 10)) <= 0) {
298 tprintf(
"Invalid Mp spec!:%s\n", *str);
301 return new Maxpool(
"Maxpool", input_shape.
depth(), x, y);
308 char* spec_start = *str;
309 int chars_consumed = 1;
311 char key = (*str)[chars_consumed], dir =
'f', dim =
'x';
314 num_outputs = num_softmax_outputs_;
316 }
else if (key ==
'E') {
318 num_outputs = num_softmax_outputs_;
320 }
else if (key ==
'2' && (((*str)[2] ==
'x' && (*str)[3] ==
'y') ||
321 ((*str)[2] ==
'y' && (*str)[3] ==
'x'))) {
325 }
else if (key ==
'f' || key ==
'r' || key ==
'b') {
328 if (dim !=
'x' && dim !=
'y') {
329 tprintf(
"Invalid dimension (x|y) in L Spec!:%s\n", *str);
333 if ((*str)[chars_consumed] ==
's') {
338 tprintf(
"Invalid direction (f|r|b) in L Spec!:%s\n", *str);
341 int num_states = strtol(*str + chars_consumed, str, 10);
342 if (num_states <= 0) {
343 tprintf(
"Invalid number of states in L Spec!:%s\n", *str);
348 lstm = BuildLSTMXYQuad(input_shape.
depth(), num_states);
350 if (num_outputs == 0) num_outputs = num_states;
351 STRING name(spec_start, *str - spec_start);
352 lstm =
new LSTM(name, input_shape.
depth(), num_states, num_outputs,
false,
363 num_outputs,
false, type));
377 Network* NetworkBuilder::BuildLSTMXYQuad(
int num_inputs,
int num_states) {
379 parallel->
AddToStack(
new LSTM(
"L2DLTRDown", num_inputs, num_states,
382 rev->
SetNetwork(
new LSTM(
"L2DRTLDown", num_inputs, num_states, num_states,
387 new LSTM(
"L2DRTLUp", num_inputs, num_states, num_states,
true,
NT_LSTM));
392 rev->
SetNetwork(
new LSTM(
"L2DLTRDown", num_inputs, num_states, num_states,
402 if (input_shape.
height() == 0 || input_shape.
width() == 0) {
403 tprintf(
"Fully connected requires positive height and width, had %d,%d\n",
407 int input_size = input_shape.
height() * input_shape.
width();
408 int input_depth = input_size * input_shape.
depth();
410 if (input_size > 1) {
423 char* spec_start = *str;
426 tprintf(
"Invalid nonlinearity on F-spec!: %s\n", *str);
429 int depth = strtol(*str + 1, str, 10);
431 tprintf(
"Invalid F spec!:%s\n", *str);
434 STRING name(spec_start, *str - spec_start);
435 return BuildFullyConnected(input_shape, type, name, depth);
441 char dims_ch = (*str)[1];
442 if (dims_ch !=
'0' && dims_ch !=
'1' && dims_ch !=
'2') {
443 tprintf(
"Invalid dims (2|1|0) in output spec!:%s\n", *str);
446 char type_ch = (*str)[2];
447 if (type_ch !=
'l' && type_ch !=
's' && type_ch !=
'c') {
448 tprintf(
"Invalid output type (l|s|c) in output spec!:%s\n", *str);
451 int depth = strtol(*str + 3, str, 10);
452 if (depth != num_softmax_outputs_) {
453 tprintf(
"Warning: given outputs %d not equal to unicharset of %d.\n", depth,
454 num_softmax_outputs_);
455 depth = num_softmax_outputs_;
460 else if (type_ch ==
's')
462 if (dims_ch ==
'0') {
464 return BuildFullyConnected(input_shape, type,
"Output", depth);
465 }
else if (dims_ch ==
'2') {
470 if (input_shape.
height() == 0) {
471 tprintf(
"Fully connected requires fixed height!\n");
474 int input_size = input_shape.
height();
475 int input_depth = input_size * input_shape.
depth();
477 if (input_size > 1) {
virtual void CacheXScaleFactor(int factor)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void SplitAt(int last_start, Series **start, Series **end)
virtual void SetNetworkFlags(uinT32 flags)
void SetNetwork(Network *network)
void AppendSeries(Network *src)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void SetShape(int batch, int height, int width, int depth)
Network * BuildFromString(const StaticShape &input_shape, char **str)
virtual void AddToStack(Network *network)