|
9 | 9 |
|
10 | 10 | from AeroViz.plot.utils import *
|
11 | 11 |
|
12 |
| -__all__ = ['corr_matrix'] |
| 12 | +__all__ = ['corr_matrix', 'cross_corr_matrix'] |
13 | 13 |
|
14 | 14 |
|
15 | 15 | @set_figure
|
16 | 16 | def corr_matrix(data: pd.DataFrame,
|
17 | 17 | cmap: str = "RdBu",
|
18 | 18 | ax: Axes | None = None,
|
| 19 | + items_order: list = None, # 新增參數用於指定順序 |
19 | 20 | **kwargs
|
20 | 21 | ) -> tuple[Figure, Axes]:
|
21 | 22 | fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)
|
22 | 23 |
|
23 | 24 | _corr = data.corr()
|
| 25 | + breakpoint() |
24 | 26 | corr = pd.melt(_corr.reset_index(), id_vars='index')
|
25 | 27 | corr.columns = ['x', 'y', 'value']
|
26 | 28 |
|
@@ -94,8 +96,172 @@ def value_to_color(val):
|
94 | 96 | label='p < 0.05'
|
95 | 97 | )
|
96 | 98 |
|
97 |
| - ax.legend(handles=[point2], labels=['p < 0.05'], bbox_to_anchor=(0.05, 1, 0.1, 0.05)) |
| 99 | + ax.legend(handles=[point2], labels=['p < 0.05'], bbox_to_anchor=(0.02, 1, 0.05, 0.05)) |
98 | 100 |
|
99 | 101 | plt.show()
|
100 | 102 |
|
101 | 103 | return fig, ax
|
| 104 | + |
| 105 | + |
| 106 | +@set_figure(figsize=(6, 6)) |
| 107 | +def cross_corr_matrix(data1: pd.DataFrame, |
| 108 | + data2: pd.DataFrame, |
| 109 | + cmap: str = "RdBu", |
| 110 | + ax: Axes | None = None, |
| 111 | + items_order: list = None, # 新增參數用於指定順序 |
| 112 | + **kwargs |
| 113 | + ) -> tuple[Figure, Axes]: |
| 114 | + """ |
| 115 | + Create a correlation matrix between two different DataFrames. |
| 116 | +
|
| 117 | + Parameters: |
| 118 | + ----------- |
| 119 | + data1 : pd.DataFrame |
| 120 | + First DataFrame |
| 121 | + data2 : pd.DataFrame |
| 122 | + Second DataFrame |
| 123 | + cmap : str, optional |
| 124 | + Color map for the correlation matrix |
| 125 | + ax : Axes, optional |
| 126 | + Matplotlib axes to plot on |
| 127 | + items_order : list, optional |
| 128 | + List specifying the order of items to display |
| 129 | + **kwargs : dict |
| 130 | + Additional keyword arguments |
| 131 | + """ |
| 132 | + if ax is None: |
| 133 | + fig_kws = kwargs.get('fig_kws', {}) |
| 134 | + default_figsize = fig_kws.get('figsize', (8, 8)) |
| 135 | + fig = plt.figure(figsize=default_figsize) |
| 136 | + ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) |
| 137 | + else: |
| 138 | + fig = ax.get_figure() |
| 139 | + |
| 140 | + # 如果沒有指定順序,使用原始列名順序 |
| 141 | + if items_order is None: |
| 142 | + x_labels = list(data1.columns) |
| 143 | + y_labels = list(data2.columns) |
| 144 | + else: |
| 145 | + # 使用指定順序,但只包含實際存在於數據中的列 |
| 146 | + x_labels = [item for item in items_order if item in data1.columns] |
| 147 | + y_labels = [item for item in items_order if item in data2.columns] |
| 148 | + |
| 149 | + # Calculate cross-correlation between the two DataFrames |
| 150 | + correlations = [] |
| 151 | + p_values_list = [] |
| 152 | + |
| 153 | + for col1 in x_labels: # 使用指定順序的列名 |
| 154 | + for col2 in y_labels: |
| 155 | + try: |
| 156 | + mask = ~(np.isnan(data1[col1]) | np.isnan(data2[col2])) |
| 157 | + if mask.sum() > 2: |
| 158 | + corr, p_val = pearsonr(data1[col1][mask], data2[col2][mask]) |
| 159 | + else: |
| 160 | + corr, p_val = np.nan, np.nan |
| 161 | + except Exception as e: |
| 162 | + print(f"Error calculating correlation for {col1} and {col2}: {str(e)}") |
| 163 | + corr, p_val = np.nan, np.nan |
| 164 | + |
| 165 | + correlations.append({ |
| 166 | + 'x': col1, |
| 167 | + 'y': col2, |
| 168 | + 'value': corr |
| 169 | + }) |
| 170 | + if p_val is not None and p_val < 0.05: |
| 171 | + p_values_list.append({ |
| 172 | + 'x': col1, |
| 173 | + 'y': col2, |
| 174 | + 'value': p_val |
| 175 | + }) |
| 176 | + |
| 177 | + corr = pd.DataFrame(correlations) |
| 178 | + p_values = pd.DataFrame(p_values_list) |
| 179 | + |
| 180 | + # Create mapping using the specified order |
| 181 | + x_to_num = {label: i for i, label in enumerate(x_labels)} |
| 182 | + y_to_num = {label: i for i, label in enumerate(y_labels)} |
| 183 | + |
| 184 | + # 調整標籤顯示 |
| 185 | + ax.set_xticks([x_to_num[v] for v in x_labels]) |
| 186 | + ax.set_xticklabels(x_labels, rotation=45, ha='right') |
| 187 | + ax.set_yticks([y_to_num[v] for v in y_labels]) |
| 188 | + ax.set_yticklabels(y_labels) |
| 189 | + |
| 190 | + ax.grid(False, 'major') |
| 191 | + ax.grid(True, 'minor') |
| 192 | + ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True) |
| 193 | + ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True) |
| 194 | + |
| 195 | + ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5]) |
| 196 | + ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5]) |
| 197 | + |
| 198 | + # Color mapping |
| 199 | + n_colors = 256 |
| 200 | + palette = sns.color_palette(cmap, n_colors=n_colors) |
| 201 | + color_min, color_max = [-1, 1] |
| 202 | + |
| 203 | + def value_to_color(val): |
| 204 | + if pd.isna(val): |
| 205 | + return (1, 1, 1) |
| 206 | + val_position = float((val - color_min)) / (color_max - color_min) |
| 207 | + val_position = np.clip(val_position, 0, 1) |
| 208 | + ind = int(val_position * (n_colors - 1)) |
| 209 | + return palette[ind] |
| 210 | + |
| 211 | + # Plot correlation squares |
| 212 | + x_coords = corr['x'].map(x_to_num) |
| 213 | + y_coords = corr['y'].map(y_to_num) |
| 214 | + sizes = corr['value'].abs().fillna(0) * 70 |
| 215 | + colors = [value_to_color(val) for val in corr['value']] |
| 216 | + |
| 217 | + point = ax.scatter( |
| 218 | + x=x_coords, |
| 219 | + y=y_coords, |
| 220 | + s=sizes, |
| 221 | + c=colors, |
| 222 | + marker='s', |
| 223 | + label='$R^{2}$' |
| 224 | + ) |
| 225 | + |
| 226 | + # 調整顏色軸的位置和大小 |
| 227 | + cax = fig.add_axes([0.91, 0.1, 0.02, 0.8]) |
| 228 | + axes_image = plt.cm.ScalarMappable(cmap=colormaps[cmap]) |
| 229 | + cbar = plt.colorbar(mappable=axes_image, cax=cax, label=r'$R^{2}$') |
| 230 | + cbar.set_ticks([0, 0.25, 0.5, 0.75, 1]) |
| 231 | + cbar.set_ticklabels(np.linspace(-1, 1, 5)) |
| 232 | + |
| 233 | + # Plot significance markers |
| 234 | + if not p_values.empty: |
| 235 | + point2 = ax.scatter( |
| 236 | + x=p_values['x'].map(x_to_num), |
| 237 | + y=p_values['y'].map(y_to_num), |
| 238 | + s=10, |
| 239 | + marker='*', |
| 240 | + color='k', |
| 241 | + label='p < 0.05' |
| 242 | + ) |
| 243 | + ax.legend(handles=[point2], labels=['p < 0.05'], |
| 244 | + bbox_to_anchor=(0.005, 1.04), loc='upper left') |
| 245 | + |
| 246 | + # Add labels |
| 247 | + ax.set_xlabel('NZ', labelpad=10) |
| 248 | + ax.set_ylabel('FS', labelpad=10) |
| 249 | + |
| 250 | + plt.show() |
| 251 | + |
| 252 | + return fig, ax |
| 253 | + |
| 254 | + |
| 255 | +if __name__ == '__main__': |
| 256 | + import pandas as pd |
| 257 | + from pandas import to_numeric |
| 258 | + |
| 259 | + df_NZ = pd.read_csv('/Users/chanchihyu/Desktop/NZ_minion_202402-202411.csv', parse_dates=True, index_col=0) |
| 260 | + df_FS = pd.read_csv('/Users/chanchihyu/Desktop/FS_minion_202402-202411.csv', parse_dates=True, index_col=0) |
| 261 | + |
| 262 | + items = ['Ext', 'Sca', 'Abs', 'PNC', 'PSC', 'PVC', 'SO2', 'NO', 'NOx', 'NO2', 'CO', 'O3', 'THC', 'NMHC', 'CH4', |
| 263 | + 'PM10', 'PM2.5', 'WS', 'AT', 'RH', |
| 264 | + 'OC', 'EC', 'Na+', 'NH4+', 'NO3-', 'SO42-', 'Al', 'Si', 'Ca', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Cu', 'Zn'] |
| 265 | + df_NZ = df_NZ.apply(to_numeric, errors='coerce') |
| 266 | + |
| 267 | + corr_matrix(df_NZ[items], items_order=items) |
0 commit comments