Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine and fix bugs about firmware update tool and watchdog issue #533

Merged
merged 6 commits into from
Jun 5, 2024
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
# Copyright (c) Siemens AG, 2020-2023
# Copyright (c) Siemens AG, 2020-2024
#
# Authors:
# Chao Zeng <[email protected]>
Expand Down Expand Up @@ -142,22 +142,20 @@ class UpgradeError(Exception):
self.code = code

def __str__(self):
return repr(self)
return self.err


class Firmware():
"""The Firmware class represents flash base operations for all flashes"""
def __init__(self, firmware):
try:
self.firmware = open(firmware, "rb")
except (IOError, TypeError):
self.firmware = firmware
if not isinstance(firmware, io.IOBase):
raise UpgradeError("TypeError: firmware must be a file-like object!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isinstance(firmware, io.IOBase) would be more accurate.

self.firmware = firmware

def __del__(self):
try:
if hasattr(self, 'firmware'):
self.firmware.close()
except Exception:
pass
del self.firmware

@abstractmethod
def write(self):
Expand Down Expand Up @@ -233,13 +231,10 @@ class MtdDevice():
firmware_content += bytearray([0xff] * padsize)

if not mtd_content == firmware_content:
#sys.stdout.flush()
self.__erase(mtd_dev, mtd_pos, mtd_erasesize)
os.lseek(mtd_dev, mtd_pos, os.SEEK_SET)
os.write(mtd_dev, firmware_content)
#else:
# print(".", end="")
# sys.stdout.flush()

mtd_pos += mtd_erasesize
file_size -= mtd_erasesize
os.close(mtd_dev)
Expand Down Expand Up @@ -332,15 +327,17 @@ class BootloaderFirmware(Firmware):

class EnvFirmware(Firmware):
"""The EnvFirmware class represents env partition operations"""
def __init__(self, firmware):
def __init__(self, firmware_path, firmware):
super().__init__(firmware)
self.firmware_path = firmware
self.firmware_path = firmware_path
self.mtd_device = MtdDevice()

mtd_num = 0
self.env_mtd_num = 0
self.env_bk_mtd_num = 0
while True:
# mtd_device is typically less than 20, if one mtd device can't be
# located in 20 rounds, jump out the loop.
for mtd_num in range(20):
try:
mtd_dev_path, mtd_size, mtd_erasesize, mtd_name = \
self.mtd_device.get_mtd_info(mtd_num)
Expand All @@ -351,10 +348,11 @@ class EnvFirmware(Firmware):
self.env_mtd_num = mtd_num
if "env.backup" == mtd_name:
self.env_bk_mtd_num = mtd_num
mtd_num +=1

if self.env_mtd_num and self.env_bk_mtd_num:
break
else:
raise UpgradeError("EnvFirmware: No env partition found")

def write(self):
"""A env firmware can write contents to the env partition"""
Expand All @@ -371,13 +369,13 @@ class EnvFirmware(Firmware):
firmware_size = os.path.getsize(env_default_binary.name)
self.firmware = open(env_default_binary.name, "rb")

while True:
if firmware_size <= 0:
break
firmware_size = self.mtd_device.write(
mtd_dev_path, mtd_size, mtd_erasesize,
env_default_binary, firmware_size
)
firmware_size = self.mtd_device.write(
mtd_dev_path, mtd_size, mtd_erasesize,
env_default_binary, firmware_size
)

if firmware_size > 0:
raise UpgradeError("Write env failed")
except subprocess.CalledProcessError as error:
print(error.stdout)
raise UpgradeError("EnvFirmware: Run mkenvimage failed")
Expand Down Expand Up @@ -408,25 +406,28 @@ class ForceUpdate():
self.firmware_type = firmware_type
self.interactor = interactor

def update(self):
if self.firmware_type == "uboot":
firmware_obj = BootloaderFirmware(self.firmware)
try:
self.firmware_obj = BootloaderFirmware(self.firmware)
except UpgradeError as e:
raise UpgradeError(e.err, ErrorCode.INVALID_FIRMWARE.value)
else:
raise UpgradeError("Unsupported firmware type!")
raise UpgradeError("Unsupported firmware type!", ErrorCode.INVALID_FIRMWARE.value)

def update(self):
print("===================================================")
print("IOT2050 firmware update started - DO NOT INTERRUPT!")
print("===================================================")

self.interactor.progress_bar(info="Updating {}".format(self.firmware_type))
firmware_obj.write()
self.firmware_obj.write()

firmware_md5 = hashlib.md5()
self.firmware.seek(0)
firmware_md5.update(self.firmware.read())

read_out_md5 = hashlib.md5()
read_out_md5.update(firmware_obj.read())
read_out_md5.update(self.firmware_obj.read())
self.interactor.progress_bar(start=False)

if firmware_md5.hexdigest() != read_out_md5.hexdigest():
Expand Down Expand Up @@ -472,11 +473,13 @@ class FirmwareUpdate():
print("\nPreserved env list: ")
for env in env_list:
print(env)
self.firmwares[firmware_type] = EnvFirmware(
env_path, env_binary = \
self.tarball.generate_env_firmware(env_list)
)
self.firmwares[firmware_type] = \
EnvFirmware(env_path, env_binary)
else:
self.firmwares[firmware_type] = EnvFirmware(
self.tarball.get_file_path(self.tarball.UBOOT_ENV_FILE),
self.tarball.get_file(self.tarball.UBOOT_ENV_FILE)
)
elif firmware_type == self.tarball.FIRMWARE_TYPES[2]:
Expand Down Expand Up @@ -557,7 +560,6 @@ class FirmwareUpdate():

self.firmwares[firmware_type].write()


self.firmwares[firmware_type].firmware.seek(0)
content = self.firmwares[firmware_type].firmware.read()
firmware_md5 = self.__get_md5_digest(content)
Expand Down Expand Up @@ -602,12 +604,14 @@ class FirmwareTarball(object):
# extract file path
self.extract_path = "/tmp"
self.firmware_tarball.seek(0)
self.extracted_files = []
with tarfile.open(fileobj=self.firmware_tarball) as f:
for member in f:
file_tarfileinfo = f.getmember(name=member.name)
file_tarfileinfo.uid = os.getuid()
file_tarfileinfo.gid = os.getgid()
f.extract(file_tarfileinfo, path=self.extract_path)
self.extracted_files.append(self.extract_path + "/" + member.name)

self._board_info = BoardInfo()
print("Current board: {}".format(self._board_info.board_name))
Expand All @@ -617,7 +621,7 @@ class FirmwareTarball(object):
# to access the fields.
try:
self._jsonobj = json.load(
open(self.get_file(self.CONF_JSON), "rb"),
self.get_file(self.CONF_JSON),
object_hook=lambda d: Namespace(**d)
)
except ValueError:
Expand All @@ -626,9 +630,8 @@ class FirmwareTarball(object):
self.firmware_names = dict.fromkeys(self.FIRMWARE_TYPES)

def __del__(self):
with tarfile.open(fileobj=self.firmware_tarball) as f:
for member in f:
os.remove(self.extract_path + "/" + member.name)
for file in self.extracted_files:
os.remove(file)

def __check_os(self, target_os, os_info) -> bool:
for tos in target_os:
Expand Down Expand Up @@ -704,6 +707,12 @@ class FirmwareTarball(object):
"""Get the file object of specified name"""
file = os.path.join(self.extract_path, name)

return open(file, 'rb')

def get_file_path(self, name):
"""Get the file object of specified name"""
file = os.path.join(self.extract_path, name)

return file

def __get_suggest_preserved_uboot_env(self):
Expand Down Expand Up @@ -773,7 +782,7 @@ class FirmwareTarball(object):
file.write(value)
file.write("\n")

return uboot_env_assemble_file
return uboot_env_assemble_file, open(uboot_env_assemble_file, 'rb')


class BoardInfo(object):
Expand Down Expand Up @@ -915,7 +924,7 @@ def main(argv):
group = parser.add_mutually_exclusive_group()
parser.add_argument('firmware', nargs='?', metavar='FIRMWARE',
type=argparse.FileType('rb'),
help='firmware tarball')
help='firmware or tarball')
group.add_argument('-f', '--force',
help='Force update, ignore all the checking',
action='store_true')
Expand Down
Loading