diff --git a/spec/lucky/memoize_spec.cr b/spec/lucky/memoize_spec.cr index f1dd88899..e6c734c13 100644 --- a/spec/lucky/memoize_spec.cr +++ b/spec/lucky/memoize_spec.cr @@ -8,6 +8,8 @@ private class ObjectWithMemoizedMethods getter times_method_1_called = 0 getter times_method_2_called = 0 getter times_method_3_called = 0 + getter times_method_4_called = 0 + getter times_method_5_called = 0 memoize def method_1 : String @times_method_1_called += 1 @@ -23,6 +25,16 @@ private class ObjectWithMemoizedMethods @times_method_3_called += 1 arg_a + ", " + arg_b end + + memoize def method_4? : Bool + @times_method_4_called += 1 + true + end + + memoize def method_5! : String + @times_method_5_called += 1 + "Boom!" + end end describe "memoizations" do @@ -81,4 +93,36 @@ describe "memoizations" do object.method_3("arg-a", "arg-b").should eq("arg-a, arg-b") object.times_method_3_called.should eq 3 end + + it "works with predicate methods" do + object = ObjectWithMemoizedMethods.new + + object.method_4?.should eq(true) + object.method_4?.should eq(true) + object.method_4?.should eq(true) + object.times_method_4_called.should eq(1) + end + + it "Works with bang methods" do + object = ObjectWithMemoizedMethods.new + + object.method_5!.should eq("Boom!") + object.method_5!.should eq("Boom!") + object.method_5!.should eq("Boom!") + object.times_method_5_called.should eq(1) + end + + it "calls uncached with predicate and bang methods" do + object = ObjectWithMemoizedMethods.new + + object.method_4__uncached?.should eq(true) + object.method_4__uncached?.should eq(true) + object.method_4__uncached?.should eq(true) + object.times_method_4_called.should eq(3) + + object.method_5__uncached!.should eq("Boom!") + object.method_5__uncached!.should eq("Boom!") + object.method_5__uncached!.should eq("Boom!") + object.times_method_5_called.should eq(3) + end end diff --git a/src/lucky/memoizable.cr b/src/lucky/memoizable.cr index 0dae503e4..b78452184 100644 --- a/src/lucky/memoizable.cr +++ b/src/lucky/memoizable.cr @@ -25,7 +25,22 @@ module Lucky::Memoizable raise "All arguments must have an explicit type restriction for memoized methods" if method_def.args.any? &.restriction.is_a?(Nop) %} - @__memoized_{{method_def.name}} : Tuple( + {% + special_ending = nil + safe_method_name = method_def.name + %} + + {% + if method_def.name.ends_with?('?') + special_ending = "?" + safe_method_name = method_def.name.tr("?", "") + elsif method_def.name.ends_with?('!') + special_ending = "!" + safe_method_name = method_def.name.tr("!", "") + end + %} + + @__memoized_{{safe_method_name}} : Tuple( {{ method_def.return_type }}, {% for arg in method_def.args %} {{ arg.restriction }}, @@ -33,7 +48,7 @@ module Lucky::Memoizable )? # Returns uncached value - def {{ method_def.name }}__uncached( + def {{ safe_method_name }}__uncached{% if special_ending %}{{ special_ending.id }}{% end %}( {% for arg in method_def.args %} {{ arg.name }} : {{ arg.restriction }}, {% end %} @@ -44,7 +59,7 @@ module Lucky::Memoizable # Checks the passed arguments against the memoized args # and runs the method body if it is the very first call # or the arguments do not match - def {{ method_def.name }}__tuple_cached( + def {{ safe_method_name }}__tuple_cached{% if special_ending %}{{ special_ending.id }}{% end %}( {% for arg in method_def.args %} {{ arg.name }} : {{ arg.restriction }}, {% end %} @@ -55,10 +70,10 @@ module Lucky::Memoizable {% end %} ) {% for arg, index in method_def.args %} - @__memoized_{{ method_def.name }} = nil if {{arg.name}} != @__memoized_{{ method_def.name }}.try &.at({{index}} + 1) + @__memoized_{{ safe_method_name }} = nil if {{arg.name}} != @__memoized_{{ safe_method_name }}.try &.at({{index}} + 1) {% end %} - @__memoized_{{ method_def.name }} ||= -> do - result = {{ method_def.name }}__uncached( + @__memoized_{{ safe_method_name }} ||= -> do + result = {{ safe_method_name }}__uncached{% if special_ending %}{{ special_ending.id }}{% end %}( {% for arg in method_def.args %} {{arg.name}}, {% end %} @@ -79,7 +94,7 @@ module Lucky::Memoizable {{ arg.name }} : {{ arg.restriction }}{% if has_default %} = {{ arg.default_value }}{% end %}, {% end %} ) : {{ method_def.return_type }} - {{ method_def.name }}__tuple_cached( + {{ safe_method_name }}__tuple_cached{% if special_ending %}{{ special_ending.id }}{% end %}( {% for arg in method_def.args %} {{arg.name}}, {% end %}