提问人:Simon 提问时间:4/22/2023 最后编辑:Simon 更新时间:4/23/2023 访问量:434
在 python 多处理和多线程中处理异常和 CTRL+C
Handling exceptions and CTRL+C in python multiprocessing and multithread
问:
我的代码执行以下操作:
- 启动收集数据的进程
- 启动进程以测试模型
- 一个线程负责聚合数据以训练模型
- 一个线程聚合测试结果
我需要处理异常,如果训练线程失败,一切都会终止,但如果测试线程失败,训练将继续。
下面是一个 MWE(没有模型,但我管理进程和线程的方式是相同的)。
它可以工作,但我无法抓住 CTRL+C。如果我按下它,我会收到一条很长的消息文本,但程序会继续运行。
我已经读到我应该将主线程设置为但没有任何变化。我尝试在主呼叫周围放置一个无济于事。daemon = True
try except
run
如何使用 CTRL+C 优雅地结束程序?
以及其他问题:
- 有没有更好的方法可以在训练线程结束时结束程序?我正在使用该变量,因为不起作用。我已经阅读了有关使用的信息,但我找不到简单的指南。
terminate
daemon = True
concurrent.future
- 像我的示例一样,使用一个池还是使用两个池(一个用于收集过程,一个用于训练过程)更好?会有什么区别吗?
- 如何使每个线程捕获正确的异常?我的代码在 和 中引发了异常,这些异常由进程运行。但是,由 引发的异常必须由 捕获,而由 引发的异常必须由 捕获。
collect
test
collect
run_train
test
run_test
import threading
import logging
import traceback
import torch
import time
from torch import multiprocessing as mp
try:
mp.set_start_method('spawn')
except:
pass
shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
def collect(id, queue, data_collect):
log.info('Collect %i started ...', id)
try:
while True:
idx = queue.get()
if idx is None:
break
data_collect[idx] = torch.rand(1)
queue.task_done()
# actually do something meaningful
except Exception as e:
log.error('Exception in collect process %i', id)
traceback.print_exc()
raise e
def test(id, queue, data_test):
log.info('Test %i started ...', id)
try:
while True:
idx = queue.get()
if idx is None:
break
data_test[idx] = torch.rand(1)
queue.task_done()
# actually do something meaningful
except Exception as e:
log.error('Exception in test process %i', id)
traceback.print_exc()
raise e
def run():
steps = 0
num_collect_procs = 3
num_test_procs = 2
max_steps = 10
terminate = False
data_collect = torch.zeros(num_collect_procs).share_memory_()
data_test = torch.zeros(num_test_procs).share_memory_()
manager = mp.Manager()
pool = mp.Pool()
collect_queue = manager.JoinableQueue()
test_queue = manager.JoinableQueue()
# Start collection and testing processes
for i in range(num_collect_procs):
pool.apply_async(collect, args=(i, collect_queue, data_collect))
for i in range(num_test_procs):
pool.apply_async(test, args=(i, test_queue, data_test))
# Define target function for the learning thread
def run_train():
nonlocal steps, terminate
log.info('Training thread started ...')
while steps < max_steps and not terminate:
try:
for idx in range(num_collect_procs):
collect_queue.put(idx)
collect_queue.join()
time.sleep(0.1)
log.info('Training, %i %f', steps, data_collect.sum())
steps += 1
except:
terminate = True
for idx in range(num_collect_procs):
collect_queue.put(None)
break
log.info('Training done')
# Define target function for the testing thread
def run_test():
nonlocal steps, terminate
log.info('Testing thread started ...')
while steps < max_steps and not terminate:
try:
for idx in range(num_test_procs):
test_queue.put(idx)
test_queue.join()
time.sleep(0.1)
log.info('Testing, %i %f', steps, data_test.sum())
except:
for idx in range(num_test_procs):
test_queue.put(None)
break
log.info('Testing done')
learning_thread = threading.Thread(target=run_train, name='train')
learning_thread.start()
testing_thread = threading.Thread(target=run_test, name='test')
testing_thread.start()
learning_thread.join()
testing_thread.join()
collect_queue.join()
test_queue.join()
pool.terminate()
pool.join()
if __name__ == '__main__':
run()
答:
0赞
Simon
4/23/2023
#1
我使用并捕获进程中的异常来做到这一点。代码比我预期的要长,所以如果有人有更好的解决方案,我很乐意使用他们的。concurrent.future
done_callback
代码如下。如果有人想使用它,请注意以下几点:
- 我将上下文设置为因为它是 Windows 唯一支持的上下文,并且我需要相同的代码才能在 Ubuntu 和 Windows 上工作。
spawn
- 进程引发的异常不能由线程直接读取。但是,如果引发异常,则对象会将工作视为已完成,因此我用于清理队列并更新变量以通知线程结束。
future
done_callback
terminate
- 进程和线程之间通信的另一种方法是读取它们使用的共享数据。在此示例中,我检查 .
nan
- 仅当 或 中发生错误时,程序才会结束。如果在测试过程中出现故障,训练将继续进行。
data_collect
run_train
- 这也抓住了.
CTRL+C
- 在我的示例中,基本上做同样的事情(将随机值分配给张量),但这只是一个示例。在实践中,他们会做不同的事情。
collect
train
import logging
import traceback
import torch
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from torch import multiprocessing as mp
shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
# to randomly raise errors
# play with them to see how the program behaves
error_prob_collect = 0.
error_prob_test = 0.5
error_prob_training = 0.
error_prob_testing = 0.
def collect(id, queue, data):
log.info('Collect %i started ...', id)
while True:
try:
idx = queue.get()
if idx is None:
break
data[idx] = torch.rand(1)
if torch.rand(1) < error_prob_collect:
data[idx] = torch.nan # the thread will read this
raise Exception('unlucky training')
except Exception as e:
log.error('Exception in collect process %i', id)
traceback.print_exc()
finally:
queue.task_done()
log.info('Collect %i completed', id)
def test(id, queue, data):
log.info('Test %i started ...', id)
while True:
try:
idx = queue.get()
if idx is None:
break
data[idx] = torch.rand(1)
if torch.rand(1) < error_prob_test:
data[idx] = torch.nan # the thread will read this
raise Exception('unluckly testing')
except Exception as e:
log.error('Exception in test process %i', id)
traceback.print_exc()
finally:
queue.task_done()
log.info('Test %i completed', id)
def run():
steps = 0
terminate = False
num_collect_procs = 3
num_test_procs = 2
max_steps = 10
data_collect = torch.zeros(num_collect_procs).share_memory_()
data_test = torch.zeros(num_test_procs).share_memory_()
ctx = mp.get_context('spawn')
manager = mp.Manager()
collect_queue = manager.JoinableQueue()
test_queue = manager.JoinableQueue()
collect_pool = ProcessPoolExecutor(num_collect_procs, mp_context=ctx)
test_pool = ProcessPoolExecutor(num_test_procs, mp_context=ctx)
def clear_collect(future):
while not collect_queue.empty():
try:
collect_queue.task_done()
except ValueError:
break
nonlocal terminate
terminate = True
for i in range(num_collect_procs):
future = collect_pool.submit(collect, i, collect_queue, data_collect)
future.add_done_callback(clear_collect)
for i in range(num_test_procs):
future = test_pool.submit(test, i, test_queue, data_test)
def run_train():
nonlocal steps, terminate
log.info('Training thread started ...')
while steps < max_steps:
try:
for idx in range(num_collect_procs):
collect_queue.put(idx)
collect_queue.join()
time.sleep(0.1) # just to make printed output easier to read
log.info('Training, %i %f', steps, data_collect.sum())
steps += 1
if torch.rand(1) < error_prob_training:
raise Exception('bad training')
except:
log.error(' Error in training thread, we are going to terminate everything')
traceback.print_exc()
break
if data_collect.isnan().any():
log.error(' Training thread knows that there was an error in data_collect, we are going to terminate everything')
break
log.info('Training ended')
for i in range(num_collect_procs):
collect_queue.put(None)
terminate = True
def run_test():
nonlocal steps
log.info('Testing thread started ...')
while steps < max_steps and not terminate:
try:
for idx in range(num_test_procs):
test_queue.put(idx)
test_queue.join()
time.sleep(0.1) # just to make printed output easier to read
log.info('Testing, %i %f', steps, data_test.sum())
if torch.rand(1) < error_prob_testing:
raise Exception('bad testing')
except:
log.error(' Error in testing thread, but we keep going')
traceback.print_exc()
if data_collect.isnan().any():
log.error(' Testing thread knows that there was an error in data_test, but we keep going')
log.info('Testing ended')
for i in range(num_test_procs):
test_queue.put(None)
training_thread = ThreadPoolExecutor(1)
testing_thread = ThreadPoolExecutor(1)
training_thread.submit(run_train)
testing_thread.submit(run_test)
if __name__ == '__main__':
run()
评论
0赞
Booboo
4/23/2023
顺便说一句,这只是增量.因此,如果循环的每次迭代运行速度都快于递增速度,则它可以在不更改的情况下执行此循环的多次迭代(我的答案也是如此)。也许你应该只运行 whenever 的迭代完成 (???) 的另一次迭代。run_train
steps
while steps < max_steps and not terminate
run_test
steps
steps
run_test
run_train
0赞
Simon
4/23/2023
@Booboo 是的,我检查了我的其他代码:只有在完成训练步骤时才会发生测试步骤,但训练不会等待测试(在我的情况下,训练需要尽可能快地运行)。这只是我能做的最简单的例子,以获得有关处理错误的帮助。steps
1赞
Booboo
4/23/2023
#2
要处理 Ctrl-c 中断,您需要初始化池进程以忽略中断。最好使用池初始值设定项函数实现此目的,如下面的代码所示。您没有指定您的平台,但在 Linux 上使用 spawn 仍可能导致打印堆栈跟踪,但一切都应该终止。
由于您已经在使用多处理池,因此我将摆脱使用显式队列并使用 提交索引范围,这是非阻塞的。由于您似乎需要特定数量的训练和测试进程,因此需要两个池:map_async
import logging
import traceback
import torch
import time
from torch import multiprocessing as mp
from functools import partial
import signal
import sys
try:
mp.set_start_method('spawn')
except:
pass
shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
def init_pool_processes(*args):
global terminate_event
terminate_event = args[0]
signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl-c
def collect(data_collect, idx):
"""
This now processes a single idx
"""
if terminate_event.is_set():
return
try:
data_collect[idx] = torch.rand(1)
# actually do something meaningful
...
except Exception as e:
log.error('Exception in collect for idx %i', idx)
traceback.print_exc()
raise e
def test(data_test, idx):
"""
This now processes a single idx
"""
if terminate_event.is_set():
return
try:
data_test[idx] = torch.rand(1)
# actually do something meaningful
pass
except Exception as e:
log.error('Exception in test process %i', id)
traceback.print_exc()
raise e
def run():
steps = 0
num_collect_procs = 3
num_test_procs = 2
max_steps = 10
terminate_event = mp.Event()
completion_event = mp.Event() # Training has been completed
data_collect = torch.zeros(num_collect_procs).share_memory_()
data_test = torch.zeros(num_test_procs).share_memory_()
# Since you want to control the number of processes used for
# training and collecting, we need two pools:
collect_pool = mp.Pool(num_collect_procs,
initializer=init_pool_processes,
initargs=(terminate_event,)
)
test_pool = mp.Pool(num_test_procs,
initializer=init_pool_processes,
initargs=(terminate_event,)
)
def collect_completion(result):
nonlocal steps
if isinstance(result, Exception):
# Got an exception so do not submit any more tasks
# and prevent any more tasks from being
terminate_event.set() # Force termination
completion_event.set() # Show we are done
elif steps < max_steps:
# Submit next collection of indices:
log.info('Training, %i %f', steps, data_collect.sum())
collect_pool.map_async(partial(collect, data_collect),
range(num_collect_procs),
callback=collect_completion,
error_callback=collect_completion
)
time.sleep(.1) # What is this for?
if steps == 0:
# kick of test tasks
test_completion(None)
time.sleep(0.1)
steps += 1
else:
completion_event.set() # Show we are done
def test_completion(result):
nonlocal steps
if not isinstance(result, Exception) and not terminate_event.is_set() and steps < max_steps:
# Submit next collection of indices:
log.info('Testing, %i %f', steps, data_test.sum())
test_pool.map_async(partial(test, data_test),
range(num_test_procs),
callback=test_completion,
error_callback=test_completion
)
try:
# Start the ball of rolling:
collect_completion(None)
# Wait for collection tasks to complete:
completion_event.wait()
# Now it is safe to wait for all submitted tasks to complete:
collect_pool.close()
collect_pool.join()
test_pool.close()
test_pool.join()
except KeyboardInterrupt:
pass
except Exception as e:
print(e)
if __name__ == '__main__':
run()
更新
如果只想在新的训练步骤完成后运行测试步骤,则:
import logging
import traceback
import torch
import time
from torch import multiprocessing as mp
import multiprocessing as mp
from functools import partial
import signal
import sys
try:
mp.set_start_method('spawn')
except:
pass
shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)
def init_pool_processes(*args):
global terminate_event
terminate_event = args[0]
signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl-c
def collect(data_collect, idx):
"""
This now processes a single idx
"""
if terminate_event.is_set():
return
try:
data_collect[idx] = torch.rand(1)
# actually do something meaningful
...
except Exception as e:
log.error('Exception in collect for idx %i', idx)
traceback.print_exc()
raise e
def test(data_test, idx):
"""
This now processes a single idx
"""
if terminate_event.is_set():
return
try:
data_test[idx] = torch.rand(1)
# actually do something meaningful
pass
except Exception as e:
log.error('Exception in test process %i', id)
traceback.print_exc()
raise e
def run():
steps = 0
num_collect_procs = 3
num_test_procs = 2
max_steps = 10
terminate_event = mp.Event()
completion_event = mp.Event() # Training has been completed
stop_tests = False
data_collect = torch.zeros(num_collect_procs).share_memory_()
data_test = torch.zeros(num_test_procs).share_memory_()
# Since you want to control the number of processes used for
# training and collecting, we need two pools:
collect_pool = mp.Pool(num_collect_procs,
initializer=init_pool_processes,
initargs=(terminate_event,)
)
test_pool = mp.Pool(num_test_procs,
initializer=init_pool_processes,
initargs=(terminate_event,)
)
def submit_next_collection(this_step):
# Submit next collection of indices:
log.info('Training, %i %f', steps, data_collect.sum())
collect_pool.map_async(partial(collect, data_collect),
range(num_collect_procs),
callback=collect_completion,
error_callback=collect_completion
)
def submit_next_test(this_step):
# Submit next collection of indices:
log.info('Testing, %i %f', steps, data_test.sum())
test_pool.map_async(partial(test, data_test),
range(num_test_procs),
callback=test_completion,
error_callback=test_completion
)
def collect_completion(result):
nonlocal steps
if isinstance(result, Exception):
# Got an exception so do not submit any more tasks
# and prevent any more tasks from being
terminate_event.set() # Force termination
completion_event.set() # Show we are done
return
if not stop_tests:
submit_next_test(steps)
steps += 1
if steps < max_steps:
submit_next_collection(steps)
time.sleep(0.1)
else:
completion_event.set() # Show we are done
def test_completion(result):
nonlocal stop_tests
if isinstance(result, Exception):
# No more testing:
stop_tests = True
try:
# Start the ball of rolling:
submit_next_collection(0)
# Wait for collection tasks to complete:
completion_event.wait()
# Now it is safe to wait for all submitted tasks to complete:
collect_pool.close()
collect_pool.join()
test_pool.close()
test_pool.join()
except KeyboardInterrupt:
pass
except Exception as e:
print(e)
if __name__ == '__main__':
run()
评论
0赞
Simon
4/23/2023
谢谢,它有效!否则,终端输出速度太快,很难看到错误消息。是的,在 Ubuntu 中,CTRL+C 实际上已经起作用了,问题出在 Windows 上。time.sleep
0赞
Simon
4/23/2023
我有2个问题。我已经设法让它工作(见我自己的答案)。使用它有什么优点/缺点吗?在我的原始(大得多)代码中,我还在 csv 文件上编写来自两个线程 ( 和 ) 的结果。我使用锁 from 来防止并发写入。这够了吗?我在函数的开头创建一个锁,然后在循环中执行。到目前为止,我还没有遇到任何问题,但我想知道这是否只是运气。再次感谢!concurrent.future
run_train
run_test
threading.Lock()
with lock: csv_logger.log(...)
0赞
Booboo
4/23/2023
我通常会用我的答案给出更长的描述,但我不得不继续做其他事情。请注意,不需要线程和工作函数和 ;一切都在回调函数中完成:当调用完成时,它会在下一次运行时提交到池(除非有异常)。所以我认为这比你所拥有的更简单。在这个程序中,看不到 or 之间有很大的区别。run_test
run_train
map_async
multiprocessing.Pool
concurrent.futures.ProcessPoolExecutor
0赞
Booboo
4/23/2023
如果在回调函数中执行文件 I/O,则该函数在重新输入之前会运行到完成,因此不需要锁定。
0赞
Simon
4/23/2023
对于我的例子,你是对的。问题是我的代码做了一些事情,我在这个例子中使用的结构是必要的。例如,我需要并有一个循环,因为在循环之前我实例化了一些对象。这需要时间,我不想每次通话都这样做。因此,我只创建一次对象,然后让“永远”运行。我将看看如何调整您的解决方案,因为它运行得更快,比我的更优雅。我对 multiproc 和 multithread 有点陌生,所以你的回答给了我很多想法。collect
test
collect
评论