Skip to content

Commit

Permalink
fix(binaryread): raise/handle EOFError, deprecate vartype=str (#2226)
Browse files Browse the repository at this point in the history
This fixes issues while reading some binaryfiles with auto precision detection, and also modernizes a few aspects of flopy.utils.binaryfile left-over from python2.

There are two changes to flopy.utils.binaryfile.binaryread():

* Raises EOFError if attempting to read data beyond the end-of-file
* Deprecate vartype=str, since bytes is the the return type with Python3

Other refactors:

* Simplify conventional ASCII range checks by converting bytes to a list of int, then check if bytes are within range
* Remove checks if bytes are not str, and use .encode("ascii") where appropriate
  • Loading branch information
mwtoews authored Jun 13, 2024
1 parent c69990a commit e2a85a3
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 102 deletions.
20 changes: 19 additions & 1 deletion autotest/test_binaryfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,30 @@ def test_binaryread(example_data_path):
np.testing.assert_array_equal(res, np.array([1, 1], np.int32))
res = flopy.utils.binaryfile.binaryread(fp, np.float32, 2)
np.testing.assert_array_equal(res, np.array([10, 10], np.float32))
res = flopy.utils.binaryfile.binaryread(fp, str)
res = flopy.utils.binaryfile.binaryread(fp, bytes)
assert res == b" HEAD"
res = flopy.utils.binaryfile.binaryread(fp, np.int32)
assert res == 20


def test_binaryread_misc(tmp_path):
# Check deprecated warning
file = tmp_path / "data.file"
file.write_bytes(b" data")
with file.open("rb") as fp:
with pytest.deprecated_call(match="vartype=str is deprecated"):
res = flopy.utils.binaryfile.binaryread(fp, str, charlen=5)
assert res == b" data"
# Test exceptions with a small file with 1 byte
file.write_bytes(b"\x00")
with file.open("rb") as fp:
with pytest.raises(EOFError):
flopy.utils.binaryfile.binaryread(fp, bytes, charlen=6)
with file.open("rb") as fp:
with pytest.raises(EOFError):
flopy.utils.binaryfile.binaryread(fp, np.int32)


def test_deprecated_binaryread_struct(example_data_path):
# similar to test_binaryread(), but check the calls are deprecated
pth = example_data_path / "freyberg" / "freyberg.githds"
Expand Down
12 changes: 12 additions & 0 deletions autotest/test_cellbudgetfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,18 @@ def test_cellbudgetfile_build_index_mf6(example_data_path):
)


def test_cellbudgetfile_imeth_5(example_data_path):
pth = example_data_path / "preserve_unitnums/testsfr2.ghb.cbc"
with CellBudgetFile(pth) as cbc:
pass
# check a few components
pd.testing.assert_index_equal(
cbc.headers.index, pd.Index(np.arange(12, dtype=np.int64) * 156 + 64)
)
assert cbc.headers.text.unique().tolist() == ["HEAD DEP BOUNDS"]
assert cbc.headers.imeth.unique().tolist() == [5]


@pytest.fixture
def zonbud_model_path(example_data_path):
return example_data_path / "zonbud_examples"
Expand Down
178 changes: 77 additions & 101 deletions flopy/utils/binaryfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ class BinaryHeader(Header):
Parameters
----------
bintype : str
Type of file being opened. Accepted values are 'head' and 'ucn'.
precision : str
Precision of floating point data in the file.
bintype : str, default None
Type of file being opened. Accepted values are 'head' and 'ucn'.
precision : str, default 'single'
Precision of floating point data in the file.
"""

Expand Down Expand Up @@ -313,32 +313,47 @@ def binaryread_struct(file, vartype, shape=(1,), charlen=16):

def binaryread(file, vartype, shape=(1,), charlen=16):
"""
Read text, a scalar value, or an array of values from a binary file.
Read character bytes, scalar or array values from a binary file.
Parameters
----------
file : file object
is an open file object
vartype : type
is the return variable type: str, numpy.int32, numpy.float32,
or numpy.float64
is the return variable type: bytes, numpy.int32,
numpy.float32, or numpy.float64. Using str is deprecated since
bytes is preferred.
shape : tuple, default (1,)
is the shape of the returned array (shape(1, ) returns a single
value) for example, shape = (nlay, nrow, ncol)
charlen : int, default 16
is the length of the text string. Note that string arrays
cannot be returned, only multi-character strings. Shape has no
affect on strings.
is the length character bytes. Note that arrays of bytes
cannot be returned, only multi-character bytes. Shape has no
affect on bytes.
Raises
------
EOFError
"""

# read a string variable of length charlen
if vartype == str:
# handle a hang-over from python2
warnings.warn(
"vartype=str is deprecated; use vartype=bytes instead.",
DeprecationWarning,
)
vartype = bytes
if vartype == bytes:
# read character bytes of length charlen
result = file.read(charlen)
if len(result) < charlen:
raise EOFError
else:
# find the number of values
nval = np.prod(shape)
result = np.fromfile(file, vartype, nval)
if result.size < nval:
raise EOFError
if nval != 1:
result = np.reshape(result, shape)
return result
Expand All @@ -364,23 +379,18 @@ def get_headfile_precision(filename: Union[str, os.PathLike]):
Parameters
----------
filename : str or PathLike
Path of binary MODFLOW file to determine precision.
Path of binary MODFLOW file to determine precision.
Returns
-------
result : str
Result will be unknown, single, or double
str
Result will be unknown, single, or double
"""

# Set default result if neither single or double works
result = "unknown"

# Create string containing set of ascii characters
asciiset = " "
for i in range(33, 127):
asciiset += chr(i)

# Open file, and check filesize to ensure this is not an empty file
f = open(filename, "rb")
f.seek(0, 2)
Expand All @@ -399,15 +409,12 @@ def get_headfile_precision(filename: Union[str, os.PathLike]):
("text", "S16"),
]
hdr = binaryread(f, vartype)
text = hdr[0][4]
try:
text = text.decode()
for t in text:
if t.upper() not in asciiset:
raise Exception()
charbytes = list(hdr[0][4])
if min(charbytes) >= 32 and max(charbytes) <= 126:
# check if bytes are within conventional ASCII range
result = "single"
success = True
except:
else:
success = False

# next try double
Expand All @@ -421,14 +428,10 @@ def get_headfile_precision(filename: Union[str, os.PathLike]):
("text", "S16"),
]
hdr = binaryread(f, vartype)
text = hdr[0][4]
try:
text = text.decode()
for t in text:
if t.upper() not in asciiset:
raise Exception()
charbytes = list(hdr[0][4])
if min(charbytes) >= 32 and max(charbytes) <= 126:
result = "double"
except:
else:
f.close()
raise ValueError(
f"Could not determine the precision of the headfile {filename}"
Expand Down Expand Up @@ -1171,7 +1174,7 @@ def _set_precision(self, precision="single"):

try:
self._build_index()
except BudgetIndexError:
except (BudgetIndexError, EOFError):
success = False
self.__reset()

Expand Down Expand Up @@ -1201,20 +1204,14 @@ def _build_index(self):
Build the ordered dictionary, which maps the header information
to the position in the binary file.
"""
asciiset = " "
for i in range(33, 127):
asciiset += chr(i)

# read first record
header = self._get_header()
nrow = header["nrow"]
ncol = header["ncol"]
text = header["text"]
if isinstance(text, bytes):
text = text.decode()
text = header["text"].decode("ascii").strip()
if nrow < 0 or ncol < 0:
raise Exception("negative nrow, ncol")
if not text.endswith("FLOW-JA-FACE"):
if text != "FLOW-JA-FACE":
self.nrow = nrow
self.ncol = ncol
self.nlay = np.abs(header["nlay"])
Expand Down Expand Up @@ -1242,17 +1239,14 @@ def _build_index(self):
self.kstpkper.append(kstpkper)
if header["text"] not in self.textlist:
# check the precision of the file using text records
try:
tlist = [header["text"], header["modelnam"]]
for text in tlist:
if isinstance(text, bytes):
text = text.decode()
for t in text:
if t.upper() not in asciiset:
raise Exception()

except:
raise BudgetIndexError("Improper precision")
tlist = [header["text"], header["modelnam"]]
for text in tlist:
if len(text) == 0:
continue
charbytes = list(text)
if min(charbytes) < 32 or max(charbytes) > 126:
# not in conventional ASCII range
raise BudgetIndexError("Improper precision")
self.textlist.append(header["text"])
self.imethlist.append(header["imeth"])
if header["paknam"] not in self.paknamlist_from:
Expand All @@ -1279,23 +1273,15 @@ def _build_index(self):
"paknam2",
]:
s = header[itxt]
if isinstance(s, bytes):
s = s.decode()
print(f"{itxt}: {s}")
print("file position: ", ipos)
if (
header["imeth"].item() != 5
and header["imeth"].item() != 6
and header["imeth"].item() != 7
):
if header["imeth"].item() not in {5, 6, 7}:
print("")

# set the nrow, ncol, and nlay if they have not been set
if self.nrow == 0:
text = header["text"]
if isinstance(text, bytes):
text = text.decode()
if not text.endswith("FLOW-JA-FACE"):
text = header["text"].decode("ascii").strip()
if text != "FLOW-JA-FACE":
self.nrow = header["nrow"]
self.ncol = header["ncol"]
self.nlay = np.abs(header["nlay"])
Expand Down Expand Up @@ -1350,51 +1336,47 @@ def _skip_record(self, header):
nrow = header["nrow"]
ncol = header["ncol"]
imeth = header["imeth"]
realtype_nbytes = self.realtype(1).nbytes
if imeth == 0:
nbytes = nrow * ncol * nlay * self.realtype(1).nbytes
nbytes = nrow * ncol * nlay * realtype_nbytes
elif imeth == 1:
nbytes = nrow * ncol * nlay * self.realtype(1).nbytes
nbytes = nrow * ncol * nlay * realtype_nbytes
elif imeth == 2:
nlist = binaryread(self.file, np.int32)[0]
nbytes = nlist * (np.int32(1).nbytes + self.realtype(1).nbytes)
nbytes = nlist * (4 + realtype_nbytes)
elif imeth == 3:
nbytes = nrow * ncol * self.realtype(1).nbytes
nbytes += nrow * ncol * np.int32(1).nbytes
nbytes = nrow * ncol * realtype_nbytes + (nrow * ncol * 4)
elif imeth == 4:
nbytes = nrow * ncol * self.realtype(1).nbytes
nbytes = nrow * ncol * realtype_nbytes
elif imeth == 5:
nauxp1 = binaryread(self.file, np.int32)[0]
naux = nauxp1 - 1

for i in range(naux):
temp = binaryread(self.file, str, charlen=16)
naux_nbytes = naux * 16
if naux_nbytes:
check = self.file.seek(naux_nbytes, 1)
if check < naux_nbytes:
raise EOFError
nlist = binaryread(self.file, np.int32)[0]
if self.verbose:
print("naux: ", naux)
print("nlist: ", nlist)
print("")
nbytes = nlist * (
np.int32(1).nbytes
+ self.realtype(1).nbytes
+ naux * self.realtype(1).nbytes
)
nbytes = nlist * (4 + realtype_nbytes + naux * realtype_nbytes)
elif imeth == 6:
# read rest of list data
nauxp1 = binaryread(self.file, np.int32)[0]
naux = nauxp1 - 1

for i in range(naux):
temp = binaryread(self.file, str, charlen=16)
naux_nbytes = naux * 16
if naux_nbytes:
check = self.file.seek(naux_nbytes, 1)
if check < naux_nbytes:
raise EOFError
nlist = binaryread(self.file, np.int32)[0]
if self.verbose:
print("naux: ", naux)
print("nlist: ", nlist)
print("")
nbytes = nlist * (
np.int32(1).nbytes * 2
+ self.realtype(1).nbytes
+ naux * self.realtype(1).nbytes
)
nbytes = nlist * (4 * 2 + realtype_nbytes + naux * realtype_nbytes)
else:
raise Exception(f"invalid method code {imeth}")
if nbytes != 0:
Expand All @@ -1418,10 +1400,10 @@ def _get_header(self):
for name in temp.dtype.names:
header2[name] = temp[name]
if header2["imeth"].item() == 6:
header2["modelnam"] = binaryread(self.file, str, charlen=16)
header2["paknam"] = binaryread(self.file, str, charlen=16)
header2["modelnam2"] = binaryread(self.file, str, charlen=16)
header2["paknam2"] = binaryread(self.file, str, charlen=16)
header2["modelnam"] = binaryread(self.file, bytes, charlen=16)
header2["paknam"] = binaryread(self.file, bytes, charlen=16)
header2["modelnam2"] = binaryread(self.file, bytes, charlen=16)
header2["paknam2"] = binaryread(self.file, bytes, charlen=16)
else:
header2 = np.array(
[(0, 0.0, 0.0, 0.0, "", "", "", "")], dtype=self.header2_dtype
Expand Down Expand Up @@ -1951,9 +1933,7 @@ def get_record(self, idx, full3D=False):
self.file.seek(ipos, 0)
imeth = header["imeth"][0]

t = header["text"][0]
if isinstance(t, bytes):
t = t.decode("utf-8")
t = header["text"][0].decode("ascii")
s = f"Returning {t.strip()} as "

nlay = abs(header["nlay"][0])
Expand Down Expand Up @@ -2039,10 +2019,8 @@ def get_record(self, idx, full3D=False):
naux = nauxp1 - 1
l = [("node", np.int32), ("q", self.realtype)]
for i in range(naux):
auxname = binaryread(self.file, str, charlen=16)
if not isinstance(auxname, str):
auxname = auxname.decode()
l.append((auxname.strip(), self.realtype))
auxname = binaryread(self.file, bytes, charlen=16)
l.append((auxname.decode("ascii").strip(), self.realtype))
dtype = np.dtype(l)
nlist = binaryread(self.file, np.int32)[0]
data = binaryread(self.file, dtype, shape=(nlist,))
Expand All @@ -2064,10 +2042,8 @@ def get_record(self, idx, full3D=False):
naux = nauxp1 - 1
l = [("node", np.int32), ("node2", np.int32), ("q", self.realtype)]
for i in range(naux):
auxname = binaryread(self.file, str, charlen=16)
if not isinstance(auxname, str):
auxname = auxname.decode()
l.append((auxname.strip(), self.realtype))
auxname = binaryread(self.file, bytes, charlen=16)
l.append((auxname.decode("ascii").strip(), self.realtype))
dtype = np.dtype(l)
nlist = binaryread(self.file, np.int32)[0]
data = binaryread(self.file, dtype, shape=(nlist,))
Expand Down

0 comments on commit e2a85a3

Please sign in to comment.