使用 FastAPI 的回调异步函数 (websockets)

Callback async function with FastAPI (websockets)

提问人:Julian S. 提问时间:9/27/2023 更新时间:9/27/2023 访问量:128

问:

免責聲明:我是 python 的新手,所以如果(某些部分)问题听起来很愚蠢,我深表歉意。


我正在尝试制作一个利用 Stable Diffusion Pipelines 的 python API。

生成结果图像的过程大约需要 10 秒钟。出于用户体验目的,我尝试通过 websockets (byest/base64) 将当前状态(参见示例 gif)中继回给用户。

示例预览

Pipeline 每 n(例如 10 个)步骤调用一个函数(假设生成过程每次需要 100 个步骤)callback

该进程通过 FastAPI 的“/ws”Websocket 端点启动和处理

如何在同步映像生成过程中使用异步 websocket 函数?

我尝试使用 asyncio 在单独的线程中启动回调函数

  • asyncio.create_task()
  • asyncio.get_event_loop().run_in_executor()

但我没有得到它的工作:/

这是我当前的代码:

注意:回调函数由 Pipeline 内部调用,参数为:、、step: inttimestep: intlatents: torch.FloatTensor

@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    prompt = await websocket.receive_text()
    if(prompt == ""):
        return await websocket.send_text("Error: Empty prompt")

    def callbackCaller(iteration, t, latents):
        asyncio.get_running_loop().run_in_executor(None, pipelineCallback, iteration, t, latents)

    async def pipelineCallback(iteration, t, latents):
        print("called task...")
        with torch.no_grad():
            latents = 1 / 0.18215 * latents
            image = stableDiffusionPipeline.vae.decode(latents).sample

            image = (image / 2 + 0.5).clamp(0, 1)

            # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
            image = image.cpu().permute(0, 2, 3, 1).float().numpy()

            # convert to PIL Images
            images = stableDiffusionPipeline.numpy_to_pil(image)

            # check if batch size is 1
            image = mergeImageIfNecessary(images)

            # convert to bytes
            buffer = io.BytesIO()
            image.save(buffer, format="PNG")
            buffer.seek(0) 

            # send image to client
            await websocket.send_bytes(buffer.read())
            

    try: 
        torch.Generator("cuda").manual_seed(0)

        with torch.autocast("cuda"):
            images = stableDiffusionPipeline(
                prompt=prompt,
                width=512,
                height=512,
                num_inference_steps=100,
                guidance_scale=8,
                callback=functools.partial(await pipelineCallback),
                callback_steps=10,
                num_images_per_prompt=1  # batch size
            ).images

        image = mergeImageIfNecessary(images)

        gc.collect()
        torch.cuda.empty_cache()

        buffer = io.BytesIO()
        image.save(buffer, format="PNG")
        buffer.seek(0)

        image = buffer.read()

        if(image == None):
            return await websocket.send_text("Error: Image generation failed.")
        
        buffer = io.BytesIO()
        image.save("test.png")
        image.save(buffer, format="PNG")
        buffer.seek(0)
        image = buffer.read()

        
        await websocket.send_bytes(image)
        print("closing socket...")
        await websocket.close()
    except Exception as e:
        traceback.print_exc()
        await websocket.send_text(f"Error: {str(e)}")
    finally:
       pass

内部回调函数类型

callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None

我已经阅读了无数的stackoverflow问题和文档,但我无法让它工作。

我不太明白这是/可以工作的。

我尝试使用几个异步函数在单独的线程中响应用户。我希望存档异步处理,这样就不会抛出错误。

错误:RuntimeWarning: coroutine 'websocket_endpoint.<locals>.pipelineCallback' was never awaited

我还试图简单地等待管道参数中的函数。callback

错误:TypeError: object function can't be used in 'await' expression

python async-await 回调 fastapi stable-diffusion

评论

0赞 Chris 9/27/2023
回答了你的问题吗?
0赞 Chris 9/27/2023
你可能会发现这个,以及这个和这个很有帮助
0赞 Julian S. 9/28/2023
@Chris,感谢您的快速回答。我实现了您提供的第一个解决方案:它似乎在第一次通话时工作正常。在函数的第二次迭代中,它将被调用并执行 websocket 调用。它不会在执行时将数据发送到客户端。整个过程完成后,所有“排队???”消息都会发送到客户端。我还没有找到通过管道回调按“同步”顺序发送所有响应的解决方案。你知道为什么会这样吗?loop.run_in_executor(None, lambda: asyncio.run(pipelineCallback(...)))
0赞 Chris 9/28/2023
请看一下这个答案
0赞 Julian S. 9/29/2023
@Chris,感谢您对我的耐心等待,:),我不确定我的 websockets 方法是否正确。我之前实际上已经尝试过 StreamingResponses(带有 FastAPI 文档中的基本示例),但我没有让它与我的用例一起使用,因此我切换到 websockets,现在也证明这很棘手。在我继续这个项目之前,我想确定一种具体的方式,但我还不确定是哪一种......我知道这有很多问题要问,而且不在我最初问题的范围之内,但您能推荐一下我的项目的最佳方法是什么吗?最好的问候 ~朱利安

答: 暂无答案