diff --git a/device_smi/base.py b/device_smi/base.py index 03ca194..5293515 100644 --- a/device_smi/base.py +++ b/device_smi/base.py @@ -1,5 +1,6 @@ import subprocess from abc import abstractmethod +from typing import Optional, Callable class BaseDevice: @@ -67,7 +68,7 @@ def __repr__(self): return self.__str__() -def _run(args, line_start: str = None) -> str: +def _run(args, line_start: Optional[str] = None) -> str: result = subprocess.run( args, stdout=subprocess.PIPE, diff --git a/device_smi/cpu.py b/device_smi/cpu.py index f82ac39..a7afa00 100644 --- a/device_smi/cpu.py +++ b/device_smi/cpu.py @@ -21,6 +21,7 @@ def __init__(self, cls): command_result = _run(["wmic", "cpu", "get", "manufacturer,name,numberofcores,numberoflogicalprocessors", "/format:csv"]).strip() command_result = re.sub(r'\n+', '\n', command_result) # windows uses \n\n result = command_result.split("\n")[1].split(",") + cpu_count = command_result.count('\n') model = result[2].strip() cpu_cores = int(result[3]) @@ -30,6 +31,7 @@ def __init__(self, cls): command_result = _run(["wmic", "os", "get", "TotalVisibleMemorySize", "/Value", "/format:csv"]).strip() command_result = re.sub(r'\n+', '\n', command_result) result = command_result.split("\n")[1].split(",") + mem_total = int(result[1]) elif platform.system().lower() == 'darwin': model = (_run(["sysctl", "-n", "machdep.cpu.brand_string"]).replace("Apple", "").strip()) diff --git a/device_smi/device.py b/device_smi/device.py index 3813032..62b11cd 100644 --- a/device_smi/device.py +++ b/device_smi/device.py @@ -21,6 +21,7 @@ class Device: def __init__(self, device): # CPU/GPU Device + self.memory_total = None self.type = None self.features = None self.vendor = None @@ -86,9 +87,6 @@ def info(self): ) return self - def memory_total(self): - return self.memory_total - def memory_used(self) -> int: return self.device.metrics().memory_used diff --git a/tests/cpu.py b/tests/cpu.py index 94bb492..e4a3e67 100644 --- a/tests/cpu.py +++ b/tests/cpu.py @@ -10,7 +10,7 @@ assert i not in dev.model, f"{i} should be removed in model" assert dev.vendor in "amd, intel, apple", f"check vendor: {dev.vendor}" -assert dev.memory_total() > 10, f"wrong memory size: {dev.memory_total()}" +assert dev.memory_total > 10, f"wrong memory size: {dev.memory_total}" assert dev.features is not None memory_used = dev.memory_used()