diff --git a/libs/models/tsom_plus_som.py b/libs/models/tsom_plus_som.py index 10e0609..bd3276a 100644 --- a/libs/models/tsom_plus_som.py +++ b/libs/models/tsom_plus_som.py @@ -5,13 +5,13 @@ class TSOMPlusSOM: - def __init__(self, member_features, index_members_of_group, params_tsom, params_som): + def __init__(self, member_features, group_features, params_tsom, params_som): self.params_tsom = params_tsom self.params_som = params_som self.params_tsom['X'] = member_features - self.index_members_of_group = index_members_of_group # グループ数の確認 - self.group_num = len(self.index_members_of_group) + self.group_features = group_features # グループ数の確認 + self.group_num = len(self.group_features) def fit(self, tsom_epoch_num, kernel_width, som_epoch_num): self._fit_1st_TSOM(tsom_epoch_num) @@ -25,14 +25,23 @@ def _fit_1st_TSOM(self, tsom_epoch_num): def _fit_KDE(self, kernel_width): # 学習した後の潜在空間からKDEで確率分布を作る prob_data = np.zeros((self.group_num, self.tsom.K1)) # group数*ノード数 # グループごとにKDEを適用 - for i in range(self.group_num): - Dist = dist.cdist(self.tsom.Zeta1, self.tsom.Z1[self.index_members_of_group[i], :], - 'sqeuclidean') # KxNの距離行列を計算 - H = np.exp(-Dist / (2 * kernel_width * kernel_width)) # KxNの学習量行列を計算 - prob = np.sum(H, axis=1) - prob_sum = np.sum(prob) - prob = prob / prob_sum - prob_data[i, :] = prob + if isinstance(self.group_features, np.ndarray) and self.group_features.ndim == 2: + # group_featuresがbag of membersで与えられた時の処理 + distance = dist.cdist(self.tsom.Zeta1, self.tsom.Z1, 'sqeuclidean') # K1 x num_members + H = np.exp(-0.5 * distance / (kernel_width * kernel_width)) # KxN + prob_data = self.group_features @ H.T # num_group x K1 + prob_data = prob_data / prob_data.sum(axis=1)[:, None] + else: + # group_featuresがlist of listsもしくはlist of arraysで与えられた時の処理 + for i in range(self.group_num): + Dist = dist.cdist(self.tsom.Zeta1, + self.tsom.Z1[self.group_features[i], :], + 'sqeuclidean') # KxNの距離行列を計算 + H = np.exp(-Dist / (2 * kernel_width * kernel_width)) # KxNの学習量行列を計算 + prob = np.sum(H, axis=1) + prob_sum = np.sum(prob) + prob = prob / prob_sum + prob_data[i, :] = prob self.params_som['X'] = prob_data self.params_som['metric'] = "KLdivergence" diff --git a/tests/plus_TSOM/test_plusTSOM.py b/tests/plus_TSOM/test_plusTSOM.py new file mode 100644 index 0000000..d18c58a --- /dev/null +++ b/tests/plus_TSOM/test_plusTSOM.py @@ -0,0 +1,140 @@ +import unittest + +import numpy as np + +from libs.models.tsom_plus_som import TSOMPlusSOM +from tests.plus_TSOM.plus_TSOM_watanabe import TSOMPlusSOMWatanabe + + +class TestTSOMPlusSOM(unittest.TestCase): + def create_artficial_data(self,n_samples,n_features,n_groups,n_samples_per_group): + x = np.random.normal(0.0,1.0,(n_samples,n_features)) + if isinstance(n_samples_per_group,int): + n_samples_per_group = np.ones(n_groups,int) * n_samples_per_group + index_members_of_group = [] + for n_samples_in_the_group in n_samples_per_group: + index_members_of_group.append(np.random.randint(0,n_samples,n_samples_in_the_group)) + return x, index_members_of_group + + def test_plusTSOM_ishida_vs_test_plusTSOM_watanabe(self): + seed = 100 + np.random.seed(seed) + n_samples = 1000 + n_groups = 10 # group数 + n_features = 3 # 各メンバーの特徴数 + n_samples_per_group = np.random.randint(1,30,n_groups) # 各グループにメンバーに何人いるのか + member_features,index_members_of_group = self.create_artficial_data(n_samples, + n_features, + n_groups, + n_samples_per_group) + # 1stTSOMの初期値 + Z1 = np.random.rand(n_samples, 2) * 2.0 - 1.0 + Z2 = np.random.rand(n_features, 2) * 2.0 - 1.0 + init_TSOM = [Z1, Z2] + init_SOM = np.random.rand(n_groups, 2) * 2.0 - 1.0 + + + params_tsom = {'latent_dim': [2, 2], + 'resolution': [10, 10], + 'SIGMA_MAX': [1.0, 1.0], + 'SIGMA_MIN': [0.1, 0.1], + 'TAU': [50, 50], + 'init': init_TSOM} + params_som = {'latent_dim': 2, + 'resolution': 10, + 'sigma_max': 2.0, + 'sigma_min': 0.5, + 'tau': 50, + 'init': init_SOM} + tsom_epoch_num = 50 + som_epoch_num = 50 + kernel_width = 0.3 + + htsom_ishida = TSOMPlusSOM(member_features=member_features, + group_features=index_members_of_group, + params_tsom=params_tsom, + params_som=params_som) + htsom_watanabe = TSOMPlusSOMWatanabe(member_features=member_features, + index_members_of_group=index_members_of_group, + params_tsom=params_tsom, + params_som=params_som) + + htsom_ishida.fit(tsom_epoch_num=tsom_epoch_num, + kernel_width=kernel_width, + som_epoch_num=som_epoch_num) + htsom_watanabe.fit(tsom_epoch_num=tsom_epoch_num, + kernel_width=kernel_width, + som_epoch_num=som_epoch_num) + + np.testing.assert_allclose(htsom_ishida.tsom.history['y'], htsom_watanabe.tsom.history['y']) + np.testing.assert_allclose(htsom_ishida.tsom.history['z1'], htsom_watanabe.tsom.history['z1']) + np.testing.assert_allclose(htsom_ishida.tsom.history['z2'], htsom_watanabe.tsom.history['z2']) + np.testing.assert_allclose(htsom_ishida.params_som['X'], htsom_watanabe.params_som['X']) + np.testing.assert_allclose(htsom_ishida.som.history['y'], htsom_watanabe.som.history['y']) + np.testing.assert_allclose(htsom_ishida.som.history['z'], htsom_watanabe.som.history['z']) + + def _transform_list_to_bag(self,list_of_indexes,num_members): + bag_of_members = np.empty((0,num_members)) + for indexes in list_of_indexes: + one_hot_vectors = np.eye(num_members)[indexes] + one_bag = one_hot_vectors.sum(axis=0)[None,:] + bag_of_members=np.append(bag_of_members,one_bag,axis=0) + return bag_of_members + def test_matching_index_member_as_list_or_bag(self): + seed = 100 + np.random.seed(seed) + n_members = 100 + n_groups = 10 # group数 + n_features = 3 # 各メンバーの特徴数 + n_samples_per_group = np.random.randint(1,50,n_groups) # 各グループにメンバーに何人いるのか + member_features,index_members_of_group = self.create_artficial_data(n_members, + n_features, + n_groups, + n_samples_per_group) + bag_of_members = self._transform_list_to_bag(index_members_of_group, n_members) + + Z1 = np.random.rand(n_members, 2) * 2.0 - 1.0 + Z2 = np.random.rand(n_features, 2) * 2.0 - 1.0 + init_TSOM = [Z1, Z2] + init_SOM = np.random.rand(n_groups, 2) * 2.0 - 1.0 + params_tsom = {'latent_dim': [2, 2], + 'resolution': [10, 10], + 'SIGMA_MAX': [1.0, 1.0], + 'SIGMA_MIN': [0.1, 0.1], + 'TAU': [50, 50], + 'init': init_TSOM} + params_som = {'latent_dim': 2, + 'resolution': 10, + 'sigma_max': 2.0, + 'sigma_min': 0.5, + 'tau': 50, + 'init': init_SOM} + tsom_epoch_num = 50 + som_epoch_num = 50 + kernel_width = 0.3 + + tsom_plus_som_input_list = TSOMPlusSOM(member_features=member_features, + group_features=index_members_of_group, + params_tsom=params_tsom, + params_som=params_som) + tsom_plus_som_input_bag = TSOMPlusSOM(member_features=member_features, + group_features=bag_of_members, + params_tsom=params_tsom, + params_som=params_som) + + tsom_plus_som_input_list.fit(tsom_epoch_num=tsom_epoch_num, + kernel_width=kernel_width, + som_epoch_num=som_epoch_num) + tsom_plus_som_input_bag.fit(tsom_epoch_num=tsom_epoch_num, + kernel_width=kernel_width, + som_epoch_num=som_epoch_num) + + np.testing.assert_allclose(tsom_plus_som_input_list.tsom.history['y'], tsom_plus_som_input_bag.tsom.history['y']) + np.testing.assert_allclose(tsom_plus_som_input_list.tsom.history['z1'], tsom_plus_som_input_bag.tsom.history['z1']) + np.testing.assert_allclose(tsom_plus_som_input_list.tsom.history['z2'], tsom_plus_som_input_bag.tsom.history['z2']) + np.testing.assert_allclose(tsom_plus_som_input_list.params_som['X'], tsom_plus_som_input_bag.params_som['X']) + np.testing.assert_allclose(tsom_plus_som_input_list.som.history['y'], tsom_plus_som_input_bag.som.history['y']) + np.testing.assert_allclose(tsom_plus_som_input_list.som.history['z'], tsom_plus_som_input_bag.som.history['z']) + +if __name__ == "__main__": + unittest.main() diff --git a/tutorials/TSOM_plus_SOM/fit_TSOM_plus_SOM.py b/tutorials/TSOM_plus_SOM/fit_TSOM_plus_SOM.py index 5bfbe65..20a01ba 100644 --- a/tutorials/TSOM_plus_SOM/fit_TSOM_plus_SOM.py +++ b/tutorials/TSOM_plus_SOM/fit_TSOM_plus_SOM.py @@ -2,7 +2,7 @@ import numpy as np from libs.datasets.artificial.kura_tsom import load_kura_tsom import matplotlib.pyplot as plt -from libs.models.TSOMPlusSOM import TSOMPlusSOM +from libs.models.tsom_plus_som import TSOMPlusSOM from mpl_toolkits.mplot3d import Axes3D from libs.visualization.som.Grad_norm import Grad_Norm @@ -27,11 +27,13 @@ input_data[int(2 * i), :, :] = group1 input_data[int(2 * i + 1), :, :] = group2 -# fig = plt.figure() -# ax = fig.add_subplot(1, 1, 1, projection="3d") -# for i in range(group_num): -# ax.scatter(input_data[i, :, 0], input_data[i, :, 1], input_data[i, :, 2]) -# plt.show() +#観測データの描画 +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1, projection="3d") +for i in range(group_num): + ax.scatter(input_data[i, :, 0], input_data[i, :, 1], input_data[i, :, 2],label="group"+str(i+1)) + plt.legend(bbox_to_anchor=(1, 1), loc='upper right', borderaxespad=0, fontsize=10) +plt.show() input_data = input_data.reshape(-1, 3) # グループラベルの作成 @@ -55,7 +57,7 @@ # +型階層TSOMのclass読み込み # group_label以降の変数ははlatent_dim,resolution,sigma_max,sigma_min,tauでSOMとTSOMでまとめている tsom_plus_som = TSOMPlusSOM(member_features=input_data, - index_members_of_group=group_label, + group_features=group_label, params_tsom=params_tsom, params_som=params_som)