diff --git a/lib/mongoid/association/eager_loadable.rb b/lib/mongoid/association/eager_loadable.rb index 5ae3b67bfd..84cfd14996 100644 --- a/lib/mongoid/association/eager_loadable.rb +++ b/lib/mongoid/association/eager_loadable.rb @@ -31,6 +31,9 @@ def preload(associations, docs) docs_map = {} queue = [ klass.to_s ] + # account for single-collection inheritance + queue.push(klass.root_class.to_s) if klass != klass.root_class + while klass = queue.shift if as = assoc_map.delete(klass) as.each do |assoc| diff --git a/lib/mongoid/traversable.rb b/lib/mongoid/traversable.rb index ee0ca440da..ce0433c2da 100644 --- a/lib/mongoid/traversable.rb +++ b/lib/mongoid/traversable.rb @@ -300,6 +300,18 @@ def hereditary? !!(Mongoid::Document > superclass) end + # Returns the root class of the STI tree that the current + # class participates in. If the class is not an STI subclass, this + # returns the class itself. + # + # @return [ Mongoid::Document ] the root of the STI tree + def root_class + root = self + root = root.superclass while root.hereditary? + + root + end + # When inheriting, we want to copy the fields from the parent class and # set the on the child to start, mimicking the behavior of the old # class_inheritable_accessor that was deprecated in Rails edge. diff --git a/spec/mongoid/association/eager_spec.rb b/spec/mongoid/association/eager_spec.rb index a194fc7411..2171f63040 100644 --- a/spec/mongoid/association/eager_spec.rb +++ b/spec/mongoid/association/eager_spec.rb @@ -14,14 +14,36 @@ Mongoid::Contextual::Mongo.new(criteria) end + let(:association_host) { Account } + let(:inclusions) do includes.map do |key| - Account.reflect_on_association(key) + association_host.reflect_on_association(key) end end let(:doc) { criteria.first } + context 'when root is an STI subclass' do + # Driver has_one Vehicle + # Vehicle belongs_to Driver + # Truck is a Vehicle + + before do + Driver.create!(vehicle: Truck.new) + end + + let(:criteria) { Truck.all } + let(:includes) { %i[ driver ] } + let(:association_host) { Truck } + + it 'preloads the driver' do + expect(doc.ivar(:driver)).to be false + context.preload(inclusions, [ doc ]) + expect(doc.ivar(:driver)).to be == Driver.first + end + end + context "when belongs_to" do let!(:account) do @@ -42,7 +64,7 @@ it "preloads the parent" do expect(doc.ivar(:person)).to be false context.preload(inclusions, [doc]) - expect(doc.ivar(:person)).to eq(doc.person) + expect(doc.ivar(:person)).to be == person end end