diff --git a/tests/test_config_overrides.py b/tests/test_config_overrides.py index bc3f606..08366ce 100755 --- a/tests/test_config_overrides.py +++ b/tests/test_config_overrides.py @@ -63,6 +63,30 @@ def test_update_from_args_controller_host(self): self.assertEqual(self.config.get("controller", "host"), "foo") self.assertEqual(self.config.getint("controller", "port"), 2149) + parser = zeekclient.cli.create_parser() + args = parser.parse_args(["--controller", "127.0.0.1"]) + self.config.update_from_args(args) + self.assertEqual(self.config.get("controller", "host"), "127.0.0.1") + self.assertEqual(self.config.getint("controller", "port"), 2149) + + parser = zeekclient.cli.create_parser() + args = parser.parse_args(["--controller", "127.0.0.1:"]) + self.config.update_from_args(args) + self.assertEqual(self.config.get("controller", "host"), "127.0.0.1") + self.assertEqual(self.config.getint("controller", "port"), 2149) + + parser = zeekclient.cli.create_parser() + args = parser.parse_args(["--controller", "[fe80::1]"]) + self.config.update_from_args(args) + self.assertEqual(self.config.get("controller", "host"), "[fe80::1]") + self.assertEqual(self.config.getint("controller", "port"), 2149) + + parser = zeekclient.cli.create_parser() + args = parser.parse_args(["--controller", "[fe80::1]:"]) + self.config.update_from_args(args) + self.assertEqual(self.config.get("controller", "host"), "[fe80::1]") + self.assertEqual(self.config.getint("controller", "port"), 2149) + def test_update_from_args_controller_port(self): parser = zeekclient.cli.create_parser() args = parser.parse_args(["--controller", ":2222"]) @@ -76,3 +100,15 @@ def test_update_from_args_controller_hostport(self): self.config.update_from_args(args) self.assertEqual(self.config.get("controller", "host"), "foo") self.assertEqual(self.config.getint("controller", "port"), 2222) + + parser = zeekclient.cli.create_parser() + args = parser.parse_args(["--controller", "127.0.0.1:2222"]) + self.config.update_from_args(args) + self.assertEqual(self.config.get("controller", "host"), "127.0.0.1") + self.assertEqual(self.config.getint("controller", "port"), 2222) + + parser = zeekclient.cli.create_parser() + args = parser.parse_args(["--controller", "[fe80::1]:2222"]) + self.config.update_from_args(args) + self.assertEqual(self.config.get("controller", "host"), "[fe80::1]") + self.assertEqual(self.config.getint("controller", "port"), 2222) diff --git a/zeekclient/config.py b/zeekclient/config.py index 9200250..461f699 100644 --- a/zeekclient/config.py +++ b/zeekclient/config.py @@ -118,16 +118,30 @@ def update_from_args(self, args): # The `--controller` argument is a shortcut for two `--set` arguments that # set controller host and port, so update these manually: if args.controller: - host_port = args.controller.split(":", 1) - if len(host_port) != 2 or not host_port[1]: + if ":" not in args.controller: + host = args.controller + port = "" + else: + (host, _, port) = args.controller.rpartition(":") + if host.count(":") >= 1: + # We likely have an IPv6 address + if not host.startswith("["): + raise ValueError( + "IPv6 addresses must be surrounded by brackets: []:" + ) + if port.endswith("]"): + host = host + ":" + port + port = "" + + if port == "": # It's just a hostname - self.set("controller", "host", host_port[0]) - elif not host_port[0]: + self.set("controller", "host", host) + elif host == "": # It's just a port (as ":") - self.set("controller", "port", host_port[1]) + self.set("controller", "port", port) else: - self.set("controller", "host", host_port[0]) - self.set("controller", "port", host_port[1]) + self.set("controller", "host", host) + self.set("controller", "port", port) # --verbose/-v/-vvv etc set a numeric verbosity level: if args.verbose: