在 python 多处理和多线程中处理异常和 CTRL+C

Handling exceptions and CTRL+C in python multiprocessing and multithread

提问人:Simon 提问时间:4/22/2023 最后编辑:Simon 更新时间:4/23/2023 访问量:434

问:

我的代码执行以下操作:

  • 启动收集数据的进程
  • 启动进程以测试模型
  • 一个线程负责聚合数据以训练模型
  • 一个线程聚合测试结果

我需要处理异常,如果训练线程失败,一切都会终止,但如果测试线程失败,训练将继续。

下面是一个 MWE(没有模型,但我管理进程和线程的方式是相同的)。

它可以工作,但我无法抓住 CTRL+C。如果我按下它,我会收到一条很长的消息文本,但程序会继续运行。
我已经读到我应该将主线程设置为但没有任何变化。我尝试在主呼叫周围放置一个无济于事。
daemon = Truetry exceptrun

如何使用 CTRL+C 优雅地结束程序?

以及其他问题:

  • 有没有更好的方法可以在训练线程结束时结束程序?我正在使用该变量,因为不起作用。我已经阅读了有关使用的信息,但我找不到简单的指南。terminatedaemon = Trueconcurrent.future
  • 像我的示例一样,使用一个池还是使用两个池(一个用于收集过程,一个用于训练过程)更好?会有什么区别吗?
  • 如何使每个线程捕获正确的异常?我的代码在 和 中引发了异常,这些异常由进程运行。但是,由 引发的异常必须由 捕获,而由 引发的异常必须由 捕获。collecttestcollectrun_traintestrun_test
import threading
import logging
import traceback
import torch
import time
from torch import multiprocessing as mp
try:
    mp.set_start_method('spawn')
except:
    pass

shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)


def collect(id, queue, data_collect):
    log.info('Collect %i started ...', id)
    try:
        while True:
            idx = queue.get()
            if idx is None:
                break
            data_collect[idx] = torch.rand(1)
            queue.task_done()
            # actually do something meaningful
    except Exception as e:
        log.error('Exception in collect process %i', id)
        traceback.print_exc()
        raise e


def test(id, queue, data_test):
    log.info('Test %i started ...', id)
    try:
        while True:
            idx = queue.get()
            if idx is None:
                break
            data_test[idx] = torch.rand(1)
            queue.task_done()
            # actually do something meaningful
    except Exception as e:
        log.error('Exception in test process %i', id)
        traceback.print_exc()
        raise e


def run():
    steps = 0
    num_collect_procs = 3
    num_test_procs = 2
    max_steps = 10
    terminate = False

    data_collect = torch.zeros(num_collect_procs).share_memory_()
    data_test = torch.zeros(num_test_procs).share_memory_()

    manager = mp.Manager()
    pool = mp.Pool()
    collect_queue = manager.JoinableQueue()
    test_queue = manager.JoinableQueue()

    # Start collection and testing processes
    for i in range(num_collect_procs):
        pool.apply_async(collect, args=(i, collect_queue, data_collect))

    for i in range(num_test_procs):
        pool.apply_async(test, args=(i, test_queue, data_test))

    # Define target function for the learning thread
    def run_train():
        nonlocal steps, terminate
        log.info('Training thread started ...')
        while steps < max_steps and not terminate:
            try:
                for idx in range(num_collect_procs):
                    collect_queue.put(idx)
                collect_queue.join()
                time.sleep(0.1)
                log.info('Training, %i %f', steps, data_collect.sum())
                steps += 1
            except:
                terminate = True
                for idx in range(num_collect_procs):
                    collect_queue.put(None)
                break
        log.info('Training done')


    # Define target function for the testing thread
    def run_test():
        nonlocal steps, terminate
        log.info('Testing thread started ...')
        while steps < max_steps and not terminate:
            try:
                for idx in range(num_test_procs):
                    test_queue.put(idx)
                test_queue.join()
                time.sleep(0.1)
                log.info('Testing, %i %f', steps, data_test.sum())
            except:
                for idx in range(num_test_procs):
                    test_queue.put(None)
                break
        log.info('Testing done')

    learning_thread = threading.Thread(target=run_train, name='train')
    learning_thread.start()

    testing_thread = threading.Thread(target=run_test, name='test')
    testing_thread.start()

    learning_thread.join()
    testing_thread.join()

    collect_queue.join()
    test_queue.join()
    pool.terminate()
    pool.join()


if __name__ == '__main__':
    run()
Python 线程 异常 多处理

评论

0赞 user2357112 4/22/2023
为什么你创建了两个锁,然后从不使用任何一个?
0赞 Simon 4/22/2023
@user2357112 我的错,我从我的代码中复制粘贴了它,忘记删除它。

答:

0赞 Simon 4/23/2023 #1

我使用并捕获进程中的异常来做到这一点。代码比我预期的要长,所以如果有人有更好的解决方案,我很乐意使用他们的。concurrent.futuredone_callback

代码如下。如果有人想使用它,请注意以下几点:

  • 我将上下文设置为因为它是 Windows 唯一支持的上下文,并且我需要相同的代码才能在 Ubuntu 和 Windows 上工作。spawn
  • 进程引发的异常不能由线程直接读取。但是,如果引发异常,则对象会将工作视为已完成,因此我用于清理队列并更新变量以通知线程结束。futuredone_callbackterminate
  • 进程和线程之间通信的另一种方法是读取它们使用的共享数据。在此示例中,我检查 .nan
  • 仅当 或 中发生错误时,程序才会结束。如果在测试过程中出现故障,训练将继续进行。data_collectrun_train
  • 这也抓住了.CTRL+C
  • 在我的示例中,基本上做同样的事情(将随机值分配给张量),但这只是一个示例。在实践中,他们会做不同的事情。collecttrain
import logging
import traceback
import torch
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from torch import multiprocessing as mp

shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)

# to randomly raise errors
# play with them to see how the program behaves
error_prob_collect = 0.
error_prob_test = 0.5
error_prob_training = 0.
error_prob_testing = 0.

def collect(id, queue, data):
    log.info('Collect %i started ...', id)
    while True:
        try:
            idx = queue.get()
            if idx is None:
                break
            data[idx] = torch.rand(1)
            if torch.rand(1) < error_prob_collect:
                data[idx] = torch.nan # the thread will read this
                raise Exception('unlucky training')
        except Exception as e:
            log.error('Exception in collect process %i', id)
            traceback.print_exc()
        finally:
            queue.task_done()
    log.info('Collect %i completed', id)


def test(id, queue, data):
    log.info('Test %i started ...', id)
    while True:
        try:
            idx = queue.get()
            if idx is None:
                break
            data[idx] = torch.rand(1)
            if torch.rand(1) < error_prob_test:
                data[idx] = torch.nan # the thread will read this
                raise Exception('unluckly testing')
        except Exception as e:
            log.error('Exception in test process %i', id)
            traceback.print_exc()
        finally:
            queue.task_done()
    log.info('Test %i completed', id)


def run():
    steps = 0
    terminate = False
    num_collect_procs = 3
    num_test_procs = 2
    max_steps = 10

    data_collect = torch.zeros(num_collect_procs).share_memory_()
    data_test = torch.zeros(num_test_procs).share_memory_()

    ctx = mp.get_context('spawn')
    manager = mp.Manager()
    collect_queue = manager.JoinableQueue()
    test_queue = manager.JoinableQueue()

    collect_pool = ProcessPoolExecutor(num_collect_procs, mp_context=ctx)
    test_pool = ProcessPoolExecutor(num_test_procs, mp_context=ctx)

    def clear_collect(future):
        while not collect_queue.empty():
            try:
                collect_queue.task_done()
            except ValueError:
                break
        nonlocal terminate
        terminate = True

    for i in range(num_collect_procs):
        future = collect_pool.submit(collect, i, collect_queue, data_collect)
        future.add_done_callback(clear_collect)

    for i in range(num_test_procs):
        future = test_pool.submit(test, i, test_queue, data_test)

    def run_train():
        nonlocal steps, terminate
        log.info('Training thread started ...')
        while steps < max_steps:
            try:
                for idx in range(num_collect_procs):
                    collect_queue.put(idx)
                collect_queue.join()
                time.sleep(0.1) # just to make printed output easier to read
                log.info('Training, %i %f', steps, data_collect.sum())
                steps += 1
                if torch.rand(1) < error_prob_training:
                    raise Exception('bad training')
            except:
                log.error('   Error in training thread, we are going to terminate everything')
                traceback.print_exc()
                break
            if data_collect.isnan().any():
                log.error('   Training thread knows that there was an error in data_collect, we are going to terminate everything')
                break
        log.info('Training ended')
        for i in range(num_collect_procs):
            collect_queue.put(None)
        terminate = True

    def run_test():
        nonlocal steps
        log.info('Testing thread started ...')
        while steps < max_steps and not terminate:
            try:
                for idx in range(num_test_procs):
                    test_queue.put(idx)
                test_queue.join()
                time.sleep(0.1) # just to make printed output easier to read
                log.info('Testing, %i %f', steps, data_test.sum())
                if torch.rand(1) < error_prob_testing:
                    raise Exception('bad testing')
            except:
                log.error('   Error in testing thread, but we keep going')
                traceback.print_exc()
            if data_collect.isnan().any():
                log.error('   Testing thread knows that there was an error in data_test, but we keep going')
        log.info('Testing ended')
        for i in range(num_test_procs):
            test_queue.put(None)

    training_thread = ThreadPoolExecutor(1)
    testing_thread = ThreadPoolExecutor(1)
    training_thread.submit(run_train)
    testing_thread.submit(run_test)


if __name__ == '__main__':
    run()

评论

0赞 Booboo 4/23/2023
顺便说一句,这只是增量.因此,如果循环的每次迭代运行速度都快于递增速度,则它可以在不更改的情况下执行此循环的多次迭代(我的答案也是如此)。也许你应该只运行 whenever 的迭代完成 (???) 的另一次迭代。run_trainstepswhile steps < max_steps and not terminaterun_teststepsstepsrun_testrun_train
0赞 Simon 4/23/2023
@Booboo 是的,我检查了我的其他代码:只有在完成训练步骤时才会发生测试步骤,但训练不会等待测试(在我的情况下,训练需要尽可能快地运行)。这只是我能做的最简单的例子,以获得有关处理错误的帮助。steps
1赞 Booboo 4/23/2023 #2

要处理 Ctrl-c 中断,您需要初始化池进程以忽略中断。最好使用池初始值设定项函数实现此目的,如下面的代码所示。您没有指定您的平台,但在 Linux 上使用 spawn 仍可能导致打印堆栈跟踪,但一切都应该终止。

由于您已经在使用多处理池,因此我将摆脱使用显式队列并使用 提交索引范围,这是非阻塞的。由于您似乎需要特定数量的训练和测试进程,因此需要两个池:map_async

import logging
import traceback
import torch
import time
from torch import multiprocessing as mp
from functools import partial
import signal
import sys

try:
    mp.set_start_method('spawn')
except:
    pass

shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)

def init_pool_processes(*args):
    global terminate_event

    terminate_event = args[0]
    signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl-c

def collect(data_collect, idx):
    """
    This now processes a single idx
    """
    if terminate_event.is_set():
        return
    try:
        data_collect[idx] = torch.rand(1)
        # actually do something meaningful
        ...
    except Exception as e:
        log.error('Exception in collect for idx %i', idx)
        traceback.print_exc()
        raise e


def test(data_test, idx):
    """
    This now processes a single idx
    """
    if terminate_event.is_set():
        return
    try:
        data_test[idx] = torch.rand(1)
        # actually do something meaningful
        pass
    except Exception as e:
        log.error('Exception in test process %i', id)
        traceback.print_exc()
        raise e

def run():
    steps = 0
    num_collect_procs = 3
    num_test_procs = 2
    max_steps = 10
    terminate_event = mp.Event()
    completion_event = mp.Event() # Training has been completed

    data_collect = torch.zeros(num_collect_procs).share_memory_()
    data_test = torch.zeros(num_test_procs).share_memory_()

    # Since you want to control the number of processes used for
    # training and collecting, we need two pools:
    collect_pool = mp.Pool(num_collect_procs,
                           initializer=init_pool_processes,
                           initargs=(terminate_event,)
                           )
    test_pool = mp.Pool(num_test_procs,
                        initializer=init_pool_processes,
                        initargs=(terminate_event,)
                        )

    def collect_completion(result):
        nonlocal steps

        if isinstance(result, Exception):
            # Got an exception so do not submit any more tasks
            # and prevent any more tasks from being
            terminate_event.set() # Force termination
            completion_event.set() # Show we are done
        elif steps < max_steps:
            # Submit next collection of indices:
            log.info('Training, %i %f', steps, data_collect.sum())
            collect_pool.map_async(partial(collect, data_collect),
                                   range(num_collect_procs),
                                   callback=collect_completion,
                                   error_callback=collect_completion
                                   )
            time.sleep(.1) # What is this for?
            if steps == 0:
                # kick of test tasks
                test_completion(None)
            time.sleep(0.1)
            steps += 1
        else:
            completion_event.set() # Show we are done

    def test_completion(result):
        nonlocal steps

        if not isinstance(result, Exception) and not terminate_event.is_set() and steps < max_steps:
            # Submit next collection of indices:
            log.info('Testing, %i %f', steps, data_test.sum())
            test_pool.map_async(partial(test, data_test),
                                range(num_test_procs),
                                callback=test_completion,
                                error_callback=test_completion
                                )

    try:
        # Start the ball of rolling:
        collect_completion(None)

        # Wait for collection tasks to complete:
        completion_event.wait()

        # Now it is safe to wait for all submitted tasks to complete:
        collect_pool.close()
        collect_pool.join()
        test_pool.close()
        test_pool.join()
    except KeyboardInterrupt:
        pass
    except Exception as e:
        print(e)

if __name__ == '__main__':
    run()

更新

如果只想在新的训练步骤完成后运行测试步骤,则:

import logging
import traceback
import torch
import time
from torch import multiprocessing as mp
import multiprocessing as mp
from functools import partial
import signal
import sys

try:
    mp.set_start_method('spawn')
except:
    pass

shandle = logging.StreamHandler()
log = logging.getLogger('rl')
log.propagate = False
log.addHandler(shandle)
log.setLevel(logging.INFO)

def init_pool_processes(*args):
    global terminate_event

    terminate_event = args[0]
    signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore Ctrl-c

def collect(data_collect, idx):
    """
    This now processes a single idx
    """
    if terminate_event.is_set():
        return
    try:
        data_collect[idx] = torch.rand(1)
        # actually do something meaningful
        ...
    except Exception as e:
        log.error('Exception in collect for idx %i', idx)
        traceback.print_exc()
        raise e


def test(data_test, idx):
    """
    This now processes a single idx
    """
    if terminate_event.is_set():
        return
    try:
        data_test[idx] = torch.rand(1)
        # actually do something meaningful
        pass
    except Exception as e:
        log.error('Exception in test process %i', id)
        traceback.print_exc()
        raise e

def run():
    steps = 0
    num_collect_procs = 3
    num_test_procs = 2
    max_steps = 10
    terminate_event = mp.Event()
    completion_event = mp.Event() # Training has been completed
    stop_tests = False

    data_collect = torch.zeros(num_collect_procs).share_memory_()
    data_test = torch.zeros(num_test_procs).share_memory_()

    # Since you want to control the number of processes used for
    # training and collecting, we need two pools:
    collect_pool = mp.Pool(num_collect_procs,
                           initializer=init_pool_processes,
                           initargs=(terminate_event,)
                           )
    test_pool = mp.Pool(num_test_procs,
                        initializer=init_pool_processes,
                        initargs=(terminate_event,)
                        )

    def submit_next_collection(this_step):
        # Submit next collection of indices:
        log.info('Training, %i %f', steps, data_collect.sum())
        collect_pool.map_async(partial(collect, data_collect),
                               range(num_collect_procs),
                               callback=collect_completion,
                               error_callback=collect_completion
                               )

    def submit_next_test(this_step):
        # Submit next collection of indices:
        log.info('Testing, %i %f', steps, data_test.sum())
        test_pool.map_async(partial(test, data_test),
                            range(num_test_procs),
                            callback=test_completion,
                            error_callback=test_completion
                            )


    def collect_completion(result):
        nonlocal steps

        if isinstance(result, Exception):
            # Got an exception so do not submit any more tasks
            # and prevent any more tasks from being
            terminate_event.set() # Force termination
            completion_event.set() # Show we are done
            return

        if not stop_tests:
            submit_next_test(steps)

        steps += 1
        if steps < max_steps:
            submit_next_collection(steps)
            time.sleep(0.1)
        else:
            completion_event.set() # Show we are done

    def test_completion(result):
        nonlocal stop_tests

        if isinstance(result, Exception):
            # No more testing:
            stop_tests = True

    try:
        # Start the ball of rolling:
        submit_next_collection(0)

        # Wait for collection tasks to complete:
        completion_event.wait()

        # Now it is safe to wait for all submitted tasks to complete:
        collect_pool.close()
        collect_pool.join()
        test_pool.close()
        test_pool.join()
    except KeyboardInterrupt:
        pass
    except Exception as e:
        print(e)

if __name__ == '__main__':
    run()

评论

0赞 Simon 4/23/2023
谢谢,它有效!否则,终端输出速度太快,很难看到错误消息。是的,在 Ubuntu 中,CTRL+C 实际上已经起作用了,问题出在 Windows 上。time.sleep
0赞 Simon 4/23/2023
我有2个问题。我已经设法让它工作(见我自己的答案)。使用它有什么优点/缺点吗?在我的原始(大得多)代码中,我还在 csv 文件上编写来自两个线程 ( 和 ) 的结果。我使用锁 from 来防止并发写入。这够了吗?我在函数的开头创建一个锁,然后在循环中执行。到目前为止,我还没有遇到任何问题,但我想知道这是否只是运气。再次感谢!concurrent.futurerun_trainrun_testthreading.Lock()with lock: csv_logger.log(...)
0赞 Booboo 4/23/2023
我通常会用我的答案给出更长的描述,但我不得不继续做其他事情。请注意,不需要线程和工作函数和 ;一切都在回调函数中完成:当调用完成时,它会在下一次运行时提交到池(除非有异常)。所以我认为这比你所拥有的更简单。在这个程序中,看不到 or 之间有很大的区别。run_testrun_trainmap_asyncmultiprocessing.Poolconcurrent.futures.ProcessPoolExecutor
0赞 Booboo 4/23/2023
如果在回调函数中执行文件 I/O,则该函数在重新输入之前会运行到完成,因此不需要锁定。
0赞 Simon 4/23/2023
对于我的例子,你是对的。问题是我的代码做了一些事情,我在这个例子中使用的结构是必要的。例如,我需要并有一个循环,因为在循环之前我实例化了一些对象。这需要时间,我不想每次通话都这样做。因此,我只创建一次对象,然后让“永远”运行。我将看看如何调整您的解决方案,因为它运行得更快,比我的更优雅。我对 multiproc 和 multithread 有点陌生,所以你的回答给了我很多想法。collecttestcollect