From 756ed51ae9553df328f5025b8eae53d22aeb93e8 Mon Sep 17 00:00:00 2001 From: Bowen XU <58252898+bowen-xu@users.noreply.github.com> Date: Tue, 14 Nov 2023 23:24:58 -0500 Subject: [PATCH] modify --- cpp/src/Group.cpp | 9 +++------ utils/Profiler.py | 18 ------------------ utils/draw_group.py | 2 -- 3 files changed, 3 insertions(+), 26 deletions(-) diff --git a/cpp/src/Group.cpp b/cpp/src/Group.cpp index bde49d4..6531a42 100644 --- a/cpp/src/Group.cpp +++ b/cpp/src/Group.cpp @@ -74,11 +74,7 @@ std::set &Group::activate(size_t i_column) /* 更新node的状态 */ for (auto &node : nodes) { - if (node->ts_update == this->ts - 1) - node->roll_state(); - else if (node->ts_update < this->ts - 1) - node->reset_state(); - node->ts_update = ts; + node->update(this->ts); } /* 激活结点。首先判断是否存在预激活的结点,如果存在,则激活那些结点;否则激活所有结点。 */ bool anticipated = std::any_of( @@ -272,7 +268,7 @@ void Group::learn(size_t i_column) if (nodes_selected.size() == 0) { - /* 如果没有任何激活的结点,则从现有的激活结点中,选取一个value最大的结点进行学习。 */ + /* 如果没有任何预测后激活的结点,则从现有的激活结点中,选取一个value最大的结点进行学习。 */ double value_max = 0.0; for (pNode &post_node : nodes) for (auto &[ante_node, ante_link] : post_node->ante_links->links) @@ -332,6 +328,7 @@ void Group::learn(size_t i_column) } } } + for (auto &post_node : this->buffer2) { // auto n = post_node->ante_links->links.size(); diff --git a/utils/Profiler.py b/utils/Profiler.py index 24bef2a..68d282a 100644 --- a/utils/Profiler.py +++ b/utils/Profiler.py @@ -7,24 +7,6 @@ def __init__(self, window=100) -> None: self.cnt_total = 0 self.cnt_correct = 0 self.pred_correct = np.full(window, False) - - # def observe1(self, pred, compare): - # '''''' - # if self.cnt_total < len(self.pred_correct): - # self.cnt_total += 1 - # if len(pred) > 0: - # pred = sorted(pred, key=lambda x: -x[1].e) - # result = chr(compare+65) == Interpreter.get(Column, pred[0][0]._column._id) - # else: - # result = False - # if result: - # self.cnt_correct += 1 - # if self.pred_correct[0]: - # self.cnt_correct -= 1 - # self.pred_correct[:-1] = self.pred_correct[1:] - # self.pred_correct[-1] = result - # acc = self.cnt_correct / self.cnt_total - # return acc def observe(self, pred: set, compare): '''''' diff --git a/utils/draw_group.py b/utils/draw_group.py index 227b025..ec4c826 100644 --- a/utils/draw_group.py +++ b/utils/draw_group.py @@ -1,8 +1,6 @@ try: - from ..narsese import Term, Statement, Truth from ..SequentialGroup import Group, Column, Node, Link, Bundle except ValueError as e: - from narsese import Term, Statement, Truth from SequentialGroup import Group, Column, Node, Link, Bundle from typing import List import networkx as nx