From 333030294a820dfe024cd63ab824b798409b9e24 Mon Sep 17 00:00:00 2001 From: Toshiyuki Hanaoka Date: Mon, 13 Nov 2023 10:52:46 +0000 Subject: [PATCH] Support inner segment info. Candidates with inner segment info can be generated by mobile prediction. PiperOrigin-RevId: 581895193 --- src/converter/segments.h | 1 + src/rewriter/user_segment_history_rewriter.cc | 56 +++++++++-- src/rewriter/user_segment_history_rewriter.h | 2 + .../user_segment_history_rewriter_test.cc | 98 +++++++++++++++++++ 4 files changed, 149 insertions(+), 8 deletions(-) diff --git a/src/converter/segments.h b/src/converter/segments.h index 45e287de7..f95fac39e 100644 --- a/src/converter/segments.h +++ b/src/converter/segments.h @@ -277,6 +277,7 @@ class Segment final { absl::string_view GetContentValue() const; absl::string_view GetFunctionalKey() const; absl::string_view GetFunctionalValue() const; + size_t GetIndex() const { return index_; } private: const Candidate *candidate_; diff --git a/src/rewriter/user_segment_history_rewriter.cc b/src/rewriter/user_segment_history_rewriter.cc index c9497a194..75f0a6ca7 100644 --- a/src/rewriter/user_segment_history_rewriter.cc +++ b/src/rewriter/user_segment_history_rewriter.cc @@ -607,10 +607,6 @@ void UserSegmentHistoryRewriter::RememberNumberPreference( void UserSegmentHistoryRewriter::RememberFirstCandidate( const Segments &segments, size_t segment_index) { const Segment &seg = segments.segment(segment_index); - if (seg.candidates_size() <= 1) { - return; - } - const Segment::Candidate &candidate = seg.candidate(0); // http://b/issue?id=3156109 @@ -710,6 +706,49 @@ bool UserSegmentHistoryRewriter::IsAvailable(const ConversionRequest &request, return true; } +// Returns segments for learning. +// Inner segments boundary will be expanded. +Segments UserSegmentHistoryRewriter::MakeLearningSegmentsForTesting( + const Segments &segments) { + Segments ret; + for (size_t i = 0; i < segments.segments_size(); ++i) { + const Segment &segment = segments.segment(i); + const Segment::Candidate &candidate = segment.candidate(0); + if (candidate.inner_segment_boundary.size() <= 1) { + // No inner segment info + Segment *seg = ret.add_segment(); + *seg = segment; + continue; + } + for (Segment::Candidate::InnerSegmentIterator iter(&candidate); + !iter.Done(); iter.Next()) { + size_t index = iter.GetIndex(); + absl::string_view key = iter.GetKey(); + + Segment *seg = ret.add_segment(); + seg->set_segment_type(segment.segment_type()); + seg->set_key(key); + seg->clear_candidates(); + + Segment::Candidate *cand = seg->add_candidate(); + cand->attributes = candidate.attributes; + cand->key = key; + cand->content_key = iter.GetContentKey(); + cand->value = iter.GetValue(); + cand->content_value = iter.GetContentValue(); + // Fill IDs for the first and last inner segment. + if (index == 0) { + cand->lid = candidate.lid; + cand->rid = candidate.lid; + } else if (index == candidate.inner_segment_boundary.size() - 1) { + cand->lid = candidate.rid; + cand->rid = candidate.rid; + } + } + } + return ret; +} + void UserSegmentHistoryRewriter::Finish(const ConversionRequest &request, Segments *segments) { if (request.request_type() != ConversionRequest::CONVERSION) { @@ -725,9 +764,10 @@ void UserSegmentHistoryRewriter::Finish(const ConversionRequest &request, return; } - for (size_t i = segments->history_segments_size(); - i < segments->segments_size(); ++i) { - const Segment &segment = segments->segment(i); + const Segments target_segments = MakeLearningSegmentsForTesting(*segments); + for (size_t i = target_segments.history_segments_size(); + i < target_segments.segments_size(); ++i) { + const Segment &segment = target_segments.segment(i); if (segment.candidates_size() <= 0 || segment.segment_type() != Segment::FIXED_VALUE || segment.candidate(0).attributes & @@ -739,7 +779,7 @@ void UserSegmentHistoryRewriter::Finish(const ConversionRequest &request, continue; } InsertTriggerKey(segment); - RememberFirstCandidate(*segments, i); + RememberFirstCandidate(target_segments, i); } // update usage stats here usage_stats::UsageStats::SetInteger("UserSegmentHistoryEntrySize", diff --git a/src/rewriter/user_segment_history_rewriter.h b/src/rewriter/user_segment_history_rewriter.h index e520a9257..f842836fa 100644 --- a/src/rewriter/user_segment_history_rewriter.h +++ b/src/rewriter/user_segment_history_rewriter.h @@ -59,6 +59,8 @@ class UserSegmentHistoryRewriter : public RewriterInterface { bool Reload() override; void Clear() override; + static Segments MakeLearningSegmentsForTesting(const Segments &segments); + private: struct Score { constexpr void Update(const Score other) { diff --git a/src/rewriter/user_segment_history_rewriter_test.cc b/src/rewriter/user_segment_history_rewriter_test.cc index 1004d8770..41c5461c3 100644 --- a/src/rewriter/user_segment_history_rewriter_test.cc +++ b/src/rewriter/user_segment_history_rewriter_test.cc @@ -29,6 +29,7 @@ #include "rewriter/user_segment_history_rewriter.h" +#include #include #include #include @@ -1661,5 +1662,102 @@ TEST_F(UserSegmentHistoryRewriterTest, AnnotationAfterLearning) { } } +TEST_F(UserSegmentHistoryRewriterTest, SupportInnerSegmentsOnLearning) { + Segments segments; + std::unique_ptr rewriter( + CreateUserSegmentHistoryRewriter()); + + { + segments.Clear(); + InitSegments(&segments, 1, 2); + constexpr absl::string_view kKey = "わたしのなまえはなかのです"; + constexpr absl::string_view kValue = "私の名前は中野です"; + segments.mutable_segment(0)->set_key(kKey); + Segment::Candidate *candidate = + segments.mutable_segment(0)->mutable_candidate(1); + + candidate->value = kValue; + candidate->content_value = kValue; + candidate->key = kKey; + candidate->content_key = kKey; + // "わたしの, 私の", "わたし, 私" + candidate->PushBackInnerSegmentBoundary(12, 6, 9, 3); + // "なまえは, 名前は", "なまえ, 名前" + candidate->PushBackInnerSegmentBoundary(12, 9, 9, 6); + // "なかのです, 中野です", "なかの, 中野" + candidate->PushBackInnerSegmentBoundary(15, 12, 9, 6); + candidate->lid = 10; + candidate->rid = 20; + + segments.mutable_segment(0)->move_candidate(1, 0); + segments.mutable_segment(0)->mutable_candidate(0)->attributes |= + Segment::Candidate::RERANKED; + segments.mutable_segment(0)->set_segment_type(Segment::FIXED_VALUE); + + { + const Segments learning_segments = + UserSegmentHistoryRewriter::MakeLearningSegmentsForTesting(segments); + EXPECT_EQ(learning_segments.segments_size(), 3); + EXPECT_EQ(learning_segments.segment(0).key(), "わたしの"); + EXPECT_EQ(learning_segments.segment(0).candidate(0).key, "わたしの"); + EXPECT_EQ(learning_segments.segment(0).candidate(0).value, "私の"); + EXPECT_EQ(learning_segments.segment(0).candidate(0).content_key, + "わたし"); + EXPECT_EQ(learning_segments.segment(0).candidate(0).content_value, "私"); + EXPECT_EQ(learning_segments.segment(0).candidate(0).lid, 10); + EXPECT_EQ(learning_segments.segment(0).candidate(0).rid, 10); + EXPECT_EQ(learning_segments.segment(0).segment_type(), + Segment::FIXED_VALUE); + + EXPECT_EQ(learning_segments.segment(1).key(), "なまえは"); + EXPECT_EQ(learning_segments.segment(1).candidate(0).key, "なまえは"); + EXPECT_EQ(learning_segments.segment(1).candidate(0).value, "名前は"); + EXPECT_EQ(learning_segments.segment(1).candidate(0).content_key, + "なまえ"); + EXPECT_EQ(learning_segments.segment(1).candidate(0).content_value, + "名前"); + EXPECT_EQ(learning_segments.segment(1).candidate(0).lid, 0); + EXPECT_EQ(learning_segments.segment(1).candidate(0).rid, 0); + EXPECT_EQ(learning_segments.segment(1).segment_type(), + Segment::FIXED_VALUE); + + EXPECT_EQ(learning_segments.segment(2).key(), "なかのです"); + EXPECT_EQ(learning_segments.segment(2).candidate(0).key, "なかのです"); + EXPECT_EQ(learning_segments.segment(2).candidate(0).value, "中野です"); + EXPECT_EQ(learning_segments.segment(2).candidate(0).content_key, + "なかの"); + EXPECT_EQ(learning_segments.segment(2).candidate(0).content_value, + "中野"); + EXPECT_EQ(learning_segments.segment(2).candidate(0).lid, 20); + EXPECT_EQ(learning_segments.segment(2).candidate(0).rid, 20); + EXPECT_EQ(learning_segments.segment(2).segment_type(), + Segment::FIXED_VALUE); + } + + rewriter->Finish(request_, &segments); + } + + { + segments.Clear(); + InitSegments(&segments, 1, 2); + segments.mutable_segment(0)->set_key("なかの"); + Segment::Candidate *candidate = + segments.mutable_segment(0)->mutable_candidate(0); + candidate->value = "中埜"; + candidate->content_value = "中埜"; + candidate->content_key = "なかの"; + candidate->content_key = "なかの"; + + candidate = segments.mutable_segment(0)->mutable_candidate(1); + candidate->value = "中野"; + candidate->content_value = "中野"; + candidate->content_key = "なかの"; + candidate->content_key = "なかの"; + + EXPECT_TRUE(rewriter->Rewrite(request_, &segments)); + EXPECT_EQ(segments.segment(0).candidate(0).value, "中野"); + } +} + } // namespace } // namespace mozc