diff --git a/.gitignore b/.gitignore index 3bd6415..a65e81c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ __pycache__ current_version.txt MEMORY_LOADED memory.json -*.pkl \ No newline at end of file +*.pkl +memories/ +models/ \ No newline at end of file diff --git a/bot.py b/bot.py index ffb0623..c69338b 100644 --- a/bot.py +++ b/bot.py @@ -26,6 +26,10 @@ analyzer = SentimentIntensityAnalyzer() print(splashtext) # you can use https://patorjk.com/software/taag/ for 3d text or just remove this entirely +os.makedirs("memories", exist_ok=True) +os.makedirs("models", exist_ok=True) + + def download_json(): locales_dir = "locales" response = requests.get(f"{VERSION_URL}/goob/locales/{LOCALE}.json") @@ -121,20 +125,18 @@ def register_name(NAME): register_name(NAME) def save_markov_model(model, filename='markov_model.pkl'): - with open(filename, 'wb') as f: + model_file = f"models/{filename}" + with open(model_file, "wb") as f: pickle.dump(model, f) - print(f"Markov model saved to {filename}.") + print(f"{GREEN}Markov model saved to {model_file}{RESET}") -def load_markov_model(filename='markov_model.pkl'): - - try: - with open(filename, 'rb') as f: - model = pickle.load(f) - print(f"{GREEN}{get_translation(LOCALE, 'model_loaded')} {filename}.{RESET}") - return model - except FileNotFoundError: - print(f"{RED}{filename} {get_translation(LOCALE, 'not_found')}{RESET}") - return None +def load_markov_model(server_id=None): + if server_id: + filename = f"markov_model_{server_id}.pkl" + else: + filename = "markov_model.pkl" + + model_file = f"models/{filename}" def get_latest_version_info(): @@ -173,7 +175,7 @@ def generate_sha256_of_current_file(): latest_version = "0.0.0" -local_version = "0.14.8.3" +local_version = "rewrite/seperate-memories" os.environ['gooberlocal_version'] = local_version @@ -226,39 +228,50 @@ def get_file_info(file_path): nltk.download('punkt') -def load_memory(): +def load_memory(server_id=None): + if server_id: + memory_file = f"memories/memory_{server_id}.json" + else: + memory_file = "memories/memory.json" + data = [] - - # load data from MEMORY_FILE try: - with open(MEMORY_FILE, "r") as f: + with open(memory_file, "r") as f: data = json.load(f) except FileNotFoundError: pass - - if not os.path.exists(MEMORY_LOADED_FILE): - try: - with open(DEFAULT_DATASET_FILE, "r") as f: - default_data = json.load(f) - data.extend(default_data) - except FileNotFoundError: - pass - with open(MEMORY_LOADED_FILE, "w") as f: - f.write("Data loaded") + except json.JSONDecodeError: + print(f"{RED}Error decoding memory file {memory_file}{RESET}") + return data -def save_memory(memory): - with open(MEMORY_FILE, "w") as f: +def save_memory(memory, server_id=None): + if server_id: + memory_file = f"memories/memory_{server_id}.json" + else: + memory_file = "memories/memory.json" + + with open(memory_file, "w") as f: json.dump(memory, f, indent=4) -def train_markov_model(memory, additional_data=None): +def train_markov_model(memory, additional_data=None, server_id=None): if not memory: return None + text = "\n".join(memory) if additional_data: text += "\n" + "\n".join(additional_data) - model = markovify.NewlineText(text, state_size=2) - return model + + try: + model = markovify.NewlineText(text, state_size=2) + if server_id: + model_filename = f"markov_model_{server_id}.pkl" + save_markov_model(model, model_filename) + return model + except Exception as e: + print(f"{RED}Error training model: {e}{RESET}") + return None + #this doesnt work and im extremely pissed and mad def append_mentions_to_18digit_integer(message): pattern = r'\b\d{18}\b' @@ -381,42 +394,144 @@ async def send_message(ctx, message=None, embed=None, file=None, edit=False, mes sent_message = await ctx.send(file=file) return sent_message -@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_retrain')}") -async def retrain(ctx): +@bot.hybrid_command(description="Retrain Markov models for servers") +@app_commands.choices(option=[ + app_commands.Choice(name="Retrain current server", value="current"), + app_commands.Choice(name="Retrain all servers", value="all"), + app_commands.Choice(name="Select servers to retrain", value="select") +]) +async def retrain_models(ctx, option: app_commands.Choice[str]): if ctx.author.id != ownerid: - return + return await ctx.send("You don't have permission to use this command.", ephemeral=True) - message_ref = await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retrain')}") + if option.value == "current": + server_id = ctx.guild.id if ctx.guild else "DM" + await retrain_single_server(ctx, server_id) + + elif option.value == "all": + await retrain_all_servers(ctx) + + elif option.value == "select": + await show_server_selection(ctx) + +async def retrain_single_server(ctx, server_id): + memory_file = f"memories/memory_{server_id}.json" + model_file = f"models/markov_model_{server_id}.pkl" + try: - with open(MEMORY_FILE, 'r') as f: + with open(memory_file, 'r') as f: memory = json.load(f) except FileNotFoundError: - await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_memory_not_found')}") - return - except json.JSONDecodeError: - await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_memory_is_corrupt')}") - return - data_size = len(memory) - processed_data = 0 - processing_message_ref = await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining').format(processed_data=processed_data, data_size=data_size)}") - start_time = time.time() - for i, data in enumerate(memory): - processed_data += 1 - if processed_data % 1000 == 0 or processed_data == data_size: - await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining').format(processed_data=processed_data, data_size=data_size)}", edit=True, message_reference=processing_message_ref) - - global markov_model + return await ctx.send(f"No memory data found for server {server_id}", ephemeral=True) - markov_model = train_markov_model(memory) - save_markov_model(markov_model) + processing_msg = await ctx.send(f"Retraining model for server {server_id}...") + + model = train_markov_model(memory, server_id=server_id) + + if model: + await processing_msg.edit(content=f"Successfully retrained model for server {server_id}") + else: + await processing_msg.edit(content=f"Failed to retrain model for server {server_id}") - await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retrain_successful').format(data_size=data_size)}", edit=True, message_reference=processing_message_ref) +async def retrain_all_servers(ctx): + memory_files = [f for f in os.listdir("memories/") if f.startswith("memory_") and f.endswith(".json")] + + if not memory_files: + return await ctx.send("No server memory files found to retrain.", ephemeral=True) + + progress_msg = await ctx.send(f"Retraining models for {len(memory_files)} servers...") + success_count = 0 + + for mem_file in memory_files: + try: + server_id = mem_file.replace("memory_", "").replace(".json", "") + with open(f"memories/{mem_file}", 'r') as f: + memory = json.load(f) + + model = train_markov_model(memory, server_id=server_id) + if model: + success_count += 1 + + if success_count % 5 == 0: + await progress_msg.edit(content=f"Retraining in progress... {success_count}/{len(memory_files)} completed") + + except Exception as e: + print(f"Error retraining {mem_file}: {e}") + + await progress_msg.edit(content=f"Retraining complete successfully retrained {success_count}/{len(memory_files)} servers") + +async def show_server_selection(ctx): + memory_files = [f for f in os.listdir("memories/") if f.startswith("memory_") and f.endswith(".json")] + + if not memory_files: + return await ctx.send("No server memory files found.", ephemeral=True) + options = [] + for mem_file in memory_files: + server_id = mem_file.replace("memory_", "").replace(".json", "") + server_name = f"Server {server_id}" + if server_id != "DM": + guild = bot.get_guild(int(server_id)) + if guild: + server_name = guild.name + + options.append(discord.SelectOption(label=server_name, value=server_id)) + select_menus = [] + for i in range(0, len(options), 25): + chunk = options[i:i+25] + + select = discord.ui.Select( + placeholder=f"Select servers to retrain ({i+1}-{min(i+25, len(options))})", + min_values=1, + max_values=len(chunk), + options=chunk + ) + + select_menus.append(select) + view = discord.ui.View() + for menu in select_menus: + menu.callback = lambda interaction, m=menu: handle_server_selection(interaction, m) + view.add_item(menu) + + await ctx.send("Select which servers to retrain:", view=view) + +async def handle_server_selection(interaction, select_menu): + await interaction.response.defer() + + selected_servers = select_menu.values + if not selected_servers: + return await interaction.followup.send("No servers selected.", ephemeral=True) + + progress_msg = await interaction.followup.send(f"Retraining {len(selected_servers)} selected servers...") + success_count = 0 + + for server_id in selected_servers: + try: + memory_file = f"memories/memory_{server_id}.json" + with open(memory_file, 'r') as f: + memory = json.load(f) + + model = train_markov_model(memory, server_id=server_id) + if model: + success_count += 1 + if success_count % 5 == 0: + await progress_msg.edit(content=f"Retraining in progress... {success_count}/{len(selected_servers)} completed") + + except Exception as e: + print(f"Error retraining {server_id}: {e}") + + await progress_msg.edit(content=f"Retraining complete Successfully retrained {success_count}/{len(selected_servers)} selected servers") @bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_talk')}") async def talk(ctx, sentence_size: int = 5): + server_id = ctx.guild.id if ctx.guild else "DM" + markov_model = load_markov_model(server_id) + if not markov_model: - await send_message(ctx, f"{get_translation(LOCALE, 'command_talk_insufficent_text')}") - return + memory = load_memory(server_id) + markov_model = train_markov_model(memory, server_id=server_id) + if not markov_model: + await send_message(ctx, f"{get_translation(LOCALE, 'command_talk_insufficent_text')}") + return response = None for _ in range(20): @@ -447,6 +562,7 @@ async def talk(ctx, sentence_size: int = 5): else: await send_message(ctx, f"{get_translation(LOCALE, 'command_talk_generation_fail')}") + def improve_sentence_coherence(sentence): sentence = sentence.replace(" i ", " I ") @@ -492,8 +608,6 @@ async def help(ctx): @bot.event async def on_message(message): - global memory, markov_model, last_random_talk_time - if message.author.bot: return @@ -508,22 +622,29 @@ async def on_message(message): if profanity.contains_profanity(message.content): return - if message.content: - if not USERTRAIN_ENABLED: - return + if message.content and USERTRAIN_ENABLED: + server_id = message.guild.id if message.guild else "DM" + memory = load_memory(server_id) + formatted_message = append_mentions_to_18digit_integer(message.content) cleaned_message = preprocess_message(formatted_message) + if cleaned_message: memory.append(cleaned_message) - save_memory(memory) + save_memory(memory, server_id) - # process any commands in the message await bot.process_commands(message) @bot.event async def on_interaction(interaction): - print(f"{get_translation(LOCALE, 'command_ran_s').format(interaction=interaction)}{interaction.data['name']}") + try: + if interaction.type == discord.InteractionType.application_command: + command_name = interaction.data.get('name', 'unknown') + print(f"{get_translation(LOCALE, 'command_ran_s').format(interaction=interaction)}{command_name}") + except Exception as e: + print(f"{RED}Error handling interaction: {e}{RESET}") + traceback.print_exc() @bot.check async def block_blacklisted(ctx): @@ -598,10 +719,13 @@ async def stats(ctx): async def mem(ctx): if showmemenabled != "true": return - memory = load_memory() - memory_text = json.dumps(memory, indent=4) - with open(MEMORY_FILE, "r") as f: - await send_message(ctx, file=discord.File(f, MEMORY_FILE)) + server_id = ctx.guild.id if ctx.guild else "DM" + memory_file = f"memories/memory_{server_id}.json" if server_id else "memories/memory.json" + try: + with open(memory_file, "r") as f: + await send_message(ctx, file=discord.File(f, memory_file)) + except FileNotFoundError: + await send_message(ctx, f"No memory file found at {memory_file}") def improve_sentence_coherence(sentence): sentence = sentence.replace(" i ", " I ")