Skip to content

Commit

Permalink
Merge Converter::Predict() into StartPrediction.
Browse files Browse the repository at this point in the history
#codehealth

PiperOrigin-RevId: 699945786
  • Loading branch information
hiroyuki-komatsu committed Nov 25, 2024
1 parent a81137f commit 51945d7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 35 deletions.
56 changes: 25 additions & 31 deletions src/converter/converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,27 @@ void Converter::MaybeSetConsumedKeySizeToSegment(size_t consumed_key_size,
}

namespace {
bool ValidateConversionRequestForPrediction(const ConversionRequest &request) {
switch (request.request_type()) {
case ConversionRequest::CONVERSION:
// Conversion request is not for prediction.
return false;
case ConversionRequest::PREDICTION:
case ConversionRequest::SUGGESTION:
// Typical use case.
return true;
case ConversionRequest::PARTIAL_PREDICTION:
case ConversionRequest::PARTIAL_SUGGESTION: {
// Partial prediction/suggestion request is applicable only if the
// cursor is in the middle of the composer.
const size_t cursor = request.composer().GetCursor();
return cursor != 0 || cursor != request.composer().GetLength();
}
default:
ABSL_UNREACHABLE();
}
}

std::string GetPredictionKey(const ConversionRequest &request) {
switch (request.request_type()) {
case ConversionRequest::PREDICTION:
Expand All @@ -391,8 +412,10 @@ std::string GetPredictionKey(const ConversionRequest &request) {
}
} // namespace

bool Converter::Predict(const ConversionRequest &request,
Segments *segments) const {
bool Converter::StartPrediction(const ConversionRequest &request,
Segments *segments) const {
DCHECK(ValidateConversionRequestForPrediction(request));

const std::string key = GetPredictionKey(request);
if (ShouldSetKeyForPrediction(key, *segments)) {
SetKey(segments, key);
Expand Down Expand Up @@ -426,35 +449,6 @@ bool Converter::Predict(const ConversionRequest &request,
return IsValidSegments(request, *segments);
}

namespace {
bool ValidateConversionRequestForPrediction(const ConversionRequest &request) {
switch (request.request_type()) {
case ConversionRequest::CONVERSION:
// Conversion request is not for prediction.
return false;
case ConversionRequest::PREDICTION:
case ConversionRequest::SUGGESTION:
// Typical use case.
return true;
case ConversionRequest::PARTIAL_PREDICTION:
case ConversionRequest::PARTIAL_SUGGESTION: {
// Partial prediction/suggestion request is applicable only if the
// cursor is in the middle of the composer.
const size_t cursor = request.composer().GetCursor();
return cursor != 0 || cursor != request.composer().GetLength();
}
default:
ABSL_UNREACHABLE();
}
}
} // namespace

bool Converter::StartPrediction(const ConversionRequest &request,
Segments *segments) const {
DCHECK(ValidateConversionRequestForPrediction(request));
return Predict(request, segments);
}

void Converter::FinishConversion(const ConversionRequest &request,
Segments *segments) const {
CommitUsageStats(segments, segments->history_segments_size(),
Expand Down
3 changes: 0 additions & 3 deletions src/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ class Converter final : public ConverterInterface {
bool GetLastConnectivePart(absl::string_view preceding_text, std::string *key,
std::string *value, uint16_t *id) const;

ABSL_MUST_USE_RESULT bool Predict(const ConversionRequest &request,
Segments *segments) const;

const dictionary::PosMatcher *pos_matcher_ = nullptr;
const dictionary::SuppressionDictionary *suppression_dictionary_;
std::unique_ptr<prediction::PredictorInterface> predictor_;
Expand Down
2 changes: 1 addition & 1 deletion src/converter/converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ TEST_F(ConverterTest, PredictSetKey) {
.SetComposer(composer)
.SetRequestType(ConversionRequest::PREDICTION)
.Build();
ASSERT_TRUE(converter->Predict(request, &segments));
ASSERT_TRUE(converter->StartPrediction(request, &segments));

ASSERT_EQ(segments.conversion_segments_size(), 1);
EXPECT_EQ(segments.conversion_segment(0).key(), kPredictionKey);
Expand Down

0 comments on commit 51945d7

Please sign in to comment.