Skip to content

Commit

Permalink
fix #377 (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyz5864 authored Oct 30, 2023
1 parent 90729be commit 2dab8b3
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions dipu/tests/python/unittests/test_amp_init_dtype_multithread.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,32 @@
class TestInitAMPDtypeMultiThread(TestCase):
NUM_THREADS = 10
TIMEOUT = 5
DTYPES = [torch.int32, torch.int64, torch.float16, torch.float32]
DTYPES = [torch.bfloat16, torch.float16, torch.float32, torch.float64]

def _run_multithread_test(self, f, args=(), kwargs={}):
threads = [Thread(target=f, args=args, kwargs=kwargs) for _ in range(self.NUM_THREADS)]

class PropagatingThread(Thread):
'''Helper class to propagate exception from child
thread to main thread on join.
Reference: https://stackoverflow.com/a/31614591/5602957
Reference: https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/test/test_autograd.py#L10221-L10239
'''

def run(self):
self.exception = None
try:
self.ret = super().run()
except Exception as e:
self.exception = e

def join(self, timeout=None):
super().join(timeout)
if self.exception:
raise self.exception from self.exception
return self.ret

threads = [PropagatingThread(target=f, args=args, kwargs=kwargs) for _ in range(self.NUM_THREADS)]
[t.start() for t in threads]
[t.join(self.TIMEOUT) for t in threads]
self.assertTrue(all(not t.is_alive() for t in threads))
Expand Down

0 comments on commit 2dab8b3

Please sign in to comment.