From dc3673a171c8e3b266c59b90f2f6ac64e2b957fd Mon Sep 17 00:00:00 2001 From: Qin-sx Date: Tue, 19 Nov 2024 20:22:45 +0800 Subject: [PATCH] added a RFC file for memory stats APIS new file: rfcs/APIs/20241119_api_design_for_reset_peak_memory_stats_reset_max_memory_allocated_memory_stats.md --- ...reset_max_memory_allocated_memory_stats.md | 631 ++++++++++++++++++ 1 file changed, 631 insertions(+) create mode 100644 rfcs/APIs/20241119_api_design_for_reset_peak_memory_stats_reset_max_memory_allocated_memory_stats.md diff --git a/rfcs/APIs/20241119_api_design_for_reset_peak_memory_stats_reset_max_memory_allocated_memory_stats.md b/rfcs/APIs/20241119_api_design_for_reset_peak_memory_stats_reset_max_memory_allocated_memory_stats.md new file mode 100644 index 000000000..eaf615ec5 --- /dev/null +++ b/rfcs/APIs/20241119_api_design_for_reset_peak_memory_stats_reset_max_memory_allocated_memory_stats.md @@ -0,0 +1,631 @@ +# paddle.device.cuda.reset_peak_memory_stats / paddle.device.cuda.reset_max_memory_allocated / paddle.device.cuda.memory_stats 设计文档 + +| API名称 | paddle.device.cuda.reset_peak_memory_stats / paddle.device.cuda.reset_max_memory_allocated / paddle.device.cuda.memory_stats | +| ------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------- | +| 提交作者 | Qin-sx | +| 提交时间 | 2024-11-19 | +| 版本号 | V1.0 | +| 依赖飞桨版本 | develop版本 | +| 文件名 | 20241119_api_design_for_reset_peak_memory_stats_reset_max_memory_allocated_memory_stats.md
| + +# 一、概述 + +## 1、相关背景 + +https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_7th/%E3%80%90Hackathon%207th%E3%80%91%E4%B8%AA%E4%BA%BA%E6%8C%91%E6%88%98%E8%B5%9B%E2%80%94%E6%A1%86%E6%9E%B6%E5%BC%80%E5%8F%91%E4%BB%BB%E5%8A%A1%E5%90%88%E9%9B%86.md#no21-%E4%B8%BA-paddle-%E6%96%B0%E5%A2%9E-reset_peak_memory_statsreset_max_memory_allocatedmemory_stats-api + +## 2、功能目标 + +在 paddle.device.cuda 包中,增加对 CUDA 张量类型的以下三个支持 +1. **重置CUDA内存分配器的峰值统计信息**:新增API `reset_peak_memory_stats`,位于`paddle.device.cuda`路径下,用于重置CUDA内存分配器的峰值统计信息。 + +2. **重置最大GPU内存占用的跟踪起点**:新增API `reset_max_memory_allocated`,位于`paddle.device.cuda`路径下,用于重置特定设备上张量占用的最大GPU内存的跟踪起点。 + +3. **获取CUDA内存分配器统计信息**:新增API `memory_stats`,位于`paddle.device.cuda`路径下,用于返回包含给定设备CUDA内存分配器统计信息的字典。 + +## 3、意义 + +新增paddle.device.cuda.reset_peak_memory_stats, paddle.device.cuda.reset_max_memory_allocated,paddle.device.cuda.memory_stats方法,丰富 paddle API。 + +# 二、飞桨现状 + +飞桨(PaddlePaddle)目前提供了几个关于CUDA设备端内存信息的API,包括: + +- `max_memory_allocated`:用于获取给定设备上分配给Tensor的显存峰值。[API文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/cuda/max_memory_allocated_cn.html#cn-api-paddle-device-cuda-max-memory-allocated) +- `max_memory_reserved`:用于获取给定设备上由Allocator管理的显存峰值。[API文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/cuda/max_memory_reserved_cn.html#cn-api-paddle-device-cuda-max-memory-reserved) +- `memory_allocated`:用于获取给定设备上当前分配给Tensor的显存大小。[API文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/cuda/memory_allocated_cn.html#cn-api-paddle-device-cuda-memory-allocated) +- `memory_reserved`:用于获取给定设备上当前由Allocator管理的显存大小。[API文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/cuda/memory_reserved_cn.html#cn-api-paddle-device-cuda-memory-reserved) + +然而,飞桨尚未提供类似于`peak_memory_stats`和`memory_stats`这样的API,这些API能够提供更详细的内存使用统计信息。此外,飞桨也没有提供重置`peak_memory_stats`和`max_memory_allocated`的API。为了进一步完善内存管理功能,飞桨应该加入`reset_peak_memory_stats`、`reset_max_memory_allocated`和`memory_stats`这三个API。这些新增的API将有助于开发者更精确地监控和控制内存使用情况。 + +# 三、业内方案调研 + +## PyTorch + +### `cuda.reset_peak_memory_stats`的实现 + +#### Python 接口 + +`reset_peak_memory_stats`函数主要通过`torch._C._cuda_resetPeakMemoryStats` 函数实现。 + +```python +def reset_peak_memory_stats(device: Union[Device, int] = None) -> None: + r"""Reset the "peak" stats tracked by the CUDA memory allocator. + + See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + device = _get_device_index(device, optional=True) + return torch._C._cuda_resetPeakMemoryStats(device) +``` + +#### C++ 实现 + +`torch._C._cuda_resetPeakMemoryStats`函数位于 `torch/csrc/cuda/Module.cpp`中,并通过 Python C API 注册到 Python 模块中。 + +```C++ +static struct PyMethodDef _THCPModule_methods[] = { + {"_cuda_resetPeakMemoryStats", + THCPModule_resetPeakMemoryStats, + METH_O, + nullptr} + // others... +} +``` +`THCPModule_resetPeakMemoryStats`函数主要通过调用`c10::cuda::CUDACachingAllocator::resetPeakStats`函数来实现。 + +```C++ +PyObject* THCPModule_resetPeakMemoryStats(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + c10::cuda::CUDACachingAllocator::resetPeakStats(device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} +``` + +`c10::cuda::CUDACachingAllocator::resetPeakStats`函数主要是调用`NativeCachingAllocator`类的`resetPeakStats`函数。 + +```C++ +C10_CUDA_API extern std::atomic allocator; + +inline CUDAAllocator* get() { + return allocator.load(); +} + +inline void resetPeakStats(c10::DeviceIndex device) { + return get()->resetPeakStats(device); +} +``` + +`NativeCachingAllocator`类的`resetPeakStats`函数主要是调用`DeviceCachingAllocator`类的`resetPeakStats`函数。 + +```C++ +class NativeCachingAllocator : public CUDAAllocator { +public: + std::vector> device_allocator; + + void resetPeakStats(c10::DeviceIndex device) override { + assertValidDevice(device); + device_allocator[device]->resetPeakStats(); + } +} +``` + +`DeviceCachingAllocator`类的`resetPeakStats`函数将`DeviceStats`类中的所有相关参数从`peak`改为`current`。 + +```C++ +class DeviceCachingAllocator { + private: + // lock around all operations + mutable std::recursive_mutex mutex; + + // device statistics + DeviceStats stats; + + void resetPeakStats() { + std::lock_guard lock(mutex); + + for (const auto statType : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats.allocation[statType].reset_peak(); + stats.segment[statType].reset_peak(); + stats.active[statType].reset_peak(); + stats.inactive_split[statType].reset_peak(); + stats.allocated_bytes[statType].reset_peak(); + stats.reserved_bytes[statType].reset_peak(); + stats.active_bytes[statType].reset_peak(); + stats.inactive_split_bytes[statType].reset_peak(); + stats.requested_bytes[statType].reset_peak(); + } + stats.oversize_allocations.reset_peak(); + stats.oversize_segments.reset_peak(); + } +} +``` + +### `cuda.reset_max_memory_allocated`的实现 + +#### Python 接口 + +PyTorch中的`reset_peak_memory_stats`函数是通过调用`reset_peak_memory_stats`函数实现,即会将所有的内存状态重置。 + +```python +def reset_max_memory_allocated(device: Union[Device, int] = None) -> None: + r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device. + + See :func:`~torch.cuda.max_memory_allocated` for details. + + Args: + device (torch.device or int, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. warning:: + This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets + /all/ peak memory stats. + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + warnings.warn( + "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, " + "which resets /all/ peak memory stats.", + FutureWarning, + ) + return reset_peak_memory_stats(device=device) +``` + +### `cuda.memory_stats`的实现 + +#### Python 接口 + +`memory_stats`函数主要通过`memory_stats_as_nested_dict`函数收集有关内存管理的信息。 + +```python +def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]: + r"""Return a dictionary of CUDA memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from ``cudaMalloc()``. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of October 2019, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of October 2019, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of CUDA allocation calls. This includes both + cuMemMap and cudaMalloc. + - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap + and cudaFree. + + The caching allocator can be configured via ENV to not split blocks larger than a + defined size (see Memory Management section of the Cuda Semantics documentation). + This helps avoid memory fragmentation but may have a performance + penalty. Additional outputs to assist with tuning and evaluating impact: + + - ``"max_split_size"``: blocks above this size will not be split. + - ``"oversize_allocations.{current,peak,allocated,freed}"``: + number of over-size allocation requests received by the memory allocator. + - ``"oversize_segments.{current,peak,allocated,freed}"``: + number of over-size reserved segments from ``cudaMalloc()``. + + The caching allocator can be configured via ENV to round memory allocations in order + to reduce fragmentation. Sometimes the overhead from rounding can be higher than + the fragmentation it helps reduce. The following stat can be used to check if + rounding adds too much overhead: + + - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + memory requested by client code, compare this with allocated_bytes to check if + allocation rounding adds too much overhead. + + Args: + device (torch.device or int, optional): selected device. Returns + statistics for the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + + .. note:: + With :ref:`backend:cudaMallocAsync`, some stats are not + meaningful, and are always reported as zero. + """ + result = [] + + def _recurse_add_to_result(prefix, obj): + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = memory_stats_as_nested_dict(device=device) + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) +``` + +`memory_stats_as_nested_dict`函数主要通过`torch._C._cuda_memoryStats`函数实现。 + +```python +def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]: + r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + device = _get_device_index(device, optional=True) + return torch._C._cuda_memoryStats(device) +``` + +#### C++ 实现 + +`torch._C._cuda_memoryStats`函数位于 `torch/csrc/cuda/Module.cpp`中,并通过 Python C API 注册到 Python 模块中。 + +```C++ +static struct PyMethodDef _THCPModule_methods[] = { + {"_cuda_memoryStats", THCPModule_memoryStats, METH_O, nullptr} + // others... +} +``` + +`THCPModule_memoryStats`函数主要将`DeviceStats`的信息存入字典中并返回。 + +```C++ +PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + + using c10::CachingDeviceAllocator::DeviceStats; + using c10::CachingDeviceAllocator::Stat; + using c10::CachingDeviceAllocator::StatArray; + using c10::CachingDeviceAllocator::StatType; + + const auto statToDict = [](const Stat& stat) { + py::dict dict; + + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto statArrayToDict = [=](const StatArray& statArray) { + const std::array(StatType::NUM_TYPES)> + statTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(statTypeNames.size())) { + dict[statTypeNames[i]] = statToDict(statArray[i]); + } + return dict; + }; + + const DeviceStats stats = + c10::cuda::CUDACachingAllocator::getDeviceStats(device_index); + + py::dict result; + result["num_alloc_retries"] = stats.num_alloc_retries; + result["num_ooms"] = stats.num_ooms; + result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; + result["allocation"] = statArrayToDict(stats.allocation); + result["segment"] = statArrayToDict(stats.segment); + result["active"] = statArrayToDict(stats.active); + result["inactive_split"] = statArrayToDict(stats.inactive_split); + result["allocated_bytes"] = statArrayToDict(stats.allocated_bytes); + result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes); + result["active_bytes"] = statArrayToDict(stats.active_bytes); + result["inactive_split_bytes"] = statArrayToDict(stats.inactive_split_bytes); + result["requested_bytes"] = statArrayToDict(stats.requested_bytes); + result["oversize_allocations"] = statToDict(stats.oversize_allocations); + result["oversize_segments"] = statToDict(stats.oversize_segments); + + return result.release().ptr(); + END_HANDLE_TH_ERRORS +} +``` + +## TensorFlow + +TensorFlow中关于GPU内存使用信息的函数主要是`reset_memory_stats`函数。 + +### `reset_memory_stats`的实现 + +#### Python 接口 + +```python +@tf_export('config.experimental.reset_memory_stats') +def reset_memory_stats(device): + """Resets the tracked memory stats for the chosen device. + + This function sets the tracked peak memory for a device to the device's + current memory usage. This allows you to measure the peak memory usage for a + specific part of your program. For example: + + >>> if tf.config.list_physical_devices('GPU'): + ... # Sets the peak memory to the current memory. + ... tf.config.experimental.reset_memory_stats('GPU:0') + ... # Creates the first peak memory usage. + ... x1 = tf.ones(1000 * 1000, dtype=tf.float64) + ... del x1 # Frees the memory referenced by `x1`. + ... peak1 = tf.config.experimental.get_memory_info('GPU:0')['peak'] + ... # Sets the peak memory to the current memory again. + ... tf.config.experimental.reset_memory_stats('GPU:0') + ... # Creates the second peak memory usage. + ... x2 = tf.ones(1000 * 1000, dtype=tf.float32) + ... del x2 + ... peak2 = tf.config.experimental.get_memory_info('GPU:0')['peak'] + ... assert peak2 < peak1 # tf.float32 consumes less memory than tf.float64. + + Currently only supports GPU and TPU. If called on a CPU device, an exception + will be raised. + + Args: + device: Device string to reset the memory stats, e.g. `"GPU:0"`, `"TPU:0"`. + See https://www.tensorflow.org/api_docs/python/tf/device for specifying + device strings. + + Raises: + ValueError: No device found with the device name, like '"nonexistent"'. + ValueError: Invalid device name, like '"GPU"', '"CPU:GPU"', '"CPU:"'. + ValueError: Multiple devices matched with the device name. + ValueError: Memory statistics not tracked or clearing memory statistics not + supported, like '"CPU:0"'. + """ + context.context().reset_memory_stats(device) +``` + +`reset_memory_stats`函数主要是由`Context`类的`reset_memory_stats`函数实现。 + +```python +class Context: + def reset_memory_stats(self, dev): + """Resets the tracked memory stats for the device.""" + self._initialize_physical_devices() + self.ensure_initialized() + pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev) +``` + +`reset_memory_stats`函数主要通过`TFE_ResetMemoryStats`函数实现。 + +#### C++ 实现 + +`TFE_ResetMemoryStats`函数通过 PYBIND11 注册到 Python 模块中。 + +```C++ +PYBIND11_MODULE(_pywrap_tfe, m) { + m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) { + tensorflow::Device* matched_device = + tensorflow::GetMatchedDevice(ctx, device_name); + + tensorflow::AllocatorAttributes attrs; + tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs); + + if (!allocator->ClearStats()) { + tensorflow::ThrowValueError( + absl::StrFormat("Cannot reset memory stats for device '%s'", + device_name) + .c_str()); + } + }); + // others... +} +``` + +`TFE_ResetMemoryStats`函数主要通过`tensorflow::Allocator`的`ClearStats`函数实现。 + +# 四、对比分析 + +由于 PyTorch 中的函数结构更符合飞浆的要求,因此可以参考 PyTorch 中的函数实现。 + +# 五、设计思路与实现方案 + +## 命名与参数设计 + +API `paddle.device.cuda.reset_peak_memory_stats(device: _CudaPlaceLike | None = None) -> None` +paddle.device.cuda.reset_peak_memory_stats +---------------------- +参数 +- device (_CudaPlaceLike) - 输入device名称或者序号。 +- None 无返回值。 + +API `paddle.device.cuda.reset_max_memory_allocated(device: _CudaPlaceLike | None = None) -> None` +paddle.device.cuda.reset_max_memory_allocated +---------------------- +参数 +- device (_CudaPlaceLike) - 输入device名称或者序号。 +- None 无返回值。 + +API `paddle.device.cuda.memory_stats(device: _CudaPlaceLike | None = None) -> Dict[str, Any]` +paddle.device.cuda.memory_stats +---------------------- +参数 +- device (_CudaPlaceLike) - 输入device名称或者序号。 +- Dict - 返回存储device内存管理信息的字典。 + +## 底层设计 + +### `cuda.reset_peak_memory_stats`的实现 + +设计`ResetPeakMemoryStats`函数,将`StatRegistry`类(单例类)的变量`stat_map_`存储的相关的`StatBase`中的peak值改为current值。 + +```C++ +void ResetPeakMemoryStats(int dev_id) { +} +``` + +可能需要修改`Stat`类,例如在类内直接添加函数或者添加一个友元函数。 + +```C++ +template +class Stat : public StatBase { + // 新增函数,用于将 peak_value_ 设置为 current_value_ + void ResetPeakValue() { + int64_t current_value = GetCurrentValue(); + peak_value_.store(current_value, std::memory_order_relaxed); + } +}; +``` + +### `cuda.reset_max_memory_allocated`的实现 + +设计`ResetMaxMemoryAllocated`函数,将变量`stat_map_`中存储的关于"Allocated"的部分的peak值改为current值。 + +```C++ +void ResetMaxMemoryAllocated(int dev_id) { +} +``` + +另一种实现方式:采用PyTorch中的方式,不设计底层`ResetMaxMemoryAllocated`函数,`reset_max_memory_allocated`实际调用`ResetPeakMemoryStats`函数。 + +### `cuda.memory_stats`的实现 + +设计`MemoryStats`函数,将所有内存信息存入unordered_map中。例如: + +```C++ +unordered_map stat_info_map = { + {"memory.allocated.peak", 0}, + {"memory.allocated.current", 0}, + {"memory.reserved.peak", 0}, + {"memory.reserved.current", 0} +} +``` + +可能需要增加一个函数获取`stat_map_` + +```C++ +class StatRegistry { +public: + const std::unordered_map& GetStatMap() const { + return stat_map_; + } +} +``` + +遍历`stat_map_`,获取每个stat的current和peak值。 + +另一种实现方式:在外层的python函数中直接设定有哪些内存信息需要遍历,例如只有"allocated"和"reserved"信息需要遍历,则可以直接调用相关函数获取peak值和current值。 + +## API实现方案 + +通过 PYBIND11 将C++函数注册到Python模块中。例如: + +```C++ +PYBIND11_MODULE(libpaddle, m){ + m.def("device_reset_peak_memory_stats", memory::ResetPeakMemoryStats); + m.def("device_reset_max_memory_allocated", memory::ResetMaxMemoryAllocated); + m.def("device_memory_stats", []() { + py::dict result; + for (const auto& pair : stat_info_map) { + result[pair.first.c_str()] = pair.second; + } + return result; + }); +} +``` + +# 六、测试和验收的考量 + +## 单元测试 + +### `reset_peak_memory_stats` + +#### 功能测试 + - 分配一定量的内存(例如,创建大型张量),记录峰值内存使用量。 + - 调用 `reset_peak_memory_stats` 函数。 + - 再次分配内存,确保新的峰值内存使用量从重置后重新计算。 + +#### 验证重置效果 + - 重置后,获取当前的峰值内存使用量,应该等于当前的内存使用量,而不是之前的峰值。 + +### `reset_max_memory_allocated` + +#### 功能测试 + - 分配内存,记录最大内存分配量。 + - 调用 `reset_max_memory_allocated` 函数。 + - 再次分配内存,验证最大内存分配量从重置后开始统计。 + +#### 验证重置效果 + - 重置后,获取当前的最大内存分配量,应该等于当前的内存分配量。 + +### `memory_stats` + +#### 数据正确性测试 + - 分配和释放不同大小的内存块,调用 `memory_stats`,验证返回的数据是否准确反映当前的内存使用情况。 + - 检查返回的字典或数据结构中,各个字段是否正确,例如已用内存、空闲内存等。 + + +# 七、影响面 + +## 需要进一步讨论的问题 +1. 是否对`Stat`类进行修改。 +2. 目前`StatRegistry`类中是否只注册了`Allocated`和`Reserved`。 +3. 是否参考Pytorch注册更多的内存信息标签。 + +## 对二次开发用户的影响 + +如果修改`Stat`类,类中新增的函数会暴露给二次开发用户。 + +# 八、排期规划 + +1. 2024/11/30前提交完善后的RFC文档。 +2. 2024/12/08前提交第一版代码。 +3. 2024/12/14前提交优化后的代码和API文档。 +