Skip to content

Commit

Permalink
Merge branch 'main' of github.com:zjuwss/gnnwr
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-nuclear committed Feb 19, 2024
2 parents 154ca17 + 6326ba8 commit a10be37
Showing 1 changed file with 2 additions and 31 deletions.
33 changes: 2 additions & 31 deletions src/gnnwr/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,13 @@ def scale(self, scale_fn=None, scale_params=None):
self.x_scale_info = {"min": x_scale_params.data_min_, "max": x_scale_params.data_max_}
self.x_data = x_scale_params.transform(pd.DataFrame(self.x_data, columns=self.x))
self.y_scale_info = {"min": y_scale_params.data_min_, "max": y_scale_params.data_max_}
self.y_data = y_scale_params.transform(pd.DataFrame(self.y_data, columns=self.y))
elif scale_fn == "standard_scale":
self.scale_fn = "standard_scale"
x_scale_params = scale_params[0]
y_scale_params = scale_params[1]
self.x_scale_info = {"mean": x_scale_params.mean_, "var": x_scale_params.var_}
self.x_data = x_scale_params.transform(pd.DataFrame(self.x_data, columns=self.x))
self.y_scale_info = {"mean": y_scale_params.mean_, "var": y_scale_params.var_}
self.y_data = y_scale_params.transform(pd.DataFrame(self.y_data, columns=self.y))

self.getScaledDataframe()

Expand All @@ -160,16 +158,12 @@ def scale2(self, scale_fn, scale_params):
self.scale_fn = "minmax_scale"
x_scale_params = scale_params[0]
y_scale_params = scale_params[1]
# self.x_data = self.x_data * (x_scale_params["max"] - x_scale_params["min"]) + x_scale_params["min"]
self.x_data = (self.x_data - x_scale_params["min"]) / (x_scale_params["max"] - x_scale_params["min"])
self.y_data = (self.y_data - y_scale_params["min"]) / (y_scale_params["max"] - y_scale_params["min"])
elif scale_fn == "standard_scale":
self.scale_fn = "standard_scale"
x_scale_params = scale_params[0]
y_scale_params = scale_params[1]
# self.x_data = self.x_data * np.sqrt(x_scale_params["var"]) + x_scale_params["mean"]
self.x_data = (self.x_data - x_scale_params['mean']) / np.sqrt(x_scale_params["var"])
self.y_data = (self.y_data - y_scale_params['mean']) / np.sqrt(y_scale_params["var"])

self.getScaledDataframe()

Expand All @@ -184,7 +178,7 @@ def getScaledDataframe(self):
scaledData = np.concatenate((self.x_data, self.y_data), axis=1)
self.scaledDataframe = pd.DataFrame(scaledData, columns=columns)

def rescale(self, x, y):
def rescale(self, x):
"""
rescale the data with the scale function and scale parameters
Expand All @@ -204,35 +198,12 @@ def rescale(self, x, y):
"""
if self.scale_fn == "minmax_scale":
x = np.multiply(x, self.x_scale_info["max"] - self.x_scale_info["min"]) + self.x_scale_info["min"]
y = np.multiply(y, self.y_scale_info["max"] - self.y_scale_info["min"]) + self.y_scale_info["min"]
elif self.scale_fn == "standard_scale":
x = np.multiply(x, np.sqrt(self.x_scale_info["var"])) + self.x_scale_info["mean"]
y = np.multiply(y, np.sqrt(self.y_scale_info["var"])) + self.y_scale_info["mean"]
else:
raise ValueError("invalid process_fn")
return x, y

def rescale_y(self, y):
"""
rescale the dependent variable data
Parameters
----------
y: numpy.ndarray
dependent variable data
return x

Returns
-------
y: numpy.ndarray
rescaled dependent variable data
"""
if self.scale_fn == "minmax_scale":
y = np.multiply(y, self.y_scale_info["max"] - self.y_scale_info["min"]) + self.y_scale_info["min"]
elif self.scale_fn == "standard_scale":
y = np.multiply(y, np.sqrt(self.y_scale_info["var"])) + self.y_scale_info["mean"]
else:
raise ValueError("invalid process_fn")
return y

def save(self, dirname):
"""
Expand Down

0 comments on commit a10be37

Please sign in to comment.