Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuwei Yan committed Jan 31, 2025
1 parent ca42d85 commit 1cf21e5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 39 deletions.
65 changes: 27 additions & 38 deletions websocietysimulator/tools/cache_interaction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,22 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None):
self.env_dir = os.path.join(data_dir, "lmdb_cache")
os.makedirs(self.env_dir, exist_ok=True)

self.user_env = lmdb.open(os.path.join(self.env_dir, "users"), map_size=2 * 1024 * 1024 * 1024)
self.item_env = lmdb.open(os.path.join(self.env_dir, "items"), map_size=2 * 1024 * 1024 * 1024)
self.review_env = lmdb.open(os.path.join(self.env_dir, "reviews"), map_size=8 * 1024 * 1024 * 1024)
self.user_env = lmdb.open(os.path.join(self.env_dir, "users"), map_size=4 * 1024 * 1024 * 1024)
self.item_env = lmdb.open(os.path.join(self.env_dir, "items"), map_size=4 * 1024 * 1024 * 1024)
self.review_env = lmdb.open(os.path.join(self.env_dir, "reviews"), map_size=32 * 1024 * 1024 * 1024)

# Load block set data if provided
self.block_set_items = []
self.block_set_pairs = set() # Store (user_id, item_id) pairs
self.block_set_pairs = set()
if self.block_set_dir:
logger.info(f"Loading block set data from {self.block_set_dir}")
self.block_set_items = self._load_block_set()
self.block_set_pairs = {(item['user_id'], item['item_id']) for item in self.block_set_items}

# Filter reviews based on block set
filtered_reviews = []
with self.review_env.begin() as txn:
cursor = txn.cursor()
for key, value in cursor:
review = json.loads(value)
if (review['user_id'], review['item_id']) not in self.block_set_pairs:
filtered_reviews.append(review)

logger.info(f"Filtered out {txn.stat()['entries'] - len(filtered_reviews)} reviews based on block set")
self.review_env = filtered_reviews

self._initialize_db()

def _load_block_set(self) -> List[Dict[str, str]]:
"""Load block set files and return a list of user-item pairs."""
def _load_block_set(self) -> List[dict]:
"""Load all block set files from the block set directory."""
block_set_data = []
task_dir = os.path.join(self.block_set_dir, 'tasks')
groundtruth_dir = os.path.join(self.block_set_dir, 'groundtruth')
Expand Down Expand Up @@ -96,29 +84,30 @@ def _initialize_db(self):

# Initialize reviews and their indices
with self.review_env.begin(write=True) as txn:
filtered_count = 0
if not txn.stat()['entries']:
for review in tqdm(self._iter_file('review.json')):
# 检查是否在block set中
if (review['user_id'], review['item_id']) in self.block_set_pairs:
filtered_count += 1
continue

# Store the review
txn.put(
review['review_id'].encode(),
json.dumps(review).encode()
)

# Update item reviews index (store only review_ids)
item_review_ids = json.loads(txn.get(f"item_{review['item_id']}".encode()) or '[]')
item_review_ids.append(review['review_id'])
txn.put(
f"item_{review['item_id']}".encode(),
json.dumps(item_review_ids).encode()
)

# Update user reviews index (store only review_ids)
user_review_ids = json.loads(txn.get(f"user_{review['user_id']}".encode()) or '[]')
user_review_ids.append(review['review_id'])
txn.put(
f"user_{review['user_id']}".encode(),
json.dumps(user_review_ids).encode()
)
review_key = review['review_id'].encode()
txn.put(review_key, json.dumps(review).encode())

# Update item reviews index
item_key = f"item_{review['item_id']}".encode()
item_reviews = json.loads(txn.get(item_key) or '[]')
item_reviews.append(review['review_id'])
txn.put(item_key, json.dumps(item_reviews).encode())

# Update user reviews index
user_key = f"user_{review['user_id']}".encode()
user_reviews = json.loads(txn.get(user_key) or '[]')
user_reviews.append(review['review_id'])
txn.put(user_key, json.dumps(user_reviews).encode())
logger.info(f"Filtered out {filtered_count} reviews based on block set")

def _iter_file(self, filename: str) -> Iterator[Dict]:
"""Iterate through file line by line."""
Expand Down
2 changes: 1 addition & 1 deletion websocietysimulator/tools/interaction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _load_data(self, filename: str) -> List[Dict]:
with open(file_path, 'r', encoding='utf-8') as file:
return [json.loads(line) for line in file]

def _load_block_set(self) -> List[str]:
def _load_block_set(self) -> List[dict]:
"""Load all block set files from the block set directory."""
block_set_data = []
task_dir = os.path.join(self.block_set_dir, 'tasks')
Expand Down

0 comments on commit 1cf21e5

Please sign in to comment.