Skip to content

Commit

Permalink
Merge pull request #7 from jzuhone/fix_thermal_models
Browse files Browse the repository at this point in the history
  • Loading branch information
jzuhone authored Jul 7, 2021
2 parents cd8dba2 + 5e415dc commit 5163aad
Show file tree
Hide file tree
Showing 40 changed files with 15,672 additions and 935 deletions.
3 changes: 1 addition & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
include versioneer.py
include acispy/_version.py
include acispy/tests/*
10 changes: 9 additions & 1 deletion acispy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
HistogramPlot, make_dateplots, DummyDatePlot
from acispy.thermal_models import SimulateECSRun, \
ThermalModelRunner, ThermalModelFromLoad, \
ThermalModelFromRun, SimulateSingleObs
ThermalModelFromRun, SimulateSingleState
from acispy.load_review import ACISLoadReview


def test(*args, **kwargs):
"""
Run py.test unit tests.
"""
import testr
return testr.test(*args, **kwargs)

18 changes: 9 additions & 9 deletions acispy/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,29 @@ def _determine_field(self, field):
if field not in self._checked_fields:
if isinstance(field, tuple):
if len(field) != 2:
raise RuntimeError("Invalid field specification {}!".format(field))
raise RuntimeError(f"Invalid field specification {format}!")
fd = (field[0].lower(), field[1].lower())
if fd in self.fields:
checked_field = fd
else:
raise RuntimeError("Cannot find field {}!".format(field))
raise RuntimeError(f"Cannot find field {field}!")
elif isinstance(field, str):
fd = field.lower()
candidates = []
for ftype in self.fields.types:
if (ftype, fd) in self.fields:
candidates.append((ftype, fd))
if len(candidates) > 1:
msg = "Multiple field types for field name %s!\n" % field
msg = f"Multiple field types for field name {field}!\n"
for c in candidates:
msg += " {}\n".format(c)
msg += f" {c}\n"
raise RuntimeError(msg)
elif len(candidates) == 0:
raise RuntimeError("Cannot find field {}!".format(field))
raise RuntimeError(f"Cannot find field {field}!")
else:
checked_field = candidates[0]
else:
raise RuntimeError("Invalid field specification {}!".format(field))
raise RuntimeError(f"Invalid field specification {field}!")
else:
checked_field = field
return checked_field
Expand Down Expand Up @@ -132,7 +132,7 @@ def write_hdf5(self, filename, overwrite=True):
import h5py
import os
if os.path.exists(filename) and not overwrite:
raise IOError("The file %s already exists and overwrite=False!!" % filename)
raise IOError(f"The file {filename} already exists and overwrite=False!!")
f = h5py.File(filename, "w")
if not self.msids._is_empty:
gmsids = f.create_group("msids")
Expand Down Expand Up @@ -348,8 +348,8 @@ def write_msids(self, filename, fields, mask=None, overwrite=False):
"but '%s', '%s' does not have the same " % field +
"set of times as '%s', '%s'!" % (fields[0][0], fields[0][1]))
data = dict(("_".join(k), self[k].value[mask]) for k in fields)
data["times"] = self.times(*fields[0]).value[mask]
data["dates"] = self.dates(*fields[0])[mask]
data["time"] = self.times(*fields[0]).value[mask]
data["date"] = self.dates(*fields[0])[mask]
Table(data).write(filename, format='ascii', overwrite=overwrite)

def write_states(self, filename, overwrite=False):
Expand Down
8 changes: 4 additions & 4 deletions acispy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(self, dfield, ofields):
self.ofields = ofields

def __str__(self):
return ("Derived field {} depends on the following ".format(self.dfield) +
"fields which are not found in this Dataset: {} ".format(self.ofields))
return (f"Derived field {self.dfield} depends on the following "
f"fields which are not found in this Dataset: {self.ofields} ")


class OutputFieldFunction(object):
Expand Down Expand Up @@ -162,9 +162,9 @@ def _simpos(ds):
if "earth_solid_angle" in dset.msids.derived_msids:
def _earth_solid_angle(ds):
# Collect individual MSIDs for use in calc_earth_vis()
ephem_xyzs = [ds["msids", "orbitephem0_{}".format(x)]
ephem_xyzs = [ds["msids", f"orbitephem0_{x}"]
for x in "xyz"]
aoattqt_1234s = [ds["msids","aoattqt{}".format(x)]
aoattqt_1234s = [ds["msids",f"aoattqt{x}"]
for x in range(1, 5)]
ephems = np.array([x.value for x in ephem_xyzs]).transpose()
q_atts = np.array([x.value for x in aoattqt_1234s]).transpose()
Expand Down
6 changes: 3 additions & 3 deletions acispy/load_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def __init__(self, load_name, get_msids=True, tl_file=None):
self.load_name = find_load(load_name)
self.load_letter = self.load_name[-1]
self.load_week = self.load_name[:7]
self.load_year = "20%s" % self.load_week[5:7]
self.load_year = f"20{self.load_week[5:7]}"
self.next_year = str(int(self.load_year)+1)
loaddir = os.path.join(lr_root, self.load_year, self.load_week)
oflsdir = os.path.join(loaddir, "ofls%s" % self.load_letter.lower())
oflsdir = os.path.join(loaddir, f"ofls{self.load_letter.lower()}")
self.load_file = os.path.join(oflsdir, lr_file)
self.events = defaultdict(dict)
self.start_status = self._get_start_status()
Expand Down Expand Up @@ -290,7 +290,7 @@ def __getattr__(self, item):
if item in self.events:
return LoadReviewEvent(item, self.events[item])
else:
raise AttributeError("'LoadReview' object has no attribute '%s'" % item)
raise AttributeError(f"'LoadReview' object has no attribute '{item}'")

def _add_annotations(self, plot, annotations, tbegin, tend):
for i, line in enumerate(plot.ax.lines):
Expand Down
10 changes: 5 additions & 5 deletions acispy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,27 @@ def from_xija(cls, model, components, interp_times=None, masks=None):
def from_load_page(cls, load, components, time_range=None):
components = [comp.lower() for comp in components]
load = find_load(load)
mylog.info("Reading model data from the %s load." % load)
mylog.info(f"Reading model data from the {load} load.")
components = ensure_list(components)
if "fptemp_11" in components:
components.append("earth_solid_angle")
data = {}
for comp in components:
if comp == "earth_solid_angle":
url = "http://cxc.cfa.harvard.edu/acis/FP_thermPredic/"
url += "%s/ofls%s/earth_solid_angles.dat" % (load[:-1].upper(), load[-1].lower())
url += f"{load[:-1].upper()}/ofls{load[-1].lower()}/earth_solid_angles.dat"
table_key = comp
else:
c = comp_map[comp].upper()
table_key = "fptemp" if comp == "fptemp_11" else comp
url = "http://cxc.cfa.harvard.edu/acis/%s_thermPredic/" % c
url += "%s/ofls%s/temperatures.dat" % (load[:-1].upper(), load[-1].lower())
url = f"http://cxc.cfa.harvard.edu/acis/{c}_thermPredic/"
url += f"{load[:-1].upper()}/ofls{load[-1].lower()}/temperatures.dat"
u = requests.get(url)
if not u.ok:
if table_key == "earth_solid_angle":
mylog.warning("Could not find the earth solid angles file. Skipping.")
else:
mylog.warning("Could not find the model page for '%s'. Skipping." % comp)
mylog.warning(f"Could not find the model page for '{comp}'. Skipping.")
continue
table = ascii.read(u.text)
if time_range is None:
Expand Down
2 changes: 1 addition & 1 deletion acispy/msids.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def from_mit_file(cls, filename, tbegin=None, tend=None):
bmask = masks["bilevels"]
bilevels = np.char.strip(table["bilevels"], "b")[bmask]
for i in range(8):
key = "1stat%dst" % (7-i)
key = f"1stat{7-i}dst"
table[key] = np.array(["BAD"]*bmask.size)
table[key][bmask] = np.array([b[i] for b in bilevels])
times[key] = times["bilevels"]
Expand Down
46 changes: 32 additions & 14 deletions acispy/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
from astropy.units import Quantity

datefmt = "%Y-%m-%d %H:%M:%S.%f"

drawstyles = {"simpos": "steps",
"pitch": "steps",
"ccd_count": "steps",
Expand Down Expand Up @@ -195,6 +197,8 @@ def get_figure(plot, fig, subplot, figsize):
if hasattr(plot, "ax2"):
ax2 = plot.ax2
lines2 = plot.lines2
for axis in ['top', 'bottom', 'left', 'right']:
ax.spines[axis].set_linewidth(2)
return fig, ax, lines, ax2, lines2


Expand Down Expand Up @@ -235,10 +239,12 @@ def __init__(self, dates, values, fmt='-b', lw=2, fontsize=18, ls='-',
x = np.asarray(dates)
y = np.asarray(values)
if color is None:
color = "C{}".format(len(lines))
color = f"C{len(lines)}"
ticklocs, fig, ax = plot_cxctime(x, y, fmt=fmt, fig=fig, ax=ax,
lw=lw, ls=ls, color=color, **kwargs)
super(CustomDatePlot, self).__init__(fig, ax, lines, ax2, lines2)
self.ax.tick_params(which="major", width=2, length=6)
self.ax.tick_params(which="minor", width=2, length=3)
self.lines.append(ax.lines[-1])
self.ax.set_xlabel("Date", fontdict={"size": fontsize})
fontProperties = font_manager.FontProperties(size=fontsize)
Expand Down Expand Up @@ -287,6 +293,8 @@ def plot_right(self, dates, values, fmt='-b', lw=2, fontsize=18,
y = np.asarray(values)
plot_cxctime(x, y, fmt=fmt, fig=self.fig,
ax=self.ax2, ls=ls, color=color, lw=lw, **kwargs)
self.ax2.tick_params(which="major", width=2, length=6)
self.ax2.tick_params(which="minor", width=2, length=3)
fontProperties = font_manager.FontProperties(size=fontsize)
for label in self.ax2.get_xticklabels():
label.set_fontproperties(fontProperties)
Expand All @@ -303,9 +311,9 @@ def set_xlim(self, xmin, xmax):
>>> p.set_xlim("2016:050:12:45:47.324", "2016:056:22:32:01.123")
"""
if not isinstance(xmin, datetime):
xmin = datetime.strptime(DateTime(xmin).iso, "%Y-%m-%d %H:%M:%S.%f")
xmin = datetime.strptime(DateTime(xmin).iso, datefmt)
if not isinstance(xmax, datetime):
xmax = datetime.strptime(DateTime(xmax).iso, "%Y-%m-%d %H:%M:%S.%f")
xmax = datetime.strptime(DateTime(xmax).iso, datefmt)
self.ax.set_xlim(xmin, xmax)

def add_hline(self, y, lw=2, ls='-', color='green',
Expand Down Expand Up @@ -333,11 +341,11 @@ def add_hline(self, y, lw=2, ls='-', color='green',
if xmin is None:
xmin = 0
else:
xmin = datetime.strptime(DateTime(xmin).iso, "%Y-%m-%d %H:%M:%S.%f")
xmin = datetime.strptime(DateTime(xmin).iso, datefmt)
if xmax is None:
xmax = 1
else:
xmax = datetime.strptime(DateTime(xmax).iso, "%Y-%m-%d %H:%M:%S.%f")
xmax = datetime.strptime(DateTime(xmax).iso, datefmt)
self.ax.axhline(y=y, lw=lw, ls=ls, color=color, xmin=xmin,
xmax=xmax, label='_nolegend_', **kwargs)

Expand All @@ -363,7 +371,7 @@ def add_vline(self, time, lw=2, ls='solid', color='green', **kwargs):
--------
>>> p.add_vline("2016:101:12:36:10.102", lw=3, ls='dashed', color='red')
"""
time = datetime.strptime(DateTime(time).iso, "%Y-%m-%d %H:%M:%S.%f")
time = datetime.strptime(DateTime(time).iso, datefmt)
self.ax.axvline(x=time, lw=lw, ls=ls, color=color, **kwargs,
label='_nolegend_')

Expand Down Expand Up @@ -393,7 +401,7 @@ def add_text(self, time, y, text, fontsize=18, color='black',
>>> dp.add_text("2016:101:12:36:10.102", 35., "Something happened here!",
... fontsize=15, color='magenta')
"""
time = datetime.strptime(DateTime(time).iso, "%Y-%m-%d %H:%M:%S.%f")
time = datetime.strptime(DateTime(time).iso, datefmt)
self.ax.text(time, y, text, fontsize=fontsize, color=color,
rotation=rotation, **kwargs)

Expand Down Expand Up @@ -701,17 +709,20 @@ def __init__(self, ds, fields, field2=None, fmt='-b', lw=2, ls='-',
self.ax.set_ylim(ymin, ymax)
if self.num_fields > 0:
units = ds.fields[self.fields[0]].units
ulabel = unit_labels.get(units, units)
if self.num_fields > 1:
if units == '':
ylabel = ''
else:
ylabel = '%s (%s)' % (units_map[units], unit_labels.get(units, units))
ylabel = f"{units_map[units]} ({ulabel})"
self.set_ylabel(ylabel)
else:
ylabel = ds.fields[self.fields[0]].display_name
if units != '':
ylabel += ' (%s)' % unit_labels.get(units, units)
ylabel += f" ({ulabel})"
self.set_ylabel(ylabel)
self.ax.tick_params(which="major", width=2, length=6)
self.ax.tick_params(which="minor", width=2, length=3)
if field2 is not None:
field2 = ds._determine_field(field2)
self.field2 = field2
Expand All @@ -720,6 +731,8 @@ def __init__(self, ds, fields, field2=None, fmt='-b', lw=2, ls='-',
self.ax2 = self.ax.twinx()
self.ax2.set_zorder(-10)
self.ax.patch.set_visible(False)
self.ax2.tick_params(which="major", width=2, length=6)
self.ax2.tick_params(which="minor", width=2, length=3)
drawstyle = drawstyles.get(fd2, None)
state_codes = ds.state_codes.get(field2, None)
if not plot_bad:
Expand Down Expand Up @@ -759,8 +772,9 @@ def __init__(self, ds, fields, field2=None, fmt='-b', lw=2, ls='-',
self.ax2.set_ylim(ymin2, ymax2)
units2 = ds.fields[field2].units
ylabel2 = ds.fields[field2].display_name
ulabel2 = unit_labels.get(units2, units2)
if units2 != '':
ylabel2 += ' (%s)' % unit_labels.get(units2, units2)
ylabel2 += f" ({ulabel2})"
self.set_ylabel2(ylabel2)
else:
self.field2 = None
Expand Down Expand Up @@ -1087,8 +1101,9 @@ def _annotate_plot(self, fontsize, density, cumulative):
label.set_fontproperties(fontProperties)
for label in self.ax.get_yticklabels():
label.set_fontproperties(fontProperties)
ulabel = unit_labels.get(self.unit, self.unit)
if self.unit != '':
self.xlabel += ' (%s)' % unit_labels.get(self.unit, self.unit)
self.xlabel += f" ({ulabel})"
self.ax.set_xlabel(self.xlabel, fontsize=18)
if density:
self.ylabel = "Fraction of Time"
Expand Down Expand Up @@ -1128,10 +1143,12 @@ def _annotate_plot(self, fontsize):
label.set_fontproperties(fontProperties)
for label in self.ax.get_yticklabels():
label.set_fontproperties(fontProperties)
uxlabel = unit_labels.get(self.xunit, self.xunit)
uylabel = unit_labels.get(self.yunit, self.yunit)
if self.xunit != '':
self.xlabel += ' (%s)' % unit_labels.get(self.xunit, self.xunit)
self.xlabel += f" ({uxlabel})"
if self.yunit != '':
self.ylabel += ' (%s)' % unit_labels.get(self.yunit, self.yunit)
self.ylabel += f" ({uylabel})"
self.set_xlabel(self.xlabel)
self.set_ylabel(self.ylabel)
return fontProperties
Expand Down Expand Up @@ -1302,7 +1319,8 @@ def __init__(self, ds, x_field, y_field, c_field=None,
clabel = self.ds.fields[c_field].display_name
cunit = self.ds.fields[c_field].units
if cunit != '':
clabel += ' (%s)' % unit_labels.get(cunit, cunit)
uclabel = unit_labels.get(cunit, cunit)
clabel += f" ({uclabel})"
divider = make_axes_locatable(self.ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cb = plt.colorbar(self.pp, cax=cax)
Expand Down
36 changes: 14 additions & 22 deletions acispy/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def __init__(self, table):
new_table = OrderedDict()
if isinstance(table, np.ndarray):
state_names = list(table.dtype.names)
if "date" in state_names:
table = rf.append_fields(
table, ['time'],
[date2secs(table["date"])],
usemask=False
)
if "tstart" not in state_names:
table = rf.append_fields(
table, ["tstart", "tstop"],
Expand All @@ -46,7 +52,10 @@ def __init__(self, table):
table["tstart"] = date2secs(table["datestart"])
table["tstop"] = date2secs(table["datestop"])
state_names += ["tstart", "tstop"]
times = Quantity([table["tstart"], table["tstop"]], "s")
if "tstart" in state_names:
times = Quantity([table["tstart"], table["tstop"]], "s")
else:
times = Quantity(table["time"], "s")
for k in state_names:
v = np.asarray(table[k])
if k == "trans_keys" and v.dtype.char == "O":
Expand Down Expand Up @@ -78,21 +87,11 @@ def from_kadi_states(cls, tstart, tstop, state_keys=None):
merge_identical=True).as_array()
return cls(t)

@classmethod
def from_database(cls, tstart, tstop, state_keys=None, server=None):
from Chandra.cmd_states import fetch_states
tstart = get_time(tstart)
tstop = get_time(tstop)
if state_keys is not None:
state_keys = ensure_list(state_keys)
t = fetch_states(tstart, tstop, vals=state_keys, server=server)
return cls(t)

@classmethod
def from_load_page(cls, load, comp="DPA"):
load = find_load(load)
url = "http://cxc.cfa.harvard.edu/acis/%s_thermPredic/" % comp
url += "%s/ofls%s/states.dat" % (load[:-1].upper(), load[-1].lower())
url = f"http://cxc.cfa.harvard.edu/acis/{comp}_thermPredic/"
url += f"{load[:-1].upper()}/ofls{load[-1].lower()}/states.dat"
u = requests.get(url)
t = ascii.read(u.text)
table = dict((k, t[k].data) for k in t.keys())
Expand All @@ -111,16 +110,9 @@ def from_load_file(cls, states_file):
return cls(table)

@classmethod
def from_commands(cls, tstart, tstop, cmds=None, state_keys=None):
from kadi import commands
def from_commands(cls, cmds, state_keys=None):
from kadi.commands import states
tstart = get_time(tstart)
tstop = get_time(tstop)
if cmds is None:
cmds = commands.get_cmds(tstart, tstop)
continuity = states.get_continuity(tstart, state_keys)
t = states.get_states(cmds=cmds, continuity=continuity,
state_keys=state_keys,
t = states.get_states(cmds=cmds, state_keys=state_keys,
merge_identical=True).as_array()
return cls(t)

Expand Down
Loading

0 comments on commit 5163aad

Please sign in to comment.