Skip to content

Commit b335abf

Browse files
committed
feat: enhance rate calculation to use minimum rates across all determinant keys
- Merge base_rate and calculate_period_rates functions into a single calculate_rates function - Modify rate calculation to use minimum rates from all determinant keys instead of first key only - Extract report generation into a separate function for better code organization - Retain all original logging functionality - Add input validation and proper error handling
1 parent 3141083 commit b335abf

File tree

10 files changed

+433
-112
lines changed

10 files changed

+433
-112
lines changed

AeroViz/dataProcess/Optical/_absorption.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ def _absCoe(df, instru, specified_band: list):
55

66
band_AE33 = np.array([370, 470, 520, 590, 660, 880, 950])
77
band_BC1054 = np.array([370, 430, 470, 525, 565, 590, 660, 700, 880, 950])
8+
band_MA350 = np.array([375, 470, 528, 625, 880])
89

910
MAE_AE33 = np.array([18.47, 14.54, 13.14, 11.58, 10.35, 7.77, 7.19]) * 1e-3
1011
MAE_BC1054 = np.array([18.48, 15.90, 14.55, 13.02, 12.10, 11.59, 10.36, 9.77, 7.77, 7.20]) * 1e-3
12+
MAE_MA350 = np.array([24.069, 19.070, 17.028, 14.091, 10.120]) * 1e-3
1113

1214
band = band_AE33 if instru == 'AE33' else band_BC1054
1315
MAE = MAE_AE33 if instru == 'AE33' else MAE_BC1054

AeroViz/plot/templates/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .ammonium_rich import ammonium_rich
22
from .contour import *
3-
from .corr_matrix import corr_matrix
3+
from .corr_matrix import corr_matrix, cross_corr_matrix
44
from .diurnal_pattern import *
55
from .koschmieder import *
66
from .metal_heatmap import metal_heatmaps, process_data_with_two_df

AeroViz/plot/templates/corr_matrix.py

+168-2
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,20 @@
99

1010
from AeroViz.plot.utils import *
1111

12-
__all__ = ['corr_matrix']
12+
__all__ = ['corr_matrix', 'cross_corr_matrix']
1313

1414

1515
@set_figure
1616
def corr_matrix(data: pd.DataFrame,
1717
cmap: str = "RdBu",
1818
ax: Axes | None = None,
19+
items_order: list = None, # 新增參數用於指定順序
1920
**kwargs
2021
) -> tuple[Figure, Axes]:
2122
fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
2223

2324
_corr = data.corr()
25+
breakpoint()
2426
corr = pd.melt(_corr.reset_index(), id_vars='index')
2527
corr.columns = ['x', 'y', 'value']
2628

@@ -94,8 +96,172 @@ def value_to_color(val):
9496
label='p < 0.05'
9597
)
9698

97-
ax.legend(handles=[point2], labels=['p < 0.05'], bbox_to_anchor=(0.05, 1, 0.1, 0.05))
99+
ax.legend(handles=[point2], labels=['p < 0.05'], bbox_to_anchor=(0.02, 1, 0.05, 0.05))
98100

99101
plt.show()
100102

101103
return fig, ax
104+
105+
106+
@set_figure(figsize=(6, 6))
107+
def cross_corr_matrix(data1: pd.DataFrame,
108+
data2: pd.DataFrame,
109+
cmap: str = "RdBu",
110+
ax: Axes | None = None,
111+
items_order: list = None, # 新增參數用於指定順序
112+
**kwargs
113+
) -> tuple[Figure, Axes]:
114+
"""
115+
Create a correlation matrix between two different DataFrames.
116+
117+
Parameters:
118+
-----------
119+
data1 : pd.DataFrame
120+
First DataFrame
121+
data2 : pd.DataFrame
122+
Second DataFrame
123+
cmap : str, optional
124+
Color map for the correlation matrix
125+
ax : Axes, optional
126+
Matplotlib axes to plot on
127+
items_order : list, optional
128+
List specifying the order of items to display
129+
**kwargs : dict
130+
Additional keyword arguments
131+
"""
132+
if ax is None:
133+
fig_kws = kwargs.get('fig_kws', {})
134+
default_figsize = fig_kws.get('figsize', (8, 8))
135+
fig = plt.figure(figsize=default_figsize)
136+
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
137+
else:
138+
fig = ax.get_figure()
139+
140+
# 如果沒有指定順序,使用原始列名順序
141+
if items_order is None:
142+
x_labels = list(data1.columns)
143+
y_labels = list(data2.columns)
144+
else:
145+
# 使用指定順序,但只包含實際存在於數據中的列
146+
x_labels = [item for item in items_order if item in data1.columns]
147+
y_labels = [item for item in items_order if item in data2.columns]
148+
149+
# Calculate cross-correlation between the two DataFrames
150+
correlations = []
151+
p_values_list = []
152+
153+
for col1 in x_labels: # 使用指定順序的列名
154+
for col2 in y_labels:
155+
try:
156+
mask = ~(np.isnan(data1[col1]) | np.isnan(data2[col2]))
157+
if mask.sum() > 2:
158+
corr, p_val = pearsonr(data1[col1][mask], data2[col2][mask])
159+
else:
160+
corr, p_val = np.nan, np.nan
161+
except Exception as e:
162+
print(f"Error calculating correlation for {col1} and {col2}: {str(e)}")
163+
corr, p_val = np.nan, np.nan
164+
165+
correlations.append({
166+
'x': col1,
167+
'y': col2,
168+
'value': corr
169+
})
170+
if p_val is not None and p_val < 0.05:
171+
p_values_list.append({
172+
'x': col1,
173+
'y': col2,
174+
'value': p_val
175+
})
176+
177+
corr = pd.DataFrame(correlations)
178+
p_values = pd.DataFrame(p_values_list)
179+
180+
# Create mapping using the specified order
181+
x_to_num = {label: i for i, label in enumerate(x_labels)}
182+
y_to_num = {label: i for i, label in enumerate(y_labels)}
183+
184+
# 調整標籤顯示
185+
ax.set_xticks([x_to_num[v] for v in x_labels])
186+
ax.set_xticklabels(x_labels, rotation=45, ha='right')
187+
ax.set_yticks([y_to_num[v] for v in y_labels])
188+
ax.set_yticklabels(y_labels)
189+
190+
ax.grid(False, 'major')
191+
ax.grid(True, 'minor')
192+
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
193+
ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)
194+
195+
ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5])
196+
ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])
197+
198+
# Color mapping
199+
n_colors = 256
200+
palette = sns.color_palette(cmap, n_colors=n_colors)
201+
color_min, color_max = [-1, 1]
202+
203+
def value_to_color(val):
204+
if pd.isna(val):
205+
return (1, 1, 1)
206+
val_position = float((val - color_min)) / (color_max - color_min)
207+
val_position = np.clip(val_position, 0, 1)
208+
ind = int(val_position * (n_colors - 1))
209+
return palette[ind]
210+
211+
# Plot correlation squares
212+
x_coords = corr['x'].map(x_to_num)
213+
y_coords = corr['y'].map(y_to_num)
214+
sizes = corr['value'].abs().fillna(0) * 70
215+
colors = [value_to_color(val) for val in corr['value']]
216+
217+
point = ax.scatter(
218+
x=x_coords,
219+
y=y_coords,
220+
s=sizes,
221+
c=colors,
222+
marker='s',
223+
label='$R^{2}$'
224+
)
225+
226+
# 調整顏色軸的位置和大小
227+
cax = fig.add_axes([0.91, 0.1, 0.02, 0.8])
228+
axes_image = plt.cm.ScalarMappable(cmap=colormaps[cmap])
229+
cbar = plt.colorbar(mappable=axes_image, cax=cax, label=r'$R^{2}$')
230+
cbar.set_ticks([0, 0.25, 0.5, 0.75, 1])
231+
cbar.set_ticklabels(np.linspace(-1, 1, 5))
232+
233+
# Plot significance markers
234+
if not p_values.empty:
235+
point2 = ax.scatter(
236+
x=p_values['x'].map(x_to_num),
237+
y=p_values['y'].map(y_to_num),
238+
s=10,
239+
marker='*',
240+
color='k',
241+
label='p < 0.05'
242+
)
243+
ax.legend(handles=[point2], labels=['p < 0.05'],
244+
bbox_to_anchor=(0.005, 1.04), loc='upper left')
245+
246+
# Add labels
247+
ax.set_xlabel('NZ', labelpad=10)
248+
ax.set_ylabel('FS', labelpad=10)
249+
250+
plt.show()
251+
252+
return fig, ax
253+
254+
255+
if __name__ == '__main__':
256+
import pandas as pd
257+
from pandas import to_numeric
258+
259+
df_NZ = pd.read_csv('/Users/chanchihyu/Desktop/NZ_minion_202402-202411.csv', parse_dates=True, index_col=0)
260+
df_FS = pd.read_csv('/Users/chanchihyu/Desktop/FS_minion_202402-202411.csv', parse_dates=True, index_col=0)
261+
262+
items = ['Ext', 'Sca', 'Abs', 'PNC', 'PSC', 'PVC', 'SO2', 'NO', 'NOx', 'NO2', 'CO', 'O3', 'THC', 'NMHC', 'CH4',
263+
'PM10', 'PM2.5', 'WS', 'AT', 'RH',
264+
'OC', 'EC', 'Na+', 'NH4+', 'NO3-', 'SO42-', 'Al', 'Si', 'Ca', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Cu', 'Zn']
265+
df_NZ = df_NZ.apply(to_numeric, errors='coerce')
266+
267+
corr_matrix(df_NZ[items], items_order=items)

AeroViz/plot/templates/metal_heatmap.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ def normalize_and_split(df, df2):
117117
return df, df2
118118

119119

120-
@set_figure(figsize=(12, 3), fs=6)
120+
@set_figure(figsize=(6, 3), fs=8, fw='normal')
121121
def metal_heatmaps(df,
122122
process=True,
123-
major_freq='24h',
124-
minor_freq='12h',
123+
major_freq='10d',
124+
minor_freq='1d',
125125
cmap='jet',
126126
ax: Axes | None = None,
127127
**kwargs
@@ -131,7 +131,7 @@ def metal_heatmaps(df,
131131

132132
fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
133133

134-
sns.heatmap(df.T, vmin=None, vmax=3, cmap=cmap, xticklabels=False, yticklabels=True,
134+
sns.heatmap(df.T, vmin=None, vmax=3, cmap=cmap, xticklabels=True, yticklabels=True,
135135
cbar_kws={'label': 'Z score', "pad": 0.02})
136136
ax.grid(color='gray', linestyle='-', linewidth=0.3)
137137

@@ -142,14 +142,23 @@ def metal_heatmaps(df,
142142
# Set the major and minor ticks
143143
ax.set_xticks(ticks=[df.index.get_loc(t) for t in major_tick])
144144
ax.set_xticks(ticks=[df.index.get_loc(t) for t in minor_tick], minor=True)
145-
ax.set_xticklabels(major_tick.strftime('%F'))
145+
ax.set_xticklabels(major_tick.strftime('%F'), rotation=0)
146146
ax.tick_params(axis='y', rotation=0)
147147

148148
ax.set(xlabel='',
149-
ylabel='',
149+
ylabel='Trace metals',
150150
title=kwargs.get('title', None)
151151
)
152152

153+
if kwargs.get('savefig'):
154+
plt.savefig(kwargs.get('savefig'), dpi=600)
155+
153156
plt.show()
154157

155158
return fig, ax
159+
160+
161+
if __name__ == '__main__':
162+
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
163+
plt.title('text', font={'weight': 'bold'})
164+
plt.show()

0 commit comments

Comments
 (0)