diff --git a/CHANGELOG.md b/CHANGELOG.md index 54648571..ae82a55d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,11 +6,45 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). ## Unreleased +Nothing notable unreleased. + +## 1.3.0 (2022-10-11) + +### Added + +* (flags) Added a new `absl.flags.set_default` function that updates the flag + default for a provided `FlagHolder`. This parallels the + `absl.flags.FlagValues.set_default` interface which takes a flag name. +* (flags) The following functions now also accept `FlagHolder` instance(s) in + addition to flag name(s) as their first positional argument: + - `flags.register_validator` + - `flags.validator` + - `flags.register_multi_flags_validator` + - `flags.multi_flags_validator` + - `flags.mark_flag_as_required` + - `flags.mark_flags_as_required` + - `flags.mark_flags_as_mutual_exclusive` + - `flags.mark_bool_flags_as_mutual_exclusive` + - `flags.declare_key_flag` + ### Changed * (testing) Assertions `assertRaisesWithPredicateMatch` and `assertRaisesWithLiteralMatch` now capture the raised `Exception` for further analysis when used as a context manager. +* (testing) TextAndXMLTestRunner now produces time duration values with + millisecond precision in XML test result output. +* (flags) Keyword access to `flag_name` arguments in the following functions + is deprecated. This parameter will be renamed in a future 2.0.0 release. + - `flags.register_validator` + - `flags.validator` + - `flags.register_multi_flags_validator` + - `flags.multi_flags_validator` + - `flags.mark_flag_as_required` + - `flags.mark_flags_as_required` + - `flags.mark_flags_as_mutual_exclusive` + - `flags.mark_bool_flags_as_mutual_exclusive` + - `flags.declare_key_flag` ## 1.2.0 (2022-07-18) diff --git a/absl/flags/__init__.py b/absl/flags/__init__.py index 45e64f33..6d8ba033 100644 --- a/absl/flags/__init__.py +++ b/absl/flags/__init__.py @@ -68,6 +68,8 @@ 'mark_flags_as_required', 'mark_flags_as_mutual_exclusive', 'mark_bool_flags_as_mutual_exclusive', + # Flag modifiers. + 'set_default', # Key flag related functions. 'declare_key_flag', 'adopt_module_key_flags', @@ -152,6 +154,9 @@ mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive +# Flag modifiers. +set_default = _defines.set_default + # Key flag related functions. declare_key_flag = _defines.declare_key_flag adopt_module_key_flags = _defines.adopt_module_key_flags diff --git a/absl/flags/_argument_parser.py b/absl/flags/_argument_parser.py index 7a94c69b..2c4de9b1 100644 --- a/absl/flags/_argument_parser.py +++ b/absl/flags/_argument_parser.py @@ -147,7 +147,7 @@ class ArgumentSerializer(object): def serialize(self, value): """Returns a serialized string of the value.""" - return _helpers.str_or_unicode(value) + return str(value) class NumericParser(ArgumentParser): @@ -454,7 +454,7 @@ def __init__(self, list_sep): def serialize(self, value): """See base class.""" - return self.list_sep.join([_helpers.str_or_unicode(x) for x in value]) + return self.list_sep.join([str(x) for x in value]) class EnumClassListSerializer(ListSerializer): @@ -498,7 +498,7 @@ def serialize(self, value): # We need the returned value to be pure ascii or Unicodes so that # when the xml help is generated they are usefully encodable. - return _helpers.str_or_unicode(serialized_value) + return str(serialized_value) class EnumClassSerializer(ArgumentSerializer): @@ -514,7 +514,7 @@ def __init__(self, lowercase): def serialize(self, value): """Returns a serialized string of the Enum class value.""" - as_string = _helpers.str_or_unicode(value.name) + as_string = str(value.name) return as_string.lower() if self._lowercase else as_string diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index 12335e5d..dce53ea2 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -148,6 +148,23 @@ def DEFINE_flag( # pylint: disable=invalid-name fv, flag, ensure_non_none_value=ensure_non_none_value) +def set_default(flag_holder, value): + """Changes the default value of the provided flag object. + + The flag's current value is also updated if the flag is currently using + the default value, i.e. not specified in the command line, and not set + by FLAGS.name = value. + + Args: + flag_holder: FlagHolder, the flag to modify. + value: The new default value. + + Raises: + IllegalFlagValueError: Raised when value is not valid. + """ + flag_holder._flagvalues.set_default(flag_holder.name, value) # pylint: disable=protected-access + + def _internal_declare_key_flags(flag_names, flag_values=_flagvalues.FLAGS, key_flag_values=None): @@ -157,8 +174,7 @@ def _internal_declare_key_flags(flag_names, adopt_module_key_flags instead. Args: - flag_names: [str], a list of strings that are names of already-registered - Flag objects. + flag_names: [str], a list of names of already-registered Flag objects. flag_values: :class:`FlagValues`, the FlagValues instance with which the flags listed in flag_names have registered (the value of the flag_values argument from the ``DEFINE_*`` calls that defined those flags). This @@ -176,8 +192,7 @@ def _internal_declare_key_flags(flag_names, module = _helpers.get_calling_module() for flag_name in flag_names: - flag = flag_values[flag_name] - key_flag_values.register_key_flag_for_module(module, flag) + key_flag_values.register_key_flag_for_module(module, flag_values[flag_name]) def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): @@ -194,9 +209,10 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): flags.declare_key_flag('flag_1') Args: - flag_name: str, the name of an already declared flag. (Redeclaring flags as - key, including flags implicitly key because they were declared in this - module, is a no-op.) + flag_name: str | :class:`FlagHolder`, the name or holder of an already + declared flag. (Redeclaring flags as key, including flags implicitly key + because they were declared in this module, is a no-op.) + Positional-only parameter. flag_values: :class:`FlagValues`, the FlagValues instance in which the flag will be declared as a key flag. This should almost never need to be overridden. @@ -204,6 +220,7 @@ def declare_key_flag(flag_name, flag_values=_flagvalues.FLAGS): Raises: ValueError: Raised if flag_name not defined as a Python flag. """ + flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values) if flag_name in _helpers.SPECIAL_FLAGS: # Take care of the special flags, e.g., --flagfile, --undefok. # These flags are defined in SPECIAL_FLAGS, and are treated diff --git a/absl/flags/_defines.pyi b/absl/flags/_defines.pyi index 0fbe9215..9bc8067a 100644 --- a/absl/flags/_defines.pyi +++ b/absl/flags/_defines.pyi @@ -650,8 +650,11 @@ def DEFINE_alias( ... +def set_default(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None: + ... + -def declare_key_flag(flag_name: Text, +def declare_key_flag(flag_name: Union[Text, _flagvalues.FlagHolder], flag_values: _flagvalues.FlagValues = ...) -> None: ... diff --git a/absl/flags/_flag.py b/absl/flags/_flag.py index 28d92196..124f1371 100644 --- a/absl/flags/_flag.py +++ b/absl/flags/_flag.py @@ -153,7 +153,7 @@ def _get_parsed_value_as_string(self, value): return repr('true') else: return repr('false') - return repr(_helpers.str_or_unicode(value)) + return repr(str(value)) def parse(self, argument): """Parses string and sets flag value. diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index c52990c6..937dc6c2 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -412,11 +412,7 @@ def __setitem__(self, name, flag): fl = self._flags() if not isinstance(flag, _flag.Flag): raise _exceptions.IllegalFlagValueError(flag) - if str is bytes and isinstance(name, unicode): - # When using Python 2 with unicode_literals, allow it but encode it - # into the bytes type we require. - name = name.encode('utf-8') - if not isinstance(name, type('')): + if not isinstance(name, str): raise _exceptions.Error('Flag name must be a string') if not name: raise _exceptions.Error('Flag name cannot be empty') @@ -632,7 +628,7 @@ def __call__(self, argv, known_only=False): TypeError: Raised on passing wrong type of arguments. ValueError: Raised on flag value parsing error. """ - if _helpers.is_bytes_or_string(argv): + if isinstance(argv, (str, bytes)): raise TypeError( 'argv should be a tuple/list of strings, not bytes or string.') if not argv: @@ -1006,7 +1002,7 @@ def get_flag_value(self, name, default): # pylint: disable=invalid-name def _is_flag_file_directive(self, flag_string): """Checks whether flag_string contain a --flagfile= directive.""" - if isinstance(flag_string, type('')): + if isinstance(flag_string, str): if flag_string.startswith('--flagfile='): return 1 elif flag_string == '--flagfile': @@ -1388,3 +1384,35 @@ def default(self): def present(self): """Returns True if the flag was parsed from command-line flags.""" return bool(self._flagvalues[self._name].present) + + +def resolve_flag_ref(flag_ref, flag_values): + """Helper to validate and resolve a flag reference argument.""" + if isinstance(flag_ref, FlagHolder): + new_flag_values = flag_ref._flagvalues # pylint: disable=protected-access + if flag_values != FLAGS and flag_values != new_flag_values: + raise ValueError( + 'flag_values must not be customized when operating on a FlagHolder') + return flag_ref.name, new_flag_values + return flag_ref, flag_values + + +def resolve_flag_refs(flag_refs, flag_values): + """Helper to validate and resolve flag reference list arguments.""" + fv = None + names = [] + for ref in flag_refs: + if isinstance(ref, FlagHolder): + newfv = ref._flagvalues # pylint: disable=protected-access + name = ref.name + else: + newfv = flag_values + name = ref + if fv and fv != newfv: + raise ValueError( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + fv = newfv + names.append(name) + return names, fv diff --git a/absl/flags/_helpers.py b/absl/flags/_helpers.py index cb0cfb20..ea02f2d1 100644 --- a/absl/flags/_helpers.py +++ b/absl/flags/_helpers.py @@ -32,8 +32,9 @@ _DEFAULT_HELP_WIDTH = 80 # Default width of help output. -_MIN_HELP_WIDTH = 40 # Minimal "sane" width of help output. We assume that any - # value below 40 is unreasonable. +# Minimal "sane" width of help output. We assume that any value below 40 is +# unreasonable. +_MIN_HELP_WIDTH = 40 # Define the allowed error rate in an input string to get suggestions. # @@ -125,32 +126,6 @@ def get_calling_module(): return get_calling_module_object_and_name().module_name -def str_or_unicode(value): - """Converts a value to a python string. - - Behavior of this function is intentionally different in Python2/3. - - In Python2, the given value is attempted to convert to a str (byte string). - If it contains non-ASCII characters, it is converted to a unicode instead. - - In Python3, the given value is always converted to a str (unicode string). - - This behavior reflects the (bad) practice in Python2 to try to represent - a string as str as long as it contains ASCII characters only. - - Args: - value: An object to be converted to a string. - - Returns: - A string representation of the given value. See the description above - for its type. - """ - try: - return str(value) - except UnicodeEncodeError: - return unicode(value) # Python3 should never come here - - def create_xml_dom_element(doc, name, value): """Returns an XML DOM element with name and text value. @@ -164,7 +139,7 @@ def create_xml_dom_element(doc, name, value): Returns: An instance of minidom.Element. """ - s = str_or_unicode(value) + s = str(value) if isinstance(value, bool): # Display boolean values as the C++ flag library does: no caps. s = s.lower() @@ -424,10 +399,3 @@ def doc_to_help(doc): doc = re.sub(r'(?<=\S)\n(?=\S)', ' ', doc, flags=re.M) return doc - - -def is_bytes_or_string(maybe_string): - if str is bytes: - return isinstance(maybe_string, basestring) - else: - return isinstance(maybe_string, (str, bytes)) diff --git a/absl/flags/_validators.py b/absl/flags/_validators.py index c4e11392..2161284a 100644 --- a/absl/flags/_validators.py +++ b/absl/flags/_validators.py @@ -51,7 +51,8 @@ def register_validator(flag_name, change of the corresponding flag's value. Args: - flag_name: str, name of the flag to be checked. + flag_name: str | FlagHolder, name or holder of the flag to be checked. + Positional-only parameter. checker: callable, a function to validate the flag. * input - A single positional argument: The value of the corresponding @@ -70,7 +71,10 @@ def register_validator(flag_name, Raises: AttributeError: Raised when flag_name is not registered as a valid flag name. + ValueError: Raised when flag_values is non-default and does not match the + FlagValues of the provided FlagHolder instance. """ + flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values) v = _validators_classes.SingleFlagValidator(flag_name, checker, message) _add_validator(flag_values, v) @@ -88,7 +92,8 @@ def _CheckFoo(foo): See :func:`register_validator` for the specification of checker function. Args: - flag_name: str, name of the flag to be checked. + flag_name: str | FlagHolder, name or holder of the flag to be checked. + Positional-only parameter. message: str, error text to be shown to the user if checker returns False. If checker raises flags.ValidationError, message from the raised error will be shown. @@ -119,7 +124,8 @@ def register_multi_flags_validator(flag_names, change of the corresponding flag's value. Args: - flag_names: [str], a list of the flag names to be checked. + flag_names: [str | FlagHolder], a list of the flag names or holders to be + checked. Positional-only parameter. multi_flags_checker: callable, a function to validate the flag. * input - dict, with keys() being flag_names, and value for each key @@ -136,7 +142,13 @@ def register_multi_flags_validator(flag_names, Raises: AttributeError: Raised when a flag is not registered as a valid flag name. + ValueError: Raised when multiple FlagValues are used in the same + invocation. This can occur when FlagHolders have different `_flagvalues` + or when str-type flag_names entries are present and the `flag_values` + argument does not match that of provided FlagHolder(s). """ + flag_names, flag_values = _flagvalues.resolve_flag_refs( + flag_names, flag_values) v = _validators_classes.MultiFlagsValidator( flag_names, multi_flags_checker, message) _add_validator(flag_values, v) @@ -157,7 +169,8 @@ def _CheckFooBar(flags_dict): function. Args: - flag_names: [str], a list of the flag names to be checked. + flag_names: [str | FlagHolder], a list of the flag names or holders to be + checked. Positional-only parameter. message: str, error text to be shown to the user if checker returns False. If checker raises flags.ValidationError, message from the raised error will be shown. @@ -196,13 +209,17 @@ def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS): app.run() Args: - flag_name: str, name of the flag + flag_name: str | FlagHolder, name or holder of the flag. + Positional-only parameter. flag_values: flags.FlagValues, optional :class:`~absl.flags.FlagValues` instance where the flag is defined. Raises: AttributeError: Raised when flag_name is not registered as a valid flag name. + ValueError: Raised when flag_values is non-default and does not match the + FlagValues of the provided FlagHolder instance. """ + flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values) if flag_values[flag_name].default is not None: warnings.warn( 'Flag --%s has a non-None default value; therefore, ' @@ -227,7 +244,7 @@ def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS): app.run() Args: - flag_names: Sequence[str], names of the flags. + flag_names: Sequence[str | FlagHolder], names or holders of the flags. flag_values: flags.FlagValues, optional FlagValues instance where the flags are defined. Raises: @@ -248,13 +265,22 @@ def mark_flags_as_mutual_exclusive(flag_names, required=False, includes multi flags with a default value of ``[]`` instead of None. Args: - flag_names: [str], names of the flags. + flag_names: [str | FlagHolder], names or holders of flags. + Positional-only parameter. required: bool. If true, exactly one of the flags must have a value other than None. Otherwise, at most one of the flags can have a value other than None, and it is valid for all of the flags to be None. flag_values: flags.FlagValues, optional FlagValues instance where the flags are defined. + + Raises: + ValueError: Raised when multiple FlagValues are used in the same + invocation. This can occur when FlagHolders have different `_flagvalues` + or when str-type flag_names entries are present and the `flag_values` + argument does not match that of provided FlagHolder(s). """ + flag_names, flag_values = _flagvalues.resolve_flag_refs( + flag_names, flag_values) for flag_name in flag_names: if flag_values[flag_name].default is not None: warnings.warn( @@ -280,12 +306,21 @@ def mark_bool_flags_as_mutual_exclusive(flag_names, required=False, """Ensures that only one flag among flag_names is True. Args: - flag_names: [str], names of the flags. + flag_names: [str | FlagHolder], names or holders of flags. + Positional-only parameter. required: bool. If true, exactly one flag must be True. Otherwise, at most one flag can be True, and it is valid for all flags to be False. flag_values: flags.FlagValues, optional FlagValues instance where the flags are defined. + + Raises: + ValueError: Raised when multiple FlagValues are used in the same + invocation. This can occur when FlagHolders have different `_flagvalues` + or when str-type flag_names entries are present and the `flag_values` + argument does not match that of provided FlagHolder(s). """ + flag_names, flag_values = _flagvalues.resolve_flag_refs( + flag_names, flag_values) for flag_name in flag_names: if not flag_values[flag_name].boolean: raise _exceptions.ValidationError( diff --git a/absl/flags/_validators_classes.py b/absl/flags/_validators_classes.py index 28814991..59100c8e 100644 --- a/absl/flags/_validators_classes.py +++ b/absl/flags/_validators_classes.py @@ -156,7 +156,7 @@ def _get_input_to_checker_function(self, flag_values): Args: flag_values: flags.FlagValues, the FlagValues instance to get flags from. Returns: - dict, with keys() being self.lag_names, and value for each key + dict, with keys() being self.flag_names, and value for each key being the value of the corresponding flag (string, boolean, etc). """ return dict([key, flag_values[key].value] for key in self.flag_names) diff --git a/absl/flags/tests/_helpers_test.py b/absl/flags/tests/_helpers_test.py index 2697d1c0..78b90518 100644 --- a/absl/flags/tests/_helpers_test.py +++ b/absl/flags/tests/_helpers_test.py @@ -150,20 +150,5 @@ def iteritems(self): sys.modules = orig_sys_modules -class IsBytesOrString(absltest.TestCase): - - def test_bytes(self): - self.assertTrue(_helpers.is_bytes_or_string(b'bytes')) - - def test_str(self): - self.assertTrue(_helpers.is_bytes_or_string('str')) - - def test_unicode(self): - self.assertTrue(_helpers.is_bytes_or_string(u'unicode')) - - def test_list(self): - self.assertFalse(_helpers.is_bytes_or_string(['str'])) - - if __name__ == '__main__': absltest.main() diff --git a/absl/flags/tests/_validators_test.py b/absl/flags/tests/_validators_test.py index 1cccf535..9aa328e0 100644 --- a/absl/flags/tests/_validators_test.py +++ b/absl/flags/tests/_validators_test.py @@ -55,6 +55,45 @@ def checker(x): self.assertEqual(2, self.flag_values.test_flag) self.assertEqual([None, 2], self.call_args) + def test_success_holder(self): + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + flag_holder, + checker, + message='Errors happen', + flag_values=self.flag_values) + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + + def test_success_holder_infer_flagvalues(self): + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', None, 'Usual integer flag', flag_values=self.flag_values) + _validators.register_validator( + flag_holder, + checker, + message='Errors happen') + + argv = ('./program',) + self.flag_values(argv) + self.assertIsNone(self.flag_values.test_flag) + self.flag_values.test_flag = 2 + self.assertEqual(2, self.flag_values.test_flag) + self.assertEqual([None, 2], self.call_args) + def test_default_value_not_used_success(self): def checker(x): self.call_args.append(x) @@ -218,6 +257,26 @@ def checker(x): self.assertTrue(checker(3)) self.assertEqual([None, 2, 3], self.call_args) + def test_mismatching_flagvalues(self): + + def checker(x): + self.call_args.append(x) + return True + + flag_holder = _defines.DEFINE_integer( + 'test_flag', + None, + 'Usual integer flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'flag_values must not be customized when operating on a FlagHolder') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.register_validator( + flag_holder, + checker, + message='Errors happen', + flag_values=self.flag_values) + class MultiFlagsValidatorTest(absltest.TestCase): """Test flags multi-flag validators.""" @@ -226,9 +285,9 @@ def setUp(self): super(MultiFlagsValidatorTest, self).setUp() self.flag_values = _flagvalues.FlagValues() self.call_args = [] - _defines.DEFINE_integer( + self.foo_holder = _defines.DEFINE_integer( 'foo', 1, 'Usual integer flag', flag_values=self.flag_values) - _defines.DEFINE_integer( + self.bar_holder = _defines.DEFINE_integer( 'bar', 2, 'Usual integer flag', flag_values=self.flag_values) def test_success(self): @@ -248,6 +307,55 @@ def checker(flags_dict): self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}], self.call_args) + def test_success_holder(self): + + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder], + checker, + flag_values=self.flag_values) + + argv = ('./program', '--bar=2') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.foo) + self.assertEqual(2, self.flag_values.bar) + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + self.flag_values.foo = 3 + self.assertEqual(3, self.flag_values.foo) + self.assertEqual([{ + 'foo': 1, + 'bar': 2 + }, { + 'foo': 3, + 'bar': 2 + }], self.call_args) + + def test_success_holder_infer_flagvalues(self): + def checker(flags_dict): + self.call_args.append(flags_dict) + return True + + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder], checker) + + argv = ('./program', '--bar=2') + self.flag_values(argv) + self.assertEqual(1, self.flag_values.foo) + self.assertEqual(2, self.flag_values.bar) + self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args) + self.flag_values.foo = 3 + self.assertEqual(3, self.flag_values.foo) + self.assertEqual([{ + 'foo': 1, + 'bar': 2 + }, { + 'foo': 3, + 'bar': 2 + }], self.call_args) + def test_validator_not_called_when_other_flag_is_changed(self): def checker(flags_dict): self.call_args.append(flags_dict) @@ -322,6 +430,30 @@ def checker(flags_dict): # pylint: disable=unused-variable self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], self.call_args) + def test_mismatching_flagvalues(self): + + def checker(flags_dict): + self.call_args.append(flags_dict) + values = flags_dict.values() + # Make sure all the flags have different values. + return len(set(values)) == len(values) + + other_holder = _defines.DEFINE_integer( + 'other_flag', + 3, + 'Other integer flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.register_multi_flags_validator( + [self.foo_holder, self.bar_holder, other_holder], + checker, + message='Errors happen', + flag_values=self.flag_values) + class MarkFlagsAsMutualExclusiveTest(absltest.TestCase): @@ -329,9 +461,9 @@ def setUp(self): super(MarkFlagsAsMutualExclusiveTest, self).setUp() self.flag_values = _flagvalues.FlagValues() - _defines.DEFINE_string( + self.flag_one_holder = _defines.DEFINE_string( 'flag_one', None, 'flag one', flag_values=self.flag_values) - _defines.DEFINE_string( + self.flag_two_holder = _defines.DEFINE_string( 'flag_two', None, 'flag two', flag_values=self.flag_values) _defines.DEFINE_string( 'flag_three', None, 'flag three', flag_values=self.flag_values) @@ -358,6 +490,24 @@ def test_no_flags_present(self): self.assertIsNone(self.flag_values.flag_one) self.assertIsNone(self.flag_values.flag_two) + def test_no_flags_present_holder(self): + self._mark_flags_as_mutually_exclusive( + [self.flag_one_holder, self.flag_two_holder], False) + argv = ('./program',) + + self.flag_values(argv) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) + + def test_no_flags_present_mixed(self): + self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'], + False) + argv = ('./program',) + + self.flag_values(argv) + self.assertIsNone(self.flag_values.flag_one) + self.assertIsNone(self.flag_values.flag_two) + def test_no_flags_present_required(self): self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True) argv = ('./program',) @@ -494,6 +644,20 @@ def test_flag_default_not_none_warning(self): self.assertIn('--flag_not_none has a non-None default value', str(caught_warnings[0].message)) + def test_multiple_flagvalues(self): + other_holder = _defines.DEFINE_boolean( + 'other_flagvalues', + False, + 'other ', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + self._mark_flags_as_mutually_exclusive( + [self.flag_one_holder, other_holder], False) + class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase): @@ -501,13 +665,13 @@ def setUp(self): super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp() self.flag_values = _flagvalues.FlagValues() - _defines.DEFINE_boolean( + self.false_1_holder = _defines.DEFINE_boolean( 'false_1', False, 'default false 1', flag_values=self.flag_values) - _defines.DEFINE_boolean( + self.false_2_holder = _defines.DEFINE_boolean( 'false_2', False, 'default false 2', flag_values=self.flag_values) - _defines.DEFINE_boolean( + self.true_1_holder = _defines.DEFINE_boolean( 'true_1', True, 'default true 1', flag_values=self.flag_values) - _defines.DEFINE_integer( + self.non_bool_holder = _defines.DEFINE_integer( 'non_bool', None, 'non bool', flag_values=self.flag_values) def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required): @@ -520,6 +684,20 @@ def test_no_flags_present(self): self.assertEqual(False, self.flag_values.false_1) self.assertEqual(False, self.flag_values.false_2) + def test_no_flags_present_holder(self): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, self.false_2_holder], False) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(False, self.flag_values.false_2) + + def test_no_flags_present_mixed(self): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, 'false_2'], False) + self.flag_values(('./program',)) + self.assertEqual(False, self.flag_values.false_1) + self.assertEqual(False, self.flag_values.false_2) + def test_no_flags_present_required(self): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True) argv = ('./program',) @@ -554,6 +732,17 @@ def test_non_bool_flag(self): self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'], False) + def test_multiple_flagvalues(self): + other_bool_holder = _defines.DEFINE_boolean( + 'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues()) + expected = ( + 'multiple FlagValues instances used in invocation. ' + 'FlagHolders must be registered to the same FlagValues instance as ' + 'do flag names, if provided.') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + self._mark_bool_flags_as_mutually_exclusive( + [self.false_1_holder, other_bool_holder], False) + class MarkFlagAsRequiredTest(absltest.TestCase): @@ -570,6 +759,22 @@ def test_success(self): self.flag_values(argv) self.assertEqual('value', self.flag_values.string_flag) + def test_success_holder(self): + holder = _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required(holder, flag_values=self.flag_values) + argv = ('./program', '--string_flag=value') + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + + def test_success_holder_infer_flagvalues(self): + holder = _defines.DEFINE_string( + 'string_flag', None, 'string flag', flag_values=self.flag_values) + _validators.mark_flag_as_required(holder) + argv = ('./program', '--string_flag=value') + self.flag_values(argv) + self.assertEqual('value', self.flag_values.string_flag) + def test_catch_none_as_default(self): _defines.DEFINE_string( 'string_flag', None, 'string flag', flag_values=self.flag_values) @@ -608,6 +813,18 @@ def test_flag_default_not_none_warning(self): self.assertIn('--flag_not_none has a non-None default value', str(caught_warnings[0].message)) + def test_mismatching_flagvalues(self): + flag_holder = _defines.DEFINE_string( + 'string_flag', + 'value', + 'string flag', + flag_values=_flagvalues.FlagValues()) + expected = ( + 'flag_values must not be customized when operating on a FlagHolder') + with self.assertRaisesWithLiteralMatch(ValueError, expected): + _validators.mark_flag_as_required( + flag_holder, flag_values=self.flag_values) + class MarkFlagsAsRequiredTest(absltest.TestCase): @@ -627,6 +844,18 @@ def test_success(self): self.assertEqual('value_1', self.flag_values.string_flag_1) self.assertEqual('value_2', self.flag_values.string_flag_2) + def test_success_holders(self): + flag_1_holder = _defines.DEFINE_string( + 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) + flag_2_holder = _defines.DEFINE_string( + 'string_flag_2', None, 'string flag 2', flag_values=self.flag_values) + _validators.mark_flags_as_required([flag_1_holder, flag_2_holder], + flag_values=self.flag_values) + argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2') + self.flag_values(argv) + self.assertEqual('value_1', self.flag_values.string_flag_1) + self.assertEqual('value_2', self.flag_values.string_flag_2) + def test_catch_none_as_default(self): _defines.DEFINE_string( 'string_flag_1', None, 'string flag 1', flag_values=self.flag_values) diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index 8a42bc96..77ed307e 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -2483,6 +2483,71 @@ def test_flag_definition_via_setitem(self): flag_values['flag_name'] = 'flag_value' +class SetDefaultTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.flag_values = flags.FlagValues() + + def test_success(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + flags.set_default(int_holder, 2) + self.flag_values.mark_as_parsed() + + self.assertEqual(int_holder.value, 2) + + def test_update_after_parse(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + flags.set_default(int_holder, 2) + + self.assertEqual(int_holder.value, 2) + + def test_overridden_by_explicit_assignment(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + self.flag_values.an_int = 3 + flags.set_default(int_holder, 2) + + self.assertEqual(int_holder.value, 3) + + def test_restores_back_to_none(self): + int_holder = flags.DEFINE_integer( + 'an_int', None, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + flags.set_default(int_holder, 3) + flags.set_default(int_holder, None) + + self.assertIsNone(int_holder.value) + + def test_failure_on_invalid_type(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + + with self.assertRaises(flags.IllegalFlagValueError): + flags.set_default(int_holder, 'a') + + def test_failure_on_type_protected_none_default(self): + int_holder = flags.DEFINE_integer( + 'an_int', 1, 'an int', flag_values=self.flag_values) + + self.flag_values.mark_as_parsed() + + flags.set_default(int_holder, None) # NOTE: should be a type failure + + with self.assertRaises(flags.IllegalFlagValueError): + _ = int_holder.value # Will also fail on later access. + + class KeyFlagsTest(absltest.TestCase): def setUp(self): @@ -2646,6 +2711,40 @@ def test_key_flags_with_non_default_flag_values_object(self): self._get_names_of_key_flags(main_module, fv), names_of_flags_defined_by_bar + ['flagfile', 'undefok']) + def test_key_flags_with_flagholders(self): + main_module = sys.argv[0] + + self.assertListEqual( + self._get_names_of_key_flags(main_module, self.flag_values), []) + self.assertListEqual( + self._get_names_of_defined_flags(main_module, self.flag_values), []) + + int_holder = flags.DEFINE_integer( + 'main_module_int_fg', + 1, + 'Integer flag in the main module.', + flag_values=self.flag_values) + + flags.declare_key_flag(int_holder, self.flag_values) + + self.assertCountEqual( + self.flag_values.get_flags_for_module(main_module), + self.flag_values.get_key_flags_for_module(main_module)) + + bool_holder = flags.DEFINE_boolean( + 'main_module_bool_fg', + False, + 'Boolean flag in the main module.', + flag_values=self.flag_values) + + flags.declare_key_flag(bool_holder) # omitted flag_values + + self.assertCountEqual( + self.flag_values.get_flags_for_module(main_module), + self.flag_values.get_key_flags_for_module(main_module)) + + self.assertLen(self.flag_values.get_flags_for_module(main_module), 2) + def test_main_module_help_with_key_flags(self): # Similar to test_main_module_help, but this time we make sure to # declare some key flags. diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py index d1b0e590..9071f8f6 100644 --- a/absl/testing/absltest.py +++ b/absl/testing/absltest.py @@ -735,7 +735,7 @@ def enter_context(self, manager): (e.g. `TestCase.enter_context`), the context is exited after the test class's tearDownClass call. - Contexts are are exited in the reverse order of entering. They will always + Contexts are exited in the reverse order of entering. They will always be exited, regardless of test failure/success. This is useful to eliminate per-test boilerplate when context managers diff --git a/absl/testing/tests/xml_reporter_test.py b/absl/testing/tests/xml_reporter_test.py index 0261f641..c0d43a60 100644 --- a/absl/testing/tests/xml_reporter_test.py +++ b/absl/testing/tests/xml_reporter_test.py @@ -64,12 +64,12 @@ def xml_escaped_exception_type(exception_type): OUTPUT_STRING = '\n'.join([ r'<\?xml version="1.0"\?>', (''), + ' errors="%(errors)d" time="%(run_time).3f" timestamp="%(start_time)s">'), (''), (' %(message)s'), ' ', '', '', @@ -696,8 +696,8 @@ def test_suite_time(self): run_time = max(end_time1, end_time2) - min(start_time1, start_time2) timestamp = self._iso_timestamp(start_time1) expected_prefix = """ - - + + """ % (run_time, timestamp, run_time, timestamp) xml_output = self.xml_stream.getvalue() self.assertTrue( diff --git a/absl/testing/xml_reporter.py b/absl/testing/xml_reporter.py index 5996ce2a..591eb7ef 100644 --- a/absl/testing/xml_reporter.py +++ b/absl/testing/xml_reporter.py @@ -202,7 +202,7 @@ def print_xml_summary(self, stream): ('name', '%s' % self.name), ('status', '%s' % status), ('result', '%s' % result), - ('time', '%.1f' % self.run_time), + ('time', '%.3f' % self.run_time), ('classname', self.full_class_name), ('timestamp', _iso8601_timestamp(self.start_time)), ] @@ -263,7 +263,7 @@ def print_xml_summary(self, stream): ('tests', '%d' % overall_test_count), ('failures', '%d' % overall_failures), ('errors', '%d' % overall_errors), - ('time', '%.1f' % (self.overall_end_time - self.overall_start_time)), + ('time', '%.3f' % (self.overall_end_time - self.overall_start_time)), ('timestamp', _iso8601_timestamp(self.overall_start_time)), ] _print_xml_element_header('testsuites', overall_attributes, stream) @@ -285,7 +285,7 @@ def print_xml_summary(self, stream): ('tests', '%d' % len(suite)), ('failures', '%d' % failures), ('errors', '%d' % errors), - ('time', '%.1f' % (suite_end_time - suite_start_time)), + ('time', '%.3f' % (suite_end_time - suite_start_time)), ('timestamp', _iso8601_timestamp(suite_start_time)), ] _print_xml_element_header('testsuite', suite_attributes, stream) diff --git a/setup.py b/setup.py index 23fcac2c..f947fd7b 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ setuptools.setup( name='absl-py', - version='1.2.0', + version='1.3.0', description=( 'Abseil Python Common Libraries, ' 'see https://github.com/abseil/abseil-py.'),