链式 UI 的 RetrievalQA 流式处理问题

RetrievalQA streaming issue with chainlit UI

提问人:Tatva Joshi 提问时间:11/17/2023 最后编辑:cronoikTatva Joshi 更新时间:11/20/2023 访问量:91

问:

我一直在尝试将答案从 llm 流式传输到 Chainlit UI。我正在使用 langchain 库和 RetrievalQA 链将 llm、prompt 和向量存储与 memorybuffer 相结合。我尝试先在 cli 上流式传输 LLMchain,然后使用 chainlit ui。它工作得很好。但是当我尝试使用 RetrievalQA 链时,它仅适用于 cli,而不是将令牌流式传输到 chainlit ui。

这是我的代码:

`from llama_cpp import Llama
from langchain.chains import LLMChain,QAWithSourcesChain
from langchain.prompts import PromptTemplate
from langchain.llms import LlamaCpp
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import chainlit as cl
from langchain.memory import ConversationBufferMemory
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler
from langchain.callbacks.manager import AsyncCallbackManagerForRetrieverRun
from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler

DB_FAISS_PATH = 'vectorstores/db_faiss'

custom_prompt_template="""
 Below is an instruction that describes a task. Write a response that appropriately completes the request from the given context.
    ### Chat History:
    {chat_history}
    ## Context:
    {context}
    ### Instruction:
    {question}
    ### Response:
    """
def set_custom_prompt():
    """
    Prompt template for QA retrieval for each vectorstore
    """
    prompt = PromptTemplate(template=custom_prompt_template,
                            input_variables=["chat_history","context","question"])
    return prompt
custom_prompt_llmchain="""
 Below is an instruction that describes a task. Write a response that appropriately completes the request from the given context.
    ### Instruction:
    {question}
    ### Response:
    """


def set_custom_prompt_1():
    """
    Prompt template for QA retrieval for each vectorstore
    """
    prompt = PromptTemplate(template=custom_prompt_llmchain,
                            input_variables=["question"])
    return prompt
#Loading the model
def load_llm():
    llm = CTransformers(
    model="H:\WPCS_OPS_MERGED\OPS_WPCS_Q4.gguf", 
    model_type="llama",
    callbacks=[AsyncFinalIteratorCallbackHandler()],
    config={
    "temperature":0.15,
    # "max_new_tokens":512,
    # verbose=True,
    # streaming=True,
    "context_length":1800,
    "top_k":30,
    "repetition_penalty":1.2}
)
    return llm
def create_chain():
    llm_chain = LLMChain(
            llm=load_llm(),
            prompt=set_custom_prompt_1(),
            verbose=True,
        )
    return llm_chain
def create_chain_qa():
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
                                       model_kwargs={'device': 'cpu'})
    db = FAISS.load_local(DB_FAISS_PATH, embeddings)
    qa_chain = RetrievalQA.from_chain_type(llm=load_llm(),
                                       chain_type='stuff',
                                       retriever=db.as_retriever(search_kwargs={'k': 3}),
                                       verbose=True,
                                       return_source_documents=False,
                                       chain_type_kwargs={'prompt': set_custom_prompt(),"verbose":True,"memory":ConversationBufferMemory(memory_key="chat_history",input_key="question",max_token_limit=150,return_messages=True)}
                                       )
    return qa_chain
# if __name__=="__main__":
#     chain=create_chain()
#     a=chain.run("How can I access the Service Status Page for HLA?")
@cl.on_chat_start
async def start():
    chain = create_chain_qa()
    msg = cl.Message(content="Starting the bot...")
    await msg.send()
    msg.content = "Hi, Welcome to AHS Bot. What is your query?"
    await msg.update()
    cl.user_session.set("chain", chain)
from langchain.callbacks.base import BaseCallbackHandler

class StreamHandler(BaseCallbackHandler):
    def __init__(self):
        self.msg = cl.Message(content="")

    async def on_llm_new_token(self, token: str, **kwargs):
        await self.msg.stream_token(token)

    async def on_llm_end(self, response: str, **kwargs):
        await self.msg.send()
        self.msg = cl.Message(content="")
@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain") 
    cb = cl.AsyncLangchainCallbackHandler(
        stream_final_answer=True,
        answer_prefix_tokens=["FINAL", "ANSWER"]
    )
    # print(cb)
    # msg=cl.Message("")
    cb.answer_reached = True
    res = await chain.acall(message, callbacks=[cb])




    # # answer = res["result"]
    # async for part in res:
    #     if token := part.choices[0].text or "":
    #         await msg.stream_token(token)
    # await cl.Message(res["text"]).send()`

如果有人能帮我一把或提供解决这个问题的解决方案,那就太好了。

蟒蛇 语言链

评论


答: 暂无答案