diff --git a/nmigen_soc/csr/bus.py b/nmigen_soc/csr/bus.py
index 0f444ee..1a63770 100644
--- a/nmigen_soc/csr/bus.py
+++ b/nmigen_soc/csr/bus.py
@@ -293,7 +293,9 @@ def elaborate(self, platform):
                             m.d.sync += shadow_en.eq(self.bus.r_stb << chunk_offset)
 
                         if elem.access.writable():
-                            if chunk_addr == elem_end - 1:
+                            # Write when chunk_addr matches the start address of a chunk
+                            # aligned to the data width of the bus.
+                            if chunk_addr == elem_end - (1 << (log2_int(self.bus.data_width) - 3)):
                                 # Delay by 1 cycle, avoiding combinatorial paths through
                                 # the CSR bus and into CSR registers.
                                 m.d.sync += elem.w_stb.eq(self.bus.w_stb)
diff --git a/nmigen_soc/test/test_csr_bus.py b/nmigen_soc/test/test_csr_bus.py
index 84e115f..614f057 100644
--- a/nmigen_soc/test/test_csr_bus.py
+++ b/nmigen_soc/test/test_csr_bus.py
@@ -7,6 +7,7 @@
 
 from ..csr.bus import *
 from ..memory import MemoryMap
+from .utils import MockRegister
 
 
 class ElementTestCase(unittest.TestCase):
@@ -306,6 +307,61 @@ def sim_test():
         with sim.write_vcd(vcd_file=open("test.vcd", "w")):
             sim.run()
 
+class MultiplexerAlignedWideTestCase(unittest.TestCase):
+    def setUp(self):
+        self.dut = Multiplexer(addr_width=16, data_width=32, alignment=2)
+
+    def test_sim(self):
+        bus = self.dut.bus
+
+        elem_1 = MockRegister("elem1", 32)
+        elem_1_rw = elem_1.element
+        self.dut.add(elem_1_rw)
+        elem_2 = MockRegister("elem2", 32)
+        elem_2_rw = elem_2.element
+        self.dut.add(elem_2_rw)
+
+        def sim_test():
+            # Write 0x01020304 to 0x00
+            yield bus.addr.eq(0)
+            yield bus.w_stb.eq(1)
+            yield bus.w_data.eq(0x01020304)
+            yield
+            self.assertEqual((yield elem_1_rw.w_stb), 0)
+            yield bus.w_stb.eq(0)
+            yield
+            self.assertEqual((yield elem_1_rw.w_stb), 1)
+            self.assertEqual((yield elem_1_rw.w_data), 0x01020304)
+            yield
+            self.assertEqual((yield elem_1_rw.w_stb), 0)
+
+            # Write 0x11223344 to 0x04
+            yield bus.addr.eq(4)
+            yield bus.w_stb.eq(1)
+            yield bus.w_data.eq(0x11223344)
+            yield
+            self.assertEqual((yield elem_2_rw.w_stb), 0)
+            yield bus.w_stb.eq(0)
+            yield
+            self.assertEqual((yield elem_2_rw.w_stb), 1)
+            self.assertEqual((yield elem_2_rw.w_data), 0x11223344)
+
+            # Read from 0x00
+            yield bus.addr.eq(0)
+            yield bus.r_stb.eq(1)
+            yield
+            self.assertEqual((yield elem_1_rw.r_stb), 1)
+            yield bus.r_stb.eq(0)
+            yield
+            self.assertEqual((yield bus.r_data), 0x01020304)
+
+        m = Module()
+        m.submodules += self.dut, elem_1, elem_2
+        sim = Simulator(m)
+        sim.add_clock(1e-6)
+        sim.add_sync_process(sim_test)
+        with sim.write_vcd(vcd_file=open("test.vcd", "w")):
+            sim.run()
 
 class DecoderTestCase(unittest.TestCase):
     def setUp(self):
diff --git a/nmigen_soc/test/test_csr_wishbone.py b/nmigen_soc/test/test_csr_wishbone.py
index 3911b96..b6afbcb 100644
--- a/nmigen_soc/test/test_csr_wishbone.py
+++ b/nmigen_soc/test/test_csr_wishbone.py
@@ -6,27 +6,7 @@
 
 from .. import csr
 from ..csr.wishbone import *
-
-
-class MockRegister(Elaboratable):
-    def __init__(self, width):
-        self.element = csr.Element(width, "rw")
-        self.r_count = Signal(8)
-        self.w_count = Signal(8)
-        self.data    = Signal(width)
-
-    def elaborate(self, platform):
-        m = Module()
-
-        with m.If(self.element.r_stb):
-            m.d.sync += self.r_count.eq(self.r_count + 1)
-        m.d.comb += self.element.r_data.eq(self.data)
-
-        with m.If(self.element.w_stb):
-            m.d.sync += self.w_count.eq(self.w_count + 1)
-            m.d.sync += self.data.eq(self.element.w_data)
-
-        return m
+from .utils import MockRegister
 
 
 class WishboneCSRBridgeTestCase(unittest.TestCase):
@@ -42,9 +22,9 @@ def test_wrong_csr_bus_data_width(self):
 
     def test_narrow(self):
         mux   = csr.Multiplexer(addr_width=10, data_width=8)
-        reg_1 = MockRegister(8)
+        reg_1 = MockRegister("reg_1", 8)
         mux.add(reg_1.element)
-        reg_2 = MockRegister(16)
+        reg_2 = MockRegister("reg_2", 16)
         mux.add(reg_2.element)
         dut   = WishboneCSRBridge(mux.bus)
 
@@ -143,7 +123,7 @@ def sim_test():
 
     def test_wide(self):
         mux = csr.Multiplexer(addr_width=10, data_width=8)
-        reg = MockRegister(32)
+        reg = MockRegister("reg", 32)
         mux.add(reg.element)
         dut = WishboneCSRBridge(mux.bus, data_width=32)
 
diff --git a/nmigen_soc/test/utils.py b/nmigen_soc/test/utils.py
new file mode 100644
index 0000000..69c3da2
--- /dev/null
+++ b/nmigen_soc/test/utils.py
@@ -0,0 +1,22 @@
+from nmigen import *
+from .. import csr
+
+class MockRegister(Elaboratable):
+    def __init__(self, name, width):
+        self.element = csr.Element(width, "rw", name=name)
+        self.r_count = Signal(8)
+        self.w_count = Signal(8)
+        self.data    = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        with m.If(self.element.r_stb):
+            m.d.sync += self.r_count.eq(self.r_count + 1)
+        m.d.comb += self.element.r_data.eq(self.data)
+
+        with m.If(self.element.w_stb):
+            m.d.sync += self.w_count.eq(self.w_count + 1)
+            m.d.sync += self.data.eq(self.element.w_data)
+
+        return m