From 154ca177b4c675c277e1be5ca7e25535b8b675f4 Mon Sep 17 00:00:00 2001 From: Yzy <2154597198@qq.com> Date: Mon, 19 Feb 2024 15:14:51 +0800 Subject: [PATCH] Update model.py and utils.py --- src/gnnwr/models.py | 11 +++++------ src/gnnwr/utils.py | 45 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/gnnwr/models.py b/src/gnnwr/models.py index 928f335..b4774b5 100644 --- a/src/gnnwr/models.py +++ b/src/gnnwr/models.py @@ -7,7 +7,7 @@ import torch.optim as optim import warnings from sklearn.metrics import r2_score -from torch.utils.tensorboard import SummaryWriter # 用于保存训练过程 +from torch.utils.tensorboard import SummaryWriter # to save the process of the model from tqdm import trange from collections import OrderedDict import logging @@ -15,7 +15,6 @@ from .utils import OLS, DIAGNOSIS -# 23.6.8_TODO: 寻找合适的优化器 考虑SGD+学习率调整 输出权重 class GNNWR: r""" GNNWR(Geographically neural network coefficiented regression) is a model to address spatial non-stationarity in various domains with complex geographical processes, @@ -526,7 +525,7 @@ def predict(self, dataset): def predict_coef(self, dataset): """ - predict the spatial coefficient of the dataset + predict the spatial coefficient of the independent variable Parameters ---------- @@ -557,7 +556,7 @@ def predict_coef(self, dataset): def load_model(self, path, use_dict=False, map_location=None): """ - load the model + load the model from the path Parameters ---------- @@ -634,7 +633,7 @@ def add_graph(self): def result(self, path=None, use_dict=False, map_location=None): """ - print the result of the model, including the model structure, optimizer,the result of test dataset + print the result of the model, including the model name, regression fomula and the result of test dataset Parameters ---------- @@ -700,7 +699,7 @@ def result(self, path=None, use_dict=False, map_location=None): def reg_result(self, filename=None, model_path=None, use_dict=False, only_return=False, map_location=None): """ - save the regression result of the model, including the coefficient of each argument, the bias, the predicted result + save the regression result of the model, including the coefficient of each argument, the bias and the predicted result Parameters ---------- diff --git a/src/gnnwr/utils.py b/src/gnnwr/utils.py index 6173792..a8cb190 100644 --- a/src/gnnwr/utils.py +++ b/src/gnnwr/utils.py @@ -32,7 +32,9 @@ def __init__(self, dataset, xName: list, yName: list): class DIAGNOSIS: """ `DIAGNOSIS` is the class to calculate the diagnoses of the result of GNNWR/GTNNWR. - + These diagnoses include F1-test, F2-test, F3-test, AIC, AICc, R2, Adjust_R2, RMSE (Root Mean Square Error). + The explanation of these diagnoses can be found in the paper + `Geographically neural network weighted regression for the accurate estimation of spatial non-stationarity `. :param weight: output of the neural network :param x_data: the independent variables :param y_data: the dependent variables @@ -70,7 +72,6 @@ def __init__(self, weight, x_data, y_data, y_pred): self.f3_dict_2 = None def hat(self): """ - :return: hat matrix """ return self.__hat @@ -174,6 +175,15 @@ def RMSE(self): class Visualize: + """ + `Visualize` is the class to visualize the data and the result of GNNWR/GTNNWR. + It based on the `folium` package and use GaoDe map as the background. And it can display the dataset, the coefficients heatmap, and the dot map, + which helps to understand the spatial distribution of the data and the result of GNNWR/GTNNWR better. + + :param data: the input data + :param lon_lat_columns: the columns of longitude and latitude + :param zoom: the zoom of the map + """ def __init__(self, data, lon_lat_columns=None, zoom=4): self.__raw_data = data self.__tiles = 'https://wprd01.is.autonavi.com/appmaptile?x={x}&y={y}&z={z}&lang=en&size=1&scl=1&style=7' @@ -205,6 +215,15 @@ def __init__(self, data, lon_lat_columns=None, zoom=4): raise ValueError("given data is not instance of GNNWR") def display_dataset(self, name="all", y_column=None, colors=None, steps=20, vmin=None, vmax=None): + """ + Display the dataset on the map, including the train, valid, test dataset. + + :param name: the name of the dataset, including 'all', 'train', 'valid', 'test' + :param y_column: the column of the displayed variable + :param colors: the list of colors, if not given, the default color is used + :param steps: the steps of the colors + + """ if colors is None: colors = [] if y_column is None: @@ -241,6 +260,15 @@ def display_dataset(self, name="all", y_column=None, colors=None, steps=20, vmin return res def coefs_heatmap(self, data_column, colors=None, steps=20, vmin=None, vmax=None): + """ + Display the heatmap of the coefficients of the result of GNNWR/GTNNWR. + + :param data_column: the column of the displayed variable + :param colors: the list of colors, if not given, the default color is used + :param steps: the steps of the colors + :param vmin: the minimum value of the displayed variable, if not given, the minimum value of the variable is used + :param vmax: the maximum value of the displayed variable, if not given, the maximum value of the variable is used + """ if colors is None: colors = [] res = folium.Map(location=[self.__center_lat, self.__center_lon], zoom_start=self.__zoom, tiles=self.__tiles, @@ -261,6 +289,19 @@ def coefs_heatmap(self, data_column, colors=None, steps=20, vmin=None, vmax=None return res def dot_map(self, data, lon_column, lat_column, y_column, zoom=4, colors=None, steps=20, vmin=None, vmax=None): + """ + Display the data by dot map, the color of the dot represents the value of the variable. + + :param data: the input data + :param lon_column: the column of longitude + :param lat_column: the column of latitude + :param y_column: the column of the displayed variable + :param zoom: the zoom of the map + :param colors: the list of colors, if not given, the default color is used + :param steps: the steps of the colors + :param vmin: the minimum value of the displayed variable, if not given, the minimum value of the variable is used + :param vmax: the maximum value of the displayed variable, if not given, the maximum value of the variable is used + """ if colors is None: colors = [] center_lon = data[lon_column].mean()