AI创想

标题: LangGraph使用 [打印本页]

作者: 米落枫    时间: 2 小时前
标题: LangGraph使用
作者:大得369
  1. pip install langgraph --upgrade
复制代码
  1. from langchain_core.prompts import ChatPromptTemplate
  2. from langchain_core.output_parsers import StrOutputParser
  3. from langchain_openai import ChatOpenAI
  4. from langchain_core.tools import tool
  5. from langgraph.graph import StateGraph, END
  6. from typing import TypedDict
  7. import os
  8. # ===================== 配置 =====================
  9. os.environ["DEEPSEEK_API_KEY"]="sk-"
  10. llm = ChatOpenAI(model="deepseek-v4-pro",
  11.     api_key=os.environ["DEEPSEEK_API_KEY"],
  12.     base_url="https://api.deepseek.com",
  13.     temperature=0.2,
  14.     max_tokens=4096,
  15.     streaming=True
  16. )
  17. parser = StrOutputParser()# ===================== MCP 工具 =====================
  18. @tool
  19. def mcp_create_file(file_path: str, content: str) -> str:
  20.     """MCP 文件创建工具"""
  21.     try:
  22.         with open(file_path, "w", encoding="utf-8") as f:
  23.             f.write(content)return f"[MCP] 文件已保存:{file_path}"
  24.     except Exception as e:
  25.         return f"[MCP] 错误:{str(e)}"# ===================== 定义状态(LangGraph 核心) =====================
  26. class WorkState(TypedDict):
  27.     user_task: str
  28.     analyse_result: str
  29.     code_result: str
  30.     optimize_result: str
  31.     summary_result: str
  32. # ===================== 提示词 =====================
  33. prompt_analyse = ChatPromptTemplate.from_messages([("system", "资深后端工程师,精准拆解开发需求,梳理实现思路"),
  34.     ("user", "需求内容:{user_task}")])
  35. chain_analyse = prompt_analyse | llm | parser
  36. prompt_code = ChatPromptTemplate.from_messages([("system", "根据需求思路编写完整可运行代码,规范整洁"),
  37.     ("user", "需求思路:{analyse_result}")])
  38. chain_code = prompt_code | llm | parser
  39. prompt_optimize = ChatPromptTemplate.from_messages([("system", "优化代码性能、可读性、异常处理"),
  40.     ("user", "原始代码:{code_result}")])
  41. chain_optimize = prompt_optimize | llm | parser
  42. prompt_summary = ChatPromptTemplate.from_messages([("system", "总结功能,并自动保存代码"),
  43.     ("user", "优化后代码:{optimize_result}")])
  44. chain_summary = prompt_summary | llm | parser
  45. # ===================== 流式输出 =====================
  46. def stream_output(chain, input_data):
  47.     result =""forchunkin chain.stream(input_data):
  48.         print(chunk, end="", flush=True)
  49.         result += chunk
  50.     print("\n")return result
  51. # ===================== LangGraph 节点函数 =====================
  52. def node_analyse(state: WorkState):
  53.     print("===== 1.需求解析中 =====\n")
  54.     res = stream_output(chain_analyse, {"user_task": state["user_task"]})return{"analyse_result": res}
  55. def node_code(state: WorkState):
  56.     print("===== 2.生成代码中 =====\n")
  57.     res = stream_output(chain_code, {"analyse_result": state["analyse_result"]})return{"code_result": res}
  58. def node_optimize(state: WorkState):
  59.     print("===== 3.代码优化中 =====\n")
  60.     res = stream_output(chain_optimize, {"code_result": state["code_result"]})return{"optimize_result": res}
  61. def node_summary(state: WorkState):
  62.     print("===== 4.最终总结 + 自动保存文件 =====\n")
  63.     res = stream_output(chain_summary, {"optimize_result": state["optimize_result"]})# MCP 保存
  64.     print(mcp_create_file.invoke({"file_path":"output_code2.html", "content": state["optimize_result"]}))return{"summary_result": res}# ===================== 构建 LangGraph 流程图 =====================
  65. builder = StateGraph(WorkState)# 添加节点
  66. builder.add_node("analyse", node_analyse)
  67. builder.add_node("code", node_code)
  68. builder.add_node("optimize", node_optimize)
  69. builder.add_node("summary", node_summary)# 流程连线
  70. builder.set_entry_point("analyse")
  71. builder.add_edge("analyse", "code")
  72. builder.add_edge("code", "optimize")
  73. builder.add_edge("optimize", "summary")
  74. builder.add_edge("summary", END)# 编译图
  75. graph = builder.compile()# ===================== 运行 =====================if __name__ =="__main__":
  76.     task ="写个漂亮的登录页面,html"
  77.     result = graph.invoke({"user_task": task})
  78.     print("\n✅ LangGraph 工作流执行完成!")
复制代码
有记忆,简单智能体
[code]from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Dict
import os
import json
import datetime

# ===================== 配置 =====================
os.environ["DEEPSEEK_API_KEY"]="sk-"

llm = ChatOpenAI(model="deepseek-v4-pro",
    api_key=os.environ["DEEPSEEK_API_KEY"],
    base_url="https://api.deepseek.com",
    temperature=0.2,
    max_tokens=2048,
    streaming=True
)
parser = StrOutputParser()# 闲聊专用轻量对话链
chat_prompt = ChatPromptTemplate.from_messages([("system", "你是友好AI助手,简洁自然聊天,记住上下文"),
    ("user", "历史上下文:{memory}\n用户:{query}")])
chat_chain = chat_prompt | llm | parser

# ===================== 本地记忆管理(永久保存) =====================
MEMORY_PATH ="agent_memory.json"

def load_memory() -> List[Dict]:
    if os.path.exists(MEMORY_PATH):
        with open(MEMORY_PATH, "r", encoding="utf-8") as f:
            return json.load(f)return[]

def save_memory(task: str, result: str):
    memory = load_memory()
    memory.append({"time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "task": task,
        "result": result
    })
    with open(MEMORY_PATH, "w", encoding="utf-8") as f:
        json.dump(memory, f, ensure_ascii=False, indent=2)

def get_memory_prompt() -> str:
    memory = load_memory()if not memory:
        return"无历史记录"
    text ="\n".join([f"用户:{i['task']} 助手:{i['result']}"foriin memory[-5:]])return text

# 智能判断:是否为编程开发类需求
def is_dev_task(text: str) -> bool:
    dev_key ={"写代码", "编写", "html", "css", "js", "python", "java", "页面", "接口", "脚本", "功能", "程序", "搭建", "开发", "前端", "后端"}return any(k in text.lower()forkin dev_key)# ===================== MCP 文件工具 =====================
@tool
def mcp_create_file(file_path: str, content: str) -> str:
    """创建并保存代码文件到本地"""
    try:
        with open(file_path, "w", encoding="utf-8") as f:
            f.write(content)return f"[MCP] 已保存:{file_path}"
    except Exception as e:
        return f"[MCP] 错误:{str(e)}"# ===================== LangGraph 状态 =====================
class WorkState(TypedDict):
    user_task: str
    memory: str
    analyse_result: str
    code_result: str
    optimize_result: str
    summary_result: str

# ===================== 工作流提示词 =====================
prompt_analyse = ChatPromptTemplate.from_messages([("system", "你有记忆,能记住历史对话。\n{memory}\n资深工程师,拆解需求"),
    ("user", "需求:{user_task}")])
chain_analyse = prompt_analyse | llm | parser

prompt_code = ChatPromptTemplate.from_messages([("system", "直接输出完整干净代码,无多余解释"),
    ("user", "{analyse_result}")])
chain_code = prompt_code | llm | parser

prompt_optimize = ChatPromptTemplate.from_messages([("system", "优化代码,保持可运行,直接输出"),
    ("user", "{code_result}")])
chain_optimize = prompt_optimize | llm | parser

prompt_summary = ChatPromptTemplate.from_messages([("system", "总结结果,并保存文件"),
    ("user", "{optimize_result}")])
chain_summary = prompt_summary | llm | parser

# ===================== 流式输出 =====================
def stream_output(chain, input_data):
    result =""forchunkin chain.stream(input_data):
        print(chunk, end="", flush=True)
        result += chunk
    print("\n")return result

# ===================== LangGraph 节点 =====================
def node_analyse(state: WorkState):
    print("===== 1. 分析需求 =====")
    res = stream_output(chain_analyse, {"user_task": state["user_task"],
        "memory": state["memory"]})return{"analyse_result": res}

def node_code(state: WorkState):
    print("===== 2. 生成代码 =====")
    res = stream_output(chain_code, {"analyse_result": state["analyse_result"]})return{"code_result": res}

def node_optimize(state: WorkState):
    print("===== 3. 优化代码 =====")
    res = stream_output(chain_optimize, {"code_result": state["code_result"]})return{"optimize_result": res}

def node_summary(state: WorkState):
    print("===== 4. 总结 + 保存 =====")
    res = stream_output(chain_summary, {"optimize_result": state["optimize_result"]})
    print(mcp_create_file.invoke({"file_path":"output.html",
        "content": state["optimize_result"]}))return{"summary_result": res}# ===================== 构建流程图 =====================
builder = StateGraph(WorkState)
builder.add_node("analyse", node_analyse)
builder.add_node("code", node_code)
builder.add_node("optimize", node_optimize)
builder.add_node("summary", node_summary)

builder.set_entry_point("analyse")
builder.add_edge("analyse", "code")
builder.add_edge("code", "optimize")
builder.add_edge("optimize", "summary")
builder.add_edge("summary", END)

graph = builder.compile()# ===================== 交互主逻辑(智能分流) =====================if __name__ =="__main__":
    print("✅ AI 编程助手已启动(输入 exit 退出)")
    print("✅ 闲聊自动聊天,开发需求自动走工作流\n")while True:
        user_input = input("你:")if user_input.lower()in["exit", "quit", "q"]:
            print("再见!")break

        mem_text = get_memory_prompt()if is_dev_task(user_input):
            res_data = graph.invoke({"user_task": user_input,
                "memory": mem_text
            })
            final_reply = res_data["summary_result"]
        else:
            print("助手:", end="", flush=True)
            final_reply = stream_output(chat_chain, {"memory": mem_text, "query": user_input})

        save_memory(user_input, final_reply)
        print("\n




欢迎光临 AI创想 (http://llms-ai.com/) Powered by Discuz! X3.4