ConversationChainに独自の入力値を持つSystemMessagePromptTemplateを利用してみた

introduction

はじめまして!今年の7月からInsight Edge開発チームに加わった塚越です。 ChatGPT関連のPoCに携わっています!開発だけでなく、分析の要素も経験もさせていただき、とても楽しく取り組んでいます。 また、Insight Edgeではコワーキングスペースの利用が可能なので、家の近くで集中できる環境を手に入れ、快適な日々を過ごしています。

今回は、LLMを使用したチャットボットへの会話履歴の実装方法について、試行錯誤した経験を共有したいと思います。

目次:

概要

チャットボットへの会話履歴の実装方法は色々ありますが、今回はLangChainから以下のモジュールを採用しました。

  • ChatPromptTemplate.from_messages
  • SystemMessagePromptTemplate
  • ConversationTokenBufferMemory(こちらはConversationBufferMemoryでも他のものでも可)
  • ConversationChain

そしてSystemMessagePromptTemplateに入力値を組み込んだ独自のシステムテンプレートを実装することが今回の目標です。 本記事では扱っていませんが、ベクトルストアから取得した文書をシステムテンプレートに組み込みたかったことが発端です。

公式ドキュメントを参照した個々の実装は容易ですが、そのまま組み合わせてもエラーが発生しました。

ConversationChainではプロンプトに利用する入力値を変えられないという点でつまづき、SystemMessagePromptTemplateに入力値を組み込んだまま実装する方法を試行錯誤しました。

上手くいった方法は 実装 にてご紹介し、つまづいた点については ハマりどころ にてご紹介します。

環境

  • Python 3.10.12
  • langchain 0.0.225

実装

環境変数

以下の .env ファイルを同階層に配置してください。

OPENAI_API_KEY=(ご自身のキーを記載してください)

なお、実際の案件では下記AzureChatOpenAIを利用しています。

from langchain.chat_models import AzureChatOpenAI # langchain 0.0.225

使い方の詳細は割愛します。

上手くいった方法

以下の点に注意することで上手くいきました。

  • ConversationChainで利用するプロンプトの入力値 historyinput を公式ドキュメント通りに利用すること。
  • システムテンプレートに独自の入力値を持たせたい場合は、Pythonの文字列フォーマットを使って値を入れてからSystemMessagePromptTemplateで利用すること。

参考①:ChatPromptTemplateとSystemMessagePromptTemplate.from_templateの利用

参考②:Memory機能を持たせたChatPromptTemplateの利用

参考③:ConversationTokenBufferMemoryの利用

参考④:ConversationChainの利用

from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.memory import ConversationTokenBufferMemory
# 下記は参考サイトをベースに自分で取捨選択して記載
from langchain.prompts import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
# 環境変数の読み込み
from dotenv import load_dotenv
load_dotenv()
def answer_query(query, input_language, output_language):
    # ChatPromptTemplateとSystemMessagePromptTemplate.from_templateの利用(参考①)
    # Memory機能を持たせたChatPromptTemplateの利用(参考②)
    template="You are a helpful assistant that translates {input_language} to {output_language}."
    template_filled=template.format(input_language=input_language, output_language=output_language)

    chat_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(template_filled),
    MessagesPlaceholder(variable_name="history"),
    HumanMessagePromptTemplate.from_template("{input}")])

    # ConversationTokenBufferMemoryの利用(参考③)
    llm = OpenAI()
    # トークン数に応じて会話履歴の抽出数が調整されることを max_token_limit で確認する
    memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=30, return_messages=True)
    memory.save_context({"input": "hi"}, {"output": "こんにちは"})
    memory.save_context({"input": "not much you"}, {"output": "私はあまりあなたを知りません"})

    # ConversationChainの利用(参考④)
    conversation = ConversationChain(prompt=chat_prompt, llm=llm, verbose=True, memory=memory)
    # LLMへ質問
    return conversation.predict(input=query)

# 質問を入力
answer=answer_query("Hi there!", input_language="English", output_language="日本語")
# 回答を表示
print(answer)

出力内容


> Entering new  chain...
Prompt after formatting:
System: You are a helpful assistant that translates English to 日本語.
Human: not much you
AI: 私はあまりあなたを知りません
Human: Hi there!

> Finished chain.

AI: やあ!

これらの出力から下記を確認し、意図した挙動が実現できたことを確認しました。

  • システムテンプレートに"English"、"日本語"という入力値が反映されている。
  • 会話履歴が反映されいる(verbose=Trueによって出力可能)。
  • max_token_limit=30 により、会話履歴の抽出数が調整されている(「hi」「こんにちは」の会話が削除されている)。

ハマりどころ

ポイント1:ConversationChainで利用するプロンプトの入力値を変えることができなかった

参考②のLLMChainではプロンプトの入力値を変更しても問題なく動作しました。 しかし、ConversationChainでは名前を変更できず、ここにハマりました。

実際に試した上手くいかない例1

上手くいった方法 からプロンプトの入力値 historyinput を別の名前に変更したところ、期待される値が見つからない旨のエラーが発生しました。

(省略)
    MessagesPlaceholder(variable_name="chat_history"), 
    HumanMessagePromptTemplate.from_template("{human_input}")])
(省略)
    # LLMへ質問
    return conversation.predict(human_input=query)

# 質問を入力
answer=answer_query("Hi there!", input_language="English", output_language="日本語")
# 回答を表示
print(answer)

ポイント2:入力値付きSystemMessagePromptTemplateをConversationChainに指定することができなかった

上手くいった方法 では template を単純な文字列 template_filled に変換して利用しています。 他のモジュールでは、複数の入力値を持つプロンプトテンプレートの扱い方が見受けられます。 今回の目標条件下での応用を試みましたが、入力値の挿入タイミングの特定が困難であり、適切な実装方法を見つけることができず、この点で大きくつまづきました。

以下は、複数の入力値を持ったままプロンプトテンプレートを扱うLLMChainの例です。

参考:LLMChainの入力例

from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
# Multiple inputs example
template = """Tell me a {adjective} joke about {subject}."""
prompt = PromptTemplate(template=template, input_variables=["adjective", "subject"])
llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=0))

print(llm_chain.predict(adjective="sad", subject="ducks"))

出力結果

Q: What did the duck say when his friend died?
A: Quack, quack, goodbye.

これをComversationChainとSystemMessagePromptTemplateを使って試行錯誤した記録は以下の通りです。

実際に試した上手くいかない例2

from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.memory import ConversationTokenBufferMemory
from langchain.prompts import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)
# 環境変数の読み込み
from dotenv import load_dotenv
load_dotenv()
def answer_query(query, input_language, output_language):
    template="You are a helpful assistant that translates {input_language} to {output_language}."

    chat_prompt = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(template),
    MessagesPlaceholder(variable_name="history"), 
    HumanMessagePromptTemplate.from_template("{input}")]) 

    llm = OpenAI()
    memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=30, return_messages=True) 
    memory.save_context({"input": "hi"}, {"output": "こんにちは"})
    memory.save_context({"input": "not much you"}, {"output": "私はあまりあなたを知りません"})

    conversation = ConversationChain(prompt=chat_prompt, llm=llm, verbose=True, memory=memory)
    return conversation.predict(input_language=input_language, output_language=output_language, input=query)

answer=answer_query("Hi there!", input_language="English", output_language="日本語")
print(answer)

下記の2行を conversation = ConversationChain(prompt=chat_prompt, llm=llm, verbose=True, memory=memory) の直前に追記してみたり・・・もちろんエラーが発生します・・・。

    from langchain.prompts import PromptTemplate
    chat_prompt = PromptTemplate(template=chat_prompt, input_variables=["input_language", "output_language"])

紹介はしませんが、思いつくままに試してたくさん失敗しました。

まとめ

SystemMessagePromptTemplateに入力値を組み込みながら、ConversationChainで利用する方法は見つかりませんでした。 代わりに、Pythonの文字列フォーマットを使って値を入れてからSystemMessagePromptTemplateを利用する方法でなんとか対応できました。

感想

ConversationChainではなくLLM Chainを使ったり、ChatPromptTemplateではなくPromptTemplateを使ったりする選択肢もありました。 しかし、その変更が性能に影響を与える可能性があったので、できるところまで挑戦し続けました。

LangChainで色々なモジュールを組み合わせる際は、自分で一から試行錯誤するよりも、実装例を探してその通りに実施する方がスムーズだと感じました。 モジュールの細部を完全に理解すれば、自分で解決策を見つけることも可能かと思いますが、それはそれでかなりの労力が必要そうだったので、まずは色々と試行錯誤してみました。

今回は会話履歴についてご紹介しましたが、QAや要約タスクに関しても精度の向上を目指して日々努力しています。 機会があれば、その成果もシェアしたいと思います!