Skip to content

Commit

Permalink
Add auto and multi-device training to CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Oct 18, 2024
1 parent 8ac1163 commit a9b812e
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions kraken/ketos/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,18 @@ def message(msg, **styles):


def to_ptl_device(device: str) -> Tuple[str, Optional[List[int]]]:
if device in ['cpu', 'mps']:
return device, 'auto'
elif any([device.startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]):
dev, idx = device.split(':')
if device.strip() == 'auto':
return 'auto', 'auto'
devices = device.split(',')
if devices[0] in ['cpu', 'mps']:
return devices[0], 'auto'
elif any([devices[0].startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]):
devices = [device.split(':') for device in devices]
devices = [(x[0].strip(), x[1].strip()) for x in devices]
if len(set(x[0] for x in devices)) > 1:
raise Exception('Can only use a single type of device at a time.')
dev, _ = devices[0]
if dev == 'cuda':
dev = 'gpu'
return dev, [int(idx)]
return dev, [int(x[1]) for x in devices]
raise Exception(f'Invalid device {device} specified')

0 comments on commit a9b812e

Please sign in to comment.