Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the return type of Enumerable#sum/product for union elements #15314

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
28 changes: 28 additions & 0 deletions spec/std/enumerable_spec.cr
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require "spec"
require "./spec_helper"
require "spec/helpers/iterate"

module SomeInterface; end
Expand Down Expand Up @@ -1364,6 +1365,19 @@ describe "Enumerable" do
it { [1, 2, 3].sum(4.5).should eq(10.5) }
it { (1..3).sum { |x| x * 2 }.should eq(12) }
it { (1..3).sum(1.5) { |x| x * 2 }.should eq(13.5) }
it { [1, 3_u64].sum(0_i32).should eq(4_u32) }
it { [1, 3].sum(0_u64).should eq(4_u64) }
it { [1, 10000000000_u64].sum(0_u64).should eq(10000000001) }
pending_wasm32 "raises if union types are summed", tags: %w[slow] do
assert_compile_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].sum
CRYSTAL
"`Enumerable#sum` and `#product` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call."
end

it "uses additive_identity from type" do
typeof([1, 2, 3].sum).should eq(Int32)
Expand Down Expand Up @@ -1405,6 +1419,20 @@ describe "Enumerable" do
typeof([1.5, 2.5, 3.5].product).should eq(Float64)
typeof([1, 2, 3].product(&.to_f)).should eq(Float64)
end

it { [1, 3_u64].product(3_i32).should eq(9_u32) }
it { [1, 3].product(3_u64).should eq(9_u64) }
it { [1, 10000000000_u64].product(3_u64).should eq(30000000000_u64) }
pending_wasm32 "raises if union types are multiplied", tags: %w[slow] do
assert_compile_error <<-CRYSTAL,
require "prelude"
[1, 10000000000_u64].product
CRYSTAL
"`Enumerable#sum` and `#product` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call."
end
end

describe "first" do
Expand Down
27 changes: 27 additions & 0 deletions spec/std/spec_helper.cr
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,33 @@ def compile_file(source_file, *, bin_name = "executable_file", flags = %w(), fil
end
end

def assert_compile_error(source, expected_error, *, flags = %w(), file = __FILE__, line = __LINE__)
# can't use backtick in interpreted code (#12241)
pending_interpreted! "Unable to compile Crystal code in interpreted code"

with_tempfile("source_file", file: file) do |source_file|
File.write(source_file, source)

bin_name = "executable_file"
with_temp_executable(bin_name, file: file) do |executable_file|
compiler = ENV["CRYSTAL_SPEC_COMPILER_BIN"]? || "bin/crystal"
args = ["build"] + flags + ["-o", executable_file, source_file]
output = IO::Memory.new
status = Process.run(compiler, args, env: {
"CRYSTAL_PATH" => Crystal::PATH,
"CRYSTAL_LIBRARY_PATH" => Crystal::LIBRARY_PATH,
"CRYSTAL_CACHE_DIR" => Crystal::CACHE_DIR,
"NO_COLOR" => "1",
}, output: output, error: output)

output.to_s.should contain(expected_error)

status.success?.should be_false
File.exists?(executable_file).should be_false
end
end
end

def compile_source(source, flags = %w(), file = __FILE__, &)
with_tempfile("source_file", file: file) do |source_file|
File.write(source_file, source)
Expand Down
8 changes: 4 additions & 4 deletions spec/support/wasm32.cr
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
require "spec"

{% if flag?(:wasm32) %}
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
pending("#{description} [wasm32]", file, line, end_line)
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, focus : Bool = false, tags : String | Enumerable(String) | Nil = nil, &block)
pending("#{description} [wasm32]", file, line, end_line, focus: focus, tags: tags)
end

def pending_wasm32(*, describe, file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
pending_wasm32(describe, file, line, end_line) { }
end
{% else %}
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
it(description, file, line, end_line, &block)
def pending_wasm32(description = "assert", file = __FILE__, line = __LINE__, end_line = __END_LINE__, focus : Bool = false, tags : String | Enumerable(String) | Nil = nil, &block)
it(description, file, line, end_line, focus: focus, tags: tags, &block)
end

def pending_wasm32(*, describe, file = __FILE__, line = __LINE__, end_line = __END_LINE__, &block)
Expand Down
34 changes: 22 additions & 12 deletions src/enumerable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ module Enumerable(T)
end

private def additive_identity(reflect)
type = reflect.first
type = reflect.type
if type.responds_to? :additive_identity
type.additive_identity
else
Expand Down Expand Up @@ -1808,7 +1808,10 @@ module Enumerable(T)
# Expects all types returned from the block to respond to `#+` method.
#
# This method calls `.additive_identity` on the yielded type to determine the
# type of the sum value.
# type of the sum value. Hence, it can fail to compile if
# `.additive_identity` fails to determine a safe type, e.g., in case of
# union types. In such cases, use `sum(initial)` with an initial value of
# the expected type of the sum value.
#
# If the collection is empty, returns `additive_identity`.
#
Expand Down Expand Up @@ -1847,15 +1850,15 @@ module Enumerable(T)
# ```
#
# This method calls `.multiplicative_identity` on the element type to determine the
# type of the sum value.
# type of the product value.
#
# If the collection is empty, returns `multiplicative_identity`.
#
# ```
# ([] of Int32).product # => 1
# ```
def product
product Reflect(T).first.multiplicative_identity
product Reflect(T).type.multiplicative_identity
end

# Multiplies *initial* and all the elements in the collection
Expand Down Expand Up @@ -1886,16 +1889,19 @@ module Enumerable(T)
#
# Expects all types returned from the block to respond to `#*` method.
#
# This method calls `.multiplicative_identity` on the element type to determine the
# type of the sum value.
# This method calls `.multiplicative_identity` on the element type to
# determine the type of the product value. Hence, it can fail to compile if
# `.multiplicative_identity` fails to determine a safe type, e.g., in case
# of union types. In such cases, use `product(initial)` with an initial
# value of the expected type of the product value.
#
# If the collection is empty, returns `multiplicative_identity`.
#
# ```
# ([] of Int32).product { |x| x + 1 } # => 1
# ```
def product(& : T -> _)
product(Reflect(typeof(yield Enumerable.element_type(self))).first.multiplicative_identity) do |value|
product(Reflect(typeof(yield Enumerable.element_type(self))).type.multiplicative_identity) do |value|
yield value
end
end
Expand Down Expand Up @@ -2285,12 +2291,16 @@ module Enumerable(T)

# :nodoc:
private struct Reflect(X)
# For now it's just a way to implement `Enumerable#sum` in a way that the
# initial value given to it has the type of the first type in the union,
# if the type is a union.
def self.first
# For now, Reflect is used to reject union types in `#sum()` and
# `#product()` methods.
def self.type
{% if X.union? %}
{{X.union_types.first}}
{{
raise("`Enumerable#sum` and `#product` do not support Union " +
"types. Instead, use `Enumerable#sum(initial)` and " +
"`#product(initial)`, respectively, with an initial value " +
"of the intended type of the call.")
}}
{% else %}
X
{% end %}
Expand Down