diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py
index 86e432ad71..23f293ad1e 100644
--- a/devito/ir/iet/nodes.py
+++ b/devito/ir/iet/nodes.py
@@ -16,7 +16,7 @@
                                Forward, WithLock, PrefetchUpdate, detect_io)
 from devito.symbolics import ListInitializer, CallFromPointer, ccode
 from devito.tools import (Signer, as_tuple, filter_ordered, filter_sorted, flatten,
-                          ctypes_to_cstr, OrderedSet)
+                          ctypes_to_cstr)
 from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed,
                                 Symbol)
 from devito.types.object import AbstractObject, LocalObject
@@ -1438,20 +1438,7 @@ def DummyExpr(*args, init=False):
 
 # Nodes required for distributed-memory halo exchange
 
-class HaloMixin:
-
-    def __repr__(self):
-        fstrings = []
-        for f in self.fmapper.keys():
-            loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values()))
-            loc_indices_str = str(list(loc_indices)) if loc_indices else ""
-            fstrings.append("%s%s" % (f.name, loc_indices_str))
-
-        functions = ",".join(fstrings)
-        return "<%s(%s)>" % (self.__class__.__name__, functions)
-
-
-class HaloSpot(HaloMixin, Node):
+class HaloSpot(Node):
 
     """
     A halo exchange operation (e.g., send, recv, wait, ...) required to
@@ -1508,6 +1495,9 @@ def body(self):
     def functions(self):
         return tuple(self.fmapper)
 
+    def __repr__(self):
+        funcs = self.halo_scheme.__reprfuncs__()
+        return "<%s(%s)>" % (self.__class__.__name__, funcs)
 
 # Utility classes
 
diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py
index f1bbb79003..d80127fae2 100644
--- a/devito/mpi/halo_scheme.py
+++ b/devito/mpi/halo_scheme.py
@@ -9,10 +9,9 @@
 from devito import configuration
 from devito.data import CORE, OWNED, LEFT, CENTER, RIGHT
 from devito.ir.support import Forward, Scope
-from devito.ir.iet.nodes import HaloMixin
 from devito.symbolics.manipulation import _uxreplace_registry
 from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
-                          frozendict, is_integer, filter_sorted)
+                          frozendict, is_integer, filter_sorted, OrderedSet)
 from devito.types import Grid
 
 __all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch']
@@ -36,8 +35,8 @@ class HaloSchemeEntry(Reconstructable):
     def __init__(self, loc_indices, loc_dirs, halos, dims):
         self.loc_indices = frozendict(loc_indices)
         self.loc_dirs = frozendict(loc_dirs)
-        self.halos = halos
-        self.dims = dims
+        self.halos = frozenset(halos)
+        self.dims = frozenset(dims)
 
     def __eq__(self, other):
         if not isinstance(other, HaloSchemeEntry):
@@ -48,10 +47,10 @@ def __eq__(self, other):
                 self.dims == other.dims)
 
     def __hash__(self):
-        return hash((frozenset(self.loc_indices.items()),
-                     frozenset(self.loc_dirs.items()),
-                     frozenset(self.halos),
-                     frozenset(self.dims)))
+        return hash((tuple(self.loc_indices.items()),
+                     tuple(self.loc_dirs.items()),
+                     self.halos,
+                     self.dims))
 
     def __repr__(self):
         return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, "
@@ -63,7 +62,7 @@ def __repr__(self):
 OMapper = namedtuple('OMapper', 'core owned')
 
 
-class HaloScheme(HaloMixin):
+class HaloScheme():
 
     """
     A HaloScheme describes a set of halo exchanges through a mapper:
@@ -121,6 +120,18 @@ def __init__(self, exprs, ispace):
             self._honored[i.root] = frozenset([(ltk, rtk)])
         self._honored = frozendict(self._honored)
 
+    def __reprfuncs__(self):
+        fstrings = []
+        for f in self.fmapper.keys():
+            loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values()))
+            loc_indices_str = str(list(loc_indices)) if loc_indices else ""
+            fstrings.append("%s%s" % (f.name, loc_indices_str))
+
+        return ",".join(fstrings)
+
+    def __repr__(self):
+        return "<%s(%s)>" % (self.__class__.__name__, self.__reprfuncs__())
+
     def __eq__(self, other):
         return (isinstance(other, HaloScheme) and
                 self._mapper == other._mapper and
diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py
index 204fcc2cae..030dfacfad 100644
--- a/devito/passes/iet/mpi.py
+++ b/devito/passes/iet/mpi.py
@@ -12,7 +12,7 @@
 from devito.mpi.reduction_scheme import DistReduce
 from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
 from devito.passes.iet.engine import iet_pass
-from devito.tools import generator, frozendict
+from devito.tools import generator
 
 __all__ = ['mpiize']
 
@@ -94,7 +94,6 @@ def _hoist_invariant(iet):
                 continue
 
             for f, v in hs1.fmapper.items():
-
                 if f not in hs0.functions:
                     continue
 
@@ -114,7 +113,7 @@ def _hoist_invariant(iet):
                         else:
                             raw_loc_indices[d] = v
 
-                    hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices))
+                    hse = hse._rebuild(loc_indices=raw_loc_indices)
                     hs1.halo_scheme.fmapper[f] = hse
 
                     hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
@@ -348,20 +347,27 @@ def _filter_iter_mapper(iet):
     Given an IET, return a mapper from Iterations to the HaloSpots.
     Additionally, filter out Iterations that are not of interest.
     """
-    iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet)
-    iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void]
-                   for k, v in iter_mapper.items()}
-    iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None}
-    iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1}
+    iter_mapper = {}
+    for k, v in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
+        filtered_hs = [hs for hs in v if not hs.halo_scheme.is_void]
+        if k is not None and len(filtered_hs) > 1:
+            iter_mapper[k] = filtered_hs
 
     return iter_mapper
 
 
 def _make_cond_mapper(iet):
-    cond_mapper = MapHaloSpots().visit(iet)
-    return {hs: {i for i in v if i.is_Conditional and
-                 not isinstance(i.condition, GuardFactorEq)}
-            for hs, v in cond_mapper.items()}
+
+    cond_mapper = {}
+    for hs, v in MapHaloSpots().visit(iet).items():
+        conditionals = set()
+        for i in v:
+            if i.is_Conditional and not isinstance(i.condition, GuardFactorEq):
+                conditionals.add(i)
+
+        cond_mapper[hs] = conditionals
+
+    return cond_mapper
 
 
 def _check_control_flow(hs0, hs1, cond_mapper):
diff --git a/examples/seismic/tutorials/09_viscoelastic.ipynb b/examples/seismic/tutorials/09_viscoelastic.ipynb
index 5d7cabdd75..0d1d524613 100644
--- a/examples/seismic/tutorials/09_viscoelastic.ipynb
+++ b/examples/seismic/tutorials/09_viscoelastic.ipynb
@@ -457,7 +457,7 @@
    "source": [
     "# References\n",
     "\n",
-    "[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelatic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n",
+    "[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelastic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n",
     "\n",
     "\n",
     "[2] https://janth.home.xs4all.nl/Software/fdelmodcManual.pdf"
diff --git a/examples/seismic/viscoelastic/operators.py b/examples/seismic/viscoelastic/operators.py
index fdf1110004..9e0269665a 100644
--- a/examples/seismic/viscoelastic/operators.py
+++ b/examples/seismic/viscoelastic/operators.py
@@ -64,4 +64,4 @@ def ForwardOperator(model, geometry, space_order=4, save=False, **kwargs):
 
     # Substitute spacing terms to reduce flops
     return Operator([u_v, u_r, u_t] + src_rec_expr, subs=model.spacing_map,
-                    name='ViscoElForward', **kwargs)
+                    name='ViscoIsoElasticForward', **kwargs)
diff --git a/tests/test_mpi.py b/tests/test_mpi.py
index 7d28dfe7bc..05a5b4b82d 100644
--- a/tests/test_mpi.py
+++ b/tests/test_mpi.py
@@ -1008,183 +1008,6 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode):
         calls = FindNodes(Call).visit(op)
         assert len(calls) == 0
 
-    @pytest.fixture
-    def setup(self):
-        shape = (2,)
-        so = 2
-        tn = 30
-
-        grid = Grid(shape=shape)
-
-        # Velocity and pressure fields
-        v = TimeFunction(name='v', grid=grid, space_order=so)
-        tau = TimeFunction(name='tau', grid=grid, space_order=so)
-
-        # First order elastic-like dependencies equations
-        pde_v = v.dt - (tau.dx)
-        pde_tau = tau.dt - ((v.forward).dx)
-        u_v = Eq(v.forward, solve(pde_v, v.forward))
-        u_tau = Eq(tau.forward, solve(pde_tau, tau.forward))
-
-        # Receiver
-        rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn)
-        rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1)
-
-        return grid, v, tau, u_v, u_tau, rec
-
-    @pytest.mark.parallel(mode=1)
-    def test_issue_2448_I(self, mode, setup):
-        _, v, tau, u_v, u_tau, rec = setup
-
-        rec_term0 = rec.interpolate(expr=v)
-
-        op0 = Operator([u_v, u_tau, rec_term0])
-
-        calls = [i for i in FindNodes(Call).visit(op0) if isinstance(i, HaloUpdateCall)]
-
-        assert len(calls) == 3
-        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1
-        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2
-        assert calls[0].arguments[0] is v
-        assert calls[1].arguments[0] is tau
-        assert calls[2].arguments[0] is v
-
-    @pytest.mark.parallel(mode=1)
-    def test_issue_2448_II(self, mode, setup):
-        _, v, tau, u_v, u_tau, rec = setup
-
-        rec_term1 = rec.interpolate(expr=v.forward)
-
-        op1 = Operator([u_v, u_tau, rec_term1])
-
-        calls = [i for i in FindNodes(Call).visit(op1) if isinstance(i, HaloUpdateCall)]
-
-        assert len(calls) == 3
-        assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[0])) == 0
-        assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[1])) == 3
-        assert calls[0].arguments[0] is tau
-        assert calls[1].arguments[0] is v
-        assert calls[2].arguments[0] is v
-
-    @pytest.mark.parallel(mode=1)
-    def test_issue_2448_III(self, mode, setup):
-        grid, v, tau, u_v, u_tau, rec = setup
-
-        # Additional velocity and pressure fields
-        v2 = TimeFunction(name='v2', grid=grid, space_order=2)
-        tau2 = TimeFunction(name='tau2', grid=grid, space_order=2)
-
-        # First order elastic-like dependencies equations
-        pde_v2 = v2.dt - (tau2.dx)
-        pde_tau2 = tau2.dt - ((v2.forward).dx)
-        u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward))
-        u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward))
-
-        # Receiver
-        rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30)
-        rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1)
-
-        rec_term0 = rec.interpolate(expr=v)
-        rec_term2 = rec2.interpolate(expr=v2)
-
-        op2 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term2])
-
-        calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)]
-
-        assert len(calls) == 5
-        assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2
-        assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3
-        assert calls[0].arguments[0] is v
-        assert calls[1].arguments[0] is v2
-        assert calls[2].arguments[0] is tau
-        assert calls[2].arguments[1] is tau2
-        assert calls[3].arguments[0] is v
-        assert calls[4].arguments[0] is v2
-
-    @pytest.mark.parallel(mode=1)
-    def test_issue_2448_IV(self, mode, setup):
-        grid, v, tau, u_v, u_tau, rec = setup
-
-        # Additional velocity and pressure fields
-        v2 = TimeFunction(name='v2', grid=grid, space_order=2)
-        tau2 = TimeFunction(name='tau2', grid=grid, space_order=2)
-
-        # First order elastic-like dependencies equations
-        pde_v2 = v2.dt - (tau2.dx)
-        pde_tau2 = tau2.dt - ((v2.forward).dx)
-        u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward))
-        u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward))
-
-        # Receiver
-        rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30)
-        rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1)
-
-        rec_term0 = rec.interpolate(expr=v)
-        rec_term3 = rec2.interpolate(expr=v2.forward)
-
-        op3 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term3])
-
-        calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)]
-
-        assert len(calls) == 5
-        assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[0])) == 1
-        assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[1])) == 4
-        assert calls[0].arguments[0] is v
-        assert calls[1].arguments[0] is tau
-        assert calls[1].arguments[1] is tau2
-        assert calls[2].arguments[0] is v
-        assert calls[3].arguments[0] is v2
-        assert calls[4].arguments[0] is v2
-
-    @pytest.mark.parallel(mode=1)
-    def test_issue_2448_backward(self, mode):
-        '''
-        Similar to test_issue_2448, but with backward instead of forward
-        so that the hoisted halo has different starting point
-        '''
-        shape = (2,)
-        so = 2
-
-        grid = Grid(shape=shape)
-        t = grid.stepping_dim
-
-        tn = 7
-
-        # Velocity and pressure fields
-        v = TimeFunction(name='v', grid=grid, space_order=so)
-        v.data_with_halo[0, :] = 1.
-        v.data_with_halo[1, :] = 3.
-
-        tau = TimeFunction(name='tau', grid=grid, space_order=so)
-        tau.data_with_halo[:] = 1.
-
-        # First order elastic-like dependencies equations
-        pde_v = v.dt - (tau.dx)
-        pde_tau = tau.dt - ((v.backward).dx)
-
-        u_v = Eq(v.backward, solve(pde_v, v))
-        u_tau = Eq(tau.backward, solve(pde_tau, tau))
-
-        # Test two variants of receiver interpolation
-        nrec = 1
-        rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn)
-        rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec)
-
-        # Test receiver interpolation 0, here we have a halo exchange hoisted
-        op0 = Operator([u_v] + [u_tau] + rec.interpolate(expr=v))
-
-        calls = [i for i in FindNodes(Call).visit(op0)
-                 if isinstance(i, HaloUpdateCall)]
-
-        # The correct we want
-        assert len(calls) == 3
-        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1
-        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2
-        assert calls[0].arguments[0] is v
-        assert calls[0].arguments[3].args[0] is t.symbolic_max
-        assert calls[1].arguments[0] is tau
-        assert calls[2].arguments[0] is v
-
     @pytest.mark.parallel(mode=1)
     def test_avoid_haloupdate_with_subdims(self, mode):
         grid = Grid(shape=(4,))
@@ -1400,7 +1223,7 @@ def test_avoid_fullmode_if_crossloop_dep(self, mode):
         assert np.all(f.data[:] == 2.)
 
     @pytest.mark.parallel(mode=2)
-    def test_avoid_halopudate_if_flowdep_along_other_dim(self, mode):
+    def test_avoid_haloupdate_if_flowdep_along_other_dim(self, mode):
         grid = Grid(shape=(10,))
         x = grid.dimensions[0]
         t = grid.stepping_dim
@@ -1513,8 +1336,10 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode):
 
         * the second and third Eqs cannot be fused in the same loop
 
-        In the IET we end up with *one* HaloSpots, placed right before the
-        second Eq. The third Eq will seamlessy find its halo up-to-date.
+        In the IET we end up with *two* HaloSpots, one placed before the
+        time loop, and one placed before the second Eq. The third Eq,
+        reading from f[t0], will seamlessy find its halo up-to-date,
+        due to the f[t1] being updated in the previous time iteration.
         """
         grid = Grid(shape=(10,))
         x = grid.dimensions[0]
@@ -1545,6 +1370,7 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode):
         assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[2])) == 0
 
         op.apply(time_M=1)
+
         glb_pos_map = f.grid.distributor.glb_pos_map
         R = 1e-07  # Can't use np.all due to rounding error at the tails
         if LEFT in glb_pos_map[x]:
@@ -2911,7 +2737,7 @@ def test_adjoint_F_no_omp(self, mode):
         self.run_adjoint_F(3)
 
 
-class TestElastic:
+class TestElasticLike:
 
     @pytest.mark.parallel(mode=[(1, 'diag')])
     def test_elastic_structure(self, mode):
@@ -2927,7 +2753,6 @@ def test_elastic_structure(self, mode):
         mu = Function(name='mu', grid=grid)
         ro = Function(name='b', grid=grid)
 
-        # The receiver
         rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=10)
         rec_term = rec.interpolate(expr=v[0] + v[1])
 
@@ -2945,7 +2770,6 @@ def test_elastic_structure(self, mode):
 
         calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)]
 
-        # The correct we want
         assert len(calls) == 5
 
         assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[0])) == 1
@@ -2960,12 +2784,188 @@ def test_elastic_structure(self, mode):
         assert calls[4].arguments[0] is v[0]
         assert calls[4].arguments[1] is v[1]
 
+    @pytest.fixture
+    def setup(self):
+        """
+        This fixture sets up the grid, fields, elastic-like
+        equations and receivers for test_issue_2448_*.
+        """
+        shape = (2,)
+        so = 2
+        tn = 30
+
+        grid = Grid(shape=shape)
+
+        # Velocity and pressure fields
+        v = TimeFunction(name='v', grid=grid, space_order=so)
+        tau = TimeFunction(name='tau', grid=grid, space_order=so)
+
+        # First order elastic-like dependencies equations
+        pde_v = v.dt - (tau.dx)
+        pde_tau = tau.dt - ((v.forward).dx)
+        u_v = Eq(v.forward, solve(pde_v, v.forward))
+        u_tau = Eq(tau.forward, solve(pde_tau, tau.forward))
+
+        rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn)
+        rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1)
+
+        return grid, v, tau, u_v, u_tau, rec
+
+    @pytest.mark.parallel(mode=1)
+    def test_issue_2448_v0(self, mode, setup):
+        _, v, tau, u_v, u_tau, rec = setup
+
+        rec_term0 = rec.interpolate(expr=v)
+
+        op0 = Operator([u_v, u_tau, rec_term0])
+
+        calls = [i for i in FindNodes(Call).visit(op0) if isinstance(i, HaloUpdateCall)]
+
+        assert len(calls) == 3
+        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1
+        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2
+        assert calls[0].arguments[0] is v
+        assert calls[1].arguments[0] is tau
+        assert calls[2].arguments[0] is v
+
+    @pytest.mark.parallel(mode=1)
+    def test_issue_2448_v1(self, mode, setup):
+        _, v, tau, u_v, u_tau, rec = setup
+
+        rec_term1 = rec.interpolate(expr=v.forward)
+
+        op1 = Operator([u_v, u_tau, rec_term1])
+
+        calls = [i for i in FindNodes(Call).visit(op1) if isinstance(i, HaloUpdateCall)]
+
+        assert len(calls) == 3
+        assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[0])) == 0
+        assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[1])) == 3
+        assert calls[0].arguments[0] is tau
+        assert calls[1].arguments[0] is v
+        assert calls[2].arguments[0] is v
+
+    @pytest.mark.parallel(mode=1)
+    def test_issue_2448_v2(self, mode, setup):
+        grid, v, tau, u_v, u_tau, rec = setup
+
+        # Additional velocity and pressure fields
+        v2 = TimeFunction(name='v2', grid=grid, space_order=2)
+        tau2 = TimeFunction(name='tau2', grid=grid, space_order=2)
+
+        # First order elastic-like dependencies equations
+        pde_v2 = v2.dt - (tau2.dx)
+        pde_tau2 = tau2.dt - ((v2.forward).dx)
+        u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward))
+        u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward))
+
+        rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30)
+        rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1)
+
+        rec_term0 = rec.interpolate(expr=v)
+        rec_term2 = rec2.interpolate(expr=v2)
+
+        op2 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term2])
+
+        calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)]
+
+        assert len(calls) == 5
+        assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2
+        assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3
+        assert calls[0].arguments[0] is v
+        assert calls[1].arguments[0] is v2
+        assert calls[2].arguments[0] is tau
+        assert calls[2].arguments[1] is tau2
+        assert calls[3].arguments[0] is v
+        assert calls[4].arguments[0] is v2
+
+    @pytest.mark.parallel(mode=1)
+    def test_issue_2448_v3(self, mode, setup):
+        grid, v, tau, u_v, u_tau, rec = setup
+
+        # Additional velocity and pressure fields
+        v2 = TimeFunction(name='v2', grid=grid, space_order=2)
+        tau2 = TimeFunction(name='tau2', grid=grid, space_order=2)
+
+        # First order elastic-like dependencies equations
+        pde_v2 = v2.dt - (tau2.dx)
+        pde_tau2 = tau2.dt - ((v2.forward).dx)
+        u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward))
+        u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward))
+
+        rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30)
+        rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1)
+
+        rec_term0 = rec.interpolate(expr=v)
+        rec_term3 = rec2.interpolate(expr=v2.forward)
+
+        op3 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term3])
+
+        calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)]
+
+        assert len(calls) == 5
+        assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[0])) == 1
+        assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[1])) == 4
+        assert calls[0].arguments[0] is v
+        assert calls[1].arguments[0] is tau
+        assert calls[1].arguments[1] is tau2
+        assert calls[2].arguments[0] is v
+        assert calls[3].arguments[0] is v2
+        assert calls[4].arguments[0] is v2
+
+    @pytest.mark.parallel(mode=1)
+    def test_issue_2448_backward(self, mode):
+        '''
+        Similar to test_issue_2448, but with backward instead of forward
+        so that the hoisted halo has different starting point
+        '''
+        shape = (2,)
+        so = 2
+
+        grid = Grid(shape=shape)
+        t = grid.stepping_dim
+
+        tn = 7
+
+        # Velocity and pressure fields
+        v = TimeFunction(name='v', grid=grid, space_order=so)
+        v.data_with_halo[0, :] = 1.
+        v.data_with_halo[1, :] = 3.
+
+        tau = TimeFunction(name='tau', grid=grid, space_order=so)
+        tau.data_with_halo[:] = 1.
+
+        # First order elastic-like dependencies equations
+        pde_v = v.dt - (tau.dx)
+        pde_tau = tau.dt - ((v.backward).dx)
+
+        u_v = Eq(v.backward, solve(pde_v, v))
+        u_tau = Eq(tau.backward, solve(pde_tau, tau))
+
+        # Test two variants of receiver interpolation
+        nrec = 1
+        rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn)
+        rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec)
+
+        # Test receiver interpolation 0, here we have a halo exchange hoisted
+        op0 = Operator([u_v] + [u_tau] + rec.interpolate(expr=v))
+
+        calls = [i for i in FindNodes(Call).visit(op0)
+                 if isinstance(i, HaloUpdateCall)]
+
+        assert len(calls) == 3
+        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1
+        assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2
+        assert calls[0].arguments[0] is v
+        assert calls[0].arguments[3].args[0] is t.symbolic_max
+        assert calls[1].arguments[0] is tau
+        assert calls[2].arguments[0] is v
+
 
 class TestTTIOp:
 
     @pytest.mark.parallel(mode=1)
     def test_halo_structure(self, mode):
-
         solver = TestTTI().tti_operator(opt='advanced', space_order=8)
         op = solver.op_fwd(save=False)