From b7d3775555e4357442dd15101201c3c93205962f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 20 Nov 2024 13:34:02 +0900 Subject: [PATCH] needs to split with `-` Signed-off-by: Masaki Kozuki --- thunder/__init__.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 54c94855d..d921f313f 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -780,14 +780,17 @@ def wrapped(*args, **kwargs): def check_storage_aliases(cache_entry, args): if cache_entry.vanilla_tensor_args: if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args): - alias_tensor_indices = alias_tensor_indices_str - alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} - vanilla_tensor_args = cache_entry.vanilla_tensor_args - check( - not vanilla_tensor_args & alias_tensor_indices, - lambda: f"It seems that {vanilla_tensor_args} are {alias_tensor_indices=} share their storage and some of them are modified in-place", - NotImplementedError, - ) + for alias_tensor_group in alias_tensor_indices_str.split("-"): + alias_tensor_indices = {int(i) for i in alias_tensor_group.split(",")} + vanilla_tensor_args = cache_entry.vanilla_tensor_args + check( + not vanilla_tensor_args & alias_tensor_indices, + lambda: ( + f"It seems that {vanilla_tensor_args} are {alias_tensor_indices=} share " + "their storage and some of them are modified in-place" + ), + NotImplementedError, + ) def maybe_connect_to_autograd(cache_entry, result): if cache_entry.backward_fn: