提问人:Julian S. 提问时间:9/27/2023 更新时间:9/27/2023 访问量:128
使用 FastAPI 的回调异步函数 (websockets)
Callback async function with FastAPI (websockets)
问:
免責聲明:我是 python 的新手,所以如果(某些部分)问题听起来很愚蠢,我深表歉意。
我正在尝试制作一个利用 Stable Diffusion Pipelines 的 python API。
生成结果图像的过程大约需要 10 秒钟。出于用户体验目的,我尝试通过 websockets (byest/base64) 将当前状态(参见示例 gif)中继回给用户。
Pipeline 每 n(例如 10 个)步骤调用一个函数(假设生成过程每次需要 100 个步骤)callback
该进程通过 FastAPI 的“/ws”Websocket 端点启动和处理
如何在同步映像生成过程中使用异步 websocket 函数?
我尝试使用 asyncio 在单独的线程中启动回调函数
asyncio.create_task()
asyncio.get_event_loop().run_in_executor()
但我没有得到它的工作:/
这是我当前的代码:
注意:回调函数由 Pipeline 内部调用,参数为:、、step: int
timestep: int
latents: torch.FloatTensor
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
prompt = await websocket.receive_text()
if(prompt == ""):
return await websocket.send_text("Error: Empty prompt")
def callbackCaller(iteration, t, latents):
asyncio.get_running_loop().run_in_executor(None, pipelineCallback, iteration, t, latents)
async def pipelineCallback(iteration, t, latents):
print("called task...")
with torch.no_grad():
latents = 1 / 0.18215 * latents
image = stableDiffusionPipeline.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# convert to PIL Images
images = stableDiffusionPipeline.numpy_to_pil(image)
# check if batch size is 1
image = mergeImageIfNecessary(images)
# convert to bytes
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
# send image to client
await websocket.send_bytes(buffer.read())
try:
torch.Generator("cuda").manual_seed(0)
with torch.autocast("cuda"):
images = stableDiffusionPipeline(
prompt=prompt,
width=512,
height=512,
num_inference_steps=100,
guidance_scale=8,
callback=functools.partial(await pipelineCallback),
callback_steps=10,
num_images_per_prompt=1 # batch size
).images
image = mergeImageIfNecessary(images)
gc.collect()
torch.cuda.empty_cache()
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
image = buffer.read()
if(image == None):
return await websocket.send_text("Error: Image generation failed.")
buffer = io.BytesIO()
image.save("test.png")
image.save(buffer, format="PNG")
buffer.seek(0)
image = buffer.read()
await websocket.send_bytes(image)
print("closing socket...")
await websocket.close()
except Exception as e:
traceback.print_exc()
await websocket.send_text(f"Error: {str(e)}")
finally:
pass
内部回调函数类型
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None
我已经阅读了无数的stackoverflow问题和文档,但我无法让它工作。
我不太明白这是/可以工作的。
我尝试使用几个异步函数在单独的线程中响应用户。我希望存档异步处理,这样就不会抛出错误。
错误:RuntimeWarning: coroutine 'websocket_endpoint.<locals>.pipelineCallback' was never awaited
我还试图简单地等待管道参数中的函数。callback
错误:TypeError: object function can't be used in 'await' expression
答: 暂无答案
评论
loop.run_in_executor(None, lambda: asyncio.run(pipelineCallback(...)))