From 6fe65d263e4a3a5aac9dd452c19d762d05cb1992 Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Mon, 23 Sep 2024 17:13:45 +0300 Subject: [PATCH] default_discovery_dir should use YT username https://github.com/ytsaurus/ytsaurus-spyt/issues/20 --- Pull Request resolved: https://github.com/ytsaurus/ytsaurus-spyt/pull/22 commit_hash:dc2b9acb71855a0871e6b7b3d4f7899748f3d89f --- spyt-package/src/main/python/spyt/client.py | 6 +++--- spyt-package/src/main/python/spyt/utils.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/spyt-package/src/main/python/spyt/client.py b/spyt-package/src/main/python/spyt/client.py index 1ee1d596..1f58a503 100644 --- a/spyt-package/src/main/python/spyt/client.py +++ b/spyt-package/src/main/python/spyt/client.py @@ -127,9 +127,9 @@ def _create_spark_session(do_create_spark_session): stop(spark, exception) -def get_spark_discovery(discovery_path, conf): +def get_spark_discovery(discovery_path, conf, client=None): discovery_path = discovery_path or conf.get("discovery_path") or conf.get( - "discovery_dir") or default_discovery_dir() + "discovery_dir") or default_discovery_dir(client=client) return SparkDiscovery(discovery_path=discovery_path) @@ -150,7 +150,7 @@ def _configure_client_mode(spark_conf, local_conf, client=None, spyt_version=None): - discovery = get_spark_discovery(discovery_path, local_conf) + discovery = get_spark_discovery(discovery_path, local_conf, client=client) master = get_spark_master(discovery, rest=False, yt_client=client) set_conf(spark_conf, base_spark_conf(client=client, discovery=discovery)) spark_conf.set("spark.master", master) diff --git a/spyt-package/src/main/python/spyt/utils.py b/spyt-package/src/main/python/spyt/utils.py index 76e61c2f..dd10c28f 100644 --- a/spyt-package/src/main/python/spyt/utils.py +++ b/spyt-package/src/main/python/spyt/utils.py @@ -1,5 +1,4 @@ import argparse -import getpass import logging import os import re @@ -218,8 +217,8 @@ def scala_buffer_to_list(buffer): return [buffer.apply(i) for i in range(buffer.length())] -def default_user(): - return os.getenv("YT_USER") or getpass.getuser() +def default_user(client=None): + return os.getenv("YT_USER") or get_user_name(client=client) def default_token(): @@ -256,8 +255,9 @@ def set_conf(conf, dict_conf): conf.set(key, value) -def default_discovery_dir(): - return os.getenv("SPARK_YT_DISCOVERY_DIR") or YPath("//home").join(os.getenv("USER")).join("spark-tmp") +def default_discovery_dir(client=None): + return os.getenv("SPARK_YT_DISCOVERY_DIR") \ + or YPath("//tmp").join(default_user(client=client)).join("spark-tmp") def default_proxy(): @@ -281,11 +281,11 @@ def get_default_arg_parser(**kwargs): return parser -def parse_args(parser=None, parser_arguments=None, raw_args=None): +def parse_args(parser=None, parser_arguments=None, raw_args=None, client=None): parser_arguments = parser_arguments or {} parser = parser or get_default_arg_parser(**parser_arguments) args, unknown_args = parser.parse_known_args(args=raw_args) - args.discovery_path = args.discovery_path or args.discovery_dir or default_discovery_dir() + args.discovery_path = args.discovery_path or args.discovery_dir or default_discovery_dir(client=client) return args, unknown_args