Skip to content

Commit

Permalink
Add strict typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
cmeyer committed Dec 6, 2024
1 parent 7d6033e commit e8e0824
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 171 deletions.
1 change: 1 addition & 0 deletions meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ requirements:
- nionutils >=4.12,<5.0
- niondata >=15.7,<16.0
- numpy >=2.0,<3.0
- pytz
- tifffile

test:
Expand Down
9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Global options:

[mypy]
ignore_missing_imports = True
follow_imports = silent
strict = True
warn_redundant_casts = False
warn_unused_ignores = False
packages = nionswift_plugin.DM_IO, nionswift_plugin.TIFF_IO
15 changes: 8 additions & 7 deletions nionswift_plugin/DM_IO/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# standard libraries
import gettext
import pathlib
import typing

# third party libraries
from nion.data import DataAndMetadata
Expand All @@ -18,20 +19,20 @@

class DM3IODelegate(object):

def __init__(self, api):
def __init__(self, api: typing.Any) -> None:
self.__api = api
self.io_handler_id = "dm-io-handler"
self.io_handler_name = _("DigitalMicrograph Files")
self.io_handler_extensions = ["dm3", "dm4"]

def read_data_and_metadata(self, extension, file_path):
def read_data_and_metadata(self, extension: str, file_path: str) -> DataAndMetadata.DataAndMetadata:
with open(file_path, "rb", buffering=8 * 1024 * 1024) as f:
return dm3_image_utils.load_image(f)

def can_write_data_and_metadata(self, data_and_metadata, extension):
def can_write_data_and_metadata(self, data_and_metadata: DataAndMetadata.DataAndMetadata, extension: str) -> bool:
return extension.lower() in self.io_handler_extensions

def write_data_and_metadata(self, data_and_metadata, file_path_str: str, extension):
def write_data_and_metadata_stream(self, data_and_metadata: DataAndMetadata.DataAndMetadata, file_path_str: str) -> None:
file_path = pathlib.Path(file_path_str)
data = data_and_metadata.data
data_descriptor = data_and_metadata.data_descriptor
Expand All @@ -58,7 +59,7 @@ def write_data_and_metadata(self, data_and_metadata, file_path_str: str, extensi
dm3_image_utils.save_image(xdata, f, 4 if file_path.suffix == ".dm4" else 3)


def load_image(file_path):
def load_image(file_path: str) -> DataAndMetadata.DataAndMetadata:
with open(file_path, "rb", buffering=8 * 1024 * 1024) as f:
return dm3_image_utils.load_image(f)

Expand All @@ -68,13 +69,13 @@ class DM3IOExtension(object):
# required for Swift to recognize this as an extension class.
extension_id = "nion.swift.extensions.dm3"

def __init__(self, api_broker):
def __init__(self, api_broker: typing.Any) -> None:
# grab the api object.
api = api_broker.get_api(version="1", ui_version="1")
# be sure to keep a reference or it will be closed immediately.
self.__io_handler_ref = api.create_data_and_metadata_io_handler(DM3IODelegate(api))

def close(self):
def close(self) -> None:
# close will be called when the extension is unloaded. in turn, close any references so they get closed. this
# is not strictly necessary since the references will be deleted naturally when this object is deleted.
self.__io_handler_ref.close()
Expand Down
55 changes: 29 additions & 26 deletions nionswift_plugin/DM_IO/dm3_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,26 @@
# datratypes in describing the data.
# from .parse_dm3 import *

import array
import copy
import datetime
import pprint
import numpy
import typing

import numpy
import numpy.typing
import pytz

from nion.data import Calibration
from nion.data import DataAndMetadata

from . import parse_dm3


def str_to_utf16_bytes(s):
def str_to_utf16_bytes(s: str) -> bytes:
return s.encode('utf-16')

def get_datetime_from_timestamp_str(timestamp_str):
def get_datetime_from_timestamp_str(timestamp_str: str) -> typing.Optional[datetime.datetime]:
if len(timestamp_str) in (23, 26):
return datetime.datetime.strptime(timestamp_str, "%Y-%m-%dT%H:%M:%S.%f")
elif len(timestamp_str) == 19:
Expand Down Expand Up @@ -66,19 +70,20 @@ def get_datetime_from_timestamp_str(timestamp_str):
}


def imagedatadict_to_ndarray(imdict):
def imagedatadict_to_ndarray(imdict: dict[str, typing.Any]) -> numpy.typing.NDArray[typing.Any]:
"""
Converts the ImageData dictionary, imdict, to an nd image.
"""
arr = imdict['Data']
im = None
if isinstance(arr, parse_dm3.array.array):
if isinstance(arr, array.array):
im = numpy.asarray(arr, dtype=arr.typecode)
elif isinstance(arr, parse_dm3.structarray):
t = tuple(arr.typecodes)
t = typing.cast(tuple[str, str], tuple(arr.typecodes))
im = numpy.frombuffer(
arr.raw_data,
typing.cast(typing.Any, arr.raw_data), # huh?
dtype=structarray_to_np_map[t])
assert im is not None
# print "Image has dmimagetype", imdict["DataType"], "numpy type is", im.dtype
assert dm_image_dtypes[imdict["DataType"]][1] == im.dtype
assert imdict['PixelDepth'] == im.dtype.itemsize
Expand All @@ -89,7 +94,7 @@ def imagedatadict_to_ndarray(imdict):
return im


def platform_independent_char(dtype):
def platform_independent_char(dtype: typing.Any) -> str: # ugh dtype
# windows and linux/macos treat dtype.char differently.
# on linux/macos where 'l' has size 8, ints of size 4 are reported as 'i'
# on windows where 'l' has size 4, ints of size 4 are reported as 'l'
Expand All @@ -98,16 +103,16 @@ def platform_independent_char(dtype):
if dtype.char == 'l' and dtype.itemsize == 8: return 'q'
if dtype.char == 'L' and dtype.itemsize == 4: return 'I'
if dtype.char == 'L' and dtype.itemsize == 8: return 'Q'
return dtype.char
return typing.cast(str, dtype.char)


def ndarray_to_imagedatadict(nparr):
def ndarray_to_imagedatadict(nparr: numpy.typing.NDArray[typing.Any]) -> dict[str, typing.Any]:
"""
Convert the numpy array nparr into a suitable ImageList entry dictionary.
Returns a dictionary with the appropriate Data, DataType, PixelDepth
to be inserted into a dm3 tag dictionary and written to a file.
"""
ret = {}
ret = dict[str, typing.Any]()
dm_type = None
for k, v in iter(dm_image_dtypes.items()):
if v[1] == nparr.dtype.type:
Expand All @@ -125,7 +130,7 @@ def ndarray_to_imagedatadict(nparr):
rgba_image[:,:,3] = 255
rgb_view = rgba_image.view(numpy.int32).reshape(rgba_image.shape[:-1]) # squash the color into uint32
ret["Dimensions"] = list(rgb_view.shape[::-1])
ret["Data"] = parse_dm3.array.array(platform_independent_char(rgb_view.dtype), rgb_view.flatten())
ret["Data"] = array.array(platform_independent_char(rgb_view.dtype), rgb_view.flatten())
else:
ret["DataType"] = dm_type
ret["PixelDepth"] = nparr.dtype.itemsize
Expand All @@ -135,11 +140,11 @@ def ndarray_to_imagedatadict(nparr):
ret["Data"] = parse_dm3.structarray(types)
ret["Data"].raw_data = bytes(numpy.asarray(nparr).data)
else:
ret["Data"] = parse_dm3.array.array(platform_independent_char(nparr.dtype), numpy.asarray(nparr).flatten())
ret["Data"] = array.array(platform_independent_char(nparr.dtype), numpy.asarray(nparr).flatten())
return ret


def display_keys(tag: typing.Dict) -> None:
def display_keys(tag: dict[str, typing.Any]) -> None:
tag_copy = copy.deepcopy(tag)
for image_data in tag_copy.get("ImageList", list()):
image_data.get("ImageData", dict()).pop("Data", None)
Expand All @@ -148,7 +153,7 @@ def display_keys(tag: typing.Dict) -> None:
pprint.pprint(tag_copy)


def fix_strings(d):
def fix_strings(d: typing.Any) -> typing.Any:
if isinstance(d, dict):
r = dict()
for k, v in d.items():
Expand All @@ -162,7 +167,7 @@ def fix_strings(d):
for v in d:
l.append(fix_strings(v))
return l
elif isinstance(d, parse_dm3.array.array):
elif isinstance(d, array.array):
if d.typecode == 'H':
return d.tobytes().decode("utf-16")
else:
Expand All @@ -183,23 +188,23 @@ def load_image(file: typing.BinaryIO) -> DataAndMetadata.DataAndMetadata:
img_index = -1
image_tags = dmtag['ImageList'][img_index]
data = imagedatadict_to_ndarray(image_tags['ImageData'])
calibrations = []
calibrations = list[tuple[float, float, str]]()
calibration_tags = image_tags['ImageData'].get('Calibrations', dict())
for dimension in calibration_tags.get('Dimension', list()):
origin, scale, units = dimension.get('Origin', 0.0), dimension.get('Scale', 1.0), dimension.get('Units', str())
calibrations.append((-origin * scale, scale, units))
calibrations = tuple(reversed(calibrations))
calibrations = list(reversed(calibrations))
if len(data.shape) == 3 and data.dtype != numpy.uint8:
if image_tags['ImageTags'].get('Meta Data', dict()).get("Format", str()).lower() in ("spectrum", "spectrum image"):
if data.shape[1] == 1:
data = numpy.squeeze(data, 1)
data = numpy.moveaxis(data, 0, 1)
data_descriptor = DataAndMetadata.DataDescriptor(False, 1, 1)
calibrations = (calibrations[2], calibrations[0])
calibrations = [calibrations[2], calibrations[0]]
else:
data = numpy.moveaxis(data, 0, 2)
data_descriptor = DataAndMetadata.DataDescriptor(False, 2, 1)
calibrations = tuple(calibrations[1:]) + (calibrations[0],)
calibrations = list(calibrations[1:]) + [calibrations[0]]
else:
data_descriptor = DataAndMetadata.DataDescriptor(False, 1, 2)
elif len(data.shape) == 4 and data.dtype != numpy.uint8:
Expand All @@ -215,8 +220,7 @@ def load_image(file: typing.BinaryIO) -> DataAndMetadata.DataAndMetadata:
timestamp = None
timezone = None
timezone_offset = None
title = image_tags.get('Name')
properties = dict()
properties = dict[str, typing.Any]()
if 'ImageTags' in image_tags:
voltage = image_tags['ImageTags'].get('ImageScanned', dict()).get('EHT', dict())
if voltage:
Expand Down Expand Up @@ -286,12 +290,12 @@ def save_image(xdata: DataAndMetadata.DataAndMetadata, file: typing.BinaryIO, fi
data_descriptor = DataAndMetadata.DataDescriptor(False, 2, 1)
needs_slice = True
data_dict = ndarray_to_imagedatadict(data)
ret = {}
ret = dict[str, typing.Any]()
ret["ImageList"] = [{"ImageData": data_dict}]
if dimensional_calibrations and len(dimensional_calibrations) == len(data.shape):
dimension_list = data_dict.setdefault("Calibrations", dict()).setdefault("Dimension", list())
for dimensional_calibration in reversed(dimensional_calibrations):
dimension = dict()
dimension = dict[str, typing.Any]()
if dimensional_calibration.scale != 0.0:
origin = -dimensional_calibration.offset / dimensional_calibration.scale
else:
Expand All @@ -313,7 +317,6 @@ def save_image(xdata: DataAndMetadata.DataAndMetadata, file: typing.BinaryIO, fi
timezone_str = None
if timezone_str is None and timezone:
try:
import pytz
tz = pytz.timezone(timezone)
timezone_str = tz.tzname(modified)
except:
Expand All @@ -332,7 +335,7 @@ def save_image(xdata: DataAndMetadata.DataAndMetadata, file: typing.BinaryIO, fi
ret["DocumentObjectList"] = [{"ImageSource": 0, "AnnotationType": 20}]
# finally some display options
ret["Image Behavior"] = {"ViewDisplayID": 8}
dm_metadata = copy.deepcopy(metadata)
dm_metadata = dict(metadata)
if metadata.get("hardware_source", dict()).get("signal_type", "").lower() == "eels":
if len(data.shape) == 1 or (len(data.shape) == 2 and data.shape[0] == 1):
dm_metadata.setdefault("Meta Data", dict())["Format"] = "Spectrum"
Expand Down
Loading

0 comments on commit e8e0824

Please sign in to comment.