Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuwei Yan committed Feb 5, 2025
1 parent acdbd32 commit b780ae9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 41 deletions.
46 changes: 22 additions & 24 deletions websocietysimulator/tools/cache_interaction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,30 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None):
block_set_dir: Optional path to the directory containing block set files.
"""
logger.info(f"Initializing CacheInteractionTool with data directory: {data_dir}")
self.data_dir = data_dir
self.block_set_dir = block_set_dir

# Set up LMDB environments for caching
self.env_dir = os.path.join(data_dir, "lmdb_cache")
os.makedirs(self.env_dir, exist_ok=True)
env_dir = os.path.join(data_dir, "lmdb_cache")
os.makedirs(env_dir, exist_ok=True)

self.user_env = lmdb.open(os.path.join(self.env_dir, "users"), map_size=4 * 1024 * 1024 * 1024)
self.item_env = lmdb.open(os.path.join(self.env_dir, "items"), map_size=4 * 1024 * 1024 * 1024)
self.review_env = lmdb.open(os.path.join(self.env_dir, "reviews"), map_size=32 * 1024 * 1024 * 1024)
self.user_env = lmdb.open(os.path.join(env_dir, "users"), map_size=4 * 1024 * 1024 * 1024)
self.item_env = lmdb.open(os.path.join(env_dir, "items"), map_size=4 * 1024 * 1024 * 1024)
self.review_env = lmdb.open(os.path.join(env_dir, "reviews"), map_size=32 * 1024 * 1024 * 1024)

# Load block set data if provided
self.block_set_items = []
self.block_set_pairs = set()
if self.block_set_dir:
logger.info(f"Loading block set data from {self.block_set_dir}")
self.block_set_items = self._load_block_set()
self.block_set_pairs = {(item['user_id'], item['item_id']) for item in self.block_set_items}
block_set_items = []
block_set_pairs = set()
if block_set_dir:
logger.info(f"Loading block set data from {block_set_dir}")
block_set_items = self._load_block_set(block_set_dir)
block_set_pairs = {(item['user_id'], item['item_id']) for item in block_set_items}

self._initialize_db()
self._initialize_db(data_dir, block_set_pairs)

def _load_block_set(self) -> List[dict]:
def _load_block_set(self, block_set_dir: str) -> List[dict]:
"""Load all block set files from the block set directory."""
block_set_data = []
task_dir = os.path.join(self.block_set_dir, 'tasks')
groundtruth_dir = os.path.join(self.block_set_dir, 'groundtruth')
task_dir = os.path.join(block_set_dir, 'tasks')
groundtruth_dir = os.path.join(block_set_dir, 'groundtruth')

for filename in os.listdir(task_dir):
if filename.startswith('task_') and filename.endswith('.json'):
Expand All @@ -60,13 +58,13 @@ def _load_block_set(self) -> List[dict]:
block_set_data.append({'user_id': task_data['user_id'], 'item_id': item})
return block_set_data

def _initialize_db(self):
def _initialize_db(self, data_dir: str, block_set_pairs: set):
"""Initialize the LMDB databases with data if they are empty."""
# Initialize users
with self.user_env.begin(write=True) as txn:
if not txn.stat()['entries']:
with txn.cursor() as cursor:
for user in tqdm(self._iter_file('user.json')):
for user in tqdm(self._iter_file(data_dir, 'user.json')):
cursor.put(
user['user_id'].encode(),
json.dumps(user).encode()
Expand All @@ -76,7 +74,7 @@ def _initialize_db(self):
with self.item_env.begin(write=True) as txn:
if not txn.stat()['entries']:
with txn.cursor() as cursor:
for item in tqdm(self._iter_file('item.json')):
for item in tqdm(self._iter_file(data_dir, 'item.json')):
cursor.put(
item['item_id'].encode(),
json.dumps(item).encode()
Expand All @@ -86,9 +84,9 @@ def _initialize_db(self):
with self.review_env.begin(write=True) as txn:
filtered_count = 0
if not txn.stat()['entries']:
for review in tqdm(self._iter_file('review.json')):
for review in tqdm(self._iter_file(data_dir, 'review.json')):
# 检查是否在block set中
if (review['user_id'], review['item_id']) in self.block_set_pairs:
if (review['user_id'], review['item_id']) in block_set_pairs:
filtered_count += 1
continue

Expand All @@ -109,9 +107,9 @@ def _initialize_db(self):
txn.put(user_key, json.dumps(user_reviews).encode())
logger.info(f"Filtered out {filtered_count} reviews based on block set")

def _iter_file(self, filename: str) -> Iterator[Dict]:
def _iter_file(self, data_dir: str, filename: str) -> Iterator[Dict]:
"""Iterate through file line by line."""
file_path = os.path.join(self.data_dir, filename)
file_path = os.path.join(data_dir, filename)
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
yield json.loads(line)
Expand Down
32 changes: 15 additions & 17 deletions websocietysimulator/tools/interaction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,28 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None):
data_dir: Path to the directory containing Yelp dataset files.
"""
logger.info(f"Initializing InteractionTool with data directory: {data_dir}")
self.data_dir = data_dir
self.block_set_dir = block_set_dir
# Convert DataFrames to dictionaries for O(1) lookup
logger.info(f"Loading item data from {os.path.join(data_dir, 'item.json')}")
self.item_data = {item['item_id']: item for item in self._load_data('item.json')}
self.item_data = {item['item_id']: item for item in self._load_data(data_dir, 'item.json')}
logger.info(f"Loading user data from {os.path.join(data_dir, 'user.json')}")
self.user_data = {user['user_id']: user for user in self._load_data('user.json')}
self.user_data = {user['user_id']: user for user in self._load_data(data_dir, 'user.json')}

# Create review indices
logger.info(f"Loading review data from {os.path.join(data_dir, 'review.json')}")
reviews = self._load_data('review.json')
reviews = self._load_data(data_dir, 'review.json')
# Load ground truth data if available
self.block_set_items = []
self.block_set_pairs = set() # 新增:用于存储(user_id, item_id)对
if self.block_set_dir:
logger.info(f"Loading block set data from {self.block_set_dir}")
self.block_set_items = self._load_block_set()
block_set_items = []
block_set_pairs = set() # 新增:用于存储(user_id, item_id)对
if block_set_dir:
logger.info(f"Loading block set data from {block_set_dir}")
block_set_items = self._load_block_set(block_set_dir)
# 将block set数据转换为(user_id, item_id)对的集合
self.block_set_pairs = {(item['user_id'], item['item_id']) for item in self.block_set_items}
block_set_pairs = {(item['user_id'], item['item_id']) for item in block_set_items}

# 过滤review数据,移除block set中的评论
filtered_reviews = []
for review in reviews:
if (review['user_id'], review['item_id']) not in self.block_set_pairs:
if (review['user_id'], review['item_id']) not in block_set_pairs:
filtered_reviews.append(review)

logger.info(f"Filtered out {len(reviews) - len(filtered_reviews)} reviews based on block set")
Expand All @@ -55,17 +53,17 @@ def __init__(self, data_dir: str, block_set_dir: Optional[str] = None):
# Index by user_id
self.user_reviews.setdefault(review['user_id'], []).append(review)

def _load_data(self, filename: str) -> List[Dict]:
def _load_data(self, data_dir: str, filename: str) -> List[Dict]:
"""Load data as a list of dictionaries."""
file_path = os.path.join(self.data_dir, filename)
file_path = os.path.join(data_dir, filename)
with open(file_path, 'r', encoding='utf-8') as file:
return [json.loads(line) for line in file]

def _load_block_set(self) -> List[dict]:
def _load_block_set(self, block_set_dir: str) -> List[dict]:
"""Load all block set files from the block set directory."""
block_set_data = []
task_dir = os.path.join(self.block_set_dir, 'tasks')
groundtruth_dir = os.path.join(self.block_set_dir, 'groundtruth')
task_dir = os.path.join(block_set_dir, 'tasks')
groundtruth_dir = os.path.join(block_set_dir, 'groundtruth')

for filename in os.listdir(task_dir):
if filename.startswith('task_') and filename.endswith('.json'):
Expand Down

0 comments on commit b780ae9

Please sign in to comment.