diff --git a/geodataset/geodataset.py b/geodataset/geodataset.py index cda6f6c..cc39c1c 100644 --- a/geodataset/geodataset.py +++ b/geodataset/geodataset.py @@ -126,10 +126,10 @@ def set_projection_variable(self): pvar = self.createVariable(self.grid_mapping_variable, 'i1') pvar.setncatts(self.get_grid_mapping_ncattrs()) - def set_time_variables_dimensions(self, time_data, time_atts, time_bnds_data): + def set_time_variable(self, time_data, time_atts): """ - set the temporal dimensions (time, nv) - and variables (time, time_bnds) + set the temporal dimensions: time + and variables: time Parameters: ----------- @@ -137,25 +137,32 @@ def set_time_variables_dimensions(self, time_data, time_atts, time_bnds_data): data for time variable time_atts : dict netcdf attributes for time variable - time_bnds_data : np.array - data for time_bnds variable - time_bnds_atts : dict - netcdf attributes for time_bnds variable """ # dimensions self.createDimension('time', None)#time should be unlimited - self.createDimension('nv', 2) # time should have units and a calendar attribute ncatts = dict(**time_atts) ncatts['calendar'] = time_atts.get('calendar', 'standard') - units = time_atts['units'] # time var tvar = self.createVariable('time', 'f8', ('time',), zlib=True) tvar.setncatts(ncatts) tvar[:] = time_data - # time_bnds var - just needs units + + def set_time_bnds_variable(self, time_atts, time_bnds_data): + """ + set the temporal dimension: nv + and variable: time_bnds + + Parameters: + ----------- + time_atts : dict + netcdf attributes for time variable + time_bnds_data : np.array + data for time_bnds variable + """ + self.createDimension('nv', 2) tbvar = self.createVariable('time_bnds', 'f8', ('time', 'nv'), zlib=True) - tbvar.setncattr('units', units) + tbvar.setncattr('units', time_atts['units']) tbvar[:] = time_bnds_data def set_xy_dims(self, x, y): @@ -200,7 +207,7 @@ def set_lonlat(self, lon, lat): dst_var.setncattr('units', units) dst_var[:] = data - def set_variable(self, vname, data, dims, atts, dtype='f4'): + def set_variable(self, vname, data, dims, atts, dtype=np.float32): """ set variable data and attributes @@ -214,18 +221,17 @@ def set_variable(self, vname, data, dims, atts, dtype='f4'): list of dimension names for the variable atts : dict netcdf attributes to set - dtype : str - netcdf data type for new variable (eg 'f4' or 'f8') + dtype : type + netcdf data type for new variable (eg np.float32 or np.double) """ - type_converter = dict(f4=np.float32, f8=np.double)[dtype] ncatts = {k:v for k,v in atts.items() if k != '_FillValue'} kw = dict(zlib=True)# use compression if '_FillValue' in atts: # needs to be a keyword for createVariable and of right data type - kw['fill_value'] = type_converter(atts['_FillValue']) + kw['fill_value'] = dtype(atts['_FillValue']) if 'missing_value' in atts: # needs to be of right data type - ncatts['missing_value'] = type_converter(atts['missing_value']) + ncatts['missing_value'] = dtype(atts['missing_value']) dst_var = self.createVariable(vname, dtype, dims, **kw) ncatts['grid_mapping'] = self.grid_mapping_variable dst_var.setncatts(ncatts) diff --git a/geodataset/tests/test_geodataset.py b/geodataset/tests/test_geodataset.py index d36936c..dbde097 100644 --- a/geodataset/tests/test_geodataset.py +++ b/geodataset/tests/test_geodataset.py @@ -127,7 +127,6 @@ def test_set_xy_dims(self, **kwargs): createVariable=DEFAULT, ) def test_set_lonlat(self, **kwargs): - slon = (2,2) slat = (3,3) lon = np.random.normal(size=slon) @@ -155,19 +154,36 @@ def test_set_lonlat(self, **kwargs): createDimension=DEFAULT, createVariable=DEFAULT, ) - def test_set_time_variables_dimensions(self, **kwargs): + def test_set_time_variable(self, **kwargs): nc = GeoDatasetWrite() nt = 3 - time_inds = [1,2] time = np.random.normal(size=(nt,)) - time_bnds = np.random.normal(size=(nt,2)) time_atts = dict(a1='A1', a2='A2', units='units') - nc.set_time_variables_dimensions(time, time_atts, time_bnds) - self.assert_mock_has_calls(kwargs['createDimension'], [call('time', None), call('nv', 2)]) + nc.set_time_variable(time, time_atts) + self.assert_mock_has_calls(kwargs['createDimension'], [call('time', None)]) req_calls = [ - call('time', 'f8', ('time',), zlib=True), call().setncatts({'a1': 'A1', 'a2': 'A2', 'units': 'units', 'calendar': 'standard'}), + call('time', 'f8', ('time',), zlib=True), + call().setncatts({'a1': 'A1', 'a2': 'A2', 'units': 'units', 'calendar': 'standard'}), call().__setitem__(slice(None, None, None), time), + ] + self.assert_mock_has_calls(kwargs['createVariable'], req_calls) + + @patch.multiple(GeoDatasetWrite, + __init__=MagicMock(return_value=None), + createDimension=DEFAULT, + createVariable=DEFAULT, + ) + def test_set_time_bnds_variable(self, **kwargs): + nc = GeoDatasetWrite() + nt = 3 + time = np.random.normal(size=(nt,)) + time_bnds = np.random.normal(size=(nt,2)) + time_atts = dict(a1='A1', a2='A2', units='units') + + nc.set_time_bnds_variable(time_atts, time_bnds) + self.assert_mock_has_calls(kwargs['createDimension'], [call('nv', 2)]) + req_calls = [ call('time_bnds', 'f8', ('time', 'nv'), zlib=True), call().setncattr('units', 'units'), call().__setitem__(slice(None, None, None), time_bnds), @@ -186,12 +202,12 @@ def test_set_variable_1(self, f4, f8, **kwargs): nc.grid_mapping_variable = 'gmn' atts = dict(a1='A1', a2='A2', _FillValue='fv') f4.return_value = 'fv4' - nc.set_variable('vname', 'data', 'dims', atts, dtype='f4') + nc.set_variable('vname', 'data', 'dims', atts, dtype=np.float32) f4.assert_called_once_with('fv') f8.assert_not_called() req_calls = [ - call('vname', 'f4', 'dims', fill_value='fv4', zlib=True), + call('vname', np.float32, 'dims', fill_value='fv4', zlib=True), call().setncatts({'a1': 'A1', 'a2': 'A2', 'grid_mapping': 'gmn'}), call().__setitem__(slice(None, None, None), 'data'), ] @@ -210,11 +226,11 @@ def test_set_variable_2(self, f4, f8, **kwargs): atts = dict(a1='A1', a2='A2', missing_value='fv') f8.return_value = 'fv8' - nc.set_variable('vname', 'data', 'dims', atts, dtype='f8') + nc.set_variable('vname', 'data', 'dims', atts, dtype=np.double) f8.assert_called_once_with('fv') f4.assert_not_called() req_calls = [ - call('vname', 'f8', 'dims', zlib=True), + call('vname', np.double, 'dims', zlib=True), call().setncatts({ 'a1': 'A1', 'a2': 'A2',