Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuwei Yan committed Jan 7, 2025
1 parent 526c27e commit 2acbd11
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions websocietysimulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,18 @@ def run_simulation(self, number_of_tasks: int = None, enable_threading: bool = F
logger.info(f"Simulation finished for task {index}")
else:
# 多线程处理
from threading import Lock

from threading import Lock, Event
log_lock = Lock()
cancel_event = Event() # 添加取消事件标志
self.simulation_outputs = [None] * len(task_to_run)

def process_task(task_index_tuple):
index, task = task_index_tuple
# 检查是否已经被要求取消
if cancel_event.is_set():
return index, None

if isinstance(self.llm, list):
agent = self.agent_class(llm=self.llm[index%len(self.llm)])
else:
Expand All @@ -194,6 +199,10 @@ def process_task(task_index_tuple):
agent.insert_task(task)

try:
# 定期检查是否需要取消
if cancel_event.is_set():
return index, None

output = agent.workflow()
result = {
"task": task.to_dict(),
Expand All @@ -204,11 +213,12 @@ def process_task(task_index_tuple):
"task": task.to_dict(),
"error": "Forward method not implemented by participant."
}
except Exception as e:
return index, None

with log_lock:
logger.info(f"Simulation finished for task {index}")

self.simulation_outputs[index] = result
return index, result

# 确定线程数
Expand All @@ -220,14 +230,12 @@ def process_task(task_index_tuple):
logger.info(f"Running with {max_workers} threads")

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_index = {
executor.submit(process_task, (i, task)): i
for i, task in enumerate(task_to_run)
}

try:
# 等待所有任务完成或达到时间限制
for future in as_completed(future_to_index, timeout=timeout_seconds):
try:
index, result = future.result()
Expand All @@ -236,6 +244,14 @@ def process_task(task_index_tuple):
logger.error(f"Task failed with error: {str(e)}")
except TimeoutError:
logger.error(f"Time limit ({time_limitation} minutes) reached.")
# 设置取消标志
cancel_event.set()
# 强制取消所有任务
for future in future_to_index:
future.cancel()
# 立即关闭执行器,不等待任务完成
executor._threads.clear()
executor.shutdown(wait=False)
raise TimeoutError

logger.info("Simulation finished")
Expand Down

0 comments on commit 2acbd11

Please sign in to comment.