diff --git a/websocietysimulator/tools/cache_interaction_tool.py b/websocietysimulator/tools/cache_interaction_tool.py index 5ef602b..475aec9 100644 --- a/websocietysimulator/tools/cache_interaction_tool.py +++ b/websocietysimulator/tools/cache_interaction_tool.py @@ -23,34 +23,22 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None): self.env_dir = os.path.join(data_dir, "lmdb_cache") os.makedirs(self.env_dir, exist_ok=True) - self.user_env = lmdb.open(os.path.join(self.env_dir, "users"), map_size=2 * 1024 * 1024 * 1024) - self.item_env = lmdb.open(os.path.join(self.env_dir, "items"), map_size=2 * 1024 * 1024 * 1024) - self.review_env = lmdb.open(os.path.join(self.env_dir, "reviews"), map_size=8 * 1024 * 1024 * 1024) + self.user_env = lmdb.open(os.path.join(self.env_dir, "users"), map_size=4 * 1024 * 1024 * 1024) + self.item_env = lmdb.open(os.path.join(self.env_dir, "items"), map_size=4 * 1024 * 1024 * 1024) + self.review_env = lmdb.open(os.path.join(self.env_dir, "reviews"), map_size=32 * 1024 * 1024 * 1024) # Load block set data if provided self.block_set_items = [] - self.block_set_pairs = set() # Store (user_id, item_id) pairs + self.block_set_pairs = set() if self.block_set_dir: logger.info(f"Loading block set data from {self.block_set_dir}") self.block_set_items = self._load_block_set() self.block_set_pairs = {(item['user_id'], item['item_id']) for item in self.block_set_items} - # Filter reviews based on block set - filtered_reviews = [] - with self.review_env.begin() as txn: - cursor = txn.cursor() - for key, value in cursor: - review = json.loads(value) - if (review['user_id'], review['item_id']) not in self.block_set_pairs: - filtered_reviews.append(review) - - logger.info(f"Filtered out {txn.stat()['entries'] - len(filtered_reviews)} reviews based on block set") - self.review_env = filtered_reviews - self._initialize_db() - def _load_block_set(self) -> List[Dict[str, str]]: - """Load block set files and return a list of user-item pairs.""" + def _load_block_set(self) -> List[dict]: + """Load all block set files from the block set directory.""" block_set_data = [] task_dir = os.path.join(self.block_set_dir, 'tasks') groundtruth_dir = os.path.join(self.block_set_dir, 'groundtruth') @@ -96,29 +84,30 @@ def _initialize_db(self): # Initialize reviews and their indices with self.review_env.begin(write=True) as txn: + filtered_count = 0 if not txn.stat()['entries']: for review in tqdm(self._iter_file('review.json')): + # 检查是否在block set中 + if (review['user_id'], review['item_id']) in self.block_set_pairs: + filtered_count += 1 + continue + # Store the review - txn.put( - review['review_id'].encode(), - json.dumps(review).encode() - ) - - # Update item reviews index (store only review_ids) - item_review_ids = json.loads(txn.get(f"item_{review['item_id']}".encode()) or '[]') - item_review_ids.append(review['review_id']) - txn.put( - f"item_{review['item_id']}".encode(), - json.dumps(item_review_ids).encode() - ) - - # Update user reviews index (store only review_ids) - user_review_ids = json.loads(txn.get(f"user_{review['user_id']}".encode()) or '[]') - user_review_ids.append(review['review_id']) - txn.put( - f"user_{review['user_id']}".encode(), - json.dumps(user_review_ids).encode() - ) + review_key = review['review_id'].encode() + txn.put(review_key, json.dumps(review).encode()) + + # Update item reviews index + item_key = f"item_{review['item_id']}".encode() + item_reviews = json.loads(txn.get(item_key) or '[]') + item_reviews.append(review['review_id']) + txn.put(item_key, json.dumps(item_reviews).encode()) + + # Update user reviews index + user_key = f"user_{review['user_id']}".encode() + user_reviews = json.loads(txn.get(user_key) or '[]') + user_reviews.append(review['review_id']) + txn.put(user_key, json.dumps(user_reviews).encode()) + logger.info(f"Filtered out {filtered_count} reviews based on block set") def _iter_file(self, filename: str) -> Iterator[Dict]: """Iterate through file line by line.""" diff --git a/websocietysimulator/tools/interaction_tool.py b/websocietysimulator/tools/interaction_tool.py index 7d4bb31..922a73a 100644 --- a/websocietysimulator/tools/interaction_tool.py +++ b/websocietysimulator/tools/interaction_tool.py @@ -61,7 +61,7 @@ def _load_data(self, filename: str) -> List[Dict]: with open(file_path, 'r', encoding='utf-8') as file: return [json.loads(line) for line in file] - def _load_block_set(self) -> List[str]: + def _load_block_set(self) -> List[dict]: """Load all block set files from the block set directory.""" block_set_data = [] task_dir = os.path.join(self.block_set_dir, 'tasks')