Skip to content

Commit

Permalink
Update model.py and utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-nuclear committed Feb 19, 2024
1 parent bc31cdc commit 154ca17
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
11 changes: 5 additions & 6 deletions src/gnnwr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
import torch.optim as optim
import warnings
from sklearn.metrics import r2_score
from torch.utils.tensorboard import SummaryWriter # 用于保存训练过程
from torch.utils.tensorboard import SummaryWriter # to save the process of the model
from tqdm import trange
from collections import OrderedDict
import logging
from .networks import SWNN, STPNN, STNN_SPNN
from .utils import OLS, DIAGNOSIS


# 23.6.8_TODO: 寻找合适的优化器 考虑SGD+学习率调整 输出权重
class GNNWR:
r"""
GNNWR(Geographically neural network coefficiented regression) is a model to address spatial non-stationarity in various domains with complex geographical processes,
Expand Down Expand Up @@ -526,7 +525,7 @@ def predict(self, dataset):

def predict_coef(self, dataset):
"""
predict the spatial coefficient of the dataset
predict the spatial coefficient of the independent variable
Parameters
----------
Expand Down Expand Up @@ -557,7 +556,7 @@ def predict_coef(self, dataset):

def load_model(self, path, use_dict=False, map_location=None):
"""
load the model
load the model from the path
Parameters
----------
Expand Down Expand Up @@ -634,7 +633,7 @@ def add_graph(self):

def result(self, path=None, use_dict=False, map_location=None):
"""
print the result of the model, including the model structure, optimizer,the result of test dataset
print the result of the model, including the model name, regression fomula and the result of test dataset
Parameters
----------
Expand Down Expand Up @@ -700,7 +699,7 @@ def result(self, path=None, use_dict=False, map_location=None):

def reg_result(self, filename=None, model_path=None, use_dict=False, only_return=False, map_location=None):
"""
save the regression result of the model, including the coefficient of each argument, the bias, the predicted result
save the regression result of the model, including the coefficient of each argument, the bias and the predicted result
Parameters
----------
Expand Down
45 changes: 43 additions & 2 deletions src/gnnwr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self, dataset, xName: list, yName: list):
class DIAGNOSIS:
"""
`DIAGNOSIS` is the class to calculate the diagnoses of the result of GNNWR/GTNNWR.
These diagnoses include F1-test, F2-test, F3-test, AIC, AICc, R2, Adjust_R2, RMSE (Root Mean Square Error).
The explanation of these diagnoses can be found in the paper
`Geographically neural network weighted regression for the accurate estimation of spatial non-stationarity <https://doi.org/10.1080/13658816.2019.1707834>`.
:param weight: output of the neural network
:param x_data: the independent variables
:param y_data: the dependent variables
Expand Down Expand Up @@ -70,7 +72,6 @@ def __init__(self, weight, x_data, y_data, y_pred):
self.f3_dict_2 = None
def hat(self):
"""
:return: hat matrix
"""
return self.__hat
Expand Down Expand Up @@ -174,6 +175,15 @@ def RMSE(self):


class Visualize:
"""
`Visualize` is the class to visualize the data and the result of GNNWR/GTNNWR.
It based on the `folium` package and use GaoDe map as the background. And it can display the dataset, the coefficients heatmap, and the dot map,
which helps to understand the spatial distribution of the data and the result of GNNWR/GTNNWR better.
:param data: the input data
:param lon_lat_columns: the columns of longitude and latitude
:param zoom: the zoom of the map
"""
def __init__(self, data, lon_lat_columns=None, zoom=4):
self.__raw_data = data
self.__tiles = 'https://wprd01.is.autonavi.com/appmaptile?x={x}&y={y}&z={z}&lang=en&size=1&scl=1&style=7'
Expand Down Expand Up @@ -205,6 +215,15 @@ def __init__(self, data, lon_lat_columns=None, zoom=4):
raise ValueError("given data is not instance of GNNWR")

def display_dataset(self, name="all", y_column=None, colors=None, steps=20, vmin=None, vmax=None):
"""
Display the dataset on the map, including the train, valid, test dataset.
:param name: the name of the dataset, including 'all', 'train', 'valid', 'test'
:param y_column: the column of the displayed variable
:param colors: the list of colors, if not given, the default color is used
:param steps: the steps of the colors
"""
if colors is None:
colors = []
if y_column is None:
Expand Down Expand Up @@ -241,6 +260,15 @@ def display_dataset(self, name="all", y_column=None, colors=None, steps=20, vmin
return res

def coefs_heatmap(self, data_column, colors=None, steps=20, vmin=None, vmax=None):
"""
Display the heatmap of the coefficients of the result of GNNWR/GTNNWR.
:param data_column: the column of the displayed variable
:param colors: the list of colors, if not given, the default color is used
:param steps: the steps of the colors
:param vmin: the minimum value of the displayed variable, if not given, the minimum value of the variable is used
:param vmax: the maximum value of the displayed variable, if not given, the maximum value of the variable is used
"""
if colors is None:
colors = []
res = folium.Map(location=[self.__center_lat, self.__center_lon], zoom_start=self.__zoom, tiles=self.__tiles,
Expand All @@ -261,6 +289,19 @@ def coefs_heatmap(self, data_column, colors=None, steps=20, vmin=None, vmax=None
return res

def dot_map(self, data, lon_column, lat_column, y_column, zoom=4, colors=None, steps=20, vmin=None, vmax=None):
"""
Display the data by dot map, the color of the dot represents the value of the variable.
:param data: the input data
:param lon_column: the column of longitude
:param lat_column: the column of latitude
:param y_column: the column of the displayed variable
:param zoom: the zoom of the map
:param colors: the list of colors, if not given, the default color is used
:param steps: the steps of the colors
:param vmin: the minimum value of the displayed variable, if not given, the minimum value of the variable is used
:param vmax: the maximum value of the displayed variable, if not given, the maximum value of the variable is used
"""
if colors is None:
colors = []
center_lon = data[lon_column].mean()
Expand Down

0 comments on commit 154ca17

Please sign in to comment.