diff --git a/websocietysimulator/simulator.py b/websocietysimulator/simulator.py index b6e4890..0df2bf9 100644 --- a/websocietysimulator/simulator.py +++ b/websocietysimulator/simulator.py @@ -179,13 +179,18 @@ def run_simulation(self, number_of_tasks: int = None, enable_threading: bool = F logger.info(f"Simulation finished for task {index}") else: # 多线程处理 - from threading import Lock - + from threading import Lock, Event + log_lock = Lock() + cancel_event = Event() # 添加取消事件标志 self.simulation_outputs = [None] * len(task_to_run) def process_task(task_index_tuple): index, task = task_index_tuple + # 检查是否已经被要求取消 + if cancel_event.is_set(): + return index, None + if isinstance(self.llm, list): agent = self.agent_class(llm=self.llm[index%len(self.llm)]) else: @@ -194,6 +199,10 @@ def process_task(task_index_tuple): agent.insert_task(task) try: + # 定期检查是否需要取消 + if cancel_event.is_set(): + return index, None + output = agent.workflow() result = { "task": task.to_dict(), @@ -204,11 +213,12 @@ def process_task(task_index_tuple): "task": task.to_dict(), "error": "Forward method not implemented by participant." } + except Exception as e: + return index, None with log_lock: logger.info(f"Simulation finished for task {index}") - self.simulation_outputs[index] = result return index, result # 确定线程数 @@ -220,14 +230,12 @@ def process_task(task_index_tuple): logger.info(f"Running with {max_workers} threads") with ThreadPoolExecutor(max_workers=max_workers) as executor: - # 提交所有任务 future_to_index = { executor.submit(process_task, (i, task)): i for i, task in enumerate(task_to_run) } try: - # 等待所有任务完成或达到时间限制 for future in as_completed(future_to_index, timeout=timeout_seconds): try: index, result = future.result() @@ -236,6 +244,14 @@ def process_task(task_index_tuple): logger.error(f"Task failed with error: {str(e)}") except TimeoutError: logger.error(f"Time limit ({time_limitation} minutes) reached.") + # 设置取消标志 + cancel_event.set() + # 强制取消所有任务 + for future in future_to_index: + future.cancel() + # 立即关闭执行器,不等待任务完成 + executor._threads.clear() + executor.shutdown(wait=False) raise TimeoutError logger.info("Simulation finished")