Spaces:
Paused
Paused
| from typing import Dict, List, TypedDict, Sequence | |
| from langgraph.graph import StateGraph, END | |
| from langchain.schema import StrOutputParser | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| import models | |
| import prompts | |
| import json | |
| from operator import itemgetter | |
| from langgraph.errors import GraphRecursionError | |
| ####################################### | |
| ### Research Team Components ### | |
| ####################################### | |
| class ResearchState(TypedDict): | |
| workflow: List[str] | |
| topic: str | |
| research_data: Dict[str, str] | |
| next: str | |
| message_to_manager: str | |
| message_from_manager: str | |
| # | |
| # Reserach Chains and Tools | |
| # | |
| qdrant_research_chain = ( | |
| {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")} | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")} | |
| ) | |
| tavily_tool = TavilySearchResults(max_results=3) | |
| query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() ) | |
| tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser()) | |
| tavily_chain = ( | |
| {"query": query_chain} | tavily_simple | |
| ) | |
| research_supervisor_chain = ( | |
| prompts.research_supervisor_prompt | models.gpt4o | StrOutputParser() | |
| ) | |
| # | |
| # Reserach Node Defs | |
| # | |
| def query_qdrant(state: ResearchState) -> ResearchState: | |
| topic = state["topic"] | |
| result = qdrant_research_chain.invoke({"topic": topic}) | |
| print(result) | |
| state["research_data"]["qdrant_results"] = result["response"] | |
| state['workflow'].append("query_qdrant") | |
| print(state['workflow']) | |
| return state | |
| def web_search(state: ResearchState) -> ResearchState: | |
| topic = state["topic"] | |
| qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.") | |
| result = tavily_chain.invoke({"topic": topic,"qdrant_results": qdrant_results }) | |
| print(result) | |
| state["research_data"]["web_search_results"] = result | |
| state['workflow'].append("web_search") | |
| print(state['workflow']) | |
| return state | |
| def research_supervisor(state): | |
| message_from_manager = state["message_from_manager"] | |
| collected_data = state["research_data"] | |
| topic = state['topic'] | |
| supervisor_result = research_supervisor_chain.invoke({"message_from_manager": message_from_manager, "collected_data": collected_data, "topic": topic}) | |
| lines = supervisor_result.split('\n') | |
| print(supervisor_result) | |
| for line in lines: | |
| if line.startswith('Next Action: '): | |
| state['next'] = line[len('Next Action: '):].strip() # Extract the next action content | |
| elif line.startswith('Message to project manager: '): | |
| state['message_to_manager'] = line[len('Message to project manager: '):].strip() | |
| state['workflow'].append("research_supervisor") | |
| print(state['workflow']) | |
| return state | |
| def research_end(state): | |
| state['workflow'].append("research_end") | |
| print(state['workflow']) | |
| return state | |
| ####################################### | |
| ### Writing Team Components ### | |
| ####################################### | |
| class WritingState(TypedDict): | |
| workflow: List[str] | |
| topic: str | |
| research_data: Dict[str, str] | |
| draft_posts: Sequence[str] | |
| final_post: str | |
| next: str | |
| message_to_manager: str | |
| message_from_manager: str | |
| review_comments: str | |
| style_checked: bool | |
| # | |
| # Writing Chains | |
| # | |
| writing_supervisor_chain = ( | |
| prompts.writing_supervisor_prompt | models.gpt4o | StrOutputParser() | |
| ) | |
| post_creation_chain = ( | |
| prompts.post_creation_prompt | models.gpt4o_mini | StrOutputParser() | |
| ) | |
| post_editor_chain = ( | |
| prompts.post_editor_prompt | models.gpt4o | StrOutputParser() | |
| ) | |
| post_review_chain = ( | |
| prompts.post_review_prompt | models.gpt4o | StrOutputParser() | |
| ) | |
| # | |
| # Writing Node Defs | |
| # | |
| def post_creation(state): | |
| topic = state['topic'] | |
| drafts = state['draft_posts'] | |
| collected_data = state["research_data"] | |
| review_comments = state['review_comments'] | |
| results = post_creation_chain.invoke({"topic": topic, "collected_data": collected_data, "drafts": drafts, "review_comments": review_comments}) | |
| print(results) | |
| state['draft_posts'].append(results) | |
| state['workflow'].append("post_creation") | |
| print(state['workflow']) | |
| return state | |
| def post_editor(state): | |
| current_draft = state['draft_posts'][-1] | |
| styleguide = prompts.style_guide_text | |
| review_comments = state['review_comments'] | |
| results = post_editor_chain.invoke({"current_draft": current_draft, "styleguide": styleguide, "review_comments": review_comments}) | |
| print(results) | |
| state['draft_posts'].append(results) | |
| state['workflow'].append("post_editor") | |
| print(state['workflow']) | |
| return state | |
| def post_review(state): | |
| print("post_review node") | |
| current_draft = state['draft_posts'][-1] | |
| styleguide = prompts.style_guide_text | |
| results = post_review_chain.invoke({"current_draft": current_draft, "styleguide": styleguide}) | |
| print(results) | |
| data = json.loads(results.strip()) | |
| state['review_comments'] = data["Comments on current draft"] | |
| if data["Draft Acceptable"] == 'Yes': | |
| state['final_post'] = state['draft_posts'][-1] | |
| state['workflow'].append("post_review") | |
| print(state['workflow']) | |
| return state | |
| def writing_end(state): | |
| print("writing_end node") | |
| state['workflow'].append("writing_end") | |
| print(state['workflow']) | |
| return state | |
| def writing_supervisor(state): | |
| print("writing_supervisor node") | |
| message_from_manager = state['message_from_manager'] | |
| topic = state['topic'] | |
| drafts = state['draft_posts'] | |
| final_draft = state['final_post'] | |
| review_comments = state['review_comments'] | |
| supervisor_result = writing_supervisor_chain.invoke({"review_comments": review_comments, "message_from_manager": message_from_manager, "topic": topic, "drafts": drafts, "final_draft": final_draft}) | |
| print(supervisor_result) | |
| lines = supervisor_result.split('\n') | |
| for line in lines: | |
| if line.startswith('Next Action: '): | |
| state['next'] = line[len('Next Action: '):].strip() # Extract the next action content | |
| elif line.startswith('Message to project manager: '): | |
| state['message_to_manager'] = line[len('Message to project manager: '):].strip() | |
| state['workflow'].append("writing_supervisor") | |
| print(state['workflow']) | |
| return state | |
| ####################################### | |
| ### Overarching Graph Components ### | |
| ####################################### | |
| class State(TypedDict): | |
| workflow: List[str] | |
| topic: str | |
| research_data: Dict[str, str] | |
| draft_posts: Sequence[str] | |
| final_post: str | |
| next: str | |
| user_input: str | |
| message_to_manager: str | |
| message_from_manager: str | |
| last_active_team :str | |
| next_team: str | |
| review_comments: str | |
| # | |
| # Complete Graph Chains | |
| # | |
| overall_supervisor_chain = ( | |
| prompts.overall_supervisor_prompt | models.gpt4o | StrOutputParser() | |
| ) | |
| # | |
| # Complete Graph Node defs | |
| # | |
| def overall_supervisor(state): | |
| init_user_query = state["user_input"] | |
| message_to_manager = state['message_to_manager'] | |
| last_active_team = state['last_active_team'] | |
| final_post = state['final_post'] | |
| supervisor_result = overall_supervisor_chain.invoke({"query": init_user_query, "message_to_manager": message_to_manager, "last_active_team": last_active_team, "final_post": final_post}) | |
| print(supervisor_result) | |
| lines = supervisor_result.split('\n') | |
| for line in lines: | |
| if line.startswith('Next Action: '): | |
| state['next_team'] = line[len('Next Action: '):].strip() # Extract the next action content | |
| elif line.startswith('Extracted Topic: '): | |
| state['topic'] = line[len('Extracted Topic: '):].strip() # Extract the next action content | |
| elif line.startswith('Message to supervisor: '): | |
| state['message_from_manager'] = line[len('Message to supervisor: '):].strip() # Extract the next action content | |
| state['workflow'].append("overall_supervisor") | |
| print(state['workflow']) | |
| return state | |
| ####################################### | |
| ### Graph structures ### | |
| ####################################### | |
| # | |
| # Reserach Graph Nodes | |
| # | |
| research_graph = StateGraph(ResearchState) | |
| research_graph.add_node("query_qdrant", query_qdrant) | |
| research_graph.add_node("web_search", web_search) | |
| research_graph.add_node("research_supervisor", research_supervisor) | |
| research_graph.add_node("research_end", research_end) | |
| # | |
| # Reserach Graph Edges | |
| # | |
| research_graph.set_entry_point("research_supervisor") | |
| research_graph.add_edge("query_qdrant", "research_supervisor") | |
| research_graph.add_edge("web_search", "research_supervisor") | |
| research_graph.add_conditional_edges( | |
| "research_supervisor", | |
| lambda x: x["next"], | |
| {"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": "research_end"}, | |
| ) | |
| research_graph_comp = research_graph.compile() | |
| # | |
| # Writing Graph Nodes | |
| # | |
| writing_graph = StateGraph(WritingState) | |
| writing_graph.add_node("post_creation", post_creation) | |
| writing_graph.add_node("post_editor", post_editor) | |
| writing_graph.add_node("post_review", post_review) | |
| writing_graph.add_node("writing_supervisor", writing_supervisor) | |
| writing_graph.add_node("writing_end", writing_end) | |
| # | |
| # Writing Graph Edges | |
| # | |
| writing_graph.set_entry_point("writing_supervisor") | |
| writing_graph.add_edge("post_creation", "post_editor") | |
| writing_graph.add_edge("post_editor", "post_review") | |
| writing_graph.add_edge("post_review", "writing_supervisor") | |
| writing_graph.add_conditional_edges( | |
| "writing_supervisor", | |
| lambda x: x["next"], | |
| {"NEW DRAFT": "post_creation", | |
| "FINISH": "writing_end"}, | |
| ) | |
| writing_graph_comp = writing_graph.compile() | |
| # | |
| # Complete Graph Nodes | |
| # | |
| overall_graph = StateGraph(State) | |
| overall_graph.add_node("overall_supervisor", overall_supervisor) | |
| overall_graph.add_node("research_team_graph", research_graph_comp) | |
| overall_graph.add_node("writing_team_graph", writing_graph_comp) | |
| # | |
| # Complete Graph Edges | |
| # | |
| overall_graph.set_entry_point("overall_supervisor") | |
| overall_graph.add_edge("research_team_graph", "overall_supervisor") | |
| overall_graph.add_edge("writing_team_graph", "overall_supervisor") | |
| overall_graph.add_conditional_edges( | |
| "overall_supervisor", | |
| lambda x: x["next_team"], | |
| {"research_team": "research_team_graph", | |
| "writing_team": "writing_team_graph", | |
| "FINISH": END}, | |
| ) | |
| app = overall_graph.compile() | |
| ####################################### | |
| ### Run method ### | |
| ####################################### | |
| def getSocialMediaPost(userInput: str) -> str: | |
| finalPost = "" | |
| initial_state = State( | |
| workflow = [], | |
| topic= "", | |
| research_data = {}, | |
| draft_posts = [], | |
| final_post = [], | |
| next = [], | |
| next_team = [], | |
| user_input=userInput, | |
| message_to_manager="", | |
| message_from_manager="", | |
| last_active_team="", | |
| review_comments="" | |
| ) | |
| results = app.invoke(initial_state, {"recursion_limit": 40}) | |
| try: | |
| results = app.invoke(initial_state, {"recursion_limit": 40}) | |
| except GraphRecursionError: | |
| return "Recursion Error" | |
| finalPost = results['final_post'] | |
| return finalPost |