diff --git a/Makefile b/Makefile index 1ae4728eb..af7b1844f 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ endif @echo "make deb_bionic - Generate bionic deb package" @echo "make itest_bionic - Run tests and integration checks" @echo "make _itest_bionic - Run only integration checks" - @echo "make deb_jammy - Generate bionic deb package" + @echo "make deb_jammy - Generate jammy deb package" @echo "make itest_jammy - Run tests and integration checks" @echo "make _itest_jammy - Run only integration checks" @echo "make release - Prepare debian info for new release" diff --git a/bin/tronfig b/bin/tronfig index 057b0ed3b..5f50a83e7 100755 --- a/bin/tronfig +++ b/bin/tronfig @@ -199,7 +199,6 @@ if __name__ == "__main__": client = Client(args.server) if args.print_config: - # TODO: use maybe_encode() content = client.config(args.source)["config"] if type(content) is not bytes: content = content.encode("utf8") diff --git a/docs/source/jobs.rst b/docs/source/jobs.rst index ae8199f0b..1d4f7b00c 100644 --- a/docs/source/jobs.rst +++ b/docs/source/jobs.rst @@ -209,7 +209,7 @@ Optional Fields after this duration. **trigger_downstreams** (bool or dict) - Upon successfull completion of an action, will emit a trigger for every + Upon successful completion of an action, will emit a trigger for every item in the dictionary. When set to ``true``, a default dict of ``{shortdate: "{shortdate}"}`` is assumed. Emitted triggers will be in form: ``....``. See @@ -220,7 +220,7 @@ Optional Fields have been emitted by upstream actions. Unlike with ``requires`` attribute, dependent actions don't have to belong to the same job. ``triggered_by`` template may contain any pattern allowed in ``command`` attribute. - See :ref:`shortdate` for an explantion of shortdate + See :ref:`shortdate` for an explanation of shortdate Example: diff --git a/itest.sh b/itest.sh index a16b10bda..815c6b3f3 100755 --- a/itest.sh +++ b/itest.sh @@ -60,14 +60,3 @@ fi kill -SIGTERM $TRON_PID wait $TRON_PID || true - -/opt/venvs/tron/bin/python - < start time {}".format(ts, int(os.environ['TRON_START_TIME']))) -assert ts > int(os.environ['TRON_START_TIME']) -EOF diff --git a/requirements-dev-minimal.txt b/requirements-dev-minimal.txt index eddbadb2b..6df62cd55 100644 --- a/requirements-dev-minimal.txt +++ b/requirements-dev-minimal.txt @@ -8,6 +8,7 @@ pylint pytest pytest-asyncio requirements-tools +types-pytz types-PyYAML types-requests<2.31.0.7 # newer types-requests requires urllib3>=2 types-simplejson diff --git a/requirements-dev.txt b/requirements-dev.txt index 2860357cd..ee1dd3374 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -29,6 +29,7 @@ requirements-tools==1.2.1 toml==0.10.2 tomli==2.0.1 tomlkit==0.11.6 +types-pytz==2024.2.0.20240913 types-PyYAML==6.0.12 types-requests==2.31.0.5 types-simplejson==3.19.0.20240310 diff --git a/tests/serialize/runstate/dynamodb_state_store_test.py b/tests/serialize/runstate/dynamodb_state_store_test.py index 4609f497e..a6d282455 100644 --- a/tests/serialize/runstate/dynamodb_state_store_test.py +++ b/tests/serialize/runstate/dynamodb_state_store_test.py @@ -1,4 +1,4 @@ -import pickle +import json from unittest import mock import boto3 @@ -100,12 +100,29 @@ def store(): @pytest.fixture def small_object(): - yield pickle.dumps({"this": "data"}) + yield { + "job_name": "example_job", + "run_num": 1, + "run_time": None, + "node_name": "example_node", + "runs": [], + "cleanup_run": None, + "manual": False, + } @pytest.fixture def large_object(): - yield pickle.dumps([i for i in range(1000000)]) + yield { + "job_name": "example_job", + "run_num": 1, + "run_time": None, + "node_name": "example_node", + "runs": [], + "cleanup_run": None, + "manual": False, + "large_data": [i for i in range(1_000_000)], + } @pytest.mark.usefixtures("store", "small_object", "large_object") @@ -113,11 +130,11 @@ class TestDynamoDBStateStore: def test_save(self, store, small_object, large_object): key_value_pairs = [ ( - store.build_key("DynamoDBTest", "two"), + store.build_key("job_state", "two"), small_object, ), ( - store.build_key("DynamoDBTest2", "four"), + store.build_key("job_run_state", "four"), small_object, ), ] @@ -126,21 +143,27 @@ def test_save(self, store, small_object, large_object): assert store.save_errors == 0 keys = [ - store.build_key("DynamoDBTest", "two"), - store.build_key("DynamoDBTest2", "four"), + store.build_key("job_state", "two"), + store.build_key("job_run_state", "four"), ] vals = store.restore(keys) for key, value in key_value_pairs: assert_equal(vals[key], value) + for key in keys: + item = store.table.get_item(Key={"key": key, "index": 0}) + assert "Item" in item + assert "json_val" in item["Item"] + assert_equal(json.loads(item["Item"]["json_val"]), small_object) + def test_delete_if_val_is_none(self, store, small_object, large_object): key_value_pairs = [ ( - store.build_key("DynamoDBTest", "two"), + store.build_key("job_state", "two"), small_object, ), ( - store.build_key("DynamoDBTest2", "four"), + store.build_key("job_run_state", "four"), small_object, ), ] @@ -149,7 +172,7 @@ def test_delete_if_val_is_none(self, store, small_object, large_object): delete = [ ( - store.build_key("DynamoDBTest", "two"), + store.build_key("job_state", "two"), None, ), ] @@ -159,8 +182,8 @@ def test_delete_if_val_is_none(self, store, small_object, large_object): assert store.save_errors == 0 # Try to restore both, we should just get one back keys = [ - store.build_key("DynamoDBTest", "two"), - store.build_key("DynamoDBTest2", "four"), + store.build_key("job_state", "two"), + store.build_key("job_run_state", "four"), ] vals = store.restore(keys) assert vals == {keys[1]: small_object} @@ -168,7 +191,7 @@ def test_delete_if_val_is_none(self, store, small_object, large_object): def test_save_more_than_4KB(self, store, small_object, large_object): key_value_pairs = [ ( - store.build_key("DynamoDBTest", "two"), + store.build_key("job_state", "two"), large_object, ), ] @@ -176,14 +199,14 @@ def test_save_more_than_4KB(self, store, small_object, large_object): store._consume_save_queue() assert store.save_errors == 0 - keys = [store.build_key("DynamoDBTest", "two")] + keys = [store.build_key("job_state", "two")] vals = store.restore(keys) for key, value in key_value_pairs: assert_equal(vals[key], value) def test_restore_more_than_4KB(self, store, small_object, large_object): - keys = [store.build_key("thing", i) for i in range(3)] - value = pickle.loads(large_object) + keys = [store.build_key("job_state", i) for i in range(3)] + value = large_object pairs = zip(keys, (value for i in range(len(keys)))) store.save(pairs) store._consume_save_queue() @@ -191,11 +214,11 @@ def test_restore_more_than_4KB(self, store, small_object, large_object): assert store.save_errors == 0 vals = store.restore(keys) for key in keys: - assert_equal(pickle.dumps(vals[key]), large_object) + assert_equal(vals[key], large_object) def test_restore(self, store, small_object, large_object): - keys = [store.build_key("thing", i) for i in range(3)] - value = pickle.loads(small_object) + keys = [store.build_key("job_state", i) for i in range(3)] + value = small_object pairs = zip(keys, (value for i in range(len(keys)))) store.save(pairs) store._consume_save_queue() @@ -203,11 +226,11 @@ def test_restore(self, store, small_object, large_object): assert store.save_errors == 0 vals = store.restore(keys) for key in keys: - assert_equal(pickle.dumps(vals[key]), small_object) + assert_equal(vals[key], small_object) def test_delete_item(self, store, small_object, large_object): - keys = [store.build_key("thing", i) for i in range(3)] - value = pickle.loads(large_object) + keys = [store.build_key("job_state", i) for i in range(3)] + value = large_object pairs = list(zip(keys, (value for i in range(len(keys))))) store.save(pairs) @@ -222,8 +245,8 @@ def test_retry_saving(self, store, small_object, large_object): "moto.dynamodb2.responses.DynamoHandler.transact_write_items", side_effect=KeyError("foo"), ) as mock_failed_write: - keys = [store.build_key("thing", i) for i in range(1)] - value = pickle.loads(small_object) + keys = [store.build_key("job_state", i) for i in range(1)] + value = small_object pairs = zip(keys, (value for i in range(len(keys)))) try: store.save(pairs) @@ -236,7 +259,7 @@ def test_retry_reading(self, store, small_object, large_object): store.name: [ { "index": {"N": "0"}, - "key": {"S": "thing 0"}, + "key": {"S": "job_state 0"}, }, ], }, @@ -246,15 +269,15 @@ def test_retry_reading(self, store, small_object, large_object): "Keys": [ { "index": {"N": "0"}, - "key": {"S": "thing 0"}, + "key": {"S": "job_state 0"}, } ], }, }, "ResponseMetadata": {}, } - keys = [store.build_key("thing", i) for i in range(1)] - value = pickle.loads(small_object) + keys = [store.build_key("job_state", i) for i in range(1)] + value = small_object pairs = zip(keys, (value for i in range(len(keys)))) store.save(pairs) with mock.patch.object( @@ -269,7 +292,7 @@ def test_retry_reading(self, store, small_object, large_object): def test_restore_exception_propagation(self, store, small_object): # This test is to ensure that restore propagates exceptions upwards: see DAR-2328 - keys = [store.build_key("thing", i) for i in range(3)] + keys = [store.build_key("job_state", i) for i in range(3)] mock_future = mock.MagicMock() mock_future.result.side_effect = Exception("mocked exception") diff --git a/tests/serialize/runstate/statemanager_test.py b/tests/serialize/runstate/statemanager_test.py index 8e35b5040..7010b3bff 100644 --- a/tests/serialize/runstate/statemanager_test.py +++ b/tests/serialize/runstate/statemanager_test.py @@ -19,9 +19,7 @@ from tron.serialize.runstate.statemanager import PersistenceStoreError from tron.serialize.runstate.statemanager import PersistentStateManager from tron.serialize.runstate.statemanager import StateChangeWatcher -from tron.serialize.runstate.statemanager import StateMetadata from tron.serialize.runstate.statemanager import StateSaveBuffer -from tron.serialize.runstate.statemanager import VersionMismatchError class TestPersistenceManagerFactory(TestCase): @@ -42,24 +40,6 @@ def test_from_config_shelve(self): shutil.rmtree(tmpdir) -class TestStateMetadata(TestCase): - def test_validate_metadata(self): - metadata = {"version": (0, 5, 2)} - StateMetadata.validate_metadata(metadata) - - def test_validate_metadata_no_state_data(self): - metadata = None - StateMetadata.validate_metadata(metadata) - - def test_validate_metadata_mismatch(self): - metadata = {"version": (200, 1, 1)} - assert_raises( - VersionMismatchError, - StateMetadata.validate_metadata, - metadata, - ) - - class TestStateSaveBuffer(TestCase): @setup def setup_buffer(self): @@ -103,15 +83,7 @@ def test_keys_for_items(self): def test_restore(self): job_names = ["one", "two"] - with mock.patch.object( - self.manager, - "_restore_metadata", - autospec=True, - ) as mock_restore_metadata, mock.patch.object( - self.manager, - "_restore_dicts", - autospec=True, - ) as mock_restore_dicts, mock.patch.object( + with mock.patch.object(self.manager, "_restore_dicts", autospec=True,) as mock_restore_dicts, mock.patch.object( self.manager, "_restore_runs_for_job", autospect=True, @@ -125,7 +97,6 @@ def test_restore(self): ] restored_state = self.manager.restore(job_names) - mock_restore_metadata.assert_called_once_with() assert mock_restore_dicts.call_args_list == [ mock.call(runstate.JOB_STATE, job_names), ] @@ -280,19 +251,6 @@ def test_save_job(self): mock_job.state_data, ) - @mock.patch( - "tron.serialize.runstate.statemanager.StateMetadata", - autospec=None, - ) - def test_save_metadata(self, mock_state_metadata): - self.watcher.save_metadata() - meta_data = mock_state_metadata.return_value - self.watcher.state_manager.save.assert_called_with( - runstate.MCP_STATE, - meta_data.name, - meta_data.state_data, - ) - def test_shutdown(self): self.watcher.shutdown() assert not self.watcher.state_manager.enabled diff --git a/tests/utils/crontab_test.py b/tests/utils/crontab_test.py index ee0281713..aa80b0a22 100644 --- a/tests/utils/crontab_test.py +++ b/tests/utils/crontab_test.py @@ -43,6 +43,30 @@ def test_parse(self, mock_dow, mock_month, mock_monthday, mock_hour, mock_min): assert_equal(actual["months"], mock_month.return_value) assert_equal(actual["weekdays"], mock_dow.return_value) + def test_full_crontab_line(self): + line = "*/15 0 1,15 * 1-5" + expected = { + "minutes": [0, 15, 30, 45], + "hours": [0], + "monthdays": [1, 15], + "months": None, + "weekdays": [1, 2, 3, 4, 5], + "ordinals": None, + } + assert_equal(crontab.parse_crontab(line), expected) + + def test_full_crontab_line_with_last(self): + line = "0 0 L * *" + expected = { + "minutes": [0], + "hours": [0], + "monthdays": ["LAST"], + "months": None, + "weekdays": None, + "ordinals": None, + } + assert_equal(crontab.parse_crontab(line), expected) + class TestMinuteFieldParser(TestCase): @setup @@ -108,5 +132,35 @@ def test_parse_last(self): assert_equal(self.parser.parse("5, 6, L"), expected) +class TestComplexExpressions(TestCase): + @setup + def setup_parser(self): + self.parser = crontab.MinuteFieldParser() + + def test_complex_expression(self): + expected = [0, 10, 20, 30, 40, 50, 55] + assert_equal(self.parser.parse("*/10,55"), expected) + + +class TestInvalidInputs(TestCase): + @setup + def setup_parser(self): + self.parser = crontab.MinuteFieldParser() + + def test_invalid_expression(self): + with assert_raises(ValueError): + self.parser.parse("61") + + +class TestBoundaryValues(TestCase): + @setup + def setup_parser(self): + self.parser = crontab.MinuteFieldParser() + + def test_boundary_values(self): + assert_equal(self.parser.parse("0"), [0]) + assert_equal(self.parser.parse("59"), [59]) + + if __name__ == "__main__": run() diff --git a/tools/migration/migrate_state.py b/tools/migration/migrate_state.py index 5a0f7c114..23e9f53f1 100644 --- a/tools/migration/migrate_state.py +++ b/tools/migration/migrate_state.py @@ -103,7 +103,6 @@ def convert_state(opts): job_states = source_manager.restore( job_names, - skip_validation=True, ) source_manager.cleanup() diff --git a/tron/actioncommand.py b/tron/actioncommand.py index f34e07076..db6e5aba4 100644 --- a/tron/actioncommand.py +++ b/tron/actioncommand.py @@ -1,12 +1,15 @@ +import json import logging import os from io import StringIO from shlex import quote +from typing import Optional from tron.config import schema from tron.serialize import filehandler from tron.utils import timeutils from tron.utils.observer import Observable +from tron.utils.persistable import Persistable from tron.utils.state import Machine log = logging.getLogger(__name__) @@ -144,7 +147,8 @@ def clear(self): self.buffers.clear() -class NoActionRunnerFactory: +# TODO: TRON-2304 - Cleanup NoActionRunnerFactory +class NoActionRunnerFactory(Persistable): """Action runner factory that does not wrap the action run command.""" @classmethod @@ -156,8 +160,12 @@ def build_stop_action_command(cls, _id, _command): """It is not possible to stop action commands without a runner.""" raise NotImplementedError("An action_runner is required to stop.") + @staticmethod + def to_json(): + return None -class SubprocessActionRunnerFactory: + +class SubprocessActionRunnerFactory(Persistable): """Run actions by wrapping them in `action_runner.py`.""" runner_exec_name = "action_runner.py" @@ -195,6 +203,22 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + @staticmethod + def to_json(state_data: dict) -> Optional[str]: + try: + return json.dumps( + { + "status_path": state_data["status_path"], + "exec_path": state_data["exec_path"], + } + ) + except KeyError: + log.exception("Missing key in state_data:") + raise + except Exception: + log.exception("Error serializing SubprocessActionRunnerFactory to JSON:") + raise + def create_action_runner_factory_from_config(config): """A factory-factory method which returns a callable that can be used to diff --git a/tron/api/resource.py b/tron/api/resource.py index 663f4ed10..0460c910f 100644 --- a/tron/api/resource.py +++ b/tron/api/resource.py @@ -183,7 +183,9 @@ def getChild(self, action_name, _): if not action_name: return self - action_name = maybe_decode(action_name) + action_name = maybe_decode( + action_name + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. if action_name in self.job_run.action_runs: action_run = self.job_run.action_runs[action_name] return ActionRunResource(action_run, self.job_run) @@ -231,7 +233,9 @@ def getChild(self, run_id, _): if not run_id: return self - run_id = maybe_decode(run_id) + run_id = maybe_decode( + run_id + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. run = self.get_run_from_identifier(run_id) if run: return JobRunResource(run, self.job_scheduler) @@ -297,7 +301,7 @@ def getChild(self, name, request): if not name: return self - name = maybe_decode(name) + name = maybe_decode(name) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. return resource_from_collection(self.job_collection, name, JobResource) def get_data( diff --git a/tron/commands/display.py b/tron/commands/display.py index c64813260..2ab076327 100644 --- a/tron/commands/display.py +++ b/tron/commands/display.py @@ -443,6 +443,8 @@ def view_with_less(content, color=True): cmd.append("-r") less_proc = subprocess.Popen(cmd, stdin=subprocess.PIPE) - less_proc.stdin.write(maybe_encode(content)) + less_proc.stdin.write( + maybe_encode(content) + ) # TODO: TRON-2293 maybe_encode is a relic of Python2->Python3 migration. Remove it. less_proc.stdin.close() less_proc.wait() diff --git a/tron/config/manager.py b/tron/config/manager.py index 21b6e19ca..7717f039a 100644 --- a/tron/config/manager.py +++ b/tron/config/manager.py @@ -32,7 +32,9 @@ def read(path): def write_raw(path, content): with open(path, "w") as fh: - fh.write(maybe_decode(content)) + fh.write( + maybe_decode(content) + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. def read_raw(path) -> str: @@ -41,7 +43,9 @@ def read_raw(path) -> str: def hash_digest(content): - return hashlib.sha1(maybe_encode(content)).hexdigest() + return hashlib.sha1( + maybe_encode(content) + ).hexdigest() # TODO: TRON-2293 maybe_encode is a relic of Python2->Python3 migration. Remove it. class ManifestFile: diff --git a/tron/core/action.py b/tron/core/action.py index becbafcca..a0ba04d82 100644 --- a/tron/core/action.py +++ b/tron/core/action.py @@ -1,4 +1,5 @@ import datetime +import json import logging from dataclasses import dataclass from dataclasses import field @@ -14,12 +15,13 @@ from tron.config.schema import ConfigProjectedSAVolume from tron.config.schema import ConfigSecretVolume from tron.config.schema import ConfigTopologySpreadConstraints +from tron.utils.persistable import Persistable log = logging.getLogger(__name__) @dataclass -class ActionCommandConfig: +class ActionCommandConfig(Persistable): """A configurable data object for one try of an Action.""" command: str @@ -30,7 +32,6 @@ class ActionCommandConfig: cap_drop: List[str] = field(default_factory=list) constraints: set = field(default_factory=set) docker_image: Optional[str] = None - # XXX: we can get rid of docker_parameters once we're off of Mesos docker_parameters: set = field(default_factory=set) env: dict = field(default_factory=dict) secret_env: dict = field(default_factory=dict) @@ -53,6 +54,50 @@ def state_data(self): def copy(self): return ActionCommandConfig(**self.state_data) + @staticmethod + def to_json(state_data: dict) -> Optional[str]: + """Serialize the ActionCommandConfig instance to a JSON string.""" + + def serialize_namedtuple(obj): + if isinstance(obj, tuple) and hasattr(obj, "_fields"): + return obj._asdict() + return obj + + try: + return json.dumps( + { + "command": state_data["command"], + "cpus": state_data["cpus"], + "mem": state_data["mem"], + "disk": state_data["disk"], + "cap_add": state_data["cap_add"], + "cap_drop": state_data["cap_drop"], + "constraints": list(state_data["constraints"]), + "docker_image": state_data["docker_image"], + "docker_parameters": list(state_data["docker_parameters"]), + "env": state_data["env"], + "secret_env": state_data["secret_env"], + "secret_volumes": [serialize_namedtuple(volume) for volume in state_data["secret_volumes"]], + "projected_sa_volumes": [ + serialize_namedtuple(volume) for volume in state_data["projected_sa_volumes"] + ], + "field_selector_env": state_data["field_selector_env"], + "extra_volumes": list(state_data["extra_volumes"]), + "node_selectors": state_data["node_selectors"], + "node_affinities": [serialize_namedtuple(affinity) for affinity in state_data["node_affinities"]], + "labels": state_data["labels"], + "annotations": state_data["annotations"], + "service_account_name": state_data["service_account_name"], + "ports": state_data["ports"], + } + ) + except KeyError: + log.exception("Missing key in state_data:") + raise + except Exception: + log.exception("Error serializing ActionCommandConfig to JSON:") + raise + @dataclass class Action: diff --git a/tron/core/actionrun.py b/tron/core/actionrun.py index b6a07df69..9249abcb2 100644 --- a/tron/core/actionrun.py +++ b/tron/core/actionrun.py @@ -2,6 +2,7 @@ tron.core.actionrun """ import datetime +import json import logging import os from dataclasses import dataclass @@ -13,6 +14,7 @@ from typing import Union from twisted.internet import reactor +from twisted.internet.base import DelayedCall from tron import command_context from tron import node @@ -21,9 +23,12 @@ from tron.actioncommand import SubprocessActionRunnerFactory from tron.bin.action_runner import build_environment from tron.bin.action_runner import build_labels +from tron.command_context import CommandContext +from tron.config import schema from tron.config.config_utils import StringFormatter from tron.config.schema import ExecutorTypes from tron.core import action +from tron.core.action import ActionCommandConfig from tron.eventbus import EventBus from tron.kubernetes import KubernetesClusterRepository from tron.kubernetes import KubernetesTask @@ -35,6 +40,7 @@ from tron.utils import timeutils from tron.utils.observer import Observable from tron.utils.observer import Observer +from tron.utils.persistable import Persistable from tron.utils.state import Machine @@ -53,7 +59,9 @@ class ActionRunFactory: def build_action_run_collection(cls, job_run, action_runner): """Create an ActionRunCollection from an ActionGraph and JobRun.""" action_run_map = { - maybe_decode(name): cls.build_run_for_action( + maybe_decode( + name + ): cls.build_run_for_action( # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. job_run, action_inst, action_runner, @@ -79,6 +87,7 @@ def action_run_collection_from_state( ), ) + # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. action_run_map = {maybe_decode(action_run.action_name): action_run for action_run in action_runs} return ActionRunCollection(job_run.action_graph, action_run_map) @@ -135,7 +144,7 @@ def action_run_from_state(cls, job_run, state_data, cleanup=False): @dataclass -class ActionRunAttempt: +class ActionRunAttempt(Persistable): """Stores state about one try of an action run.""" command_config: action.ActionCommandConfig @@ -165,6 +174,28 @@ def state_data(self): state_data[field.name] = getattr(self, field.name) return state_data + @staticmethod + def to_json(state_data: dict) -> Optional[str]: + """Serialize the ActionRunAttempt instance to a JSON string.""" + try: + return json.dumps( + { + "command_config": ActionCommandConfig.to_json(state_data["command_config"]), + "start_time": state_data["start_time"].isoformat() if state_data["start_time"] else None, + "end_time": state_data["end_time"].isoformat() if state_data["end_time"] else None, + "rendered_command": state_data["rendered_command"], + "exit_status": state_data["exit_status"], + "mesos_task_id": state_data["mesos_task_id"], + "kubernetes_task_id": state_data["kubernetes_task_id"], + } + ) + except KeyError: + log.exception("Missing key in state_data:") + raise + except Exception: + log.exception("Error serializing ActionRunAttempt to JSON:") + raise + @classmethod def from_state(cls, state_data): # it's possible that we've rolled back to an older Tron version that doesn't support data that we've persisted @@ -182,7 +213,7 @@ def from_state(cls, state_data): return cls(**valid_actionrun_attempt_entries_from_state) -class ActionRun(Observable): +class ActionRun(Observable, Persistable): """Base class for tracking the state of a single run of an Action. ActionRun's state machine is observed by a parent JobRun. @@ -279,32 +310,36 @@ class ActionRun(Observable): # TODO: create a class for ActionRunId, JobRunId, Etc def __init__( self, - job_run_id, - name, - node, - command_config, - parent_context=None, - output_path=None, - cleanup=False, - start_time=None, - end_time=None, - run_state=SCHEDULED, - exit_status=None, - attempts=None, - action_runner=None, - retries_remaining=None, - retries_delay=None, - machine=None, - executor=None, - trigger_downstreams=None, - triggered_by=None, - on_upstream_rerun=None, - trigger_timeout_timestamp=None, - original_command=None, + job_run_id: str, + name: str, + node: node.Node, + command_config: action.ActionCommandConfig, + parent_context: Optional[CommandContext] = None, + output_path: Optional[filehandler.OutputPath] = None, + cleanup: bool = False, + start_time: Optional[datetime.datetime] = None, + end_time: Optional[datetime.datetime] = None, + run_state: str = SCHEDULED, + exit_status: Optional[int] = None, + attempts: Optional[List[ActionRunAttempt]] = None, + action_runner: Optional[Union[NoActionRunnerFactory, SubprocessActionRunnerFactory]] = None, + retries_remaining: Optional[int] = None, + retries_delay: Optional[datetime.timedelta] = None, + machine: Optional[Machine] = None, + executor: Optional[str] = None, + trigger_downstreams: Optional[Union[bool, dict]] = None, + triggered_by: Optional[List[str]] = None, + on_upstream_rerun: Optional[schema.ActionOnRerun] = None, + trigger_timeout_timestamp: Optional[float] = None, + original_command: Optional[str] = None, ): super().__init__() - self.job_run_id = maybe_decode(job_run_id) - self.action_name = maybe_decode(name) + self.job_run_id = maybe_decode( + job_run_id + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. + self.action_name = maybe_decode( + name + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. self.node = node self.start_time = start_time self.end_time = end_time @@ -333,7 +368,7 @@ def __init__( self.trigger_timeout_call = None self.action_command = None - self.in_delay = None + self.in_delay = None # type: Optional[DelayedCall] @property def state(self): @@ -378,7 +413,9 @@ def attempts_from_state(cls, state_data, command_config): if "attempts" in state_data: attempts = [ActionRunAttempt.from_state(a) for a in state_data["attempts"]] else: - rendered_command = maybe_decode(state_data.get("rendered_command")) + rendered_command = maybe_decode( + state_data.get("rendered_command") + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. exit_statuses = state_data.get("exit_statuses", []) # If the action has started, add an attempt for the final try if state_data.get("start_time"): @@ -625,8 +662,6 @@ def triggers_to_emit(self) -> List[str]: templates = ["shortdate.{shortdate}"] elif isinstance(self.trigger_downstreams, dict): templates = [f"{k}.{v}" for k, v in self.trigger_downstreams.items()] - else: - log.error(f"{self} trigger_downstreams must be true or dict") return [self.render_template(trig) for trig in templates] @@ -702,6 +737,44 @@ def state_data(self): "trigger_timeout_timestamp": self.trigger_timeout_timestamp, } + @staticmethod + def to_json(state_data: dict) -> Optional[str]: + """Serialize the ActionRun instance to a JSON string.""" + action_runner = state_data.get("action_runner") + if action_runner is None: + action_runner_json = NoActionRunnerFactory.to_json() + else: + action_runner_json = SubprocessActionRunnerFactory.to_json(action_runner) + + try: + return json.dumps( + { + "job_run_id": state_data["job_run_id"], + "action_name": state_data["action_name"], + "state": state_data["state"], + "original_command": state_data["original_command"], + "start_time": state_data["start_time"].isoformat() if state_data["start_time"] else None, + "end_time": state_data["end_time"].isoformat() if state_data["end_time"] else None, + "node_name": state_data["node_name"], + "exit_status": state_data["exit_status"], + "attempts": [ActionRunAttempt.to_json(attempt) for attempt in state_data["attempts"]], + "retries_remaining": state_data["retries_remaining"], + "retries_delay": state_data["retries_delay"], + "action_runner": action_runner_json, + "executor": state_data["executor"], + "trigger_downstreams": state_data["trigger_downstreams"], + "triggered_by": state_data["triggered_by"], + "on_upstream_rerun": state_data["on_upstream_rerun"], + "trigger_timeout_timestamp": state_data["trigger_timeout_timestamp"], + } + ) + except KeyError: + log.exception("Missing key in state_data:") + raise + except Exception: + log.exception("Error serializing ActionRun to JSON:") + raise + def render_template(self, template): """Render our configured command using the command context.""" return StringFormatter(self.context).format(template) @@ -1233,6 +1306,16 @@ def recover(self) -> Optional[KubernetesTask]: last_attempt = self.attempts[-1] + if last_attempt.rendered_command is None: + log.error(f"{self} rendered_command is None, cannot recover") + self.fail(exitcode.EXIT_INVALID_COMMAND) + return None + + if last_attempt.command_config.docker_image is None: + log.error(f"{self} docker_image is None, cannot recover") + self.fail(exitcode.EXIT_KUBERNETES_TASK_INVALID) + return None + log.info(f"{self} recovering Kubernetes run") task = k8s_cluster.create_task( diff --git a/tron/core/job.py b/tron/core/job.py index 95de2a047..df8171b61 100644 --- a/tron/core/job.py +++ b/tron/core/job.py @@ -1,13 +1,25 @@ +import datetime +import json import logging +from typing import Any +from typing import Dict +from typing import Optional +from typing import TypeVar from tron import command_context from tron import node +from tron.actioncommand import SubprocessActionRunnerFactory from tron.core import jobrun +from tron.core.actiongraph import ActionGraph from tron.core.actionrun import ActionRun +from tron.core.jobrun import JobRunCollection +from tron.node import NodePool +from tron.scheduler import GeneralScheduler from tron.serialize import filehandler from tron.utils import maybe_decode from tron.utils.observer import Observable from tron.utils.observer import Observer +from tron.utils.persistable import Persistable class Error(Exception): @@ -24,8 +36,10 @@ class InvalidStartStateError(Error): log = logging.getLogger(__name__) +T = TypeVar("T", bound="Job") -class Job(Observable, Observer): + +class Job(Observable, Observer, Persistable): """A configurable data object. Job uses JobRunCollection to manage its runs, and ActionGraph to manage its @@ -61,29 +75,30 @@ class Job(Observable, Observer): "run_limit", ] - # TODO: use config object def __init__( self, - name, - scheduler, - queueing=True, - all_nodes=False, - monitoring=None, - node_pool=None, - enabled=True, - action_graph=None, - run_collection=None, - parent_context=None, - output_path=None, - allow_overlap=None, - action_runner=None, - max_runtime=None, - time_zone=None, - expected_runtime=None, - run_limit=None, + name: str, + scheduler: GeneralScheduler, + queueing: bool = True, + all_nodes: bool = False, + monitoring: Optional[Dict[str, Any]] = None, + node_pool: Optional[NodePool] = None, + enabled: bool = True, + action_graph: Optional[ActionGraph] = None, + run_collection: Optional[JobRunCollection] = None, + parent_context: Optional[command_context.CommandContext] = None, + output_path: Optional[filehandler.OutputPath] = None, + allow_overlap: Optional[bool] = None, + action_runner: Optional[SubprocessActionRunnerFactory] = None, + max_runtime: Optional[datetime.timedelta] = None, + time_zone: Optional[datetime.tzinfo] = None, + expected_runtime: Optional[datetime.timedelta] = None, + run_limit: Optional[int] = None, ): super().__init__() - self.name = maybe_decode(name) + self.name = maybe_decode( + name + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. self.monitoring = monitoring self.action_graph = action_graph self.scheduler = scheduler @@ -107,6 +122,15 @@ def __init__( self.run_limit = run_limit log.info(f"{self} created") + @staticmethod + def to_json(state_data: dict) -> Optional[str]: + """Serialize the Job instance to a JSON string.""" + try: + return json.dumps(state_data) + except Exception: + log.exception("Error serializing Job to JSON:") + raise + @classmethod def from_config( cls, diff --git a/tron/core/job_collection.py b/tron/core/job_collection.py index 16b3b8404..dbc3fc8d5 100644 --- a/tron/core/job_collection.py +++ b/tron/core/job_collection.py @@ -40,7 +40,9 @@ def reconfigure_filter(config): else: return config.namespace == namespace_to_reconfigure - # NOTE: as this is a generator expression, we will only go through job configs and build a scheduler for them once something iterates over us (i.e, once `self.state_watcher.watch_all()` is called) + # NOTE: as this is a generator expression, we will only go through job configs + # and build a scheduler for them once something iterates over us (i.e, once + # `self.state_watcher.watch_all()` is called) seq = (factory.build(config) for config in job_configs.values() if reconfigure_filter(config)) return map_to_job_and_schedule(filter(self.add, seq)) diff --git a/tron/core/job_scheduler.py b/tron/core/job_scheduler.py index f3d26b344..3b8262c1c 100644 --- a/tron/core/job_scheduler.py +++ b/tron/core/job_scheduler.py @@ -20,7 +20,7 @@ class JobScheduler(Observer): x seconds into the future. """ - def __init__(self, job): + def __init__(self, job: Job): self.job = job self.watch(job) diff --git a/tron/core/jobgraph.py b/tron/core/jobgraph.py index 4c28fd89d..82147bbd9 100644 --- a/tron/core/jobgraph.py +++ b/tron/core/jobgraph.py @@ -1,6 +1,8 @@ from collections import defaultdict from collections import namedtuple +from typing import Optional +from tron.config.config_parse import ConfigContainer from tron.core.action import Action from tron.core.actiongraph import ActionGraph from tron.utils import maybe_decode @@ -13,7 +15,7 @@ class JobGraph: cross-job dependencies (aka triggers) """ - def __init__(self, config_container, should_validate_missing_dependency=False): + def __init__(self, config_container: ConfigContainer, should_validate_missing_dependency: Optional[bool] = False): """Build an adjacency list and a reverse adjacency list for the graph, and store all the actions as well as which actions belong to which job """ @@ -93,7 +95,9 @@ def get_action_graph_for_job(self, job_name): return ActionGraph(job_action_map, required_actions, required_triggers) def _save_action(self, action_name, job_name, config): - action_name = maybe_decode(action_name) + action_name = maybe_decode( + action_name + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. full_name = f"{job_name}.{action_name}" self.action_map[full_name] = Action.from_config(config) self._actions_for_job[job_name].append(full_name) diff --git a/tron/core/jobrun.py b/tron/core/jobrun.py index b34d23ca9..061ef53c4 100644 --- a/tron/core/jobrun.py +++ b/tron/core/jobrun.py @@ -1,6 +1,7 @@ """ Classes to manage job runs. """ +import datetime import json import logging import time @@ -12,6 +13,7 @@ from tron import node from tron.core.actiongraph import ActionGraph from tron.core.actionrun import ActionRun +from tron.core.actionrun import ActionRunCollection from tron.core.actionrun import ActionRunFactory from tron.serialize import filehandler from tron.utils import maybe_decode @@ -20,6 +22,7 @@ from tron.utils import timeutils from tron.utils.observer import Observable from tron.utils.observer import Observer +from tron.utils.persistable import Persistable log = logging.getLogger(__name__) state_logger = logging.getLogger(f"{__name__}.state_changes") @@ -29,11 +32,11 @@ class Error(Exception): pass -def get_job_run_id(job_name, run_num): +def get_job_run_id(job_name: str, run_num: int) -> str: return f"{job_name}.{run_num}" -class JobRun(Observable, Observer): +class JobRun(Observable, Observer, Persistable): """A JobRun is an execution of a Job. It has a list of ActionRuns and is responsible for starting ActionRuns in the correct order and managing their dependencies. @@ -48,18 +51,20 @@ class JobRun(Observable, Observer): # TODO: use config object def __init__( self, - job_name, - run_num, - run_time, - node, - output_path=None, - base_context=None, - action_runs=None, + job_name: str, + run_num: int, + run_time: datetime.datetime, + node: node.Node, + output_path: Optional[filehandler.OutputPath] = None, + base_context: Optional[command_context.CommandContext] = None, + action_runs: Optional[ActionRunCollection] = None, action_graph: Optional[ActionGraph] = None, - manual=None, + manual: Optional[bool] = None, ): super().__init__() - self.job_name = maybe_decode(job_name) + self.job_name = maybe_decode( + job_name + ) # TODO: TRON-2293 - maybe_decode is a relic of Python2->Python3 migration. Remove it. self.run_num = run_num self.run_time = run_time self.node = node @@ -75,6 +80,28 @@ def __init__( self.context = command_context.build_context(self, base_context) + @staticmethod + def to_json(state_data: dict) -> Optional[str]: + """Serialize the JobRun instance to a JSON string.""" + try: + return json.dumps( + { + "job_name": state_data["job_name"], + "run_num": state_data["run_num"], + "run_time": state_data["run_time"].isoformat() if state_data["run_time"] else None, + "node_name": state_data["node_name"], + "runs": [ActionRun.to_json(run) for run in state_data["runs"]], + "cleanup_run": ActionRun.to_json(state_data["cleanup_run"]) if state_data["cleanup_run"] else None, + "manual": state_data["manual"], + } + ) + except KeyError: + log.exception("Missing key in state_data:") + raise + except Exception: + log.exception("Error serializing JobRun to JSON:") + raise + @property def id(self): return get_job_run_id(self.job_name, self.run_num) diff --git a/tron/mcp.py b/tron/mcp.py index e6b316c10..cf42f52e5 100644 --- a/tron/mcp.py +++ b/tron/mcp.py @@ -192,7 +192,6 @@ def restore_state(self, action_runner): log.info( f"Tron completed restoring state for the jobs. Time elapsed since Tron started {time.time() - self.boot_time}" ) - self.state_watcher.save_metadata() def __str__(self): return "MCP" diff --git a/tron/prom_metrics.py b/tron/prom_metrics.py index 4a11bb8fd..d6eef803c 100644 --- a/tron/prom_metrics.py +++ b/tron/prom_metrics.py @@ -1,6 +1,12 @@ +from prometheus_client import Counter from prometheus_client import Gauge tron_cpu_gauge = Gauge("tron_k8s_cpus", "Total number of CPUs allocated to Tron-launched containers") tron_memory_gauge = Gauge("tron_k8s_mem", "Total amount of memory allocated to Tron-launched containers (in megabytes)") tron_disk_gauge = Gauge("tron_k8s_disk", "Total amount of disk allocated to Tron-launched containers (in megabytes)") + +json_serialization_errors_counter = Counter( + "json_serialization_errors_total", + "Total number of errors encountered while serializing state_data as JSON. These errors occur before writing to DynamoDB.", +) diff --git a/tron/serialize/filehandler.py b/tron/serialize/filehandler.py index 6c3dd8190..f7941ba4f 100644 --- a/tron/serialize/filehandler.py +++ b/tron/serialize/filehandler.py @@ -69,7 +69,9 @@ def write(self, content): return self.last_accessed = time.time() - self._fh.write(maybe_encode(content)) + self._fh.write( + maybe_encode(content) + ) # TODO: TRON-2293 maybe_encode is a relic of Python2->Python3 migration. Remove it. self.manager.update(self) def __enter__(self): diff --git a/tron/serialize/runstate/__init__.py b/tron/serialize/runstate/__init__.py index 7b16b0033..4f31a1807 100644 --- a/tron/serialize/runstate/__init__.py +++ b/tron/serialize/runstate/__init__.py @@ -1,5 +1,4 @@ # State types JOB_STATE = "job_state" JOB_RUN_STATE = "job_run_state" -MCP_STATE = "mcp_state" MESOS_STATE = "mesos_state" diff --git a/tron/serialize/runstate/dynamodb_state_store.py b/tron/serialize/runstate/dynamodb_state_store.py index 4706596a2..34b3a6976 100644 --- a/tron/serialize/runstate/dynamodb_state_store.py +++ b/tron/serialize/runstate/dynamodb_state_store.py @@ -8,18 +8,34 @@ import time from collections import defaultdict from collections import OrderedDict +from typing import Any from typing import DefaultDict +from typing import Dict from typing import List +from typing import Literal +from typing import Optional from typing import Sequence +from typing import Tuple from typing import TypeVar import boto3 # type: ignore +import tron.prom_metrics as prom_metrics +from tron.core.job import Job +from tron.core.jobrun import JobRun from tron.metrics import timer +from tron.serialize import runstate -OBJECT_SIZE = 400000 +# Max DynamoDB object size is 400KB. Since we save two copies of the object (pickled and JSON), +# we need to consider this max size applies to the entire item, so we use a max size of 200KB +# for each version. +# +# In testing I could get away with 201_000 for both partitions so this should be enough overhead +# to contain other attributes like object name and number of partitions. +OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles MAX_SAVE_QUEUE = 500 MAX_ATTEMPTS = 10 +MAX_TRANSACT_WRITE_ITEMS = 100 log = logging.getLogger(__name__) T = TypeVar("T") @@ -46,7 +62,7 @@ def build_key(self, type, iden) -> str: def restore(self, keys) -> dict: """ - Fetch all under the same parition key(keys). + Fetch all under the same parition key(s). ret: """ first_items = self._get_first_partitions(keys) @@ -106,6 +122,7 @@ def _get_first_partitions(self, keys: list): new_keys = [{"key": {"S": key}, "index": {"N": "0"}} for key in keys] return self._get_items(new_keys) + # TODO: Check max partitions as JSON is larger def _get_remaining_partitions(self, items: list): """Get items in the remaining partitions: N = 1 and beyond""" keys_for_remaining_items = [] @@ -142,7 +159,12 @@ def save(self, key_value_pairs) -> None: time.sleep(5) continue with self.save_lock: - self.save_queue[key] = val + if val is None: + self.save_queue[key] = (val, None) + else: + state_type = self.get_type_from_key(key) + json_val = self._serialize_item(state_type, val) + self.save_queue[key] = (val, json_val) break def _consume_save_queue(self): @@ -152,19 +174,21 @@ def _consume_save_queue(self): for _ in range(qlen): try: with self.save_lock: - key, val = self.save_queue.popitem(last=False) + key, (val, json_val) = self.save_queue.popitem(last=False) # Remove all previous data with the same partition key # TODO: only remove excess partitions if new data has fewer self._delete_item(key) if val is not None: - self[key] = pickle.dumps(val) + self[key] = (pickle.dumps(val), json_val) # reset errors count if we can successfully save saved += 1 except Exception as e: - error = "tron_dynamodb_save_failure: failed to save key" f'"{key}" to dynamodb:\n{repr(e)}' + error = "tron_dynamodb_save_failure: failed to save key " f'"{key}" to dynamodb:\n{repr(e)}' log.error(error) + # Add items back to the queue if we failed to save. While we roll out and test TRON-2237 we will only re-add the Pickle. + # TODO: TRON-2239 - Pass JSON back to the save queue with self.save_lock: - self.save_queue[key] = val + self.save_queue[key] = (val, None) duration = time.time() - start log.info(f"saved {saved} items in {duration}s") @@ -173,6 +197,23 @@ def _consume_save_queue(self): else: self.save_errors = 0 + def get_type_from_key(self, key: str) -> str: + return key.split()[0] + + # TODO: TRON-2305 - In an ideal world, we wouldn't be passing around state/state_data dicts. It would be a lot nicer to have regular objects here + def _serialize_item(self, key: Literal[runstate.JOB_STATE, runstate.JOB_RUN_STATE], state: Dict[str, Any]) -> Optional[str]: # type: ignore + try: + if key == runstate.JOB_STATE: + return Job.to_json(state) + elif key == runstate.JOB_RUN_STATE: + return JobRun.to_json(state) + else: + raise ValueError(f"Unknown type: key {key}") + except Exception: + log.exception(f"Serialization error for key {key}") + prom_metrics.json_serialization_errors_counter.inc() + return None + def _save_loop(self): while True: if self.stopping: @@ -189,19 +230,25 @@ def _save_loop(self): log.error("too many dynamodb errors in a row, crashing") os.exit(1) - def __setitem__(self, key: str, val: bytes) -> None: + def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None: """ - Partition the item and write up to 10 partitions atomically. - Retry up to 3 times on failure + Partition the item and write up to MAX_TRANSACT_WRITE_ITEMS + partitions atomically. Retry up to 3 times on failure. - Examine the size of `val`, and splice it into - different parts under 400KB with different sort keys, - and save them under the same partition key built. + Examine the size of `pickled_val` and `json_val`, and + splice them into different parts based on `OBJECT_SIZE` + with different sort keys, and save them under the same + partition key built. """ start = time.time() - num_partitions = math.ceil(len(val) / OBJECT_SIZE) + + pickled_val, json_val = value + num_partitions = math.ceil(len(pickled_val) / OBJECT_SIZE) + num_json_val_partitions = math.ceil(len(json_val) / OBJECT_SIZE) if json_val else 0 items = [] - for index in range(num_partitions): + + max_partitions = max(num_partitions, num_json_val_partitions) + for index in range(max_partitions): item = { "Put": { "Item": { @@ -212,7 +259,9 @@ def __setitem__(self, key: str, val: bytes) -> None: "N": str(index), }, "val": { - "B": val[index * OBJECT_SIZE : min(index * OBJECT_SIZE + OBJECT_SIZE, len(val))], + "B": pickled_val[ + index * OBJECT_SIZE : min(index * OBJECT_SIZE + OBJECT_SIZE, len(pickled_val)) + ], }, "num_partitions": { "N": str(num_partitions), @@ -221,10 +270,19 @@ def __setitem__(self, key: str, val: bytes) -> None: "TableName": self.name, }, } + + if json_val: + item["Put"]["Item"]["json_val"] = { + "S": json_val[index * OBJECT_SIZE : min(index * OBJECT_SIZE + OBJECT_SIZE, len(json_val))] + } + item["Put"]["Item"]["num_json_val_partitions"] = { + "N": str(num_json_val_partitions), + } + count = 0 items.append(item) - # Only up to 10 items are allowed per transactions - while len(items) == 10 or index == num_partitions - 1: + + while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1: try: self.client.transact_write_items(TransactItems=items) items = [] @@ -236,6 +294,7 @@ def __setitem__(self, key: str, val: bytes) -> None: name="tron.dynamodb.setitem", delta=time.time() - start, ) + log.error(f"Failed to save partition for key: {key}, error: {repr(e)}") raise e else: log.warning(f"Got error while saving {key}, trying again: {repr(e)}") @@ -244,6 +303,7 @@ def __setitem__(self, key: str, val: bytes) -> None: delta=time.time() - start, ) + # TODO: TRON-2238 - Is this ok if we just use the max number of partitions? def _delete_item(self, key: str) -> None: start = time.time() try: @@ -261,9 +321,10 @@ def _delete_item(self, key: str) -> None: delta=time.time() - start, ) + # TODO: TRON-2238 - Get max partitions between pickle and json def _get_num_of_partitions(self, key: str) -> int: """ - Return how many parts is the item partitioned into + Return the number of partitions an item is divided into. """ try: partition = self.table.get_item( diff --git a/tron/serialize/runstate/shelvestore.py b/tron/serialize/runstate/shelvestore.py index 22d1bf8ac..4e275b7f5 100644 --- a/tron/serialize/runstate/shelvestore.py +++ b/tron/serialize/runstate/shelvestore.py @@ -12,6 +12,7 @@ log = logging.getLogger(__name__) +# TODO: TRON-2293 This class does some Python 2 and Python 3 handling shenanigans. It should be cleaned up. class Py2Shelf(shelve.Shelf): def __init__(self, filename, flag="c", protocol=2, writeback=False): db = bsddb3.hashopen(filename, flag) @@ -52,8 +53,12 @@ class ShelveKey: __slots__ = ["type", "iden"] def __init__(self, type, iden): - self.type = maybe_decode(type) - self.iden = maybe_decode(iden) + self.type = maybe_decode( + type + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. + self.iden = maybe_decode( + iden + ) # TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. @property def key(self): diff --git a/tron/serialize/runstate/statemanager.py b/tron/serialize/runstate/statemanager.py index 5114261a2..f37b4882b 100644 --- a/tron/serialize/runstate/statemanager.py +++ b/tron/serialize/runstate/statemanager.py @@ -6,6 +6,7 @@ import time from contextlib import contextmanager from typing import Dict +from typing import List from tron.config import schema from tron.core import job @@ -53,42 +54,6 @@ def from_config(cls, persistence_config): return PersistentStateManager(store, buffer) -class StateMetadata: - """A data object for saving state metadata. Conforms to the same - RunState interface as Jobs and Services. - """ - - name = "StateMetadata" - - # State schema version, only first component counts, - # for backwards compatibility - version = (0, 7, 0, 0) - - def __init__(self): - self.state_data = { - "version": self.version, - "create_time": time.time(), - } - - @classmethod - def validate_metadata(cls, metadata): - """Raises an exception if the metadata version is newer then - StateMetadata.version - """ - if not metadata: - return - - if metadata["version"][0] > cls.version[0]: - msg = "State version %s, expected <= %s" - raise VersionMismatchError( - msg - % ( - metadata["version"], - cls.version, - ), - ) - - class StateSaveBuffer: """Buffer calls to save, and perform the saves when buffer reaches buffer size. This buffer will only store one state_data for each key. @@ -138,23 +103,17 @@ def __init__(self, persistence_impl, buffer): self.enabled = True self._buffer = buffer self._impl = persistence_impl - self.metadata_key = self._impl.build_key( - runstate.MCP_STATE, - StateMetadata.name, - ) - def restore(self, job_names, skip_validation=False): + def restore(self, job_names): """Return the most recent serialized state.""" log.debug("Restoring state.") - if not skip_validation: - self._restore_metadata() # First, restore the jobs themselves jobs = self._restore_dicts(runstate.JOB_STATE, job_names) # jobs should be a dictionary that contains job name and number of runs # {'MASTER.k8s': {'run_nums':[0], 'enabled': True}, 'MASTER.cits_test_frequent_1': {'run_nums': [1,0], 'enabled': True}} - # second, restore the runs for each of the jobs restored above + # Second, restore the runs for each of the jobs restored above with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: # start the threads and mark each future with it's job name # this is useful so that we can index the job name later to add the runs to the jobs dictionary @@ -190,16 +149,12 @@ def _restore_runs_for_job(self, job_name, job_state): runs.sort(key=lambda x: x["run_num"], reverse=True) return runs - def _restore_metadata(self): - metadata = self._impl.restore([self.metadata_key]) - StateMetadata.validate_metadata(metadata.get(self.metadata_key)) - def _keys_for_items(self, item_type, names): """Returns a dict of item to the key for that item.""" keys = (self._impl.build_key(item_type, name) for name in names) return dict(zip(keys, names)) - def _restore_dicts(self, item_type, items) -> Dict[str, dict]: + def _restore_dicts(self, item_type: str, items: List[str]) -> Dict[str, dict]: """Return a dict mapping of the items name to its state data.""" key_to_item_map = self._keys_for_items(item_type, items) key_to_state_map = self._impl.restore(key_to_item_map.keys()) @@ -323,9 +278,6 @@ def delete_job_run(self, job_run): def save_frameworks(self, clusters): self._save_object(runstate.MESOS_STATE, clusters) - def save_metadata(self): - self._save_object(runstate.MCP_STATE, StateMetadata()) - def _save_object(self, state_type, obj): self.state_manager.save(state_type, obj.name, obj.state_data) diff --git a/tron/serialize/runstate/yamlstore.py b/tron/serialize/runstate/yamlstore.py index af770b908..4c8d760c8 100644 --- a/tron/serialize/runstate/yamlstore.py +++ b/tron/serialize/runstate/yamlstore.py @@ -15,7 +15,6 @@ TYPE_MAPPING = { runstate.JOB_STATE: "jobs", - runstate.MCP_STATE: runstate.MCP_STATE, } diff --git a/tron/utils/__init__.py b/tron/utils/__init__.py index f39abcc13..45225d948 100644 --- a/tron/utils/__init__.py +++ b/tron/utils/__init__.py @@ -7,12 +7,14 @@ log = logging.getLogger(__name__) +# TODO: TRON-2293 maybe_decode is a relic of Python2->Python3 migration. Remove it. def maybe_decode(maybe_string): if type(maybe_string) is bytes: return maybe_string.decode() return maybe_string +# TODO: TRON-2293 maybe_encode is a relic of Python2->Python3 migration. Remove it. def maybe_encode(maybe_bytes): if type(maybe_bytes) is not bytes: return maybe_bytes.encode() diff --git a/tron/utils/crontab.py b/tron/utils/crontab.py index 953bf9748..1ee5bd07e 100644 --- a/tron/utils/crontab.py +++ b/tron/utils/crontab.py @@ -2,6 +2,11 @@ import calendar import itertools import re +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union PREDEFINED_SCHEDULE = { "@yearly": "0 0 1 1 *", @@ -14,7 +19,7 @@ } -def convert_predefined(line): +def convert_predefined(line: str) -> str: if not line.startswith("@"): return line @@ -23,11 +28,12 @@ def convert_predefined(line): return PREDEFINED_SCHEDULE[line] +# TODO: TRON-1761 - Fix cron validation. The pattern is not working as expected. class FieldParser: """Parse and validate a field in a crontab entry.""" - name = None - bounds = None + name: str = "" + bounds: Tuple[int, int] = (0, 0) range_pattern = re.compile( r""" (?P\d+|\*) # Initial value @@ -37,34 +43,35 @@ class FieldParser: re.VERBOSE, ) - def normalize(self, source): + def normalize(self, source: str) -> str: return source.strip() - def get_groups(self, source): + def get_groups(self, source: str) -> List[str]: return source.split(",") - def parse(self, source): + def parse(self, source: str) -> Optional[Union[List[int], List[Union[int, str]]]]: if source == "*": return None - groups = [self.get_values(group) for group in self.get_groups(source)] - groups = set(itertools.chain.from_iterable(groups)) - has_last = False - if "LAST" in groups: - has_last = True + groups: Set[Union[int, str]] = set( + itertools.chain.from_iterable(self.get_values(group) for group in self.get_groups(source)) + ) + has_last = "LAST" in groups + if has_last: groups.remove("LAST") - groups = sorted(groups) + sorted_groups: List[Union[int, str]] = sorted(groups, key=lambda x: (isinstance(x, str), x)) if has_last: - groups.append("LAST") - return groups + sorted_groups.append("LAST") + + return sorted_groups - def get_match_groups(self, source): + def get_match_groups(self, source: str) -> dict: match = self.range_pattern.match(source) if not match: raise ValueError("Unknown expression: %s" % source) return match.groupdict() - def get_values(self, source): + def get_values(self, source: str) -> List[Union[int, str]]: source = self.normalize(source) match_groups = self.get_match_groups(source) step = 1 @@ -74,7 +81,7 @@ def get_values(self, source): step = self.validate_bounds(match_groups["step"]) return self.get_range(min_value, max_value, step) - def get_value_range(self, match_groups): + def get_value_range(self, match_groups: dict) -> Tuple[int, int]: if match_groups["min"] == "*": return self.bounds @@ -86,7 +93,7 @@ def get_value_range(self, match_groups): return min_value, min_value + 1 - def get_range(self, min_value, max_value, step): + def get_range(self, min_value: int, max_value: int, step: int) -> List[Union[int, str]]: if min_value < max_value: return list(range(min_value, max_value, step)) @@ -94,12 +101,12 @@ def get_range(self, min_value, max_value, step): diff = (max_bound - min_value) + (max_value - min_bound) return [(min_value + i) % max_bound for i in list(range(0, diff, step))] - def validate_bounds(self, value): + def validate_bounds(self, value: str) -> int: min_value, max_value = self.bounds - value = int(value) - if not min_value <= value < max_value: - raise ValueError(f"{self.name} value out of range: {value}") - return value + int_value = int(value) + if not min_value <= int_value < max_value: + raise ValueError(f"{self.name} value out of range: {int_value}") + return int_value class MinuteFieldParser(FieldParser): @@ -116,7 +123,7 @@ class MonthdayFieldParser(FieldParser): name = "monthdays" bounds = (1, 32) - def get_values(self, source): + def get_values(self, source: str) -> List[Union[int, str]]: # Handle special case for last day of month source = self.normalize(source) if source == "L": @@ -130,7 +137,7 @@ class MonthFieldParser(FieldParser): bounds = (1, 13) month_names = calendar.month_abbr[1:] - def normalize(self, month): + def normalize(self, month: str) -> str: month = super().normalize(month) month = month.lower() for month_num, month_name in enumerate(self.month_names, start=1): @@ -143,7 +150,7 @@ class WeekdayFieldParser(FieldParser): bounds = (0, 7) day_names = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"] - def normalize(self, day_of_week): + def normalize(self, day_of_week: str) -> str: day_of_week = super().normalize(day_of_week) day_of_week = day_of_week.lower() for dow_num, dow_name in enumerate(self.day_names): @@ -159,7 +166,7 @@ def normalize(self, day_of_week): # TODO: support L (for dow), W, # -def parse_crontab(line): +def parse_crontab(line: str) -> dict: line = convert_predefined(line) minutes, hours, dom, months, dow = line.split(None, 4) diff --git a/tron/utils/persistable.py b/tron/utils/persistable.py new file mode 100644 index 000000000..620956a2a --- /dev/null +++ b/tron/utils/persistable.py @@ -0,0 +1,12 @@ +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Dict +from typing import Optional + + +class Persistable(ABC): + @staticmethod + @abstractmethod + def to_json(state_data: Dict[Any, Any]) -> Optional[str]: + pass