Skip to content

Commit

Permalink
#20 fix set_time_var, set_time_bnd_var, set_var
Browse files Browse the repository at this point in the history
  • Loading branch information
akorosov committed Mar 31, 2022
1 parent 45fbad6 commit e8500f9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 28 deletions.
40 changes: 23 additions & 17 deletions geodataset/geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,36 +126,43 @@ def set_projection_variable(self):
pvar = self.createVariable(self.grid_mapping_variable, 'i1')
pvar.setncatts(self.get_grid_mapping_ncattrs())

def set_time_variables_dimensions(self, time_data, time_atts, time_bnds_data):
def set_time_variable(self, time_data, time_atts):
"""
set the temporal dimensions (time, nv)
and variables (time, time_bnds)
set the temporal dimensions: time
and variables: time
Parameters:
-----------
time_data : np.array
data for time variable
time_atts : dict
netcdf attributes for time variable
time_bnds_data : np.array
data for time_bnds variable
time_bnds_atts : dict
netcdf attributes for time_bnds variable
"""
# dimensions
self.createDimension('time', None)#time should be unlimited
self.createDimension('nv', 2)
# time should have units and a calendar attribute
ncatts = dict(**time_atts)
ncatts['calendar'] = time_atts.get('calendar', 'standard')
units = time_atts['units']
# time var
tvar = self.createVariable('time', 'f8', ('time',), zlib=True)
tvar.setncatts(ncatts)
tvar[:] = time_data
# time_bnds var - just needs units

def set_time_bnds_variable(self, time_atts, time_bnds_data):
"""
set the temporal dimension: nv
and variable: time_bnds
Parameters:
-----------
time_atts : dict
netcdf attributes for time variable
time_bnds_data : np.array
data for time_bnds variable
"""
self.createDimension('nv', 2)
tbvar = self.createVariable('time_bnds', 'f8', ('time', 'nv'), zlib=True)
tbvar.setncattr('units', units)
tbvar.setncattr('units', time_atts['units'])
tbvar[:] = time_bnds_data

def set_xy_dims(self, x, y):
Expand Down Expand Up @@ -200,7 +207,7 @@ def set_lonlat(self, lon, lat):
dst_var.setncattr('units', units)
dst_var[:] = data

def set_variable(self, vname, data, dims, atts, dtype='f4'):
def set_variable(self, vname, data, dims, atts, dtype=np.float32):
"""
set variable data and attributes
Expand All @@ -214,18 +221,17 @@ def set_variable(self, vname, data, dims, atts, dtype='f4'):
list of dimension names for the variable
atts : dict
netcdf attributes to set
dtype : str
netcdf data type for new variable (eg 'f4' or 'f8')
dtype : type
netcdf data type for new variable (eg np.float32 or np.double)
"""
type_converter = dict(f4=np.float32, f8=np.double)[dtype]
ncatts = {k:v for k,v in atts.items() if k != '_FillValue'}
kw = dict(zlib=True)# use compression
if '_FillValue' in atts:
# needs to be a keyword for createVariable and of right data type
kw['fill_value'] = type_converter(atts['_FillValue'])
kw['fill_value'] = dtype(atts['_FillValue'])
if 'missing_value' in atts:
# needs to be of right data type
ncatts['missing_value'] = type_converter(atts['missing_value'])
ncatts['missing_value'] = dtype(atts['missing_value'])
dst_var = self.createVariable(vname, dtype, dims, **kw)
ncatts['grid_mapping'] = self.grid_mapping_variable
dst_var.setncatts(ncatts)
Expand Down
38 changes: 27 additions & 11 deletions geodataset/tests/test_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def test_set_xy_dims(self, **kwargs):
createVariable=DEFAULT,
)
def test_set_lonlat(self, **kwargs):

slon = (2,2)
slat = (3,3)
lon = np.random.normal(size=slon)
Expand Down Expand Up @@ -155,19 +154,36 @@ def test_set_lonlat(self, **kwargs):
createDimension=DEFAULT,
createVariable=DEFAULT,
)
def test_set_time_variables_dimensions(self, **kwargs):
def test_set_time_variable(self, **kwargs):
nc = GeoDatasetWrite()
nt = 3
time_inds = [1,2]
time = np.random.normal(size=(nt,))
time_bnds = np.random.normal(size=(nt,2))
time_atts = dict(a1='A1', a2='A2', units='units')

nc.set_time_variables_dimensions(time, time_atts, time_bnds)
self.assert_mock_has_calls(kwargs['createDimension'], [call('time', None), call('nv', 2)])
nc.set_time_variable(time, time_atts)
self.assert_mock_has_calls(kwargs['createDimension'], [call('time', None)])
req_calls = [
call('time', 'f8', ('time',), zlib=True), call().setncatts({'a1': 'A1', 'a2': 'A2', 'units': 'units', 'calendar': 'standard'}),
call('time', 'f8', ('time',), zlib=True),
call().setncatts({'a1': 'A1', 'a2': 'A2', 'units': 'units', 'calendar': 'standard'}),
call().__setitem__(slice(None, None, None), time),
]
self.assert_mock_has_calls(kwargs['createVariable'], req_calls)

@patch.multiple(GeoDatasetWrite,
__init__=MagicMock(return_value=None),
createDimension=DEFAULT,
createVariable=DEFAULT,
)
def test_set_time_bnds_variable(self, **kwargs):
nc = GeoDatasetWrite()
nt = 3
time = np.random.normal(size=(nt,))
time_bnds = np.random.normal(size=(nt,2))
time_atts = dict(a1='A1', a2='A2', units='units')

nc.set_time_bnds_variable(time_atts, time_bnds)
self.assert_mock_has_calls(kwargs['createDimension'], [call('nv', 2)])
req_calls = [
call('time_bnds', 'f8', ('time', 'nv'), zlib=True),
call().setncattr('units', 'units'),
call().__setitem__(slice(None, None, None), time_bnds),
Expand All @@ -186,12 +202,12 @@ def test_set_variable_1(self, f4, f8, **kwargs):
nc.grid_mapping_variable = 'gmn'
atts = dict(a1='A1', a2='A2', _FillValue='fv')
f4.return_value = 'fv4'
nc.set_variable('vname', 'data', 'dims', atts, dtype='f4')
nc.set_variable('vname', 'data', 'dims', atts, dtype=np.float32)
f4.assert_called_once_with('fv')
f8.assert_not_called()

req_calls = [
call('vname', 'f4', 'dims', fill_value='fv4', zlib=True),
call('vname', np.float32, 'dims', fill_value='fv4', zlib=True),
call().setncatts({'a1': 'A1', 'a2': 'A2', 'grid_mapping': 'gmn'}),
call().__setitem__(slice(None, None, None), 'data'),
]
Expand All @@ -210,11 +226,11 @@ def test_set_variable_2(self, f4, f8, **kwargs):
atts = dict(a1='A1', a2='A2', missing_value='fv')
f8.return_value = 'fv8'

nc.set_variable('vname', 'data', 'dims', atts, dtype='f8')
nc.set_variable('vname', 'data', 'dims', atts, dtype=np.double)
f8.assert_called_once_with('fv')
f4.assert_not_called()
req_calls = [
call('vname', 'f8', 'dims', zlib=True),
call('vname', np.double, 'dims', zlib=True),
call().setncatts({
'a1': 'A1',
'a2': 'A2',
Expand Down

0 comments on commit e8500f9

Please sign in to comment.