forked from gooberinc/goober
added translations + finnish 1
This commit is contained in:
parent
59c7494675
commit
01ba29c944
12 changed files with 591 additions and 521 deletions
323
bot.py
323
bot.py
|
@ -21,28 +21,123 @@ from better_profanity import profanity
|
|||
from config import *
|
||||
import traceback
|
||||
import shutil
|
||||
|
||||
print(splashtext) # you can use https://patorjk.com/software/taag/ for 3d text or just remove this entirely
|
||||
|
||||
def download_json():
|
||||
response = requests.get(f"{VERSION_URL}/goob/locales/{LOCALE}.json")
|
||||
if response.status_code == 200:
|
||||
locales_dir = "locales"
|
||||
if not os.path.exists(locales_dir):
|
||||
os.makedirs(locales_dir)
|
||||
file_path = os.path.join(locales_dir, f"{LOCALE}.json")
|
||||
if os.path.exists(file_path):
|
||||
return
|
||||
else:
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(response.text)
|
||||
|
||||
if not os.path.exists(os.path.join(locales_dir, "en.json")):
|
||||
response = requests.get(f"{VERSION_URL}/goob/locales/en.json")
|
||||
if response.status_code == 200:
|
||||
with open(os.path.join(locales_dir, "en.json"), "w", encoding="utf-8") as file:
|
||||
file.write(response.text)
|
||||
|
||||
download_json()
|
||||
def load_translations():
|
||||
translations = {}
|
||||
translations_dir = os.path.join(os.path.dirname(__file__), "locales")
|
||||
|
||||
for filename in os.listdir(translations_dir):
|
||||
if filename.endswith(".json"):
|
||||
lang_code = filename.replace(".json", "")
|
||||
with open(os.path.join(translations_dir, filename), "r", encoding="utf-8") as f:
|
||||
translations[lang_code] = json.load(f)
|
||||
|
||||
return translations
|
||||
|
||||
translations = load_translations()
|
||||
|
||||
def get_translation(lang: str, key: str):
|
||||
lang_translations = translations.get(lang, translations["en"])
|
||||
if key not in lang_translations:
|
||||
print(f"{RED}Missing key: {key} in language {lang}{RESET}")
|
||||
return lang_translations.get(key, key)
|
||||
|
||||
|
||||
|
||||
def is_name_available(NAME):
|
||||
if os.getenv("gooberTOKEN"):
|
||||
return
|
||||
try:
|
||||
response = requests.post(f"{VERSION_URL}/check-if-available", json={"name": NAME}, headers={"Content-Type": "application/json"})
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("available", False)
|
||||
else:
|
||||
print(f"{get_translation(LOCALE, 'name_check')}", response.json())
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"{get_translation(LOCALE, 'name_check2')}", e)
|
||||
return False
|
||||
|
||||
def register_name(NAME):
|
||||
try:
|
||||
if ALIVEPING == False:
|
||||
return
|
||||
# check if the name is avaliable
|
||||
if not is_name_available(NAME):
|
||||
if os.getenv("gooberTOKEN"):
|
||||
return
|
||||
print(f"{RED}{get_translation(LOCALE, 'name_taken')}{RESET}")
|
||||
quit()
|
||||
|
||||
# if it is register it
|
||||
response = requests.post(f"{VERSION_URL}/register", json={"name": NAME}, headers={"Content-Type": "application/json"})
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
token = data.get("token")
|
||||
|
||||
if not os.getenv("gooberTOKEN"):
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'add_token').format(token=token)} gooberTOKEN=<your_token>.{RESET}")
|
||||
quit()
|
||||
else:
|
||||
print(f"{GREEN}{RESET}")
|
||||
|
||||
return token
|
||||
else:
|
||||
print(f"{RED}{get_translation(LOCALE, 'token_exists').format()}{RESET}", response.json())
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"{RED}{get_translation(LOCALE, 'registration_error').format()}{RESET}", e)
|
||||
return None
|
||||
|
||||
register_name(NAME)
|
||||
|
||||
def save_markov_model(model, filename='markov_model.pkl'):
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(model, f)
|
||||
print(f"Markov model saved to {filename}.")
|
||||
|
||||
|
||||
def backup_current_version():
|
||||
if os.path.exists(LOCAL_VERSION_FILE):
|
||||
shutil.copy(LOCAL_VERSION_FILE, LOCAL_VERSION_FILE + ".bak")
|
||||
print(f"Backup created: {LOCAL_VERSION_FILE}.bak")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'version_backup')} {LOCAL_VERSION_FILE}.bak{RESET}")
|
||||
else:
|
||||
print(f"Error: {LOCAL_VERSION_FILE} not found for backup.")
|
||||
print(f"{RED}{get_translation(LOCALE, 'backup_error').format(LOCAL_VERSION_FILE=LOCAL_VERSION_FILE)} {RESET}")
|
||||
|
||||
def load_markov_model(filename='markov_model.pkl'):
|
||||
|
||||
try:
|
||||
with open(filename, 'rb') as f:
|
||||
model = pickle.load(f)
|
||||
print(f"Markov model loaded from {filename}.")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'model_loaded')} {filename}.{RESET}")
|
||||
return model
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {filename} not found.")
|
||||
print(f"{RED}{filename} {get_translation(LOCALE, 'not_found')}{RESET}")
|
||||
return None
|
||||
|
||||
def get_latest_version_info():
|
||||
|
@ -53,10 +148,10 @@ def get_latest_version_info():
|
|||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
print(f"Error: Unable to fetch version info. Status code {response.status_code}")
|
||||
print(f"{RED}{get_translation(LOCALE, 'version_error')} {response.status_code}{RESET}")
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
print(f"Error: Unable to connect to the update server. {e}")
|
||||
print(f"{RED}{get_translation(LOCALE, 'version_error')} {e}{RESET}")
|
||||
return None
|
||||
|
||||
async def load_cogs_from_folder(bot, folder_name="cogs"):
|
||||
|
@ -65,9 +160,9 @@ async def load_cogs_from_folder(bot, folder_name="cogs"):
|
|||
cog_name = filename[:-3]
|
||||
try:
|
||||
await bot.load_extension(f"{folder_name}.{cog_name}")
|
||||
print(f"Loaded cog: {cog_name}")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'loaded_cog')} {cog_name}{RESET}")
|
||||
except Exception as e:
|
||||
print(f"Failed to load cog {cog_name}: {e}")
|
||||
print(f"{RED}{get_translation(LOCALE, 'cog_fail')} {cog_name} {e}{RESET}")
|
||||
traceback.print_exc()
|
||||
|
||||
currenthash = ""
|
||||
|
@ -97,14 +192,14 @@ def check_for_update():
|
|||
|
||||
latest_version_info = get_latest_version_info()
|
||||
if not latest_version_info:
|
||||
print("Could not fetch update information.")
|
||||
print(f"{get_translation(LOCALE, 'fetch_update_fail')}")
|
||||
return None, None
|
||||
|
||||
latest_version = latest_version_info.get("version")
|
||||
download_url = latest_version_info.get("download_url")
|
||||
|
||||
if not latest_version or not download_url:
|
||||
print("Error: Invalid version information received from server.")
|
||||
print(f"{RED}{get_translation(LOCALE, 'invalid_server')}{RESET}")
|
||||
return None, None
|
||||
|
||||
local_version = get_local_version()
|
||||
|
@ -112,28 +207,28 @@ def check_for_update():
|
|||
gooberhash = latest_version_info.get("hash")
|
||||
if gooberhash == currenthash:
|
||||
if local_version < latest_version:
|
||||
print(f"{YELLOW}New version available: {latest_version} (Current: {local_version}){RESET}")
|
||||
print(f"Check {VERSION_URL}/goob/changes.txt to check out the changelog\n\n")
|
||||
print(f"{YELLOW}{get_translation(LOCALE, 'new_version')}{RESET}")
|
||||
print(f"{YELLOW}{get_translation(LOCALE, 'changelog').format(VERSION_URL=VERSION_URL)}")
|
||||
elif local_version > latest_version:
|
||||
if IGNOREWARNING == False:
|
||||
print(f"\n{RED}The version: {local_version} isnt valid!")
|
||||
print(f"{RED}If this is intended then ignore this message, else press Y to pull a valid version from the server regardless of the version of goober currently running")
|
||||
print(f"The current version will be backed up to current_version.bak..{RESET}\n\n")
|
||||
userinp = input("(Y or any other key to ignore....)\n")
|
||||
print(f"\n{RED}{get_translation(LOCALE, 'invalid_version').format(local_version=local_version)}")
|
||||
print(f"{get_translation(LOCALE, 'invalid_version2')}")
|
||||
print(f"{get_translation(LOCALE, 'invalid_version3')}{RESET}\n\n")
|
||||
userinp = input(f"{get_translation(LOCALE, 'input')}\n")
|
||||
if userinp.lower() == "y":
|
||||
backup_current_version()
|
||||
with open(LOCAL_VERSION_FILE, "w") as f:
|
||||
f.write(latest_version)
|
||||
else:
|
||||
print(f"{RED}You've modified {LOCAL_VERSION_FILE}")
|
||||
print(f"IGNOREWARNING is set to false..{RESET}")
|
||||
print(f"{RED}{get_translation(LOCALE, 'modification_ignored')} {LOCAL_VERSION_FILE}")
|
||||
print(f"{get_translation(LOCALE, 'modification_ignored2')}{RESET}")
|
||||
else:
|
||||
print(f"{GREEN}You're using the latest version: {local_version}{RESET}")
|
||||
print(f"Check {VERSION_URL}/goob/changes.txt to check out the changelog\n\n")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'latest_version')} {local_version}{RESET}")
|
||||
print(f"{get_translation(LOCALE, 'latest_version2').format(VERSION_URL=VERSION_URL)}\n\n")
|
||||
else:
|
||||
print(f"{YELLOW}Goober has been modified! Skipping server checks entirely...")
|
||||
print(f"Reported Version: {local_version}{RESET}")
|
||||
print(f"Current Hash: {currenthash}")
|
||||
print(f"{YELLOW}{get_translation(LOCALE, 'modification_warning')}")
|
||||
print(f"{YELLOW}{get_translation(LOCALE, 'reported_version')} {local_version}{RESET}")
|
||||
print(f"{DEBUG}{get_translation(LOCALE, 'current_hash')} {currenthash}{RESET}")
|
||||
|
||||
|
||||
check_for_update()
|
||||
|
@ -147,12 +242,8 @@ def get_file_info(file_path):
|
|||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
nltk.download('punkt')
|
||||
|
||||
|
||||
|
||||
|
||||
def load_memory():
|
||||
data = []
|
||||
|
||||
|
@ -206,7 +297,7 @@ bot = commands.Bot(command_prefix=PREFIX, intents=intents)
|
|||
memory = load_memory()
|
||||
markov_model = load_markov_model()
|
||||
if not markov_model:
|
||||
print("No saved Markov model found. Starting from scratch.")
|
||||
print(f"{get_translation(LOCALE, 'no_model')}")
|
||||
memory = load_memory()
|
||||
markov_model = train_markov_model(memory)
|
||||
|
||||
|
@ -216,48 +307,51 @@ used_words = set()
|
|||
slash_commands_enabled = False
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
|
||||
folder_name = "cogs"
|
||||
if not os.path.exists(folder_name):
|
||||
os.makedirs(folder_name)
|
||||
print(f"Folder '{folder_name}' created.")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'folder_created').format(folder_name=folder_name)}{RESET}")
|
||||
else:
|
||||
print(f"Folder '{folder_name}' already exists. skipping...")
|
||||
print(f"{DEBUG}{get_translation(LOCALE, 'folder_exists').format(folder_name=folder_name)}{RESET}")
|
||||
markov_model = train_markov_model(memory)
|
||||
await load_cogs_from_folder(bot)
|
||||
global slash_commands_enabled
|
||||
print(f"Logged in as {bot.user}")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'logged_in')} {bot.user}{RESET}")
|
||||
try:
|
||||
synced = await bot.tree.sync()
|
||||
print(f"Synced {len(synced)} commands.")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'synced_commands')} {len(synced)} {get_translation(LOCALE, 'synced_commands2')} {RESET}")
|
||||
slash_commands_enabled = True
|
||||
ping_server()
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'started').format()}{RESET}")
|
||||
except Exception as e:
|
||||
print(f"Failed to sync commands: {e}")
|
||||
quit
|
||||
post_message.start()
|
||||
print(f"{RED}{get_translation(LOCALE, 'fail_commands_sync')} {e}{RESET}")
|
||||
traceback.print_exc()
|
||||
quit()
|
||||
if not song:
|
||||
return
|
||||
await bot.change_presence(activity=discord.Activity(type=discord.ActivityType.listening, name=f"{song}"))
|
||||
|
||||
def ping_server():
|
||||
if ALIVEPING == "false":
|
||||
print("Pinging is disabled! Not telling the server im on...")
|
||||
print(f"{YELLOW}{get_translation(LOCALE, 'pinging_disabled')}{RESET}")
|
||||
return
|
||||
file_info = get_file_info(MEMORY_FILE)
|
||||
payload = {
|
||||
"name": NAME,
|
||||
"memory_file_info": file_info,
|
||||
"version": local_version,
|
||||
"slash_commands": slash_commands_enabled
|
||||
"slash_commands": slash_commands_enabled,
|
||||
"token": gooberTOKEN
|
||||
}
|
||||
try:
|
||||
response = requests.post(VERSION_URL+"/ping", json=payload)
|
||||
if response.status_code == 200:
|
||||
print("Sent alive ping to goober central!")
|
||||
print(f"{GREEN}{get_translation(LOCALE, 'goober_ping_success')}{RESET}")
|
||||
else:
|
||||
print(f"Failed to send data. Server returned status code: {response.status_code}")
|
||||
print(f"{RED}{get_translation(LOCALE, 'goober_ping_fail')} {response.status_code}{RESET}")
|
||||
except Exception as e:
|
||||
print(f"An error occurred while sending data: {str(e)}")
|
||||
print(f"{RED}{get_translation(LOCALE, 'goober_ping_fail2')} {str(e)}{RESET}")
|
||||
|
||||
|
||||
positive_gifs = os.getenv("POSITIVE_GIFS").split(',')
|
||||
|
@ -265,7 +359,7 @@ positive_gifs = os.getenv("POSITIVE_GIFS").split(',')
|
|||
def is_positive(sentence):
|
||||
blob = TextBlob(sentence)
|
||||
sentiment_score = blob.sentiment.polarity
|
||||
print(sentiment_score)
|
||||
print(f"{DEBUG}{get_translation(LOCALE, 'sentence_positivity')} {sentiment_score}{RESET}")
|
||||
return sentiment_score > 0.1
|
||||
|
||||
|
||||
|
@ -275,7 +369,7 @@ async def send_message(ctx, message=None, embed=None, file=None, edit=False, mes
|
|||
# Editing the existing message
|
||||
await message_reference.edit(content=message, embed=embed)
|
||||
except Exception as e:
|
||||
await ctx.send(f"Failed to edit message: {e}")
|
||||
await ctx.send(f"{RED}{get_translation(LOCALE, 'edit_fail')} {e}{RESET}")
|
||||
else:
|
||||
if hasattr(ctx, "respond"):
|
||||
# For slash command contexts
|
||||
|
@ -297,44 +391,41 @@ async def send_message(ctx, message=None, embed=None, file=None, edit=False, mes
|
|||
sent_message = await ctx.send(file=file)
|
||||
return sent_message
|
||||
|
||||
|
||||
|
||||
|
||||
@bot.hybrid_command(description="Retrains the Markov model manually.")
|
||||
@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_retrain')}")
|
||||
async def retrain(ctx):
|
||||
if ctx.author.id != ownerid:
|
||||
return
|
||||
|
||||
message_ref = await send_message(ctx, "Retraining the Markov model... Please wait.")
|
||||
message_ref = await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retrain')}")
|
||||
try:
|
||||
with open(MEMORY_FILE, 'r') as f:
|
||||
memory = json.load(f)
|
||||
except FileNotFoundError:
|
||||
await send_message(ctx, "Error: memory file not found!")
|
||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_memory_not_found')}")
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
await send_message(ctx, "Error: memory file is corrupted!")
|
||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_memory_is_corrupt')}")
|
||||
return
|
||||
data_size = len(memory)
|
||||
processed_data = 0
|
||||
processing_message_ref = await send_message(ctx, f"Processing {processed_data}/{data_size} data points...")
|
||||
processing_message_ref = await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining')}")
|
||||
start_time = time.time()
|
||||
for i, data in enumerate(memory):
|
||||
processed_data += 1
|
||||
if processed_data % 1000 == 0 or processed_data == data_size:
|
||||
await send_message(ctx, f"Processing {processed_data}/{data_size} data points...", edit=True, message_reference=processing_message_ref)
|
||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retraining')}", edit=True, message_reference=processing_message_ref)
|
||||
|
||||
global markov_model
|
||||
|
||||
markov_model = train_markov_model(memory)
|
||||
save_markov_model(markov_model)
|
||||
|
||||
await send_message(ctx, f"Markov model retrained successfully using {data_size} data points!", edit=True, message_reference=processing_message_ref)
|
||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_markov_retrain_successful')}", edit=True, message_reference=processing_message_ref)
|
||||
|
||||
@bot.hybrid_command(description="talks n like stuf")
|
||||
@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_talk')}")
|
||||
async def talk(ctx):
|
||||
if not markov_model:
|
||||
await send_message(ctx, "I need to learn more from messages before I can talk.")
|
||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_talk_insufficent_text')}")
|
||||
return
|
||||
|
||||
response = None
|
||||
|
@ -356,7 +447,7 @@ async def talk(ctx):
|
|||
combined_message = coherent_response
|
||||
await send_message(ctx, combined_message)
|
||||
else:
|
||||
await send_message(ctx, "I have nothing to say right now!")
|
||||
await send_message(ctx, f"{get_translation(LOCALE, 'command_talk_generation_fail')}")
|
||||
|
||||
def improve_sentence_coherence(sentence):
|
||||
|
||||
|
@ -373,27 +464,27 @@ def rephrase_for_coherence(sentence):
|
|||
bot.help_command = None
|
||||
|
||||
|
||||
@bot.hybrid_command(description="help")
|
||||
@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_help')}")
|
||||
async def help(ctx):
|
||||
embed = discord.Embed(
|
||||
title="Bot Help",
|
||||
description="List of commands grouped by category.",
|
||||
title=f"{get_translation(LOCALE, 'command_help_embed_title')}",
|
||||
description=f"{get_translation(LOCALE, 'command_help_embed_desc')}",
|
||||
color=discord.Color.blue()
|
||||
)
|
||||
|
||||
command_categories = {
|
||||
"General": ["mem", "talk", "about", "ping"],
|
||||
"Administration": ["stats", "retrain"]
|
||||
f"{get_translation(LOCALE, 'command_help_categories_general')}": ["mem", "talk", "about", "ping"],
|
||||
f"{get_translation(LOCALE, 'command_help_categories_admin')}": ["stats", "retrain"]
|
||||
}
|
||||
|
||||
custom_commands = []
|
||||
for cog_name, cog in bot.cogs.items():
|
||||
for command in cog.get_commands():
|
||||
if command.name not in command_categories["General"] and command.name not in command_categories["Administration"]:
|
||||
if command.name not in command_categories[f"{get_translation(LOCALE, 'command_help_categories_general')}"] and command.name not in command_categories["Administration"]:
|
||||
custom_commands.append(command.name)
|
||||
|
||||
if custom_commands:
|
||||
embed.add_field(name="Custom Commands", value="\n".join([f"{PREFIX}{command}" for command in custom_commands]), inline=False)
|
||||
embed.add_field(name=f"{get_translation(LOCALE, 'command_help_categories_custom')}", value="\n".join([f"{PREFIX}{command}" for command in custom_commands]), inline=False)
|
||||
|
||||
for category, commands_list in command_categories.items():
|
||||
commands_in_category = "\n".join([f"{PREFIX}{command}" for command in commands_list])
|
||||
|
@ -401,9 +492,6 @@ async def help(ctx):
|
|||
|
||||
await send_message(ctx, embed=embed)
|
||||
|
||||
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_message(message):
|
||||
global memory, markov_model, last_random_talk_time
|
||||
|
@ -414,13 +502,8 @@ async def on_message(message):
|
|||
if str(message.author.id) in BLACKLISTED_USERS:
|
||||
return
|
||||
|
||||
random_talk_channels = [random_talk_channel_id2, random_talk_channel_id1]
|
||||
cooldowns = {
|
||||
random_talk_channel_id2: 1,
|
||||
}
|
||||
default_cooldown = 10800
|
||||
|
||||
if message.content.startswith((f"{PREFIX}talk", f"{PREFIX}mem", f"{PREFIX}help", f"{PREFIX}stats", f"{PREFIX}")):
|
||||
print(f"{get_translation(LOCALE, 'command_ran').format(message=message)}")
|
||||
await bot.process_commands(message)
|
||||
return
|
||||
|
||||
|
@ -436,50 +519,10 @@ async def on_message(message):
|
|||
memory.append(cleaned_message)
|
||||
save_memory(memory)
|
||||
|
||||
|
||||
cooldown_time = cooldowns.get(message.channel.id, default_cooldown)
|
||||
if message.reference and message.reference.message_id:
|
||||
replied_message = await message.channel.fetch_message(message.reference.message_id)
|
||||
if replied_message.author == bot.user:
|
||||
print("Bot is replying to a message directed at it!")
|
||||
response = None
|
||||
for _ in range(10):
|
||||
response = markov_model.make_sentence(tries=100)
|
||||
if response and response not in generated_sentences:
|
||||
response = improve_sentence_coherence(response)
|
||||
generated_sentences.add(response)
|
||||
break
|
||||
if response:
|
||||
await message.channel.send(response)
|
||||
return
|
||||
|
||||
# random chance for bot to talk
|
||||
random_chance = random.randint(0, 20)
|
||||
|
||||
# talk randomly only in the specified channels
|
||||
if message.channel.id in random_talk_channels and random_chance >= 10:
|
||||
current_time = time.time()
|
||||
print(f"Random chance: {random_chance}, Time passed: {current_time - last_random_talk_time}")
|
||||
|
||||
if current_time - last_random_talk_time >= cooldown_time:
|
||||
print("Bot is talking randomly!")
|
||||
last_random_talk_time = current_time
|
||||
|
||||
response = None
|
||||
for _ in range(10):
|
||||
response = markov_model.make_sentence(tries=100)
|
||||
if response and response not in generated_sentences:
|
||||
response = improve_sentence_coherence(response)
|
||||
generated_sentences.add(response)
|
||||
break
|
||||
|
||||
if response:
|
||||
await message.channel.send(response)
|
||||
|
||||
# process any commands in the message
|
||||
await bot.process_commands(message)
|
||||
|
||||
@bot.hybrid_command(description="ping")
|
||||
@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_desc_ping')}")
|
||||
async def ping(ctx):
|
||||
await ctx.defer()
|
||||
latency = round(bot.latency * 1000)
|
||||
|
@ -488,34 +531,33 @@ async def ping(ctx):
|
|||
title="Pong!!",
|
||||
description=(
|
||||
f"{PING_LINE}\n"
|
||||
f"`Bot Latency: {latency}ms`\n"
|
||||
f"`{get_translation(LOCALE, 'command_ping_embed_desc')}: {latency}ms`\n"
|
||||
),
|
||||
color=discord.Color.blue()
|
||||
)
|
||||
LOLembed.set_footer(text=f"Requested by {ctx.author.name}", icon_url=ctx.author.avatar.url)
|
||||
LOLembed.set_footer(text=f"{get_translation(LOCALE, 'command_ping_footer')} {ctx.author.name}", icon_url=ctx.author.avatar.url)
|
||||
|
||||
await ctx.send(embed=LOLembed)
|
||||
|
||||
@bot.hybrid_command(description="about bot")
|
||||
@bot.hybrid_command(description=f"{get_translation(LOCALE, 'command_about_desc')}")
|
||||
async def about(ctx):
|
||||
print("-------UPDATING VERSION INFO-------\n\n")
|
||||
print("-----------------------------------\n\n")
|
||||
try:
|
||||
check_for_update()
|
||||
except Exception as e:
|
||||
pass
|
||||
print("-----------------------------------")
|
||||
embed = discord.Embed(title="About me", description="", color=discord.Color.blue())
|
||||
embed.add_field(name="Name", value=f"{NAME}", inline=False)
|
||||
embed.add_field(name="Version", value=f"Local: {local_version} \nLatest: {latest_version}", inline=False)
|
||||
embed = discord.Embed(title=f"{get_translation(LOCALE, 'command_about_embed_title')}", description="", color=discord.Color.blue())
|
||||
embed.add_field(name=f"{get_translation(LOCALE, 'command_about_embed_field1')}", value=f"{NAME}", inline=False)
|
||||
embed.add_field(name=f"{get_translation(LOCALE, 'command_about_embed_field2name')}", value=f"f{get_translation(LOCALE, 'command_about_embed_field2value').format(local_version=local_version, latest_version=latest_version)}", inline=False)
|
||||
|
||||
await send_message(ctx, embed=embed)
|
||||
|
||||
|
||||
@bot.hybrid_command(description="stats")
|
||||
async def stats(ctx):
|
||||
if ctx.author.id != ownerid:
|
||||
return
|
||||
print("-------UPDATING VERSION INFO-------\n\n")
|
||||
print("-----------------------------------\n\n")
|
||||
try:
|
||||
check_for_update()
|
||||
except Exception as e:
|
||||
|
@ -526,49 +568,24 @@ async def stats(ctx):
|
|||
|
||||
with open(memory_file, 'r') as file:
|
||||
line_count = sum(1 for _ in file)
|
||||
|
||||
embed = discord.Embed(title="Bot Stats", description="Data about the the bot's memory.", color=discord.Color.blue())
|
||||
embed.add_field(name="File Stats", value=f"Size: {file_size} bytes\nLines: {line_count}", inline=False)
|
||||
embed.add_field(name="Version", value=f"Local: {local_version} \nLatest: {latest_version}", inline=False)
|
||||
embed.add_field(name="Variable Info", value=f"Name: {NAME} \nPrefix: {PREFIX} \nOwner ID: {ownerid} \nCooldown: {cooldown_time} \nPing line: {PING_LINE} \nMemory Sharing Enabled: {showmemenabled} \nUser Training Enabled: {USERTRAIN_ENABLED} \nLast Random TT: {last_random_talk_time} \nSong: {song} \nSplashtext: ```{splashtext}```", inline=False)
|
||||
embed = discord.Embed(title=f"{get_translation(LOCALE, 'command_stats_embed_title')}", description=f"{get_translation(LOCALE, 'command_stats_embed_desc')}", color=discord.Color.blue())
|
||||
embed.add_field(name=f"{get_translation(LOCALE, 'command_stats_embed_field1name')}", value=f"{get_translation(LOCALE, 'command_stats_embed_field1value').format(file_size=file_size, line_count=line_count)}", inline=False)
|
||||
embed.add_field(name=f"{get_translation(LOCALE, 'command_stats_embed_field2name')}", value=f"{get_translation(LOCALE, 'command_stats_embed_field2value').format(local_version=local_version, latest_version=latest_version)}", inline=False)
|
||||
embed.add_field(name=f"{get_translation(LOCALE, 'command_stats_embed_field3name')}", value=f"{get_translation(LOCALE, 'command_stats_embed_field3value').format(NAME=NAME, PREFIX=PREFIX, ownerid=ownerid, cooldown_time=cooldown_time, PING_LINE=PING_LINE, showmemenabled=showmemenabled, USERTRAIN_ENABLED=USERTRAIN_ENABLED, last_random_talk_time=last_random_talk_time, song=song, splashtext=splashtext)}", inline=False)
|
||||
|
||||
await send_message(ctx, embed=embed)
|
||||
|
||||
|
||||
|
||||
@bot.hybrid_command()
|
||||
async def mem(ctx):
|
||||
if showmemenabled != "true":
|
||||
return
|
||||
memory = load_memory()
|
||||
memory_text = json.dumps(memory, indent=4)
|
||||
|
||||
if len(memory_text) > 1024:
|
||||
with open(MEMORY_FILE, "r") as f:
|
||||
await send_message(ctx, file=discord.File(f, MEMORY_FILE))
|
||||
else:
|
||||
embed = discord.Embed(title="Memory Contents", description="The bot's memory.", color=discord.Color.blue())
|
||||
embed.add_field(name="Memory Data", value=f"```json\n{memory_text}\n```", inline=False)
|
||||
await send_message(ctx, embed=embed)
|
||||
|
||||
with open(MEMORY_FILE, "r") as f:
|
||||
await send_message(ctx, file=discord.File(f, MEMORY_FILE))
|
||||
|
||||
def improve_sentence_coherence(sentence):
|
||||
sentence = sentence.replace(" i ", " I ")
|
||||
return sentence
|
||||
|
||||
@tasks.loop(minutes=60)
|
||||
async def post_message():
|
||||
channel_id = hourlyspeak
|
||||
channel = bot.get_channel(channel_id)
|
||||
if channel and markov_model:
|
||||
response = None
|
||||
for _ in range(20):
|
||||
response = markov_model.make_sentence(tries=100)
|
||||
if response and response not in generated_sentences:
|
||||
generated_sentences.add(response)
|
||||
break
|
||||
|
||||
if response:
|
||||
await channel.send(response)
|
||||
|
||||
bot.run(TOKEN)
|
||||
bot.run(TOKEN)
|
|
@ -1,340 +1,340 @@
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
import os
|
||||
from typing import List, TypedDict
|
||||
import numpy as np
|
||||
import json
|
||||
from time import strftime, localtime
|
||||
import pickle
|
||||
import functools
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
ready: bool = True
|
||||
MODEL_MATCH_STRING = "[0-9]{2}_[0-9]{2}_[0-9]{4}-[0-9]{2}_[0-9]{2}"
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from keras.preprocessing.text import Tokenizer
|
||||
from keras_preprocessing.sequence import pad_sequences
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Embedding, LSTM, Dense
|
||||
from keras.models import load_model
|
||||
from keras.backend import clear_session
|
||||
tf.config.optimizer.set_jit(True)
|
||||
except ImportError:
|
||||
print("ERROR: Failed to import Tensorflow. Here is a list of required dependencies:",(
|
||||
"tensorflow==2.10.0"
|
||||
"(for Nvidia users: tensorflow-gpu==2.10.0)"
|
||||
"(for macOS: tensorflow-metal==0.6.0, tensorflow-macos==2.10.0)"
|
||||
"numpy~=1.23"
|
||||
))
|
||||
ready = False
|
||||
|
||||
class TFCallback(keras.callbacks.Callback):
|
||||
def __init__(self,bot, progress_embed:discord.Embed, message):
|
||||
self.embed:discord.Embed = progress_embed
|
||||
self.bot:commands.Bot = bot
|
||||
self.message = message
|
||||
self.times:List[int] = [time.time()]
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
pass
|
||||
|
||||
async def send_message(self,message:str, description:str, **kwargs):
|
||||
if "epoch" in kwargs:
|
||||
self.times.append(time.time())
|
||||
average_epoch_time:int = np.average(np.diff(np.array(self.times)))
|
||||
description = f"ETA: {round(average_epoch_time)}s"
|
||||
self.embed.add_field(name=f"<t:{round(time.time())}:t> - {message}",value=description,inline=False)
|
||||
await self.message.edit(embed=self.embed)
|
||||
|
||||
def on_train_end(self,logs=None):
|
||||
self.bot.loop.create_task(self.send_message("Training stopped", "training has been stopped."))
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
self.bot.loop.create_task(self.send_message(f"Starting epoch {epoch}","This might take a while", epoch=True))
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
self.bot.loop.create_task(self.send_message(f"Epoch {epoch} ended",f"Accuracy: {round(logs.get('accuracy',0.0),4)}"))
|
||||
|
||||
|
||||
class Ai:
|
||||
def __init__(self):
|
||||
model_path = settings.get("model_path")
|
||||
if model_path:
|
||||
self.__load_model(model_path)
|
||||
self.is_loaded = model_path is not None
|
||||
self.batch_size = 64
|
||||
|
||||
def get_model_name_from_path(self,path:str):
|
||||
match:re.Match = re.search(MODEL_MATCH_STRING, path)
|
||||
return path[match.start():][:match.end()]
|
||||
|
||||
def generate_model_name(self) -> str:
|
||||
return strftime('%d_%m_%Y-%H_%M', localtime())
|
||||
|
||||
def generate_model_abs_path(self, name:str):
|
||||
name = name or self.generate_model_name()
|
||||
return os.path.join(".","models",self.generate_model_name(),"model.h5")
|
||||
|
||||
def generate_tokenizer_abs_path(self, name:str):
|
||||
name = name or self.generate_model_name()
|
||||
return os.path.join(".","models",name,"tokenizer.pkl")
|
||||
|
||||
def generate_info_abs_path(self,name:str):
|
||||
name = name or self.generate_model_name()
|
||||
return os.path.join(".","models",name,"info.json")
|
||||
|
||||
|
||||
def save_model(self,model, tokenizer, history, _name:str=None):
|
||||
name:str = _name or self.generate_model_name()
|
||||
os.makedirs(os.path.join(".","models",name), exist_ok=True)
|
||||
|
||||
with open(self.generate_info_abs_path(name),"w") as f:
|
||||
json.dump(history.history,f)
|
||||
|
||||
with open(self.generate_tokenizer_abs_path(name), "wb") as f:
|
||||
pickle.dump(tokenizer,f)
|
||||
|
||||
model.save(self.generate_model_abs_path(name))
|
||||
|
||||
|
||||
def __load_model(self, model_path:str):
|
||||
clear_session()
|
||||
self.model = load_model(os.path.join(model_path,"model.h5"))
|
||||
|
||||
model_name:str = self.get_model_name_from_path(model_path)
|
||||
|
||||
try:
|
||||
with open(self.generate_tokenizer_abs_path(model_name),"rb") as f:
|
||||
self.tokenizer = pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
print("Failed to load tokenizer for model... Using default")
|
||||
self.tokenizer = Tokenizer()
|
||||
|
||||
with open("memory.json","r") as f:
|
||||
self.tokenizer.fit_on_sequences(json.load(f))
|
||||
self.is_loaded = True
|
||||
|
||||
def reload_model(self):
|
||||
clear_session()
|
||||
model_path:str = settings.get("model_path")
|
||||
if model_path:
|
||||
self.model = self.__load_model(model_path)
|
||||
self.is_loaded = True
|
||||
|
||||
async def run_async(self,func,bot,*args,**kwargs):
|
||||
func = functools.partial(func,*args,**kwargs)
|
||||
return await bot.loop.run_in_executor(None,func)
|
||||
|
||||
class Learning(Ai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __generate_labels_and_inputs(self,memory: List[str], tokenizer=None) -> tuple:
|
||||
if not tokenizer:
|
||||
tokenizer = Tokenizer()
|
||||
tokenizer.fit_on_texts(memory)
|
||||
sequences = tokenizer.texts_to_sequences(memory)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for seq in sequences:
|
||||
for i in range(1, len(seq)):
|
||||
x.append(seq[:i])
|
||||
y.append(seq[i])
|
||||
|
||||
return x,y, tokenizer
|
||||
|
||||
def create_model(self,memory: list, iters:int=2):
|
||||
memory = memory[:2000]
|
||||
X,y,tokenizer = self.__generate_labels_and_inputs(memory)
|
||||
maxlen:int = max([len(x) for x in X])
|
||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
||||
|
||||
y = np.array(y)
|
||||
|
||||
model = Sequential()
|
||||
model.add(Embedding(input_dim=VOCAB_SIZE,output_dim=128,input_length=maxlen))
|
||||
model.add(LSTM(64))
|
||||
model.add(Dense(VOCAB_SIZE, activation="softmax"))
|
||||
|
||||
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
||||
history = model.fit(x_pad, y, epochs=iters, batch_size=64, callbacks=[tf_callback])
|
||||
self.save_model(model, tokenizer, history)
|
||||
return
|
||||
|
||||
|
||||
def add_training(self,memory: List[str], iters:int=2):
|
||||
tokenizer_path = os.path.join(settings.get("model_path"),"tokenizer.pkl")
|
||||
with open(tokenizer_path, "rb") as f:
|
||||
tokenizer = pickle.load(f)
|
||||
|
||||
X,y,_ = self.__generate_labels_and_inputs(memory, tokenizer)
|
||||
|
||||
maxlen:int = max([len(x) for x in X])
|
||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
||||
y = np.array(y)
|
||||
|
||||
history = self.model.fit(x_pad,y, epochs=iters, validation_data=(x_pad,y), batch_size=64, callbacks=[tf_callback]) # Ideally, validation data would be seperate from the actual data
|
||||
self.save_model(self.model,tokenizer,history,self.get_model_name_from_path(settings.get("model_path")))
|
||||
return
|
||||
|
||||
class Generation(Ai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def generate_sentence(self, word_amount:int, seed:str):
|
||||
if not self.is_loaded:
|
||||
return False
|
||||
for _ in range(word_amount):
|
||||
token_list = self.tokenizer.texts_to_sequences([seed])[0]
|
||||
token_list = pad_sequences([token_list], maxlen=self.model.layers[0].input_shape[1], padding="pre")
|
||||
|
||||
output_word = "" # Sometimes model fails to predict the word, so using a fallback
|
||||
|
||||
predicted_probs = self.model.predict(token_list, verbose=0)
|
||||
predicted_word_index = np.argmax(predicted_probs, axis=-1)[0]
|
||||
|
||||
for word, index in self.tokenizer.word_index.items():
|
||||
if index == predicted_word_index:
|
||||
output_word = word
|
||||
break
|
||||
|
||||
seed += " " + output_word
|
||||
return seed
|
||||
|
||||
|
||||
VOCAB_SIZE = 100_000
|
||||
SETTINGS_TYPE = TypedDict("SETTINGS_TYPE", {
|
||||
"model_path":str, # path to the base folder of the model, aka .../models/05-01-2025-22_31/
|
||||
"tokenizer_path":str,
|
||||
})
|
||||
|
||||
tf_callback:TFCallback
|
||||
model_dropdown_items = []
|
||||
settings: SETTINGS_TYPE = {}
|
||||
|
||||
target_message:int
|
||||
learning:Learning
|
||||
generation: Generation
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.settings_path:str = os.path.join(".","models","settings.json")
|
||||
|
||||
def load(self):
|
||||
global settings
|
||||
try:
|
||||
with open(self.settings_path,"r") as f:
|
||||
settings = json.load(f)
|
||||
except FileNotFoundError:
|
||||
with open(self.settings_path,"w") as f:
|
||||
json.dump({},f)
|
||||
|
||||
def change_model(self,new_model_base_path:str):
|
||||
global settings
|
||||
new_model_path = os.path.join(".","models",new_model_base_path)
|
||||
|
||||
with open(self.settings_path,"r") as f:
|
||||
settings = json.load(f)
|
||||
|
||||
settings["model_path"] = new_model_path
|
||||
|
||||
with open(self.settings_path, "w") as f:
|
||||
json.dump(settings,f)
|
||||
|
||||
|
||||
class Dropdown(discord.ui.Select):
|
||||
def __init__(self, items:List[str]):
|
||||
global model_dropdown_items
|
||||
model_dropdown_items = []
|
||||
|
||||
for item in items:
|
||||
model_dropdown_items.append(
|
||||
discord.SelectOption(label=item)
|
||||
)
|
||||
|
||||
super().__init__(placeholder="Select model", options=model_dropdown_items)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction):
|
||||
if int(interaction.user.id) != int(os.getenv("ownerid")):
|
||||
await interaction.message.channel.send("KILL YOURSELF")
|
||||
Settings().change_model(self.values[0])
|
||||
await interaction.message.channel.send(f"Changed model to {self.values[0]}")
|
||||
|
||||
class DropdownView(discord.ui.View):
|
||||
def __init__(self, timeout, models):
|
||||
super().__init__(timeout=timeout)
|
||||
self.add_item(Dropdown(models))
|
||||
|
||||
|
||||
class Tf(commands.Cog):
|
||||
def __init__(self,bot):
|
||||
global learning, generation, ready
|
||||
os.makedirs(os.path.join(".","models"),exist_ok=True)
|
||||
Settings().load()
|
||||
self.bot = bot
|
||||
learning = Learning()
|
||||
generation = Generation()
|
||||
|
||||
@commands.command()
|
||||
async def start(self,ctx):
|
||||
await ctx.defer()
|
||||
await ctx.send("hi")
|
||||
|
||||
@commands.command()
|
||||
async def generate(self,ctx,seed:str,word_amount:int=5):
|
||||
await ctx.defer()
|
||||
await ctx.send(generation.generate_sentence(word_amount,seed))
|
||||
|
||||
@commands.command()
|
||||
async def create(self,ctx:commands.Context, epochs:int=3):
|
||||
global tf_callback
|
||||
await ctx.defer()
|
||||
with open("memory.json","r") as f:
|
||||
memory:List[str] = json.load(f)
|
||||
await ctx.send("Initializing tensorflow")
|
||||
embed = discord.Embed(title="Creating a model...", description="Progress of creating a model")
|
||||
embed.set_footer(text="Note: Progress tracking might report delayed / wrong data, since the function is run asynchronously")
|
||||
target_message:discord.Message = await ctx.send(embed=embed)
|
||||
|
||||
tf_callback = TFCallback(self.bot,embed,target_message)
|
||||
await learning.run_async(learning.create_model,self.bot,memory,epochs)
|
||||
embed = target_message.embeds[0]
|
||||
embed.add_field(name=f"<t:{round(time.time())}:t> Finished",value="Model saved.")
|
||||
await target_message.edit(embed=embed)
|
||||
|
||||
|
||||
@commands.command()
|
||||
async def train(self,ctx, epochs:int=2):
|
||||
global tf_callback
|
||||
|
||||
await ctx.defer()
|
||||
with open("memory.json","r") as f:
|
||||
memory:List[str] = json.load(f)
|
||||
|
||||
embed = discord.Embed(title="Training model...", description="Progress of training model")
|
||||
target_message = await ctx.send(embed=embed)
|
||||
tf_callback = TFCallback(self.bot,embed,target_message)
|
||||
|
||||
await learning.run_async(learning.add_training,self.bot,memory,epochs)
|
||||
await ctx.send("Finished!")
|
||||
|
||||
@commands.command()
|
||||
async def change(self,ctx,model:str=None):
|
||||
embed = discord.Embed(title="Change model",description="Which model would you like to use?")
|
||||
if model is None:
|
||||
models:List[str] = os.listdir(os.path.join(".","models"))
|
||||
models = [folder for folder in models if re.match(MODEL_MATCH_STRING,folder)]
|
||||
if len(models) == 0:
|
||||
models = ["No models available."]
|
||||
await ctx.send(embed=embed,view=DropdownView(90,models))
|
||||
learning.reload_model()
|
||||
generation.reload_model()
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(Tf(bot))
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import os
|
||||
from typing import List, TypedDict
|
||||
import numpy as np
|
||||
import json
|
||||
from time import strftime, localtime
|
||||
import pickle
|
||||
import functools
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
ready: bool = True
|
||||
MODEL_MATCH_STRING = "[0-9]{2}_[0-9]{2}_[0-9]{4}-[0-9]{2}_[0-9]{2}"
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from keras.preprocessing.text import Tokenizer
|
||||
from keras_preprocessing.sequence import pad_sequences
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Embedding, LSTM, Dense
|
||||
from keras.models import load_model
|
||||
from keras.backend import clear_session
|
||||
tf.config.optimizer.set_jit(True)
|
||||
except ImportError:
|
||||
print("ERROR: Failed to import Tensorflow. Here is a list of required dependencies:",(
|
||||
"tensorflow==2.10.0"
|
||||
"(for Nvidia users: tensorflow-gpu==2.10.0)"
|
||||
"(for macOS: tensorflow-metal==0.6.0, tensorflow-macos==2.10.0)"
|
||||
"numpy~=1.23"
|
||||
))
|
||||
ready = False
|
||||
|
||||
class TFCallback(keras.callbacks.Callback):
|
||||
def __init__(self,bot, progress_embed:discord.Embed, message):
|
||||
self.embed:discord.Embed = progress_embed
|
||||
self.bot:commands.Bot = bot
|
||||
self.message = message
|
||||
self.times:List[int] = [time.time()]
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
pass
|
||||
|
||||
async def send_message(self,message:str, description:str, **kwargs):
|
||||
if "epoch" in kwargs:
|
||||
self.times.append(time.time())
|
||||
average_epoch_time:int = np.average(np.diff(np.array(self.times)))
|
||||
description = f"ETA: {round(average_epoch_time)}s"
|
||||
self.embed.add_field(name=f"<t:{round(time.time())}:t> - {message}",value=description,inline=False)
|
||||
await self.message.edit(embed=self.embed)
|
||||
|
||||
def on_train_end(self,logs=None):
|
||||
self.bot.loop.create_task(self.send_message("Training stopped", "training has been stopped."))
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
self.bot.loop.create_task(self.send_message(f"Starting epoch {epoch}","This might take a while", epoch=True))
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
self.bot.loop.create_task(self.send_message(f"Epoch {epoch} ended",f"Accuracy: {round(logs.get('accuracy',0.0),4)}"))
|
||||
|
||||
|
||||
class Ai:
|
||||
def __init__(self):
|
||||
model_path = settings.get("model_path")
|
||||
if model_path:
|
||||
self.__load_model(model_path)
|
||||
self.is_loaded = model_path is not None
|
||||
self.batch_size = 64
|
||||
|
||||
def get_model_name_from_path(self,path:str):
|
||||
match:re.Match = re.search(MODEL_MATCH_STRING, path)
|
||||
return path[match.start():][:match.end()]
|
||||
|
||||
def generate_model_name(self) -> str:
|
||||
return strftime('%d_%m_%Y-%H_%M', localtime())
|
||||
|
||||
def generate_model_abs_path(self, name:str):
|
||||
name = name or self.generate_model_name()
|
||||
return os.path.join(".","models",self.generate_model_name(),"model.h5")
|
||||
|
||||
def generate_tokenizer_abs_path(self, name:str):
|
||||
name = name or self.generate_model_name()
|
||||
return os.path.join(".","models",name,"tokenizer.pkl")
|
||||
|
||||
def generate_info_abs_path(self,name:str):
|
||||
name = name or self.generate_model_name()
|
||||
return os.path.join(".","models",name,"info.json")
|
||||
|
||||
|
||||
def save_model(self,model, tokenizer, history, _name:str=None):
|
||||
name:str = _name or self.generate_model_name()
|
||||
os.makedirs(os.path.join(".","models",name), exist_ok=True)
|
||||
|
||||
with open(self.generate_info_abs_path(name),"w") as f:
|
||||
json.dump(history.history,f)
|
||||
|
||||
with open(self.generate_tokenizer_abs_path(name), "wb") as f:
|
||||
pickle.dump(tokenizer,f)
|
||||
|
||||
model.save(self.generate_model_abs_path(name))
|
||||
|
||||
|
||||
def __load_model(self, model_path:str):
|
||||
clear_session()
|
||||
self.model = load_model(os.path.join(model_path,"model.h5"))
|
||||
|
||||
model_name:str = self.get_model_name_from_path(model_path)
|
||||
|
||||
try:
|
||||
with open(self.generate_tokenizer_abs_path(model_name),"rb") as f:
|
||||
self.tokenizer = pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
print("Failed to load tokenizer for model... Using default")
|
||||
self.tokenizer = Tokenizer()
|
||||
|
||||
with open("memory.json","r") as f:
|
||||
self.tokenizer.fit_on_sequences(json.load(f))
|
||||
self.is_loaded = True
|
||||
|
||||
def reload_model(self):
|
||||
clear_session()
|
||||
model_path:str = settings.get("model_path")
|
||||
if model_path:
|
||||
self.model = self.__load_model(model_path)
|
||||
self.is_loaded = True
|
||||
|
||||
async def run_async(self,func,bot,*args,**kwargs):
|
||||
func = functools.partial(func,*args,**kwargs)
|
||||
return await bot.loop.run_in_executor(None,func)
|
||||
|
||||
class Learning(Ai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __generate_labels_and_inputs(self,memory: List[str], tokenizer=None) -> tuple:
|
||||
if not tokenizer:
|
||||
tokenizer = Tokenizer()
|
||||
tokenizer.fit_on_texts(memory)
|
||||
sequences = tokenizer.texts_to_sequences(memory)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
for seq in sequences:
|
||||
for i in range(1, len(seq)):
|
||||
x.append(seq[:i])
|
||||
y.append(seq[i])
|
||||
|
||||
return x,y, tokenizer
|
||||
|
||||
def create_model(self,memory: list, iters:int=2):
|
||||
memory = memory[:2000]
|
||||
X,y,tokenizer = self.__generate_labels_and_inputs(memory)
|
||||
maxlen:int = max([len(x) for x in X])
|
||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
||||
|
||||
y = np.array(y)
|
||||
|
||||
model = Sequential()
|
||||
model.add(Embedding(input_dim=VOCAB_SIZE,output_dim=128,input_length=maxlen))
|
||||
model.add(LSTM(64))
|
||||
model.add(Dense(VOCAB_SIZE, activation="softmax"))
|
||||
|
||||
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
||||
history = model.fit(x_pad, y, epochs=iters, batch_size=64, callbacks=[tf_callback])
|
||||
self.save_model(model, tokenizer, history)
|
||||
return
|
||||
|
||||
|
||||
def add_training(self,memory: List[str], iters:int=2):
|
||||
tokenizer_path = os.path.join(settings.get("model_path"),"tokenizer.pkl")
|
||||
with open(tokenizer_path, "rb") as f:
|
||||
tokenizer = pickle.load(f)
|
||||
|
||||
X,y,_ = self.__generate_labels_and_inputs(memory, tokenizer)
|
||||
|
||||
maxlen:int = max([len(x) for x in X])
|
||||
x_pad = pad_sequences(X, maxlen=maxlen, padding="pre")
|
||||
y = np.array(y)
|
||||
|
||||
history = self.model.fit(x_pad,y, epochs=iters, validation_data=(x_pad,y), batch_size=64, callbacks=[tf_callback]) # Ideally, validation data would be seperate from the actual data
|
||||
self.save_model(self.model,tokenizer,history,self.get_model_name_from_path(settings.get("model_path")))
|
||||
return
|
||||
|
||||
class Generation(Ai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def generate_sentence(self, word_amount:int, seed:str):
|
||||
if not self.is_loaded:
|
||||
return False
|
||||
for _ in range(word_amount):
|
||||
token_list = self.tokenizer.texts_to_sequences([seed])[0]
|
||||
token_list = pad_sequences([token_list], maxlen=self.model.layers[0].input_shape[1], padding="pre")
|
||||
|
||||
output_word = "" # Sometimes model fails to predict the word, so using a fallback
|
||||
|
||||
predicted_probs = self.model.predict(token_list, verbose=0)
|
||||
predicted_word_index = np.argmax(predicted_probs, axis=-1)[0]
|
||||
|
||||
for word, index in self.tokenizer.word_index.items():
|
||||
if index == predicted_word_index:
|
||||
output_word = word
|
||||
break
|
||||
|
||||
seed += " " + output_word
|
||||
return seed
|
||||
|
||||
|
||||
VOCAB_SIZE = 100_000
|
||||
SETTINGS_TYPE = TypedDict("SETTINGS_TYPE", {
|
||||
"model_path":str, # path to the base folder of the model, aka .../models/05-01-2025-22_31/
|
||||
"tokenizer_path":str,
|
||||
})
|
||||
|
||||
tf_callback:TFCallback
|
||||
model_dropdown_items = []
|
||||
settings: SETTINGS_TYPE = {}
|
||||
|
||||
target_message:int
|
||||
learning:Learning
|
||||
generation: Generation
|
||||
|
||||
class Settings:
|
||||
def __init__(self):
|
||||
self.settings_path:str = os.path.join(".","models","settings.json")
|
||||
|
||||
def load(self):
|
||||
global settings
|
||||
try:
|
||||
with open(self.settings_path,"r") as f:
|
||||
settings = json.load(f)
|
||||
except FileNotFoundError:
|
||||
with open(self.settings_path,"w") as f:
|
||||
json.dump({},f)
|
||||
|
||||
def change_model(self,new_model_base_path:str):
|
||||
global settings
|
||||
new_model_path = os.path.join(".","models",new_model_base_path)
|
||||
|
||||
with open(self.settings_path,"r") as f:
|
||||
settings = json.load(f)
|
||||
|
||||
settings["model_path"] = new_model_path
|
||||
|
||||
with open(self.settings_path, "w") as f:
|
||||
json.dump(settings,f)
|
||||
|
||||
|
||||
class Dropdown(discord.ui.Select):
|
||||
def __init__(self, items:List[str]):
|
||||
global model_dropdown_items
|
||||
model_dropdown_items = []
|
||||
|
||||
for item in items:
|
||||
model_dropdown_items.append(
|
||||
discord.SelectOption(label=item)
|
||||
)
|
||||
|
||||
super().__init__(placeholder="Select model", options=model_dropdown_items)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction):
|
||||
if int(interaction.user.id) != int(os.getenv("ownerid")):
|
||||
await interaction.message.channel.send("KILL YOURSELF")
|
||||
Settings().change_model(self.values[0])
|
||||
await interaction.message.channel.send(f"Changed model to {self.values[0]}")
|
||||
|
||||
class DropdownView(discord.ui.View):
|
||||
def __init__(self, timeout, models):
|
||||
super().__init__(timeout=timeout)
|
||||
self.add_item(Dropdown(models))
|
||||
|
||||
|
||||
class Tf(commands.Cog):
|
||||
def __init__(self,bot):
|
||||
global learning, generation, ready
|
||||
os.makedirs(os.path.join(".","models"),exist_ok=True)
|
||||
Settings().load()
|
||||
self.bot = bot
|
||||
learning = Learning()
|
||||
generation = Generation()
|
||||
|
||||
@commands.command()
|
||||
async def start(self,ctx):
|
||||
await ctx.defer()
|
||||
await ctx.send("hi")
|
||||
|
||||
@commands.command()
|
||||
async def generate(self,ctx,seed:str,word_amount:int=5):
|
||||
await ctx.defer()
|
||||
await ctx.send(generation.generate_sentence(word_amount,seed))
|
||||
|
||||
@commands.command()
|
||||
async def create(self,ctx:commands.Context, epochs:int=3):
|
||||
global tf_callback
|
||||
await ctx.defer()
|
||||
with open("memory.json","r") as f:
|
||||
memory:List[str] = json.load(f)
|
||||
await ctx.send("Initializing tensorflow")
|
||||
embed = discord.Embed(title="Creating a model...", description="Progress of creating a model")
|
||||
embed.set_footer(text="Note: Progress tracking might report delayed / wrong data, since the function is run asynchronously")
|
||||
target_message:discord.Message = await ctx.send(embed=embed)
|
||||
|
||||
tf_callback = TFCallback(self.bot,embed,target_message)
|
||||
await learning.run_async(learning.create_model,self.bot,memory,epochs)
|
||||
embed = target_message.embeds[0]
|
||||
embed.add_field(name=f"<t:{round(time.time())}:t> Finished",value="Model saved.")
|
||||
await target_message.edit(embed=embed)
|
||||
|
||||
|
||||
@commands.command()
|
||||
async def train(self,ctx, epochs:int=2):
|
||||
global tf_callback
|
||||
|
||||
await ctx.defer()
|
||||
with open("memory.json","r") as f:
|
||||
memory:List[str] = json.load(f)
|
||||
|
||||
embed = discord.Embed(title="Training model...", description="Progress of training model")
|
||||
target_message = await ctx.send(embed=embed)
|
||||
tf_callback = TFCallback(self.bot,embed,target_message)
|
||||
|
||||
await learning.run_async(learning.add_training,self.bot,memory,epochs)
|
||||
await ctx.send("Finished!")
|
||||
|
||||
@commands.command()
|
||||
async def change(self,ctx,model:str=None):
|
||||
embed = discord.Embed(title="Change model",description="Which model would you like to use?")
|
||||
if model is None:
|
||||
models:List[str] = os.listdir(os.path.join(".","models"))
|
||||
models = [folder for folder in models if re.match(MODEL_MATCH_STRING,folder)]
|
||||
if len(models) == 0:
|
||||
models = ["No models available."]
|
||||
await ctx.send(embed=embed,view=DropdownView(90,models))
|
||||
learning.reload_model()
|
||||
generation.reload_model()
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(Tf(bot))
|
30
config.py
30
config.py
|
@ -3,9 +3,6 @@ from dotenv import load_dotenv
|
|||
import platform
|
||||
import random
|
||||
|
||||
def print_multicolored(text):
|
||||
print(text)
|
||||
|
||||
load_dotenv()
|
||||
VERSION_URL = "https://goober.whatdidyouexpect.eu"
|
||||
UPDATE_URL = VERSION_URL+"/latest_version.json"
|
||||
|
@ -15,6 +12,8 @@ PREFIX = os.getenv("BOT_PREFIX")
|
|||
hourlyspeak = int(os.getenv("hourlyspeak"))
|
||||
PING_LINE = os.getenv("PING_LINE")
|
||||
random_talk_channel_id1 = int(os.getenv("rnd_talk_channel1"))
|
||||
LOCALE = os.getenv("locale")
|
||||
gooberTOKEN = os.getenv("gooberTOKEN")
|
||||
random_talk_channel_id2 = int(os.getenv("rnd_talk_channel2"))
|
||||
cooldown_time = os.getenv("cooldown")
|
||||
splashtext = os.getenv("splashtext")
|
||||
|
@ -36,29 +35,4 @@ GREEN = "\033[32mSuccess: "
|
|||
YELLOW = "\033[33mWarning: "
|
||||
DEBUG = "\033[1;30mDebug: "
|
||||
RESET = "\033[0m"
|
||||
multicolorsplash = False
|
||||
|
||||
|
||||
def apply_multicolor(text, chunk_size=3):
|
||||
if multicolorsplash == False:
|
||||
return
|
||||
colors = [
|
||||
"\033[38;5;196m", # Red
|
||||
"\033[38;5;202m", # Orange
|
||||
"\033[38;5;220m", # Yellow
|
||||
"\033[38;5;46m", # Green
|
||||
"\033[38;5;21m", # Blue
|
||||
"\033[38;5;93m", # Indigo
|
||||
"\033[38;5;201m", # Violet
|
||||
]
|
||||
|
||||
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
||||
|
||||
colored_text = ""
|
||||
for chunk in chunks:
|
||||
color = random.choice(colors)
|
||||
colored_text += f"{color}{chunk}\033[0m"
|
||||
|
||||
return colored_text
|
||||
splashtext = apply_multicolor(splashtext)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ showmemenabled="true"
|
|||
NAME="an instance of goober"
|
||||
locale=fi
|
||||
ALIVEPING="true"
|
||||
gooberTOKEN=
|
||||
song="War Without Reason"
|
||||
POSITIVE_GIFS="https://media.discordapp.net/attachments/821047460151427135/1181371808566493184/jjpQGeno.gif, https://tenor.com/view/chill-guy-my-new-character-gif-2777893510283028272,https://tenor.com/view/goodnight-goodnight-friends-weezer-weezer-goodnight-gif-7322052181075806988"
|
||||
splashtext="
|
78
locales/fi.json
Normal file
78
locales/fi.json
Normal file
|
@ -0,0 +1,78 @@
|
|||
{
|
||||
"modification_warning": "Gooberia on muokattu! Ohitetaan palvelimen tarkistus kokonaan...",
|
||||
"reported_version": "Ilmoitettu versio:",
|
||||
"current_hash": "Tämänhetkinen hash:",
|
||||
"not_found": "ei löytynyt!",
|
||||
"version_error": "Versiotietojen saanti epäonnistui.. Tilakoodi:",
|
||||
"loaded_cog": "Ladatut cogit:",
|
||||
"cog_fail": "Cogin lataus epäonnistui kohteelle:",
|
||||
"no_model": "Olemassaolevaa markov-mallia ei löydetty. Aloitetaan alusta.",
|
||||
"folder_created": "Kansio '{folder_name}' luotu.",
|
||||
"folder_exists": "Kansio '{folder_name}' on jo olemassa...",
|
||||
"logged_in": "Kirjauduttiin sisään käyttäjänä",
|
||||
"synced_commands": "Synkronoitiin",
|
||||
"synced_commands2": "komennot!",
|
||||
"fail_commands_sync": "Komentojen synkronointi epäonnistui:",
|
||||
"started": "Goober on käynnistynyt!",
|
||||
"name_check": "Nimen saatavuuden tarkistus epäonnistui:",
|
||||
"name_taken": "Nimi on jo käytössä. Valitse toinen nimi.",
|
||||
"name_check2": "Virhe tapahtui nimen saatavuuden tarkistamisessa:",
|
||||
"add_token": "Token: {token}\nLisää tämä .env-tiedostoosi nimellä",
|
||||
"token_exists": "Token on jo olemassa .env-tiedostossa. Jatketaan määritetyllä tokenilla.",
|
||||
"registration_error": "Virhe rekisteröinnissä:",
|
||||
"version_backup": "Varmuuskopio luotu:",
|
||||
"backup_error": "Virhe: {LOCAL_VERSION_FILE}-tiedostoa ei löytynyt varmuuskopiota varten.",
|
||||
"model_loaded": "Markov-malli ladattu",
|
||||
"fetch_update_fail": "Päivitystietojen hankkiminen epäonnistui.",
|
||||
"invalid_server": "Virhe: Palvelin antoi virheellisen versiotietoraportin.",
|
||||
"new_version": "Uusi versio saatavilla: {latest_version} (Tämänhetkinen: {local_version})",
|
||||
"changelog": "Mene osoitteeseen {VERSION_URL}/goob/changes.txt katsotakseen muutoslokin\n\n",
|
||||
"invalid_version": "Versio: {local_version} on virheellinen!",
|
||||
"invalid_version2": "Jos tämä on tahallista, voit jättää tämän viestin huomiotta. Jos tämä ei ole tahallista, paina Y-näppäintä hankkiaksesi kelvollisen version, riippumatta Gooberin tämänhetkisestä versiosta.",
|
||||
"invalid_version3": "Tämänhetkinen versio varmuuskopioidaan kohteeseen current_version.bak..",
|
||||
"input": "(Y:tä tai mitä vaan muuta näppäintä jättää tämän huomioimatta....)",
|
||||
"modification_ignored": "Olet muokannut",
|
||||
"modification_ignored2": "IGNOREWARNING on asetettu false:ksi..",
|
||||
"latest_version": "Käytät uusinta versiota:",
|
||||
"latest_version2": "Tarkista {VERSION_URL}/goob/changes.txt katsotakseen muutosloki",
|
||||
"pinging_disabled": "Pingaus on poistettu käytöstä! En kerro palvelimelle, että olen päällä...",
|
||||
"goober_ping_success": "Lähetettiin olemassaolo ping goober centraliin!",
|
||||
"goober_ping_fail": "Tiedon lähetys epäonnistui. Palvelin antoi tilakoodin:",
|
||||
"goober_ping_fail2": "Tiedon lähettämisen aikana tapahtui virhe:",
|
||||
"sentence_positivity": "Lauseen positiivisuus on:",
|
||||
"command_edit_fail": "Viestin muokkaus epäonnistui:",
|
||||
"command_desc_retrain": "Uudelleenkouluttaa markov-mallin manuaalisesti.",
|
||||
"command_markov_retrain": "Uudelleenkoulutetaan markov-mallia... Odota.",
|
||||
"command_markov_memory_not_found": "Virhe: muistitiedostoa ei löytynyt!",
|
||||
"command_markov_memory_is_corrupt": "Virhe: muistitiedosto on korruptoitu!",
|
||||
"command_markov_retraining": "Käsitellään {processed_data}/{data_size} datapisteestä...",
|
||||
"command_markov_retrain_successful": "Markov-malli koulutettiin uudestaan {data_size} datapisteellä!",
|
||||
"command_desc_talk":"puhuu ja sillei",
|
||||
"command_talk_insufficent_text": "Minun pitää oppia lisää viesteistä ennen kun puhun.",
|
||||
"command_talk_generation_fail": "Minulla ei ole mitään sanottavaa!",
|
||||
"command_desc_help": "auta",
|
||||
"command_help_embed_title": "Botin apu",
|
||||
"command_help_embed_desc": "Komennot ryhmitelty kategorioilla",
|
||||
"command_help_categories_general": "Yleiset",
|
||||
"command_help_categories_admin": "Ylläpito",
|
||||
"command_help_categories_custom": "Mukautetut komennot",
|
||||
"command_ran": "Tietoa: {message.author.name} suoritti {message.content}",
|
||||
"command_desc_ping": "ping",
|
||||
"command_ping_embed_desc": "Botin viive:",
|
||||
"command_ping_footer": "Pyytäjä: ",
|
||||
"command_about_desc": "tietoa",
|
||||
"command_about_embed_title": "Tietoa minusta",
|
||||
"command_about_field1": "Nimi",
|
||||
"command_about_field2name": "Versio",
|
||||
"command_about_field2value": "Paikallinen: {local_version} \nUusin: {latest_version}",
|
||||
"command_desc_stats": "statistiikat",
|
||||
"command_stats_embed_title": "Botin statistiikat",
|
||||
"command_stats_embed_desc": "Tietoa botin muistista.",
|
||||
"command_stats_embed_field1name": "Tiedostostatistiikat",
|
||||
"command_stats_embed_field1value": "Koko: {file_size} tavua\nLinjoja: {line_count}",
|
||||
"command_stats_embed_field2name": "Versio",
|
||||
"command_stats_embed_field2value": "Paikallinen: {local_version} \nUusin: {latest_version}",
|
||||
"command_stats_embed_field3name": "Muuttajainformaatio",
|
||||
"command_stats_embed_field3value": "Nimi: {NAME} \nEtuliite: {PREFIX} \nOmistajan ID: {ownerid} \nJäähtymisaika: {cooldown_time} \nPing-linja: {PING_LINE} \nMuistin jako päällä: {showmemenabled} \nOppiminen käyttäjistä: {USERTRAIN_ENABLED} \nViimeisin satunnainen TT: {last_random_talk_time} \nLaulu: {song} \nRoisketeksti: ```{splashtext}```"
|
||||
}
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue