Skip to content

Commit

Permalink
iot2050-firmware-update: Fix bug in class destruction and refactor code
Browse files Browse the repository at this point in the history
This commit addresses two issues:

1. Fixed two bugs in destruction of classes. It was trying to release a
non-existent object, causing an error. This has been corrected.

2. Removed an unnecessary pattern from the Firmware class. This pattern
was not contributing to the functionality of the class and was causing
confusion. Its removal simplifies the class.

Signed-off-by: Li Hua Qian <[email protected]>
  • Loading branch information
huaqianli committed Apr 9, 2024
1 parent 1599b30 commit 9cf6905
Showing 1 changed file with 45 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
# Copyright (c) Siemens AG, 2020-2023
# Copyright (c) Siemens AG, 2020-2024
#
# Authors:
# Chao Zeng <[email protected]>
Expand Down Expand Up @@ -142,22 +142,20 @@ class UpgradeError(Exception):
self.code = code

def __str__(self):
return repr(self)
return self.err


class Firmware():
"""The Firmware class represents flash base operations for all flashes"""
def __init__(self, firmware):
try:
self.firmware = open(firmware, "rb")
except (IOError, TypeError):
self.firmware = firmware
if not isinstance(firmware, io.IOBase):
raise UpgradeError("TypeError: firmware must be a file-like object!")
self.firmware = firmware

def __del__(self):
try:
if hasattr(self, 'firmware'):
self.firmware.close()
except Exception:
pass
del self.firmware

@abstractmethod
def write(self):
Expand Down Expand Up @@ -233,13 +231,10 @@ class MtdDevice():
firmware_content += bytearray([0xff] * padsize)

if not mtd_content == firmware_content:
#sys.stdout.flush()
self.__erase(mtd_dev, mtd_pos, mtd_erasesize)
os.lseek(mtd_dev, mtd_pos, os.SEEK_SET)
os.write(mtd_dev, firmware_content)
#else:
# print(".", end="")
# sys.stdout.flush()

mtd_pos += mtd_erasesize
file_size -= mtd_erasesize
os.close(mtd_dev)
Expand Down Expand Up @@ -332,9 +327,9 @@ class BootloaderFirmware(Firmware):

class EnvFirmware(Firmware):
"""The EnvFirmware class represents env partition operations"""
def __init__(self, firmware):
def __init__(self, firmware_path, firmware):
super().__init__(firmware)
self.firmware_path = firmware
self.firmware_path = firmware_path
self.mtd_device = MtdDevice()

mtd_num = 0
Expand All @@ -356,6 +351,9 @@ class EnvFirmware(Firmware):
if self.env_mtd_num and self.env_bk_mtd_num:
break

if mtd_num > 20:
raise UpgradeError("EnvFirmware: No env partition found")

def write(self):
"""A env firmware can write contents to the env partition"""
with tempfile.NamedTemporaryFile() as env_default_binary:
Expand All @@ -371,13 +369,13 @@ class EnvFirmware(Firmware):
firmware_size = os.path.getsize(env_default_binary.name)
self.firmware = open(env_default_binary.name, "rb")

while True:
if firmware_size <= 0:
break
firmware_size = self.mtd_device.write(
mtd_dev_path, mtd_size, mtd_erasesize,
env_default_binary, firmware_size
)
firmware_size = self.mtd_device.write(
mtd_dev_path, mtd_size, mtd_erasesize,
env_default_binary, firmware_size
)

if firmware_size > 0:
raise UpgradeError("Write env failed")
except subprocess.CalledProcessError as error:
print(error.stdout)
raise UpgradeError("EnvFirmware: Run mkenvimage failed")
Expand Down Expand Up @@ -408,25 +406,28 @@ class ForceUpdate():
self.firmware_type = firmware_type
self.interactor = interactor

def update(self):
if self.firmware_type == "uboot":
firmware_obj = BootloaderFirmware(self.firmware)
try:
self.firmware_obj = BootloaderFirmware(self.firmware)
except UpgradeError as e:
raise UpgradeError(e.err, ErrorCode.INVALID_FIRMWARE.value)
else:
raise UpgradeError("Unsupported firmware type!")
raise UpgradeError("Unsupported firmware type!", ErrorCode.INVALID_FIRMWARE.value)

def update(self):
print("===================================================")
print("IOT2050 firmware update started - DO NOT INTERRUPT!")
print("===================================================")

self.interactor.progress_bar(info="Updating {}".format(self.firmware_type))
firmware_obj.write()
self.firmware_obj.write()

firmware_md5 = hashlib.md5()
self.firmware.seek(0)
firmware_md5.update(self.firmware.read())

read_out_md5 = hashlib.md5()
read_out_md5.update(firmware_obj.read())
read_out_md5.update(self.firmware_obj.read())
self.interactor.progress_bar(start=False)

if firmware_md5.hexdigest() != read_out_md5.hexdigest():
Expand Down Expand Up @@ -472,11 +473,13 @@ class FirmwareUpdate():
print("\nPreserved env list: ")
for env in env_list:
print(env)
self.firmwares[firmware_type] = EnvFirmware(
env_path, env_binary = \
self.tarball.generate_env_firmware(env_list)
)
self.firmwares[firmware_type] = \
EnvFirmware(env_path, env_binary)
else:
self.firmwares[firmware_type] = EnvFirmware(
self.tarball.get_file_path(self.tarball.UBOOT_ENV_FILE),
self.tarball.get_file(self.tarball.UBOOT_ENV_FILE)
)
elif firmware_type == self.tarball.FIRMWARE_TYPES[2]:
Expand Down Expand Up @@ -557,7 +560,6 @@ class FirmwareUpdate():

self.firmwares[firmware_type].write()


self.firmwares[firmware_type].firmware.seek(0)
content = self.firmwares[firmware_type].firmware.read()
firmware_md5 = self.__get_md5_digest(content)
Expand Down Expand Up @@ -602,12 +604,14 @@ class FirmwareTarball(object):
# extract file path
self.extract_path = "/tmp"
self.firmware_tarball.seek(0)
self.extracted_files = []
with tarfile.open(fileobj=self.firmware_tarball) as f:
for member in f:
file_tarfileinfo = f.getmember(name=member.name)
file_tarfileinfo.uid = os.getuid()
file_tarfileinfo.gid = os.getgid()
f.extract(file_tarfileinfo, path=self.extract_path)
self.extracted_files.append(self.extract_path + "/" + member.name)

self._board_info = BoardInfo()
print("Current board: {}".format(self._board_info.board_name))
Expand All @@ -617,7 +621,7 @@ class FirmwareTarball(object):
# to access the fields.
try:
self._jsonobj = json.load(
open(self.get_file(self.CONF_JSON), "rb"),
self.get_file(self.CONF_JSON),
object_hook=lambda d: Namespace(**d)
)
except ValueError:
Expand All @@ -626,9 +630,8 @@ class FirmwareTarball(object):
self.firmware_names = dict.fromkeys(self.FIRMWARE_TYPES)

def __del__(self):
with tarfile.open(fileobj=self.firmware_tarball) as f:
for member in f:
os.remove(self.extract_path + "/" + member.name)
for file in self.extracted_files:
os.remove(file)

def __check_os(self, target_os, os_info) -> bool:
for tos in target_os:
Expand Down Expand Up @@ -704,6 +707,12 @@ class FirmwareTarball(object):
"""Get the file object of specified name"""
file = os.path.join(self.extract_path, name)

return open(file, 'rb')

def get_file_path(self, name):
"""Get the file object of specified name"""
file = os.path.join(self.extract_path, name)

return file

def __get_suggest_preserved_uboot_env(self):
Expand Down Expand Up @@ -773,7 +782,7 @@ class FirmwareTarball(object):
file.write(value)
file.write("\n")

return uboot_env_assemble_file
return uboot_env_assemble_file, open(uboot_env_assemble_file, 'rb')


class BoardInfo(object):
Expand Down Expand Up @@ -915,7 +924,7 @@ def main(argv):
group = parser.add_mutually_exclusive_group()
parser.add_argument('firmware', nargs='?', metavar='FIRMWARE',
type=argparse.FileType('rb'),
help='firmware tarball')
help='firmware or tarball')
group.add_argument('-f', '--force',
help='Force update, ignore all the checking',
action='store_true')
Expand Down

0 comments on commit 9cf6905

Please sign in to comment.