【LangGraph】複数LLMモデルをChainで実装する方法

当ページには広告が含まれています。

こんにちは、DXCEL WAVEの運営者(@dxcelwave)です!

こんな方におすすめ!
  • LangGraphを活用した開発に興味がある!
  • LangGraphで複数のLLMをChain方式で実装する方法が知りたい!
目次

【LangGraph】実装イメージ

当記事では上図のように2つのLLMをChainで繋げた処理をLangGraphで実装します。

model_1、model_2はそれぞれ次のような入出力をできるようにプログラミングします。

モデル名入力出力
model_1あるお題お題に対するダジャレ
model_2model_1の出力結果(ダジャレ)ダジャレの面白さ評価レビュー

【Python】LangGraphを用いた複数Chainの実装

それでは実際にPythonを記述しながらLangGraphでChainを構築していきます。

ライブラリ

はじめに、以下のライブラリを読み込みます。今回の例ではChatGPTのAPI_KEYを利用します。

# State
from typing_extensions       import TypedDict,Annotated

# LangChain
from IPython.display          import Image, display
from langchain_openai         import ChatOpenAI
from langchain.prompts        import PromptTemplate
from langchain.schema         import SystemMessage, HumanMessage,AIMessage

# LangGraph
from langgraph.graph.message  import add_messages
from langgraph.graph          import StateGraph, START, END

# ChatGPT API Key
API_KEY = "...."

Stateの作成

今回の例では、次のようなStateを用意します。

# ================================
# State
# ================================

class State(TypedDict,total=False):
    question:str                               # 質問内容
    model_1_output:str                         # モデル(1)の出力結果
    model_2_output:dict                        # モデル(2)の出力結果
    messages: Annotated[list, add_messages]    # LLM同士の会話履歴

ユーザーのお題を保持するquestion、各種モデルの出力結果であるmodel_1_outputmodel_2_output、すべての会話履歴を保持するmessagesを用いる仕様です。

LLMの作成

後述のNodeで実行するLLMを構築します。以下を記述しましょう。

# ================================
# LLM
# ================================

llm = ChatOpenAI(
        model       = "gpt-4.1",   
        temperature = 0,
        api_key     = API_KEY
        )

# 実行
# llm.invoke("こんにちは").content

# 出力イメージ
# こんにちは! 今日はどんなご用件でしょうか?お手伝いできることがあれば教えてください。

Nodeの作成

LangGraphの処理の中心となるNode(今回はmodel_1、model_2)を作成していきます。

model_1

次のようなコードを記述します。

def model_1(state: State) -> str:    
    # ============================
    # プロンプト
    # ============================
    text_prompt = \
    """
    お題:{question}
    """
    
    prompt = PromptTemplate.from_template(text_prompt).format(question=state["question"])
    
    input_messages = [
        SystemMessage(content = "いただいたお題をもとに面白い気の利いたダジャレを出力して下さい。出力はダジャレの部分だけ出力すること"),
        HumanMessage(content  = prompt)
        ]
    
    # ============================
    # 実行
    # ============================
    print("\n\n==== model_1 処理開始 =======")
    
    response = llm.invoke(input_messages)
    
    print(f"input: {state['question']}")
    print(f"output: {response.content}\n\n")
    # ============================
    # State更新
    # ============================
    print("State更新: \n")
    new_state =  {
            "question": state["question"],
            "model_1_output": response.content,
            "model_2_output": "",
            "messages":[response]
            }

    print(new_state)
    print("\n==== model_1 処理終了 =======")
    return new_state

model_1の動作は以下のコードで確認できます。

# ============================
# 実行
# ============================
initial_state =  {
        "question": "布団",
        }

model_1(initial_state)

# ============================
# 出力イメージ
# ============================

# {'question': '布団', 
#  'model_1_output': '布団が吹っ飛んだ!',
# }

model_2

次のようなコードを記述します。

def model_2(state: State) -> str:    
    # ============================
    # プロンプト
    # ============================
    text_prompt = \
    """
    ダジャレ:{content}
    """

    # 入力にはmodel_1の出力結果を渡す
    prompt = PromptTemplate.from_template(text_prompt).format(content=state["model_1_output"])
    
    input_messages = [
        SystemMessage(content = "あなたはダジャレが面白いか判断する役割を持ちます。0~10段階でどれだけダジャレが面白いか評価して下さい"),
        HumanMessage(content  = prompt)
        ]
    
    # ============================
    # 実行
    # ============================
    print("\n\n==== model_2 処理開始 =======")
    
    response = llm.invoke(input_messages)
    
    print(f"input: {state['model_1_output']}")
    print(f"output: {response.content}\n\n")
    # ============================
    # State更新
    # ============================
    print("State更新: \n")

    # 会話履歴を取得・更新
    state_message = state["messages"]
    state_message.append(response)
    
    new_state =  {
            "model_2_output": response.content,
            "messages":state_message
            }

    print(new_state)
    print("\n==== model_2 処理終了 =======")
    return new_state

model_2の動作は以下のコードで確認できます。

# ============================
# 実行
# ============================
initial_state =  {
        "question": "布団",
        "model_1_output":"布団が吹っ飛んだ!",
        "messages":[],
        }

model_2(initial_state)

# ============================
# 出力イメージ
# ============================

# {'model_2_output': '評価: 6/10 
#                     理由: 「布団が吹っ飛んだ!」は日本の定番ダジャレの一つで、語呂も良く、初めて聞く人にはクスッと笑える面白さがあります。
#                           ただし、あまりにも有名で使い古されているため、新鮮味や意外性は少なめです。そのため、平均よりやや上の6点と評価しました。}

Graphの作成

前述で作成したState, Nodeを用いてGraphを構築します。以下のコードを実行しましょう。

# ================================
# Graph
# ================================

graph = StateGraph(State)
graph.add_node("model_1", model_1)
graph.add_node("model_2", model_2)
graph.add_edge(START, "model_1")
graph.add_edge("model_1","model_2")
graph.add_edge("model_2",END)
app = graph.compile()

# 可視化
display(Image(app.get_graph().draw_mermaid_png()))

実行

最後に作成したGraphをもとに処理を実行してみましょう。

# ================================
# 実行
# ================================

initial_state =  {
        "question": "布団",
        }

app.invoke(initial_state)

【まとめ】LangGraphのコードすべて

今回紹介したLangGraphのコードを以下にまとめて記載します。

コード

# ================================
# ライブラリ
# ================================

# State
from typing_extensions       import TypedDict,Annotated

# LangChain
from IPython.display          import Image, display
from langchain_openai         import ChatOpenAI
from langchain.prompts        import PromptTemplate
from langchain.schema         import SystemMessage, HumanMessage,AIMessage

# LangGraph
from langgraph.graph.message  import add_messages
from langgraph.graph          import StateGraph, START, END

# ChatGPT API Key
API_KEY = "...."


# ================================
# State
# ================================

class State(TypedDict,total=False):
    question:str                               # 質問内容
    model_1_output:str                         # モデル(1)の出力結果
    model_2_output:dict                        # モデル(2)の出力結果
    messages: Annotated[list, add_messages]    # LLM同士の会話履歴

# ================================
# LLM
# ================================

llm = ChatOpenAI(
        model       = "gpt-4.1",   
        temperature = 0,
        api_key     = API_KEY
        )

# ================================
# Node
# ================================

def model_1(state: State) -> str:    
    # ============================
    # プロンプト
    # ============================
    text_prompt = \
    """
    お題:{question}
    """
    
    prompt = PromptTemplate.from_template(text_prompt).format(question=state["question"])
    
    input_messages = [
        SystemMessage(content = "いただいたお題をもとに面白い気の利いたダジャレを出力して下さい。出力はダジャレの部分だけ出力すること"),
        HumanMessage(content  = prompt)
        ]
    
    # ============================
    # 実行
    # ============================
    print("\n\n==== model_1 処理開始 =======")
    
    response = llm.invoke(input_messages)
    
    print(f"input: {state['question']}")
    print(f"output: {response.content}\n\n")
    # ============================
    # State更新
    # ============================
    print("State更新: \n")
    new_state =  {
            "question": state["question"],
            "model_1_output": response.content,
            "model_2_output": "",
            "messages":[response]
            }

    print(new_state)
    print("\n==== model_1 処理終了 =======")
    return new_state


def model_2(state: State) -> str:    
    # ============================
    # プロンプト
    # ============================
    text_prompt = \
    """
    ダジャレ:{content}
    """

    # 入力にはmodel_1の出力結果を渡す
    prompt = PromptTemplate.from_template(text_prompt).format(content=state["model_1_output"])
    
    input_messages = [
        SystemMessage(content = "あなたはダジャレが面白いか判断する役割を持ちます。0~10段階でどれだけダジャレが面白いか評価して下さい"),
        HumanMessage(content  = prompt)
        ]
    
    # ============================
    # 実行
    # ============================
    print("\n\n==== model_2 処理開始 =======")
    
    response = llm.invoke(input_messages)
    
    print(f"input: {state['model_1_output']}")
    print(f"output: {response.content}\n\n")
    # ============================
    # State更新
    # ============================
    print("State更新: \n")

    # 会話履歴を取得・更新
    state_message = state["messages"]
    state_message.append(response)
    
    new_state =  {
            "model_2_output": response.content,
            "messages":state_message
            }

    print(new_state)
    print("\n==== model_2 処理終了 =======")
    return new_state

# ================================
# Graph
# ================================

graph = StateGraph(State)
graph.add_node("model_1", model_1)
graph.add_node("model_2", model_2)
graph.add_edge(START, "model_1")
graph.add_edge("model_1","model_2")
graph.add_edge("model_2",END)
app = graph.compile()

# 可視化
display(Image(app.get_graph().draw_mermaid_png()))

実行用

# ================================
# 実行
# ================================

initial_state =  {
        "question": "布団",
        }

app.invoke(initial_state)

最後に

この記事が気に入ったら
フォローしてね!

本記事をシェア!
目次