File size: 1,278 Bytes
e5e882e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Dict, Any, List
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from core.config import settings

class BaseAgent:
    def __init__(self, role: str):
        self.role = role
        self.llm = ChatGroq(
            api_key=settings.GROQ_API_KEY,
            model_name=settings.MODEL_NAME,
            temperature=settings.TEMPERATURE
        )
        self.chain = None
        
    def create_chain(self, template: str):
        """Create a LangChain chain with the given template"""
        prompt = ChatPromptTemplate.from_template(template)
        self.chain = (
            {"input": RunnablePassthrough()}
            | prompt
            | self.llm
            | StrOutputParser()
        )
        
    async def process(self, input_data: Dict[str, Any]) -> str:
        """Process input data and return the result"""
        if not self.chain:
            raise ValueError("Chain not initialized. Call create_chain first.")
        return await self.chain.ainvoke(input_data)
    
    def get_role(self) -> str:
        """Get the agent's role"""
        return self.role