diff --git a/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl b/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl index 2124ef022..4714a1f5f 100755 --- a/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl +++ b/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright (c) Siemens AG, 2020-2023 +# Copyright (c) Siemens AG, 2020-2024 # # Authors: # Chao Zeng @@ -142,22 +142,20 @@ class UpgradeError(Exception): self.code = code def __str__(self): - return repr(self) + return self.err class Firmware(): """The Firmware class represents flash base operations for all flashes""" def __init__(self, firmware): - try: - self.firmware = open(firmware, "rb") - except (IOError, TypeError): - self.firmware = firmware + if not isinstance(firmware, io.IOBase): + raise UpgradeError("TypeError: firmware must be a file-like object!") + self.firmware = firmware def __del__(self): - try: + if hasattr(self, 'firmware'): self.firmware.close() - except Exception: - pass + del self.firmware @abstractmethod def write(self): @@ -233,13 +231,10 @@ class MtdDevice(): firmware_content += bytearray([0xff] * padsize) if not mtd_content == firmware_content: - #sys.stdout.flush() self.__erase(mtd_dev, mtd_pos, mtd_erasesize) os.lseek(mtd_dev, mtd_pos, os.SEEK_SET) os.write(mtd_dev, firmware_content) - #else: - # print(".", end="") - # sys.stdout.flush() + mtd_pos += mtd_erasesize file_size -= mtd_erasesize os.close(mtd_dev) @@ -332,15 +327,17 @@ class BootloaderFirmware(Firmware): class EnvFirmware(Firmware): """The EnvFirmware class represents env partition operations""" - def __init__(self, firmware): + def __init__(self, firmware_path, firmware): super().__init__(firmware) - self.firmware_path = firmware + self.firmware_path = firmware_path self.mtd_device = MtdDevice() mtd_num = 0 self.env_mtd_num = 0 self.env_bk_mtd_num = 0 - while True: + # mtd_device is typically less than 20, if one mtd device can't be + # located in 20 rounds, jump out the loop. + for mtd_num in range(20): try: mtd_dev_path, mtd_size, mtd_erasesize, mtd_name = \ self.mtd_device.get_mtd_info(mtd_num) @@ -351,10 +348,11 @@ class EnvFirmware(Firmware): self.env_mtd_num = mtd_num if "env.backup" == mtd_name: self.env_bk_mtd_num = mtd_num - mtd_num +=1 if self.env_mtd_num and self.env_bk_mtd_num: break + else: + raise UpgradeError("EnvFirmware: No env partition found") def write(self): """A env firmware can write contents to the env partition""" @@ -371,13 +369,13 @@ class EnvFirmware(Firmware): firmware_size = os.path.getsize(env_default_binary.name) self.firmware = open(env_default_binary.name, "rb") - while True: - if firmware_size <= 0: - break - firmware_size = self.mtd_device.write( - mtd_dev_path, mtd_size, mtd_erasesize, - env_default_binary, firmware_size - ) + firmware_size = self.mtd_device.write( + mtd_dev_path, mtd_size, mtd_erasesize, + env_default_binary, firmware_size + ) + + if firmware_size > 0: + raise UpgradeError("Write env failed") except subprocess.CalledProcessError as error: print(error.stdout) raise UpgradeError("EnvFirmware: Run mkenvimage failed") @@ -408,25 +406,28 @@ class ForceUpdate(): self.firmware_type = firmware_type self.interactor = interactor - def update(self): if self.firmware_type == "uboot": - firmware_obj = BootloaderFirmware(self.firmware) + try: + self.firmware_obj = BootloaderFirmware(self.firmware) + except UpgradeError as e: + raise UpgradeError(e.err, ErrorCode.INVALID_FIRMWARE.value) else: - raise UpgradeError("Unsupported firmware type!") + raise UpgradeError("Unsupported firmware type!", ErrorCode.INVALID_FIRMWARE.value) + def update(self): print("===================================================") print("IOT2050 firmware update started - DO NOT INTERRUPT!") print("===================================================") self.interactor.progress_bar(info="Updating {}".format(self.firmware_type)) - firmware_obj.write() + self.firmware_obj.write() firmware_md5 = hashlib.md5() self.firmware.seek(0) firmware_md5.update(self.firmware.read()) read_out_md5 = hashlib.md5() - read_out_md5.update(firmware_obj.read()) + read_out_md5.update(self.firmware_obj.read()) self.interactor.progress_bar(start=False) if firmware_md5.hexdigest() != read_out_md5.hexdigest(): @@ -472,11 +473,13 @@ class FirmwareUpdate(): print("\nPreserved env list: ") for env in env_list: print(env) - self.firmwares[firmware_type] = EnvFirmware( + env_path, env_binary = \ self.tarball.generate_env_firmware(env_list) - ) + self.firmwares[firmware_type] = \ + EnvFirmware(env_path, env_binary) else: self.firmwares[firmware_type] = EnvFirmware( + self.tarball.get_file_path(self.tarball.UBOOT_ENV_FILE), self.tarball.get_file(self.tarball.UBOOT_ENV_FILE) ) elif firmware_type == self.tarball.FIRMWARE_TYPES[2]: @@ -557,7 +560,6 @@ class FirmwareUpdate(): self.firmwares[firmware_type].write() - self.firmwares[firmware_type].firmware.seek(0) content = self.firmwares[firmware_type].firmware.read() firmware_md5 = self.__get_md5_digest(content) @@ -602,12 +604,14 @@ class FirmwareTarball(object): # extract file path self.extract_path = "/tmp" self.firmware_tarball.seek(0) + self.extracted_files = [] with tarfile.open(fileobj=self.firmware_tarball) as f: for member in f: file_tarfileinfo = f.getmember(name=member.name) file_tarfileinfo.uid = os.getuid() file_tarfileinfo.gid = os.getgid() f.extract(file_tarfileinfo, path=self.extract_path) + self.extracted_files.append(self.extract_path + "/" + member.name) self._board_info = BoardInfo() print("Current board: {}".format(self._board_info.board_name)) @@ -617,7 +621,7 @@ class FirmwareTarball(object): # to access the fields. try: self._jsonobj = json.load( - open(self.get_file(self.CONF_JSON), "rb"), + self.get_file(self.CONF_JSON), object_hook=lambda d: Namespace(**d) ) except ValueError: @@ -626,9 +630,8 @@ class FirmwareTarball(object): self.firmware_names = dict.fromkeys(self.FIRMWARE_TYPES) def __del__(self): - with tarfile.open(fileobj=self.firmware_tarball) as f: - for member in f: - os.remove(self.extract_path + "/" + member.name) + for file in self.extracted_files: + os.remove(file) def __check_os(self, target_os, os_info) -> bool: for tos in target_os: @@ -704,6 +707,12 @@ class FirmwareTarball(object): """Get the file object of specified name""" file = os.path.join(self.extract_path, name) + return open(file, 'rb') + + def get_file_path(self, name): + """Get the file object of specified name""" + file = os.path.join(self.extract_path, name) + return file def __get_suggest_preserved_uboot_env(self): @@ -773,7 +782,7 @@ class FirmwareTarball(object): file.write(value) file.write("\n") - return uboot_env_assemble_file + return uboot_env_assemble_file, open(uboot_env_assemble_file, 'rb') class BoardInfo(object): @@ -915,7 +924,7 @@ def main(argv): group = parser.add_mutually_exclusive_group() parser.add_argument('firmware', nargs='?', metavar='FIRMWARE', type=argparse.FileType('rb'), - help='firmware tarball') + help='firmware or tarball') group.add_argument('-f', '--force', help='Force update, ignore all the checking', action='store_true')