diff --git a/websocietysimulator/tools/cache_interaction_tool.py b/websocietysimulator/tools/cache_interaction_tool.py index 475aec9..e526b63 100644 --- a/websocietysimulator/tools/cache_interaction_tool.py +++ b/websocietysimulator/tools/cache_interaction_tool.py @@ -16,32 +16,30 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None): block_set_dir: Optional path to the directory containing block set files. """ logger.info(f"Initializing CacheInteractionTool with data directory: {data_dir}") - self.data_dir = data_dir - self.block_set_dir = block_set_dir # Set up LMDB environments for caching - self.env_dir = os.path.join(data_dir, "lmdb_cache") - os.makedirs(self.env_dir, exist_ok=True) + env_dir = os.path.join(data_dir, "lmdb_cache") + os.makedirs(env_dir, exist_ok=True) - self.user_env = lmdb.open(os.path.join(self.env_dir, "users"), map_size=4 * 1024 * 1024 * 1024) - self.item_env = lmdb.open(os.path.join(self.env_dir, "items"), map_size=4 * 1024 * 1024 * 1024) - self.review_env = lmdb.open(os.path.join(self.env_dir, "reviews"), map_size=32 * 1024 * 1024 * 1024) + self.user_env = lmdb.open(os.path.join(env_dir, "users"), map_size=4 * 1024 * 1024 * 1024) + self.item_env = lmdb.open(os.path.join(env_dir, "items"), map_size=4 * 1024 * 1024 * 1024) + self.review_env = lmdb.open(os.path.join(env_dir, "reviews"), map_size=32 * 1024 * 1024 * 1024) # Load block set data if provided - self.block_set_items = [] - self.block_set_pairs = set() - if self.block_set_dir: - logger.info(f"Loading block set data from {self.block_set_dir}") - self.block_set_items = self._load_block_set() - self.block_set_pairs = {(item['user_id'], item['item_id']) for item in self.block_set_items} + block_set_items = [] + block_set_pairs = set() + if block_set_dir: + logger.info(f"Loading block set data from {block_set_dir}") + block_set_items = self._load_block_set(block_set_dir) + block_set_pairs = {(item['user_id'], item['item_id']) for item in block_set_items} - self._initialize_db() + self._initialize_db(data_dir, block_set_pairs) - def _load_block_set(self) -> List[dict]: + def _load_block_set(self, block_set_dir: str) -> List[dict]: """Load all block set files from the block set directory.""" block_set_data = [] - task_dir = os.path.join(self.block_set_dir, 'tasks') - groundtruth_dir = os.path.join(self.block_set_dir, 'groundtruth') + task_dir = os.path.join(block_set_dir, 'tasks') + groundtruth_dir = os.path.join(block_set_dir, 'groundtruth') for filename in os.listdir(task_dir): if filename.startswith('task_') and filename.endswith('.json'): @@ -60,13 +58,13 @@ def _load_block_set(self) -> List[dict]: block_set_data.append({'user_id': task_data['user_id'], 'item_id': item}) return block_set_data - def _initialize_db(self): + def _initialize_db(self, data_dir: str, block_set_pairs: set): """Initialize the LMDB databases with data if they are empty.""" # Initialize users with self.user_env.begin(write=True) as txn: if not txn.stat()['entries']: with txn.cursor() as cursor: - for user in tqdm(self._iter_file('user.json')): + for user in tqdm(self._iter_file(data_dir, 'user.json')): cursor.put( user['user_id'].encode(), json.dumps(user).encode() @@ -76,7 +74,7 @@ def _initialize_db(self): with self.item_env.begin(write=True) as txn: if not txn.stat()['entries']: with txn.cursor() as cursor: - for item in tqdm(self._iter_file('item.json')): + for item in tqdm(self._iter_file(data_dir, 'item.json')): cursor.put( item['item_id'].encode(), json.dumps(item).encode() @@ -86,9 +84,9 @@ def _initialize_db(self): with self.review_env.begin(write=True) as txn: filtered_count = 0 if not txn.stat()['entries']: - for review in tqdm(self._iter_file('review.json')): + for review in tqdm(self._iter_file(data_dir, 'review.json')): # 检查是否在block set中 - if (review['user_id'], review['item_id']) in self.block_set_pairs: + if (review['user_id'], review['item_id']) in block_set_pairs: filtered_count += 1 continue @@ -109,9 +107,9 @@ def _initialize_db(self): txn.put(user_key, json.dumps(user_reviews).encode()) logger.info(f"Filtered out {filtered_count} reviews based on block set") - def _iter_file(self, filename: str) -> Iterator[Dict]: + def _iter_file(self, data_dir: str, filename: str) -> Iterator[Dict]: """Iterate through file line by line.""" - file_path = os.path.join(self.data_dir, filename) + file_path = os.path.join(data_dir, filename) with open(file_path, 'r', encoding='utf-8') as file: for line in file: yield json.loads(line) diff --git a/websocietysimulator/tools/interaction_tool.py b/websocietysimulator/tools/interaction_tool.py index 922a73a..0ec89a1 100644 --- a/websocietysimulator/tools/interaction_tool.py +++ b/websocietysimulator/tools/interaction_tool.py @@ -14,30 +14,28 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None): data_dir: Path to the directory containing Yelp dataset files. """ logger.info(f"Initializing InteractionTool with data directory: {data_dir}") - self.data_dir = data_dir - self.block_set_dir = block_set_dir # Convert DataFrames to dictionaries for O(1) lookup logger.info(f"Loading item data from {os.path.join(data_dir, 'item.json')}") - self.item_data = {item['item_id']: item for item in self._load_data('item.json')} + self.item_data = {item['item_id']: item for item in self._load_data(data_dir, 'item.json')} logger.info(f"Loading user data from {os.path.join(data_dir, 'user.json')}") - self.user_data = {user['user_id']: user for user in self._load_data('user.json')} + self.user_data = {user['user_id']: user for user in self._load_data(data_dir, 'user.json')} # Create review indices logger.info(f"Loading review data from {os.path.join(data_dir, 'review.json')}") - reviews = self._load_data('review.json') + reviews = self._load_data(data_dir, 'review.json') # Load ground truth data if available - self.block_set_items = [] - self.block_set_pairs = set() # 新增:用于存储(user_id, item_id)对 - if self.block_set_dir: - logger.info(f"Loading block set data from {self.block_set_dir}") - self.block_set_items = self._load_block_set() + block_set_items = [] + block_set_pairs = set() # 新增:用于存储(user_id, item_id)对 + if block_set_dir: + logger.info(f"Loading block set data from {block_set_dir}") + block_set_items = self._load_block_set(block_set_dir) # 将block set数据转换为(user_id, item_id)对的集合 - self.block_set_pairs = {(item['user_id'], item['item_id']) for item in self.block_set_items} + block_set_pairs = {(item['user_id'], item['item_id']) for item in block_set_items} # 过滤review数据,移除block set中的评论 filtered_reviews = [] for review in reviews: - if (review['user_id'], review['item_id']) not in self.block_set_pairs: + if (review['user_id'], review['item_id']) not in block_set_pairs: filtered_reviews.append(review) logger.info(f"Filtered out {len(reviews) - len(filtered_reviews)} reviews based on block set") @@ -55,17 +53,17 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None): # Index by user_id self.user_reviews.setdefault(review['user_id'], []).append(review) - def _load_data(self, filename: str) -> List[Dict]: + def _load_data(self, data_dir: str, filename: str) -> List[Dict]: """Load data as a list of dictionaries.""" - file_path = os.path.join(self.data_dir, filename) + file_path = os.path.join(data_dir, filename) with open(file_path, 'r', encoding='utf-8') as file: return [json.loads(line) for line in file] - def _load_block_set(self) -> List[dict]: + def _load_block_set(self, block_set_dir: str) -> List[dict]: """Load all block set files from the block set directory.""" block_set_data = [] - task_dir = os.path.join(self.block_set_dir, 'tasks') - groundtruth_dir = os.path.join(self.block_set_dir, 'groundtruth') + task_dir = os.path.join(block_set_dir, 'tasks') + groundtruth_dir = os.path.join(block_set_dir, 'groundtruth') for filename in os.listdir(task_dir): if filename.startswith('task_') and filename.endswith('.json'):