Skip to content

Commit

Permalink
Use CheckResizeSegmentsRequest for CalculatorRewriter.
Browse files Browse the repository at this point in the history
This CL stops using `parent_converter` in `ConverterInterface::ResizeSegment`.

#codehealth

PiperOrigin-RevId: 706981284
  • Loading branch information
hiroyuki-komatsu committed Dec 17, 2024
1 parent 86cf31d commit 9761c65
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 91 deletions.
16 changes: 16 additions & 0 deletions src/converter/converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2293,4 +2293,20 @@ TEST_F(ConverterTest, ResizeSegmentsRequest) {
EXPECT_EQ(Util::CharsLen(segments.conversion_segment(2).key()), 2);
}
}

TEST_F(ConverterTest, IntegrationWithCalculatorRewriter) {
std::unique_ptr<EngineInterface> engine =
MockDataEngineFactory::Create().value();
ConverterInterface *converter = engine->GetConverter();

{
Segments segments;
const ConversionRequest convreq =
ConversionRequestBuilder().SetKey("1+1=").Build();
ASSERT_TRUE(converter->StartConversion(convreq, &segments));
EXPECT_EQ(segments.conversion_segments_size(), 1);
EXPECT_EQ(segments.conversion_segment(0).candidate(0).value, "2");
}
}

} // namespace mozc
4 changes: 0 additions & 4 deletions src/rewriter/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,6 @@ mozc_cc_library(
":rewriter_interface",
"//base:japanese_util",
"//base:util",
"//converter:converter_interface",
"//converter:segments",
"//protocol:commands_cc_proto",
"//protocol:config_cc_proto",
Expand All @@ -773,11 +772,8 @@ mozc_cc_test(
":calculator_rewriter",
":rewriter_interface",
"//config:config_handler",
"//converter:converter_mock",
"//converter:segments",
"//converter:segments_matchers",
"//engine",
"//engine:mock_data_engine_factory",
"//protocol:commands_cc_proto",
"//protocol:config_cc_proto",
"//request:conversion_request",
Expand Down
89 changes: 46 additions & 43 deletions src/rewriter/calculator_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>

#include "absl/log/check.h"
Expand All @@ -40,7 +42,6 @@
#include "absl/strings/string_view.h"
#include "base/japanese_util.h"
#include "base/util.h"
#include "converter/converter_interface.h"
#include "converter/segments.h"
#include "protocol/commands.pb.h"
#include "protocol/config.pb.h"
Expand All @@ -50,25 +51,50 @@

namespace mozc {

CalculatorRewriter::CalculatorRewriter(
const ConverterInterface *parent_converter)
: parent_converter_(parent_converter) {
DCHECK(parent_converter_);
}

int CalculatorRewriter::capability(const ConversionRequest &request) const {
if (request.request().mixed_conversion()) {
return RewriterInterface::ALL;
}
return RewriterInterface::CONVERSION;
}

std::optional<RewriterInterface::ResizeSegmentsRequest>
CalculatorRewriter::CheckResizeSegmentsRequest(const ConversionRequest &request,
const Segments &segments) const {
if (!request.config().use_calculator()) {
return std::nullopt;
}

const CalculatorInterface *calculator = CalculatorFactory::GetCalculator();

const size_t segments_size = segments.conversion_segments_size();
if (segments_size <= 1) {
return std::nullopt;
}

// Merge keys of all conversion segments and try calculation.
std::string merged_key;
for (const Segment &segment : segments.conversion_segments()) {
merged_key += segment.key();
}
// The decision to calculate and calculation itself are both done by the
// calculator.
std::string result;
if (!calculator->CalculateString(merged_key, &result)) {
return std::nullopt;
}

// Merge all conversion segments.
const uint8_t key_size = static_cast<uint8_t>(Util::CharsLen(merged_key));
ResizeSegmentsRequest resize_request = {
.segment_index = 0,
.segment_sizes = { key_size, 0, 0, 0, 0, 0, 0, 0 },
};
return resize_request;
}

// Rewrites candidates when conversion segments of |segments| represents an
// expression that can be calculated. In such case, if |segments| consists
// of multiple segments, it merges them by calling ConverterInterface::
// ResizeSegment(), otherwise do calculation and insertion.
// TODO(tok): It currently calculates same expression twice, if |segments| is
// a valid expression.
// expression that can be calculated.
bool CalculatorRewriter::Rewrite(const ConversionRequest &request,
Segments *segments) const {
if (!request.config().use_calculator()) {
Expand All @@ -78,50 +104,27 @@ bool CalculatorRewriter::Rewrite(const ConversionRequest &request,
CalculatorInterface *calculator = CalculatorFactory::GetCalculator();

const size_t segments_size = segments->conversion_segments_size();
if (segments_size == 0) {
if (segments_size != 1) {
return false;
}

// If |segments| has only one conversion segment, try calculation and insert
// the result on success.
if (segments_size == 1) {
const std::string &key = segments->conversion_segment(0).key();
std::string result;
if (key.empty()) {
return false;
}
if (!calculator->CalculateString(key, &result)) {
return false;
}
// Insert the result.
if (!InsertCandidate(result, 0, segments->mutable_conversion_segment(0))) {
return false;
}
return true;
const std::string &key = segments->conversion_segment(0).key();
if (key.empty()) {
return false;
}

// Merge keys of all conversion segments and try calculation.
std::string merged_key;
for (const Segment &segment : segments->conversion_segments()) {
merged_key += segment.key();
}
// The decision to calculate and calculation itself are both done by the
// calculator.
std::string result;
if (!calculator->CalculateString(merged_key, &result)) {
if (!calculator->CalculateString(key, &result)) {
return false;
}

// Merge all conversion segments.
int offset = Util::CharsLen(merged_key) -
Util::CharsLen(segments->conversion_segment(0).key());
// ConverterInterface::ResizeSegment() calls Rewriter::Rewrite(), so
// CalculatorRewriter::Rewrite() is recursively called with merged
// conversion segment.
if (!parent_converter_->ResizeSegment(segments, request, 0, offset)) {
LOG(ERROR) << "Failed to merge conversion segments";
// Insert the result.
if (!InsertCandidate(result, 0, segments->mutable_conversion_segment(0))) {
return false;
}

return true;
}

Expand Down
9 changes: 5 additions & 4 deletions src/rewriter/calculator_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define MOZC_REWRITER_CALCULATOR_REWRITER_H_

#include <cstddef>
#include <optional>

#include "absl/strings/string_view.h"
#include "converter/segments.h"
Expand All @@ -48,10 +49,12 @@ class CalculatorRewriter : public RewriterInterface {
public:
friend class CalculatorRewriterTest;

explicit CalculatorRewriter(const ConverterInterface *parent_converter);

int capability(const ConversionRequest &request) const override;

std::optional<ResizeSegmentsRequest> CheckResizeSegmentsRequest(
const ConversionRequest &request,
const Segments &segments) const override;

bool Rewrite(const ConversionRequest &request,
Segments *segments) const override;

Expand All @@ -61,8 +64,6 @@ class CalculatorRewriter : public RewriterInterface {
// insertion is failed.
bool InsertCandidate(absl::string_view value, size_t insert_pos,
Segment *segment) const;

const ConverterInterface *parent_converter_;
};

} // namespace mozc
Expand Down
65 changes: 26 additions & 39 deletions src/rewriter/calculator_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,15 @@
#include "rewriter/calculator_rewriter.h"

#include <cstddef>
#include <memory>
#include <optional>
#include <string>

#include "absl/log/check.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "config/config_handler.h"
#include "converter/converter_mock.h"
#include "converter/segments.h"
#include "converter/segments_matchers.h"
#include "engine/engine.h"
#include "engine/mock_data_engine_factory.h"
#include "protocol/commands.pb.h"
#include "protocol/config.pb.h"
#include "request/conversion_request.h"
Expand Down Expand Up @@ -133,12 +130,10 @@ class CalculatorRewriterTest : public testing::TestWithTempUserProfile {

private:
CalculatorMock calculator_mock_;
MockConverter mock_converter_;
};

TEST_F(CalculatorRewriterTest, InsertCandidateTest) {
MockConverter converter;
CalculatorRewriter calculator_rewriter(&converter);
CalculatorRewriter calculator_rewriter;

{
Segment segment;
Expand Down Expand Up @@ -172,8 +167,7 @@ TEST_F(CalculatorRewriterTest, BasicTest) {
// Pretend "key" is calculated to "value".
calculator_mock().SetCalculatePair("key", "value", true);

MockConverter converter;
CalculatorRewriter calculator_rewriter(&converter);
CalculatorRewriter calculator_rewriter;
const int counter_at_first = calculator_mock().calculation_counter();

Segments segments;
Expand All @@ -197,11 +191,7 @@ TEST_F(CalculatorRewriterTest, SeparatedSegmentsTest) {
// Pretend "1+1=" is calculated to "2".
calculator_mock().SetCalculatePair("1+1=", "2", true);

// Since this test depends on the actual implementation of
// Converter::ResizeSegments(), we cannot use converter mock here. However,
// the test itself is independent of data.
std::unique_ptr<Engine> engine = MockDataEngineFactory::Create().value();
CalculatorRewriter calculator_rewriter(engine->GetConverter());
CalculatorRewriter calculator_rewriter;

// Push back separated segments.
Segments segments;
Expand All @@ -211,27 +201,21 @@ TEST_F(CalculatorRewriterTest, SeparatedSegmentsTest) {
AddSegment("=", "=", &segments);

const ConversionRequest convreq = ConvReq(config_, request_);
calculator_rewriter.Rewrite(convreq, &segments);
EXPECT_EQ(segments.segments_size(), 1); // merged

int index = GetIndexOfCalculatedCandidate(segments);
EXPECT_NE(index, -1);

// Secondary result with expression (description: "1+1=2");
EXPECT_TRUE(
ContainsCalculatedResult(segments.segment(0).candidate(index + 1)));
ASSERT_FALSE(calculator_rewriter.Rewrite(convreq, &segments));

EXPECT_EQ("2", segments.segment(0).candidate(index).value);
EXPECT_EQ("1+1=2", segments.segment(0).candidate(index + 1).value);
std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
calculator_rewriter.CheckResizeSegmentsRequest(convreq, segments);
ASSERT_TRUE(resize_request.has_value());
EXPECT_EQ(resize_request->segment_index, 0);
EXPECT_EQ(resize_request->segment_sizes[0], 4);
}

// CalculatorRewriter should convert an expression starting with '='.
TEST_F(CalculatorRewriterTest, ExpressionStartingWithEqualTest) {
// Pretend "=1+1" is calculated to "2".
calculator_mock().SetCalculatePair("=1+1", "2", true);

MockConverter converter;
CalculatorRewriter calculator_rewriter(&converter);
CalculatorRewriter calculator_rewriter;
const ConversionRequest request;

Segments segments;
Expand All @@ -255,8 +239,7 @@ TEST_F(CalculatorRewriterTest, DescriptionCheckTest) {
// Pretend kExpression is calculated to "3"
calculator_mock().SetCalculatePair(kExpression, "3", true);

MockConverter converter;
CalculatorRewriter calculator_rewriter(&converter);
CalculatorRewriter calculator_rewriter;

Segments segments;
AddSegment(kExpression, kExpression, &segments);
Expand All @@ -273,11 +256,7 @@ TEST_F(CalculatorRewriterTest, DescriptionCheckTest) {
TEST_F(CalculatorRewriterTest, ConfigTest) {
calculator_mock().SetCalculatePair("1+1=", "2", true);

// Since this test depends on the actual implementation of
// Converter::ResizeSegments(), we cannot use converter mock here. However,
// the test itself is independent of data.
std::unique_ptr<Engine> engine = MockDataEngineFactory::Create().value();
CalculatorRewriter calculator_rewriter(engine->GetConverter());
CalculatorRewriter calculator_rewriter;
{
Segments segments;
AddSegment("1", "1", &segments);
Expand All @@ -286,7 +265,13 @@ TEST_F(CalculatorRewriterTest, ConfigTest) {
AddSegment("=", "=", &segments);
config_.set_use_calculator(true);
const ConversionRequest convreq = ConvReq(config_, request_);
EXPECT_TRUE(calculator_rewriter.Rewrite(convreq, &segments));
ASSERT_FALSE(calculator_rewriter.Rewrite(convreq, &segments));

std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
calculator_rewriter.CheckResizeSegmentsRequest(convreq, segments);
ASSERT_TRUE(resize_request.has_value());
EXPECT_EQ(resize_request->segment_index, 0);
EXPECT_EQ(resize_request->segment_sizes[0], 4);
}

{
Expand All @@ -298,12 +283,15 @@ TEST_F(CalculatorRewriterTest, ConfigTest) {
config_.set_use_calculator(false);
const ConversionRequest convreq = ConvReq(config_, request_);
EXPECT_FALSE(calculator_rewriter.Rewrite(convreq, &segments));

std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
calculator_rewriter.CheckResizeSegmentsRequest(convreq, segments);
EXPECT_FALSE(resize_request.has_value());
}
}

TEST_F(CalculatorRewriterTest, MobileEnvironmentTest) {
MockConverter converter;
CalculatorRewriter rewriter(&converter);
CalculatorRewriter rewriter;
{
request_.set_mixed_conversion(true);
const ConversionRequest convreq = ConvReq(config_, request_);
Expand All @@ -317,8 +305,7 @@ TEST_F(CalculatorRewriterTest, MobileEnvironmentTest) {
}

TEST_F(CalculatorRewriterTest, EmptyKeyTest) {
MockConverter converter;
CalculatorRewriter calculator_rewriter(&converter);
CalculatorRewriter calculator_rewriter;
{
Segments segments;
AddSegment("", "1", &segments);
Expand Down
2 changes: 1 addition & 1 deletion src/rewriter/rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Rewriter::Rewriter(const engine::Modules &modules,
AddRewriter(std::make_unique<IvsVariantsRewriter>());
AddRewriter(std::make_unique<EmojiRewriter>(*data_manager));
AddRewriter(EmoticonRewriter::CreateFromDataManager(*data_manager));
AddRewriter(std::make_unique<CalculatorRewriter>(&parent_converter));
AddRewriter(std::make_unique<CalculatorRewriter>());
AddRewriter(
std::make_unique<SymbolRewriter>(&parent_converter, data_manager));
AddRewriter(std::make_unique<UnicodeRewriter>(&parent_converter));
Expand Down

0 comments on commit 9761c65

Please sign in to comment.