From 3147cb9c335559d81a9a6c34f1ee0a3111e3b8c5 Mon Sep 17 00:00:00 2001
From: Li Hua Qian <huaqian.li@siemens.com>
Date: Wed, 27 Mar 2024 11:47:18 +0800
Subject: [PATCH] iot2050-firmware-update: Fix bug in class destruction and
 refactor code

This commit addresses two issues:

1. Fixed two bugs in destruction of classes. It was trying to release a
non-existent object, causing an error. This has been corrected.

2. Removed an unnecessary pattern from the Firmware class. This pattern
was not contributing to the functionality of the class and was causing
confusion. Its removal simplifies the class.

Signed-off-by: Li Hua Qian <huaqian.li@siemens.com>
---
 .../files/iot2050-firmware-update.tmpl        | 81 ++++++++++---------
 1 file changed, 45 insertions(+), 36 deletions(-)

diff --git a/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl b/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl
index 2124ef022..759f70d9e 100755
--- a/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl
+++ b/recipes-app/iot2050-firmware-update/files/iot2050-firmware-update.tmpl
@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 #
-# Copyright (c) Siemens AG, 2020-2023
+# Copyright (c) Siemens AG, 2020-2024
 #
 # Authors:
 #  Chao Zeng <chao.zeng@siemens.com>
@@ -142,22 +142,20 @@ class UpgradeError(Exception):
         self.code = code
 
     def __str__(self):
-        return repr(self)
+        return self.err
 
 
 class Firmware():
     """The Firmware class represents flash base operations for all flashes"""
     def __init__(self, firmware):
-        try:
-            self.firmware = open(firmware, "rb")
-        except (IOError, TypeError):
-            self.firmware = firmware
+        if not isinstance(firmware, io.IOBase):
+            raise UpgradeError("TypeError: firmware must be a file-like object!")
+        self.firmware = firmware
 
     def __del__(self):
-        try:
+        if hasattr(self, 'firmware'):
             self.firmware.close()
-        except Exception:
-            pass
+            del self.firmware
 
     @abstractmethod
     def write(self):
@@ -233,13 +231,10 @@ class MtdDevice():
                 firmware_content += bytearray([0xff] * padsize)
 
                 if not mtd_content == firmware_content:
-                    #sys.stdout.flush()
                     self.__erase(mtd_dev, mtd_pos, mtd_erasesize)
                     os.lseek(mtd_dev, mtd_pos, os.SEEK_SET)
                     os.write(mtd_dev, firmware_content)
-                #else:
-                #    print(".", end="")
-                #    sys.stdout.flush()
+
                 mtd_pos += mtd_erasesize
                 file_size -= mtd_erasesize
             os.close(mtd_dev)
@@ -332,9 +327,9 @@ class BootloaderFirmware(Firmware):
 
 class EnvFirmware(Firmware):
     """The EnvFirmware class represents env partition operations"""
-    def __init__(self, firmware):
+    def __init__(self, firmware_path, firmware):
         super().__init__(firmware)
-        self.firmware_path = firmware
+        self.firmware_path = firmware_path
         self.mtd_device = MtdDevice()
 
         mtd_num = 0
@@ -356,6 +351,9 @@ class EnvFirmware(Firmware):
             if self.env_mtd_num and self.env_bk_mtd_num:
                 break
 
+            if mtd_num > 20:
+                raise UpgradeError("EnvFirmware: No env partition found")
+
     def write(self):
         """A env firmware can write contents to the env partition"""
         with tempfile.NamedTemporaryFile() as env_default_binary:
@@ -371,13 +369,13 @@ class EnvFirmware(Firmware):
                     firmware_size = os.path.getsize(env_default_binary.name)
                     self.firmware = open(env_default_binary.name, "rb")
 
-                    while True:
-                        if firmware_size <= 0:
-                            break
-                        firmware_size = self.mtd_device.write(
-                            mtd_dev_path, mtd_size, mtd_erasesize,
-                            env_default_binary, firmware_size
-                        )
+                    firmware_size = self.mtd_device.write(
+                        mtd_dev_path, mtd_size, mtd_erasesize,
+                        env_default_binary, firmware_size
+                    )
+
+                    if firmware_size > 0:
+                        raise UpgradeError("Write env failed")
                 except subprocess.CalledProcessError as error:
                     print(error.stdout)
                     raise UpgradeError("EnvFirmware: Run mkenvimage failed")
@@ -408,25 +406,28 @@ class ForceUpdate():
         self.firmware_type = firmware_type
         self.interactor = interactor
 
-    def update(self):
         if self.firmware_type == "uboot":
-            firmware_obj = BootloaderFirmware(self.firmware)
+            try:
+                self.firmware_obj = BootloaderFirmware(self.firmware)
+            except UpgradeError as e:
+                raise UpgradeError(e.err, ErrorCode.INVALID_FIRMWARE.value)
         else:
-            raise UpgradeError("Unsupported firmware type!")
+            raise UpgradeError("Unsupported firmware type!", ErrorCode.INVALID_FIRMWARE.value)
 
+    def update(self):
         print("===================================================")
         print("IOT2050 firmware update started - DO NOT INTERRUPT!")
         print("===================================================")
 
         self.interactor.progress_bar(info="Updating {}".format(self.firmware_type))
-        firmware_obj.write()
+        self.firmware_obj.write()
 
         firmware_md5 = hashlib.md5()
         self.firmware.seek(0)
         firmware_md5.update(self.firmware.read())
 
         read_out_md5  = hashlib.md5()
-        read_out_md5.update(firmware_obj.read())
+        read_out_md5.update(self.firmware_obj.read())
         self.interactor.progress_bar(start=False)
 
         if firmware_md5.hexdigest() != read_out_md5.hexdigest():
@@ -472,11 +473,13 @@ class FirmwareUpdate():
                         print("\nPreserved env list: ")
                         for env in env_list:
                             print(env)
-                        self.firmwares[firmware_type] = EnvFirmware(
+                        env_path, env_binary = \
                             self.tarball.generate_env_firmware(env_list)
-                        )
+                        self.firmwares[firmware_type] = \
+                            EnvFirmware(env_path, env_binary)
                     else:
                         self.firmwares[firmware_type] = EnvFirmware(
+                            self.tarball.get_file_path(self.tarball.UBOOT_ENV_FILE),
                             self.tarball.get_file(self.tarball.UBOOT_ENV_FILE)
                         )
                 elif firmware_type ==  self.tarball.FIRMWARE_TYPES[2]:
@@ -557,7 +560,6 @@ class FirmwareUpdate():
 
                 self.firmwares[firmware_type].write()
 
-
                 self.firmwares[firmware_type].firmware.seek(0)
                 content = self.firmwares[firmware_type].firmware.read()
                 firmware_md5 = self.__get_md5_digest(content)
@@ -602,12 +604,14 @@ class FirmwareTarball(object):
         # extract file path
         self.extract_path = "/tmp"
         self.firmware_tarball.seek(0)
+        self.extracted_files = []
         with tarfile.open(fileobj=self.firmware_tarball) as f:
             for member in f:
                 file_tarfileinfo = f.getmember(name=member.name)
                 file_tarfileinfo.uid = os.getuid()
                 file_tarfileinfo.gid = os.getgid()
                 f.extract(file_tarfileinfo, path=self.extract_path)
+                self.extracted_files.append(self.extract_path + "/" + member.name)
 
         self._board_info = BoardInfo()
         print("Current board: {}".format(self._board_info.board_name))
@@ -617,7 +621,7 @@ class FirmwareTarball(object):
         # to access the fields.
         try:
             self._jsonobj = json.load(
-                open(self.get_file(self.CONF_JSON), "rb"),
+                self.get_file(self.CONF_JSON),
                 object_hook=lambda d: Namespace(**d)
             )
         except ValueError:
@@ -626,9 +630,8 @@ class FirmwareTarball(object):
         self.firmware_names = dict.fromkeys(self.FIRMWARE_TYPES)
 
     def __del__(self):
-        with tarfile.open(fileobj=self.firmware_tarball) as f:
-            for member in f:
-                os.remove(self.extract_path + "/" + member.name)
+        for file in self.extracted_files:
+            os.remove(file)
 
     def __check_os(self, target_os, os_info) -> bool:
         for tos in target_os:
@@ -704,6 +707,12 @@ class FirmwareTarball(object):
         """Get the file object of specified name"""
         file = os.path.join(self.extract_path, name)
 
+        return open(file, 'rb')
+
+    def get_file_path(self, name):
+        """Get the file object of specified name"""
+        file = os.path.join(self.extract_path, name)
+
         return file
 
     def __get_suggest_preserved_uboot_env(self):
@@ -773,7 +782,7 @@ class FirmwareTarball(object):
                 file.write(value)
                 file.write("\n")
 
-        return uboot_env_assemble_file
+        return uboot_env_assemble_file, open(uboot_env_assemble_file, 'rb')
 
 
 class BoardInfo(object):
@@ -915,7 +924,7 @@ def main(argv):
     group = parser.add_mutually_exclusive_group()
     parser.add_argument('firmware', nargs='?', metavar='FIRMWARE',
                         type=argparse.FileType('rb'),
-                        help='firmware tarball')
+                        help='firmware or tarball')
     group.add_argument('-f', '--force',
                         help='Force update, ignore all the checking',
                         action='store_true')