Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
38d63de
1
Parent(s):
23550aa
Add conversation reset command
Browse files- bot/discord_bot.py +22 -0
- src/db.py +26 -1
bot/discord_bot.py
CHANGED
|
@@ -4,9 +4,11 @@ from __future__ import annotations
|
|
| 4 |
|
| 5 |
|
| 6 |
import discord
|
|
|
|
| 7 |
from discord.ext import commands
|
| 8 |
|
| 9 |
from src.chat import ChatSession
|
|
|
|
| 10 |
from src.log import get_logger
|
| 11 |
|
| 12 |
from .config import DEFAULT_SESSION, DEFAULT_USER_PREFIX, DISCORD_TOKEN
|
|
@@ -21,6 +23,10 @@ class LLMDiscordBot(commands.Bot):
|
|
| 21 |
intents = intents or discord.Intents.all()
|
| 22 |
super().__init__(command_prefix=None, intents=intents)
|
| 23 |
self._log = get_logger(self.__class__.__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
async def on_ready(self) -> None: # noqa: D401
|
| 26 |
self._log.info("Logged in as %s (%s)", self.user, self.user.id)
|
|
@@ -44,6 +50,22 @@ class LLMDiscordBot(commands.Bot):
|
|
| 44 |
if reply:
|
| 45 |
await message.reply(reply, mention_author=False)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def run_bot(token: str | None = None) -> None:
|
| 49 |
"""Run the Discord bot using the provided token."""
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
import discord
|
| 7 |
+
from discord import app_commands
|
| 8 |
from discord.ext import commands
|
| 9 |
|
| 10 |
from src.chat import ChatSession
|
| 11 |
+
from src.db import reset_history as db_reset_history
|
| 12 |
from src.log import get_logger
|
| 13 |
|
| 14 |
from .config import DEFAULT_SESSION, DEFAULT_USER_PREFIX, DISCORD_TOKEN
|
|
|
|
| 23 |
intents = intents or discord.Intents.all()
|
| 24 |
super().__init__(command_prefix=None, intents=intents)
|
| 25 |
self._log = get_logger(self.__class__.__name__)
|
| 26 |
+
self.tree.add_command(self.reset_conversation)
|
| 27 |
+
|
| 28 |
+
async def setup_hook(self) -> None: # noqa: D401
|
| 29 |
+
await self.tree.sync()
|
| 30 |
|
| 31 |
async def on_ready(self) -> None: # noqa: D401
|
| 32 |
self._log.info("Logged in as %s (%s)", self.user, self.user.id)
|
|
|
|
| 50 |
if reply:
|
| 51 |
await message.reply(reply, mention_author=False)
|
| 52 |
|
| 53 |
+
@app_commands.command(
|
| 54 |
+
name="reset",
|
| 55 |
+
description="Reset conversation history for this channel.",
|
| 56 |
+
)
|
| 57 |
+
async def reset_conversation(self, interaction: discord.Interaction) -> None:
|
| 58 |
+
"""Delete all messages stored for the user and channel."""
|
| 59 |
+
|
| 60 |
+
user_id = f"{DEFAULT_USER_PREFIX}{interaction.user.id}"
|
| 61 |
+
session_id = f"{DEFAULT_SESSION}_{interaction.channel_id}"
|
| 62 |
+
deleted = db_reset_history(user_id, session_id)
|
| 63 |
+
if deleted:
|
| 64 |
+
msg = f"Conversation history cleared ({deleted} messages removed)."
|
| 65 |
+
else:
|
| 66 |
+
msg = "No conversation history found for this channel."
|
| 67 |
+
await interaction.response.send_message(msg, ephemeral=True)
|
| 68 |
+
|
| 69 |
|
| 70 |
def run_bot(token: str | None = None) -> None:
|
| 71 |
"""Run the Discord bot using the provided token."""
|
src/db.py
CHANGED
|
@@ -46,7 +46,13 @@ class Message(BaseModel):
|
|
| 46 |
created_at = DateTimeField(default=datetime.utcnow)
|
| 47 |
|
| 48 |
|
| 49 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def init_db() -> None:
|
|
@@ -54,3 +60,22 @@ def init_db() -> None:
|
|
| 54 |
if _db.is_closed():
|
| 55 |
_db.connect()
|
| 56 |
_db.create_tables([User, Conversation, Message])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
created_at = DateTimeField(default=datetime.utcnow)
|
| 47 |
|
| 48 |
|
| 49 |
+
__all__ = [
|
| 50 |
+
"_db",
|
| 51 |
+
"User",
|
| 52 |
+
"Conversation",
|
| 53 |
+
"Message",
|
| 54 |
+
"reset_history",
|
| 55 |
+
]
|
| 56 |
|
| 57 |
|
| 58 |
def init_db() -> None:
|
|
|
|
| 60 |
if _db.is_closed():
|
| 61 |
_db.connect()
|
| 62 |
_db.create_tables([User, Conversation, Message])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def reset_history(username: str, session_name: str) -> int:
|
| 66 |
+
"""Delete all messages for the given user and session."""
|
| 67 |
+
|
| 68 |
+
init_db()
|
| 69 |
+
try:
|
| 70 |
+
user = User.get(User.username == username)
|
| 71 |
+
conv = Conversation.get(
|
| 72 |
+
Conversation.user == user, Conversation.session_name == session_name
|
| 73 |
+
)
|
| 74 |
+
except (User.DoesNotExist, Conversation.DoesNotExist):
|
| 75 |
+
return 0
|
| 76 |
+
|
| 77 |
+
deleted = Message.delete().where(Message.conversation == conv).execute()
|
| 78 |
+
conv.delete_instance()
|
| 79 |
+
if not Conversation.select().where(Conversation.user == user).exists():
|
| 80 |
+
user.delete_instance()
|
| 81 |
+
return deleted
|