From bc556086e282aad34e2161bbffdae8009fd559fb Mon Sep 17 00:00:00 2001 From: isaacyang Date: Wed, 22 Jan 2025 18:16:53 +0800 Subject: [PATCH] Fix: Eth hotplug use the right way to check cable before pinging --- providers/base/bin/eth_hotplugging.py | 167 +++++++- providers/base/tests/test_eth_hotplugging.py | 401 +++++++++++++++++++ 2 files changed, 546 insertions(+), 22 deletions(-) create mode 100644 providers/base/tests/test_eth_hotplugging.py diff --git a/providers/base/bin/eth_hotplugging.py b/providers/base/bin/eth_hotplugging.py index 7d57e84090..0dff69a8ec 100755 --- a/providers/base/bin/eth_hotplugging.py +++ b/providers/base/bin/eth_hotplugging.py @@ -7,7 +7,103 @@ """Check if hotplugging works on an ethernet port.""" import sys +import os import time +import glob +import yaml +import subprocess as sp + +from gateway_ping_test import perform_ping_test + +NETPLAN_CFG_PATHS = ("/etc/netplan", "/lib/netplan", "/run/netplan") + + +def netplan_renderer(): + """ + Check the renderer used by netplan on the system if it is networkd or + NetworkManager. + This function looks for the renderer used in the yaml files located in the + NETPLAN_CFG_PATHS directories, and returns the first renderer found. + If the renderer is not found, it defaults to "networkd". + If the netplan file is not found, it defaults to "NetworkManager". + """ + netplan_file_exist = False + for basedir in NETPLAN_CFG_PATHS: + if os.path.exists(basedir): + files = glob.glob(os.path.join(basedir, "*.yaml")) + for f in files: + netplan_file_exist = True + with open(f, "r") as file: + data = yaml.safe_load(file) + if "renderer" in data["network"]: + return data["network"]["renderer"] + if netplan_file_exist: + return "networkd" + return "NetworkManager" + + +def get_interface_info(interface, renderer): + """ + Get the interface information (state and gateway) from the renderer. + """ + if renderer == "networkd": + cmd = "networkctl status --no-pager --no-legend {}".format(interface) + key_map = {"State": "state", "Gateway": "gateway"} + elif renderer == "NetworkManager": + cmd = "nmcli device show {}".format(interface) + key_map = {"GENERAL.STATE": "state", "IP4.GATEWAY": "gateway"} + else: + raise ValueError("Unknown renderer: {}".format(renderer)) + + return _get_cmd_info(cmd, key_map, renderer) + + +def _get_cmd_info(cmd, key_map, renderer): + info = {} + try: + output = sp.check_output(cmd, shell=True) + for line in output.decode(sys.stdout.encoding).splitlines(): + # Skip lines that don't have a "key: value" format + if ":" not in line: + continue + key, val = line.strip().split(":", maxsplit=1) + key = key.strip() + val = val.strip() + if key in key_map: + info[key_map[key]] = val + except sp.CalledProcessError as e: + print("Error running {} command: {}".format(renderer, e)) + return info + + +def _check_routable_state(interface, renderer): + """ + Check if the interface is in a routable state depending on the renderer + """ + routable = False + state = "" + info = get_interface_info(interface, renderer) + state = info.get("state", "") + if renderer == "networkd": + routable = "routable" in state + elif renderer == "NetworkManager": + routable = "connected" in state and "disconnected" not in state + return (routable, state) + + +def wait_for_routable_state( + interface, renderer, do_routable=True, max_wait=30 +): + attempts = 0 + routable_msg = "routable" if do_routable else "NOT routable" + while attempts <= max_wait: + attempts += 1 + (routable, _) = _check_routable_state(interface, renderer) + if routable == do_routable: + print("Reached {} state".format(routable_msg)) + return + time.sleep(1) + raise SystemExit("Failed to reach {} state!".format(routable_msg)) def has_cable(iface): @@ -17,6 +113,46 @@ def has_cable(iface): return carrier.read()[0] == "1" +def wait_for_cable_state(iface, do_cable=True, max_wait=30): + """Wait for the cable state to be True or False.""" + attempts = 0 + cable_msg = "plugged" if do_cable else "unplugged" + while attempts <= max_wait: + attempts += 1 + if has_cable(iface) == do_cable: + print("Detected cable state: {}".format(cable_msg)) + return + time.sleep(1) + raise SystemExit("Failed to detect {}!".format(cable_msg)) + + +def help_wait_cable_and_routable_state(iface, do_check=True): + if do_check: + do_cable = True + do_routable = True + else: + do_cable = False + do_routable = False + + renderer = netplan_renderer() + print( + "Waiting for cable to get {}.".format( + "connected" if do_cable else "disconnected" + ) + ) + wait_for_cable_state(iface, do_cable, 60) + + print( + "Waiting for networkd/NetworkManager {}.".format( + "routable" if do_routable else "NOT routable" + ) + ) + wait_for_routable_state(iface, renderer, do_routable, 60) + + print("Cable {}!".format("connected" if do_cable else "disconnected")) + print("Network {}!".format("routable" if do_routable else "NOT routable")) + + def main(): """Entry point to the program.""" if len(sys.argv) != 2: @@ -37,28 +173,15 @@ def main(): print("After 15 seconds plug it back in.") print("Checkbox session may be interrupted but it should come back up.") input() - print("Waiting for cable to get disconnected.") - elapsed = 0 - while elapsed < 60: - if not has_cable(sys.argv[1]): - break - time.sleep(1) - print(".", flush=True, end="") - elapsed += 1 - else: - raise SystemExit("Failed to detect unplugging!") - print("Cable unplugged!") - print("Waiting for the cable to get connected.") - elapsed = 0 - while elapsed < 60: - if has_cable(sys.argv[1]): - break - time.sleep(1) - print(".", flush=True, end="") - elapsed += 1 - else: - raise SystemExit("Failed to detect plugging it back!") - print("Cable detected!") + + help_wait_cable_and_routable_state(iface, False) + + print("Please plug the cable back in.") + + help_wait_cable_and_routable_state(iface, True) + + print("Pinging gateway...") + perform_ping_test(iface) if __name__ == "__main__": diff --git a/providers/base/tests/test_eth_hotplugging.py b/providers/base/tests/test_eth_hotplugging.py new file mode 100644 index 0000000000..98d9488264 --- /dev/null +++ b/providers/base/tests/test_eth_hotplugging.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Canonical Ltd. +# Written by: +# Isaac Yang +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 3, +# as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +import textwrap +import unittest +from unittest import TestCase +from unittest.mock import patch, MagicMock, mock_open, ANY +import datetime +import subprocess as sp +import io +import sys +from unittest.mock import call +from eth_hotplugging import ( + netplan_renderer, + get_interface_info, + _check_routable_state, + wait_for_routable_state, + has_cable, + wait_for_cable_state, + help_wait_cable_and_routable_state, + main, +) + + +class EthHotpluggingTests(TestCase): + @patch( + "builtins.open", + new_callable=mock_open, + read_data="network:\n renderer: networkd", + ) + @patch("os.path.exists", return_value=True) + @patch("glob.glob", return_value=["/etc/netplan/01-netcfg.yaml"]) + def test_renderer_networkd(self, mock_exists, mock_glob, mock_open): + renderer = netplan_renderer() + self.assertEqual(renderer, "networkd") + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="network:\n abc: def", + ) + @patch("os.path.exists", return_value=True) + @patch("glob.glob", return_value=["/etc/netplan/01-netcfg.yaml"]) + def test_renderer_networkd_no_renderer( + self, mock_exists, mock_glob, mock_open + ): + renderer = netplan_renderer() + self.assertEqual(renderer, "networkd") + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="network:\n renderer: NetworkManager", + ) + @patch("glob.glob", return_value=["/etc/netplan/01-netcfg.yaml"]) + @patch("os.path.exists", return_value=True) + def test_renderer_networkmanager(self, mock_exists, mock_glob, mock_open): + renderer = netplan_renderer() + self.assertEqual(renderer, "NetworkManager") + + @patch("glob.glob", return_value=[]) + @patch("os.path.exists", return_value=True) + def test_no_yaml_files(self, mock_exists, mock_glob): + renderer = netplan_renderer() + self.assertEqual(renderer, "NetworkManager") + + @patch("subprocess.check_output") + def test_get_interface_info_networkd(self, mock_check_output): + mock_check_output.return_value = ( + b"State: routable\nGateway: 192.168.1.1\nPath: pci-0000:02:00.0" + ) + interface = "eth0" + renderer = "networkd" + info = get_interface_info(interface, renderer) + self.assertEqual(info["state"], "routable") + self.assertEqual(info["gateway"], "192.168.1.1") + + @patch("subprocess.check_output") + def test_get_interface_info_networkd_any_name(self, mock_check_output): + mock_check_output.return_value = ( + b"State: routable\nGateway: 192.168.1.1 (TP-Link 123)\n" + b"Path: pci-0000:02:00.0" + ) + interface = "eth0" + renderer = "networkd" + info = get_interface_info(interface, renderer) + self.assertEqual(info["state"], "routable") + self.assertEqual(info["gateway"], "192.168.1.1 (TP-Link 123)") + + @patch("subprocess.check_output") + def test_get_interface_info_networkd_no_state(self, mock_check_output): + mock_check_output.return_value = ( + b"Some other info: value\nsome more info" + ) + interface = "eth0" + renderer = "networkd" + info = get_interface_info(interface, renderer) + self.assertNotIn("state", info) + self.assertNotIn("gateway", info) + + @patch("subprocess.check_output") + def test_get_interface_info_networkd_empty_output(self, mock_check_output): + mock_check_output.return_value = b"" + interface = "eth0" + renderer = "networkd" + info = get_interface_info(interface, renderer) + self.assertEqual(info, {}) + + @patch( + "subprocess.check_output", + side_effect=sp.CalledProcessError(1, "Command failed"), + ) + def test_get_interface_info_networkd_command_fails( + self, mock_check_output + ): + captured_output = io.StringIO() + sys.stdout = captured_output + interface = "eth0" + renderer = "networkd" + info = get_interface_info(interface, renderer) + sys.stdout = sys.__stdout__ + self.assertEqual(info, {}) + self.assertIn( + "Error running networkd command", captured_output.getvalue() + ) + + @patch("subprocess.check_output") + def test_get_interface_info_networkmanager(self, mock_check_output): + mock_check_output.return_value = ( + b"GENERAL.MTU: 1500\n" + b"GENERAL.STATE: 100 (connected)\n" + b"IP4.GATEWAY: 192.168.1.1" + ) + interface = "eth0" + renderer = "NetworkManager" + info = get_interface_info(interface, renderer) + self.assertEqual(info["state"], "100 (connected)") + self.assertEqual(info["gateway"], "192.168.1.1") + + @patch("subprocess.check_output") + def test_get_interface_info_networkmanager_unexpected_output( + self, mock_check_output + ): + mock_check_output.return_value = b"some unexpected output" + interface = "eth0" + renderer = "NetworkManager" + info = get_interface_info(interface, renderer) + self.assertEqual(info, {}) + + @patch( + "subprocess.check_output", + side_effect=sp.CalledProcessError(1, "Command failed"), + ) + def test_get_interface_info_networkmanager_command_fails( + self, mock_check_output + ): + captured_output = io.StringIO() + sys.stdout = captured_output + interface = "eth0" + renderer = "NetworkManager" + info = get_interface_info(interface, renderer) + sys.stdout = sys.__stdout__ + self.assertEqual(info, {}) + self.assertIn( + "Error running NetworkManager command", + captured_output.getvalue(), + ) + + def test_get_interface_info_unknown_renderer(self): + interface = "eth0" + renderer = "unknown" + with self.assertRaises(ValueError): + get_interface_info(interface, renderer) + + @patch( + "eth_hotplugging.get_interface_info", + return_value={"state": "routable"}, + ) + def test_check_routable_state_networkd(self, mock_get_interface_info): + renderer = "networkd" + self.assertTrue(_check_routable_state("eth0", renderer)) + + @patch( + "eth_hotplugging.get_interface_info", + return_value={"state": "connected"}, + ) + def test_check_routable_state_networkmanager(self, mock_get_interface_info): + renderer = "NetworkManager" + self.assertTrue(_check_routable_state("eth0", renderer)) + + @patch( + "eth_hotplugging._check_routable_state", + return_value=(True, "routable"), + ) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + def test_immediate_routable(self, mock_check_state): + captured_output = io.StringIO() + sys.stdout = captured_output + wait_for_routable_state("eth0", "networkd") + sys.stdout = sys.__stdout__ + mock_check_state.assert_called_once_with("eth0", "networkd") + self.assertIn("Reached routable state", captured_output.getvalue()) + + @patch( + "eth_hotplugging._check_routable_state", + side_effect=[ + (False, "configuring"), + (False, "configuring"), + (True, "routable"), + ], + ) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + def test_eventually_routable(self, mock_check_state): + captured_output = io.StringIO() + sys.stdout = captured_output + wait_for_routable_state("eth0", "networkd") + sys.stdout = sys.__stdout__ + self.assertIn("Reached routable state", captured_output.getvalue()) + + @patch( + "eth_hotplugging._check_routable_state", + return_value=(False, "configuring"), + ) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + def test_never_routable(self, mock_check_state): + with self.assertRaises(SystemExit) as cm: + wait_for_routable_state("eth0", "networkd", max_wait=3) + self.assertEqual( + str(cm.exception), "Failed to reach routable state!" + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="1", + ) + def test_has_cable_true(self, mock_open): + result = has_cable("eth0") + self.assertTrue(result) + mock_open.assert_called_once_with("/sys/class/net/eth0/carrier") + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="0", + ) + def test_has_cable_false(self, mock_open): + result = has_cable("eth0") + self.assertFalse(result) + mock_open.assert_called_once_with("/sys/class/net/eth0/carrier") + + @patch( + "eth_hotplugging.has_cable", + return_value=True, + ) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + def test_immediate_cable(self, mock_has_cable): + captured_output = io.StringIO() + sys.stdout = captured_output + wait_for_cable_state("eth0", do_cable=True) + sys.stdout = sys.__stdout__ + mock_has_cable.assert_called_once_with("eth0") + self.assertIn( + "Detected cable state: plugged", captured_output.getvalue() + ) + + @patch( + "eth_hotplugging.has_cable", + side_effect=[ + False, + False, + True, + ], + ) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + def test_eventually_cable(self, mock_has_cable): + captured_output = io.StringIO() + sys.stdout = captured_output + wait_for_cable_state("eth0", do_cable=True, max_wait=3) + sys.stdout = sys.__stdout__ + self.assertIn( + "Detected cable state: plugged", captured_output.getvalue() + ) + + @patch( + "eth_hotplugging.has_cable", + return_value=False, + ) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + def test_never_cable(self, mock_has_cable): + with self.assertRaises(SystemExit) as cm: + wait_for_cable_state("eth0", do_cable=True, max_wait=3) + self.assertEqual(str(cm.exception), "Failed to detect plugged!") + + @patch("eth_hotplugging.netplan_renderer", return_value="networkd") + @patch("eth_hotplugging.has_cable", return_value=True) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + @patch( + "eth_hotplugging._check_routable_state", + return_value=(True, "routable"), + ) + def test_help_wait_cable_and_routable_state_true( + self, + mock_check_routable_state, + mock_has_cable, + mock_netplan_renderer, + ): + captured_output = io.StringIO() + sys.stdout = captured_output + help_wait_cable_and_routable_state("eth0", do_check=True) + sys.stdout = sys.__stdout__ + self.assertIn("Cable connected!", captured_output.getvalue()) + self.assertIn("Network routable!", captured_output.getvalue()) + + @patch("eth_hotplugging.netplan_renderer", return_value="NetworkManager") + @patch("eth_hotplugging.has_cable", return_value=False) + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + @patch( + "eth_hotplugging._check_routable_state", + return_value=(False, "routable"), + ) + def test_help_wait_cable_and_routable_state_false( + self, + mock_check_routable_state, + mock_has_cable, + mock_netplan_renderer, + ): + captured_output = io.StringIO() + sys.stdout = captured_output + help_wait_cable_and_routable_state("eth0", do_check=False) + sys.stdout = sys.__stdout__ + self.assertIn("Cable disconnected!", captured_output.getvalue()) + self.assertIn("Network NOT routable!", captured_output.getvalue()) + + +class TestMain(TestCase): + @patch("eth_hotplugging.perform_ping_test") + @patch("eth_hotplugging.help_wait_cable_and_routable_state") + @patch("eth_hotplugging._check_routable_state") + @patch("eth_hotplugging.has_cable") + @patch("builtins.input", return_value="") + @patch("sys.argv", ["eth_hotplugging.py", "eth0"]) + @patch("builtins.print") + @patch("eth_hotplugging.time.sleep", new=MagicMock()) + def test_main_successful_execution( + self, + mock_print, + mock_input, + mock_has_cable, + mock_check_routable_state, + mock_help_wait, + mock_ping_test, + ): + mock_has_cable.return_value = True + mock_check_routable_state.return_value = (True, "routable") + + main() + + mock_ping_test.assert_called_once_with("eth0") + + @patch("sys.argv", ["eth_hotplugging.py"]) + def test_main_no_interface_argument(self): + with self.assertRaises(SystemExit) as cm: + main() + self.assertEqual( + str(cm.exception), + "Usage eth_hotplugging.py INTERFACE_NAME", + ) + + @patch("eth_hotplugging.has_cable", side_effect=FileNotFoundError) + @patch("builtins.input", return_value="") + @patch("sys.argv", ["eth_hotplugging.py", "eth0"]) + def test_main_raises_error_when_interface_not_found( + self, mock_input, mock_has_cable + ): + with self.assertRaises(SystemExit) as cm: + main() + self.assertIn( + "Could not check the cable for 'eth0'", + str(cm.exception), + ) + + +if __name__ == "__main__": + unittest.main()