36 5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
44 if (
code == null_char) {
53 if (depth > 0 &&
prev !=
nullptr) {
55 prev->
Print(null_char, unicharset, depth - 1);
63 int null_char,
bool simple_text,
Dict* dict)
69 space_delimited_(true),
70 is_simple_text_(simple_text),
71 null_char_(null_char) {
77 double cert_offset,
double worst_dict_cert,
80 int width = output.
Width();
81 for (
int t = 0; t < width; ++t) {
82 ComputeTopN(output.
f(t), output.
NumFeatures(), kBeamWidths[0]);
83 DecodeStep(output.
f(t), t, dict_ratio, cert_offset, worst_dict_cert,
88 double dict_ratio,
double cert_offset,
89 double worst_dict_cert,
92 int width = output.
dim1();
93 for (
int t = 0; t < width; ++t) {
94 ComputeTopN(output[t], output.
dim2(), kBeamWidths[0]);
95 DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
105 ExtractBestPaths(&best_nodes, NULL);
108 int width = best_nodes.
size();
110 int label = best_nodes[t]->code;
111 if (label != null_char_) {
115 while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
128 ExtractBestPaths(&best_nodes, NULL);
129 ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
131 DebugPath(unicharset, best_nodes);
132 DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
139 float scale_factor,
bool debug,
149 ExtractBestPaths(&best_nodes, &second_nodes);
151 DebugPath(unicharset, best_nodes);
152 ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
154 tprintf(
"\nSecond choice path:\n");
155 DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
158 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords);
159 int num_ids = unichar_ids.
size();
161 DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
166 float prev_space_cert = 0.0f;
167 for (
int word_start = 0; word_start < num_ids; word_start = word_end) {
168 for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
173 int index = xcoords[word_end];
174 if (best_nodes[index]->start_of_word)
break;
180 float space_cert = 0.0f;
181 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE)
182 space_cert = certs[word_end];
184 word_start > 0 && unichar_ids[word_start - 1] ==
UNICHAR_SPACE;
186 WERD_RES* word_res = InitializeWord(
187 leading_space, line_box, word_start, word_end,
188 MIN(space_cert, prev_space_cert), unicharset, xcoords, scale_factor);
189 for (
int i = word_start; i < word_end; ++i) {
190 BLOB_CHOICE_LIST* choices =
new BLOB_CHOICE_LIST;
191 BLOB_CHOICE_IT bc_it(choices);
193 unichar_ids[i], ratings[i], certs[i], -1, 1.0f,
195 int col = i - word_start;
197 bc_it.add_after_then_move(choice);
200 int index = xcoords[word_end - 1];
203 prev_space_cert = space_cert;
204 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE)
211 for (
int p = 0; p < beam_size_; ++p) {
212 for (
int d = 0; d < 2; ++d) {
213 for (
int c = 0; c <
NC_COUNT; ++c) {
216 if (beam_[p]->beams_[index].empty())
continue;
218 tprintf(
"Position %d: %s+%s beam\n", p, d ?
"Dict" :
"Non-Dict",
220 DebugBeamPos(unicharset, beam_[p]->beams_[index]);
227 void RecodeBeamSearch::DebugBeamPos(
const UNICHARSET& unicharset,
232 int heap_size = heap.
size();
233 for (
int i = 0; i < heap_size; ++i) {
236 if (null_best == NULL || null_best->
score < node->
score) null_best = node;
238 if (unichar_bests[node->
unichar_id] == NULL ||
244 for (
int u = 0;
u < unichar_bests.
size(); ++
u) {
245 if (unichar_bests[
u] != NULL) {
247 node.
Print(null_char_, unicharset, 1);
250 if (null_best != NULL) {
251 null_best->
Print(null_char_, unicharset, 1);
258 void RecodeBeamSearch::ExtractPathAsUnicharIds(
268 int width = best_nodes.
size();
270 double certainty = 0.0;
272 while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
273 double cert = best_nodes[t++]->certainty;
274 if (cert < certainty) certainty = cert;
278 int unichar_id = best_nodes[t]->unichar_id;
280 best_nodes[t]->permuter !=
NO_PERM) {
283 if (certainty < certs->back()) certs->
back() = certainty;
284 ratings->
back() += rating;
291 double cert = best_nodes[t++]->certainty;
295 best_nodes[t - 1]->permuter ==
NO_PERM)) {
299 }
while (t < width && best_nodes[t]->duplicate);
302 }
else if (!certs->
empty()) {
303 if (certainty < certs->back()) certs->
back() = certainty;
304 ratings->
back() += rating;
312 WERD_RES* RecodeBeamSearch::InitializeWord(
bool leading_space,
313 const TBOX& line_box,
int word_start,
314 int word_end,
float space_certainty,
317 float scale_factor) {
320 C_BLOB_IT b_it(&blobs);
321 for (
int i = word_start; i < word_end; ++i) {
322 int min_half_width = xcoords[i + 1] - xcoords[i];
323 if (i > 0 && xcoords[i] - xcoords[i - 1] < min_half_width)
324 min_half_width = xcoords[i] - xcoords[i - 1];
325 if (min_half_width < 1) min_half_width = 1;
327 TBOX box(xcoords[i] - min_half_width, 0, xcoords[i] + min_half_width,
329 box.
scale(scale_factor);
331 box.set_top(line_box.
top());
335 WERD* word =
new WERD(&blobs, leading_space, NULL);
338 word_res->
uch_set = unicharset;
347 void RecodeBeamSearch::ComputeTopN(
const float* outputs,
int num_outputs,
353 for (
int i = 0; i < num_outputs; ++i) {
354 if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) {
356 top_heap_.Push(&entry);
357 if (top_heap_.size() > top_n) top_heap_.Pop(&entry);
360 while (!top_heap_.empty()) {
362 top_heap_.Pop(&entry);
363 if (top_heap_.size() > 1) {
367 if (top_heap_.empty())
368 top_code_ = entry.
data;
370 second_code_ = entry.
data;
373 top_n_flags_[null_char_] =
TN_TOP2;
379 void RecodeBeamSearch::DecodeStep(
const float* outputs,
int t,
380 double dict_ratio,
double cert_offset,
381 double worst_dict_cert,
384 RecodeBeam* step = beam_[t];
390 dict_ratio, cert_offset, worst_dict_cert, step);
391 if (dict_ !=
nullptr) {
393 TN_TOP2, dict_ratio, cert_offset, worst_dict_cert, step);
396 RecodeBeam* prev = beam_[t - 1];
397 if (charset != NULL) {
399 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
401 ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
402 tprintf(
"Step %d: Dawg beam %d:\n", t, i);
403 DebugPath(charset, path);
406 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
408 ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
409 tprintf(
"Step %d: Non-Dawg beam %d:\n", t, i);
410 DebugPath(charset, path);
418 for (
int tn = 0; tn <
TN_COUNT && total_beam == 0; ++tn) {
420 for (
int index = 0; index <
kNumBeams; ++index) {
424 for (
int i = prev->beams_[index].size() - 1; i >= 0; --i) {
425 ContinueContext(&prev->beams_[index].get(i).data, index, outputs,
426 top_n, dict_ratio, cert_offset, worst_dict_cert,
430 for (
int index = 0; index <
kNumBeams; ++index) {
432 total_beam += step->beams_[index].size();
437 for (
int c = 0; c <
NC_COUNT; ++c) {
438 if (step->best_initial_dawgs_[c].code >= 0) {
439 int index =
BeamIndex(
true, static_cast<NodeContinuation>(c), 0);
441 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
452 void RecodeBeamSearch::ContinueContext(
const RecodeNode* prev,
int index,
453 const float* outputs,
456 double worst_dict_cert,
464 for (
int p = length - 1; p >= 0; --p, previous = previous->
prev) {
465 while (previous != NULL &&
466 (previous->duplicate || previous->code == null_char_)) {
467 previous = previous->prev;
469 prefix.
Set(p, previous->code);
470 full_code.
Set(p, previous->code);
472 if (prev !=
nullptr && !is_simple_text_) {
473 if (top_n_flags_[prev->
code] == top_n_flag) {
477 PushDupOrNoDawgIfBetter(length,
true, prev->
code, prev->
unichar_id,
478 cert, worst_dict_cert, dict_ratio, use_dawgs,
482 prev->
code != null_char_) {
484 outputs[null_char_]) +
486 PushDupOrNoDawgIfBetter(length,
true, prev->
code, prev->
unichar_id,
487 cert, worst_dict_cert, dict_ratio, use_dawgs,
492 if (prev->
code != null_char_ && length > 0 &&
493 top_n_flags_[null_char_] == top_n_flag) {
498 PushDupOrNoDawgIfBetter(length,
false, null_char_, INVALID_UNICHAR_ID,
499 cert, worst_dict_cert, dict_ratio, use_dawgs,
504 if (final_codes != NULL) {
505 for (
int i = 0; i < final_codes->
size(); ++i) {
506 int code = (*final_codes)[i];
507 if (top_n_flags_[code] != top_n_flag)
continue;
508 if (prev !=
nullptr && prev->
code == code && !is_simple_text_)
continue;
511 full_code.
Set(length, code);
514 if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
515 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
517 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
518 float prob = outputs[code] + outputs[null_char_];
520 prev->
code != null_char_ &&
521 ((prev->
code == top_code_ && code == second_code_) ||
522 (code == top_code_ && prev->
code == second_code_))) {
523 prob += outputs[prev->
code];
526 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
532 if (next_codes != NULL) {
533 for (
int i = 0; i < next_codes->
size(); ++i) {
534 int code = (*next_codes)[i];
535 if (top_n_flags_[code] != top_n_flag)
continue;
536 if (prev !=
nullptr && prev->
code == code && !is_simple_text_)
continue;
538 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID, cert,
539 worst_dict_cert, dict_ratio, use_dawgs,
541 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
542 float prob = outputs[code] + outputs[null_char_];
544 prev->
code != null_char_ &&
545 ((prev->
code == top_code_ && code == second_code_) ||
546 (code == top_code_ && prev->
code == second_code_))) {
547 prob += outputs[prev->
code];
550 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID,
551 cert, worst_dict_cert, dict_ratio, use_dawgs,
559 void RecodeBeamSearch::ContinueUnichar(
int code,
int unichar_id,
float cert,
560 float worst_dict_cert,
float dict_ratio,
565 if (cert > worst_dict_cert) {
566 ContinueDawg(code, unichar_id, cert, cont, prev, step);
570 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
TOP_CHOICE_PERM,
false,
571 false,
false,
false, cert * dict_ratio, prev,
nullptr,
573 if (dict_ !=
nullptr &&
579 float dawg_cert = cert;
593 dawg_cert *= dict_ratio;
594 PushInitialDawgIfBetter(code, unichar_id, permuter,
false,
false,
595 dawg_cert, cont, prev, step);
603 void RecodeBeamSearch::ContinueDawg(
int code,
int unichar_id,
float cert,
608 if (unichar_id == INVALID_UNICHAR_ID) {
609 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
NO_PERM,
false,
false,
610 false,
false, cert, prev,
nullptr, dawg_heap);
615 if (prev != NULL) score += prev->
score;
616 if (dawg_heap->
size() >= kBeamWidths[0] &&
618 nodawg_heap->
size() >= kBeamWidths[0] &&
625 while (uni_prev != NULL &&
627 uni_prev = uni_prev->
prev;
632 PushInitialDawgIfBetter(code, unichar_id, uni_prev->
permuter,
false,
633 false, cert, cont, prev, step);
634 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->
permuter,
635 false,
false,
false,
false, cert, prev,
nullptr,
648 bool word_start =
false;
649 if (uni_prev == NULL) {
653 }
else if (uni_prev->
dawgs != NULL) {
663 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
664 word_start, dawg_args.
valid_end,
false, cert, prev,
666 if (dawg_args.
valid_end && !space_delimited_) {
670 PushInitialDawgIfBetter(code, unichar_id, permuter, word_start,
true,
671 cert, cont, prev, step);
672 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
673 word_start,
true,
false, cert, prev, NULL, nodawg_heap);
676 delete updated_dawgs;
683 void RecodeBeamSearch::PushInitialDawgIfBetter(
int code,
int unichar_id,
685 bool start,
bool end,
float cert,
689 RecodeNode* best_initial_dawg = &step->best_initial_dawgs_[cont];
691 if (prev != NULL) score += prev->
score;
692 if (best_initial_dawg->
code < 0 || score > best_initial_dawg->
score) {
695 RecodeNode node(code, unichar_id, permuter,
true, start, end,
false, cert,
696 score, prev, initial_dawgs,
697 ComputeCodeHash(code,
false, prev));
698 *best_initial_dawg = node;
706 void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
707 int length,
bool dup,
int code,
int unichar_id,
float cert,
708 float worst_dict_cert,
float dict_ratio,
bool use_dawgs,
710 int index =
BeamIndex(use_dawgs, cont, length);
712 if (cert > worst_dict_cert) {
713 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
715 dup, cert, prev,
nullptr, &step->beams_[index]);
720 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
722 false, dup, cert, prev,
nullptr, &step->beams_[index]);
730 void RecodeBeamSearch::PushHeapIfBetter(
int max_size,
int code,
int unichar_id,
732 bool word_start,
bool end,
bool dup,
737 if (prev != NULL) score += prev->
score;
739 uinT64 hash = ComputeCodeHash(code, dup, prev);
740 RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
741 dup, cert, score, prev, d, hash);
742 if (UpdateHeapIfMatched(&node, heap))
return;
746 if (heap->
size() > max_size) heap->
Pop(&entry);
754 void RecodeBeamSearch::PushHeapIfBetter(
int max_size,
RecodeNode* node,
757 if (UpdateHeapIfMatched(node, heap)) {
763 if (heap->
size() > max_size) heap->
Pop(&entry);
769 bool RecodeBeamSearch::UpdateHeapIfMatched(
RecodeNode* new_node,
775 for (
int i = 0; i < nodes->
size(); ++i) {
784 (*nodes)[i].key = node.
score;
794 uinT64 RecodeBeamSearch::ComputeCodeHash(
int code,
bool dup,
797 if (!dup && code != null_char_) {
799 uinT64 carry = (((hash >> 32) * num_classes) >> 32);
811 void RecodeBeamSearch::ExtractBestPaths(
817 const RecodeBeam* last_beam = beam_[beam_size_ - 1];
818 for (
int c = 0; c <
NC_COUNT; ++c) {
821 for (
int is_dawg = 0; is_dawg < 2; ++is_dawg) {
822 int beam_index =
BeamIndex(is_dawg, cont, 0);
823 int heap_size = last_beam->beams_[beam_index].size();
824 for (
int h = 0; h < heap_size; ++h) {
825 const RecodeNode* node = &last_beam->beams_[beam_index].get(h).data;
830 while (dawg_node != NULL &&
831 (dawg_node->
unichar_id == INVALID_UNICHAR_ID ||
833 dawg_node = dawg_node->
prev;
834 if (dawg_node == NULL || (!dawg_node->
end_of_word &&
840 if (best_node == NULL || node->
score > best_node->
score) {
841 second_best_node = best_node;
843 }
else if (second_best_node == NULL ||
845 second_best_node = node;
850 if (second_nodes != NULL) ExtractPath(second_best_node, second_nodes);
851 ExtractPath(best_node, best_nodes);
856 void RecodeBeamSearch::ExtractPath(
859 while (node != NULL) {
867 void RecodeBeamSearch::DebugPath(
870 for (
int c = 0; c < path.
size(); ++c) {
873 node.
Print(null_char_, *unicharset, 1);
878 void RecodeBeamSearch::DebugUnicharPath(
883 int num_ids = unichar_ids.
size();
884 double total_rating = 0.0;
885 for (
int c = 0; c < num_ids; ++c) {
886 int coord = xcoords[c];
887 tprintf(
"%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
889 certs[c], path[coord]->start_of_word, path[coord]->end_of_word,
890 path[coord]->permuter);
891 total_rating += ratings[c];
893 tprintf(
"Path total rating = %g\n", total_rating);
static bool IsDawgFromBeamsIndex(int index)
const UNICHARSET & getUnicharset() const
const Pair & get(int index) const
static C_BLOB * FakeBlob(const TBOX &box)
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET *unicharset, GenericVector< int > *unichar_ids, GenericVector< float > *certs, GenericVector< float > *ratings, GenericVector< int > *xcoords) const
RecodeBeamSearch(const UnicharCompress &recoder, int null_char, bool simple_text, Dict *dict)
void DebugBeams(const UNICHARSET &unicharset) const
void init_to_size(int size, T t)
static const float kMinCertainty
const GenericVector< int > * GetNextCodes(const RecodedCharID &code) const
static const int kMaxCodeLen
const GenericVector< int > * GetFinalCodes(const RecodedCharID &code) const
const Pair & PeekTop() const
int def_letter_is_okay(void *void_dawg_args, UNICHAR_ID unichar_id, bool word_end) const
void Print(int null_char, const UNICHARSET &unicharset, int depth) const
void Reshuffle(Pair *pair)
static int LengthFromBeamsIndex(int index)
const char * string() const
void set_matrix_cell(int col, int row)
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
static const int kNumBeams
const char * kNodeContNames[]
bool IsSpaceDelimitedLang() const
Returns true if the language is space-delimited (not CJ, or T).
GenericVector< Pair > * heap()
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words)
void scale(const float f)
void FakeWordFromRatings(PermuterType permuter)
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const
int DecodeUnichar(const RecodedCharID &code) const
DawgPositionVector * updated_dawgs
static NodeContinuation ContinuationFromBeamsIndex(int index)
static int BeamIndex(bool is_dawg, NodeContinuation cont, int length)
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset)
void put(ICOORD pos, const T &thing)
void Set(int index, int value)
static float ProbToCertainty(float prob)
void default_dawgs(DawgPositionVector *anylength_dawgs, bool suppress_patterns) const
DawgPositionVector * active_dawgs
const UNICHARSET * uch_set
DawgPositionVector * dawgs
STRING debug_str(UNICHAR_ID id) const