- 大家好,我是同学小张,日常分享AI知识和实战案例
- 欢迎 点赞 + 关注 👏,持续学习,持续干货输出。
- 一起交流💬,一起进步💪。
- 微信公众号也可搜【同学小张】 🙏
本站文章一览:
书接上文(【AI Agent系列】【LangGraph】0. 快速上手:协同LangChain,LangGraph帮你用图结构轻松构建多智能体应用),前面我们了解了 LangGraph 的概念和基本构造方法,今天我们来看下 LangGraph 构造中的进阶用法:给边加个条件 - 条件分支(Conditional edges)。
LangGraph 构造的是个图的数据结构,有节点(node) 和边(edge),那它的边也可以是带条件的。如何给边加入条件呢?可以通过 add_conditional_edges
函数添加带条件的边。
1. 完整代码及运行
废话不多说,先上完整代码,和运行结果。先跑起来看看效果再说。
from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, BaseMessage from langgraph.graph import END, MessageGraph import json from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.utils.function_calling import convert_to_openai_tool from typing import List @tool def multiply(first_number: int, second_number: int): """Multiplies two numbers together.""" return first_number * second_number model = ChatOpenAI(temperature=0) model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)]) graph = MessageGraph() def invoke_model(state: List[BaseMessage]): return model_with_tools.invoke(state) graph.add_node("oracle", invoke_model) def invoke_tool(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) multiply_call = None for tool_call in tool_calls: if tool_call.get("function").get("name") == "multiply": multiply_call = tool_call if multiply_call is None: raise Exception("No adder input found.") res = multiply.invoke( json.loads(multiply_call.get("function").get("arguments")) ) return ToolMessage( tool_call_id=multiply_call.get("id"), content=res ) graph.add_node("multiply", invoke_tool) graph.add_edge("multiply", END) graph.set_entry_point("oracle") def router(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) if len(tool_calls): return "multiply" else: return "end" graph.add_conditional_edges("oracle", router, { "multiply": "multiply", "end": END, }) runnable = graph.compile() response = runnable.invoke(HumanMessage("What is 123 * 456?")) print(response)
运行结果如下:
2. 代码详解
下面对上面的代码进行详细解释。
2.1 add_conditional_edges
首先,我们知道了可以通过 add_conditional_edges
来对边进行条件添加。这部分代码如下:
graph.add_conditional_edges("oracle", router, { "multiply": "multiply", "end": END, })
add_conditional_edges
接收三个参数:
- 第一个为这条边的第一个node的名称
- 第二个为这条边的条件
- 第三个为条件返回结果的映射(根据条件结果映射到相应的node)
如上面的代码,意思就是往 “oracle” node上添加边,这个node有两条边,一条是往“multiply” node上走,一条是往“END”上走。怎么决定往哪个方向去:条件是 router(后面解释),如果 router 返回的是“multiply”,则往“multiply”方向走,如果 router 返回的是 “end”,则走“END”。
来看下这个函数的源码:
def add_conditional_edges( self, start_key: str, condition: Callable[..., str], conditional_edge_mapping: Optional[Dict[str, str]] = None, ) -> None: if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " "not be reflected in the compiled graph." ) if start_key not in self.nodes: raise ValueError(f"Need to add_node `{start_key}` first") if iscoroutinefunction(condition): raise ValueError("Condition cannot be a coroutine function") if conditional_edge_mapping and set( conditional_edge_mapping.values() ).difference([END]).difference(self.nodes): raise ValueError( f"Missing nodes which are in conditional edge mapping. Mapping " f"contains possible destinations: " f"{list(conditional_edge_mapping.values())}. Possible nodes are " f"{list(self.nodes.keys())}." ) self.branches[start_key].append(Branch(condition, conditional_edge_mapping))
重点是这一句:self.branches[start_key].append(Branch(condition, conditional_edge_mapping))
,给当前node添加分支Branch。
2.2 条件 router
条件代码如下:判断执行结果中是否有 tool_calls 参数,如果有,则返回"multiply",没有,则返回“end”。
def router(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) if len(tool_calls): return "multiply" else: return "end"
2.3 各node的定义
(1)起始node:oracle
@tool def multiply(first_number: int, second_number: int): """Multiplies two numbers together.""" return first_number * second_number model = ChatOpenAI(temperature=0) model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)]) graph = MessageGraph() def invoke_model(state: List[BaseMessage]): return model_with_tools.invoke(state) graph.add_node("oracle", invoke_model)
这个node是一个带有Tools 的 ChatOpenAI。在LangChain中使用Tools的详细教程请看这篇文章:【AI大模型应用开发】【LangChain系列】5. 实战LangChain的智能体Agents模块。简单解释就是:这个node的执行结果,将返回是否应该使用绑定的Tools。
(2)multiply
def invoke_tool(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) multiply_call = None for tool_call in tool_calls: if tool_call.get("function").get("name") == "multiply": multiply_call = tool_call if multiply_call is None: raise Exception("No adder input found.") res = multiply.invoke( json.loads(multiply_call.get("function").get("arguments")) ) return ToolMessage( tool_call_id=multiply_call.get("id"), content=res ) graph.add_node("multiply", invoke_tool)
这个node的作用就是执行Tools。
2.4 总体流程
如果觉得本文对你有帮助,麻烦点个赞和关注呗 ~~~
- 大家好,我是 同学小张,日常分享AI知识和实战案例
- 欢迎 点赞 + 关注 👏,持续学习,持续干货输出。
- 一起交流💬,一起进步💪。
- 微信公众号也可搜【同学小张】 🙏
本站文章一览: