Skip to content

Commit

Permalink
Changed MPI->multivariate for computing each SPI to match the bivaria…
Browse files Browse the repository at this point in the history
…te signature. Also fixed the pdist signature (it was adjacency)
  • Loading branch information
olivercliff committed Dec 16, 2021
1 parent 5e58cee commit ba78e67
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 27 deletions.
21 changes: 19 additions & 2 deletions demos/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,31 @@

import seaborn as sns

calc = Calculator(dataset=load_dataset('forex'))
# Load one of our stored datasets
dataset = load_dataset('forex')

# visualize the dataset as a heat map (also called a temporal raster plot or carpet plot)
plt.pcolormesh(dataset.to_numpy(squeeze=True),cmap='coolwarm',vmin=-2,vmax=2)
plt.show()

# Instantiate the calculator (inputting the dataset)
calc = Calculator(dataset=dataset)

# Compute all SPIs (this may take a while)
calc.compute()

# Now, we can access all of the matrices by calling calc.table.
# This property will be an Nx(SN) pandas dataframe, where N is the number of processes in the dataset and S is the number of SPIs
print(calc.table)

# We can use this to compute the correlation between all of the methods on this dataset...
corrmat = calc.table.stack().corr(method='spearman').abs()

# ...and plot this correlation matrix
sns.set(font_scale=0.5)
g = sns.clustermap( corrmat.fillna(0), mask=corrmat.isna(),
center=0.0,
cmap='RdYlBu_r',
xticklabels=1, yticklabels=1 )
plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), rotation=45, ha='right')
plt.show()
plt.show()
6 changes: 3 additions & 3 deletions pyspi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def bivariate(self,data,i=None,j=None):
raise NotImplementedError("Method not yet overloaded.")

@parse_multivariate
def mpi(self,data):
def multivariate(self,data):
""" Compute the dependency statistics for the entire multivariate dataset
"""
A = np.empty((data.n_processes,data.n_processes))
Expand Down Expand Up @@ -135,8 +135,8 @@ def ispositive(self):
return False

@parse_multivariate
def mpi(self,data):
A = super(undirected,self).mpi(data)
def multivariate(self,data):
A = super(undirected,self).multivariate(data)

li = np.tril_indices(data.n_processes,-1)
A[li] = A.T[li]
Expand Down
12 changes: 6 additions & 6 deletions pyspi/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

""" TODO: use the MPI class for each entry in the table
"""
class MPI():
class multivariate():
def __init__(self, procnames, S=None):
if S is None:
S = np.full((len(procnames),len(procnames)),np.nan)
Expand Down Expand Up @@ -158,14 +158,14 @@ def compute(self,replication=None):
replication = 0

pbar = tqdm(self.spis.keys())
for m in pbar:
pbar.set_description(f'Processing [{self._name}: {m}]')
for spi in pbar:
pbar.set_description(f'Processing [{self._name}: {spi}]')
start_time = time.time()
try:
self._table[m] = self._spis[m].mpi(self.dataset)
self._table[spi] = self._spis[spi].multivariate(self.dataset)
except Exception as err:
warnings.warn(f'Caught {type(err)} for SPI "{self._statnames[m]}": {err}')
self._table[m] = np.NaN
warnings.warn(f'Caught {type(err)} for SPI "{spi}": {err}')
self._table[spi] = np.NaN
pbar.close()

def rmmin(self):
Expand Down
2 changes: 1 addition & 1 deletion pyspi/statistics/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _from_cache(self,data):
return mycov

@parse_multivariate
def mpi(self,data):
def multivariate(self,data):
mycov = self._from_cache(data)
matrix = getattr(mycov,self._kind+'_')
np.fill_diagonal(matrix,np.nan)
Expand Down
2 changes: 1 addition & 1 deletion pyspi/statistics/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _from_cache(self,data):
return ccmf

@parse_multivariate
def mpi(self,data):
def multivariate(self,data):
ccmf = self._from_cache(data)

if self._statistic == 'mean':
Expand Down
2 changes: 1 addition & 1 deletion pyspi/statistics/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self,metric='euclidean',**kwargs):
self.name += f'_{metric}'

@parse_multivariate
def adjacency(self,data):
def multivariate(self,data):
return pairwise_distances(data.to_numpy(squeeze=True),metric=self._metric)

""" TODO: include optional kernels in each method
Expand Down
8 changes: 4 additions & 4 deletions pyspi/statistics/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_cache(self,data):
return res, freq

@parse_multivariate
def mpi(self, data):
def multivariate(self, data):
adj_freq, freq = self._get_cache(data)
freq_id = np.where((freq >= self._fmin) * (freq <= self._fmax))[0]

Expand Down Expand Up @@ -299,7 +299,7 @@ def _get_statistic(self,C):
# self.name = self.name + paramstr

# @parse_multivariate
# def mpi(self,data):
# def multivariate(self,data):
# # This should be changed to conditioning on all, rather than averaging all conditionals
# if not hasattr(data,'pcoh'):
# z = np.squeeze(data.to_numpy())
Expand Down Expand Up @@ -372,7 +372,7 @@ def _get_cache(self,data):
return F, freq

@parse_multivariate
def mpi(self,data):
def multivariate(self,data):
try:
F, freq = self._get_cache(data)
freq_id = np.where((freq >= self._fmin) * (freq <= self._fmax))[0]
Expand All @@ -399,7 +399,7 @@ def __init__(self,orth=False,log=False,absolute=False):
self.name += '_abs'

@parse_multivariate
def mpi(self, data):
def multivariate(self, data):
z = np.moveaxis(data.to_numpy(),2,0)
adj = np.squeeze(mnec.envelope_correlation(z,orthogonalize=self._orth,log=self._log,absolute=self._absolute))
np.fill_diagonal(adj,np.nan)
Expand Down
4 changes: 2 additions & 2 deletions pyspi/statistics/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_cache(self,data):
return conn, freq_id

@parse_multivariate
def mpi(self, data):
def multivariate(self, data):
adj_freq, freq_id = self._get_cache(data)
try:
adj = self._statfn(adj_freq[...,freq_id,:], axis=(2,3))
Expand Down Expand Up @@ -189,7 +189,7 @@ def _get_cache(self,data):
return psi, freq_id

@parse_multivariate
def mpi(self, data):
def multivariate(self, data):
adj_freq, freq_id = self._get_cache(data)
adj = self._statfn(np.real(adj_freq[...,freq_id]), axis=(2,3))

Expand Down
14 changes: 7 additions & 7 deletions test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_yaml():
assert calc.n_statistics == len(calc._statistics), (
'Property not equal to number of statistics')

def test_mpi():
def test_multivariate():
# Load in all base statistics from the YAML file

data = get_data()
Expand Down Expand Up @@ -81,14 +81,14 @@ def test_mpi():
if any([m.name == e for e in excuse_stochastic]):
continue

m.mpi(get_more_data())
m.multivariate(get_more_data())

scratch_adj = m.mpi(data.to_numpy())
adj = m.mpi(data)
scratch_adj = m.multivariate(data.to_numpy())
adj = m.multivariate(data)
assert np.allclose(adj,scratch_adj,rtol=1e-1,atol=1e-2,equal_nan=True), (
f'{m.name} ({m.humanname}) mpi output changed between cached and strach computations: {adj} != {scratch_adj}')

recomp_adj = m.mpi(data)
recomp_adj = m.multivariate(data)
assert np.allclose(adj,recomp_adj,rtol=1e-1,atol=1e-2,equal_nan=True), (
f'{m.name} ({m.humanname}) mpi output changed when recomputing.')

Expand All @@ -111,7 +111,7 @@ def test_mpi():
assert t_s == pytest.approx(new_t_s,rel=1e-1,abs=1e-2), (
f'{m.name} ({m.humanname}) Bivariate output from cache mismatch results from scratch for computation ({j},{i}): {t_s} != {new_t_s}')
except NotImplementedError:
a = m.mpi(p[[i,j]])
a = m.multivariate(p[[i,j]])
s_t, t_s = a[0,1], a[1,0]

if not math.isfinite(s_t):
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_group():
test_yaml()
test_load()
test_group()
test_mpi()
test_multivariate()

# This was a bit tricky to implement so just ensuring it passes a test from the creator's website
test_ccm()

0 comments on commit ba78e67

Please sign in to comment.