Skip to content

Commit

Permalink
Copy ov_config for seq2seq models
Browse files Browse the repository at this point in the history
  • Loading branch information
helena-intel committed Oct 25, 2023
1 parent ecad239 commit 26f2ce1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(
self.decoder_with_past = None
enable_compilation = kwargs.get("compile", True)
encoder_cache_dir = Path(self.model_save_dir).joinpath("encoder_cache")
ov_encoder_config = self.ov_config
ov_encoder_config = {**self.ov_config}

if "CACHE_DIR" not in ov_encoder_config.keys() and not str(self.model_save_dir).startswith(gettempdir()):
ov_encoder_config["CACHE_DIR"] = str(encoder_cache_dir)
Expand All @@ -213,7 +213,7 @@ def __init__(
)

decoder_cache_dir = Path(self.model_save_dir).joinpath("decoder_cache")
ov_decoder_config = self.ov_config
ov_decoder_config = {**self.ov_config}

if "CACHE_DIR" not in ov_decoder_config.keys() and not str(self.model_save_dir).startswith(gettempdir()):
ov_decoder_config["CACHE_DIR"] = str(decoder_cache_dir)
Expand All @@ -222,7 +222,7 @@ def __init__(

if self.use_cache:
decoder_past_cache_dir = Path(self.model_save_dir).joinpath("decoder_past_cache")
ov_decoder_past_config = self.ov_config
ov_decoder_past_config = {**self.ov_config}

if "CACHE_DIR" not in ov_decoder_past_config.keys() and not str(self.model_save_dir).startswith(
gettempdir()
Expand Down

0 comments on commit 26f2ce1

Please sign in to comment.